Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,28 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
(unsigned)S};
if (auto *Use = MT->getUse())
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
auto *CTI = MT->getComponentTypeInterpretation();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch contradicts current SPV_INTEL_joint_matrix specification, where type interpretation is a part of MulAdd. Note, IGC also expects type interpretation be a part of MulAdd and not a part of the type. Feel free to IM me to discuss this.

Copy link
Contributor

@MrSidims MrSidims May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mateuszchudyk please follow up with @vmaksimo about this comment (I'm OOO for the next 2 weeks), she will point out on spec changes and IGC patch, that adds type interpretation support.
Basically I'm fine to merge it as is, but please make sure, that we know, what we are doing with matrix special types. Please also note, that there is no such thing as Int4 interpretation (it is now a proper TypeInt 4 - see SPV_INTEL_int4 (this is to be backported by us soon)) and Int2 interpretation (and there won't be any counterpart).

if (!CTI)
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
transTypeToOCLTypeName(MT->getCompType()),
Params, !UseTPT));
std::string ComponentTypeName;
switch (static_cast<SPIRVConstant *>(CTI)->getZExtIntValue()) {
case internal::InternalJointMatrixCTI::TF32:
ComponentTypeName = "tf32";
break;
case internal::InternalJointMatrixCTI::Bfloat16:
ComponentTypeName = "bfloat16";
break;
case internal::InternalJointMatrixCTI::PackedInt2:
case internal::InternalJointMatrixCTI::PackedInt4:
// Do nothing just now
break;
default:
llvm_unreachable("Unexpected joint matrix component type");
}
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
transTypeToOCLTypeName(MT->getCompType()),
Params, !UseTPT));
ComponentTypeName, Params, !UseTPT));
}
case OpTypeCooperativeMatrixKHR: {
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(T);
Expand Down
52 changes: 31 additions & 21 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {

// Representation in LLVM IR before the translator is a pointer to an opaque
// structure:
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%scope%_%use%
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%layout%_%scope%_%use%
// Here we check the structure name yet again. Another option would be to
// check SPIR-V friendly function calls (by their name) and obtain return
// or their parameter types, assuming, that the appropriate types are Matrix
Expand All @@ -711,6 +711,18 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
// simply not true.
SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
SmallVector<std::string, 8> Postfixes) {
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
unsigned long long N = 0;
if (consumeUnsignedInteger(Postfix, 10, N))
BM->getErrorLog().checkError(
false, SPIRVEC_InvalidLlvmModule,
"TypeJointMatrixINTEL expects integer parameters");
return getUInt32(M, N);
};
std::vector<SPIRVValue *> Args;
for (size_t I = 1; I != Postfixes.size(); ++I)
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));

Type *ElemTy = nullptr;
StringRef Ty{Postfixes[0]};
auto NumBits = llvm::StringSwitch<unsigned>(Ty)
Expand All @@ -719,32 +731,30 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
.Case("int", 32)
.Case("long", 64)
.Default(0);
if (NumBits)
if (NumBits) {
ElemTy = IntegerType::get(M->getContext(), NumBits);
else if (Ty == "half")
} else if (Ty == "half") {
ElemTy = Type::getHalfTy(M->getContext());
else if (Ty == "float")
} else if (Ty == "float") {
ElemTy = Type::getFloatTy(M->getContext());
else if (Ty == "double")
} else if (Ty == "double") {
ElemTy = Type::getDoubleTy(M->getContext());
else if (Ty == "bfloat16")
} else if (Ty == "bfloat16") {
ElemTy = Type::getInt16Ty(M->getContext());
else
// TODO: add BF16 CTI when we do breaking change
// auto *CTI = transConstant(getUInt32(M, static_cast<uint64_t>(
// internal::InternalJointMatrixCTI::Bfloat16)));
// Args.push_back(CTI);
// BM->addCapability(internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
} else if (Ty == "tf32") {
ElemTy = Type::getFloatTy(M->getContext());
auto *CTI = transConstant(getUInt32(
M, static_cast<uint64_t>(internal::InternalJointMatrixCTI::TF32)));
Args.push_back(CTI);
BM->addCapability(internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
} else {
llvm_unreachable("Unexpected type for matrix!");

auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
unsigned long long N = 0;
if (consumeUnsignedInteger(Postfix, 10, N)) {
BM->getErrorLog().checkError(
false, SPIRVEC_InvalidLlvmModule,
"TypeJointMatrixINTEL expects integer parameters");
return 0;
}
return getUInt32(M, N);
};
std::vector<SPIRVValue *> Args;
for (size_t I = 1; I != Postfixes.size(); ++I)
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
}
return BM->addJointMatrixINTELType(transType(ElemTy), Args);
}

