Skip to content

[MLIR][Python] Added a base class to all builtin floating point types #81720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2024

Conversation

superbobry
Copy link
Contributor

This allows to

  • check if a given ir.Type is a floating point type via isinstance() or issubclass()
  • get the bitwidth of a floating point type

See motivation and discussion in https://discourse.llvm.org/t/add-floattype-to-mlir-python-bindings/76959.

This allows to

* check if a given ir.Type is a floating point type via isinstance() or issubclass()
* get the bitwidth of a floating point type

See motivation and discussion in https://discourse.llvm.org/t/add-floattype-to-mlir-python-bindings/76959.
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Feb 14, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2024

@llvm/pr-subscribers-mlir

Author: Sergei Lebedev (superbobry)

Changes

This allows to

  • check if a given ir.Type is a floating point type via isinstance() or issubclass()
  • get the bitwidth of a floating point type

See motivation and discussion in https://discourse.llvm.org/t/add-floattype-to-mlir-python-bindings/76959.


Full diff: https://github.com/llvm/llvm-project/pull/81720.diff

5 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+6)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+28-10)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+8)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+19-9)
  • (modified) mlir/test/python/ir/builtin_types.py (+34-1)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 881b6dad2b84d7..99c5e3f46b04c1 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -73,6 +73,12 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
+/// Checks whether the given type is a floating-point type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
+
+/// Returns the bitwidth of a floating-point type.
+MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);
+
 /// Returns the typeID of an Float8E5M2 type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);
 
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 820992de659068..e1e4eb999b3aa8 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -109,8 +109,22 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
   }
 };
 
+class PyFloatType : public PyConcreteType<PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+  static constexpr const char *pyClassName = "FloatType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly(
+        "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+        "Returns the width of the floating-point type");
+  }
+};
+
 /// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
+class PyFloat8E4M3FNType
+    : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -130,7 +144,7 @@ class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
 };
 
 /// Floating Point Type subclass - Float8M5E2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
+class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -150,7 +164,8 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
 };
 
 /// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
+class PyFloat8E4M3FNUZType
+    : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -170,7 +185,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
 };
 
 /// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
+class PyFloat8E4M3B11FNUZType
+    : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -190,7 +206,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
 };
 
 /// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
+class PyFloat8E5M2FNUZType
+    : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -210,7 +227,7 @@ class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
 };
 
 /// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type> {
+class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -230,7 +247,7 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
 };
 
 /// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type> {
+class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -250,7 +267,7 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
 };
 
 /// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type> {
+class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -270,7 +287,7 @@ class PyTF32Type : public PyConcreteType<PyTF32Type> {
 };
 
 /// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type> {
+class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -290,7 +307,7 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
 };
 
 /// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type> {
+class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -819,6 +836,7 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
 
 void mlir::python::populateIRTypes(py::module &m) {
   PyIntegerType::bind(m);
+  PyFloatType::bind(m);
   PyIndexType::bind(m);
   PyFloat8E4M3FNType::bind(m);
   PyFloat8E5M2Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 18c9414c5d0f34..e1a5d82587cf9e 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -78,6 +78,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
+bool mlirTypeIsAFloat(MlirType type) {
+  return llvm::isa<FloatType>(unwrap(type));
+}
+
+unsigned mlirFloatTypeGetWidth(MlirType type) {
+  return llvm::cast<FloatType>(unwrap(type)).getWidth();
+}
+
 MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
   return wrap(Float8E5M2Type::getTypeID());
 }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 344abb64a57d23..586bf7f8e93fba 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -1442,7 +1442,17 @@ class DictAttr(Attribute):
     @property
     def typeid(self) -> TypeID: ...
 
-class F16Type(Type):
+class FloatType(Type):
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def width(self) -> int:
+        """
+        Returns the width of the floating-point type.
+        """
+
+class F16Type(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> F16Type:
@@ -1455,7 +1465,7 @@ class F16Type(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class F32Type(Type):
+class F32Type(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> F32Type:
@@ -1468,7 +1478,7 @@ class F32Type(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class F64Type(Type):
+class F64Type(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> F64Type:
@@ -1502,7 +1512,7 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
-class Float8E4M3B11FNUZType(Type):
+class Float8E4M3B11FNUZType(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType:
@@ -1515,7 +1525,7 @@ class Float8E4M3B11FNUZType(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class Float8E4M3FNType(Type):
+class Float8E4M3FNType(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E4M3FNType:
@@ -1528,7 +1538,7 @@ class Float8E4M3FNType(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class Float8E4M3FNUZType(Type):
+class Float8E4M3FNUZType(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E4M3FNUZType:
@@ -1541,7 +1551,7 @@ class Float8E4M3FNUZType(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class Float8E5M2FNUZType(Type):
+class Float8E5M2FNUZType(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E5M2FNUZType:
@@ -1554,7 +1564,7 @@ class Float8E5M2FNUZType(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class Float8E5M2Type(Type):
+class Float8E5M2Type(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E5M2Type:
@@ -1601,7 +1611,7 @@ class FloatAttr(Attribute):
         Returns the value of the float attribute
         """
 
-class FloatTF32Type(Type):
+class FloatTF32Type(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> FloatTF32Type:
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 30a5054ada91ac..4eea1a9c372ef7 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -100,8 +100,38 @@ def testTypeIsInstance():
     print(IntegerType.isinstance(t1))
     # CHECK: False
     print(F32Type.isinstance(t1))
+    # CHECK: False
+    print(FloatType.isinstance(t1))
     # CHECK: True
     print(F32Type.isinstance(t2))
+    # CHECK: True
+    print(FloatType.isinstance(t2))
+
+
+# CHECK-LABEL: TEST: testFloatTypeSubclasses
+@run
+def testFloatTypeSubclasses():
+    ctx = Context()
+    # CHECK: True
+    print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f8E5M2", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f8E4M3FNUZ", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f8E4M3B11FNUZ", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f16", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("bf16", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f32", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("tf32", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f64", ctx), FloatType))
 
 
 # CHECK-LABEL: TEST: testTypeEqDoesNotRaise
@@ -218,7 +248,10 @@ def testFloatType():
         # CHECK: float: f32
         print("float:", F32Type.get())
         # CHECK: float: f64
-        print("float:", F64Type.get())
+        f64 = F64Type.get()
+        print("float:", f64)
+        # CHECK: f64 width: 64
+        print("f64 width:", f64.width)
 
 
 # CHECK-LABEL: TEST: testNoneType

@ftynse ftynse merged commit 82f3cbc into llvm:main Feb 14, 2024
Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@superbobry superbobry deleted the piper_export_cl_606904969 branch February 14, 2024 20:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants