detection of constant diagonal matrix * vector

Hi!

currently instcombine does not detect constant diagonal matrix times vector,
for example
a.xx * [2 0] + a.yy * [0 3]
can be optimized to
a * [2 3]

I have implemented this for float. I know that this assumes x * 0 = 0 which is not
ieee compliant but i post it here in case it is interesting for someone. on my wish list
there is still an option for target independent optimizations to have x * 0 = 0.

-Jochen

static void getIntVector(Value* value, SmallVector<int, 8>& values)
{
     if (llvm::ConstantVector* constantVector = llvm::dyn_cast<llvm::ConstantVector>(value))
     {
         // get components
         llvm::SmallVector<llvm::Constant*, 8> elements;
         constantVector->getVectorElements(elements);
         int numElements = int(elements.size());
         for (int i = 0; i < numElements; ++i)
         {
             if (llvm::ConstantInt* element = llvm::dyn_cast<llvm::ConstantInt>(elements[i]))
                 values[i] = int(element->getZExtValue());
         }
     }
}

at the end of InstCombiner::visitFAdd:

   // check for constant diagonal matrix * vector: a.xx * [2 0] + a.yy * [0 3] --> a * [2 3]
   BinaryOperator* leftMul = dyn_cast<BinaryOperator>(LHS);
   BinaryOperator* rightMul = dyn_cast<BinaryOperator>(RHS);
   if (leftMul != NULL && rightMul != NULL && leftMul->getOpcode() == Instruction::FMul && rightMul->getOpcode() == Instruction::FMul)
   {
     ShuffleVectorInst* leftShuffle = dyn_cast<ShuffleVectorInst>(leftMul->getOperand(0));
     ShuffleVectorInst* rightShuffle = dyn_cast<ShuffleVectorInst>(rightMul->getOperand(0));

     // get multiplication constant vectors (e.g. [0 1])
     ConstantVector* leftConstVector = llvm::dyn_cast<ConstantVector>(leftMul->getOperand(1));
     ConstantVector* rightConstVector = llvm::dyn_cast<ConstantVector>(rightMul->getOperand(1));

     if (leftShuffle != NULL && rightShuffle != NULL && leftConstVector != NULL && rightConstVector != NULL)
     {
         Value* value = leftShuffle->getOperand(0);
         if (value == rightShuffle->getOperand(0))
         {
             int numElements = cast<VectorType>(I.getType())->getNumElements();

             // get shuffle masks (e.g. .xx)
             SmallVector<int, 8> leftMask(numElements);
             SmallVector<int, 8> rightMask(numElements);
             getIntVector(leftShuffle->getOperand(2), leftMask);
             getIntVector(rightShuffle->getOperand(2), rightMask);

             SmallVector<Constant*, 8> leftConsts;
             SmallVector<Constant*, 8> rightConsts;
             leftConstVector->getVectorElements(leftConsts);
             rightConstVector->getVectorElements(rightConsts);

             if (leftConsts.size() == numElements && rightConsts.size() == numElements)
             {
                 SmallVector<Constant*, 8> newShuffleMask(numElements);
                 SmallVector<Constant*, 8> newConst(numElements);

                 int i;
                 bool noShuffle = true;
                 for (i = 0; i < numElements; ++i)
                 {
                     // get shuffle indices
                     int leftIndex = leftMask[i];
                     int rightIndex = rightMask[i];

                     // check if indices access the first vector
                     if (leftIndex >= numElements && rightIndex >= numElements)
                         break;

                     // get values from constant vectors
                     ConstantFP* leftConst = dyn_cast<ConstantFP>(leftConsts[i]);
                     ConstantFP* rightConst = dyn_cast<ConstantFP>(rightConsts[i]);

                     // check if valid
                     if (leftConst == NULL || rightConst == NULL)
                         break;

                     // check if at least one is zero
                     if (!leftConst->isZero() && !rightConst->isZero())
                         break;

                     // assign dependent on constant
                     int index = leftIndex;
                     ConstantFP* constant = leftConst;
                     if (!rightConst->isZero())
                     {
                         index = rightIndex;
                         constant = rightConst;
                     }

                     newShuffleMask[i] = Builder->getInt32(index);
                     newConst[i] = constant;

                     noShuffle &= index == i;
                 }

                 // check if we made it through
                 if (i == numElements)
                 {
                     Value* newShuffle = noShuffle ? value : Builder->CreateShuffleVector(
                         value, leftShuffle->getOperand(1), ConstantVector::get(newShuffleMask), "shuffle");
                     return BinaryOperator::CreateFMul(newShuffle, ConstantVector::get(newConst), "mul");
                 }
             }
         }
     }
   }