Skip to content

Commit 9600b55

Browse files
committed
[mlir][spirv] Support integer signedness
This commit updates SPIR-V dialect to support integer signedness by relaxing various checks for signless to just normal integers. The hack for spv.Bitcast can now be removed. Differential Revision: https://reviews.llvm.org/D75611
1 parent c72d60d commit 9600b55

File tree

11 files changed

+120
-111
lines changed

11 files changed

+120
-111
lines changed

mlir/docs/Dialects/SPIR-V.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,18 @@ The SPIR-V dialect reuses standard integer, float, and vector types:
252252
Specification | Dialect
253253
:----------------------------------: | :-------------------------------:
254254
`OpTypeBool` | `i1`
255-
`OpTypeInt <bitwidth>` | `i<bitwidth>`
256255
`OpTypeFloat <bitwidth>` | `f<bitwidth>`
257256
`OpTypeVector <scalar-type> <count>` | `vector<<count> x <scalar-type>>`
258257

259-
Similarly, `mlir::NoneType` can be used for SPIR-V `OpTypeVoid`; builtin
260-
function types can be used for SPIR-V `OpTypeFunction` types.
258+
For integer types, the SPIR-V dialect supports all signedness semantics
259+
(signless, signed, unsigned) in order to ease transformations from higher level
260+
dialects. However, SPIR-V spec only defines two signedness semantics state: 0
261+
indicates unsigned, or no signedness semantics, 1 indicates signed semantics. So
262+
both `iN` and `uiN` are serialized into the same `OpTypeInt N 0`. For
263+
deserialization, we always treat `OpTypeInt N 0` as `iN`.
264+
265+
`mlir::NoneType` is used for SPIR-V `OpTypeVoid`; builtin function types are
266+
used for SPIR-V `OpTypeFunction` types.
261267

262268
The SPIR-V dialect and defines the following dialect-specific types:
263269

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2945,6 +2945,17 @@ def SPV_SamplerUseAttr:
29452945
// SPIR-V type definitions
29462946
//===----------------------------------------------------------------------===//
29472947

2948+
class IOrUI<int width>
2949+
: Type<Or<[CPred<"$_self.isSignlessInteger(" # width # ")">,
2950+
CPred<"$_self.isUnsignedInteger(" # width # ")">]>,
2951+
width # "-bit signless/unsigned integer"> {
2952+
int bitwidth = width;
2953+
}
2954+
2955+
class SignlessOrUnsignedIntOfWidths<list<int> widths> :
2956+
AnyTypeOf<!foreach(w, widths, IOrUI<w>),
2957+
StrJoinInt<widths, "/">.result # "-bit signless/unsigned integer">;
2958+
29482959
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
29492960
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
29502961
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
@@ -2953,8 +2964,8 @@ def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
29532964
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
29542965
// for the definition of the following types and type categories.
29552966

2956-
def SPV_Void : TypeAlias<NoneType, "void type">;
2957-
def SPV_Bool : I<1>;
2967+
def SPV_Void : TypeAlias<NoneType, "void">;
2968+
def SPV_Bool : TypeAlias<I1, "bool">;
29582969
def SPV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
29592970
def SPV_Float : FloatOfWidths<[16, 32, 64]>;
29602971
def SPV_Float16or32 : FloatOfWidths<[16, 32]>;
@@ -2977,6 +2988,8 @@ def SPV_Type : AnyTypeOf<[
29772988
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
29782989
]>;
29792990

2991+
def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
2992+
29802993
class SPV_ScalarOrVectorOf<Type type> :
29812994
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>;
29822995

@@ -2985,7 +2998,8 @@ def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
29852998

29862999
class SPV_Vec4<Type type> : VectorOfLengthAndType<[4], [type]>;
29873000
def SPV_IntVec4 : SPV_Vec4<SPV_Integer>;
2988-
def SPV_I32Vec4 : SPV_Vec4<I32>;
3001+
def SPV_IOrUIVec4 : SPV_Vec4<SPV_SignlessOrUnsignedInt>;
3002+
def SPV_Int32Vec4 : SPV_Vec4<AnyI32>;
29893003

29903004
// TODO(antiagainst): Use a more appropriate way to model optional operands
29913005
class SPV_Optional<Type type> : Variadic<type>;

mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
6161
);
6262

6363
let results = (outs
64-
SPV_I32Vec4:$result
64+
SPV_Int32Vec4:$result
6565
);
6666

6767
let verifier = [{ return success(); }];

mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
9595
);
9696

