Skip to content

Conversation

YixingZhang007
Copy link
Contributor

@YixingZhang007 YixingZhang007 commented Sep 29, 2025

No description provided.

@YixingZhang007 YixingZhang007 marked this pull request as draft September 29, 2025 20:34
@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2025

@llvm/pr-subscribers-backend-spir-v

Author: None (YixingZhang007)

Changes

spirv-backend


Full diff: https://github.com/llvm/llvm-project/pull/161270.diff

6 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp (+12-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+9-8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+2-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+19-20)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+6-5)
  • (modified) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll (+4)
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 776208bd3e693..dff9f699ebd6f 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -50,18 +50,24 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
   unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
   const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
 
-  assert((NumVarOps == 1 || NumVarOps == 2) &&
+  // we support integer up to 1024 bits
+  assert((NumVarOps <= 1024) &&
          "Unsupported number of bits for literal variable");
 
   O << ' ';
 
-  uint64_t Imm = MI->getOperand(StartIndex).getImm();
-
-  // Handle 64 bit literals.
-  if (NumVarOps == 2) {
-    Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32);
+  // Handle arbitrary number of 32-bit words for the literal value.
+  if (MI->getOpcode() == SPIRV::OpConstantI){
+    APInt Val(NumVarOps * 32, 0);
+    for (unsigned i = 0; i < NumVarOps; ++i) {
+      Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32));
+    }
+    O << Val;
+    return;
   }
 
+  uint64_t Imm = MI->getOperand(StartIndex).getImm();
+
   // Format and print float values.
   if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
     APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 115766ce886c7..05b3371e97cdc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -149,7 +149,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
 }
 
 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
-  if (Width > 64)
+  if (Width > 1024)
     report_fatal_error("Unsupported integer width!");
   const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
   if (ST.canUseExtension(
@@ -343,7 +343,7 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
   return Res;
 }
 
-Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I,
                                                   SPIRVType *SpvType,
                                                   const SPIRVInstrInfo &TII,
                                                   bool ZeroAsNull) {
@@ -353,10 +353,11 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
   if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
              MI->getOpcode() == SPIRV::OpConstantI))
     return MI->getOperand(0).getReg();
-  return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
+  return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull);
 }
 
-Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
+Register SPIRVGlobalRegistry::createConstInt(const Constant *CI,
+                                             APInt Val,
                                              MachineInstr &I,
                                              SPIRVType *SpvType,
                                              const SPIRVInstrInfo &TII,
@@ -374,15 +375,15 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
         MachineInstrBuilder MIB;
         if (BitWidth == 1) {
           MIB = MIRBuilder
-                    .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse
+                    .buildInstr(Val.isZero() ? SPIRV::OpConstantFalse
                                              : SPIRV::OpConstantTrue)
                     .addDef(Res)
                     .addUse(getSPIRVTypeID(SpvType));
-        } else if (!CI->isZero() || !ZeroAsNull) {
+        } else if (!Val.isZero() || !ZeroAsNull) {
           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
                     .addDef(Res)
                     .addUse(getSPIRVTypeID(SpvType));
-          addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
+          addNumImm(Val, MIB);
         } else {
           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
                     .addDef(Res)
@@ -491,7 +492,7 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
   }
   assert(Type->getOpcode() == SPIRV::OpTypeInt);
   SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
-  return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
+  return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I,
                              SpvBaseType, TII, ZeroAsNull);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a648defa0a888..ee217f81fb416 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -515,10 +515,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
                             SPIRVType *SpvType, bool EmitIR,
                             bool ZeroAsNull = true);
-  Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+  Register getOrCreateConstInt(APInt Val, MachineInstr &I,
                                SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                bool ZeroAsNull = true);
-  Register createConstInt(const ConstantInt *CI, MachineInstr &I,
+  Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I,
                           SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                           bool ZeroAsNull);
   Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 1aadd9df189a8..3e5566945ec0b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2252,8 +2252,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
             .addDef(AElt)
             .addUse(GR.getSPIRVTypeID(ResType))
             .addUse(X)
-            .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
             .constrainAllUses(TII, TRI, RBI);
 
     // B[i]
