-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[MLIR][Python] Added a base class to all builtin floating point types #81720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This allows to * check if a given ir.Type is a floating point type via isinstance() or issubclass() * get the bitwidth of a floating point type See motivation and discussion in https://discourse.llvm.org/t/add-floattype-to-mlir-python-bindings/76959.
@llvm/pr-subscribers-mlir Author: Sergei Lebedev (superbobry) ChangesThis allows to
See motivation and discussion in https://discourse.llvm.org/t/add-floattype-to-mlir-python-bindings/76959. Full diff: https://github.com/llvm/llvm-project/pull/81720.diff 5 Files Affected:
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 881b6dad2b84d7..99c5e3f46b04c1 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -73,6 +73,12 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
// Floating-point types.
//===----------------------------------------------------------------------===//
+/// Checks whether the given type is a floating-point type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
+
+/// Returns the bitwidth of a floating-point type.
+MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);
+
/// Returns the typeID of an Float8E5M2 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 820992de659068..e1e4eb999b3aa8 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -109,8 +109,22 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
}
};
+class PyFloatType : public PyConcreteType<PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+ static constexpr const char *pyClassName = "FloatType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly(
+ "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+ "Returns the width of the floating-point type");
+ }
+};
+
/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
+class PyFloat8E4M3FNType
+ : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -130,7 +144,7 @@ class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
};
/// Floating Point Type subclass - Float8M5E2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
+class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -150,7 +164,8 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
};
/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
+class PyFloat8E4M3FNUZType
+ : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -170,7 +185,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
};
/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
+class PyFloat8E4M3B11FNUZType
+ : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -190,7 +206,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
};
/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
+class PyFloat8E5M2FNUZType
+ : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -210,7 +227,7 @@ class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
};
/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type> {
+class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -230,7 +247,7 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
};
/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type> {
+class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -250,7 +267,7 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
};
/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type> {
+class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -270,7 +287,7 @@ class PyTF32Type : public PyConcreteType<PyTF32Type> {
};
/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type> {
+class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -290,7 +307,7 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
};
/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type> {
+class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -819,6 +836,7 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
void mlir::python::populateIRTypes(py::module &m) {
PyIntegerType::bind(m);
+ PyFloatType::bind(m);
PyIndexType::bind(m);
PyFloat8E4M3FNType::bind(m);
PyFloat8E5M2Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 18c9414c5d0f34..e1a5d82587cf9e 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -78,6 +78,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
// Floating-point types.
//===----------------------------------------------------------------------===//
+bool mlirTypeIsAFloat(MlirType type) {
+ return llvm::isa<FloatType>(unwrap(type));
+}
+
+unsigned mlirFloatTypeGetWidth(MlirType type) {
+ return llvm::cast<FloatType>(unwrap(type)).getWidth();
+}
+
MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
return wrap(Float8E5M2Type::getTypeID());
}
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 344abb64a57d23..586bf7f8e93fba 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -1442,7 +1442,17 @@ class DictAttr(Attribute):
@property
def typeid(self) -> TypeID: ...
-class F16Type(Type):
+class FloatType(Type):
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def width(self) -> int:
+ """
+ Returns the width of the floating-point type.
+ """
+
+class F16Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> F16Type:
@@ -1455,7 +1465,7 @@ class F16Type(Type):
@property
def typeid(self) -> TypeID: ...
-class F32Type(Type):
+class F32Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> F32Type:
@@ -1468,7 +1478,7 @@ class F32Type(Type):
@property
def typeid(self) -> TypeID: ...
-class F64Type(Type):
+class F64Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> F64Type:
@@ -1502,7 +1512,7 @@ class FlatSymbolRefAttr(Attribute):
Returns the value of the FlatSymbolRef attribute as a string
"""
-class Float8E4M3B11FNUZType(Type):
+class Float8E4M3B11FNUZType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType:
@@ -1515,7 +1525,7 @@ class Float8E4M3B11FNUZType(Type):
@property
def typeid(self) -> TypeID: ...
-class Float8E4M3FNType(Type):
+class Float8E4M3FNType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E4M3FNType:
@@ -1528,7 +1538,7 @@ class Float8E4M3FNType(Type):
@property
def typeid(self) -> TypeID: ...
-class Float8E4M3FNUZType(Type):
+class Float8E4M3FNUZType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E4M3FNUZType:
@@ -1541,7 +1551,7 @@ class Float8E4M3FNUZType(Type):
@property
def typeid(self) -> TypeID: ...
-class Float8E5M2FNUZType(Type):
+class Float8E5M2FNUZType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E5M2FNUZType:
@@ -1554,7 +1564,7 @@ class Float8E5M2FNUZType(Type):
@property
def typeid(self) -> TypeID: ...
-class Float8E5M2Type(Type):
+class Float8E5M2Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E5M2Type:
@@ -1601,7 +1611,7 @@ class FloatAttr(Attribute):
Returns the value of the float attribute
"""
-class FloatTF32Type(Type):
+class FloatTF32Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> FloatTF32Type:
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 30a5054ada91ac..4eea1a9c372ef7 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -100,8 +100,38 @@ def testTypeIsInstance():
print(IntegerType.isinstance(t1))
# CHECK: False
print(F32Type.isinstance(t1))
+ # CHECK: False
+ print(FloatType.isinstance(t1))
# CHECK: True
print(F32Type.isinstance(t2))
+ # CHECK: True
+ print(FloatType.isinstance(t2))
+
+
+# CHECK-LABEL: TEST: testFloatTypeSubclasses
+@run
+def testFloatTypeSubclasses():
+ ctx = Context()
+ # CHECK: True
+ print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f8E5M2", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f8E4M3FNUZ", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f8E4M3B11FNUZ", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f16", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("bf16", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f32", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("tf32", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f64", ctx), FloatType))
# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
@@ -218,7 +248,10 @@ def testFloatType():
# CHECK: float: f32
print("float:", F32Type.get())
# CHECK: float: f64
- print("float:", F64Type.get())
+ f64 = F64Type.get()
+ print("float:", f64)
+ # CHECK: f64 width: 64
+ print("f64 width:", f64.width)
# CHECK-LABEL: TEST: testNoneType
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
This allows to
See motivation and discussion in https://discourse.llvm.org/t/add-floattype-to-mlir-python-bindings/76959.