Skip to content

Commit 21bc73c

Browse files
committed
Revert "[Backport to 16] Start preparing for TypeJointMatrixINTEL switch (KhronosGroup#1935)"
This reverts commit e781a91.
1 parent bcb764c commit 21bc73c

File tree

9 files changed

+18
-60
lines changed

9 files changed

+18
-60
lines changed

lib/SPIRV/OCLUtil.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,6 @@ SPIRAddressSpace getOCLOpaqueTypeAddrSpace(Op OpCode) {
898898
case OpTypeSampler:
899899
return SPIRV_SAMPLER_T_ADDR_SPACE;
900900
case internal::OpTypeJointMatrixINTEL:
901-
case internal::OpTypeJointMatrixINTELv2:
902901
case OpTypeCooperativeMatrixKHR:
903902
return SPIRAS_Global;
904903
default:

lib/SPIRV/SPIRVReader.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -472,11 +472,10 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
472472
auto *MT = static_cast<SPIRVTypeJointMatrixINTEL *>(T);
473473
auto R = static_cast<SPIRVConstant *>(MT->getRows())->getZExtIntValue();
474474
auto C = static_cast<SPIRVConstant *>(MT->getColumns())->getZExtIntValue();
475-
std::vector<unsigned> Params = {(unsigned)R, (unsigned)C};
476-
if (auto *Layout = MT->getLayout())
477-
Params.push_back(static_cast<SPIRVConstant *>(Layout)->getZExtIntValue());
478-
Params.push_back(
479-
static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue());
475+
auto L = static_cast<SPIRVConstant *>(MT->getLayout())->getZExtIntValue();
476+
auto S = static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue();
477+
SmallVector<unsigned, 5> Params = {(unsigned)R, (unsigned)C, (unsigned)L,
478+
(unsigned)S};
480479
if (auto *Use = MT->getUse())
481480
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
482481
auto *CTI = MT->getComponentTypeInterpretation();

lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@ SPIRVEntry *SPIRVEntry::create(Op OpCode) {
8484
static const OpToFactoryMapTy OpToFactoryMap(std::begin(Table),
8585
std::end(Table));
8686

87-
// TODO: To remove this when we make a switch to new version
88-
if (OpCode == internal::OpTypeJointMatrixINTELv2)
89-
OpCode = internal::OpTypeJointMatrixINTEL;
90-
9187
OpToFactoryMapTy::const_iterator Loc = OpToFactoryMap.find(OpCode);
9288
if (Loc != OpToFactoryMap.end())
9389
return Loc->second();

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,7 +1991,6 @@ class SPIRVCompositeConstruct : public SPIRVInstruction {
19911991
case OpTypeArray:
19921992
case OpTypeStruct:
19931993
case internal::OpTypeJointMatrixINTEL:
1994-
case internal::OpTypeJointMatrixINTELv2:
19951994
case OpTypeCooperativeMatrixKHR:
19961995
break;
19971996
default:
@@ -3517,10 +3516,10 @@ class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
35173516
SPIRV##x##INTEL;
35183517
_SPIRV_OP(JointMatrixLoad, true, 6, true)
35193518
_SPIRV_OP(JointMatrixStore, false, 5, true)
3520-
_SPIRV_OP(JointMatrixMad, true, 6, true)
3521-
_SPIRV_OP(JointMatrixSUMad, true, 6, true)
3522-
_SPIRV_OP(JointMatrixUSMad, true, 6, true)
3523-
_SPIRV_OP(JointMatrixUUMad, true, 6, true)
3519+
_SPIRV_OP(JointMatrixMad, true, 7)
3520+
_SPIRV_OP(JointMatrixSUMad, true, 7)
3521+
_SPIRV_OP(JointMatrixUSMad, true, 7)
3522+
_SPIRV_OP(JointMatrixUUMad, true, 7)
35243523
// TODO: move to SPIRVJointMatrixINTELWorkItemInst
35253524
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
35263525
#undef _SPIRV_OP

lib/SPIRV/libSPIRV/SPIRVOpCode.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ inline bool isTypeOpCode(Op OpCode) {
230230
isSubgroupAvcINTELTypeOpCode(OpCode) || OC == OpTypeVmeImageINTEL ||
231231
isVCOpCode(OpCode) || OC == internal::OpTypeTokenINTEL ||
232232
OC == internal::OpTypeJointMatrixINTEL ||
233-
OC == internal::OpTypeJointMatrixINTELv2 ||
234233
OC == OpTypeCooperativeMatrixKHR;
235234
}
236235

lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ _SPIRV_OP_INTERNAL(ArithmeticFenceINTEL, internal::OpArithmeticFenceINTEL)
66
_SPIRV_OP_INTERNAL(ConvertFToBF16INTEL, internal::OpConvertFToBF16INTEL)
77
_SPIRV_OP_INTERNAL(ConvertBF16ToFINTEL, internal::OpConvertBF16ToFINTEL)
88
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
9-
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
109
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
1110
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
1211
_SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL)

lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,7 @@ bool SPIRVType::isTypeStruct() const { return OpCode == OpTypeStruct; }
206206
bool SPIRVType::isTypeVector() const { return OpCode == OpTypeVector; }
207207

208208
bool SPIRVType::isTypeJointMatrixINTEL() const {
209-
return OpCode == internal::OpTypeJointMatrixINTEL ||
210-
OpCode == internal::OpTypeJointMatrixINTELv2;
209+
return OpCode == internal::OpTypeJointMatrixINTEL;
211210
}
212211

213212
bool SPIRVType::isTypeCooperativeMatrixKHR() const {
@@ -290,21 +289,14 @@ void SPIRVTypeForwardPointer::decode(std::istream &I) {
290289
Decoder >> PointerId >> SC;
291290
}
292291

293-
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
294-
SPIRVModule *M, SPIRVId TheId, Op OC, SPIRVType *CompType,
295-
std::vector<SPIRVValue *> Args)
296-
: SPIRVType(M, FixedWC + Args.size(), OC, TheId), CompType(CompType),
297-
Args(std::move(Args)) {}
298-
299292
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
300293
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
301294
std::vector<SPIRVValue *> Args)
302-
: SPIRVType(M, FixedWC + Args.size(), internal::OpTypeJointMatrixINTEL,
303-
TheId),
304-
CompType(CompType), Args(std::move(Args)) {}
295+
: SPIRVType(M, FixedWC + Args.size(), OC, TheId), CompType(CompType),
296+
Args(Args) {}
305297

306298
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL()
307-
: SPIRVType(internal::OpTypeJointMatrixINTEL), CompType(nullptr),
299+
: SPIRVType(OC), CompType(nullptr),
308300
Args({nullptr, nullptr, nullptr, nullptr}) {}
309301

310302
void SPIRVTypeJointMatrixINTEL::encode(spv_ostream &O) const {

lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,18 +1062,13 @@ class SPIRVTypeTokenINTEL : public SPIRVType {
10621062
};
10631063

10641064
class SPIRVTypeJointMatrixINTEL : public SPIRVType {
1065-
Op OC;
10661065
SPIRVType *CompType;
10671066
std::vector<SPIRVValue *> Args;
10681067

10691068
public:
1069+
const static Op OC = internal::OpTypeJointMatrixINTEL;
10701070
const static SPIRVWord FixedWC = 3;
1071-
// Complete constructor with non-default OC
1072-
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, Op OC,
1073-
SPIRVType *CompType,
1074-
std::vector<SPIRVValue *> Args);
1075-
1076-
// Incomplete constructor for default OC
1071+
// Complete constructor
10771072
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
10781073
std::vector<SPIRVValue *> Args);
10791074
// Incomplete constructor
@@ -1092,29 +1087,11 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
10921087
SPIRVType *getCompType() const { return CompType; }
10931088
SPIRVValue *getRows() const { return Args[0]; }
10941089
SPIRVValue *getColumns() const { return Args[1]; }
1095-
1096-
SPIRVValue *getLayout() const {
1097-
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1098-
return Args[2];
1099-
return nullptr;
1100-
}
1101-
1102-
SPIRVValue *getScope() const {
1103-
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1104-
return Args[3];
1105-
return Args[2];
1106-
}
1107-
1108-
SPIRVValue *getUse() const {
1109-
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1110-
return Args.size() > 4 ? Args[4] : nullptr;
1111-
return Args[3];
1112-
}
1113-
1090+
SPIRVValue *getLayout() const { return Args[2]; }
1091+
SPIRVValue *getScope() const { return Args[3]; }
1092+
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }
11141093
SPIRVValue *getComponentTypeInterpretation() const {
1115-
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
1116-
return Args.size() > 5 ? Args[5] : nullptr;
1117-
return Args.size() > 4 ? Args[4] : nullptr;
1094+
return Args.size() > 5 ? Args[5] : nullptr;
11181095
}
11191096

