Hi @antiagainst,
On top of this review: https://reviews.llvm.org/D82384, I am trying to allow the CopyMemory
op to accept up to 2 memory access operands to follow the spec closer (SPIR-V Specification). I have an initial patch that parses 2 sets of memory operands fine (not commited since it doesn’t fully work):
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index c92af561faf..45e40993ade 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -215,7 +215,9 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
SPV_AnyPtr:$target,
SPV_AnyPtr:$source,
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
+ OptionalAttr<I32Attr>:$alignment,
+ OptionalAttr<SPV_MemoryAccessAttr>:$source_memory_access,
+ OptionalAttr<I32Attr>:$source_alignment
);
let results = (outs);
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 58c96ea7a01..690b28f9bd1 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -29,6 +29,7 @@ using namespace mlir;
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
+static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
static constexpr const char kCallee[] = "callee";
static constexpr const char kClusterSize[] = "cluster_size";
@@ -183,6 +184,35 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
return parser.parseRSquare();
}
+// TODO Clean this up and merge it with parseMemoryAccessAttributes2 into 1
+// function.
+static ParseResult parseMemoryAccessAttributes2(OpAsmParser &parser,
+ OperationState &state) {
+ // Parse an optional list of attributes staring with '['
+ if (parser.parseOptionalLSquare()) {
+ // Nothing to do
+ return success();
+ }
+
+ spirv::MemoryAccess memoryAccessAttr;
+ if (parseEnumStrAttr(memoryAccessAttr, parser, state,
+ "source_memory_access")) {
+ return failure();
+ }
+
+ if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
+ // Parse integer attribute for alignment.
+ Attribute alignmentAttr;
+ Type i32Type = parser.getBuilder().getIntegerType(32);
+ if (parser.parseComma() ||
+ parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
+ state.attributes)) {
+ return failure();
+ }
+ }
+ return parser.parseRSquare();
+}
+
template <typename MemoryOpTy>
static void
printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer,
@@ -192,10 +222,12 @@ printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer,
elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
- // Print integer alignment attribute.
- if (auto alignment = memoryOp.alignment()) {
- elidedAttrs.push_back(kAlignmentAttrName);
- printer << ", " << alignment;
+ if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
+ // Print integer alignment attribute.
+ if (auto alignment = memoryOp.alignment()) {
+ elidedAttrs.push_back(kAlignmentAttrName);
+ printer << ", " << alignment;
+ }
}
printer << "]";
}
@@ -2862,9 +2894,24 @@ static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
parseEnumStrAttr(sourceStorageClass, parser) ||
parser.parseOperand(sourcePtrInfo) ||
- parseMemoryAccessAttributes(parser, state) ||
- parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
- parser.parseType(elementType)) {
+ parseMemoryAccessAttributes(parser, state)) {
+ return failure();
+ }
+
+ if (parser.parseOptionalComma()) {
+ // No comma, hence, no 2nd memory access attributes.
+ } else {
+ // Parse 2nd memory access attributes.
+ if (parseMemoryAccessAttributes2(parser, state)) {
+ return failure();
+ }
+ }
+
+ if (parser.parseColon() || parser.parseType(elementType)) {
+ return failure();
+ }
+
+ if (parser.parseOptionalAttrDict(state.attributes)) {
return failure();
}
diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
index 25b54c05539..3a21da85f06 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
@@ -93,6 +93,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] : f32
+ // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
+ spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Volatile"] : f32
+
spv.Return
}
}
With that, the first leg of the round-trip (serialization) works just fine as far as I can tell. The default deserializiation code, however, stands in the way. For example, for the test line in the diff above (spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Volatile"] : f32
), what gets deserialized is the following:
spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] {alignment = 1 : i32} : f32
So, what is supposed to be interpreted as the second memory access attribute, is instead interpreted as the alignment for the first memory access attribute. I believe the issue is in this section of the deserialization code:
Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
...
if (wordIndex < words.size()) {
attributes.push_back(opBuilder.getNamedAttr("memory_access", opBuilder.getI32IntegerAttr(words[wordIndex++])));
}
if (wordIndex < words.size()) {
attributes.push_back(opBuilder.getNamedAttr("alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
}
if (wordIndex < words.size()) {
attributes.push_back(opBuilder.getNamedAttr("source_memory_access", opBuilder.getI32IntegerAttr(words[wordIndex++])));
}
if (wordIndex < words.size()) {
attributes.push_back(opBuilder.getNamedAttr("source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
}
...
}
Just wanted to know whether you would suggest:
- Somehow customize the deserialization logic to make it smarter?
- Model the second memory access operand in another way. Maybe my lack of familiarity with ODS prevents me from finding an easier solution to the problem.
- Something else entirely .