Skip to content

Commit 465eb3c

Browse files
authored
[OpaquePointers] Rewrite joint_matrix tests (KhronosGroup#2088)
This patch adds joint_matrix reverse translation to target extension type and starts rewriting all of the tests. Some tests are being removed as outdated Remaining tests to add after the patch: 1. tf32 test 2. element wise operations test Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent e661cb7 commit 465eb3c

10 files changed

+484
-1221
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -457,26 +457,17 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
457457
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
458458
auto *CTI = MT->getComponentTypeInterpretation();
459459
if (!CTI)
460-
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
461-
transTypeToOCLTypeName(MT->getCompType()),
462-
Params, !UseTPT));
463-
std::string ComponentTypeName;
464-
switch (static_cast<SPIRVConstant *>(CTI)->getZExtIntValue()) {
465-
case internal::InternalJointMatrixCTI::TF32:
466-
ComponentTypeName = "tf32";
467-
break;
468-
case internal::InternalJointMatrixCTI::Bfloat16:
469-
ComponentTypeName = "bfloat16";
470-
break;
471-
case internal::InternalJointMatrixCTI::PackedInt2:
472-
case internal::InternalJointMatrixCTI::PackedInt4:
473-
// Do nothing just now
474-
break;
475-
default:
476-
llvm_unreachable("Unexpected joint matrix component type");
477-
}
478-
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
479-
ComponentTypeName, Params, !UseTPT));
460+
return mapType(
461+
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
462+
transType(MT->getCompType()), Params));
463+
const unsigned CTIValue =
464+
static_cast<SPIRVConstant *>(CTI)->getZExtIntValue();
465+
assert(CTIValue <= internal::InternalJointMatrixCTI::PackedInt4 &&
466+
"Unknown matrix component type interpretation");
467+
Params.push_back(CTIValue);
468+
return mapType(
469+
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
470+
transType(MT->getCompType()), Params));
480471
}
481472
case OpTypeForwardPointer: {
482473
SPIRVTypeForwardPointer *FP =

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -616,21 +616,6 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(Type *ET, unsigned AddrSpc) {
616616
transType(ET)));
617617
}
618618
} else {
619-
// JointMatrixINTEL type is not necessarily an opaque type, it can be
620-
// represented as a structure with pointer to a multidimensional array
621-
// member.
622-
if (ST && ST->hasName()) {
623-
StringRef STName = ST->getName();
624-
if (STName.startswith(kSPIRVTypeName::PrefixAndDelim)) {
625-
SmallVector<std::string, 8> Postfixes;
626-
auto TN = decodeSPIRVTypeName(STName, Postfixes);
627-
if (TN == kSPIRVTypeName::JointMatrixINTEL) {
628-
SPIRVType *TranslatedTy = transSPIRVJointMatrixINTELType(Postfixes);
629-
PointeeTypeMap[TypeKey] = TranslatedTy;
630-
return TranslatedTy;
631-
}
632-
}
633-
}
634619
SPIRVType *ElementType = transType(ET);
635620
// ET, as a recursive type, may contain exactly the same pointer T, so it
636621
// may happen that after translation of ET we already have translated T,
@@ -661,66 +646,6 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
661646
return TranslatedTy;
662647
}
663648

664-
// Representation in LLVM IR before the translator is a pointer to an opaque
665-
// structure:
666-
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%layout%_%scope%_%use%
667-
// Here we check the structure name yet again. Another option would be to
668-
// check SPIR-V friendly function calls (by their name) and obtain return
669-
// or their parameter types, assuming, that the appropriate types are Matrix
670-
// structure type. But in the near future, we will reuse Composite
671-
// instructions to do, for example, matrix initialization directly on AMX
672-
// register by OpCompositeConstruct. And we can't claim, that the Result type
673-
// of OpCompositeConstruct instruction is always the joint matrix type, it's
674-
// simply not true.
675-
SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
676-
SmallVector<std::string, 8> Postfixes) {
677-
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
678-
unsigned long long N = 0;
679-
if (consumeUnsignedInteger(Postfix, 10, N))
680-
BM->getErrorLog().checkError(
681-
false, SPIRVEC_InvalidLlvmModule,
682-
"TypeJointMatrixINTEL expects integer parameters");
683-
return getUInt32(M, N);
684-
};
685-
std::vector<SPIRVValue *> Args;
686-
for (size_t I = 1; I != Postfixes.size(); ++I)
687-
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
688-
689-
Type *ElemTy = nullptr;
690-
StringRef Ty{Postfixes[0]};
691-
auto NumBits = llvm::StringSwitch<unsigned>(Ty)
692-
.Case("char", 8)
693-
.Case("short", 16)
694-
.Case("int", 32)
695-
.Case("long", 64)
696-
.Default(0);
697-
if (NumBits) {
698-
ElemTy = IntegerType::get(M->getContext(), NumBits);
699-
} else if (Ty == "half") {
700-
ElemTy = Type::getHalfTy(M->getContext());
701-
} else if (Ty == "float") {
702-
ElemTy = Type::getFloatTy(M->getContext());
703-
} else if (Ty == "double") {
704-
ElemTy = Type::getDoubleTy(M->getContext());
705-
} else if (Ty == "bfloat16") {
706-
ElemTy = Type::getInt16Ty(M->getContext());
707-
// TODO: add BF16 CTI when we do breaking change
708-
// auto *CTI = transConstant(getUInt32(M, static_cast<uint64_t>(
709-
// internal::InternalJointMatrixCTI::Bfloat16)));
710-
// Args.push_back(CTI);
711-
// BM->addCapability(internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
712-
} else if (Ty == "tf32") {
713-
ElemTy = Type::getFloatTy(M->getContext());
714-
auto *CTI = transConstant(getUInt32(
715-
M, static_cast<uint64_t>(internal::InternalJointMatrixCTI::TF32)));
716-
Args.push_back(CTI);
717-
BM->addCapability(internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
718-
} else {
719-
llvm_unreachable("Unexpected type for matrix!");
720-
}
721-
return BM->addJointMatrixINTELType(transType(ElemTy), Args);
722-
}
723-
724649
SPIRVType *LLVMToSPIRVBase::transSPIRVOpaqueType(StringRef STName,
725650
unsigned AddrSpace) {
726651
std::pair<StringRef, unsigned> Key = {STName, AddrSpace};
@@ -777,8 +702,6 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVOpaqueType(StringRef STName,
777702
return SaveType(BM->addQueueType());
778703
else if (TN == kSPIRVTypeName::PipeStorage)
779704
return SaveType(BM->addPipeStorageType());
780-
else if (TN == kSPIRVTypeName::JointMatrixINTEL)
781-
return SaveType(transSPIRVJointMatrixINTELType(Postfixes));
782705
else if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute) &&
783706
TN == kSPIRVTypeName::BufferSurfaceINTEL) {
784707
auto Access = getAccessQualifier(STName);

0 commit comments

Comments
 (0)