@@ -2263,8 +2263,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
             .addDef(BElt)
             .addUse(GR.getSPIRVTypeID(ResType))
             .addUse(Y)
-            .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
             .constrainAllUses(TII, TRI, RBI);
 
     // A[i] * B[i]
@@ -2283,8 +2283,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
             .addDef(MaskMul)
             .addUse(GR.getSPIRVTypeID(ResType))
             .addUse(Mul)
-            .addUse(GR.getOrCreateConstInt(0, I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
             .constrainAllUses(TII, TRI, RBI);
 
     // Acc = Acc + A[i] * B[i]
@@ -2381,7 +2381,7 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
   auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
                  .addDef(ResVReg)
                  .addUse(GR.getSPIRVTypeID(ResType))
-                 .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
+                 .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
                                                 IntTy, TII, !STI.isShader()));
 
   for (unsigned J = 2; J < I.getNumOperands(); J++) {
@@ -2405,7 +2405,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
                     TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
                 .addDef(ResVReg)
                 .addUse(GR.getSPIRVTypeID(ResType))
-                .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy,
+                .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy,
                                                TII, !STI.isShader()))
                 .addImm(SPIRV::GroupOperation::Reduce)
                 .addUse(BallotReg)
@@ -2436,7 +2436,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
                                      !STI.isShader()))
       .addImm(SPIRV::GroupOperation::Reduce)
       .addUse(I.getOperand(2).getReg())
@@ -2463,7 +2463,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
                                      !STI.isShader()))
       .addImm(SPIRV::GroupOperation::Reduce)
       .addUse(I.getOperand(2).getReg());
@@ -2689,7 +2689,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
   bool ZeroAsNull = !STI.isShader();
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
     return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
-  return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
+  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
 }
 
 Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
@@ -2720,7 +2720,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
       AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
     return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
-  return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII);
+  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII);
 }
 
 bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
@@ -2939,8 +2939,7 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
     Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
                                 ResType, TII, !STI.isShader());
   } else {
-    Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I,
-                                 ResType, TII, !STI.isShader());
+    Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader());
   }
   return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
 }
@@ -3765,7 +3764,7 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
     bool ZeroAsNull = !STI.isShader();
     Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
     Register ConstIntLastIdx = GR.getOrCreateConstInt(
-        ComponentCount - 1, I, BaseType, TII, ZeroAsNull);
+        APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull);
 
     if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
                           SPIRV::OpVectorExtractDynamic))
@@ -3794,9 +3793,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
   SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
   bool ZeroAsNull = !STI.isShader();
   Register ConstIntZero =
-      GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
+      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull);
   Register ConstIntOne =
-      GR.getOrCreateConstInt(1, I, BaseType, TII, ZeroAsNull);
+      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull);
 
   // SPIRV doesn't support vectors with more than 4 components. Since the
   // algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
@@ -3881,9 +3880,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
 
   if (IsScalarRes) {
     NegOneReg =
-        GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull);
-    Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
-    Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull);
+        GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull);
+    Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
+    Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull);
     SelectOp = SPIRV::OpSelectSISCond;
     AddOp = SPIRV::OpIAddS;
   } else {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 820e56b362edc..e409234a83568 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -100,11 +100,12 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
     if (Bitwidth == 16)
       MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
     return;
-  } else if (Bitwidth <= 64) {
-    uint64_t FullImm = Imm.getZExtValue();
-    uint32_t LowBits = FullImm & 0xffffffff;
-    uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
-    MIB.addImm(LowBits).addImm(HighBits);
+  } else if (Bitwidth <= 1024) {
+    unsigned NumWords = (Bitwidth + 31) / 32;
+    for (unsigned i = 0; i < NumWords; ++i) {
+      uint32_t Word = Imm.extractBits(32, i * 32).getZExtValue();
+      MIB.addImm(Word);
+    }
     return;
   }
   report_fatal_error("Unsupported constant bitwidth");
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
index 41d4b58ed1157..17ba9b044842c 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
@@ -8,6 +8,10 @@ define i13 @getConstantI13() {
   ret i13 42
 }
 
