Skip to content

Commit fc9896b

Browse files
authored
Implement SPV_INTEL_task_sequence extension (#2340)
Spec: KhronosGroup/SPIRV-Registry#192
1 parent 43acfef commit fc9896b

19 files changed

+527
-8
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,4 @@ EXT(SPV_INTEL_fpga_latency_control)
7070
EXT(SPV_INTEL_fp_max_error)
7171
EXT(SPV_INTEL_cache_controls)
7272
EXT(SPV_INTEL_subgroup_requirements)
73+
EXT(SPV_INTEL_task_sequence)

lib/SPIRV/OCLUtil.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,7 @@ SPIRAddressSpace getOCLOpaqueTypeAddrSpace(Op OpCode) {
908908
case internal::OpTypeJointMatrixINTEL:
909909
case internal::OpTypeJointMatrixINTELv2:
910910
case OpTypeCooperativeMatrixKHR:
911+
case internal::OpTypeTaskSequenceINTEL:
911912
return SPIRAS_Global;
912913
default:
913914
if (isSubgroupAvcINTELTypeOpCode(OpCode))

lib/SPIRV/SPIRVInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,7 @@ template <> inline void SPIRVMap<std::string, Op, SPIRVOpaqueType>::init() {
938938
_SPIRV_OP(CooperativeMatrixKHR)
939939
#undef _SPIRV_OP
940940
add("JointMatrixINTEL", internal::OpTypeJointMatrixINTEL);
941+
add("TaskSequenceINTEL", internal::OpTypeTaskSequenceINTEL);
941942
}
942943

943944
// Check if the module contains llvm.loop.* metadata

lib/SPIRV/SPIRVReader.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,9 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
491491
return mapType(T, transType(static_cast<SPIRVType *>(
492492
BM->getEntry(FP->getPointerId()))));
493493
}
494+
case internal::OpTypeTaskSequenceINTEL:
495+
return mapType(
496+
T, llvm::TargetExtType::get(*Context, "spirv.TaskSequenceINTEL"));
494497

495498
default: {
496499
auto OC = T->getOpCode();
@@ -2291,6 +2294,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
22912294
}
22922295
case internal::OpTypeJointMatrixINTEL:
22932296
case OpTypeCooperativeMatrixKHR:
2297+
case internal::OpTypeTaskSequenceINTEL:
22942298
return mapValue(BV, transSPIRVBuiltinFromInst(CC, BB));
22952299
default:
22962300
llvm_unreachable("Unhandled type!");
@@ -3392,6 +3396,7 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
33923396
case OpSUDotAccSatKHR:
33933397
case internal::OpJointMatrixLoadINTEL:
33943398
case OpCooperativeMatrixLoadKHR:
3399+
case internal::OpTaskSequenceCreateINTEL:
33953400
AddRetTypePostfix = true;
33963401
break;
33973402
default: {

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,7 +1252,8 @@ SPIR::TypePrimitiveEnum getOCLTypePrimitiveEnum(StringRef TyName) {
12521252
/// \param Signed indicates integer type should be translated as signed.
12531253
/// \param VoidPtr indicates i8* should be translated as void*.
12541254
static SPIR::RefParamType transTypeDesc(Type *Ty,
1255-
const BuiltinArgTypeMangleInfo &Info) {
1255+
const BuiltinArgTypeMangleInfo &Info,
1256+
StringRef InstName = "") {
12561257
bool Signed = Info.IsSigned;
12571258
unsigned Attr = Info.Attr;
12581259
bool VoidPtr = Info.IsVoidPtr;
@@ -1361,8 +1362,14 @@ static SPIR::RefParamType transTypeDesc(Type *Ty,
13611362
auto *ET = TPT->getElementType();
13621363
SPIR::ParamType *EPT = nullptr;
13631364
if (isa<FunctionType>(ET)) {
1364-
assert(isVoidFuncTy(cast<FunctionType>(ET)) && "Not supported");
1365-
EPT = new SPIR::BlockType;
1365+
FunctionType *FT = cast<FunctionType>(ET);
1366+
if (InstName.consume_front(kSPIRVName::Prefix) &&
1367+
InstName.starts_with("TaskSequence")) {
1368+
EPT = new SPIR::PointerType(transTypeDesc(FT->getReturnType(), Info));
1369+
} else {
1370+
assert((isVoidFuncTy(FT)) && "Not supported");
1371+
EPT = new SPIR::BlockType;
1372+
}
13661373
} else if (auto *StructTy = dyn_cast<StructType>(ET)) {
13671374
LLVM_DEBUG(dbgs() << "ptr to struct: " << *Ty << '\n');
13681375
auto TyName = StructTy->getStructName();
@@ -1690,7 +1697,7 @@ std::string mangleBuiltin(StringRef UniqName, ArrayRef<Type *> ArgTypes,
16901697
T = MangleInfo.PointerTy;
16911698
}
16921699
FD.Parameters.emplace_back(
1693-
transTypeDesc(T, BtnInfo->getTypeMangleInfo(I)));
1700+
transTypeDesc(T, BtnInfo->getTypeMangleInfo(I), UniqName));
16941701
}
16951702
}
16961703
// Ellipsis must be the last argument of any function

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,8 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
617617
Args.emplace_back(transConstant(getUInt32(M, Op)));
618618
return mapType(T, BM->addCooperativeMatrixKHRType(ElemTy, Args));
619619
}
620+
case internal::OpTypeTaskSequenceINTEL:
621+
return mapType(T, BM->addTaskSequenceINTELType());
620622
default:
621623
if (isSubgroupAvcINTELTypeOpCode(Opcode))
622624
return mapType(T, BM->addSubgroupAvcINTELType(Opcode));

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ class SPIRVInstTemplateBase : public SPIRVInstruction {
226226
virtual void init() {}
227227
virtual void initImpl(Op OC, bool HasId = true, SPIRVWord WC = 0,
228228
bool VariWC = false, unsigned Lit1 = ~0U,
229-
unsigned Lit2 = ~0U, unsigned Lit3 = ~0U) {
229+
unsigned Lit2 = ~0U, unsigned Lit3 = ~0U,
230+
unsigned Lit4 = ~0U) {
230231
OpCode = OC;
231232
if (!HasId) {
232233
setHasNoId();
@@ -238,6 +239,7 @@ class SPIRVInstTemplateBase : public SPIRVInstruction {
238239
addLit(Lit1);
239240
addLit(Lit2);
240241
addLit(Lit3);
242+
addLit(Lit4);
241243
}
242244
bool isOperandLiteral(unsigned I) const override { return Lit.count(I); }
243245
void addLit(unsigned L) {
@@ -364,14 +366,16 @@ class SPIRVInstTemplateBase : public SPIRVInstruction {
364366

365367
template <typename BT = SPIRVInstTemplateBase, Op OC = OpNop, bool HasId = true,
366368
SPIRVWord WC = 0, bool HasVariableWC = false, unsigned Literal1 = ~0U,
367-
unsigned Literal2 = ~0U, unsigned Literal3 = ~0U>
369+
unsigned Literal2 = ~0U, unsigned Literal3 = ~0U,
370+
unsigned Literal4 = ~0U>
368371
class SPIRVInstTemplate : public BT {
369372
public:
370373
typedef BT BaseTy;
371374
SPIRVInstTemplate() { init(); }
372375
~SPIRVInstTemplate() override {}
373376
void init() override {
374-
this->initImpl(OC, HasId, WC, HasVariableWC, Literal1, Literal2, Literal3);
377+
this->initImpl(OC, HasId, WC, HasVariableWC, Literal1, Literal2, Literal3,
378+
Literal4);
375379
}
376380
};
377381

@@ -3854,5 +3858,85 @@ template <Op OC> class SPIRVReadClockKHRInstBase : public SPIRVUnaryInst<OC> {
38543858
_SPIRV_OP(ReadClockKHR)
38553859
#undef _SPIRV_OP
38563860

3861+
class SPIRVTaskSequenceINTELInstBase : public SPIRVInstTemplateBase {
3862+
public:
3863+
std::optional<ExtensionID> getRequiredExtension() const override {
3864+
return ExtensionID::SPV_INTEL_task_sequence;
3865+
}
3866+
};
3867+
3868+
class SPIRVTaskSequenceINTELInst : public SPIRVTaskSequenceINTELInstBase {
3869+
public:
3870+
SPIRVCapVec getRequiredCapability() const override {
3871+
return getVec(internal::CapabilityTaskSequenceINTEL);
3872+
}
3873+
};
3874+
3875+
class SPIRVTaskSequenceCreateINTELInst : public SPIRVTaskSequenceINTELInst {
3876+
protected:
3877+
void validate() const override {
3878+
SPIRVInstruction::validate();
3879+
std::string InstName = "TaskSequenceCreateINTEL";
3880+
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
3881+
3882+
SPIRVType *ResTy = this->getType();
3883+
SPVErrLog.checkError(
3884+
ResTy->isTypeTaskSequenceINTEL(), SPIRVEC_InvalidInstruction,
3885+
InstName + "\nResult must be TaskSequenceINTEL type\n");
3886+
3887+
SPIRVValue *Func =
3888+
const_cast<SPIRVTaskSequenceCreateINTELInst *>(this)->getOperand(0);
3889+
SPVErrLog.checkError(
3890+
Func->getOpCode() == OpFunction, SPIRVEC_InvalidInstruction,
3891+
InstName + "\nFirst argument is expected to be a function.\n");
3892+
3893+
SPIRVConstant *PipelinedConst = static_cast<SPIRVConstant *>(
3894+
const_cast<SPIRVTaskSequenceCreateINTELInst *>(this)->getOperand(1));
3895+
const int Pipelined = PipelinedConst->getZExtIntValue();
3896+
SPVErrLog.checkError(Pipelined >= -1, SPIRVEC_InvalidInstruction,
3897+
InstName + "\nPipeline must be a 32 bit integer with "
3898+
"the value bigger or equal to -1.\n");
3899+
3900+
const int ClusterMode =
3901+
static_cast<SPIRVConstant *>(
3902+
const_cast<SPIRVTaskSequenceCreateINTELInst *>(this)->getOperand(2))
3903+
->getZExtIntValue();
3904+
SPVErrLog.checkError(
3905+
ClusterMode >= -1 && ClusterMode <= 1, SPIRVEC_InvalidInstruction,
3906+
InstName + "\nClusterMode valid values are -1, 0, 1.\n");
3907+
3908+
const uint32_t GetCapacity =
3909+
static_cast<SPIRVConstant *>(
3910+
const_cast<SPIRVTaskSequenceCreateINTELInst *>(this)->getOperand(3))
3911+
->getZExtIntValue();
3912+
SPVErrLog.checkError(
3913+
GetCapacity, SPIRVEC_InvalidInstruction,
3914+
InstName + "\nGetCapacity must be unsigned 32 bit integer.\n");
3915+
3916+
const uint32_t AsyncCapacity =
3917+
static_cast<SPIRVConstant *>(
3918+
const_cast<SPIRVTaskSequenceCreateINTELInst *>(this)->getOperand(4))
3919+
->getZExtIntValue();
3920+
SPVErrLog.checkError(
3921+
AsyncCapacity, SPIRVEC_InvalidInstruction,
3922+
InstName + "\nAsyncCapacity must be unsigned 32 bit integer.\n");
3923+
}
3924+
};
3925+
3926+
#define _SPIRV_OP(x, ...) \
3927+
typedef SPIRVInstTemplate<SPIRVTaskSequenceINTELInst, \
3928+
internal::Op##x##INTEL, __VA_ARGS__> \
3929+
SPIRV##x##INTEL;
3930+
_SPIRV_OP(TaskSequenceAsync, false, 2, true)
3931+
_SPIRV_OP(TaskSequenceGet, true, 4, false)
3932+
_SPIRV_OP(TaskSequenceRelease, false, 2, false)
3933+
#undef _SPIRV_OP
3934+
#define _SPIRV_OP(x, ...) \
3935+
typedef SPIRVInstTemplate<SPIRVTaskSequenceCreateINTELInst, \
3936+
internal::Op##x##INTEL, __VA_ARGS__> \
3937+
SPIRV##x##INTEL;
3938+
_SPIRV_OP(TaskSequenceCreate, true, 8, false, 1, 2, 3, 4)
3939+
#undef _SPIRV_OP
3940+
38573941
} // namespace SPIRV
38583942
#endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ class SPIRVModuleImpl : public SPIRVModule {
247247
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) override;
248248
SPIRVTypeCooperativeMatrixKHR *
249249
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) override;
250+
SPIRVTypeTaskSequenceINTEL *addTaskSequenceINTELType() override;
250251
SPIRVType *addOpaqueGenericType(Op) override;
251252
SPIRVTypeDeviceEvent *addDeviceEventType() override;
252253
SPIRVTypeQueue *addQueueType() override;
@@ -1021,6 +1022,10 @@ SPIRVModuleImpl::addCooperativeMatrixKHRType(SPIRVType *CompType,
10211022
new SPIRVTypeCooperativeMatrixKHR(this, getId(), CompType, Args));
10221023
}
10231024

1025+
SPIRVTypeTaskSequenceINTEL *SPIRVModuleImpl::addTaskSequenceINTELType() {
1026+
return addType(new SPIRVTypeTaskSequenceINTEL(this, getId()));
1027+
}
1028+
10241029
SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) {
10251030
return addType(new SPIRVTypeOpaqueGeneric(TheOpCode, this, getId()));
10261031
}

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class SPIRVTypeBufferSurfaceINTEL;
9797
class SPIRVTypeTokenINTEL;
9898
class SPIRVTypeJointMatrixINTEL;
9999
class SPIRVTypeCooperativeMatrixKHR;
100+
class SPIRVTypeTaskSequenceINTEL;
100101

101102
typedef SPIRVBasicBlock SPIRVLabel;
102103
struct SPIRVTypeImageDescriptor;
@@ -264,6 +265,7 @@ class SPIRVModule {
264265
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
265266
virtual SPIRVTypeCooperativeMatrixKHR *
266267
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
268+
virtual SPIRVTypeTaskSequenceINTEL *addTaskSequenceINTELType() = 0;
267269
virtual SPIRVTypeVoid *addVoidType() = 0;
268270
virtual SPIRVType *addOpaqueGenericType(Op) = 0;
269271
virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0;

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
670670
"CooperativeMatrixCheckedInstructionsINTEL");
671671
add(internal::CapabilitySubgroupRequirementsINTEL,
672672
"SubgroupRequirementsINTEL");
673+
add(internal::CapabilityTaskSequenceINTEL, "TaskSequenceINTEL");
673674
}
674675
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
675676

0 commit comments

Comments
 (0)