Here’s my attempt at using LinalgInterchangePattern. Sorry if there is more code than there needs to be. I tried to trim down from a bigger file, but I’m kinda new to this.
I wrote a file called interchange.cpp with the following contents
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <filesystem>
#include <fstream>
#include <string>
#include <unistd.h>
using namespace mlir;
using namespace mlir::linalg;
using llvm::Error;
using llvm::Expected;
using llvm::StringError;
using llvm::Twine;
namespace mlir {
namespace {
namespace cl = llvm::cl;
struct Options {
cl::opt<std::string> inputFile{cl::Positional,
cl::desc("the input .mlir file"),
cl::init("")};
// Matrix multiplication sizes
cl::opt<int> M{"M", cl::Required,
cl::desc("Number of rows of first matrix"),
cl::init(2)};
cl::opt<int> N{"N", cl::Required,
cl::desc("Number of column of second matrix"),
cl::init(4)};
cl::opt<int> K{"K", cl::Required,
cl::desc("Number of columns of first matrix"),
cl::init(8)};
cl::opt<std::string> interchangeVector{"interchange-vector",
cl::desc("' '-separated tuple. Specifies the desired permutation"),
cl::init("0 1 2")};
};
}
template <class T>
static void convertToVector(const std::string &tileSizes, llvm::SmallVectorImpl<T> &sizes) {
std::stringstream ss(tileSizes);
int size;
while (ss >> size) {
sizes.push_back(size);
}
}
namespace {
struct InterchangeMatmulLoops : public PassWrapper<InterchangeMatmulLoops, FunctionPass> {
InterchangeMatmulLoops() = default;
InterchangeMatmulLoops(SmallVector<unsigned, 4> interchange) {
interchangeVector = interchange;
}
void runOnFunction() override {
MLIRContext *context = getFunction().getContext();
OwningRewritePatternList patterns;
patterns.insert<LinalgInterchangePattern<GenericOp>>(context, interchangeVector);
DimOp::getCanonicalizationPatterns(patterns, context); // not sure if i need this
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
SmallVector<unsigned, 4> interchangeVector = {0, 1, 2};
};
std::unique_ptr<OperationPass<FuncOp>> createInterchangeMatmulLoopsPass(std::string interchange_str) {
llvm::SmallVector<unsigned int, 4> interchange;
convertToVector(interchange_str, interchange);
return std::make_unique<InterchangeMatmulLoops>(interchange);
}
}
}
//===----------------------------------------------------------------------===//
/// Wrap a string into an llvm::StringError.
static Error make_string_error(const Twine &message) {
return llvm::make_error<StringError>(message.str(), llvm::inconvertibleErrorCode());
}
Error compile(Options &options, mlir::DialectRegistry ®istry) {
MLIRContext context;
registry.loadAll(&context);
llvm::errs() << "Read file: " << options.inputFile << "\n";
OwningModuleRef moduleRef = parseSourceFile(options.inputFile, &context);
if (!moduleRef)
return make_string_error(Twine("could not open ") + options.inputFile);
ModuleOp module = *moduleRef;
PassManager pm(module.getContext(), OpPassManager::Nesting::Implicit);
pm.addPass(createLinalgGeneralizationPass());
pm.addPass(createInterchangeMatmulLoopsPass(options.interchangeVector));
pm.addPass(createConvertLinalgToLoopsPass());
if (failed(pm.run(module))) {
return make_string_error(Twine("error compiling to llvm backend"));
}
std::string moduleStr;
llvm::raw_string_ostream ss(moduleStr);
ss << module;
std::string name = std::filesystem::path(std::string(options.inputFile)).stem();
name += "-interchanged.mlir";
std::ofstream output(name);
output << moduleStr;
output.close();
return Error::success();
}
int main(int argc, char **argv) {
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);
mlir::registerAllPasses();
llvm::InitializeNativeTarget();
mlir::registerPassManagerCLOptions();
Options options;
llvm::cl::ParseCommandLineOptions(argc, argv, "interchange\n");
auto error = compile(options, registry);
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error), [&exitCode](const llvm::ErrorInfoBase &info) {
llvm::errs() << "Error: ";
info.log(llvm::errs());
llvm::errs() << '\n';
exitCode = EXIT_FAILURE;
});
return exitCode;
}
I’ll be using the same starting mlir as before (a call to linalg.matmul of size 2x4x8).
Run the command
./interchange linalg-matmul.mlir -M 2 -N 4 -K 8 -interchange-vector "2 0 1"
This yields the following
module {
func @matmul(%arg0: memref<2x8xf64>, %arg1: memref<8x4xf64>, %arg2: memref<2x4xf64>) {
%c2 = constant 2 : index
%c8 = constant 8 : index
%c4 = constant 4 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
scf.for %arg3 = %c0 to %c8 step %c1 {
scf.for %arg4 = %c0 to %c2 step %c1 {
scf.for %arg5 = %c0 to %c4 step %c1 {
%0 = load %arg0[%arg4, %arg3] : memref<2x8xf64>
%1 = load %arg1[%arg3, %arg5] : memref<8x4xf64>
%2 = load %arg2[%arg4, %arg5] : memref<2x4xf64>
%3 = mulf %0, %1 : f64
%4 = addf %2, %3 : f64
store %4, %arg2[%arg4, %arg5] : memref<2x4xf64>
}
}
}
return
}
}
It correctly permutes the loop order as well as the indices correspondingly.
Now run
./interchange linalg-matmul.mlir -M 2 -N 4 -K 8 -interchange-vector "2 1 0"
This produces the following
module {
func @matmul(%arg0: memref<2x8xf64>, %arg1: memref<8x4xf64>, %arg2: memref<2x4xf64>) {
%c2 = constant 2 : index
%c8 = constant 8 : index
%c4 = constant 4 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
scf.for %arg3 = %c0 to %c2 step %c1 {
scf.for %arg4 = %c0 to %c4 step %c1 {
scf.for %arg5 = %c0 to %c8 step %c1 {
%0 = load %arg0[%arg3, %arg5] : memref<2x8xf64>
%1 = load %arg1[%arg5, %arg4] : memref<8x4xf64>
%2 = load %arg2[%arg3, %arg4] : memref<2x4xf64>
%3 = mulf %0, %1 : f64
%4 = addf %2, %3 : f64
store %4, %arg2[%arg3, %arg4] : memref<2x4xf64>
}
}
}
return
}
}
which isn’t correct. In fact, nothing has been permuted.
Am I doing something wrong here, or is this a bug? Thanks again for your help.