(Long?!) Pattern Matching

Hi,

I have the following arithmetics pattern:

void denseFma(float *a, float *b, float *c, int N) {
  float tmp = 0;
  for (int i = 0; i < N; ++i) {
    tmp = fma(a[i * 10 + 0], b[i * 8 + 0], tmp);
    tmp = fma(a[i * 10 + 1], b[i * 8 + 1], tmp);
    tmp = fma(a[i * 10 + 2], b[i * 8 + 2], tmp);
    ..
    tmp = fma(a[i * 10 + 7], b[i * 8 + 7], tmp);
  }
  *c = tmp;
}

Its llvm ir

for.body:
  %mul = mul i32 %i.0, 10
  %idxprom = zext i32 %mul to i64
  %arrayidx = getelementptr inbounds float, float* %a, i64 %idxprom
  %0 = load float, float* %arrayidx
  %mul1 = mul i32 %i.0, 8
  %idxprom3 = zext i32 %mul1 to i64
  %arrayidx4 = getelementptr inbounds float, float* %b, i64 %idxprom3
  %1 = load float, float* %arrayidx4
  %call = call float @llvm.fma.f32(float %0, float %1, float %tmp.0)
  %add6 = add i32 %mul, 1
  %idxprom7 = zext i32 %add6 to i64
  %arrayidx8 = getelementptr inbounds float, float* %a, i64 %idxprom7
  %2 = load float, float* %arrayidx8
  %add10 = add i32 %mul1, 1
  %idxprom11 = zext i32 %add10 to i64
  %arrayidx12 = getelementptr inbounds float, float* %b, i64 %idxprom11
  %3 = load float, float* %arrayidx12
  %call13 = call float @llvm.fma.f32(float %2, float %3, float %call)
  .....

I want to recognize the pattern, and convert it to a special builtin function:

void denseFma(float *a, float *b, float *c, int N) {
  float tmp = 0;
  for (int i = 0; i < N; ++i) {
    tmp = specialDenseFmaBuiltin_len8(&a[i * 10], &b[i * 8], tmp);
  }
  *c = tmp;
}

I implement a pass to recognize this pattern, for example, I try to search a fma chain in a basicblock, if I find the chain, I then check if each neighbor fma is using correct a and b, so I need to do GEP comparison, for example:

  %arrayidx = getelementptr inbounds float, float* %a, i64 %idxprom
...
  %arrayidx8 = getelementptr inbounds float, float* %a, i64 %idxprom7

I need to check if %arrayidx8 - %arrayidx == 1 (float) offset, … etc.

Things will get more complex because developer can write code in many different ways, for example:

void denseFma(float8 a, float8 b, float *c) {
  float tmp = 0;
  tmp = fma(a.s0, b.s0, tmp);
  tmp = fma(a.s1, b.s1, tmp);
  ...

So I only support very few patterns in my implementation.

LLVM has provided include/llvm/IR/PatternMatch.h for simplifying llvm IR matching, but I feel it does not fit to my case? Because I don’t know how to match (GEP - GEP = a required offset).
I would like to know are there any systematic way in llvm to handle this kind of “(long?!) pattern-matching” problems?

Any comments will be gratefully appreciated!
Thanks :slight_smile:
CY