+define i96 @getConstantI96() {
+  ret i96 18446744073709551620
+} 
+
 ;; Capabilities:
 ; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers"
 ; CHECK-DAG: OpCapability ArbitraryPrecisionIntegersINTEL

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions cpp,h -- llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp llvm/lib/Target/SPIRV/SPIRVUtils.cpp

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index dff9f699e..f12e9db6a 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -57,10 +57,11 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
   O << ' ';
 
   // Handle arbitrary number of 32-bit words for the literal value.
-  if (MI->getOpcode() == SPIRV::OpConstantI){
+  if (MI->getOpcode() == SPIRV::OpConstantI) {
     APInt Val(NumVarOps * 32, 0);
     for (unsigned i = 0; i < NumVarOps; ++i) {
-      Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32));
+      Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm())
+              << (i * 32));
     }
     O << Val;
     return;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 05b3371e9..c4f8565f3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -356,8 +356,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I,
   return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull);
 }
 
-Register SPIRVGlobalRegistry::createConstInt(const Constant *CI,
-                                             APInt Val,
+Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, APInt Val,
                                              MachineInstr &I,
                                              SPIRVType *SpvType,
                                              const SPIRVInstrInfo &TII,
@@ -492,8 +491,9 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
   }
   assert(Type->getOpcode() == SPIRV::OpTypeInt);
   SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
-  return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I,
-                             SpvBaseType, TII, ZeroAsNull);
+  return getOrCreateConstInt(
+      APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I, SpvBaseType,
+      TII, ZeroAsNull);
 }
 
 Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ee217f81f..9cb7d982c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -515,8 +515,8 @@ public:
   Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
                             SPIRVType *SpvType, bool EmitIR,
                             bool ZeroAsNull = true);
-  Register getOrCreateConstInt(APInt Val, MachineInstr &I,
-                               SPIRVType *SpvType, const SPIRVInstrInfo &TII,
+  Register getOrCreateConstInt(APInt Val, MachineInstr &I, SPIRVType *SpvType,
+                               const SPIRVInstrInfo &TII,
                                bool ZeroAsNull = true);
   Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I,
                           SPIRVType *SpvType, const SPIRVInstrInfo &TII,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3e5566945..475d59516 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2247,25 +2247,27 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
   for (unsigned i = 0; i < 4; i++) {
     // A[i]
     Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result &=
-        BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
-            .addDef(AElt)
-            .addUse(GR.getSPIRVTypeID(ResType))
-            .addUse(X)
-            .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
-            .constrainAllUses(TII, TRI, RBI);
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+                  .addDef(AElt)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(X)
+                  .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType,
+                                                 TII, ZeroAsNull))
+                  .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII,
+                                                 ZeroAsNull))
+                  .constrainAllUses(TII, TRI, RBI);
 
     // B[i]
     Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result &=
-        BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
-            .addDef(BElt)
-            .addUse(GR.getSPIRVTypeID(ResType))
-            .addUse(Y)
-            .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
-            .constrainAllUses(TII, TRI, RBI);
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+                  .addDef(BElt)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(Y)
+                  .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType,
+                                                 TII, ZeroAsNull))
+                  .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII,
+                                                 ZeroAsNull))
+                  .constrainAllUses(TII, TRI, RBI);
 
     // A[i] * B[i]
     Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
@@ -2278,14 +2280,15 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
 
     // Discard 24 highest-bits so that stored i32 register is i8 equivalent
     Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result &=
-        BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
-            .addDef(MaskMul)
-            .addUse(GR.getSPIRVTypeID(ResType))
-            .addUse(Mul)
-            .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
-            .constrainAllUses(TII, TRI, RBI);
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+                  .addDef(MaskMul)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(Mul)
+                  .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII,
+                                                 ZeroAsNull))
+                  .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII,
+                                                 ZeroAsNull))
+                  .constrainAllUses(TII, TRI, RBI);
 
     // Acc = Acc + A[i] * B[i]
     Register Sum =
@@ -2378,11 +2381,12 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
   MachineBasicBlock &BB = *I.getParent();
   SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
 
