diff --git a/clang/test/Driver/clang-offload-wrapper-exe.cpp b/clang/test/Driver/clang-offload-wrapper-exe.cpp index f17a93f89d511..556a91d2667f9 100644 --- a/clang/test/Driver/clang-offload-wrapper-exe.cpp +++ b/clang/test/Driver/clang-offload-wrapper-exe.cpp @@ -4,12 +4,22 @@ // works, and that the tool generates data properly accessible at runtime. // --- Prepare test data -// RUN: echo -e -n 'device binary image\n' > %t.bin +// - create the first binary image +// RUN: echo -e -n 'device binary image1\n' > %t.bin // RUN: echo -e -n '[Category1]\nint_prop1=1|10\n[Category2]\nint_prop2=1|20\n' > %t.props // RUN: echo -e -n 'kernel1\nkernel2\n' > %t.sym // RUN: echo -e -n 'Manifest file - arbitrary data generated by the toolchain\n' > %t.mnf + +// - create the second binary image with byte array property values +// RUN: echo -e -n 'device binary image2\n' > %t_1.bin +// RUN: echo -e -n '[Category3]\n' > %t_1.props +// RUN: echo -e -n 'kernel1=2|IAAAAAAAAAQA\n' >> %t_1.props +// RUN: echo -e -n 'kernel2=2|oAAAAAAAAAw///3/wB\n' >> %t_1.props + +// - create the batch file input for the wrapper // RUN: echo '[Code|Properties|Symbols|Manifest]' > %t.batch // RUN: echo %t.bin"|"%t.props"|"%t.sym"|"%t.mnf >> %t.batch +// RUN: echo %t_1.bin"|"%t_1.props"|""|" >> %t.batch // --- Generate "gold" output // RUN: cat %t.bin %t.mnf %t.props %t.sym > %t.all // --- Create the wrapper object @@ -22,6 +32,8 @@ // RUN: %t.batch.exe > %t.batch.exe.out // RUN: diff -b %t.batch.exe.out %t.all +#include +#include #include #include @@ -105,31 +117,97 @@ static int getInt(void *Addr) { return Ptr[0] | (Ptr[1] << 8) | (Ptr[2] << 16) | (Ptr[3] << 24); } -int main(int argc, char **argv) { - ASSERT(BinDesc->NumDeviceBinaries == 1, "BinDesc->NumDeviceBinaries"); - ASSERT(BinDesc->Version == 1, "BinDesc->Version"); +using byte = unsigned char; + +static void printProp(const pi_device_binary_property &Prop) { + std::cerr << "Property " << Prop->Name << " {\n"; + std::cerr << " Type: " << Prop->Type << "\n"; + if (Prop->Type != 1) + std::cerr << " Size = " << Prop->ValSize << "\n"; + + std::cerr << " Value = "; + if (Prop->Type == 1) + std::cerr << getInt(&Prop->ValSize); + else { + std::cerr << " {\n "; - for (int I = 0; I < BinDesc->NumDeviceBinaries; ++I) { - pi_device_binary Bin = &BinDesc->DeviceBinaries[I]; - ASSERT(Bin->Kind == 4, "Bin->Kind"); - ASSERT(Bin->Format == 1, "Bin->Format"); - - // dump code - std::cout << getString(Bin->BinaryStart, Bin->BinaryEnd); - // dump manifest - std::cout << std::string(Bin->ManifestStart, Bin->ManifestEnd - Bin->ManifestStart); - // dump properties - for (pi_device_binary_property_set PropSet = Bin->PropertySetsBegin; PropSet != Bin->PropertySetsEnd; ++PropSet) { - std::cout << "[" << PropSet->Name << "]" - << "\n"; - - for (pi_device_binary_property Prop = PropSet->PropertiesBegin; Prop != PropSet->PropertiesEnd; ++Prop) - // values fitting into 64 bits are written into the 'ValAddr' field: - std::cout << Prop->Name << "=" << Prop->Type << "|" << getInt(&Prop->ValSize) << "\n"; + byte *Ptr = (byte *)Prop->ValAddr; + + for (auto I = 0; I < Prop->ValSize && I < 100; ++I) { + std::cerr << " 0x" << std::hex << (unsigned int)Ptr[I]; + std::cerr << std::dec; + } + std::cerr << "\n }"; + } + std::cerr << "\n"; + std::cerr << "}\n"; +} + +static int dumpBinary0() { + pi_device_binary Bin = &BinDesc->DeviceBinaries[0]; + ASSERT(Bin->Kind == 4, "Bin->Kind"); + ASSERT(Bin->Format == 1, "Bin->Format"); + + // dump code + std::cout << getString(Bin->BinaryStart, Bin->BinaryEnd); + // dump manifest + std::cout << std::string(Bin->ManifestStart, Bin->ManifestEnd - Bin->ManifestStart); + // dump properties + for (pi_device_binary_property_set PropSet = Bin->PropertySetsBegin; PropSet != Bin->PropertySetsEnd; ++PropSet) { + std::cout << "[" << PropSet->Name << "]" + << "\n"; + + for (pi_device_binary_property Prop = PropSet->PropertiesBegin; Prop != PropSet->PropertiesEnd; ++Prop) { + ASSERT(Prop->Type == 1, "Prop->Type"); + std::cout << Prop->Name << "=" << Prop->Type << "|" << getInt(&Prop->ValSize) << "\n"; } - // dump symbols - for (_pi_offload_entry Entry = Bin->EntriesBegin; Entry != Bin->EntriesEnd; ++Entry) - std::cout << Entry->name << "\n"; } + // dump symbols + for (_pi_offload_entry Entry = Bin->EntriesBegin; Entry != Bin->EntriesEnd; ++Entry) + std::cout << Entry->name << "\n"; return 0; -} \ No newline at end of file +} + +// Clang offload wrapper does Base64 decoding on byte array property values, so +// they can't be dumped as is and compared to the original. Instead, this +// testcase checks that the byte array in the property value is equal to the +// pre-decoded byte array. +static int checkBinary1() { + // Decoded from "IAAAAAAAAAQA": + const byte Arr0[] = {8, 0, 0, 0, 0, 0, 0, 0, 0x1}; + // Decoded from "oAAAAAAAAAw///3/wB": + const byte Arr1[] = {40, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0x7F, 0xFF, 0x70}; + + struct { + const byte *Ptr; + const size_t Size; + } GoldArrays[] = { + {Arr0, sizeof(Arr0)}, + {Arr1, sizeof(Arr1)}}; + pi_device_binary Bin = &BinDesc->DeviceBinaries[1]; + ASSERT(Bin->Kind == 4, "Bin->Kind"); + ASSERT(Bin->Format == 1, "Bin->Format"); + + for (pi_device_binary_property_set PropSet = Bin->PropertySetsBegin; PropSet != Bin->PropertySetsEnd; ++PropSet) { + int Cnt = 0; + + for (pi_device_binary_property Prop = PropSet->PropertiesBegin; Prop != PropSet->PropertiesEnd; ++Prop, ++Cnt) { + ASSERT(Prop->Type == 2, "Prop->Type"); // must be a byte array + char *Ptr = reinterpret_cast(Prop->ValAddr); + int Cmp = std::memcmp(Prop->ValAddr, GoldArrays[Cnt].Ptr, GoldArrays[Cnt].Size); + ASSERT(Cmp == 0, "byte array property"); + } + } + return 0; +} + +int main(int argc, char **argv) { + ASSERT(BinDesc->NumDeviceBinaries == 2, "BinDesc->NumDeviceBinaries"); + ASSERT(BinDesc->Version == 1, "BinDesc->Version"); + + if (dumpBinary0() != 0) + return 1; + if (checkBinary1() != 0) + return 1; + return 0; +} diff --git a/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp b/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp index c4267c65e8471..8e2a1d91a132e 100644 --- a/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp +++ b/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp @@ -610,6 +610,15 @@ class BinaryWrapper { return std::make_pair(ImageB, ImageE); } + // Adds given data buffer as constant byte array and returns a constant + // pointer to it. The pointer type does not carry size information. + Constant *addRawDataToModule(ArrayRef Data, const Twine &Name) { + auto *Var = addGlobalArrayVariable(Name, Data); + auto *DataPtr = ConstantExpr::getGetElementPtr(Var->getValueType(), Var, + getSizetConstPair(0u, 0u)); + return DataPtr; + } + // Creates all necessary data objects for the given image and returns a pair // of pointers that point to the beginning and end of the global variable that // contains the image data. @@ -718,6 +727,14 @@ class BinaryWrapper { ConstantInt::get(Type::getInt64Ty(C), Prop.second.asUint32()); break; } + case llvm::util::PropertyValue::BYTE_ARRAY: { + const char *Ptr = + reinterpret_cast(Prop.second.asRawByteArray()); + uint64_t Size = Prop.second.getRawByteArraySize(); + PropValSize = ConstantInt::get(Type::getInt64Ty(C), Size); + PropValAddr = addRawDataToModule(ArrayRef(Ptr, Size), "prop_val"); + break; + } default: llvm_unreachable_internal("unsupported property"); } diff --git a/llvm/include/llvm/Support/Base64.h b/llvm/include/llvm/Support/Base64.h index 62064a35aa344..178acc50e966b 100644 --- a/llvm/include/llvm/Support/Base64.h +++ b/llvm/include/llvm/Support/Base64.h @@ -13,6 +13,10 @@ #ifndef LLVM_SUPPORT_BASE64_H #define LLVM_SUPPORT_BASE64_H +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" + +#include #include namespace llvm { @@ -51,6 +55,33 @@ template std::string encodeBase64(InputBytes const &Bytes) { return Buffer; } +// General-purpose Base64 encoder/decoder. +// TODO update WinCOFFObjectWriter.cpp to use this library. +class Base64 { +public: + using byte = uint8_t; + + // Get the size of the encoded byte sequence of given size. + static size_t getEncodedSize(size_t SrcSize); + + // Encode a byte sequence of given size into an output stream. + // Returns the number of bytes in the encoded result. + static size_t encode(const byte *Src, raw_ostream &Out, size_t SrcSize); + + // Get the size of the encoded byte sequence of given size. + static size_t getDecodedSize(size_t SrcSize); + + // Decode a sequence of given size into a pre-allocated memory. + // Returns the number of bytes in the decoded result or 0 in case of error. + static Expected decode(const char *Src, byte *Dst, size_t SrcSize); + + // Allocate minimum required amount of memory and decode a sequence of given + // size into it. + // Returns the decoded result. The size can be obtained via getDecodedSize. + static Expected> decode(const char *Src, + size_t SrcSize); +}; + } // end namespace llvm -#endif +#endif // LLVM_SUPPORT_BASE64_H diff --git a/llvm/include/llvm/Support/PropertySetIO.h b/llvm/include/llvm/Support/PropertySetIO.h index f212c54b05919..15587584a784e 100644 --- a/llvm/include/llvm/Support/PropertySetIO.h +++ b/llvm/include/llvm/Support/PropertySetIO.h @@ -51,8 +51,14 @@ namespace util { // container. class PropertyValue { public: + // Type of the size of the value. Value size gets serialized along with the + // value data in some cases for later reading at runtime, so size_t is not + // suitable as its size varies. + using SizeTy = uint64_t; + using byte = uint8_t; + // Defines supported property types - enum Type { first = 0, NONE = first, UINT32, last = UINT32 }; + enum Type { first = 0, NONE = first, UINT32, BYTE_ARRAY, last = BYTE_ARRAY }; // Translates C++ type to the corresponding type tag. template static Type getTypeTag(); @@ -64,37 +70,85 @@ class PropertyValue { return static_cast(T); } + ~PropertyValue() { + if ((getType() == BYTE_ARRAY) && Val.ByteArrayVal) + delete[] Val.ByteArrayVal; + } + PropertyValue() = default; PropertyValue(Type T) : Ty(T) {} PropertyValue(uint32_t Val) : Ty(UINT32), Val({Val}) {} - PropertyValue(const PropertyValue &P) = default; - PropertyValue(PropertyValue &&P) = default; + PropertyValue(const byte *Data, SizeTy DataBitSize); + PropertyValue(const PropertyValue &P); + PropertyValue(PropertyValue &&P); - PropertyValue &operator=(PropertyValue &&P) = default; + PropertyValue &operator=(PropertyValue &&P); - PropertyValue &operator=(const PropertyValue &P) = default; + PropertyValue &operator=(const PropertyValue &P); // get property value as unsigned 32-bit integer uint32_t asUint32() const { - assert(Ty == UINT32); + if (Ty != UINT32) + llvm_unreachable("must be UINT32 value"); return Val.UInt32Val; } + // Get raw data size in bits. + SizeTy getByteArraySizeInBits() const { + if (Ty != BYTE_ARRAY) + llvm_unreachable("must be BYTE_ARRAY value"); + SizeTy Res = 0; + + for (auto I = 0; I < sizeof(SizeTy); ++I) + Res |= (SizeTy)Val.ByteArrayVal[I] << (8 * I); + return Res; + } + + // Get byte array data size in bytes. + SizeTy getByteArraySize() const { + SizeTy SizeInBits = getByteArraySizeInBits(); + constexpr unsigned int MASK = 0x7; + return ((SizeInBits + MASK) & ~MASK) / 8; + } + + // Get byte array data size in bytes, including the leading bytes encoding the + // size. + SizeTy getRawByteArraySize() const { + return getByteArraySize() + sizeof(SizeTy); + } + + // Get byte array data including the leading bytes encoding the size. + const byte *asRawByteArray() const { + if (Ty != BYTE_ARRAY) + llvm_unreachable("must be BYTE_ARRAY value"); + return Val.ByteArrayVal; + } + + // Get byte array data excluding the leading bytes encoding the size. + const byte *asByteArray() const { + if (Ty != BYTE_ARRAY) + llvm_unreachable("must be BYTE_ARRAY value"); + return Val.ByteArrayVal + sizeof(SizeTy); + } + bool isValid() const { return getType() != NONE; } // set property value; the 'T' type must be convertible to a property type tag template void set(T V) { - assert(getTypeTag() == Ty); + if (getTypeTag() != Ty) + llvm_unreachable("invalid type tag for this operation"); getValueRef() = V; } Type getType() const { return Ty; } - size_t size() const { + SizeTy size() const { switch (Ty) { case UINT32: return sizeof(Val.UInt32Val); + case BYTE_ARRAY: + return getRawByteArraySize(); default: llvm_unreachable_internal("unsupported property type"); } @@ -102,15 +156,17 @@ class PropertyValue { private: template T &getValueRef(); + void copy(const PropertyValue &P); Type Ty = NONE; + // TODO: replace this union with std::variant when uplifting to C++17 union { uint32_t UInt32Val; + // Holds first sizeof(size_t) bytes of size followed by actual raw data. + byte *ByteArrayVal; } Val; }; -std::ostream &operator<<(std::ostream &Out, const PropertyValue &V); - // A property set. Preserves insertion order when iterating elements. using PropertySet = MapVector; @@ -124,6 +180,7 @@ class PropertySetRegistry { static constexpr char SYCL_SPECIALIZATION_CONSTANTS[] = "SYCL/specialization constants"; static constexpr char SYCL_DEVICELIB_REQ_MASK[] = "SYCL/devicelib req mask"; + static constexpr char SYCL_KERNEL_PARAM_OPT_INFO[] = "SYCL/kernel param opt"; // Function for bulk addition of an entire property set under given category // (property set name). diff --git a/llvm/lib/Support/Base64.cpp b/llvm/lib/Support/Base64.cpp new file mode 100644 index 0000000000000..fcee900c7f455 --- /dev/null +++ b/llvm/lib/Support/Base64.cpp @@ -0,0 +1,188 @@ +//===--- Base64.cpp - Base64 Encoder/Decoder Implementaion ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/Base64.h" + +#include + +using namespace llvm; + +namespace { + +using byte = Base64::byte; + +::llvm::Error makeError(const Twine &Msg) { + return createStringError(std::error_code{}, Msg); +} + +class Base64Impl { +private: + static constexpr char EncodingTable[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + + static_assert(sizeof(EncodingTable) == 65, ""); + + // Compose an index into the encoder table from two bytes and the number of + // significant bits in the lower byte until the byte boundary. + static inline int composeInd(byte ByteLo, byte ByteHi, int BitsLo) { + int Res = ((ByteHi << BitsLo) | (ByteLo >> (8 - BitsLo))) & 0x3F; + return Res; + } + + // Decode a single character. + static inline int decode(char Ch) { + if (Ch >= 'A' && Ch <= 'Z') // 0..25 + return Ch - 'A'; + else if (Ch >= 'a' && Ch <= 'z') // 26..51 + return Ch - 'a' + 26; + else if (Ch >= '0' && Ch <= '9') // 52..61 + return Ch - '0' + 52; + else if (Ch == '+') // 62 + return 62; + else if (Ch == '/') // 63 + return 63; + return -1; + } + + // Decode a quadruple of characters. + static inline Expected decode4(const char *Src, byte *Dst) { + int BadCh = -1; + + for (auto I = 0; I < 4; ++I) { + char Ch = Src[I]; + int Byte = decode(Ch); + + if (Byte < 0) { + BadCh = Ch; + break; + } + Dst[I] = (byte)Byte; + } + if (BadCh == -1) + return true; + return makeError("invalid char in Base64Impl encoding: 0x" + Twine(BadCh)); + } + +public: + static size_t getEncodedSize(size_t SrcSize) { + constexpr int ByteSizeInBits = 8; + constexpr int EncBitsPerChar = 6; + return (SrcSize * ByteSizeInBits + (EncBitsPerChar - 1)) / EncBitsPerChar; + } + + static size_t encode(const byte *Src, raw_ostream &Out, size_t SrcSize) { + size_t Off = 0; + + // encode full byte triples + for (auto TriB = 0; TriB < SrcSize / 3; ++TriB) { + Off = TriB * 3; + byte Byte0 = Src[Off++]; + byte Byte1 = Src[Off++]; + byte Byte2 = Src[Off++]; + + Out << EncodingTable[Byte0 & 0x3F]; + Out << EncodingTable[composeInd(Byte0, Byte1, 2)]; + Out << EncodingTable[composeInd(Byte1, Byte2, 4)]; + Out << EncodingTable[(Byte2 >> 2) & 0x3F]; + } + // encode the remainder + int RemBytes = SrcSize - Off; + + if (RemBytes > 0) { + byte Byte0 = Src[Off + 0]; + Out << EncodingTable[Byte0 & 0x3F]; + + if (RemBytes > 1) { + byte Byte1 = Src[Off + 1]; + Out << EncodingTable[composeInd(Byte0, Byte1, 2)]; + Out << EncodingTable[(Byte1 >> 4) & 0x3F]; + } else { + Out << EncodingTable[(Byte0 >> 6) & 0x3F]; + } + } + return getEncodedSize(SrcSize); + } + + static size_t getDecodedSize(size_t SrcSize) { return (SrcSize * 3 + 3) / 4; } + + static Expected decode(const char *Src, byte *Dst, size_t SrcSize) { + size_t SrcOff = 0; + size_t DstOff = 0; + + // decode full quads + for (auto Qch = 0; Qch < SrcSize / 4; ++Qch, SrcOff += 4, DstOff += 3) { + byte Ch[4]; + Expected TrRes = decode4(Src + SrcOff, Ch); + + if (!TrRes) + return TrRes.takeError(); + // each quad of chars produces three bytes of output + Dst[DstOff + 0] = Ch[0] | (Ch[1] << 6); + Dst[DstOff + 1] = (Ch[1] >> 2) | (Ch[2] << 4); + Dst[DstOff + 2] = (Ch[2] >> 4) | (Ch[3] << 2); + } + auto RemChars = SrcSize - SrcOff; + + if (RemChars == 0) + return DstOff; + // decode the remainder; variants: + // 2 chars remain - produces single byte + // 3 chars remain - produces two bytes + + if (RemChars != 2 && RemChars != 3) + return makeError("invalid encoded sequence length"); + + int Ch0 = decode(Src[SrcOff++]); + int Ch1 = decode(Src[SrcOff++]); + int Ch2 = RemChars == 3 ? decode(Src[SrcOff]) : 0; + + if (Ch0 < 0 || Ch1 < 0 || Ch2 < 0) + return makeError("invalid characters in the encoded sequence remainder"); + Dst[DstOff++] = Ch0 | (Ch1 << 6); + + if (RemChars == 3) + Dst[DstOff++] = (Ch1 >> 2) | (Ch2 << 4); + return DstOff; + } + + static Expected> decode(const char *Src, + size_t SrcSize) { + size_t DstSize = getDecodedSize(SrcSize); + byte *Dst = new byte[DstSize]; + Expected Res = decode(Src, Dst, SrcSize); + if (!Res) + return Res.takeError(); + return std::unique_ptr(Dst); + } +}; + +constexpr char Base64Impl::EncodingTable[]; + +} // anonymous namespace + +size_t Base64::getEncodedSize(size_t SrcSize) { + return Base64Impl::getEncodedSize(SrcSize); +} + +size_t Base64::encode(const byte *Src, raw_ostream &Out, size_t SrcSize) { + return Base64Impl::encode(Src, Out, SrcSize); +} + +size_t Base64::getDecodedSize(size_t SrcSize) { + return Base64Impl::getDecodedSize(SrcSize); +} + +Expected Base64::decode(const char *Src, byte *Dst, size_t SrcSize) { + return Base64Impl::decode(Src, Dst, SrcSize); +} + +Expected> Base64::decode(const char *Src, + size_t SrcSize) { + return Base64Impl::decode(Src, SrcSize); +} diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt index 53ad806d818f5..79aabdc07e3cd 100644 --- a/llvm/lib/Support/CMakeLists.txt +++ b/llvm/lib/Support/CMakeLists.txt @@ -63,6 +63,7 @@ add_llvm_component_library(LLVMSupport ARMAttributeParser.cpp ARMWinEH.cpp Allocator.cpp + Base64.cpp BinaryStreamError.cpp BinaryStreamReader.cpp BinaryStreamRef.cpp diff --git a/llvm/lib/Support/PropertySetIO.cpp b/llvm/lib/Support/PropertySetIO.cpp index c8cd4dbdd9625..9ddf3ca70caa0 100644 --- a/llvm/lib/Support/PropertySetIO.cpp +++ b/llvm/lib/Support/PropertySetIO.cpp @@ -9,17 +9,30 @@ #include "llvm/Support/PropertySetIO.h" #include "llvm/ADT/APInt.h" +#include "llvm/Support/Base64.h" #include "llvm/Support/Error.h" #include "llvm/Support/LineIterator.h" +#include +#include + using namespace llvm::util; using namespace llvm; +namespace { + +using byte = Base64::byte; + +::llvm::Error makeError(const Twine &Msg) { + return createStringError(std::error_code{}, Msg); +} + +} // anonymous namespace + Expected> PropertySetRegistry::read(const MemoryBuffer *Buf) { auto Res = std::make_unique(); PropertySet *CurPropSet = nullptr; - std::error_code EC; for (line_iterator LI(*Buf); !LI.is_at_end(); LI++) { // see if this line starts a new property set @@ -27,27 +40,27 @@ PropertySetRegistry::read(const MemoryBuffer *Buf) { // yes - parse the category (property name) auto EndPos = LI->rfind(']'); if (EndPos == StringRef::npos) - return createStringError(EC, "invalid line: " + *LI); + return makeError("invalid line: " + *LI); StringRef Category = LI->substr(1, EndPos - 1); CurPropSet = &(*Res)[Category]; continue; } if (!CurPropSet) - return createStringError(EC, "property category missing"); + return makeError("property category missing"); // parse name and type+value auto Parts = LI->split('='); if (Parts.first.empty() || Parts.second.empty()) - return createStringError(EC, "invalid property line: " + *LI); + return makeError("invalid property line: " + *LI); auto TypeVal = Parts.second.split('|'); if (TypeVal.first.empty() || TypeVal.second.empty()) - return createStringError(EC, "invalid property value: " + Parts.second); + return makeError("invalid property value: " + Parts.second); APInt Tint; // parse type if (TypeVal.first.getAsInteger(10, Tint)) - return createStringError(EC, "invalid property type: " + TypeVal.first); + return makeError("invalid property type: " + TypeVal.first); Expected Ttag = PropertyValue::getTypeTag(static_cast(Tint.getSExtValue())); StringRef Val = TypeVal.second; @@ -61,17 +74,27 @@ PropertySetRegistry::read(const MemoryBuffer *Buf) { case PropertyValue::Type::UINT32: { APInt ValV; if (Val.getAsInteger(10, ValV)) - return createStringError(EC, "invalid property value: ", Val.data()); + return createStringError(std::error_code{}, + "invalid property value: ", Val.data()); Prop.set(static_cast(ValV.getZExtValue())); break; } + case PropertyValue::Type::BYTE_ARRAY: { + Expected> DecArr = + Base64::decode(Val.data(), Val.size()); + if (!DecArr) + return DecArr.takeError(); + Prop.set(DecArr.get().release()); + break; + } default: - return createStringError(EC, "unsupported property type: ", Ttag.get()); + return createStringError(std::error_code{}, + "unsupported property type: ", Ttag.get()); } - (*CurPropSet)[Parts.first] = Prop; + (*CurPropSet)[Parts.first] = std::move(Prop); } if (!CurPropSet) - return createStringError(EC, "invalid property set registry"); + return makeError("invalid property set registry"); return Expected>(std::move(Res)); } @@ -84,6 +107,11 @@ raw_ostream &operator<<(raw_ostream &Out, const PropertyValue &Prop) { case PropertyValue::Type::UINT32: Out << Prop.asUint32(); break; + case PropertyValue::Type::BYTE_ARRAY: { + util::PropertyValue::SizeTy Size = Prop.getRawByteArraySize(); + Base64::encode(Prop.asRawByteArray(), Out, (size_t)Size); + break; + } default: llvm_unreachable( ("unsupported property type: " + utostr(Prop.getType())).c_str()); @@ -108,11 +136,67 @@ namespace util { template <> uint32_t &PropertyValue::getValueRef() { return Val.UInt32Val; } + +template <> byte *&PropertyValue::getValueRef() { + return Val.ByteArrayVal; +} + template <> PropertyValue::Type PropertyValue::getTypeTag() { return UINT32; } +template <> PropertyValue::Type PropertyValue::getTypeTag() { + return BYTE_ARRAY; +} + +PropertyValue::PropertyValue(const byte *Data, SizeTy DataBitSize) { + constexpr int ByteSizeInBits = 8; + Ty = BYTE_ARRAY; + SizeTy DataSize = (DataBitSize + (ByteSizeInBits - 1)) / ByteSizeInBits; + constexpr size_t SizeFieldSize = sizeof(SizeTy); + + // Allocate space for size and data. + Val.ByteArrayVal = new byte[SizeFieldSize + DataSize]; + + // Write the size into first bytes. + for (auto I = 0; I < SizeFieldSize; ++I) { + Val.ByteArrayVal[I] = (byte)DataBitSize; + DataBitSize >>= ByteSizeInBits; + } + // Append data. + std::memcpy(Val.ByteArrayVal + SizeFieldSize, Data, DataSize); +} + +PropertyValue::PropertyValue(const PropertyValue &P) { *this = P; } + +PropertyValue::PropertyValue(PropertyValue &&P) { *this = std::move(P); } + +PropertyValue &PropertyValue::operator=(PropertyValue &&P) { + copy(P); + + if (P.getType() == BYTE_ARRAY) + P.Val.ByteArrayVal = nullptr; + P.Ty = NONE; + return *this; +} + +PropertyValue &PropertyValue::operator=(const PropertyValue &P) { + if (P.getType() == BYTE_ARRAY) + *this = + std::move(PropertyValue(P.asByteArray(), P.getByteArraySizeInBits())); + else + copy(P); + return *this; +} + +void PropertyValue::copy(const PropertyValue &P) { + Ty = P.Ty; + Val = P.Val; +} + constexpr char PropertySetRegistry::SYCL_SPECIALIZATION_CONSTANTS[]; constexpr char PropertySetRegistry::SYCL_DEVICELIB_REQ_MASK[]; +constexpr char PropertySetRegistry::SYCL_KERNEL_PARAM_OPT_INFO[]; + } // namespace util } // namespace llvm diff --git a/llvm/test/tools/sycl-post-link/omit_kernel_args.ll b/llvm/test/tools/sycl-post-link/omit_kernel_args.ll new file mode 100644 index 0000000000000..ef80867ea2dca --- /dev/null +++ b/llvm/test/tools/sycl-post-link/omit_kernel_args.ll @@ -0,0 +1,44 @@ +; This test checks that the post-link tool generates correct kernel parameter +; optimization info into a property file if the source IR contained +; corresponding metadata. +; +; RUN: sycl-post-link -emit-param-info -S %s -o %t.files.table +; RUN: FileCheck %s -input-file=%t.files.table --check-prefixes CHECK-TABLE +; RUN: FileCheck %s -input-file=%t.files_0.prop --match-full-lines --check-prefixes CHECK-PROP + +target triple = "spir64-unknown-unknown-sycldevice" + +define weak_odr spir_kernel void @SpirKernel1(float %arg1) !spir_kernel_omit_args !0 { + call void @foo(float %arg1) + ret void +} + +define weak_odr spir_kernel void @SpirKernel2(i8 %arg1, i8 %arg2, i8 %arg3) !spir_kernel_omit_args !1 { + call void @bar(i8 %arg1) + call void @bar(i8 %arg2) + call void @bar(i8 %arg3) + ret void +} + +declare void @foo(float) +declare void @bar(i8) + +; // HEX: 0x01 (1 byte) +; // BIN: 01 - 2 arguments total, 1 remains +; // NOTE: In the metadata string below, unlike bitmask above, least significant +; // element goes first +!0 = !{i1 true, i1 false} + +; // HEX: 0x7E 0x06 (2 bytes) +; // BIN: 110 01111110 - 11 arguments total, 3 remain +!1 = !{i1 false, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true} + +; CHECK-PROP: [SYCL/kernel param opt] +; // Base64 encoding in the prop file (including 8 bytes length): +; CHECK-PROP-NEXT-DAG: SpirKernel1=2|CAAAAAAAAAQA +; // Base64 encoding in the prop file (including 8 bytes length): +; CHECK-PROP-NEXT-DAG: SpirKernel2=2|LAAAAAAAAAgfGA + +; CHECK-TABLE: [Code|Properties] +; CHECK-TABLE-NEXT: {{.*}}files_0.prop +; CHECK-TABLE-EMPTY: diff --git a/llvm/tools/sycl-post-link/CMakeLists.txt b/llvm/tools/sycl-post-link/CMakeLists.txt index 708ad048221c2..2aac0cf13047c 100644 --- a/llvm/tools/sycl-post-link/CMakeLists.txt +++ b/llvm/tools/sycl-post-link/CMakeLists.txt @@ -9,6 +9,7 @@ set(LLVM_LINK_COMPONENTS add_llvm_tool(sycl-post-link sycl-post-link.cpp + SPIRKernelParamOptInfo.cpp SpecConstants.cpp DEPENDS diff --git a/llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.cpp b/llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.cpp new file mode 100644 index 0000000000000..bb404c6211239 --- /dev/null +++ b/llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.cpp @@ -0,0 +1,51 @@ +//==-- SPIRKernelParamOptInfo.cpp -- get kernel param optimization info ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "SPIRKernelParamOptInfo.h" + +#include "llvm/IR/Constants.h" +#include "llvm/Support/Casting.h" + +namespace { + +// Must match the one produced by DeadArgumentElimination +static constexpr char MetaDataID[] = "spir_kernel_omit_args"; + +} // anonymous namespace + +namespace llvm { + +void SPIRKernelParamOptInfo::releaseMemory() { clear(); } + +SPIRKernelParamOptInfo +SPIRKernelParamOptInfoAnalysis::run(Module &M, ModuleAnalysisManager &AM) { + SPIRKernelParamOptInfo Res; + + for (const Function &F : M) { + MDNode *MD = F.getMetadata(MetaDataID); + if (!MD) + continue; + using BaseTy = SPIRKernelParamOptInfoBaseTy; + auto Ins = + Res.insert(BaseTy::value_type{F.getName(), BaseTy::mapped_type{}}); + assert(Ins.second && "duplicate kernel?"); + BitVector &ParamDropped = Ins.first->second; + + for (unsigned int I = 0; I < MD->getNumOperands(); ++I) { + const auto *MDInt1 = cast(MD->getOperand(I)); + unsigned ID = static_cast( + cast(MDInt1->getValue())->getValue().getZExtValue()); + ParamDropped.push_back(ID != 0); + } + } + return Res; +} + +AnalysisKey SPIRKernelParamOptInfoAnalysis::Key; + +} // namespace llvm \ No newline at end of file diff --git a/llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.h b/llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.h new file mode 100644 index 0000000000000..a5ba8f1d4eadb --- /dev/null +++ b/llvm/tools/sycl-post-link/SPIRKernelParamOptInfo.h @@ -0,0 +1,45 @@ +//==-- SPIRKernelParamOptInfo.h -- get kernel param optimization info ------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Analysis pass which collects parameter optimization info for SPIR kernels. +// Currently the info is whether i'th kernel parameter as emitted by the FE has +// been optimized away by the LLVM optimizer. The information is produced by +// DeadArgumentElimination transformation and stored into a specific metadata +// attached to kernel functions in a module. +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +// the StringRef key refers to a function name +using SPIRKernelParamOptInfoBaseTy = DenseMap; + +class SPIRKernelParamOptInfo : public SPIRKernelParamOptInfoBaseTy { +public: + void releaseMemory(); +}; + +class SPIRKernelParamOptInfoAnalysis + : public AnalysisInfoMixin { + friend AnalysisInfoMixin; + + static AnalysisKey Key; + +public: + /// Provide the result type for this analysis pass. + using Result = SPIRKernelParamOptInfo; + + /// Run the analysis pass over a function and produce BPI. + SPIRKernelParamOptInfo run(Module &M, ModuleAnalysisManager &AM); +}; + +} // namespace llvm diff --git a/llvm/tools/sycl-post-link/sycl-post-link.cpp b/llvm/tools/sycl-post-link/sycl-post-link.cpp index 13d7d5e4c45a1..a7fa663e31f7d 100644 --- a/llvm/tools/sycl-post-link/sycl-post-link.cpp +++ b/llvm/tools/sycl-post-link/sycl-post-link.cpp @@ -13,7 +13,9 @@ // - specialization constant intrinsic transformation //===----------------------------------------------------------------------===// +#include "SPIRKernelParamOptInfo.h" #include "SpecConstants.h" + #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Triple.h" #include "llvm/Bitcode/BitcodeWriterPass.h" @@ -34,6 +36,7 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/GlobalDCE.h" #include "llvm/Transforms/Utils/Cloning.h" + #include using namespace llvm; @@ -117,11 +120,16 @@ static cl::opt SpecConstLower{ "set spec constants to C++ defaults")), cl::cat(PostLinkCat)}; +static cl::opt EmitKernelParamInfo{ + "emit-param-info", cl::desc("emit kernel parameter optimization info"), + cl::cat(PostLinkCat)}; + struct ImagePropSaveInfo { bool NeedDeviceLibReqMask; bool DoSpecConst; bool SetSpecConstAtRT; bool SpecConstsMet; + bool EmitKernelParamInfo; }; // Please update DeviceLibFuncMap if any item is added to or removed from // fallback device libraries in libdevice. @@ -508,6 +516,30 @@ static string_vector saveDeviceImageProperty( llvm::util::PropertySetRegistry::SYCL_SPECIALIZATION_CONSTANTS, TmpSpecIDMap); } + if (ImgPSInfo.EmitKernelParamInfo) { + // extract kernel parameter optimization info per module + ModuleAnalysisManager MAM; + // Register required analysis + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + // Register the payload analysis + MAM.registerPass([&] { return SPIRKernelParamOptInfoAnalysis(); }); + SPIRKernelParamOptInfo PInfo = + MAM.getResult(*ResultModules[I]); + + // convert analysis results into properties and record them + llvm::util::PropertySet &Props = + PropSet[llvm::util::PropertySetRegistry::SYCL_KERNEL_PARAM_OPT_INFO]; + + for (const auto &NameInfoPair : PInfo) { + const llvm::BitVector &Bits = NameInfoPair.second; + const llvm::ArrayRef Arr = NameInfoPair.second.getData(); + const unsigned char *Data = + reinterpret_cast(Arr.begin()); + llvm::util::PropertyValue::SizeTy DataBitSize = Bits.size(); + Props.insert(std::make_pair( + NameInfoPair.first, llvm::util::PropertyValue(Data, DataBitSize))); + } + } std::error_code EC; std::string SCFile = makeResultFileName(".prop", I); raw_fd_ostream SCOut(SCFile, EC); @@ -580,8 +612,9 @@ int main(int argc, char **argv) { bool DoSplit = SplitMode.getNumOccurrences() > 0; bool DoSpecConst = SpecConstLower.getNumOccurrences() > 0; + bool DoParamInfo = EmitKernelParamInfo.getNumOccurrences() > 0; - if (!DoSplit && !DoSpecConst && !DoSymGen) { + if (!DoSplit && !DoSpecConst && !DoSymGen && !DoParamInfo) { errs() << "no actions specified; try --help for usage info\n"; return 1; } @@ -595,6 +628,11 @@ int main(int argc, char **argv) { << IROutputOnly.ArgStr << "\n"; return 1; } + if (IROutputOnly && DoParamInfo) { + errs() << "error: -" << EmitKernelParamInfo.ArgStr << " can't be used with" + << " -" << IROutputOnly.ArgStr << "\n"; + return 1; + } SMDiagnostic Err; std::unique_ptr M = parseIRFile(InputFilename, Err, Context); // It is OK to use raw pointer here as we control that it does not outlive M @@ -664,7 +702,7 @@ int main(int argc, char **argv) { { ImagePropSaveInfo ImgPSInfo = {true, DoSpecConst, SetSpecConstAtRT, - SpecConstsMet}; + SpecConstsMet, EmitKernelParamInfo}; string_vector Files = saveDeviceImageProperty(ResultModules, ImgPSInfo); Error Err = Table.addColumn(COL_PROPS, Files); CHECK_AND_EXIT(Err); diff --git a/llvm/unittests/Support/Base64Test.cpp b/llvm/unittests/Support/Base64Test.cpp index 9622c866b38aa..0058d036c838d 100644 --- a/llvm/unittests/Support/Base64Test.cpp +++ b/llvm/unittests/Support/Base64Test.cpp @@ -1,5 +1,4 @@ -//===- llvm/unittest/Support/Base64Test.cpp - Base64 tests -//--------------------===// +//===- llvm/unittest/Support/Base64Test.cpp - Base64 tests ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -50,3 +49,62 @@ TEST(Base64Test, Base64) { TestBase64({LargeVector, sizeof(LargeVector)}, "VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIDEzIGxhenkgZG9ncy4="); } + +TEST(Base64Test, RoundTrip) { + using byte = unsigned char; + const byte Arr0[] = {0x1}; + const byte Arr1[] = {0x81}; // 0x81 - highest and lowest bits are set + const byte Arr2[] = {0x81, 0x81}; + const byte Arr3[] = {0x81, 0x81, 0x81}; + const byte Arr4[] = {0x81, 0x81, 0x81, 0x81}; + const byte Arr5[] = {0xFF, 0xFF, 0x7F, 0xFF, 0x70}; + const byte Arr6[] = {40, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0x7F, 0xFF, 0x70}; + const byte Arr7[] = {8, 0, 0, 0, 0, 0, 0, 0, 0x1}; + // 2 tests below model real usage case for argument opt info propagation: + // { 0x7E, 0x06 } + // 11001111110 - 11 arguments total, 3 remain + // encoded: "LAAAAAAAAAgfGA" + const byte Arr8[] = {11, 0, 0, 0, 0, 0, 0, 0, 0x7E, 0x06}; + // 0x01 + // 01 - 2 arguments total, 1 remains + // encoded: "CAAAAAAAAAQA" + const byte Arr9[] = {2, 0, 0, 0, 0, 0, 0, 0, 0x1}; + + struct { + const byte *Ptr; + size_t Size; + } Tests[] = {{Arr0, sizeof(Arr0)}, {Arr1, sizeof(Arr1)}, {Arr2, sizeof(Arr2)}, + {Arr3, sizeof(Arr3)}, {Arr4, sizeof(Arr4)}, {Arr5, sizeof(Arr5)}, + {Arr6, sizeof(Arr6)}, {Arr7, sizeof(Arr7)}, {Arr8, sizeof(Arr8)}, + {Arr9, sizeof(Arr9)}}; + + for (auto I = 0; I < sizeof(Tests) / sizeof(Tests[0]); ++I) { + std::string Encoded; + size_t Len; + { + llvm::raw_string_ostream OS(Encoded); + Len = Base64::encode(Tests[I].Ptr, OS, Tests[I].Size); + } + if (Len != Encoded.size()) { + FAIL() << "Base64::encode failed on test " << I << "\n"; + continue; + } + std::unique_ptr Decoded(new byte[Base64::getDecodedSize(Len)]); + Expected Res = Base64::decode(Encoded.data(), Decoded.get(), Len); + + if (!Res) { + FAIL() << "Base64::decode failed on test " << I << "\n"; + continue; + } + if (Res.get() != Tests[I].Size) { + FAIL() << "Base64::decode length mismatch, test " << I << "\n"; + continue; + } + std::string Gold((const char *)Tests[I].Ptr, Tests[I].Size); + std::string Test((const char *)Decoded.get(), Res.get()); + + if (Gold != Test) { + FAIL() << "Base64::decode result mismatch, test " << I << "\n"; + } + } +} diff --git a/llvm/unittests/Support/PropertySetIOTest.cpp b/llvm/unittests/Support/PropertySetIOTest.cpp index 329c7692fe3f1..dd1c04ca21cd2 100644 --- a/llvm/unittests/Support/PropertySetIOTest.cpp +++ b/llvm/unittests/Support/PropertySetIOTest.cpp @@ -42,4 +42,31 @@ TEST(PropertySet, IntValuesIO) { // Check that the original and the serialized version are equal ASSERT_EQ(Serialized, Content); } + +TEST(PropertySet, ByteArrayValuesIO) { + // '2' in '2|...' means 'byte array property', Base64-encoded + // encodes the following byte arrays: + // { 8, 0, 0, 0, 0, 0, 0, 0, 0x1 }; + // { 40, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0x7F, 0xFF, 0x70 }; + // first 8 bytes are the size in bits (40) of what follows (5 bytes). + + auto Content = "[Opt/Param]\n" + "kernel1=2|IAAAAAAAAAQA\n"; + "kernel2=2|oAAAAAAAAAw///3/wB\n"; + auto MemBuf = MemoryBuffer::getMemBuffer(Content); + // Parse a property set registry + auto PropSetsPtr = PropertySetRegistry::read(MemBuf.get()); + + if (!PropSetsPtr) + FAIL() << "PropertySetRegistry::read failed\n"; + + std::string Serialized; + { + llvm::raw_string_ostream OS(Serialized); + // Serialize + PropSetsPtr->get()->write(OS); + } + // Check that the original and the serialized version are equal + ASSERT_EQ(Serialized, Content); +} } // namespace \ No newline at end of file