-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][transform] Add an op for replacing values with function calls #78398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir-func Author: Quinn Dawkins (qedawkins) ChangesAdds The idea with this operation is to allow users to author independent IR Additionally adds a mechanism for populating a type converter with a set Depends on #78397 Patch is 36.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78398.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 7a7e991c786188..e5086c26c55a4f 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -12,6 +12,8 @@
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/RegionKindInterface.td"
include "mlir/IR/OpBase.td"
def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
@@ -26,4 +28,67 @@ def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def CastAndCallOp : Op<Transform_Dialect,
+ "func.cast_and_call",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ AttrSizedOperandSegments,
+ ReportTrackingListenerFailuresOpTrait]
+ # GraphRegionNoTerminator.traits> {
+ let summary = "Casts values to the signature of a function and replaces them "
+ "with a call";
+ let description = [{
+ This transform takes a set of |input| and |output| value handles and
+ attempts to cast them to the function signature of the attached function
+ op, then builds a call to the function and replaces the users of the
+ outputs. It is the responsibility of the user to ensure that the slice of
+ the program replaced by this operation makes sense, i.e. there is no
+ verification that the inputs to this operation have any relation to the
+ outputs outside of basic dominance requirements needed for the replacement.
+
+ The casting materialization functions are specified in the graph region of
+ this op. They must implement the `TypeConversionOpInterface`. The order of
+ ops within the region is irrelevant.
+
+ The target function can be specified by a symbol name or by a handle to the
+ operation.
+
+ This transform only reads the target handles and only replaces the users of
+ the outputs with the results of the call. No handles are consumed and no
+ operations are removed. Users are expected to run cleanup separately if
+ desired.
+
+ This transform will emit a silenceable failure if:
+ - The set of outputs isn't unique
+ - The handle for the insertion point does not include exactly one operation
+ - The insertion point op does not dominate any of the output users
+ - The insertion point op is not dominated by any of the inputs
+ - The function signature does not match the number of inputs/outputs
+ - Any of the input conversions fail to be materialized
+
+ This transform will emit a definite failure if it fails to resolve the
+ target function, or if it fails to materialize the conversion from the call
+ results to the output types.
+ }];
+
+ let arguments = (ins
+ TransformHandleTypeInterface:$insertion_point,
+ UnitAttr:$insert_after,
+ Optional<TransformValueHandleTypeInterface>:$inputs,
+ Optional<TransformValueHandleTypeInterface>:$outputs,
+ OptionalAttr<SymbolRefAttr>:$function_name,
+ Optional<TransformHandleTypeInterface>:$function);
+ let results = (outs TransformHandleTypeInterface:$result);
+ let regions = (region MaxSizedRegion<1>:$conversions);
+
+ let assemblyFormat = [{
+ ($function_name^)? ($function^)?
+ ( `(` $inputs^ `)` )?
+ ( `->` $outputs^ )?
+ (`after` $insert_after^):(`before`)? $insertion_point
+ ($conversions^)? attr-dict `:` functional-type(operands, results)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 8556d9570fd120..28e9249c82e309 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -169,4 +169,17 @@ def MakeLoopIndependentOp
}];
}
+def TypeConversionCastOp : Op<Transform_Dialect,
+ "type_conversion.tensor.cast",
+ [DeclareOpInterfaceMethods<TypeConversionOpInterface>]> {
+ let description = [{
+ Indicates that tensor ops (such as tensor.generate) should be replaced with
+ constants (arith.constant) when possible.
+ }];
+ let arguments = (ins UnitAttr:$ignore_dynamic_info);
+
+ let assemblyFormat =
+ "(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
+}
+
#endif // TENSOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index f29efaee620d84..3b601f42a6452d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -280,6 +280,28 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
];
}
+def TypeConversionOpInterface : OpInterface<"TypeConversionOpInterface"> {
+ let description = [{
+ This interface should be implemented by ops that populate type casting
+ of a `transform.cast_and_inline` op. It provides a method to populate a
+ type converter with source/target materialization patterns.
+ }];
+
+ let cppNamespace = "::mlir::transform";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate the given type converter with source/target materialization
+ functions.
+ }],
+ /*returnType=*/"void",
+ /*name=*/"populateTypeMaterializations",
+ /*arguments=*/(ins "::mlir::TypeConverter &":$converter)
+ >
+ ];
+}
+
def TypeConverterBuilderOpInterface
: OpInterface<"TypeConverterBuilderOpInterface"> {
let description = [{
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index fe2c28f45aea04..6637d81dab5e2a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -725,22 +725,43 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
"functional-type(operands, results)";
}
+def GetOperandOp : TransformDialectOp<"get_operand",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
+ let summary = "Get a handle to the operand(s) of the targeted op";
+ let description = [{
+ The handle defined by this Transform op corresponds to the Operands of the
+ given `target` operation. Optionally `operand_number` can be specified to
+ select a specific operand.
+
+ This transform fails silently if the targeted operation does not have enough
+ operands. It reads the target handle and produces the result handle.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ OptionalAttr<I64Attr>:$operand_number);
+ let results = (outs TransformValueHandleTypeInterface:$result);
+ let assemblyFormat = "$target (`[` $operand_number^ `]`)? attr-dict `:` "
+ "functional-type(operands, results)";
+}
+
def GetResultOp : TransformDialectOp<"get_result",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
- let summary = "Get handle to the a result of the targeted op";
+ let summary = "Get a handle to the result(s) of the targeted op";
let description = [{
- The handle defined by this Transform op corresponds to the OpResult with
- `result_number` that is defined by the given `target` operation.
+ The handle defined by this Transform op correspond to the OpResults of the
+ given `target` operation. Optionally `result_number` can be specified to
+ select a specific result.
This transform fails silently if the targeted operation does not have enough
results. It reads the target handle and produces the result handle.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- I64Attr:$result_number);
+ OptionalAttr<I64Attr>:$result_number);
let results = (outs TransformValueHandleTypeInterface:$result);
- let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
+ let assemblyFormat = "$target (`[` $result_number^ `]`)? attr-dict `:` "
"functional-type(operands, results)";
}
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 9e9b6bcea790de..14b6e633520d6c 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -36,6 +37,202 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
return success();
}
+//===----------------------------------------------------------------------===//
+// CastAndCallOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Value> inputs;
+ if (getInputs())
+ for (Value input : state.getPayloadValues(getInputs()))
+ inputs.push_back(input);
+ SmallVector<Value> outputs;
+ if (getOutputs())
+ for (Value output : state.getPayloadValues(getOutputs()))
+ outputs.push_back(output);
+
+ // Verify that the set of output values to be replaced is unique.
+ llvm::SmallDenseSet<Value> outputSet;
+ for (Value output : outputs) {
+ outputSet.insert(output);
+ }
+ if (outputSet.size() != outputs.size()) {
+ return emitSilenceableFailure(getLoc())
+ << "cast and call output values must be unique";
+ }
+
+ // Get the insertion point for the call.
+ auto insertionOps = state.getPayloadOps(getInsertionPoint());
+ if (!llvm::hasSingleElement(insertionOps)) {
+ return emitSilenceableFailure(getLoc())
+ << "Only one op can be specified as an insertion point";
+ }
+ bool insertAfter = getInsertAfter();
+ Operation *insertionPoint = *insertionOps.begin();
+
+ // Check that all inputs dominate the insertion point, and the insertion
+ // point dominates all users of the outputs.
+ DominanceInfo dom(insertionPoint);
+ for (Value output : outputs) {
+ for (Operation *user : output.getUsers()) {
+ // If we are inserting after the insertion point operation, the
+ // insertion point operation must properly dominate the user. Otherwise
+ // basic dominance is enough.
+ bool doesDominate = insertAfter
+ ? dom.properlyDominates(insertionPoint, user)
+ : dom.dominates(insertionPoint, user);
+ if (!doesDominate) {
+ return emitDefiniteFailure()
+ << "User " << user << " is not dominated by insertion point "
+ << insertionPoint;
+ }
+ }
+ }
+
+ for (Value input : inputs) {
+ // If we are inserting before the insertion point operation, the
+ // input must properly dominate the insertion point operation. Otherwise
+ // basic dominance is enough.
+ bool doesDominate = insertAfter
+ ? dom.dominates(input, insertionPoint)
+ : dom.properlyDominates(input, insertionPoint);
+ if (!doesDominate) {
+ return emitDefiniteFailure()
+ << "input " << input << " does not dominate insertion point "
+ << insertionPoint;
+ }
+ }
+
+ // Get the function to inline. This can either be specified by symbol or as a
+ // transform handle.
+ func::FuncOp targetFunction = nullptr;
+ if (getFunctionName()) {
+ targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
+ insertionPoint, *getFunctionName());
+ if (!targetFunction) {
+ return emitDefiniteFailure()
+ << "unresolved symbol " << *getFunctionName();
+ }
+ } else if (getFunction()) {
+ auto payloadOps = state.getPayloadOps(getFunction());
+ if (!llvm::hasSingleElement(payloadOps)) {
+ return emitDefiniteFailure() << "requires a single function to call";
+ }
+ targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
+ if (!targetFunction) {
+ return emitDefiniteFailure() << "invalid non-function callee";
+ }
+ } else {
+ llvm_unreachable("Invalid CastAndCall op without a function to call");
+ return emitDefiniteFailure();
+ }
+ assert(targetFunction && "no target function found");
+
+ // Verify that the function argument and result lengths match the inputs and
+ // outputs given to this op.
+ if (targetFunction.getNumArguments() != inputs.size()) {
+ return emitSilenceableFailure(targetFunction.getLoc())
+ << "mismatch between number of function arguments "
+ << targetFunction.getNumArguments() << " and number of inputs "
+ << inputs.size();
+ }
+ if (targetFunction.getNumResults() != outputs.size()) {
+ return emitSilenceableFailure(targetFunction.getLoc())
+ << "mismatch between number of function results "
+ << targetFunction->getNumResults() << " and number of outputs "
+ << outputs.size();
+ }
+
+ // Gather all specified converters.
+ MLIRContext *ctx = insertionPoint->getContext();
+ mlir::TypeConverter converter;
+ if (!getRegion().empty()) {
+ for (Operation &op : getRegion().front()) {
+ cast<transform::TypeConversionOpInterface>(&op)
+ .populateTypeMaterializations(converter);
+ }
+ }
+
+ OpBuilder builder(ctx);
+ if (insertAfter)
+ builder.setInsertionPointAfter(insertionPoint);
+ else
+ builder.setInsertionPoint(insertionPoint);
+
+ for (auto [input, type] :
+ llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
+ if (input.getType() != type) {
+ Value newInput = converter.materializeSourceConversion(
+ builder, input.getLoc(), type, input);
+ if (!newInput) {
+ return emitSilenceableFailure(input.getLoc())
+ << "Failed to materialize conversion of " << input << " to type "
+ << type;
+ }
+ input = newInput;
+ }
+ }
+
+ auto callOp = builder.create<func::CallOp>(insertionPoint->getLoc(),
+ targetFunction, inputs);
+
+ // Cast the call results back to the expected types. If any conversions fail
+ // this is a definite failure as the call has been constructed at this point.
+ for (auto [output, newOutput] :
+ llvm::zip_equal(outputs, callOp.getResults())) {
+ Value convertedOutput = newOutput;
+ if (output.getType() != newOutput.getType()) {
+ convertedOutput = converter.materializeTargetConversion(
+ builder, output.getLoc(), output.getType(), newOutput);
+ if (!convertedOutput) {
+ return emitSilenceableFailure(output.getLoc())
+ << "Failed to materialize conversion of " << newOutput
+ << " to type " << output.getType();
+ }
+ }
+ output.replaceAllUsesExcept(convertedOutput, callOp);
+ }
+ results.set(cast<OpResult>(getResult()), {callOp});
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::CastAndCallOp::verify() {
+ if (!getRegion().empty()) {
+ for (Operation &op : getRegion().front()) {
+ if (!isa<transform::TypeConversionOpInterface>(&op)) {
+ InFlightDiagnostic diag = emitOpError()
+ << "expected children ops to implement "
+ "TypeConversionOpInterface";
+ diag.attachNote(op.getLoc()) << "op without interface";
+ return diag;
+ }
+ }
+ }
+ if (!getFunction() && !getFunctionName()) {
+ return emitOpError() << "expected a function handle or name to call";
+ }
+ if (getFunction() && getFunctionName()) {
+ return emitOpError() << "function handle and name are mutually exclusive";
+ }
+ return success();
+}
+
+void transform::CastAndCallOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getInsertionPoint(), effects);
+ if (getInputs())
+ transform::onlyReadsHandle(getInputs(), effects);
+ if (getOutputs())
+ transform::onlyReadsHandle(getOutputs(), effects);
+ if (getFunction())
+ transform::onlyReadsHandle(getFunction(), effects);
+ transform::producesHandle(getResult(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index ed274238704713..0c89ba2a1f1895 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -15,6 +15,8 @@
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace tensor;
@@ -128,6 +130,44 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
tensor::populateRewriteAsConstantPatterns(patterns);
}
+//===----------------------------------------------------------------------===//
+// TypeConversionCastOp
+//===----------------------------------------------------------------------===//
+
+void transform::TypeConversionCastOp::populateTypeMaterializations(
+ TypeConverter &converter) {
+ bool ignoreDynamicInfo = getIgnoreDynamicInfo();
+ converter.addSourceMaterialization([ignoreDynamicInfo](
+ OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ if (inputs.size() != 1) {
+ return std::nullopt;
+ }
+ Value input = inputs[0];
+ if (!ignoreDynamicInfo &&
+ !tensor::preservesStaticInformation(resultType, input.getType())) {
+ return std::nullopt;
+ }
+ if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+ return std::nullopt;
+ }
+ return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
+ });
+ converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ if (inputs.size() != 1) {
+ return std::nullopt;
+ }
+ Value input = inputs[0];
+ if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+ return std::nullopt;
+ }
+ return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
+ });
+}
+
//===----------------------------------------------------------------------===//
// MakeLoopIndependentOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index b80fc09751d2aa..59524c4c14d4fe 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -16,10 +16,12 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#i...
[truncated]
|
mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
Outdated
Show resolved
Hide resolved
This transform only reads the target handles and only replaces the users of | ||
the outputs with the results of the call. No handles are consumed and no | ||
operations are removed. Users are expected to run cleanup separately if | ||
desired. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have a memory effect to indicate "replaces the uses"... I am pondering whether we want to invalidate handles to those users after this transformation (and we would need to take the list of users as an additional handle). They keep existing and have the same signature, so maybe we can indeed get away with not invalidating them. Throughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking the users are probably not worth invalidating, but I think it probably is worth invalidating the output value handles themselves. (Also users
there is a typo, it should be uses
). If I have a handle to some consumer I can't think of a case where I'd want to invalidate that handle because one of its producers changed. That said, the uses (outputs) probably are worth invalidating because at that point those values should have no users, unless they are used by the call itself, i.e.
func.func {
%0 = foo
%1 = bar %0
}
to
func.func {
%0 = foo
%1 = call @trace(%0)
%2 = bar %1
}
We might not want to invalidate %0
in this case, but I can't think of a way to "accidentally" do this, so the caller should be aware of the IR structure at this point and can always rematch %0
. You might still be right; I haven't worked with enough examples of this op to know that I'm getting the handle invalidation rules correct here.
I was also thinking of adding a cast_and_replace_with_call
op or something like that which takes a slice of the program and does the full replacement, so users might prefer that transform if they care about being precise with handle invalidation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalidation is primarily meant to avoid the dangling-reference situation and secondarily to catch invariant violations. If we had a type along the lines of "a payload operation using a value produced by arithmetic ops" (some affine operations are like this), we'd want to invalidate that one since the invariant no longer holds. Similarly, I don't think that we necessarily want to invalidate a handle to the %0
value (we don't have handles to OpOperands, those would have to be invalidated): the value still exists, and we haven't made any strong promises about it. Maybe what I'm arguing for here is that we can occasionally re-run the verifier that checks if the contents of a handle corresponds to its type, which is currently different from invalidation. And maybe we should merge this process with invalidation, i.e., failing the conditions specified by the handle type invalidates the handle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, this makes more sense then. IIUC there are no in-tree examples of something like "a payload operation using a value produced by arithmetic ops" then? (people could have written anything downstream). So would it suffice for now to leave a warning in the op description that such cases currently won't properly track invalidation? And then adding additional verification/invalidation mechanisms can come as a follow up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we are not using much of types upstream right now. And none of the downstreams I am aware of are using them either. So it's okay to leave a warning for now.
If you have an idea where to put a summary of our discussion above so it remains visible, it would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a link to this discussion directly to the op description. We could also crystallize the warning here in the documentation as well; I suspect there are a number of other transform ops that could run into similar validation issues so adding a disclaimer for such cases could be useful. Maybe as a note here? https://mlir.llvm.org/docs/Dialects/Transform/#handle-invalidation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe as a note here? https://mlir.llvm.org/docs/Dialects/Transform/#handle-invalidation
Yes, good idea!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll send it as a follow up.
mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
Outdated
Show resolved
Hide resolved
fccd540
to
b70c13d
Compare
Adds `transform.func.cast_and_call` that takes a set of inputs and outputs and replaces the uses of those outputs with a call to a function at a specified insertion point. The idea with this operation is to allow users to author independent IR outside of a to-be-compiled module, and then match and replace a slice of the program with a call to the external function. Additionally adds a mechanism for populating a type converter with a set of conversion materialization functions that allow insertion of casts on the inputs/outputs to and from the types of the function signature.
…d address comments
b70c13d
to
e621195
Compare
Adds
transform.func.cast_and_call
that takes a set of inputs andoutputs and replaces the uses of those outputs with a call to a function
at a specified insertion point.
The idea with this operation is to allow users to author independent IR
outside of a to-be-compiled module, and then match and replace a slice of
the program with a call to the external function.
Additionally adds a mechanism for populating a type converter with a set
of conversion materialization functions that allow insertion of
casts on the inputs/outputs to and from the types of the function
signature.
Depends on #78397