Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/SPIRV/Mangler/ManglingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ static const char *PrimitiveNames[PRIMITIVE_NUM] = {
"intel_sub_group_avc_ime_result_dual_reference_streamin_t"
};

// clang-format off
const char *MangledTypes[PRIMITIVE_NUM] = {
"b", // BOOL
"h", // UCHAR
Expand All @@ -106,7 +107,7 @@ const char *MangledTypes[PRIMITIVE_NUM] = {
"Dh", // HALF
"f", // FLOAT
"d", // DOUBLE
"u6__bf16", // __BF16
"DF16b", // __BF16
"v", // VOID
"z", // VarArg
"14ocl_image1d_ro", // PRIMITIVE_IMAGE1D_RO_T
Expand Down Expand Up @@ -175,6 +176,7 @@ const char *MangledTypes[PRIMITIVE_NUM] = {
"55ocl_intel_sub_group_avc_ime_single_reference_streamin_t", // PRIMITIVE_SUB_GROUP_AVC_IME_SINGLE_REF_STREAMIN_T
"53ocl_intel_sub_group_avc_ime_dual_reference_streamin_t" // PRIMITIVE_SUB_GROUP_AVC_IME_DUAL_REF_STREAMIN_T
};
// clang-format on

const char *ReadableAttribute[ATTR_NUM] = {
"restrict", "volatile", "const", "__private",
Expand Down
4 changes: 2 additions & 2 deletions lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,8 @@ ParamType lastFuncParamType(StringRef MangledName) {
char Mangled = Copy.back();
std::string Mangled2 = Copy.substr(Copy.size() - 2);

std::string Mangled6 = Copy.substr(Copy.size() - 6);
if (Mangled6 == "__bf16") {
std::string Mangled5 = Copy.substr(Copy.size() - 5);
if (Mangled5 == "DF16b") {
return ParamType::FLOAT;
}

Expand Down
106 changes: 53 additions & 53 deletions test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll
Original file line number Diff line number Diff line change
Expand Up @@ -126,44 +126,44 @@ target triple = "spirv64-unknown-unknown"
; CHECK-LLVM: %OpFUnordLessThanEqual = fcmp ule bfloat [[DATA1]], [[DATA2]]
; CHECK-LLVM: %OpFOrdGreaterThanEqual = fcmp oge bfloat [[DATA1]], [[DATA2]]
; CHECK-LLVM: %OpFUnordGreaterThanEqual = fcmp uge bfloat [[DATA1]], [[DATA2]]
; CHECK-LLVM: %fabs = call spir_func bfloat @_Z4fabsu6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %fclamp = call spir_func bfloat @_Z5clampu6__bf16u6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]], bfloat [[DATA3]])
; CHECK-LLVM: %fma = call spir_func bfloat @_Z3fmau6__bf16u6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]], bfloat [[DATA3]])
; CHECK-LLVM: %fmax = call spir_func bfloat @_Z4fmaxu6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]])
; CHECK-LLVM: %fmin = call spir_func bfloat @_Z4fminu6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]])
; CHECK-LLVM: %mad = call spir_func bfloat @_Z3madu6__bf16u6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]], bfloat [[DATA3]])
; CHECK-LLVM: %nan = call spir_func bfloat @_Z3nanu6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_cos = call spir_func bfloat @_Z10native_cosu6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_divide = call spir_func bfloat @_Z13native_divideu6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]])
; CHECK-LLVM: %native_exp = call spir_func bfloat @_Z10native_expu6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_exp10 = call spir_func bfloat @_Z12native_exp10u6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_exp2 = call spir_func bfloat @_Z11native_exp2u6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_log = call spir_func bfloat @_Z10native_logu6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_log10 = call spir_func bfloat @_Z12native_log10u6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_log2 = call spir_func bfloat @_Z11native_log2u6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_powr = call spir_func bfloat @_Z11native_powru6__bf16u6__bf16(bfloat [[DATA1]], bfloat [[DATA2]])
; CHECK-LLVM: %native_recip = call spir_func bfloat @_Z12native_recipu6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_rsqrt = call spir_func bfloat @_Z12native_rsqrtu6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_sin = call spir_func bfloat @_Z10native_sinu6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_sqrt = call spir_func bfloat @_Z11native_sqrtu6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %native_tan = call spir_func bfloat @_Z10native_tanu6__bf16(bfloat [[DATA1]])
; CHECK-LLVM: %fabs = call spir_func bfloat @_Z4fabsDF16b(bfloat [[DATA1]])
; CHECK-LLVM: %fclamp = call spir_func bfloat @_Z5clampDF16bDF16bDF16b(bfloat [[DATA1]], bfloat [[DATA2]], bfloat [[DATA3]])
; CHECK-LLVM: %fma = call spir_func bfloat @_Z3fmaDF16bDF16bDF16b(bfloat [[DATA1]], bfloat [[DATA2]], bfloat [[DATA3]])
; CHECK-LLVM: %fmax = call spir_func bfloat @_Z4fmaxDF16bDF16b(bfloat [[DATA1]], bfloat [[DATA2]])
; CHECK-LLVM: %fmin = call spir_func bfloat @_Z4fminDF16bDF16b(bfloat [[DATA1]], bfloat [[DATA2]])
; CHECK-LLVM: %mad = call spir_func bfloat @_Z3madDF16bDF16bDF16b(bfloat [[DATA1]], bfloat [[DATA2]], bfloat [[DATA3]])
; CHECK-LLVM: %nan = call spir_func bfloat @_Z3nanDF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_cos = call spir_func bfloat @_Z10native_cosDF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_divide = call spir_func bfloat @_Z13native_divideDF16bDF16b(bfloat [[DATA1]], bfloat [[DATA2]])
; CHECK-LLVM: %native_exp = call spir_func bfloat @_Z10native_expDF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_exp10 = call spir_func bfloat @_Z12native_exp10DF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_exp2 = call spir_func bfloat @_Z11native_exp2DF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_log = call spir_func bfloat @_Z10native_logDF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_log10 = call spir_func bfloat @_Z12native_log10DF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_log2 = call spir_func bfloat @_Z11native_log2DF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_powr = call spir_func bfloat @_Z11native_powrDF16bDF16b(bfloat [[DATA1]], bfloat [[DATA2]])
; CHECK-LLVM: %native_recip = call spir_func bfloat @_Z12native_recipDF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_rsqrt = call spir_func bfloat @_Z12native_rsqrtDF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_sin = call spir_func bfloat @_Z10native_sinDF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_sqrt = call spir_func bfloat @_Z11native_sqrtDF16b(bfloat [[DATA1]])
; CHECK-LLVM: %native_tan = call spir_func bfloat @_Z10native_tanDF16b(bfloat [[DATA1]])

