diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 881b6dad2b84d..99c5e3f46b04c 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -73,6 +73,12 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx); // Floating-point types. //===----------------------------------------------------------------------===// +/// Checks whether the given type is a floating-point type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type); + +/// Returns the bitwidth of a floating-point type. +MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type); + /// Returns the typeID of an Float8E5M2 type. MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 820992de65906..e1e4eb999b3aa 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -109,8 +109,22 @@ class PyIndexType : public PyConcreteType { } }; +class PyFloatType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; + static constexpr const char *pyClassName = "FloatType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_property_readonly( + "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, + "Returns the width of the floating-point type"); + } +}; + /// Floating Point Type subclass - Float8E4M3FNType. -class PyFloat8E4M3FNType : public PyConcreteType { +class PyFloat8E4M3FNType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -130,7 +144,7 @@ class PyFloat8E4M3FNType : public PyConcreteType { }; /// Floating Point Type subclass - Float8M5E2Type. -class PyFloat8E5M2Type : public PyConcreteType { +class PyFloat8E5M2Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -150,7 +164,8 @@ class PyFloat8E5M2Type : public PyConcreteType { }; /// Floating Point Type subclass - Float8E4M3FNUZ. -class PyFloat8E4M3FNUZType : public PyConcreteType { +class PyFloat8E4M3FNUZType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -170,7 +185,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType { }; /// Floating Point Type subclass - Float8E4M3B11FNUZ. -class PyFloat8E4M3B11FNUZType : public PyConcreteType { +class PyFloat8E4M3B11FNUZType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -190,7 +206,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType { }; /// Floating Point Type subclass - Float8E5M2FNUZ. -class PyFloat8E5M2FNUZType : public PyConcreteType { +class PyFloat8E5M2FNUZType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -210,7 +227,7 @@ class PyFloat8E5M2FNUZType : public PyConcreteType { }; /// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType { +class PyBF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -230,7 +247,7 @@ class PyBF16Type : public PyConcreteType { }; /// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType { +class PyF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -250,7 +267,7 @@ class PyF16Type : public PyConcreteType { }; /// Floating Point Type subclass - TF32Type. -class PyTF32Type : public PyConcreteType { +class PyTF32Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -270,7 +287,7 @@ class PyTF32Type : public PyConcreteType { }; /// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType { +class PyF32Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -290,7 +307,7 @@ class PyF32Type : public PyConcreteType { }; /// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType { +class PyF64Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -819,6 +836,7 @@ class PyOpaqueType : public PyConcreteType { void mlir::python::populateIRTypes(py::module &m) { PyIntegerType::bind(m); + PyFloatType::bind(m); PyIndexType::bind(m); PyFloat8E4M3FNType::bind(m); PyFloat8E5M2Type::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 18c9414c5d0f3..e1a5d82587cf9 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -78,6 +78,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) { // Floating-point types. //===----------------------------------------------------------------------===// +bool mlirTypeIsAFloat(MlirType type) { + return llvm::isa(unwrap(type)); +} + +unsigned mlirFloatTypeGetWidth(MlirType type) { + return llvm::cast(unwrap(type)).getWidth(); +} + MlirTypeID mlirFloat8E5M2TypeGetTypeID() { return wrap(Float8E5M2Type::getTypeID()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 344abb64a57d2..586bf7f8e93fb 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -1442,7 +1442,17 @@ class DictAttr(Attribute): @property def typeid(self) -> TypeID: ... -class F16Type(Type): +class FloatType(Type): + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def width(self) -> int: + """ + Returns the width of the floating-point type. + """ + +class F16Type(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> F16Type: @@ -1455,7 +1465,7 @@ class F16Type(Type): @property def typeid(self) -> TypeID: ... -class F32Type(Type): +class F32Type(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> F32Type: @@ -1468,7 +1478,7 @@ class F32Type(Type): @property def typeid(self) -> TypeID: ... -class F64Type(Type): +class F64Type(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> F64Type: @@ -1502,7 +1512,7 @@ class FlatSymbolRefAttr(Attribute): Returns the value of the FlatSymbolRef attribute as a string """ -class Float8E4M3B11FNUZType(Type): +class Float8E4M3B11FNUZType(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType: @@ -1515,7 +1525,7 @@ class Float8E4M3B11FNUZType(Type): @property def typeid(self) -> TypeID: ... -class Float8E4M3FNType(Type): +class Float8E4M3FNType(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3FNType: @@ -1528,7 +1538,7 @@ class Float8E4M3FNType(Type): @property def typeid(self) -> TypeID: ... -class Float8E4M3FNUZType(Type): +class Float8E4M3FNUZType(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3FNUZType: @@ -1541,7 +1551,7 @@ class Float8E4M3FNUZType(Type): @property def typeid(self) -> TypeID: ... -class Float8E5M2FNUZType(Type): +class Float8E5M2FNUZType(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> Float8E5M2FNUZType: @@ -1554,7 +1564,7 @@ class Float8E5M2FNUZType(Type): @property def typeid(self) -> TypeID: ... -class Float8E5M2Type(Type): +class Float8E5M2Type(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> Float8E5M2Type: @@ -1601,7 +1611,7 @@ class FloatAttr(Attribute): Returns the value of the float attribute """ -class FloatTF32Type(Type): +class FloatTF32Type(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> FloatTF32Type: diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py index 30a5054ada91a..4eea1a9c372ef 100644 --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -100,8 +100,38 @@ def testTypeIsInstance(): print(IntegerType.isinstance(t1)) # CHECK: False print(F32Type.isinstance(t1)) + # CHECK: False + print(FloatType.isinstance(t1)) # CHECK: True print(F32Type.isinstance(t2)) + # CHECK: True + print(FloatType.isinstance(t2)) + + +# CHECK-LABEL: TEST: testFloatTypeSubclasses +@run +def testFloatTypeSubclasses(): + ctx = Context() + # CHECK: True + print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType)) + # CHECK: True + print(isinstance(Type.parse("f8E5M2", ctx), FloatType)) + # CHECK: True + print(isinstance(Type.parse("f8E4M3FNUZ", ctx), FloatType)) + # CHECK: True + print(isinstance(Type.parse("f8E4M3B11FNUZ", ctx), FloatType)) + # CHECK: True + print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType)) + # CHECK: True + print(isinstance(Type.parse("f16", ctx), FloatType)) + # CHECK: True + print(isinstance(Type.parse("bf16", ctx), FloatType)) + # CHECK: True + print(isinstance(Type.parse("f32", ctx), FloatType)) + # CHECK: True + print(isinstance(Type.parse("tf32", ctx), FloatType)) + # CHECK: True + print(isinstance(Type.parse("f64", ctx), FloatType)) # CHECK-LABEL: TEST: testTypeEqDoesNotRaise @@ -218,7 +248,10 @@ def testFloatType(): # CHECK: float: f32 print("float:", F32Type.get()) # CHECK: float: f64 - print("float:", F64Type.get()) + f64 = F64Type.get() + print("float:", f64) + # CHECK: f64 width: 64 + print("f64 width:", f64.width) # CHECK-LABEL: TEST: testNoneType