Skip to content

Commit 2d4f2e7

Browse files
vmaksimoMrSidims
authored andcommitted
[Backport to 15] Reverse translation of access + load/store operations for cooperative matrix (#2165)
Implement translation via SPIR-V friendly calls, as: the LLVM instructions are not capable to accept target extension types; cooperative matrix is an opaque object and accessing elements is implementation defined, hence we can't use GEP to which AccessChain naturally maps, since GEP has a different meaning. As for now some BE would need to recognize and define what to do with a call to __spirv_AccessChain(matrix, index). Better option is to map such SPIR-V to an intrinsic or define an appropriate type in LLVM (hence defining rules for GEP and other instructions) , but it's off the table now.
1 parent 350cd1d commit 2d4f2e7

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,6 +2370,10 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
23702370
auto AC = static_cast<SPIRVAccessChainBase *>(BV);
23712371
auto Base = transValue(AC->getBase(), F, BB);
23722372
SPIRVType *BaseSPVTy = AC->getBase()->getType();
2373+
if (BaseSPVTy->isTypePointer() &&
2374+
BaseSPVTy->getPointerElementType()->isTypeCooperativeMatrixKHR()) {
2375+
return mapValue(BV, transSPIRVBuiltinFromInst(AC, BB));
2376+
}
23732377
Type *BaseTy =
23742378
BaseSPVTy->isTypeVector()
23752379
? transType(
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
8+
9+
; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0
10+
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const0:]] 0
11+
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const1:]] 1 {{$}}
12+
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const3:]] 3
13+
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const12:]] 12
14+
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const42:]] 42
15+
16+
; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#TypeMatrix:]] [[#TypeInt]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const0]]
17+
; CHECK-SPIRV: TypePointer [[#TypeMatrixPtr:]] 7 [[#TypeMatrix]]
18+
; CHECK-SPIRV: TypePointer [[#TypeIntPtr:]] 7 [[#TypeInt]]
19+
20+
; CHECK-SPIRV: Variable [[#TypeMatrixPtr]] [[#VarMatrixPtr:]] 7
21+
; CHECK-SPIRV: CompositeConstruct [[#TypeMatrix]] [[#Composite:]] [[#Const0]]
22+
; CHECK-SPIRV: Store [[#VarMatrixPtr]] [[#Composite]]
23+
; CHECK-SPIRV: AccessChain [[#TypeIntPtr]] [[#Res:]] [[#VarMatrixPtr]] [[#Const1]]
24+
; CHECK-SPIRV: Store [[#Res]] [[#Const42]]
25+
26+
; CHECK-LLVM: %0 = alloca %spirv.CooperativeMatrixKHR._int_3_12_12_0 addrspace(1)*
27+
; CHECK-LLVM: %Obj = call spir_func %spirv.CooperativeMatrixKHR._int_3_12_12_0 addrspace(1)* @_Z26__spirv_CompositeConstructi(i32 0)
28+
; CHECK-LLVM: store %spirv.CooperativeMatrixKHR._int_3_12_12_0 addrspace(1)* %Obj, %spirv.CooperativeMatrixKHR._int_3_12_12_0 addrspace(1)** %0
29+
; CHECK-LLVM: %call = call spir_func i32* @_Z19__spirv_AccessChainPP43__spirv_CooperativeMatrixKHR__int_3_12_12_0i(%spirv.CooperativeMatrixKHR._int_3_12_12_0 addrspace(1)** %0, i32 1)
30+
; CHECK-LLVM: store i32 42, i32* %call
31+
32+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
33+
target triple = "spir64-unknown-unknown"
34+
35+
%spirv.CooperativeMatrixKHR._int_3_12_12_0 = type { [12 x [12 x i32]]* }
36+
37+
; Function Attrs: mustprogress uwtable
38+
define dso_local void @_Z3fooi(i32 noundef %idx) local_unnamed_addr #0 {
39+
entry:
40+
%0 = alloca %spirv.CooperativeMatrixKHR._int_3_12_12_0*, align 8
41+
%Obj = tail call spir_func noundef %spirv.CooperativeMatrixKHR._int_3_12_12_0* @_Z26__spirv_CompositeConstruct(i32 noundef 0) #4
42+
store %spirv.CooperativeMatrixKHR._int_3_12_12_0* %Obj, %spirv.CooperativeMatrixKHR._int_3_12_12_0** %0, align 8
43+
%call = call noundef i32* @_Z19__spirv_AccessChainP6Matrixii(%spirv.CooperativeMatrixKHR._int_3_12_12_0** %0, i32 noundef 1)
44+
call void @_Z13__spirv_StorePii(i32* noundef %call, i32 noundef 42)
45+
ret void
46+
}
47+
48+
declare dso_local spir_func noundef %spirv.CooperativeMatrixKHR._int_3_12_12_0* @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2
49+
50+
declare noundef i32* @_Z19__spirv_AccessChainP6Matrixii(%spirv.CooperativeMatrixKHR._int_3_12_12_0** noundef, i32 noundef) local_unnamed_addr #2
51+
52+
declare void @_Z13__spirv_StorePii(i32* noundef, i32 noundef) local_unnamed_addr #2
53+
54+
attributes #0 = { mustprogress uwtable "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
55+
attributes #2 = { "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
56+
attributes #3 = { nounwind }
57+
58+
!llvm.module.flags = !{!0, !1,!3, !4}
59+
!llvm.ident = !{!5}
60+
61+
!0 = !{i32 7, !"Dwarf Version", i32 4}
62+
!1 = !{i32 1, !"wchar_size", i32 4}
63+
!3 = !{i32 7, !"PIE Level", i32 2}
64+
!4 = !{i32 7, !"uwtable", i32 2}
65+
!5 = !{!"clang version 16.0.0 (https://github.com/llvm/llvm-project.git 08d094a0e457360ad8b94b017d2dc277e697ca76)"}

0 commit comments

Comments
 (0)