diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h index 9f8da83ae9c20..86bcd580c1b44 100644 --- a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h +++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h @@ -52,6 +52,15 @@ class LLVMImportDialectInterface return failure(); } + /// Hook for derived dialect interfaces to implement the import of + /// instructions into MLIR. + virtual LogicalResult + convertInstruction(OpBuilder &builder, llvm::Instruction *inst, + ArrayRef llvmOperands, + LLVM::ModuleImport &moduleImport) const { + return failure(); + } + /// Hook for derived dialect interfaces to implement the import of metadata /// into MLIR. Attaches the converted metadata kind and node to the provided /// operation. @@ -66,6 +75,14 @@ class LLVMImportDialectInterface /// returns the list of supported intrinsic identifiers. virtual ArrayRef getSupportedIntrinsics() const { return {}; } + /// Hook for derived dialect interfaces to publish the supported instructions. + /// As every LLVM IR instruction has a unique integer identifier, the function + /// returns the list of supported instruction identifiers. These identifiers + /// will then be used to match LLVM instructions to the appropriate import + /// interface and `convertInstruction` method. It is an error to have multiple + /// interfaces overriding the same instruction. + virtual ArrayRef getSupportedInstructions() const { return {}; } + /// Hook for derived dialect interfaces to publish the supported metadata /// kinds. As every metadata kind has a unique integer identifier, the /// function returns the list of supported metadata identifiers. @@ -88,21 +105,40 @@ class LLVMImportInterface LogicalResult initializeImport() { for (const LLVMImportDialectInterface &iface : *this) { // Verify the supported intrinsics have not been mapped before. - const auto *it = + const auto *intrinsicIt = llvm::find_if(iface.getSupportedIntrinsics(), [&](unsigned id) { return intrinsicToDialect.count(id); }); - if (it != iface.getSupportedIntrinsics().end()) { + if (intrinsicIt != iface.getSupportedIntrinsics().end()) { + return emitError( + UnknownLoc::get(iface.getContext()), + llvm::formatv( + "expected unique conversion for intrinsic ({0}), but " + "got conflicting {1} and {2} conversions", + *intrinsicIt, iface.getDialect()->getNamespace(), + intrinsicToDialect.lookup(*intrinsicIt)->getNamespace())); + } + const auto *instructionIt = + llvm::find_if(iface.getSupportedInstructions(), [&](unsigned id) { + return instructionToDialect.count(id); + }); + if (instructionIt != iface.getSupportedInstructions().end()) { return emitError( UnknownLoc::get(iface.getContext()), - llvm::formatv("expected unique conversion for intrinsic ({0}), but " - "got conflicting {1} and {2} conversions", - *it, iface.getDialect()->getNamespace(), - intrinsicToDialect.lookup(*it)->getNamespace())); + llvm::formatv( + "expected unique conversion for instruction ({0}), but " + "got conflicting {1} and {2} conversions", + *intrinsicIt, iface.getDialect()->getNamespace(), + instructionToDialect.lookup(*intrinsicIt) + ->getDialect() + ->getNamespace())); } // Add a mapping for all supported intrinsic identifiers. for (unsigned id : iface.getSupportedIntrinsics()) intrinsicToDialect[id] = iface.getDialect(); + // Add a mapping for all supported instruction identifiers. + for (unsigned id : iface.getSupportedInstructions()) + instructionToDialect[id] = &iface; // Add a mapping for all supported metadata kinds. for (unsigned kind : iface.getSupportedMetadata()) metadataToDialect[kind].push_back(iface.getDialect()); @@ -132,6 +168,26 @@ class LLVMImportInterface return intrinsicToDialect.count(id); } + /// Converts the LLVM instruction to an MLIR operation if a conversion exists. + /// Returns failure otherwise. + LogicalResult convertInstruction(OpBuilder &builder, llvm::Instruction *inst, + ArrayRef llvmOperands, + LLVM::ModuleImport &moduleImport) const { + // Lookup the dialect interface for the given instruction. + const LLVMImportDialectInterface *iface = + instructionToDialect.lookup(inst->getOpcode()); + if (!iface) + return failure(); + + return iface->convertInstruction(builder, inst, llvmOperands, moduleImport); + } + + /// Returns true if the given LLVM IR instruction is convertible to an MLIR + /// operation. + bool isConvertibleInstruction(unsigned id) { + return instructionToDialect.count(id); + } + /// Attaches the given LLVM metadata to the imported operation if a conversion /// to one or more MLIR dialect attributes exists and succeeds. Returns /// success if at least one of the conversions is successful and failure if @@ -166,6 +222,7 @@ class LLVMImportInterface private: DenseMap intrinsicToDialect; + DenseMap instructionToDialect; DenseMap> metadataToDialect; }; diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 6e70d52fa760b..af998b99d511f 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -123,12 +123,18 @@ static SmallVector getPositionFromIndices(ArrayRef indices) { /// access to the private module import methods. static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder, llvm::Instruction *inst, - ModuleImport &moduleImport) { + ModuleImport &moduleImport, + LLVMImportInterface &iface) { // Copy the operands to an LLVM operands array reference for conversion. SmallVector operands(inst->operands()); ArrayRef llvmOperands(operands); // Convert all instructions that provide an MLIR builder. + if (iface.isConvertibleInstruction(inst->getOpcode())) + return iface.convertInstruction(odsBuilder, inst, llvmOperands, + moduleImport); + // TODO: Implement the `convertInstruction` hooks in the + // `LLVMDialectLLVMIRImportInterface` and move the following include there. #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc" return failure(); } @@ -1596,7 +1602,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { } // Convert all instructions that have an mlirBuilder. - if (succeeded(convertInstructionImpl(builder, inst, *this))) + if (succeeded(convertInstructionImpl(builder, inst, *this, iface))) return success(); return emitError(loc) << "unhandled instruction: " << diag(*inst); diff --git a/mlir/test/Target/LLVMIR/Import/test.ll b/mlir/test/Target/LLVMIR/Import/test.ll new file mode 100644 index 0000000000000..a3165d6020104 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/test.ll @@ -0,0 +1,11 @@ +; RUN: mlir-translate -test-import-llvmir %s | FileCheck %s + +; CHECK-LABEL: @custom_load +; CHECK-SAME: %[[PTR:[[:alnum:]]+]] +define double @custom_load(ptr %ptr) { + ; CHECK: %[[LOAD:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64 + ; CHECK: %[[TEST:[0-9]+]] = "test.same_operand_element_type"(%[[LOAD]], %[[LOAD]]) : (f64, f64) -> f64 + %1 = load double, ptr %ptr + ; CHECK: llvm.return %[[TEST]] : f64 + ret double %1 +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index b82b1631eead5..47ddcf6524748 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES TestDialect.cpp TestPatterns.cpp TestTraits.cpp + TestFromLLVMIRTranslation.cpp TestToLLVMIRTranslation.cpp ) @@ -86,6 +87,23 @@ add_mlir_library(MLIRTestDialect MLIRTransforms ) +add_mlir_translation_library(MLIRTestFromLLVMIRTranslation + TestFromLLVMIRTranslation.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + MLIRTestDialect + MLIRSupport + MLIRTargetLLVMIRImport + MLIRLLVMIRToLLVMTranslation +) + add_mlir_translation_library(MLIRTestToLLVMIRTranslation TestToLLVMIRTranslation.cpp diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp new file mode 100644 index 0000000000000..3673d62bea2c9 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp @@ -0,0 +1,111 @@ +//===- TestFromLLVMIRTranslation.cpp - Import Test dialect from LLVM IR ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between LLVM IR and the MLIR Test dialect. +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h" +#include "mlir/Target/LLVMIR/Import.h" +#include "mlir/Target/LLVMIR/ModuleImport.h" +#include "mlir/Tools/mlir-translate/Translation.h" + +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" + +using namespace mlir; +using namespace test; + +static ArrayRef getSupportedInstructionsImpl() { + static unsigned instructions[] = {llvm::Instruction::Load}; + return instructions; +} + +static LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst, + ArrayRef llvmOperands, + LLVM::ModuleImport &moduleImport) { + FailureOr addr = moduleImport.convertValue(llvmOperands[0]); + if (failed(addr)) + return failure(); + // Create the LoadOp + Value loadOp = builder.create( + moduleImport.translateLoc(inst->getDebugLoc()), + moduleImport.convertType(inst->getType()), *addr); + moduleImport.mapValue(inst) = builder.create( + loadOp.getLoc(), loadOp.getType(), loadOp, loadOp); + return success(); +} + +namespace { +class TestDialectLLVMImportDialectInterface + : public LLVMImportDialectInterface { +public: + using LLVMImportDialectInterface::LLVMImportDialectInterface; + + LogicalResult + convertInstruction(OpBuilder &builder, llvm::Instruction *inst, + ArrayRef llvmOperands, + LLVM::ModuleImport &moduleImport) const override { + switch (inst->getOpcode()) { + case llvm::Instruction::Load: + return convertLoad(builder, inst, llvmOperands, moduleImport); + default: + break; + } + return failure(); + } + + ArrayRef getSupportedInstructions() const override { + return getSupportedInstructionsImpl(); + } +}; +} // namespace + +namespace mlir { +void registerTestFromLLVMIR() { + TranslateToMLIRRegistration registration( + "test-import-llvmir", "test dialect from LLVM IR", + [](llvm::SourceMgr &sourceMgr, + MLIRContext *context) -> OwningOpRef { + llvm::SMDiagnostic err; + llvm::LLVMContext llvmContext; + std::unique_ptr llvmModule = + llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), + err, llvmContext); + if (!llvmModule) { + std::string errStr; + llvm::raw_string_ostream errStream(errStr); + err.print(/*ProgName=*/"", errStream); + emitError(UnknownLoc::get(context)) << errStream.str(); + return {}; + } + if (llvm::verifyModule(*llvmModule, &llvm::errs())) + return nullptr; + + return translateLLVMIRToModule(std::move(llvmModule), context, false); + }, + [](DialectRegistry ®istry) { + registry.insert(); + registry.insert(); + registerLLVMDialectImport(registry); + registry.addExtension( + +[](MLIRContext *ctx, test::TestDialect *dialect) { + dialect->addInterfaces(); + }); + }); +} +} // namespace mlir diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp index 4f9661c058c2d..309def888a073 100644 --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -23,6 +23,7 @@ void registerTestRoundtripSPIRV(); void registerTestRoundtripDebugSPIRV(); #ifdef MLIR_INCLUDE_TESTS void registerTestToLLVMIR(); +void registerTestFromLLVMIR(); #endif } // namespace mlir @@ -31,6 +32,7 @@ static void registerTestTranslations() { registerTestRoundtripDebugSPIRV(); #ifdef MLIR_INCLUDE_TESTS registerTestToLLVMIR(); + registerTestFromLLVMIR(); #endif }