Skip to content

Commit d039fdd

Browse files
committed
Add test
TODO update spec Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 1e33c8b commit d039fdd

File tree

2 files changed

+164
-1
lines changed

2 files changed

+164
-1
lines changed

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3437,7 +3437,7 @@ class SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase
34373437
SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase, \
34383438
internal::Op##x##INTEL, __VA_ARGS__> \
34393439
SPIRV##x##INTEL;
3440-
_SPIRV_OP(CooperativeMatrixApplyFunction, true, 4, true)
3440+
_SPIRV_OP(CooperativeMatrixApplyFunction, true, 5)
34413441
#undef _SPIRV_OP
34423442

34433443
class SPIRVCooperativeMatrixKHRInstBase : public SPIRVInstTemplateBase {
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
;; compiled from joint_matrix_apply_bf16.cpp from intel/llvm with some modifications
2+
3+
; RUN: llvm-as < %s -o %t.bc
4+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix -o %t.spv
5+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
6+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
7+
8+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
9+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
10+
11+
; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR
12+
; CHECK-SPIRV-DAG: Capability CooperativeMatrixInvocationInstructionsINTEL
13+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
14+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix"
15+
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy:]]
16+
; CHECK-SPIRV: CompositeConstruct [[#MatTy]] [[#Mat:]]
17+
; CHECK-SPIRV: PtrCastToGeneric [[#]] [[#Ptr:]] [[#]]
18+
; CHECK-SPIRV: CooperativeMatrixApplyFunctionINTEL [[#MatTy]] [[#Apply:]] [[#Ptr]] [[#Mat]]
19+
; CHECK-SPIRV: CooperativeMatrixStoreKHR [[#]] [[#Apply]]
20+
21+
; CHECK-LLVM: %[[Mat:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @"_Z26__spirv_CompositeConstructP38class.sycl::_V1::ext::oneapi::bfloat16"
22+
; CHECK-LLVM: %[[Apply:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @"_Z43__spirv_CooperativeMatrixApplyFunctionINTELPU3AS477class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapperPU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_3"(ptr addrspace(4) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) %[[Mat]])
23+
; CHECK-LLVM: call spir_func void @"_Z33__spirv_CooperativeMatrixStoreKHRPU3AS138class.sycl::_V1::ext::oneapi::bfloat16PU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_3liii"(ptr addrspace(1) %{{.*}}, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) %[[Apply]], i64 32, i32 0, i32 3, i32 0)
24+
25+
26+
; ModuleID = 'matrix_apply.bc'
27+
source_filename = "../llvm/sycl/test-e2e/Matrix/joint_matrix_apply_bf16.cpp"
28+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
29+
target triple = "spir64-unknown-unknown"
30+
31+
%"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" }
32+
%"class.sycl::_V1::detail::array" = type { [2 x i64] }
33+
%"class.sycl::_V1::id" = type { %"class.sycl::_V1::detail::array" }
34+
%"class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapper" = type { ptr addrspace(4) }
35+
%"class.sycl::_V1::ext::oneapi::bfloat16" = type { i16 }
36+
%class.anon.0 = type <{ %"class.sycl::_V1::accessor", %class.anon, [7 x i8] }>
37+
%"class.sycl::_V1::accessor" = type { %"class.sycl::_V1::detail::AccessorImplDevice", %union.anon }
38+
%"class.sycl::_V1::detail::AccessorImplDevice" = type { %"class.sycl::_V1::id", %"class.sycl::_V1::range", %"class.sycl::_V1::range" }
39+
%union.anon = type { ptr addrspace(1) }
40+
%class.anon = type { i8 }
41+
42+
$_ZTSZZ17matrix_verify_addIN4sycl3_V13ext6oneapi8bfloat16ELm16ELm32EZ4mainEUlRS4_E_EvNS1_5queueER10big_matrixIT_XT0_EXT1_EERNS1_8nd_rangeILi2EEEfOT2_ENKUlRNS1_7handlerEE_clESI_EUlNS1_7nd_itemILi2EEEE_ = comdat any
43+
44+
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
45+
@__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
46+
47+
; Function Attrs: convergent norecurse nounwind
48+
define weak_odr dso_local spir_kernel void @_ZTSZZ17matrix_verify_addIN4sycl3_V13ext6oneapi8bfloat16ELm16ELm32EZ4mainEUlRS4_E_EvNS1_5queueER10big_matrixIT_XT0_EXT1_EERNS1_8nd_rangeILi2EEEfOT2_ENKUlRNS1_7handlerEE_clESI_EUlNS1_7nd_itemILi2EEEE_(ptr addrspace(1) noundef align 2 %_arg_accA, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA1, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA2, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accA3) local_unnamed_addr {
49+
entry:
50+
call spir_func void @__itt_offload_wi_start_wrapper()
51+
%ref.tmp.i20 = alloca %"class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapper", align 8
52+
%agg.tmp.i17 = alloca %"class.sycl::_V1::ext::oneapi::bfloat16", align 2
53+
%ref.tmp6.i = alloca float, align 4
54+
%__SYCLKernel = alloca %class.anon.0, align 8
55+
%__SYCLKernel.ascast = addrspacecast ptr %__SYCLKernel to ptr addrspace(4)
56+
call void @llvm.lifetime.start.p0(i64 64, ptr nonnull %__SYCLKernel)
57+
%agg.tmp.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accA1, align 8
58+
%agg.tmp.sroa.0.sroa.2.0._arg_accA1.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accA1, i64 8
59+
%agg.tmp.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp.sroa.0.sroa.2.0._arg_accA1.ascast.sroa_idx, align 8
60+
%agg.tmp5.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accA2, align 8
61+
%agg.tmp5.sroa.0.sroa.2.0._arg_accA2.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accA2, i64 8
62+
%agg.tmp5.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp5.sroa.0.sroa.2.0._arg_accA2.ascast.sroa_idx, align 8
63+
%agg.tmp6.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accA3, align 8
64+
%agg.tmp6.sroa.0.sroa.2.0._arg_accA3.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accA3, i64 8
65+
%agg.tmp6.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp6.sroa.0.sroa.2.0._arg_accA3.ascast.sroa_idx, align 8
66+
%0 = getelementptr inbounds %"class.sycl::_V1::accessor", ptr %__SYCLKernel, i64 0, i32 1
67+
store i64 %agg.tmp6.sroa.0.sroa.0.0.copyload, ptr %__SYCLKernel, align 8
68+
%AccessRange.i.i.i.i.i = getelementptr inbounds %"class.sycl::_V1::detail::AccessorImplDevice", ptr %__SYCLKernel, i64 0, i32 1
69+
store i64 %agg.tmp.sroa.0.sroa.0.0.copyload, ptr %AccessRange.i.i.i.i.i, align 8
70+
%MemRange.i.i.i.i.i = getelementptr inbounds %"class.sycl::_V1::detail::AccessorImplDevice", ptr %__SYCLKernel, i64 0, i32 2
71+
store i64 %agg.tmp5.sroa.0.sroa.0.0.copyload, ptr %MemRange.i.i.i.i.i, align 8
72+
%arrayidx.i21.i.i.i.i = getelementptr inbounds [2 x i64], ptr %__SYCLKernel, i64 0, i64 1
73+
store i64 %agg.tmp6.sroa.0.sroa.2.0.copyload, ptr %arrayidx.i21.i.i.i.i, align 8
74+
%arrayidx.i25.i.i.i.i = getelementptr inbounds %"class.sycl::_V1::detail::AccessorImplDevice", ptr %__SYCLKernel, i64 0, i32 1, i32 0, i32 0, i64 1
75+
store i64 %agg.tmp.sroa.0.sroa.2.0.copyload, ptr %arrayidx.i25.i.i.i.i, align 8
76+
%arrayidx.i29.i.i.i.i = getelementptr inbounds %"class.sycl::_V1::detail::AccessorImplDevice", ptr %__SYCLKernel, i64 0, i32 2, i32 0, i32 0, i64 1
77+
store i64 %agg.tmp5.sroa.0.sroa.2.0.copyload, ptr %arrayidx.i29.i.i.i.i, align 8
78+
%mul.i6.i.i.i.i = mul i64 %agg.tmp6.sroa.0.sroa.0.0.copyload, %agg.tmp5.sroa.0.sroa.2.0.copyload
79+
%1 = getelementptr %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %_arg_accA, i64 %mul.i6.i.i.i.i
80+
%add.ptr.i = getelementptr %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %1, i64 %agg.tmp6.sroa.0.sroa.2.0.copyload
81+
store ptr addrspace(1) %add.ptr.i, ptr %0, align 8
82+
%2 = load i64, ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, i64 8), align 8
83+
%3 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
84+
%4 = load i64, ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, i64 8), align 8
85+
%5 = load i64, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, align 32
86+
%ref.tmp6.ascast.i = addrspacecast ptr %ref.tmp6.i to ptr addrspace(4)
87+
%cmp.i11 = icmp ult i64 %2, 2147483648
88+
tail call void @llvm.assume(i1 %cmp.i11)
89+
%cmp.i = icmp ult i64 %3, 2147483648
90+
tail call void @llvm.assume(i1 %cmp.i)
91+
%cmp.i15 = icmp ult i64 %4, 2147483648
92+
tail call void @llvm.assume(i1 %cmp.i15)
93+
%sub.i = sub nsw i64 %2, %4
94+
%cmp.i12 = icmp ult i64 %5, 2147483648
95+
tail call void @llvm.assume(i1 %cmp.i12)
96+
%sub5.i = sub nsw i64 %3, %5
97+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %ref.tmp6.i)
98+
store float 5.000000e+00, ptr %ref.tmp6.i, align 4
99+
%call.i.i = call spir_func noundef zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4) %ref.tmp6.ascast.i)
100+
call void @llvm.lifetime.start.p0(i64 2, ptr nonnull %agg.tmp.i17)
101+
store i16 %call.i.i, ptr %agg.tmp.i17, align 2
102+
%call.i18 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z26__spirv_CompositeConstruct(ptr noundef nonnull byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2 %agg.tmp.i17)
103+
call void @llvm.lifetime.end.p0(i64 2, ptr nonnull %agg.tmp.i17)
104+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %ref.tmp6.i)
105+
%lambda.i = getelementptr inbounds %class.anon.0, ptr addrspace(4) %__SYCLKernel.ascast, i64 0, i32 1
106+
%ref.tmp.ascast.i21 = addrspacecast ptr %ref.tmp.i20 to ptr addrspace(4)
107+
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp.i20)
108+
store ptr addrspace(4) %lambda.i, ptr %ref.tmp.i20, align 8
109+
%call.i22 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef %call.i18)
110+
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ref.tmp.i20)
111+
%6 = load ptr addrspace(1), ptr %0, align 8
112+
%7 = load i64, ptr %__SYCLKernel, align 8
113+
%8 = load i64, ptr %arrayidx.i29.i.i.i.i, align 8
114+
%mul.i6.i.i.i.i.i = mul i64 %7, %8
115+
%9 = load i64, ptr %arrayidx.i21.i.i.i.i, align 8
116+
%add.i7.i.i.i.i.i = add i64 %mul.i6.i.i.i.i.i, %9
117+
%idx.neg.i.i = sub i64 0, %add.i7.i.i.i.i.i
118+
%add.ptr.i.i = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %6, i64 %idx.neg.i.i
119+
%mul12.i = shl nsw i64 %sub.i, 8
120+
%add.ptr.i43 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i.i, i64 %mul12.i
121+
%div14.i = and i64 %sub5.i, -16
122+
%add.ptr.i44 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i43, i64 %div14.i
123+
call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef %add.ptr.i44, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef %call.i22, i64 noundef 32, i32 noundef 0, i32 noundef 3, i32 noundef 0)
124+
call void @llvm.lifetime.end.p0(i64 64, ptr nonnull %__SYCLKernel)
125+
call spir_func void @__itt_offload_wi_finish_wrapper()
126+
ret void
127+
}
128+
129+
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite)
130+
declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
131+
132+
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite)
133+
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)
134+
135+
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write)
136+
declare void @llvm.assume(i1 noundef)
137+
138+
; Function Attrs: convergent nounwind
139+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z26__spirv_CompositeConstruct(ptr noundef byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2) local_unnamed_addr
140+
141+
; Function Attrs: convergent nounwind
142+
declare dso_local spir_func zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4)) local_unnamed_addr
143+
144+
; Function Attrs: convergent nounwind
145+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8), target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef) local_unnamed_addr
146+
147+
; Function Attrs: convergent nounwind
148+
declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 3) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
149+
150+
declare spir_func void @__itt_offload_wi_start_wrapper()
151+
152+
declare spir_func void @__itt_offload_wi_finish_wrapper()
153+
154+
!llvm.module.flags = !{!0, !1}
155+
!opencl.spir.version = !{!2}
156+
!spirv.Source = !{!3}
157+
!llvm.ident = !{!4}
158+
159+
!0 = !{i32 1, !"wchar_size", i32 4}
160+
!1 = !{i32 7, !"frame-pointer", i32 2}
161+
!2 = !{i32 1, i32 2}
162+
!3 = !{i32 4, i32 100000}
163+
!4 = !{!"clang version 18.0.0 (https://github.com/intel/llvm.git)"}

0 commit comments

Comments
 (0)