9797
let results = (outs
98-
SPV_IntVec4:$result
98+
SPV_IOrUIVec4:$result
9999
);
100100

101101
let assemblyFormat = [{

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,12 @@ static unsigned getBitWidth(Type type) {
363363
// TODO: Make sure not caller relies on the actual pointer width value.
364364
return 64;
365365
}
366-
if (type.isSignlessIntOrFloat()) {
366+
367+
if (type.isIntOrFloat())
367368
return type.getIntOrFloatBitWidth();
368-
}
369+
369370
if (auto vectorType = type.dyn_cast<VectorType>()) {
370-
assert(vectorType.getElementType().isSignlessIntOrFloat());
371+
assert(vectorType.getElementType().isIntOrFloat());
371372
return vectorType.getNumElements() *
372373
vectorType.getElementType().getIntOrFloatBitWidth();
373374
}
@@ -500,7 +501,7 @@ static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
500501
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
501502
auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
502503
auto elementType = ptrType.getPointeeType();
503-
if (!elementType.isSignlessInteger())
504+
if (!elementType.isa<IntegerType>())
504505
return op->emitOpError(
505506
"pointer operand must point to an integer value, found ")
506507
<< elementType;
@@ -1265,7 +1266,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
12651266
numElements *= t.getNumElements();
12661267
opElemType = t.getElementType();
12671268
}
1268-
if (!opElemType.isSignlessIntOrFloat()) {
1269+
if (!opElemType.isIntOrFloat()) {
12691270
return constOp.emitOpError("only support nested array result type");
12701271
}
12711272

@@ -1769,8 +1770,6 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
17691770
//===----------------------------------------------------------------------===//
17701771

17711772
static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
1772-
// TODO(antiagainst): check the result integer type's signedness bit is 0.
1773-
17741773
spirv::Scope scope = ballotOp.execution_scope();
17751774
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
17761775
return ballotOp.emitOpError(

mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp

Lines changed: 20 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,6 @@ class Deserializer {
344344
/// insertion point.
345345
LogicalResult processUndef(ArrayRef<uint32_t> operands);
346346

347-
/// Processes an OpBitcast instruction.
348-
LogicalResult processBitcast(ArrayRef<uint32_t> words);
349-
350347
/// Method to dispatch to the specialized deserialization function for an
351348
/// operation in SPIR-V dialect that is a mirror of an instruction in the
352349
/// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for
@@ -1045,30 +1042,35 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
10451042

10461043
switch (opcode) {
10471044
case spirv::Opcode::OpTypeVoid:
1048-
if (operands.size() != 1) {
1045+
if (operands.size() != 1)
10491046
return emitError(unknownLoc, "OpTypeVoid must have no parameters");
1050-
}
10511047
typeMap[operands[0]] = opBuilder.getNoneType();
10521048
break;
10531049
case spirv::Opcode::OpTypeBool:
1054-
if (operands.size() != 1) {
1050+
if (operands.size() != 1)
10551051
return emitError(unknownLoc, "OpTypeBool must have no parameters");
1056-
}
10571052
typeMap[operands[0]] = opBuilder.getI1Type();
10581053
break;
1059-
case spirv::Opcode::OpTypeInt:
1060-
if (operands.size() != 3) {
1054+
case spirv::Opcode::OpTypeInt: {
1055+
if (operands.size() != 3)
10611056
return emitError(
10621057
unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
1063-
}
1064-
// TODO: Ignoring the signedness right now. Need to handle this effectively
1065-
// in the MLIR representation.
1066-
typeMap[operands[0]] = opBuilder.getIntegerType(operands[1]);
1067-
break;
1058+
1059+
// SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1060+
// to preserve or validate.
1061+
// 0 indicates unsigned, or no signedness semantics
1062+
// 1 indicates signed semantics."
1063+
//
1064+
// So we cannot differentiate signless and unsigned integers; always use
1065+
// signless semantics for such cases.
1066+
auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1067+
: IntegerType::SignednessSemantics::Signless;
1068+
typeMap[operands[0]] = IntegerType::get(operands[1], sign, context);
1069+
} break;
10681070
case spirv::Opcode::OpTypeFloat: {
1069-
if (operands.size() != 2) {
1071+
if (operands.size() != 2)
10701072
return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
1071-
}
1073+
10721074
Type floatTy;
10731075
switch (operands[1]) {
10741076
case 16:
@@ -1146,7 +1148,7 @@ LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
11461148
}
11471149

11481150
if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
1149-
count = intVal.getInt();
1151+
count = intVal.getValue().getZExtValue();
11501152
} else {
11511153
return emitError(unknownLoc, "OpTypeArray count must come from a "
11521154
"scalar integer constant instruction");
@@ -1451,8 +1453,7 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
14511453
}
14521454

14531455
auto resultID = operands[1];
1454-
if (resultType.isSignlessInteger() || resultType.isa<FloatType>() ||
1455-
resultType.isa<VectorType>()) {
1456+
if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) {
14561457
auto attr = opBuilder.getZeroAttr(resultType);
14571458
// For normal constants, we just record the attribute (and its type) for
14581459
// later materialization at use sites.
@@ -2051,8 +2052,6 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
20512052
// First dispatch all the instructions whose opcode does not correspond to
20522053
// those that have a direct mirror in the SPIR-V dialect
20532054
switch (opcode) {
2054-
case spirv::Opcode::OpBitcast:
2055-
return processBitcast(operands);
20562055
case spirv::Opcode::OpCapability:
20572056
return processCapability(operands);
20582057
case spirv::Opcode::OpExtension:
@@ -2152,76 +2151,6 @@ LogicalResult Deserializer::processUndef(ArrayRef<uint32_t> operands) {
21522151
return success();
21532152
}
21542153

2155-
// TODO(b/130356985): This method is copied from the auto-generated
2156-
// deserialization function for OpBitcast instruction. This is to avoid
2157-
// generating a Bitcast operations for cast from signed integer to unsigned
2158-
// integer and viceversa. MLIR doesn't have native support for this so they both
2159-
// end up mapping to the same type right now which is illegal according to
2160-
// OpBitcast semantics (and enforced by the SPIR-V dialect).
2161-
LogicalResult Deserializer::processBitcast(ArrayRef<uint32_t> words) {
2162-
SmallVector<Type, 1> resultTypes;
2163-
size_t wordIndex = 0;
2164-
(void)wordIndex;
2165-
uint32_t valueID = 0;
2166-
(void)valueID;
2167-
{
2168-
if (wordIndex >= words.size()) {
2169-
return emitError(
2170-
unknownLoc,
2171-
"expected result type <id> while deserializing spirv::BitcastOp");
2172-
}
2173-
auto ty = getType(words[wordIndex]);
2174-
if (!ty) {
2175-
return emitError(unknownLoc, "unknown type result <id> : ")
2176-
<< words[wordIndex];
2177-
}
2178-
resultTypes.push_back(ty);
2179-
wordIndex++;
2180-
if (wordIndex >= words.size()) {
2181-
return emitError(
2182-
unknownLoc,
2183-
"expected result <id> while deserializing spirv::BitcastOp");
2184-
}
2185-
}
2186-
valueID = words[wordIndex++];
2187-
SmallVector<Value, 4> operands;
2188-
SmallVector<NamedAttribute, 4> attributes;
2189-
if (wordIndex < words.size()) {
2190-
auto arg = getValue(words[wordIndex]);
2191-
if (!arg) {
2192-
return emitError(unknownLoc, "unknown result <id> : ")
2193-
<< words[wordIndex];
2194-
}
2195-
operands.push_back(arg);
2196-
wordIndex++;
2197-
}
2198-
if (wordIndex != words.size()) {
2199-
return emitError(unknownLoc,
2200-
"found more operands than expected when deserializing "
2201-
"spirv::BitcastOp, only ")
2202-
<< wordIndex << " of " << words.size() << " processed";
2203-
}
2204-
if (resultTypes[0] == operands[0].getType() &&
2205-
resultTypes[0].isSignlessInteger()) {
2206-
// TODO(b/130356985): This check is added to ignore error in Op verification
2207-
// due to both signed and unsigned integers mapping to the same
2208-
// type. Without this check this method is same as what is auto-generated.
2209-
valueMap[valueID] = operands[0];
2210-
return success();
2211-
}
2212-
2213-
auto op = opBuilder.create<spirv::BitcastOp>(unknownLoc, resultTypes,
2214-
operands, attributes);
2215-
(void)op;
2216-
valueMap[valueID] = op.getResult();
2217-
2218-
if (decorations.count(valueID)) {
2219-
auto attrs = decorations[valueID].getAttrs();
2220-
attributes.append(attrs.begin(), attrs.end());
2221-
}
2222-
return success();
2223-
}
2224-
22252154
LogicalResult Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
22262155
if (operands.size() < 4) {
22272156
return emitError(unknownLoc,

mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -932,8 +932,11 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
932932

933933
typeEnum = spirv::Opcode::OpTypeInt;
934934
operands.push_back(intType.getWidth());
935-
// TODO(antiagainst): support unsigned integers
936-
operands.push_back(1);
935+
// SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
936+
// to preserve or validate.
937+
// 0 indicates unsigned, or no signedness semantics
938+
// 1 indicates signed semantics."
939+
operands.push_back(intType.isSigned() ? 1 : 0);
937940
return success();
938941
}
939942

mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ spv.module "Logical" "GLSL450" {
44
spv.func @bit_cast(%arg0 : f32) "None" {
55
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} : f32 to i32
66
%0 = spv.Bitcast %arg0 : f32 to i32
7+
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} : i32 to si32
8+
%1 = spv.Bitcast %0 : i32 to si32
9+
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} : si32 to i32
10+
%2 = spv.Bitcast %1 : si32 to ui32
711
spv.Return
812
}
913
}

mlir/test/Dialect/SPIRV/Serialization/constant.mlir

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,37 @@ spv.module "Logical" "GLSL450" {
2727
spv.Return
2828
}
2929

30+
// CHECK-LABEL: @si32_const
31+
spv.func @si32_const() -> () "None" {
32+
// CHECK: spv.constant 0 : si32
33+
%0 = spv.constant 0 : si32
34+
// CHECK: spv.constant 10 : si32
35+
%1 = spv.constant 10 : si32
36+
// CHECK: spv.constant -5 : si32
37+
%2 = spv.constant -5 : si32
38+
39+
%3 = spv.IAdd %0, %1 : si32
40+
%4 = spv.IAdd %2, %3 : si32
41+
spv.Return
42+
}
43+
44+
// CHECK-LABEL: @ui32_const
45+
// We cannot differentiate signless vs. unsigned integers in SPIR-V blob
46+
// because they all use 1 as the signedness bit. So we always treat them
47+
// as signless integers.
48+
spv.func @ui32_const() -> () "None" {
49+
// CHECK: spv.constant 0 : i32
50+
%0 = spv.constant 0 : ui32
51+
// CHECK: spv.constant 10 : i32
52+
%1 = spv.constant 10 : ui32
53+
// CHECK: spv.constant -5 : i32
54+
%2 = spv.constant 4294967291 : ui32
55+
56+
%3 = spv.IAdd %0, %1 : ui32
57+
%4 = spv.IAdd %2, %3 : ui32
58+
spv.Return
59+
}
60+
3061
// CHECK-LABEL: @i64_const
3162
spv.func @i64_const() -> () "None" {
3263
// CHECK: spv.constant 4294967296 : i64
@@ -141,8 +172,23 @@ spv.module "Logical" "GLSL450" {
141172
spv.Return
142173
}
143174

144-
// CHECK-LABEL: @array_const
145-
spv.func @array_const() -> (!spv.array<2 x vector<2xf32>>) "None" {
175+
// CHECK-LABEL: @ui64_array_const
176+
spv.func @ui64_array_const() -> (!spv.array<3xui64>) "None" {
177+
// CHECK: spv.constant [5, 6, 7] : !spv.array<3 x i64>
178+
%0 = spv.constant [5 : ui64, 6 : ui64, 7 : ui64] : !spv.array<3 x ui64>
179+
180+
spv.ReturnValue %0: !spv.array<3xui64>
181+
}
182+
183+
// CHECK-LABEL: @si32_array_const
184+
spv.func @si32_array_const() -> (!spv.array<3xsi32>) "None" {
185+
// CHECK: spv.constant [5 : si32, 6 : si32, 7 : si32] : !spv.array<3 x si32>
186+
%0 = spv.constant [5 : si32, 6 : si32, 7 : si32] : !spv.array<3 x si32>
187+
188+
spv.ReturnValue %0 : !spv.array<3xsi32>
189+
}
190+
// CHECK-LABEL: @float_array_const
191+
spv.func @float_array_const() -> (!spv.array<2 x vector<2xf32>>) "None" {
146192
// CHECK: spv.constant [dense<3.000000e+00> : vector<2xf32>, dense<[4.000000e+00, 5.000000e+00]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>>
147193
%0 = spv.constant [dense<3.0> : vector<2xf32>, dense<[4., 5.]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>>
148194

mlir/test/Dialect/SPIRV/non-uniform-ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
2020

2121
// -----
2222

23+
func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
24+
// expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
25+
%0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xsi32>
26+
return %0: vector<4xsi32>
27+
}
28+
29+
// -----
30+
2331
//===----------------------------------------------------------------------===//
2432
// spv.GroupNonUniformElect
2533
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)