-  auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
-                 .addDef(ResVReg)
-                 .addUse(GR.getSPIRVTypeID(ResType))
-                 .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
-                                                IntTy, TII, !STI.isShader()));
+  auto BMI =
+      BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
+          .addDef(ResVReg)
+          .addUse(GR.getSPIRVTypeID(ResType))
+          .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+                                         IntTy, TII, !STI.isShader()));
 
   for (unsigned J = 2; J < I.getNumOperands(); J++) {
     BMI.addUse(I.getOperand(J).getReg());
@@ -2401,15 +2405,16 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
                                  SPIRV::OpGroupNonUniformBallot);
 
   MachineBasicBlock &BB = *I.getParent();
-  Result &= BuildMI(BB, I, I.getDebugLoc(),
-                    TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
-                .addDef(ResVReg)
-                .addUse(GR.getSPIRVTypeID(ResType))
-                .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy,
-                                               TII, !STI.isShader()))
-                .addImm(SPIRV::GroupOperation::Reduce)
-                .addUse(BallotReg)
-                .constrainAllUses(TII, TRI, RBI);
+  Result &=
+      BuildMI(BB, I, I.getDebugLoc(),
+              TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
+          .addDef(ResVReg)
+          .addUse(GR.getSPIRVTypeID(ResType))
+          .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+                                         IntTy, TII, !STI.isShader()))
+          .addImm(SPIRV::GroupOperation::Reduce)
+          .addUse(BallotReg)
+          .constrainAllUses(TII, TRI, RBI);
 
   return Result;
 }
@@ -2436,8 +2441,8 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
-                                     !STI.isShader()))
+      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+                                     IntTy, TII, !STI.isShader()))
       .addImm(SPIRV::GroupOperation::Reduce)
       .addUse(I.getOperand(2).getReg())
       .constrainAllUses(TII, TRI, RBI);
@@ -2463,8 +2468,8 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
-                                     !STI.isShader()))
+      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+                                     IntTy, TII, !STI.isShader()))
       .addImm(SPIRV::GroupOperation::Reduce)
       .addUse(I.getOperand(2).getReg());
 }
@@ -2689,7 +2694,8 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
   bool ZeroAsNull = !STI.isShader();
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
     return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
-  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
+  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0),
+                                I, ResType, TII, ZeroAsNull);
 }
 
 Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
@@ -2720,7 +2726,9 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
       AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
     return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
-  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII);
+  return GR.getOrCreateConstInt(
+      APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I,
+      ResType, TII);
 }
 
 bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
@@ -2939,7 +2947,8 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
     Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
                                 ResType, TII, !STI.isShader());
   } else {
-    Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader());
+    Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I,
+                                 ResType, TII, !STI.isShader());
   }
   return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
 }
@@ -3764,7 +3773,8 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
     bool ZeroAsNull = !STI.isShader();
     Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
     Register ConstIntLastIdx = GR.getOrCreateConstInt(
-        APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull);
+        APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I,
+        BaseType, TII, ZeroAsNull);
 
     if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
                           SPIRV::OpVectorExtractDynamic))
@@ -3793,9 +3803,11 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
   SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
   bool ZeroAsNull = !STI.isShader();
   Register ConstIntZero =
-      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull);
+      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0),
+                             I, BaseType, TII, ZeroAsNull);
   Register ConstIntOne =
-      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull);
+      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1),
+                             I, BaseType, TII, ZeroAsNull);
 
   // SPIRV doesn't support vectors with more than 4 components. Since the
   // algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
@@ -3879,10 +3891,15 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
   unsigned AddOp;
 
   if (IsScalarRes) {
-    NegOneReg =
-        GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull);
-    Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
-    Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull);
+    NegOneReg = GR.getOrCreateConstInt(
+        APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType,
+        TII, ZeroAsNull);
+    Reg0 =
+        GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0),
+                               I, ResType, TII, ZeroAsNull);
+    Reg32 =
+        GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32),
+                               I, ResType, TII, ZeroAsNull);
     SelectOp = SPIRV::OpSelectSISCond;
     AddOp = SPIRV::OpIAddS;
   } else {

@YixingZhang007 YixingZhang007 changed the title Add support for arbitrary integer with bitwidth larger than 64 bits in [SPIRV] Add support for arbitrary-precision integers larger than 64 bits in SPIR-V backend Sep 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants