Skip to content
265 changes: 120 additions & 145 deletions llvm/lib/SYCLLowerIR/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ namespace id = itanium_demangle;

#define SLM_BTI 254

#define MAX_DIMS 3

namespace {
SmallPtrSet<Type *, 4> collectGenXVolatileTypes(Module &);
void generateKernelMetadata(Module &);
Expand Down Expand Up @@ -846,145 +848,131 @@ static Instruction *addCastInstIfNeeded(Instruction *OldI, Instruction *NewI) {
auto CastOpcode = CastInst::getCastOpcode(NewI, false, OITy, false);
NewI = CastInst::Create(CastOpcode, NewI, OITy,
NewI->getName() + ".cast.ty", OldI);
NewI->setDebugLoc(OldI->getDebugLoc());
}
return NewI;
}

static int getIndexForSuffix(StringRef Suff) {
return llvm::StringSwitch<int>(Suff)
.Case("x", 0)
.Case("y", 1)
.Case("z", 2)
.Default(-1);
/// Returns the index from the given extract element instruction \p EEI.
/// It is checked here that the index is either 0, 1, or 2.
static uint64_t getIndexFromExtract(ExtractElementInst *EEI) {
Value *IndexV = EEI->getIndexOperand();
uint64_t IndexValue = cast<ConstantInt>(IndexV)->getZExtValue();
assert(IndexValue < MAX_DIMS &&
"Extract element index should be either 0, 1, or 2");
return IndexValue;
}

// Helper function to convert extractelement instruction associated with the
// load from SPIRV builtin global, into the GenX intrinsic that returns vector
// of coordinates. It also generates required extractelement and cast
// instructions. Example:
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast
// (<3 x i64> addrspace(1)* @__spirv_BuiltInLocalInvocationId
// to <3 x i64> addrspace(4)*), align 32
// %1 = extractelement <3 x i64> %0, i64 0
//
// =>
//
// %.esimd = call <3 x i32> @llvm.genx.local.id.v3i32()
// %local_id.x = extractelement <3 x i32> %.esimd, i32 0
// %local_id.x.cast.ty = zext i32 %local_id.x to i64
static Instruction *generateVectorGenXForSpirv(ExtractElementInst *EEI,
StringRef Suff,
const std::string &IntrinName,
StringRef ValueName) {
std::string IntrName =
std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + IntrinName;
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
LLVMContext &Ctx = EEI->getModule()->getContext();
Type *I32Ty = Type::getInt32Ty(Ctx);
Function *NewFDecl = GenXIntrinsic::getGenXDeclaration(
EEI->getModule(), ID, {FixedVectorType::get(I32Ty, 3)});
Instruction *IntrI =
IntrinsicInst::Create(NewFDecl, {}, EEI->getName() + ".esimd", EEI);
int ExtractIndex = getIndexForSuffix(Suff);
assert(ExtractIndex != -1 && "Extract index is invalid.");
Twine ExtractName = ValueName + Suff;

Instruction *ExtrI = ExtractElementInst::Create(
IntrI, ConstantInt::get(I32Ty, ExtractIndex), ExtractName, EEI);
Instruction *CastI = addCastInstIfNeeded(EEI, ExtrI);
if (EEI->getDebugLoc()) {
IntrI->setDebugLoc(EEI->getDebugLoc());
ExtrI->setDebugLoc(EEI->getDebugLoc());
// It's OK if ExtrI and CastI is the same instruction
CastI->setDebugLoc(EEI->getDebugLoc());
/// Generates the call of GenX intrinsic \p IntrinName and inserts it
/// right before the given extract element instruction \p EEI using the result
/// of vector load. The parameter \p IsVectorCall tells what version of GenX
/// intrinsic (scalar or vector) to use to lower the load from SPIRV global.
static Instruction *generateGenXCall(ExtractElementInst *EEI,
StringRef IntrinName, bool IsVectorCall) {
uint64_t IndexValue = getIndexFromExtract(EEI);
std::string Suffix =
IsVectorCall
? ".v3i32"
: (Twine(".") + Twine(static_cast<char>('x' + IndexValue))).str();
std::string FullIntrinName = (Twine(GenXIntrinsic::getGenXIntrinsicPrefix()) +
Twine(IntrinName) + Suffix)
.str();
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(FullIntrinName);
Type *I32Ty = Type::getInt32Ty(EEI->getModule()->getContext());
Function *NewFDecl =
IsVectorCall
? GenXIntrinsic::getGenXDeclaration(
EEI->getModule(), ID, FixedVectorType::get(I32Ty, MAX_DIMS))
: GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID);

std::string ResultName =
(Twine(EEI->getNameOrAsOperand()) + "." + FullIntrinName).str();
Instruction *Inst = IntrinsicInst::Create(NewFDecl, {}, ResultName, EEI);
Inst->setDebugLoc(EEI->getDebugLoc());

if (IsVectorCall) {
Type *I32Ty = Type::getInt32Ty(EEI->getModule()->getContext());
std::string ExtractName =
(Twine(Inst->getNameOrAsOperand()) + ".ext." + Twine(IndexValue)).str();
Inst = ExtractElementInst::Create(Inst, ConstantInt::get(I32Ty, IndexValue),
ExtractName, EEI);
Inst->setDebugLoc(EEI->getDebugLoc());
}
return CastI;
Inst = addCastInstIfNeeded(EEI, Inst);
return Inst;
}

// Helper function to convert extractelement instruction associated with the
// load from SPIRV builtin global, into the GenX intrinsic. It also generates
// required cast instructions. Example:
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
// 32 %1 = extractelement <3 x i64> %0, i64 0
// =>
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
// 32 %group.id.x = call i32 @llvm.genx.group.id.x() %group.id.x.cast.ty = zext
// i32 %group.id.x to i64
static Instruction *generateGenXForSpirv(ExtractElementInst *EEI,
StringRef Suff,
const std::string &IntrinName) {
std::string IntrName = std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) +
IntrinName + Suff.str();
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
Function *NewFDecl =
GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID, {});

Instruction *IntrI =
IntrinsicInst::Create(NewFDecl, {}, IntrinName + Suff.str(), EEI);
Instruction *CastI = addCastInstIfNeeded(EEI, IntrI);
if (EEI->getDebugLoc()) {
IntrI->setDebugLoc(EEI->getDebugLoc());
// It's OK if IntrI and CastI is the same instruction
CastI->setDebugLoc(EEI->getDebugLoc());
/// Replaces the load \p LI of SPIRV global with corresponding call(s) of GenX
/// intrinsic(s). The users of \p LI may also be transformed if needed for
/// def/use type correctness.
/// The replaced instructions are stored into the given container
/// \p InstsToErase.
static void
translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
SmallVectorImpl<Instruction *> &InstsToErase) {
// TODO: Implement support for the following intrinsics:
// uint32_t __spirv_BuiltIn NumSubgroups;
// uint32_t __spirv_BuiltIn SubgroupId;

// Translate those loads from _scalar_ SPIRV globals that can be replaced with
// a const value here.
// The loads from other scalar SPIRV globals may require insertion of GenX
// calls before each user, which is done in the loop by users of 'LI' below.
Value *NewInst = nullptr;
if (SpirvGlobalName == "SubgroupLocalInvocationId") {
NewInst = llvm::Constant::getNullValue(LI->getType());
} else if (SpirvGlobalName == "SubgroupSize" ||
SpirvGlobalName == "SubgroupMaxSize") {
NewInst = llvm::Constant::getIntegerValue(LI->getType(),
llvm::APInt(32, 1, true));
}
if (NewInst) {
LI->replaceAllUsesWith(NewInst);
InstsToErase.push_back(LI);
return;
}
return CastI;
}

// This function translates one occurence of SPIRV builtin use into GenX
// intrinsic.
static Value *translateSpirvGlobalUse(ExtractElementInst *EEI,
StringRef SpirvGlobalName) {
Value *IndexV = EEI->getIndexOperand();
assert(isa<ConstantInt>(IndexV) &&
"Extract element index should be a constant");
// Only loads from _vector_ SPIRV globals reach here now. Their users are
// expected to be ExtractElementInst only, and they are replaced in this loop.
// When loads from _scalar_ SPIRV globals are handled here as well, the users
// will not be replaced by new instructions, but the GenX call replacing the
// original load 'LI' should be inserted before each user.
for (User *LU : LI->users()) {
ExtractElementInst *EEI = cast<ExtractElementInst>(LU);
NewInst = nullptr;

if (SpirvGlobalName == "WorkgroupSize") {
NewInst = generateGenXCall(EEI, "local.size", true);
} else if (SpirvGlobalName == "LocalInvocationId") {
NewInst = generateGenXCall(EEI, "local.id", true);
} else if (SpirvGlobalName == "WorkgroupId") {
NewInst = generateGenXCall(EEI, "group.id", false);
} else if (SpirvGlobalName == "GlobalInvocationId") {
// GlobalId = LocalId + WorkGroupSize * GroupId
Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true);
Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true);
Instruction *GroupIdI = generateGenXCall(EEI, "group.id", false);
Instruction *MulI =
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
} else if (SpirvGlobalName == "GlobalSize") {
// GlobalSize = WorkGroupSize * NumWorkGroups
Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true);
Instruction *NumWGI = generateGenXCall(EEI, "group.count", true);
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
} else if (SpirvGlobalName == "GlobalOffset") {
// TODO: Support GlobalOffset SPIRV intrinsics
// Currently all users of load of GlobalOffset are replaced with 0.
NewInst = llvm::Constant::getNullValue(EEI->getType());
} else if (SpirvGlobalName == "NumWorkgroups") {
NewInst = generateGenXCall(EEI, "group.count", true);
}