11201097
std::vector<SPIRVEntry *> getNonLiteralOperands() const override {

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ enum InternalOp {
6868
IOpJointMatrixSUMadINTEL = 6128,
6969
IOpJointMatrixUSMadINTEL = 6129,
7070
IOpJointMatrixUUMadINTEL = 6130,
71-
IOpTypeJointMatrixINTELv2 = 6184,
7271
IOpArithmeticFenceINTEL = 6145,
7372
IOpCooperativeMatrixLoadCheckedINTEL = 6193,
7473
IOpCooperativeMatrixStoreCheckedINTEL = 6194,
@@ -180,7 +179,6 @@ _SPIRV_OP(Capability, JointMatrixBF16ComponentTypeINTEL)
180179
_SPIRV_OP(Capability, JointMatrixPackedInt2ComponentTypeINTEL)
181180
_SPIRV_OP(Capability, JointMatrixPackedInt4ComponentTypeINTEL)
182181
_SPIRV_OP(Op, TypeJointMatrixINTEL)
183-
_SPIRV_OP(Op, TypeJointMatrixINTELv2)
184182
_SPIRV_OP(Op, JointMatrixLoadINTEL)
185183
_SPIRV_OP(Op, JointMatrixStoreINTEL)
186184
_SPIRV_OP(Op, JointMatrixMadINTEL)

0 commit comments

Comments
 (0)