From 9b0161c1edfcbb45809c32a935999d54b536cc90 Mon Sep 17 00:00:00 2001 From: "Maksimova, Viktoria" Date: Thu, 24 Oct 2024 05:59:21 -0700 Subject: [PATCH] Fix OpGroupAsyncCopy used with SPV_KHR_untyped_pointers --- lib/SPIRV/SPIRVWriter.cpp | 26 ++++++++++++++++++++++++-- test/transcoding/OpGroupAsyncCopy.ll | 9 +++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index 489a090081..a0af967962 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -755,7 +755,8 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) { return transPointerType(ET, SPIRAS_Private); if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers) && !(ET->isTypeArray() || ET->isTypeVector() || ET->isTypeStruct() || - ET->isTypeImage() || ET->isTypeSampler() || ET->isTypePipe())) { + ET->isTypeImage() || ET->isTypeSampler() || ET->isTypePipe() || + ET->isTypeEvent())) { TranslatedTy = BM->addUntypedPointerKHRType( SPIRSPIRVAddrSpaceMap::map(static_cast(AddrSpc))); } else { @@ -2486,7 +2487,8 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB, if (GetElementPtrInst *GEP = dyn_cast(V)) { auto *PointerOperand = GEP->getPointerOperand(); - std::vector Ops = {transValue(PointerOperand, BB)->getId()}; + auto *SPVPointerOperand = transValue(PointerOperand, BB); + std::vector Ops = {SPVPointerOperand->getId()}; for (unsigned I = 0, E = GEP->getNumIndices(); I != E; ++I) Ops.push_back(transValue(GEP->getOperand(I + 1), BB)->getId()); @@ -2536,7 +2538,15 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB, PtrTy = transType(TPT->getElementType()); Ops = getVec(PtrTy->getId(), Ops); } + } else if (SPVPointerOperand->getType()->isTypeUntypedPointerKHR()) { + // TranslatedTy defines whether ptr access chain instruction would be + // typed or untyped. + // Ensure that we don't create typed ptr access chain instruction with + // untyped pointer as an operand. + Ops[0] = BM->addUnaryInst(OpBitcast, TranslatedTy, SPVPointerOperand, BB) + ->getId(); } + return mapValue( V, BM->addPtrAccessChainInst(TranslatedTy, Ops, BB, GEP->isInBounds())); } @@ -6363,6 +6373,18 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI, } break; case OpGroupAsyncCopy: { auto BArgs = transValue(getArguments(CI), BB); + // Check that type of Source and Destination are the same, bitcast if not + // (this mostly needed for untyped pointers) + SPIRVType *SrcTy = BArgs[1]->getType(); + SPIRVType *DstTy = BArgs[2]->getType(); + if (SrcTy->getPointerElementType() != DstTy->getPointerElementType()) { + llvm::Type *Ty = Scavenger->getScavengedType(CI->getOperand(2)); + if (auto *TPT = dyn_cast(Ty)) { + SPIRVType *PtrTy = + transPointerType(TPT->getElementType(), TPT->getAddressSpace()); + BArgs[2] = BM->addUnaryInst(OpBitcast, PtrTy, BArgs[2], BB); + } + } return BM->addAsyncGroupCopy(BArgs[0], BArgs[1], BArgs[2], BArgs[3], BArgs[4], BArgs[5], BB); } break; diff --git a/test/transcoding/OpGroupAsyncCopy.ll b/test/transcoding/OpGroupAsyncCopy.ll index cbe858a6d6..66c9384427 100644 --- a/test/transcoding/OpGroupAsyncCopy.ll +++ b/test/transcoding/OpGroupAsyncCopy.ll @@ -8,6 +8,15 @@ ; RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc ; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-SPV-IR +; RUN: llvm-spirv %t.bc -spirv-text -o %t.txt --spirv-ext=+SPV_KHR_untyped_pointers +; RUN: FileCheck < %t.txt %s --check-prefix=CHECK-SPIRV +; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_KHR_untyped_pointers +; RUN: spirv-val %t.spv +; RUN: llvm-spirv -r %t.spv -o %t.rev.bc +; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM +; RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc +; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-SPV-IR + ; CHECK-LLVM: call spir_func ptr @_Z29async_work_group_strided_copyPU3AS1Dv2_hPU3AS3KS_jj9ocl_event( ; CHECK-LLVM: call spir_func void @_Z17wait_group_eventsiPU3AS49ocl_event( ; CHECK-LLVM: declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS1Dv2_hPU3AS3KS_jj9ocl_event(ptr addrspace(1), ptr addrspace(3), i32, i32, ptr)