diff --git a/clang/lib/CodeGen/TargetBuiltins/DirectX.cpp b/clang/lib/CodeGen/TargetBuiltins/DirectX.cpp index 202601e257036..51202331bb779 100644 --- a/clang/lib/CodeGen/TargetBuiltins/DirectX.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/DirectX.cpp @@ -25,12 +25,17 @@ Value *CodeGenFunction::EmitDirectXBuiltinExpr(unsigned BuiltinID, case DirectX::BI__builtin_dx_dot2add: { Value *A = EmitScalarExpr(E->getArg(0)); Value *B = EmitScalarExpr(E->getArg(1)); - Value *C = EmitScalarExpr(E->getArg(2)); + Value *Acc = EmitScalarExpr(E->getArg(2)); + + Value *AX = Builder.CreateExtractElement(A, Builder.getSize(0)); + Value *AY = Builder.CreateExtractElement(A, Builder.getSize(1)); + Value *BX = Builder.CreateExtractElement(B, Builder.getSize(0)); + Value *BY = Builder.CreateExtractElement(B, Builder.getSize(1)); Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add; return Builder.CreateIntrinsic( - /*ReturnType=*/C->getType(), ID, ArrayRef{A, B, C}, nullptr, - "dx.dot2add"); + /*ReturnType=*/Acc->getType(), ID, + ArrayRef{Acc, AX, AY, BX, BY}, nullptr, "dx.dot2add"); } } return nullptr; diff --git a/clang/test/CodeGenDirectX/Builtins/dot2add.c b/clang/test/CodeGenDirectX/Builtins/dot2add.c index 181f61fea1544..47c639b5986ce 100644 --- a/clang/test/CodeGenDirectX/Builtins/dot2add.c +++ b/clang/test/CodeGenDirectX/Builtins/dot2add.c @@ -17,7 +17,13 @@ typedef half half2 __attribute__((ext_vector_type(2))); // CHECK-NEXT: [[TMP0:%.*]] = load <2 x half>, ptr [[X_ADDR]], align 4 // CHECK-NEXT: [[TMP1:%.*]] = load <2 x half>, ptr [[Y_ADDR]], align 4 // CHECK-NEXT: [[TMP2:%.*]] = load float, ptr [[Z_ADDR]], align 4 -// CHECK-NEXT: [[DX_DOT2ADD:%.*]] = call float @llvm.dx.dot2add.v2f16(<2 x half> [[TMP0]], <2 x half> [[TMP1]], float [[TMP2]]) +// CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x half> [[TMP0]], i32 0 +// CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x half> [[TMP0]], i32 1 +// CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x half> [[TMP1]], i32 0 +// CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x half> [[TMP1]], i32 1 +// CHECK-NEXT: [[DX_DOT2ADD:%.*]] = call float @llvm.dx.dot2add(float [[TMP2]], half [[TMP3]], half [[TMP4]], half [[TMP5]], half [[TMP6]]) // CHECK-NEXT: ret float [[DX_DOT2ADD]] // -float test_dot2add(half2 X, half2 Y, float Z) { return __builtin_dx_dot2add(X, Y, Z); } +float test_dot2add(half2 X, half2 Y, float Z) { + return __builtin_dx_dot2add(X, Y, Z); +} diff --git a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl index 2464607dd636c..c345e17476e08 100644 --- a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl +++ b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl @@ -13,7 +13,11 @@ float test_default_parameter_type(half2 p1, half2 p2, float p3) { // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] - // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]]) // CHECK: ret float %[[RES]] return dot2add(p1, p2, p3); } @@ -25,7 +29,11 @@ float test_float_arg2_type(half2 p1, float2 p2, float p3) { // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] - // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]]) // CHECK: ret float %[[RES]] return dot2add(p1, p2, p3); } @@ -37,7 +45,11 @@ float test_float_arg1_type(float2 p1, half2 p2, float p3) { // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] - // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]]) // CHECK: ret float %[[RES]] return dot2add(p1, p2, p3); } @@ -49,7 +61,11 @@ float test_double_arg3_type(half2 p1, half2 p2, double p3) { // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] - // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]]) // CHECK: ret float %[[RES]] return dot2add(p1, p2, p3); } @@ -62,7 +78,11 @@ float test_float_arg1_arg2_type(float2 p1, float2 p2, float p3) { // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] - // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]]) // CHECK: ret float %[[RES]] return dot2add(p1, p2, p3); } @@ -75,7 +95,11 @@ float test_double_arg1_arg2_type(double2 p1, double2 p2, float p3) { // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] - // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]]) // CHECK: ret float %[[RES]] return dot2add(p1, p2, p3); } @@ -88,7 +112,11 @@ float test_int16_arg1_arg2_type(int16_t2 p1, int16_t2 p2, float p3) { // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] - // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]]) // CHECK: ret float %[[RES]] return dot2add(p1, p2, p3); } @@ -101,7 +129,11 @@ float test_int32_arg1_arg2_type(int32_t2 p1, int32_t2 p2, float p3) { // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] - // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]]) // CHECK: ret float %[[RES]] return dot2add(p1, p2, p3); } @@ -114,7 +146,11 @@ float test_int64_arg1_arg2_type(int64_t2 p1, int64_t2 p2, float p3) { // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] - // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]]) // CHECK: ret float %[[RES]] return dot2add(p1, p2, p3); } @@ -129,7 +165,11 @@ float test_bool_arg1_arg2_type(bool2 p1, bool2 p2, float p3) { // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] - // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0 + // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1 + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]]) // CHECK: ret float %[[RES]] return dot2add(p1, p2, p3); } diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 775d325feeb14..b1a27311e2a9c 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -76,18 +76,27 @@ def int_dx_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>; -def int_dx_dot2 : - DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], - [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], - [IntrNoMem, Commutative] >; -def int_dx_dot3 : - DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], - [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], - [IntrNoMem, Commutative] >; -def int_dx_dot4 : - DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], - [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], - [IntrNoMem, Commutative] >; +def int_dx_dot2 : DefaultAttrsIntrinsic<[LLVMMatchType<0>], + [ + llvm_anyfloat_ty, LLVMMatchType<0>, + LLVMMatchType<0>, LLVMMatchType<0> + ], + [IntrNoMem, Commutative]>; +def int_dx_dot3 : DefaultAttrsIntrinsic<[LLVMMatchType<0>], + [ + llvm_anyfloat_ty, LLVMMatchType<0>, + LLVMMatchType<0>, LLVMMatchType<0>, + LLVMMatchType<0>, LLVMMatchType<0> + ], + [IntrNoMem, Commutative]>; +def int_dx_dot4 : DefaultAttrsIntrinsic<[LLVMMatchType<0>], + [ + llvm_anyfloat_ty, LLVMMatchType<0>, + LLVMMatchType<0>, LLVMMatchType<0>, + LLVMMatchType<0>, LLVMMatchType<0>, + LLVMMatchType<0>, LLVMMatchType<0> + ], + [IntrNoMem, Commutative]>; def int_dx_fdot : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], @@ -100,9 +109,9 @@ def int_dx_udot : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, Commutative] >; -def int_dx_dot2add : - DefaultAttrsIntrinsic<[llvm_float_ty], - [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty], +def int_dx_dot2add : + DefaultAttrsIntrinsic<[llvm_float_ty], + [llvm_float_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty], [IntrNoMem, Commutative]>; def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index b1e7406ead675..645105ade72b6 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -1078,8 +1078,7 @@ def RawBufferStore : DXILOp<140, rawBufferStore> { } def Dot2AddHalf : DXILOp<162, dot2AddHalf> { - let Doc = "dot product of 2 vectors of half having size = 2, returns " - "float"; + let Doc = "2D half dot product with accumulate to float"; let intrinsics = [IntrinSelect]; let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy]; let result = FloatTy; diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index e44d3b70eb657..53ffcc3ebbdbe 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -169,7 +169,8 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) { assert(ATy->getScalarType()->isFloatingPointTy()); Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4; - switch (AVec->getNumElements()) { + int NumElts = AVec->getNumElements(); + switch (NumElts) { case 2: DotIntrinsic = Intrinsic::dx_dot2; break; @@ -185,8 +186,14 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) { /* gen_crash_diag=*/false); return nullptr; } - return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, - ArrayRef{A, B}, nullptr, "dot"); + + SmallVector Args; + for (int I = 0; I < NumElts; ++I) + Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I))); + for (int I = 0; I < NumElts; ++I) + Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I))); + return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args, + nullptr, "dot"); } // Create the appropriate DXIL float dot intrinsic for the operands of Orig diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 41a9426998826..4574e5f7bbd96 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -33,52 +33,6 @@ using namespace llvm; using namespace llvm::dxil; -static bool isVectorArgExpansion(Function &F) { - switch (F.getIntrinsicID()) { - case Intrinsic::dx_dot2: - case Intrinsic::dx_dot3: - case Intrinsic::dx_dot4: - return true; - } - return false; -} - -static SmallVector populateOperands(Value *Arg, IRBuilder<> &Builder) { - SmallVector ExtractedElements; - auto *VecArg = dyn_cast(Arg->getType()); - for (unsigned I = 0; I < VecArg->getNumElements(); ++I) { - Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I); - Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index); - ExtractedElements.push_back(ExtractedElement); - } - return ExtractedElements; -} - -static SmallVector -argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder, unsigned NumOperands) { - assert(NumOperands > 0); - Value *Arg0 = Orig->getOperand(0); - [[maybe_unused]] auto *VecArg0 = dyn_cast(Arg0->getType()); - assert(VecArg0); - SmallVector NewOperands = populateOperands(Arg0, Builder); - for (unsigned I = 1; I < NumOperands; ++I) { - Value *Arg = Orig->getOperand(I); - [[maybe_unused]] auto *VecArg = dyn_cast(Arg->getType()); - assert(VecArg); - assert(VecArg0->getElementType() == VecArg->getElementType()); - assert(VecArg0->getNumElements() == VecArg->getNumElements()); - auto NextOperandList = populateOperands(Arg, Builder); - NewOperands.append(NextOperandList.begin(), NextOperandList.end()); - } - return NewOperands; -} - -static SmallVector argVectorFlatten(CallInst *Orig, - IRBuilder<> &Builder) { - // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. - return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1); -} - namespace { class OpLowerer { Module &M; @@ -150,9 +104,6 @@ class OpLowerer { [[nodiscard]] bool replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp, ArrayRef ArgSelects) { - bool IsVectorArgExpansion = isVectorArgExpansion(F); - assert(!(IsVectorArgExpansion && ArgSelects.size()) && - "Cann't do vector arg expansion when using arg selects."); return replaceFunction(F, [&](CallInst *CI) -> Error { OpBuilder.getIRB().SetInsertPoint(CI); SmallVector Args; @@ -170,15 +121,6 @@ class OpLowerer { break; } } - } else if (IsVectorArgExpansion) { - Args = argVectorFlatten(CI, OpBuilder.getIRB()); - } else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) { - // arg[NumOperands-1] is a pointer and is not needed by our flattening. - // arg[NumOperands-2] also does not need to be flattened because it is a - // scalar. - unsigned NumOperands = CI->getNumOperands() - 2; - Args.push_back(CI->getArgOperand(NumOperands)); - Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands)); } else { Args.append(CI->arg_begin(), CI->arg_end()); } diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll index 97b025d36f018..f2167aa516057 100644 --- a/llvm/test/CodeGen/DirectX/dot2_error.ll +++ b/llvm/test/CodeGen/DirectX/dot2_error.ll @@ -4,8 +4,9 @@ ; CHECK: in function dot_double2 ; CHECK-SAME: Cannot create Dot2 operation: Invalid overload type -define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) { +define noundef double @dot_double2(double noundef %a1, double noundef %a2, + double noundef %b1, double noundef %b2) { entry: - %dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b) + %dx.dot = call double @llvm.dx.dot2(double %a1, double %a2, double %b1, double %b2) ret double %dx.dot } diff --git a/llvm/test/CodeGen/DirectX/dot2add.ll b/llvm/test/CodeGen/DirectX/dot2add.ll index 40c6cdafc83da..3a2bbcc074f2d 100644 --- a/llvm/test/CodeGen/DirectX/dot2add.ll +++ b/llvm/test/CodeGen/DirectX/dot2add.ll @@ -1,8 +1,13 @@ ; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s -define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %c) { +define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %acc) { entry: -; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %c, half %0, half %1, half %2, half %3) - %ret = call float @llvm.dx.dot2add(<2 x half> %a, <2 x half> %b, float %c) + %ax = extractelement <2 x half> %a, i32 0 + %ay = extractelement <2 x half> %a, i32 1 + %bx = extractelement <2 x half> %b, i32 0 + %by = extractelement <2 x half> %b, i32 1 + +; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %acc, half %ax, half %ay, half %bx, half %by) + %ret = call float @llvm.dx.dot2add(float %acc, half %ax, half %ay, half %bx, half %by) ret float %ret } diff --git a/llvm/test/CodeGen/DirectX/dot3_error.ll b/llvm/test/CodeGen/DirectX/dot3_error.ll index 3b5dc41ebeb6b..69cfb32047f23 100644 --- a/llvm/test/CodeGen/DirectX/dot3_error.ll +++ b/llvm/test/CodeGen/DirectX/dot3_error.ll @@ -4,8 +4,10 @@ ; CHECK: in function dot_double3 ; CHECK-SAME: Cannot create Dot3 operation: Invalid overload type -define noundef double @dot_double3(<3 x double> noundef %a, <3 x double> noundef %b) { +define noundef double @dot_double3(double noundef %a1, double noundef %a2, + double noundef %a3, double noundef %b1, + double noundef %b2, double noundef %b3) { entry: - %dx.dot = call double @llvm.dx.dot3.v3f64(<3 x double> %a, <3 x double> %b) + %dx.dot = call double @llvm.dx.dot3(double %a1, double %a2, double %a3, double %b1, double %b2, double %b3) ret double %dx.dot } diff --git a/llvm/test/CodeGen/DirectX/dot4_error.ll b/llvm/test/CodeGen/DirectX/dot4_error.ll index 0a5969616220e..f6c7ad93bd136 100644 --- a/llvm/test/CodeGen/DirectX/dot4_error.ll +++ b/llvm/test/CodeGen/DirectX/dot4_error.ll @@ -4,8 +4,11 @@ ; CHECK: in function dot_double4 ; CHECK-SAME: Cannot create Dot4 operation: Invalid overload type -define noundef double @dot_double4(<4 x double> noundef %a, <4 x double> noundef %b) { +define noundef double @dot_double4(double noundef %a1, double noundef %a2, + double noundef %a3, double noundef %a4, + double noundef %b1, double noundef %b2, + double noundef %b3, double noundef %b4) { entry: - %dx.dot = call double @llvm.dx.dot4.v4f64(<4 x double> %a, <4 x double> %b) + %dx.dot = call double @llvm.dx.dot4(double %a1, double %a2, double %a3, double %a4, double %b1, double %b2, double %b3, double %b4) ret double %dx.dot } diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll index c6f36087ba91d..a623321c2d346 100644 --- a/llvm/test/CodeGen/DirectX/fdot.ll +++ b/llvm/test/CodeGen/DirectX/fdot.ll @@ -6,12 +6,12 @@ ; CHECK-LABEL: dot_half2 define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) { entry: -; DOPCHECK: extractelement <2 x half> %a, i32 0 -; DOPCHECK: extractelement <2 x half> %a, i32 1 -; DOPCHECK: extractelement <2 x half> %b, i32 0 -; DOPCHECK: extractelement <2 x half> %b, i32 1 -; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) #[[#ATTR:]] -; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b) +; CHECK: [[A1:%.*]] = extractelement <2 x half> %a, i32 0 +; CHECK: [[A2:%.*]] = extractelement <2 x half> %a, i32 1 +; CHECK: [[B1:%.*]] = extractelement <2 x half> %b, i32 0 +; CHECK: [[B2:%.*]] = extractelement <2 x half> %b, i32 1 +; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half [[A1]], half [[A2]], half [[B1]], half [[B2]]) #[[#ATTR:]] +; EXPCHECK: call half @llvm.dx.dot2.f16(half [[A1]], half [[A2]], half [[B1]], half [[B2]]) %dx.dot = call half @llvm.dx.fdot.v2f16(<2 x half> %a, <2 x half> %b) ret half %dx.dot } @@ -19,14 +19,14 @@ entry: ; CHECK-LABEL: dot_half3 define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) { entry: -; DOPCHECK: extractelement <3 x half> %a, i32 0 -; DOPCHECK: extractelement <3 x half> %a, i32 1 -; DOPCHECK: extractelement <3 x half> %a, i32 2 -; DOPCHECK: extractelement <3 x half> %b, i32 0 -; DOPCHECK: extractelement <3 x half> %b, i32 1 -; DOPCHECK: extractelement <3 x half> %b, i32 2 +; CHECK: extractelement <3 x half> %a, i32 0 +; CHECK: extractelement <3 x half> %a, i32 1 +; CHECK: extractelement <3 x half> %a, i32 2 +; CHECK: extractelement <3 x half> %b, i32 0 +; CHECK: extractelement <3 x half> %b, i32 1 +; CHECK: extractelement <3 x half> %b, i32 2 ; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) #[[#ATTR]] -; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b) +; EXPCHECK: call half @llvm.dx.dot3.f16(half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) %dx.dot = call half @llvm.dx.fdot.v3f16(<3 x half> %a, <3 x half> %b) ret half %dx.dot } @@ -34,16 +34,16 @@ entry: ; CHECK-LABEL: dot_half4 define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) { entry: -; DOPCHECK: extractelement <4 x half> %a, i32 0 -; DOPCHECK: extractelement <4 x half> %a, i32 1 -; DOPCHECK: extractelement <4 x half> %a, i32 2 -; DOPCHECK: extractelement <4 x half> %a, i32 3 -; DOPCHECK: extractelement <4 x half> %b, i32 0 -; DOPCHECK: extractelement <4 x half> %b, i32 1 -; DOPCHECK: extractelement <4 x half> %b, i32 2 -; DOPCHECK: extractelement <4 x half> %b, i32 3 +; CHECK: extractelement <4 x half> %a, i32 0 +; CHECK: extractelement <4 x half> %a, i32 1 +; CHECK: extractelement <4 x half> %a, i32 2 +; CHECK: extractelement <4 x half> %a, i32 3 +; CHECK: extractelement <4 x half> %b, i32 0 +; CHECK: extractelement <4 x half> %b, i32 1 +; CHECK: extractelement <4 x half> %b, i32 2 +; CHECK: extractelement <4 x half> %b, i32 3 ; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) #[[#ATTR]] -; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b) +; EXPCHECK: call half @llvm.dx.dot4.f16(half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) %dx.dot = call half @llvm.dx.fdot.v4f16(<4 x half> %a, <4 x half> %b) ret half %dx.dot } @@ -51,12 +51,12 @@ entry: ; CHECK-LABEL: dot_float2 define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) { entry: -; DOPCHECK: extractelement <2 x float> %a, i32 0 -; DOPCHECK: extractelement <2 x float> %a, i32 1 -; DOPCHECK: extractelement <2 x float> %b, i32 0 -; DOPCHECK: extractelement <2 x float> %b, i32 1 +; CHECK: extractelement <2 x float> %a, i32 0 +; CHECK: extractelement <2 x float> %a, i32 1 +; CHECK: extractelement <2 x float> %b, i32 0 +; CHECK: extractelement <2 x float> %b, i32 1 ; DOPCHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) #[[#ATTR]] -; EXPCHECK: call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b) +; EXPCHECK: call float @llvm.dx.dot2.f32(float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) %dx.dot = call float @llvm.dx.fdot.v2f32(<2 x float> %a, <2 x float> %b) ret float %dx.dot } @@ -64,14 +64,14 @@ entry: ; CHECK-LABEL: dot_float3 define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) { entry: -; DOPCHECK: extractelement <3 x float> %a, i32 0 -; DOPCHECK: extractelement <3 x float> %a, i32 1 -; DOPCHECK: extractelement <3 x float> %a, i32 2 -; DOPCHECK: extractelement <3 x float> %b, i32 0 -; DOPCHECK: extractelement <3 x float> %b, i32 1 -; DOPCHECK: extractelement <3 x float> %b, i32 2 +; CHECK: extractelement <3 x float> %a, i32 0 +; CHECK: extractelement <3 x float> %a, i32 1 +; CHECK: extractelement <3 x float> %a, i32 2 +; CHECK: extractelement <3 x float> %b, i32 0 +; CHECK: extractelement <3 x float> %b, i32 1 +; CHECK: extractelement <3 x float> %b, i32 2 ; DOPCHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) #[[#ATTR]] -; EXPCHECK: call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b) +; EXPCHECK: call float @llvm.dx.dot3.f32(float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) %dx.dot = call float @llvm.dx.fdot.v3f32(<3 x float> %a, <3 x float> %b) ret float %dx.dot } @@ -79,16 +79,16 @@ entry: ; CHECK-LABEL: dot_float4 define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) { entry: -; DOPCHECK: extractelement <4 x float> %a, i32 0 -; DOPCHECK: extractelement <4 x float> %a, i32 1 -; DOPCHECK: extractelement <4 x float> %a, i32 2 -; DOPCHECK: extractelement <4 x float> %a, i32 3 -; DOPCHECK: extractelement <4 x float> %b, i32 0 -; DOPCHECK: extractelement <4 x float> %b, i32 1 -; DOPCHECK: extractelement <4 x float> %b, i32 2 -; DOPCHECK: extractelement <4 x float> %b, i32 3 +; CHECK: extractelement <4 x float> %a, i32 0 +; CHECK: extractelement <4 x float> %a, i32 1 +; CHECK: extractelement <4 x float> %a, i32 2 +; CHECK: extractelement <4 x float> %a, i32 3 +; CHECK: extractelement <4 x float> %b, i32 0 +; CHECK: extractelement <4 x float> %b, i32 1 +; CHECK: extractelement <4 x float> %b, i32 2 +; CHECK: extractelement <4 x float> %b, i32 3 ; DOPCHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) #[[#ATTR]] -; EXPCHECK: call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b) +; EXPCHECK: call float @llvm.dx.dot4.f32(float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) %dx.dot = call float @llvm.dx.fdot.v4f32(<4 x float> %a, <4 x float> %b) ret float %dx.dot } diff --git a/llvm/test/CodeGen/DirectX/normalize.ll b/llvm/test/CodeGen/DirectX/normalize.ll index 2aba9d5f74d78..cde09dacf4742 100644 --- a/llvm/test/CodeGen/DirectX/normalize.ll +++ b/llvm/test/CodeGen/DirectX/normalize.ll @@ -22,7 +22,7 @@ entry: define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) { entry: - ; EXPCHECK: [[doth2:%.*]] = call half @llvm.dx.dot2.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + ; EXPCHECK: [[doth2:%.*]] = call half @llvm.dx.dot2.f16(half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) ; DOPCHECK: [[doth2:%.*]] = call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) ; EXPCHECK: [[rsqrt:%.*]] = call half @llvm.dx.rsqrt.f16(half [[doth2]]) ; DOPCHECK: [[rsqrt:%.*]] = call half @dx.op.unary.f16(i32 25, half [[doth2]]) @@ -36,7 +36,7 @@ entry: define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) { entry: - ; EXPCHECK: [[doth3:%.*]] = call half @llvm.dx.dot3.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}) + ; EXPCHECK: [[doth3:%.*]] = call half @llvm.dx.dot3.f16(half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) ; DOPCHECK: [[doth3:%.*]] = call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) ; EXPCHECK: [[rsqrt:%.*]] = call half @llvm.dx.rsqrt.f16(half [[doth3]]) ; DOPCHECK: [[rsqrt:%.*]] = call half @dx.op.unary.f16(i32 25, half [[doth3]]) @@ -50,7 +50,7 @@ entry: define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) { entry: - ; EXPCHECK: [[doth4:%.*]] = call half @llvm.dx.dot4.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}) + ; EXPCHECK: [[doth4:%.*]] = call half @llvm.dx.dot4.f16(half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) ; DOPCHECK: [[doth4:%.*]] = call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) ; EXPCHECK: [[rsqrt:%.*]] = call half @llvm.dx.rsqrt.f16(half [[doth4]]) ; DOPCHECK: [[rsqrt:%.*]] = call half @dx.op.unary.f16(i32 25, half [[doth4]]) @@ -71,7 +71,7 @@ entry: define noundef <2 x float> @test_normalize_float2(<2 x float> noundef %p0) { entry: - ; EXPCHECK: [[dotf2:%.*]] = call float @llvm.dx.dot2.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}) + ; EXPCHECK: [[dotf2:%.*]] = call float @llvm.dx.dot2.f32(float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) ; DOPCHECK: [[dotf2:%.*]] = call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) ; EXPCHECK: [[rsqrt:%.*]] = call float @llvm.dx.rsqrt.f32(float [[dotf2]]) ; DOPCHECK: [[rsqrt:%.*]] = call float @dx.op.unary.f32(i32 25, float [[dotf2]]) @@ -85,7 +85,7 @@ entry: define noundef <3 x float> @test_normalize_float3(<3 x float> noundef %p0) { entry: - ; EXPCHECK: [[dotf3:%.*]] = call float @llvm.dx.dot3.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}) + ; EXPCHECK: [[dotf3:%.*]] = call float @llvm.dx.dot3.f32(float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) ; DOPCHECK: [[dotf3:%.*]] = call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) ; EXPCHECK: [[rsqrt:%.*]] = call float @llvm.dx.rsqrt.f32(float [[dotf3]]) ; DOPCHECK: [[rsqrt:%.*]] = call float @dx.op.unary.f32(i32 25, float [[dotf3]]) @@ -99,7 +99,7 @@ entry: define noundef <4 x float> @test_normalize_float4(<4 x float> noundef %p0) { entry: - ; EXPCHECK: [[dotf4:%.*]] = call float @llvm.dx.dot4.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}) + ; EXPCHECK: [[dotf4:%.*]] = call float @llvm.dx.dot4.f32(float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) ; DOPCHECK: [[dotf4:%.*]] = call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) ; EXPCHECK: [[rsqrt:%.*]] = call float @llvm.dx.rsqrt.f32(float [[dotf4]]) ; DOPCHECK: [[rsqrt:%.*]] = call float @dx.op.unary.f32(i32 25, float [[dotf4]])