Skip to content

Commit fa77950

Browse files
committed
[Backport to 13] Implement SPV_INTEL_bfloat16_arithmetic (#3290) (#3320)
The extension relaxes rules for bf16 type allowing to use it in some arithmetic operations. Spec is available here: intel/llvm#18352 Co-authered by: Michael Aziz <[email protected]> --------- Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 5d94357 commit fa77950

File tree

9 files changed

+314
-0
lines changed

9 files changed

+314
-0
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,4 @@ EXT(SPV_INTEL_hw_thread_queries)
5757
EXT(SPV_EXT_relaxed_printf_string_address_space)
5858
EXT(SPV_INTEL_maximum_registers)
5959
EXT(SPV_KHR_bfloat16)
60+
EXT(SPV_INTEL_bfloat16_arithmetic)

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,11 @@ ParamType lastFuncParamType(StringRef MangledName) {
575575
char Mangled = Copy.back();
576576
std::string Mangled2 = Copy.substr(Copy.size() - 2);
577577

578+
std::string Mangled6 = Copy.substr(Copy.size() - 6);
579+
if (Mangled6 == "__bf16") {
580+
return ParamType::FLOAT;
581+
}
582+
578583
if (isMangledTypeFP(Mangled) || isMangledTypeHalf(Mangled2)) {
579584
return ParamType::FLOAT;
580585
} else if (isMangledTypeUnsigned(Mangled)) {
@@ -1637,6 +1642,9 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
16371642
NumElems = VecTy->getNumElements();
16381643
Ty = VecTy->getElementType();
16391644
}
1645+
if (Ty->isBFloatTy() &&
1646+
BM->hasCapability(internal::CapabilityBFloat16ArithmeticINTEL))
1647+
return true;
16401648
if ((!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy()) ||
16411649
(!BM->hasCapability(CapabilityVectorAnyINTEL) &&
16421650
((NumElems > 4) && (NumElems != 8) && (NumElems != 16)))) {
@@ -1653,6 +1661,9 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
16531661
NumElems = VecTy->getNumElements();
16541662
Ty = VecTy->getElementType();
16551663
}
1664+
if (Ty->isBFloatTy() &&
1665+
BM->hasCapability(internal::CapabilityBFloat16ArithmeticINTEL))
1666+
return true;
16561667
if ((!Ty->isIntegerTy()) ||
16571668
(!BM->hasCapability(CapabilityVectorAnyINTEL) &&
16581669
((NumElems > 4) && (NumElems != 8) && (NumElems != 16)))) {

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2728,6 +2728,20 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
27282728
// -spirv-allow-unknown-intrinsics work correctly.
27292729
auto IID = II->getIntrinsicID();
27302730
switch (IID) {
2731+
case Intrinsic::fabs:
2732+
case Intrinsic::fma:
2733+
case Intrinsic::maxnum:
2734+
case Intrinsic::minnum:
2735+
case Intrinsic::fmuladd: {
2736+
Type *Ty = II->getType();
2737+
if (Ty->isBFloatTy())
2738+
BM->addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
2739+
break;
2740+
}
2741+
default:
2742+
break;
2743+
}
2744+
switch (IID) {
27312745
case Intrinsic::assume: {
27322746
// llvm.assume translation is currently supported only within
27332747
// SPV_KHR_expect_assume extension, ignore it otherwise, since it's
@@ -3363,6 +3377,11 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
33633377
SmallVector<std::string, 2> Dec;
33643378
if (isBuiltinTransToExtInst(CI->getCalledFunction(), &ExtSetKind, &ExtOp,
33653379
&Dec)) {
3380+
if (const auto *FirstArg = F->getArg(0)) {
3381+
const auto *Type = FirstArg->getType();
3382+
if (Type->isBFloatTy())
3383+
BM->addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
3384+
}
33663385
if (DemangledName.find("__spirv_ocl_printf") != StringRef::npos) {
33673386
auto *FormatStrPtr = cast<PointerType>(CI->getArgOperand(0)->getType());
33683387
if (FormatStrPtr->getAddressSpace() !=

lib/SPIRV/libSPIRV/SPIRVEntry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,8 @@ class SPIRVCapability : public SPIRVEntryNoId<OpCapability> {
896896
case CapabilityVectorComputeINTEL:
897897
case CapabilityVectorAnyINTEL:
898898
return ExtensionID::SPV_INTEL_vector_compute;
899+
case internal::CapabilityBFloat16ArithmeticINTEL:
900+
return ExtensionID::SPV_INTEL_bfloat16_arithmetic;
899901
default:
900902
return {};
901903
}

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
200200
ADD_VEC_INIT(CapabilityBFloat16DotProductKHR, {CapabilityBFloat16TypeKHR});
201201
ADD_VEC_INIT(CapabilityBFloat16CooperativeMatrixKHR,
202202
{CapabilityBFloat16TypeKHR, CapabilityCooperativeMatrixKHR});
203+
ADD_VEC_INIT(internal::CapabilityBFloat16ArithmeticINTEL,
204+
{CapabilityBFloat16TypeKHR});
203205
}
204206

205207
template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,8 @@ SPIRVInstruction *SPIRVModuleImpl::addBinaryInst(Op TheOpCode, SPIRVType *Type,
13571357
SPIRVValue *Op1,
13581358
SPIRVValue *Op2,
13591359
SPIRVBasicBlock *BB) {
1360+
if (Type->isTypeFloat(16, FPEncodingBFloat16KHR) && TheOpCode != OpDot)
1361+
addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
13601362
return addInstruction(SPIRVInstTemplateBase::create(
13611363
TheOpCode, Type, getId(),
13621364
getVec(Op1->getId(), Op2->getId()), BB, this),
@@ -1380,6 +1382,8 @@ SPIRVInstruction *SPIRVModuleImpl::addUnaryInst(Op TheOpCode,
13801382
SPIRVType *TheType,
13811383
SPIRVValue *Op,
13821384
SPIRVBasicBlock *BB) {
1385+
if (TheType->isTypeFloat(16, FPEncodingBFloat16KHR) && TheOpCode != OpDot)
1386+
addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
13831387
return addInstruction(
13841388
SPIRVInstTemplateBase::create(TheOpCode, TheType, getId(),
13851389
getVec(Op->getId()), BB, this),

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
588588
"TensorFloat32RoundingINTEL");
589589
add(internal::CapabilityHWThreadQueryINTEL, "HWThreadQueryINTEL");
590590
add(internal::CapabilityRegisterLimitsINTEL, "RegisterLimitsINTEL");
591+
add(internal::CapabilityBFloat16ArithmeticINTEL, "BFloat16ArithmeticINTEL");
591592
}
592593
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
593594

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ enum InternalCapability {
6969
ICapTokenTypeINTEL = 6112,
7070
ICapBfloat16ConversionINTEL = 6115,
7171
ICapFPArithmeticFenceINTEL = 6144,
72+
ICapabilityBFloat16ArithmeticINTEL = 6226,
7273
ICapabilityTensorFloat32RoundingINTEL = 6425,
7374
ICapabilityMaskedGatherScatterINTEL = 6427,
7475
ICapabilityHWThreadQueryINTEL = 6134,
@@ -160,6 +161,8 @@ constexpr Capability CapabilityRegisterLimitsINTEL =
160161

161162
constexpr FunctionControlMask FunctionControlOptNoneINTELMask =
162163
static_cast<FunctionControlMask>(IFunctionControlOptNoneINTELMask);
164+
constexpr Capability CapabilityBFloat16ArithmeticINTEL =
165+
static_cast<Capability>(ICapabilityBFloat16ArithmeticINTEL);
163166

164167
constexpr Decoration DecorationMathOpDSPModeINTEL =
165168
static_cast<Decoration>(IDecMathOpDSPModeINTEL);

0 commit comments

Comments
 (0)