// Get the suffix based on the index of extractelement instruction
ConstantInt *IndexC = cast<ConstantInt>(IndexV);
std::string Suff;
if (IndexC->equalsInt(0))
Suff = 'x';
else if (IndexC->equalsInt(1))
Suff = 'y';
else if (IndexC->equalsInt(2))
Suff = 'z';
else
assert(false && "Extract element index should be either 0, 1, or 2");

// Translate SPIRV into GenX intrinsic.
if (SpirvGlobalName == "WorkgroupSize") {
return generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
} else if (SpirvGlobalName == "LocalInvocationId") {
return generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id.");
} else if (SpirvGlobalName == "WorkgroupId") {
return generateGenXForSpirv(EEI, Suff, "group.id.");
} else if (SpirvGlobalName == "GlobalInvocationId") {
// GlobalId = LocalId + WorkGroupSize * GroupId
Instruction *LocalIdI =
generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id.");
Instruction *WGSizeI =
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
Instruction *GroupIdI = generateGenXForSpirv(EEI, Suff, "group.id.");
Instruction *MulI =
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
return BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
} else if (SpirvGlobalName == "GlobalSize") {
// GlobalSize = WorkGroupSize * NumWorkGroups
Instruction *WGSizeI =
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
Instruction *NumWGI = generateVectorGenXForSpirv(
EEI, Suff, "group.count.v3i32", "group_count.");
return BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
} else if (SpirvGlobalName == "GlobalOffset") {
// TODO: Support GlobalOffset SPIRV intrinsics
return llvm::Constant::getNullValue(EEI->getType());
} else if (SpirvGlobalName == "NumWorkgroups") {
return generateVectorGenXForSpirv(EEI, Suff, "group.count.v3i32",
"group_count.");
assert(NewInst && "Load from global SPIRV builtin was not translated");
EEI->replaceAllUsesWith(NewInst);
InstsToErase.push_back(EEI);
}

return nullptr;
InstsToErase.push_back(LI);
}

static void createESIMDIntrinsicArgs(const ESIMDIntrinDesc &Desc,
Expand Down Expand Up @@ -1370,8 +1358,7 @@ SmallPtrSet<Type *, 4> collectGenXVolatileTypes(Module &M) {

} // namespace

PreservedAnalyses SYCLLowerESIMDPass::run(Module &M,
ModuleAnalysisManager &) {
PreservedAnalyses SYCLLowerESIMDPass::run(Module &M, ModuleAnalysisManager &) {
generateKernelMetadata(M);
SmallPtrSet<Type *, 4> GVTS = collectGenXVolatileTypes(M);

Expand Down Expand Up @@ -1507,23 +1494,11 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,

auto PrefLen = StringRef(SPIRV_INTRIN_PREF).size();

// Go through all the uses of the load instruction from SPIRV builtin
// globals, which are required to be extractelement instructions.
// Translate each of them.
for (auto *LU : LI->users()) {
auto *EEI = dyn_cast<ExtractElementInst>(LU);
assert(EEI && "User of load from global SPIRV builtin is not an "
"extractelement instruction");
Value *TranslatedVal = translateSpirvGlobalUse(
EEI, SpirvGlobal->getName().drop_front(PrefLen));
assert(TranslatedVal &&
"Load from global SPIRV builtin was not translated");
EEI->replaceAllUsesWith(TranslatedVal);
ESIMDToErases.push_back(EEI);
}
// After all users of load were translated, we get rid of the load
// itself.
ESIMDToErases.push_back(LI);
// Translate all uses of the load instruction from SPIRV builtin global.
// Replaces the original global load and it is uses and stores the old
// instructions to ESIMDToErases.
translateSpirvGlobalUses(LI, SpirvGlobal->getName().drop_front(PrefLen),
ESIMDToErases);
}
}
// Now demangle and translate found ESIMD intrinsic calls
Expand Down
30 changes: 30 additions & 0 deletions sycl/test/esimd/spirv_intrins_trans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ size_t caller() {

size_t DoNotOpt;
cl::sycl::buffer<size_t, 1> buf(&DoNotOpt, 1);
uint32_t DoNotOpt32;
cl::sycl::buffer<uint32_t, 1> buf32(&DoNotOpt32, 1);

size_t DoNotOptXYZ[3];
cl::sycl::buffer<size_t, 1> bufXYZ(&DoNotOptXYZ[0], sycl::range<1>(3));

cl::sycl::queue().submit([&](cl::sycl::handler &cgh) {
auto DoNotOptimize = buf.get_access<cl::sycl::access::mode::write>(cgh);
auto DoNotOptimize32 = buf32.get_access<cl::sycl::access::mode::write>(cgh);

kernel<class kernel_GlobalInvocationId_x>([=]() SYCL_ESIMD_KERNEL {
*DoNotOptimize.get_pointer() = __spirv_GlobalInvocationId_x();
Expand Down Expand Up @@ -213,6 +216,33 @@ size_t caller() {
// CHECK: {{.*}} call i32 @llvm.genx.group.id.x()
// CHECK: {{.*}} call i32 @llvm.genx.group.id.y()
// CHECK: {{.*}} call i32 @llvm.genx.group.id.z()

kernel<class kernel_SubgroupLocalInvocationId>([=]() SYCL_ESIMD_KERNEL {
*DoNotOptimize.get_pointer() = __spirv_SubgroupLocalInvocationId();
*DoNotOptimize32.get_pointer() = __spirv_SubgroupLocalInvocationId() + 3;
});
// CHECK-LABEL: @{{.*}}kernel_SubgroupLocalInvocationId
// CHECK: [[ZEXT0:%.*]] = zext i32 0 to i64
// CHECK: store i64 [[ZEXT0]]
// CHECK: add i32 0, 3

kernel<class kernel_SubgroupSize>([=]() SYCL_ESIMD_KERNEL {
*DoNotOptimize.get_pointer() = __spirv_SubgroupSize();
*DoNotOptimize32.get_pointer() = __spirv_SubgroupSize() + 7;
});
// CHECK-LABEL: @{{.*}}kernel_SubgroupSize
// CHECK: [[ZEXT0:%.*]] = zext i32 1 to i64
// CHECK: store i64 [[ZEXT0]]
// CHECK: add i32 1, 7

kernel<class kernel_SubgroupMaxSize>([=]() SYCL_ESIMD_KERNEL {
*DoNotOptimize.get_pointer() = __spirv_SubgroupMaxSize();
*DoNotOptimize32.get_pointer() = __spirv_SubgroupMaxSize() + 9;
});
// CHECK-LABEL: @{{.*}}kernel_SubgroupMaxSize
// CHECK: [[ZEXT0:%.*]] = zext i32 1 to i64
// CHECK: store i64 [[ZEXT0]]
// CHECK: add i32 1, 9
});
return DoNotOpt;
}