Hi,
I have the following arithmetics pattern:
void denseFma(float *a, float *b, float *c, int N) {
float tmp = 0;
for (int i = 0; i < N; ++i) {
tmp = fma(a[i * 10 + 0], b[i * 8 + 0], tmp);
tmp = fma(a[i * 10 + 1], b[i * 8 + 1], tmp);
tmp = fma(a[i * 10 + 2], b[i * 8 + 2], tmp);
..
tmp = fma(a[i * 10 + 7], b[i * 8 + 7], tmp);
}
*c = tmp;
}
Its llvm ir
for.body:
%mul = mul i32 %i.0, 10
%idxprom = zext i32 %mul to i64
%arrayidx = getelementptr inbounds float, float* %a, i64 %idxprom
%0 = load float, float* %arrayidx
%mul1 = mul i32 %i.0, 8
%idxprom3 = zext i32 %mul1 to i64
%arrayidx4 = getelementptr inbounds float, float* %b, i64 %idxprom3
%1 = load float, float* %arrayidx4
%call = call float @llvm.fma.f32(float %0, float %1, float %tmp.0)
%add6 = add i32 %mul, 1
%idxprom7 = zext i32 %add6 to i64
%arrayidx8 = getelementptr inbounds float, float* %a, i64 %idxprom7
%2 = load float, float* %arrayidx8
%add10 = add i32 %mul1, 1
%idxprom11 = zext i32 %add10 to i64
%arrayidx12 = getelementptr inbounds float, float* %b, i64 %idxprom11
%3 = load float, float* %arrayidx12
%call13 = call float @llvm.fma.f32(float %2, float %3, float %call)
.....
I want to recognize the pattern, and convert it to a special builtin function:
void denseFma(float *a, float *b, float *c, int N) {
float tmp = 0;
for (int i = 0; i < N; ++i) {
tmp = specialDenseFmaBuiltin_len8(&a[i * 10], &b[i * 8], tmp);
}
*c = tmp;
}
I implement a pass to recognize this pattern, for example, I try to search a fma chain in a basicblock, if I find the chain, I then check if each neighbor fma is using correct a
and b
, so I need to do GEP comparison, for example:
%arrayidx = getelementptr inbounds float, float* %a, i64 %idxprom
...
%arrayidx8 = getelementptr inbounds float, float* %a, i64 %idxprom7
I need to check if %arrayidx8
- %arrayidx == 1 (float) offset
, … etc.
Things will get more complex because developer can write code in many different ways, for example:
void denseFma(float8 a, float8 b, float *c) {
float tmp = 0;
tmp = fma(a.s0, b.s0, tmp);
tmp = fma(a.s1, b.s1, tmp);
...
So I only support very few patterns in my implementation.
LLVM has provided include/llvm/IR/PatternMatch.h for simplifying llvm IR matching, but I feel it does not fit to my case? Because I don’t know how to match (GEP - GEP = a required offset).
I would like to know are there any systematic way in llvm to handle this kind of “(long?!) pattern-matching” problems?
Any comments will be gratefully appreciated!
Thanks
CY