declare spir_func bfloat @_Z5clampu6__bf16u6__bf16u6__bf16(bfloat, bfloat, bfloat)
declare spir_func bfloat @_Z3nanu6__bf16(bfloat)
declare spir_func bfloat @_Z10native_cosu6__bf16(bfloat)
declare spir_func bfloat @_Z13native_divideu6__bf16u6__bf16(bfloat, bfloat)
declare spir_func bfloat @_Z10native_expu6__bf16(bfloat)
declare spir_func bfloat @_Z12native_exp10u6__bf16(bfloat)
declare spir_func bfloat @_Z11native_exp2u6__bf16(bfloat)
declare spir_func bfloat @_Z10native_logu6__bf16(bfloat)
declare spir_func bfloat @_Z12native_log10u6__bf16(bfloat)
declare spir_func bfloat @_Z11native_log2u6__bf16(bfloat)
declare spir_func bfloat @_Z11native_powru6__bf16u6__bf16(bfloat, bfloat)
declare spir_func bfloat @_Z12native_recipu6__bf16(bfloat)
declare spir_func bfloat @_Z12native_rsqrtu6__bf16(bfloat)
declare spir_func bfloat @_Z10native_sinu6__bf16(bfloat)
declare spir_func bfloat @_Z11native_sqrtu6__bf16(bfloat)
declare spir_func bfloat @_Z10native_tanu6__bf16(bfloat)
declare spir_func bfloat @_Z5clampDF16bDF16bDF16b(bfloat, bfloat, bfloat)
declare spir_func bfloat @_Z3nanDF16b(bfloat)
declare spir_func bfloat @_Z10native_cosDF16b(bfloat)
declare spir_func bfloat @_Z13native_divideDF16bDF16b(bfloat, bfloat)
declare spir_func bfloat @_Z10native_expDF16b(bfloat)
declare spir_func bfloat @_Z12native_exp10DF16b(bfloat)
declare spir_func bfloat @_Z11native_exp2DF16b(bfloat)
declare spir_func bfloat @_Z10native_logDF16b(bfloat)
declare spir_func bfloat @_Z12native_log10DF16b(bfloat)
declare spir_func bfloat @_Z11native_log2DF16b(bfloat)
declare spir_func bfloat @_Z11native_powrDF16bDF16b(bfloat, bfloat)
declare spir_func bfloat @_Z12native_recipDF16b(bfloat)
declare spir_func bfloat @_Z12native_rsqrtDF16b(bfloat)
declare spir_func bfloat @_Z10native_sinDF16b(bfloat)
declare spir_func bfloat @_Z11native_sqrtDF16b(bfloat)
declare spir_func bfloat @_Z10native_tanDF16b(bfloat)