Expand Down
8 changes: 8 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixCheckedInstructionsINTEL,
{CapabilityCooperativeMatrixKHR});
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixPrefetchINTEL,
Expand Down
8 changes: 8 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,14 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(internal::CapabilityCacheControlsINTEL, "CacheControlsINTEL");
add(internal::CapabilityJointMatrixWIInstructionsINTEL,
"JointMatrixWIInstructionsINTEL");
add(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
"JointMatrixTF32ComponentTypeINTEL");
add(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
"JointMatrixBF16ComponentTypeINTEL");
add(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
"JointMatrixPackedInt2ComponentTypeINTEL");
add(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
"JointMatrixPackedInt4ComponentTypeINTEL");
add(internal::CapabilityCooperativeMatrixPrefetchINTEL,
"CooperativeMatrixPrefetchINTEL");
add(internal::CapabilityCooperativeMatrixCheckedInstructionsINTEL,
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVType.h
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,9 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
SPIRVValue *getLayout() const { return Args[2]; }
SPIRVValue *getScope() const { return Args[3]; }
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }
SPIRVValue *getComponentTypeInterpretation() const {
return Args.size() > 5 ? Args[5] : nullptr;
}
};

class SPIRVTypeCooperativeMatrixKHR : public SPIRVType {
Expand Down
16 changes: 16 additions & 0 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ enum InternalCapability {
ICapabilityTensorFloat32RoundingINTEL = 6425,
ICapabilityMaskedGatherScatterINTEL = 6427,
ICapabilityJointMatrixWIInstructionsINTEL = 6435,
ICapabilityJointMatrixTF32ComponentTypeINTEL = 6436,
ICapabilityJointMatrixBF16ComponentTypeINTEL = 6437,
ICapabilityJointMatrixPackedInt2ComponentTypeINTEL = 6438,
ICapabilityJointMatrixPackedInt4ComponentTypeINTEL = 6439,
ICapabilityCacheControlsINTEL = 6441,
ICapRegisterLimitsINTEL = 6460,
ICapabilityBindlessImagesINTEL = 6528
Expand All @@ -139,6 +143,14 @@ enum InternalJointMatrixLayout {

enum InternalJointMatrixUse { MatrixA = 0, MatrixB = 1, Accumulator = 2 };

enum InternalJointMatrixCTI {
None = 0,
TF32 = 1,
Bfloat16 = 2,
PackedInt2 = 3,
PackedInt4 = 4
};

enum InternalBuiltIn {
IBuiltInSubDeviceIDINTEL = 6135,
IBuiltInGlobalHWThreadIDINTEL = 6136,
Expand All @@ -162,6 +174,10 @@ enum class StoreCacheControlINTEL {
#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
_SPIRV_OP(Capability, JointMatrixINTEL)
_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL)
_SPIRV_OP(Capability, JointMatrixTF32ComponentTypeINTEL)
_SPIRV_OP(Capability, JointMatrixBF16ComponentTypeINTEL)
_SPIRV_OP(Capability, JointMatrixPackedInt2ComponentTypeINTEL)
_SPIRV_OP(Capability, JointMatrixPackedInt4ComponentTypeINTEL)
_SPIRV_OP(Op, TypeJointMatrixINTEL)
_SPIRV_OP(Op, JointMatrixLoadINTEL)
_SPIRV_OP(Op, JointMatrixStoreINTEL)
Expand Down
Loading