Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 43 additions & 49 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,38 +66,6 @@ using namespace jit_compiler;

namespace {

class HashPreprocessedAction : public PreprocessorFrontendAction {
protected:
void ExecuteAction() override {
CompilerInstance &CI = getCompilerInstance();

std::string PreprocessedSource;
raw_string_ostream PreprocessStream(PreprocessedSource);

PreprocessorOutputOptions Opts;
Opts.ShowCPP = 1;
Opts.MinimizeWhitespace = 1;
// Make cache key insensitive to virtual source file and header locations.
Opts.ShowLineMarkers = 0;

DoPrintPreprocessedInput(CI.getPreprocessor(), &PreprocessStream, Opts);

Hash = BLAKE3::hash(arrayRefFromStringRef(PreprocessedSource));
Executed = true;
}

public:
BLAKE3Result<> takeHash() {
assert(Executed);
Executed = false;
return std::move(Hash);
}

private:
BLAKE3Result<> Hash;
bool Executed = false;
};

class SYCLToolchain {
SYCLToolchain() {
using namespace jit_compiler::resource;
Expand Down Expand Up @@ -318,28 +286,54 @@ Expected<std::string> jit_compiler::calculateHash(
std::vector<std::string> CommandLine =
createCommandLine(UserArgList, Format, SourceFile.Path);

HashPreprocessedAction HashAction;
class HashPreprocessedAction : public PreprocessorFrontendAction {
protected:
void ExecuteAction() override {
CompilerInstance &CI = getCompilerInstance();

if (SYCLToolchain::instance().run(CommandLine, HashAction,
getInMemoryFS(SourceFile, IncludeFiles))) {
BLAKE3Result<> SourceHash = HashAction.takeHash();
// Last argument is the source file in the format `rtc_N.cpp` which is
// unique for each query, so drop it:
CommandLine.pop_back();
std::string PreprocessedSource;
raw_string_ostream PreprocessStream(PreprocessedSource);

// TODO: Include hash of the current libsycl-jit.so/.dll somehow...
BLAKE3Result<> CommandLineHash =
BLAKE3::hash(arrayRefFromStringRef(join(CommandLine, ",")));
PreprocessorOutputOptions Opts;
Opts.ShowCPP = 1;
Opts.MinimizeWhitespace = 1;
// Make cache key insensitive to virtual source file and header locations.
Opts.ShowLineMarkers = 0;

std::string EncodedHash =
encodeBase64(SourceHash) + encodeBase64(CommandLineHash);
// Make the encoding filesystem-friendly.
std::replace(EncodedHash.begin(), EncodedHash.end(), '/', '-');
return std::move(EncodedHash);
DoPrintPreprocessedInput(CI.getPreprocessor(), &PreprocessStream, Opts);

} else {
Hasher.update(PreprocessedSource);
}

public:
HashPreprocessedAction(BLAKE3 &Hasher) : Hasher(Hasher) {}

private:
BLAKE3 &Hasher;
};

BLAKE3 Hasher;
HashPreprocessedAction HashAction{Hasher};

if (!SYCLToolchain::instance().run(CommandLine, HashAction,
getInMemoryFS(SourceFile, IncludeFiles)))
return createStringError("Calculating source hash failed");
}

Hasher.update(CLANG_VERSION_STRING);
Hasher.update(
ArrayRef<uint8_t>{reinterpret_cast<const uint8_t *>(&Format),
reinterpret_cast<const uint8_t *>(&Format + 1)});

// Last argument is "rtc_N.cpp" source file name which is never the same,
// ignore it:
for (auto &Opt : drop_end(CommandLine, 1))
Hasher.update(Opt);

std::string EncodedHash = encodeBase64(Hasher.result());

// Make the encoding filesystem-friendly.
std::replace(EncodedHash.begin(), EncodedHash.end(), '/', '-');
return std::move(EncodedHash);
}

Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
Expand Down