define spir_func void @OpPhi(bfloat %data1, bfloat %data2) {
br label %blockA
Expand Down Expand Up @@ -223,26 +223,26 @@ entry:
%OpFOrdGreaterThanEqual = fcmp oge bfloat %data1, %data2
%OpFUnordGreaterThanEqual = fcmp uge bfloat %data1, %data2
%fabs = call bfloat @llvm.fabs.bfloat(bfloat %data1)
%fclamp = call spir_func bfloat @_Z5clampu6__bf16u6__bf16u6__bf16(bfloat %data1, bfloat %data2, bfloat %data3)
%fclamp = call spir_func bfloat @_Z5clampDF16bDF16bDF16b(bfloat %data1, bfloat %data2, bfloat %data3)
%fma = call bfloat @llvm.fma.bfloat(bfloat %data1, bfloat %data2, bfloat %data3)
%fmax = call bfloat @llvm.maxnum.bfloat(bfloat %data1, bfloat %data2)
%fmin = call bfloat @llvm.minnum.bfloat(bfloat %data1, bfloat %data2)
%mad = call bfloat @llvm.fmuladd.bfloat(bfloat %data1, bfloat %data2, bfloat %data3)
%nan = call spir_func bfloat @_Z3nanu6__bf16(bfloat %data1)
%native_cos = call spir_func bfloat @_Z10native_cosu6__bf16(bfloat %data1)
%native_divide = call spir_func bfloat @_Z13native_divideu6__bf16u6__bf16(bfloat %data1, bfloat %data2)
%native_exp = call spir_func bfloat @_Z10native_expu6__bf16(bfloat %data1)
%native_exp10 = call spir_func bfloat @_Z12native_exp10u6__bf16(bfloat %data1)
%native_exp2 = call spir_func bfloat @_Z11native_exp2u6__bf16(bfloat %data1)
%native_log = call spir_func bfloat @_Z10native_logu6__bf16(bfloat %data1)
%native_log10 = call spir_func bfloat @_Z12native_log10u6__bf16(bfloat %data1)
%native_log2 = call spir_func bfloat @_Z11native_log2u6__bf16(bfloat %data1)
%native_powr = call spir_func bfloat @_Z11native_powru6__bf16u6__bf16(bfloat %data1, bfloat %data2)
%native_recip = call spir_func bfloat @_Z12native_recipu6__bf16(bfloat %data1)
%native_rsqrt = call spir_func bfloat @_Z12native_rsqrtu6__bf16(bfloat %data1)
%native_sin = call spir_func bfloat @_Z10native_sinu6__bf16(bfloat %data1)
%native_sqrt = call spir_func bfloat @_Z11native_sqrtu6__bf16(bfloat %data1)
%native_tan = call spir_func bfloat @_Z10native_tanu6__bf16(bfloat %data1)
%nan = call spir_func bfloat @_Z3nanDF16b(bfloat %data1)
%native_cos = call spir_func bfloat @_Z10native_cosDF16b(bfloat %data1)
%native_divide = call spir_func bfloat @_Z13native_divideDF16bDF16b(bfloat %data1, bfloat %data2)
%native_exp = call spir_func bfloat @_Z10native_expDF16b(bfloat %data1)
%native_exp10 = call spir_func bfloat @_Z12native_exp10DF16b(bfloat %data1)
%native_exp2 = call spir_func bfloat @_Z11native_exp2DF16b(bfloat %data1)
%native_log = call spir_func bfloat @_Z10native_logDF16b(bfloat %data1)
%native_log10 = call spir_func bfloat @_Z12native_log10DF16b(bfloat %data1)
%native_log2 = call spir_func bfloat @_Z11native_log2DF16b(bfloat %data1)
%native_powr = call spir_func bfloat @_Z11native_powrDF16bDF16b(bfloat %data1, bfloat %data2)
%native_recip = call spir_func bfloat @_Z12native_recipDF16b(bfloat %data1)
%native_rsqrt = call spir_func bfloat @_Z12native_rsqrtDF16b(bfloat %data1)
%native_sin = call spir_func bfloat @_Z10native_sinDF16b(bfloat %data1)
%native_sqrt = call spir_func bfloat @_Z11native_sqrtDF16b(bfloat %data1)
%native_tan = call spir_func bfloat @_Z10native_tanDF16b(bfloat %data1)
ret void
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ define dso_local spir_func bfloat @test_AtomicFAddEXT_bfloat(ptr addrspace(4) al
entry:
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
; CHECK-SPIRV: AtomicFAddEXT [[BFLOAT]]
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1DF16biiDF16b({{.*}}bfloat
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1DF16biiDF16b(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
ret bfloat %ret
}

; Function Attrs: convergent
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1DF16biiDF16b(ptr addrspace(1), i32, i32, bfloat)
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ define dso_local spir_func bfloat @test_AtomicFMaxEXT_bfloat(ptr addrspace(4) al
entry:
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
; CHECK-SPIRV: AtomicFMaxEXT [[BFLOAT]]
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1DF16biiDF16b({{.*}}bfloat
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1DF16biiDF16b(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
ret bfloat %ret
}

; Function Attrs: convergent
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1DF16biiDF16b(ptr addrspace(1), i32, i32, bfloat)
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ define dso_local spir_func bfloat @test_AtomicFMinEXT_bfloat(ptr addrspace(4) al
entry:
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
; CHECK-SPIRV: AtomicFMinEXT [[BFLOAT]]
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1DF16biiDF16b({{.*}}bfloat
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1DF16biiDF16b(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
ret bfloat %ret
}

; Function Attrs: convergent
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1DF16biiDF16b(ptr addrspace(1), i32, i32, bfloat)
6 changes: 3 additions & 3 deletions test/extensions/KHR/SPV_KHR_bfloat16/bfloat16_dot.ll
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ target triple = "spirv64-unknown-unknown"
; CHECK-LLVM: %addrB = alloca <2 x bfloat>
; CHECK-LLVM: %dataA = load <2 x bfloat>, ptr %addrA
; CHECK-LLVM: %dataB = load <2 x bfloat>, ptr %addrB
; CHECK-LLVM: %call = call spir_func bfloat @_Z3dotDv2_u6__bf16S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
; CHECK-LLVM: %call = call spir_func bfloat @_Z3dotDv2_DF16bS_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)

declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)
declare spir_func bfloat @_Z3dotDv2_DF16bDv2_S_(<2 x bfloat>, <2 x bfloat>)

define spir_kernel void @test() {
entry:
%addrA = alloca <2 x bfloat>
%addrB = alloca <2 x bfloat>
%dataA = load <2 x bfloat>, ptr %addrA
%dataB = load <2 x bfloat>, ptr %addrB
%call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
%call = call spir_func bfloat @_Z3dotDv2_DF16bDv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
ret void
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
; CHECK-SPIRV-DAG: Constant [[#BFloatTy]] [[#]] 16256
; CHECK-SPIRV: CompositeConstruct [[#MatTy]]

; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructu6__bf16(bfloat 0xR3F80)
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructDF16b(bfloat 0xR3F80)

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

declare spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructu6__bf16(bfloat)
declare spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructDF16b(bfloat)

define spir_kernel void @test() {
%mat = call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructu6__bf16(bfloat 1.0)
%mat = call spir_func target("spirv.CooperativeMatrixKHR", bfloat, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructDF16b(bfloat 1.0)
ret void
}
Loading