diff --git a/sycl/doc/extensions/proposed/sycl_ext_oneapi_syclbin.asciidoc b/sycl/doc/extensions/proposed/sycl_ext_oneapi_syclbin.asciidoc index 84d1ede6847a1..e308531453510 100644 --- a/sycl/doc/extensions/proposed/sycl_ext_oneapi_syclbin.asciidoc +++ b/sycl/doc/extensions/proposed/sycl_ext_oneapi_syclbin.asciidoc @@ -242,3 +242,142 @@ _{endnote}_] |==== +=== New free function for linking + +This extension adds the following new free functions to create and build a +kernel bundle in `ext_oneapi_source` state. + +|==== +a| +[frame=all,grid=none] +!==== +a! +[source,c++] +---- +namespace sycl::ext::oneapi::experimental { + +template +kernel_bundle +link(const std::vector>& objectBundles, + const std::vector& devs, PropertyListT props = {}); + +} // namespace sycl::ext::oneapi::experimental +---- +!==== + +_Constraints:_ Available only when `PropertyListT` is an instance of +`sycl::ext::oneapi::experimental::properties` which contains no properties +other than those listed below in the section "New properties for the +`link` function". + +_Effects:_ Duplicate device images from `objectBundles` are eliminated as though +they were joined via `join()`, then the remaining device images are translated +into one or more new device images of state `bundle_state::executable`, and a +new kernel bundle is created to contain these new device images. The new bundle +represents all of the kernels in `objectBundles` that are compatible with at +least one of the devices in `devs`. Any remaining kernels (those that are not +compatible with any of the devices in `devs`) are not linked and not represented +in the new bundle. + +The new bundle has the same associated context as those in `objectBundles`, and +the new bundle’s set of associated devices is `devs` (with duplicate devices +removed). + +_Returns:_ The new kernel bundle. + +_Throws:_ + +* An `exception` with the `errc::invalid` error code if the bundles in +`objectBundles` do not all have the same associated context. + +* An `exception` with the `errc::invalid` error code if any of the devices in +`devs` are not in the set of associated devices for any of the bundles in +`objectBundles` (as defined by `kernel_bundle::get_devices()`) or if the `devs` +vector is empty. + +* An `exception` with the `errc::build` error code if the online link operation +fails. + + +a| +[frame=all,grid=none] +!==== +a! +[source] +---- + +namespace sycl::ext::oneapi::experimental { + +template (1) +kernel_bundle +link(const kernel_bundle& objectBundle, + const std::vector& devs, PropertyListT props = {}); + +template (2) +kernel_bundle +link(const std::vector>& objectBundles, + PropertyListT props = {}); + +template (3) +kernel_bundle +link(const kernel_bundle& objectBundle, + PropertyListT props = {}); + +} // namespace sycl::ext::oneapi::experimental +---- +!==== + +_Effects (1):_ Equivalent to `link({objectBundle}, devs, props)`. + +_Effects (2):_ Equivalent to `link(objectBundles, devs, props)`, where `devs` is +the intersection of associated devices in common for all bundles in +`objectBundles`. + +_Effects (3):_ Equivalent to +`link({objectBundle}, objectBundle.get_devices(), props)`. + + +|==== + +=== New properties for the `link` function + +This extension adds the following properties, which can be used in conjunction +with the `link` function that is defined above: + +|==== +a| +[frame=all,grid=none] +!==== +a! +[source,c++] +---- +namespace sycl::ext::oneapi::experimental { + +struct fast_link { + fast_link(bool do_fast_link = true); (1) + + bool value; +}; +using fast_link_key = fast_link; + +template<> struct is_property_key : std::true_type {}; + +} // namespace sycl::ext::oneapi::experimental +---- +!==== + +This property instructs the `link` operation to do "fast linking". Enabling this +instructs the implementation to use device binary images that have been +pre-compiled. + +For example, SYCLBIN files may contain ahead-of-time compiled binary images +together with just-in-time compiled binary images, with the kernels and exported +functions potentially overlapping. When fast-linking is enabled, the +implementation will try to use the ahead-of-time compiled binary images over +their just-in-time compiled counterparts. + +_Effects (1):_ Creates a new `fast_link` property with a boolean value +indicating whether the `link` operation should do fast-linking. + +|==== + diff --git a/sycl/include/sycl/ext/oneapi/experimental/syclbin_kernel_bundle.hpp b/sycl/include/sycl/ext/oneapi/experimental/syclbin_kernel_bundle.hpp index 6e33c33c0ed75..c8e0c5794de08 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/syclbin_kernel_bundle.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/syclbin_kernel_bundle.hpp @@ -8,6 +8,7 @@ #pragma once +#include #include #include @@ -24,6 +25,13 @@ namespace sycl { inline namespace _V1 { + +namespace detail { +__SYCL_EXPORT std::shared_ptr +link_impl(const std::vector> &ObjectBundles, + const std::vector &Devs, bool FastLink); +} + namespace ext::oneapi::experimental { template @@ -77,6 +85,58 @@ get_kernel_bundle(const context &Ctxt, const std::filesystem::path &Filename, } #endif +template >> +kernel_bundle +link(const std::vector> &ObjectBundles, + const std::vector &Devs, PropertyListT Props = {}) { + std::vector UniqueDevices = + sycl::detail::removeDuplicateDevices(Devs); + + bool UseFastLink = [&]() { + if constexpr (Props.template has_property()) + return Props.template get_property().value; + return false; + }(); + + sycl::detail::KernelBundleImplPtr Impl = + sycl::detail::link_impl(ObjectBundles, UniqueDevices, UseFastLink); + return detail::createSyclObjFromImpl< + kernel_bundle>(std::move(Impl)); +} + +template >> +kernel_bundle +link(const kernel_bundle &ObjectBundle, + const std::vector &Devs, PropertyListT Props = {}) { + return link(std::vector>{ObjectBundle}, + Devs, Props); +} + +template >> +kernel_bundle +link(const std::vector> &ObjectBundles, + PropertyListT Props = {}) { + std::vector IntersectDevices = + sycl::detail::find_device_intersection(ObjectBundles); + return link(ObjectBundles, IntersectDevices, Props); +} + +template >> +kernel_bundle +link(const kernel_bundle &ObjectBundle, + PropertyListT Props = {}) { + return link(std::vector>{ObjectBundle}, + ObjectBundle.get_devices(), Props); +} + } // namespace ext::oneapi::experimental } // namespace _V1 } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/experimental/syclbin_properties.hpp b/sycl/include/sycl/ext/oneapi/experimental/syclbin_properties.hpp new file mode 100644 index 0000000000000..6c21a92ea719e --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/experimental/syclbin_properties.hpp @@ -0,0 +1,39 @@ +//==-------- syclbin_properties.hpp - SYCLBIN and tooling properties -------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace sycl { +inline namespace _V1 { + +namespace detail { +struct link_props; +} // namespace detail + +namespace ext::oneapi::experimental { + +///////////////////////// +// PropertyT syclex::fast_link +///////////////////////// +struct fast_link + : detail::run_time_property_key { + fast_link(bool DoFastLink = true) : value(DoFastLink) {} + + bool value; +}; +using fast_link_key = fast_link; + +template <> +struct is_property_key_of + : std::true_type {}; +} // namespace ext::oneapi::experimental +} // namespace _V1 +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/properties/property.hpp b/sycl/include/sycl/ext/oneapi/properties/property.hpp index d68ec8884d93b..c2b790f5bc3ba 100644 --- a/sycl/include/sycl/ext/oneapi/properties/property.hpp +++ b/sycl/include/sycl/ext/oneapi/properties/property.hpp @@ -228,8 +228,9 @@ enum PropKind : uint32_t { InitialThreshold = 83, MaximumSize = 84, ZeroInit = 85, + FastLink = 86, // PropKindSize must always be the last value. - PropKindSize = 86, + PropKindSize = 87, }; template struct PropertyToKind { diff --git a/sycl/include/sycl/sycl.hpp b/sycl/include/sycl/sycl.hpp index a09870dd77c30..fad384683d538 100644 --- a/sycl/include/sycl/sycl.hpp +++ b/sycl/include/sycl/sycl.hpp @@ -134,6 +134,7 @@ can be disabled by setting SYCL_DISABLE_FSYCL_SYCLHPP_WARNING macro.") #include #include #include +#include #include #include #include diff --git a/sycl/source/backend.cpp b/sycl/source/backend.cpp index 4866266c0dd72..5f7ac8af1e1fb 100644 --- a/sycl/source/backend.cpp +++ b/sycl/source/backend.cpp @@ -231,7 +231,7 @@ make_kernel_bundle(ur_native_handle_t NativeHandle, case (UR_PROGRAM_BINARY_TYPE_NONE): if (State == bundle_state::object) { auto Res = Adapter.call_nocheck( - UrProgram, 1u, &Dev, nullptr); + UrProgram, 1u, &Dev, ur_exp_program_flags_t{}, nullptr); if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { Res = Adapter.call_nocheck( ContextImpl.getHandleRef(), UrProgram, nullptr); @@ -241,7 +241,7 @@ make_kernel_bundle(ur_native_handle_t NativeHandle, else if (State == bundle_state::executable) { auto Res = Adapter.call_nocheck( - UrProgram, 1u, &Dev, nullptr); + UrProgram, 1u, &Dev, ur_exp_program_flags_t{}, nullptr); if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { Res = Adapter.call_nocheck( ContextImpl.getHandleRef(), UrProgram, nullptr); @@ -261,8 +261,8 @@ make_kernel_bundle(ur_native_handle_t NativeHandle, Managed UrLinkedProgram{Adapter}; ur_program_handle_t ProgramsToLink[] = {UrProgram}; auto Res = Adapter.call_nocheck( - ContextImpl.getHandleRef(), 1u, &Dev, 1u, ProgramsToLink, nullptr, - &UrLinkedProgram); + ContextImpl.getHandleRef(), 1u, &Dev, ur_exp_program_flags_t{}, 1u, + ProgramsToLink, nullptr, &UrLinkedProgram); if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { Res = Adapter.call_nocheck( ContextImpl.getHandleRef(), 1u, ProgramsToLink, nullptr, diff --git a/sycl/source/detail/device_image_impl.hpp b/sycl/source/detail/device_image_impl.hpp index f11782237db9b..9bd9088067dd6 100644 --- a/sycl/source/detail/device_image_impl.hpp +++ b/sycl/source/detail/device_image_impl.hpp @@ -761,7 +761,8 @@ class device_image_impl std::string XsFlags = extractXsFlags(BuildOptions, MRTCBinInfo->MLanguage); auto Res = Adapter.call_nocheck( - UrProgram, DeviceVec.size(), DeviceVec.data(), XsFlags.c_str()); + UrProgram, DeviceVec.size(), DeviceVec.data(), ur_exp_program_flags_t{}, + XsFlags.c_str()); if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { Res = Adapter.call_nocheck( ContextImpl.getHandleRef(), UrProgram, XsFlags.c_str()); diff --git a/sycl/source/detail/kernel_bundle_impl.hpp b/sycl/source/detail/kernel_bundle_impl.hpp index 463181fd7ca5d..fa76bc0e70c1e 100644 --- a/sycl/source/detail/kernel_bundle_impl.hpp +++ b/sycl/source/detail/kernel_bundle_impl.hpp @@ -61,6 +61,72 @@ inline bool checkAllDevicesHaveAspect(devices_range Devices, aspect Aspect) { [&Aspect](device_impl &Dev) { return Dev.has(Aspect); }); } +inline LinkGraph +CreateLinkGraph(const std::vector &DevImages) { + // Create a map between exported symbols and their indices in the device + // images collection. + std::map ExportMap; + for (size_t I = 0; I < DevImages.size(); ++I) { + device_image_impl &DevImageImpl = *getSyclObjImpl(DevImages[I]); + if (DevImageImpl.get_bin_image_ref() == nullptr) + continue; + for (const sycl_device_binary_property &ESProp : + DevImageImpl.get_bin_image_ref()->getExportedSymbols()) { + if (ExportMap.find(ESProp->Name) != ExportMap.end()) + throw sycl::exception(make_error_code(errc::invalid), + "Duplicate exported symbol \"" + + std::string{ESProp->Name} + + "\" found in binaries."); + ExportMap.emplace(ESProp->Name, I); + } + } + + // Create dependency mappings. + std::vector> Dependencies; + Dependencies.resize(DevImages.size()); + for (size_t I = 0; I < DevImages.size(); ++I) { + device_image_impl &DevImageImpl = *getSyclObjImpl(DevImages[I]); + if (DevImageImpl.get_bin_image_ref() == nullptr) + continue; + std::set DeviceImageDepsSet; + for (const sycl_device_binary_property &ISProp : + DevImageImpl.get_bin_image_ref()->getImportedSymbols()) { + auto ExportSymbolIt = ExportMap.find(ISProp->Name); + if (ExportSymbolIt == ExportMap.end()) + throw sycl::exception(make_error_code(errc::invalid), + "No exported symbol \"" + + std::string{ISProp->Name} + + "\" found in linked images."); + DeviceImageDepsSet.emplace(ExportSymbolIt->second); + } + Dependencies[I].insert(Dependencies[I].end(), DeviceImageDepsSet.begin(), + DeviceImageDepsSet.end()); + } + return LinkGraph{DevImages, Dependencies}; +} + +inline void +ThrowIfConflictingKernels(const std::vector &DevImages) { + std::set> SeenKernelNames; + std::set> Conflicts; + for (const device_image_plain &DevImage : DevImages) { + const KernelNameSetT &KernelNames = + getSyclObjImpl(DevImage)->getKernelNames(); + std::vector Intersect; + std::set_intersection(SeenKernelNames.begin(), SeenKernelNames.end(), + KernelNames.begin(), KernelNames.end(), + std::inserter(Conflicts, Conflicts.begin())); + SeenKernelNames.insert(KernelNames.begin(), KernelNames.end()); + } + if (Conflicts.empty()) + return; + std::stringstream MsgS; + MsgS << "Conflicting kernel definitions: "; + for (const std::string_view &Conflict : Conflicts) + MsgS << " " << Conflict; + throw sycl::exception(make_error_code(errc::invalid), MsgS.str()); +} + namespace syclex = sycl::ext::oneapi::experimental; class kernel_impl; @@ -204,7 +270,8 @@ class kernel_bundle_impl // Matches sycl::link kernel_bundle_impl( const std::vector> &ObjectBundles, - devices_range Devs, const property_list &PropList, private_tag) + devices_range Devs, const property_list &PropList, bool FastLink, + private_tag) : MDevices(Devs.to>()), MState(bundle_state::executable) { if (MDevices.empty()) @@ -263,97 +330,117 @@ class kernel_bundle_impl } } - // Collect all unique images. - std::vector DevImages; - { - std::set DevImagesSet; - std::unordered_set SeenBinImgs; - for (const kernel_bundle &ObjectBundle : - ObjectBundles) - for (device_image_impl &DevImg : - getSyclObjImpl(ObjectBundle)->device_images()) - if (OfflineDeviceImageSet.find(&DevImg) == - OfflineDeviceImageSet.end()) + std::map> DevImageLinkGraphs; + if (FastLink) { + // When doing fast-linking, we insert the suitable AOT binaries from the + // object bundles. This needs to be done per-device, as AOT binaries may + // not be compatible across different architectures. + for (device_impl &Dev : get_devices()) { + std::vector DevImages; + std::set DevImagesSet; + for (const kernel_bundle &ObjectBundle : + ObjectBundles) { + auto &ObjectBundleImpl = getSyclObjImpl(ObjectBundle); + + // Firstly find all suitable AOT binaries, if the object bundle was + // made from SYCLBIN. + std::vector AOTBinaries = + ObjectBundleImpl->GetSYCLBINAOTBinaries(Dev); + + // The AOT binaries need to be brought into executable state. They + // are considered unique, so they are placed directly into the unique + // images list. + DevImages.reserve(AOTBinaries.size()); + for (const detail::RTDeviceBinaryImage *Image : AOTBinaries) { + device_image_plain &AOTDevImg = + DevImages.emplace_back(device_image_impl::create( + Image, MContext, devices_range{Dev}, + bundle_state::executable, + /*KernelIDs=*/nullptr, Managed{}, + ImageOriginSYCLBIN)); + DevImgPlainWithDeps AOTDevImgWithDeps{AOTDevImg}; + ProgramManager::getInstance().bringSYCLDeviceImageToState( + AOTDevImgWithDeps, bundle_state::executable); + } + + // Record all the AOT exported symbols and kernels. + std::unordered_set AOTExportedSymbols; + std::unordered_set AOTKernelNames; + for (const RTDeviceBinaryImage *AOTBin : AOTBinaries) { + for (const sycl_device_binary_property &ESProp : + AOTBin->getExportedSymbols()) + AOTExportedSymbols.insert(ESProp->Name); + for (const sycl_device_binary_property &KNProp : + AOTBin->getKernelNames()) + AOTKernelNames.insert(KNProp->Name); + } + + for (device_image_impl &DevImg : ObjectBundleImpl->device_images()) { + // If the image is the same as one of the offline images, we can + // skip it. + if (OfflineDeviceImageSet.find(&DevImg) != + OfflineDeviceImageSet.end()) + continue; + + // If any of the exported symbols overlap with an AOT binary, skip + // this image. + for (const sycl_device_binary_property &ESProp : + DevImg.get_bin_image_ref()->getExportedSymbols()) + if (AOTExportedSymbols.find(ESProp->Name) != + AOTExportedSymbols.end()) + continue; + + // If any of the kernels overlap with an AOT binary, skip this + // image. + for (const sycl_device_binary_property &KNProp : + DevImg.get_bin_image_ref()->getKernelNames()) + if (AOTKernelNames.find(KNProp->Name) != AOTKernelNames.end()) + continue; + DevImagesSet.insert(&DevImg); - DevImages.reserve(DevImagesSet.size()); - for (auto It = DevImagesSet.begin(); It != DevImagesSet.end();) - DevImages.push_back(createSyclObjFromImpl( - *DevImagesSet.extract(It++).value())); - } + } + } - // Check for conflicting kernels in RTC kernel bundles. - { - std::set> SeenKernelNames; - std::set> Conflicts; - for (const device_image_plain &DevImage : DevImages) { - const KernelNameSetT &KernelNames = - getSyclObjImpl(DevImage)->getKernelNames(); - std::vector Intersect; - std::set_intersection(SeenKernelNames.begin(), SeenKernelNames.end(), - KernelNames.begin(), KernelNames.end(), - std::inserter(Conflicts, Conflicts.begin())); - SeenKernelNames.insert(KernelNames.begin(), KernelNames.end()); - } + DevImages.reserve(DevImages.size() + DevImagesSet.size()); + for (auto It = DevImagesSet.begin(); It != DevImagesSet.end();) + DevImages.push_back(createSyclObjFromImpl( + *DevImagesSet.extract(It++).value())); - if (!Conflicts.empty()) { - std::stringstream MsgS; - MsgS << "Conflicting kernel definitions: "; - for (const std::string_view &Conflict : Conflicts) - MsgS << " " << Conflict; - throw sycl::exception(make_error_code(errc::invalid), MsgS.str()); - } - } + // Check for conflicting kernels in RTC kernel bundles. + ThrowIfConflictingKernels(DevImages); - // Create a map between exported symbols and their indices in the device - // images collection. - std::map ExportMap; - for (size_t I = 0; I < DevImages.size(); ++I) { - device_image_impl &DevImageImpl = *getSyclObjImpl(DevImages[I]); - if (DevImageImpl.get_bin_image_ref() == nullptr) - continue; - for (const sycl_device_binary_property &ESProp : - DevImageImpl.get_bin_image_ref()->getExportedSymbols()) { - if (ExportMap.find(ESProp->Name) != ExportMap.end()) - throw sycl::exception(make_error_code(errc::invalid), - "Duplicate exported symbol \"" + - std::string{ESProp->Name} + - "\" found in binaries."); - ExportMap.emplace(ESProp->Name, I); + // Create and insert the corresponding link graph. + DevImageLinkGraphs.emplace(&Dev, CreateLinkGraph(DevImages)); } - } - - // Create dependency mappings. - std::vector> Dependencies; - Dependencies.resize(DevImages.size()); - for (size_t I = 0; I < DevImages.size(); ++I) { - device_image_impl &DevImageImpl = *getSyclObjImpl(DevImages[I]); - if (DevImageImpl.get_bin_image_ref() == nullptr) - continue; - std::set DeviceImageDepsSet; - for (const sycl_device_binary_property &ISProp : - DevImageImpl.get_bin_image_ref()->getImportedSymbols()) { - auto ExportSymbolIt = ExportMap.find(ISProp->Name); - if (ExportSymbolIt == ExportMap.end()) - throw sycl::exception(make_error_code(errc::invalid), - "No exported symbol \"" + - std::string{ISProp->Name} + - "\" found in linked images."); - DeviceImageDepsSet.emplace(ExportSymbolIt->second); + } else { + // Collect all unique images. + std::vector DevImages; + { + std::set DevImagesSet; + for (const kernel_bundle &ObjectBundle : + ObjectBundles) + for (device_image_impl &DevImg : + getSyclObjImpl(ObjectBundle)->device_images()) + if (OfflineDeviceImageSet.find(&DevImg) == + OfflineDeviceImageSet.end()) + DevImagesSet.insert(&DevImg); + DevImages.reserve(DevImagesSet.size()); + for (auto It = DevImagesSet.begin(); It != DevImagesSet.end();) + DevImages.push_back(createSyclObjFromImpl( + *DevImagesSet.extract(It++).value())); } - Dependencies[I].insert(Dependencies[I].end(), DeviceImageDepsSet.begin(), - DeviceImageDepsSet.end()); - } - // Create a link graph and clone it for each device. - device_impl &FirstDevice = get_devices().front(); - std::map> DevImageLinkGraphs; - const auto &FirstGraph = - DevImageLinkGraphs - .emplace(&FirstDevice, - LinkGraph{DevImages, Dependencies}) - .first->second; - for (device_impl &Dev : get_devices()) - DevImageLinkGraphs.emplace(&Dev, FirstGraph.Clone()); + // Check for conflicting kernels in RTC kernel bundles. + ThrowIfConflictingKernels(DevImages); + + // Create a link graph and clone it for each device. + device_impl &FirstDevice = get_devices().front(); + const auto &FirstGraph = + DevImageLinkGraphs.emplace(&FirstDevice, CreateLinkGraph(DevImages)) + .first->second; + for (device_impl &Dev : get_devices()) + DevImageLinkGraphs.emplace(&Dev, FirstGraph.Clone()); + } // Poison the images based on whether the corresponding device supports it. for (auto &GraphIt : DevImageLinkGraphs) { @@ -369,9 +456,49 @@ class kernel_bundle_impl // Link based on the resulting graphs. for (auto &GraphIt : UnifiedGraphs) { + const std::vector &GraphDevs = GraphIt.first; + std::vector GraphImgs = + GraphIt.second.GetNodeValues(); + + sycl::span JITImgs{GraphImgs}; + sycl::span AOTImgs{}; + if (FastLink) { + std::sort( + GraphImgs.begin(), GraphImgs.end(), + [](const device_image_plain &LHS, const device_image_plain &RHS) { + // Sort by state: That leaves objects (JIT) at the beginning and + // executables (AOT) at the end. + return getSyclObjImpl(LHS)->get_state() < + getSyclObjImpl(RHS)->get_state(); + }); + auto AOTImgsBegin = + std::find_if(GraphImgs.begin(), GraphImgs.end(), + [](const device_image_plain &Img) { + return getSyclObjImpl(Img)->get_state() == + bundle_state::executable; + }); + size_t NumJITImgs = std::distance(GraphImgs.begin(), AOTImgsBegin); + JITImgs = sycl::span( + GraphImgs.data(), NumJITImgs); + AOTImgs = sycl::span( + GraphImgs.data() + NumJITImgs, GraphImgs.size() - NumJITImgs); + } + + // If there AOT binaries, the link should allow unresolved symbols. std::vector LinkedResults = detail::ProgramManager::getInstance().link( - GraphIt.second.GetNodeValues(), GraphIt.first, PropList); + JITImgs, GraphDevs, PropList, + /*AllowUnresolvedSymbols=*/!AOTImgs.empty()); + + if (!AOTImgs.empty()) { + // In dynamic linking, AOT binaries count as results as well. + LinkedResults.insert(LinkedResults.end(), AOTImgs.begin(), + AOTImgs.end()); + sycl::span LinkedResultsSpan( + LinkedResults.data(), LinkedResults.size()); + detail::ProgramManager::getInstance().dynamicLink(LinkedResultsSpan); + } + MDeviceImages.insert(MDeviceImages.end(), LinkedResults.begin(), LinkedResults.end()); MUniqueDeviceImages.insert(MUniqueDeviceImages.end(), @@ -390,9 +517,13 @@ class kernel_bundle_impl })) continue; + const std::vector &AllDevImgs = + DeviceImageWithDeps->getAll(); + sycl::span AllDevImgsSpan( + AllDevImgs.data(), AllDevImgs.size()); std::vector LinkedResults = - detail::ProgramManager::getInstance().link( - DeviceImageWithDeps->getAll(), MDevices, PropList); + detail::ProgramManager::getInstance().link(AllDevImgsSpan, MDevices, + PropList); MDeviceImages.insert(MDeviceImages.end(), LinkedResults.begin(), LinkedResults.end()); MUniqueDeviceImages.insert(MUniqueDeviceImages.end(), @@ -1077,6 +1208,20 @@ class kernel_bundle_impl MUniqueDeviceImages.erase(It, MUniqueDeviceImages.end()); } + std::vector + GetSYCLBINAOTBinaries(device_impl &Dev) { + if (MSYCLBINs.size() == 1) + return MSYCLBINs[0]->getNativeBinaryImages(Dev); + + std::vector Result; + for (auto &SYCLBIN : MSYCLBINs) { + auto NativeBinImgs = SYCLBIN->getNativeBinaryImages(Dev); + Result.insert(Result.end(), NativeBinImgs.begin(), NativeBinImgs.end()); + } + + return Result; + } + context MContext; std::vector MDevices; diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index eab2b88d0ad34..a8be1daeb579d 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -1302,7 +1302,7 @@ static ur_result_t doCompile(adapter_impl &Adapter, ur_program_handle_t Program, // Try to compile with given devices, fall back to compiling with the program // context if unsupported by the adapter auto Result = Adapter.call_nocheck( - Program, NumDevs, Devs, Opts); + Program, NumDevs, Devs, ur_exp_program_flags_t{}, Opts); if (Result == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { return Adapter.call_nocheck(Ctx, Program, Opts); @@ -1723,7 +1723,8 @@ Managed ProgramManager::build( ? CompileOptions : (CompileOptions + " " + LinkOptions); ur_result_t Error = Adapter.call_nocheck( - Program, Devices.size(), Devices.data(), Options.c_str()); + Program, Devices.size(), Devices.data(), ur_exp_program_flags_t{}, + Options.c_str()); if (Error == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { Error = Adapter.call_nocheck( Context.getHandleRef(), Program, Options.c_str()); @@ -1759,8 +1760,8 @@ Managed ProgramManager::build( auto doLink = [&] { auto Res = Adapter.call_nocheck( Context.getHandleRef(), Devices.size(), Devices.data(), - LinkPrograms.size(), LinkPrograms.data(), LinkOptions.c_str(), - &LinkedProg); + ur_exp_program_flags_t{}, LinkPrograms.size(), LinkPrograms.data(), + LinkOptions.c_str(), &LinkedProg); if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { Res = Adapter.call_nocheck( Context.getHandleRef(), LinkPrograms.size(), LinkPrograms.data(), @@ -2908,7 +2909,7 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps, // Returns a merged device binary image, new set of kernel IDs and new // specialization constant data. static const RTDeviceBinaryImage * -mergeImageData(const std::vector &Imgs, +mergeImageData(sycl::span Imgs, std::vector &KernelIDs, std::vector &NewSpecConstBlob, device_image_impl::SpecConstMapT &NewSpecConstMap, @@ -2969,8 +2970,9 @@ mergeImageData(const std::vector &Imgs, } std::vector -ProgramManager::link(const std::vector &Imgs, - devices_range Devs, const property_list &PropList) { +ProgramManager::link(sycl::span Imgs, + devices_range Devs, const property_list &PropList, + bool AllowUnresolvedSymbols) { { auto NoAllowedPropertiesCheck = [](int) { return false; }; detail::PropertyValidator::checkPropsAndThrow( @@ -2997,12 +2999,16 @@ ProgramManager::link(const std::vector &Imgs, context_impl &ContextImpl = *getSyclObjImpl(Context); adapter_impl &Adapter = ContextImpl.getAdapter(); + ur_exp_program_flags_t UrLinkFlags{}; + if (AllowUnresolvedSymbols) + UrLinkFlags &= UR_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS; + Managed LinkedProg{Adapter}; auto doLink = [&] { auto Res = Adapter.call_nocheck( ContextImpl.getHandleRef(), URDevices.size(), URDevices.data(), - URPrograms.size(), URPrograms.data(), LinkOptionsStr.c_str(), - &LinkedProg); + UrLinkFlags, URPrograms.size(), URPrograms.data(), + LinkOptionsStr.c_str(), &LinkedProg); if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { Res = Adapter.call_nocheck( ContextImpl.getHandleRef(), URPrograms.size(), URPrograms.data(), @@ -3084,6 +3090,23 @@ ProgramManager::link(const std::vector &Imgs, std::move(MergedImageStorage)))}; } +void ProgramManager::dynamicLink( + sycl::span Imgs) { + if (Imgs.empty()) + return; + + std::vector URPrograms; + URPrograms.reserve(Imgs.size()); + for (const device_image_plain &Img : Imgs) + URPrograms.push_back(getSyclObjImpl(Img)->get_ur_program()); + + device_image_impl &FirstImgImpl = *getSyclObjImpl(Imgs[0]); + context_impl &ContextImpl = *getSyclObjImpl(FirstImgImpl.get_context()); + adapter_impl &Adapter = ContextImpl.getAdapter(); + Adapter.call( + ContextImpl.getHandleRef(), URPrograms.size(), URPrograms.data()); +} + // The function duplicates most of the code from existing getBuiltPIProgram. // The differences are: // Different API - uses different objects to extract required info diff --git a/sycl/source/detail/program_manager/program_manager.hpp b/sycl/source/detail/program_manager/program_manager.hpp index b9d0dc700f77c..0d7f7adbe287e 100644 --- a/sycl/source/detail/program_manager/program_manager.hpp +++ b/sycl/source/detail/program_manager/program_manager.hpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -350,8 +351,12 @@ class ProgramManager { // Produces set of device images by convering input device images to object // the executable state std::vector - link(const std::vector &Imgs, devices_range Devs, - const property_list &PropList); + link(sycl::span Imgs, + devices_range Devs, const property_list &PropList, + bool AllowUnresolvedSymbols = false); + + // Dynamically links images in executable state. + void dynamicLink(sycl::span Imgs); // Produces new device image by converting input device image to the // executable state diff --git a/sycl/source/detail/syclbin.cpp b/sycl/source/detail/syclbin.cpp index 90513f253a2cb..16ed4ace2126a 100644 --- a/sycl/source/detail/syclbin.cpp +++ b/sycl/source/detail/syclbin.cpp @@ -453,6 +453,25 @@ SYCLBINBinaries::getBestCompatibleImages(devices_range Devs, return {Images.cbegin(), Images.cend()}; } +std::vector +SYCLBINBinaries::getNativeBinaryImages(device_impl &Dev) { + std::vector Images; + for (size_t I = 0; I < getNumAbstractModules(); ++I) { + const AbstractModuleDesc &AMDesc = AbstractModuleDescriptors[I]; + // If the target state is executable, try with native images first. + + const RTDeviceBinaryImage *CompatImagePtr = std::find_if( + AMDesc.NativeBinaries, AMDesc.NativeBinaries + AMDesc.NumNativeBinaries, + [&](const RTDeviceBinaryImage &Img) { + return doesDevSupportDeviceRequirements(Dev, Img) && + doesImageTargetMatchDevice(Img, Dev); + }); + if (CompatImagePtr != AMDesc.NativeBinaries + AMDesc.NumNativeBinaries) + Images.push_back(CompatImagePtr); + } + return Images; +} + } // namespace detail } // namespace _V1 } // namespace sycl diff --git a/sycl/source/detail/syclbin.hpp b/sycl/source/detail/syclbin.hpp index 02b9ec8348c99..c9e20bbc4458a 100644 --- a/sycl/source/detail/syclbin.hpp +++ b/sycl/source/detail/syclbin.hpp @@ -128,6 +128,9 @@ struct SYCLBINBinaries { std::vector getBestCompatibleImages(devices_range Dev, bundle_state State); + std::vector + getNativeBinaryImages(device_impl &Dev); + uint8_t getState() const { PropertySet &GlobalMetadata = (*ParsedSYCLBIN diff --git a/sycl/source/kernel_bundle.cpp b/sycl/source/kernel_bundle.cpp index 1d844dd9517d6..0dcd05026e24f 100644 --- a/sycl/source/kernel_bundle.cpp +++ b/sycl/source/kernel_bundle.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -311,7 +312,15 @@ compile_impl(const kernel_bundle &InputBundle, std::shared_ptr link_impl(const std::vector> &ObjectBundles, const std::vector &Devs, const property_list &PropList) { - return detail::kernel_bundle_impl::create(ObjectBundles, Devs, PropList); + return detail::kernel_bundle_impl::create(ObjectBundles, Devs, PropList, + /*FastLink=*/false); +} + +std::shared_ptr +link_impl(const std::vector> &ObjectBundles, + const std::vector &Devs, bool FastLink) { + return detail::kernel_bundle_impl::create(ObjectBundles, Devs, + property_list{}, FastLink); } std::shared_ptr diff --git a/sycl/test-e2e/SYCLBIN/Inputs/link.hpp b/sycl/test-e2e/SYCLBIN/Inputs/link.hpp index ce14379c4e19a..729171b43df25 100644 --- a/sycl/test-e2e/SYCLBIN/Inputs/link.hpp +++ b/sycl/test-e2e/SYCLBIN/Inputs/link.hpp @@ -4,6 +4,12 @@ namespace syclex = sycl::ext::oneapi::experimental; +#ifdef SYCLBIN_USE_FAST_LINK +static constexpr bool USE_FAST_LINK = true; +#else +static constexpr bool USE_FAST_LINK = false; +#endif + static constexpr size_t NUM = 10; int main(int argc, char *argv[]) { @@ -34,7 +40,8 @@ int main(int argc, char *argv[]) { #endif // Link the bundles. - auto KBExe = sycl::link({KBObj1, KBObj2}); + auto KBExe = syclexp::link( + {KBObj1, KBObj2}, syclexp::properties{syclexp::fast_link{USE_FAST_LINK}}); // TestKernel1 does not have any requirements, so should be there always. assert(KBExe.ext_oneapi_has_kernel("TestKernel1")); diff --git a/sycl/test-e2e/SYCLBIN/link_object_fast_link.cpp b/sycl/test-e2e/SYCLBIN/link_object_fast_link.cpp new file mode 100644 index 0000000000000..106b0ea6ce895 --- /dev/null +++ b/sycl/test-e2e/SYCLBIN/link_object_fast_link.cpp @@ -0,0 +1,29 @@ +//==-------------- link_input.cpp --- SYCLBIN extension tests --------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// REQUIRES: aspect-usm_shared_allocations + +// -- Test for linking two SYCLBIN kernel_bundle. + +// REQUIRES: target-spir + +// RUN: %clangxx --offload-new-driver -fsycl-rdc -fsyclbin=object %{syclbin_exec_opts} -fsycl-allow-device-image-dependencies %S/Inputs/exporting_function.cpp -o %t.export_w_aot.syclbin +// RUN: %clangxx --offload-new-driver -fsycl-rdc -fsyclbin=object %{syclbin_exec_opts} -fsycl-allow-device-image-dependencies %S/Inputs/importing_kernel.cpp -o %t.import_w_aot.syclbin +// RUN: %clangxx --offload-new-driver -fsyclbin=object -fsycl-allow-device-image-dependencies %S/Inputs/exporting_function.cpp -o %t.export.syclbin +// RUN: %clangxx --offload-new-driver -fsyclbin=object -fsycl-allow-device-image-dependencies %S/Inputs/importing_kernel.cpp -o %t.import.syclbin +// RUN: %{build} -o %t.out + +// RUN: %{run} %t.out %t.export.syclbin %t.import.syclbin +// RUN: %{run} %t.out %t.export_w_aot.syclbin %t.import_w_aot.syclbin +// RUN: %{run} %t.out %t.export.syclbin %t.import_w_aot.syclbin +// RUN: %{run} %t.out %t.export_w_aot.syclbin %t.import.syclbin + +#define SYCLBIN_OBJECT_STATE +#define SYCLBIN_USE_FAST_LINK + +#include "Inputs/link.hpp" diff --git a/sycl/test/abi/sycl_symbols_linux.dump b/sycl/test/abi/sycl_symbols_linux.dump index ec2a21ef34424..d8823a8dadd5a 100644 --- a/sycl/test/abi/sycl_symbols_linux.dump +++ b/sycl/test/abi/sycl_symbols_linux.dump @@ -3384,6 +3384,7 @@ _ZN4sycl3_V16detail6OSUtil12getOSMemSizeEv _ZN4sycl3_V16detail6OSUtil16getCurrentDSODirB5cxx11Ev _ZN4sycl3_V16detail6OSUtil7makeDirEPKc _ZN4sycl3_V16detail9join_implERKSt6vectorISt10shared_ptrINS1_18kernel_bundle_implEESaIS5_EENS0_12bundle_stateE +_ZN4sycl3_V16detail9link_implERKSt6vectorINS0_13kernel_bundleILNS0_12bundle_stateE1EEESaIS5_EERKS2_INS0_6deviceESaISA_EEb _ZN4sycl3_V16detail9link_implERKSt6vectorINS0_13kernel_bundleILNS0_12bundle_stateE1EEESaIS5_EERKS2_INS0_6deviceESaISA_EERKNS0_13property_listE _ZN4sycl3_V16detail9modf_implENS1_9half_impl4halfEPS3_ _ZN4sycl3_V16detail9modf_implEdPd diff --git a/sycl/test/abi/sycl_symbols_windows.dump b/sycl/test/abi/sycl_symbols_windows.dump index 8eed8b8dba437..f743a351d5061 100644 --- a/sycl/test/abi/sycl_symbols_windows.dump +++ b/sycl/test/abi/sycl_symbols_windows.dump @@ -4292,6 +4292,7 @@ ?lgamma_r_impl@detail@_V1@sycl@@YAMMPEAH@Z ?lgamma_r_impl@detail@_V1@sycl@@YANNPEAH@Z ?link_impl@detail@_V1@sycl@@YA?AV?$shared_ptr@Vkernel_bundle_impl@detail@_V1@sycl@@@std@@AEBV?$vector@V?$kernel_bundle@$00@_V1@sycl@@V?$allocator@V?$kernel_bundle@$00@_V1@sycl@@@std@@@5@AEBV?$vector@Vdevice@_V1@sycl@@V?$allocator@Vdevice@_V1@sycl@@@std@@@5@AEBVproperty_list@23@@Z +?link_impl@detail@_V1@sycl@@YA?AV?$shared_ptr@Vkernel_bundle_impl@detail@_V1@sycl@@@std@@AEBV?$vector@V?$kernel_bundle@$00@_V1@sycl@@V?$allocator@V?$kernel_bundle@$00@_V1@sycl@@@std@@@5@AEBV?$vector@Vdevice@_V1@sycl@@V?$allocator@Vdevice@_V1@sycl@@@std@@@5@_N@Z ?makeDir@OSUtil@detail@_V1@sycl@@SAHPEBD@Z ?make_context@detail@_V1@sycl@@YA?AVcontext@23@_KAEBV?$function@$$A6AXVexception_list@_V1@sycl@@@Z@std@@W4backend@23@_NAEBV?$vector@Vdevice@_V1@sycl@@V?$allocator@Vdevice@_V1@sycl@@@std@@@6@@Z ?make_device@detail@_V1@sycl@@YA?AVdevice@23@_KW4backend@23@@Z diff --git a/unified-runtime/include/ur_api.h b/unified-runtime/include/ur_api.h index 736060f3ab721..1491c73b37c65 100644 --- a/unified-runtime/include/ur_api.h +++ b/unified-runtime/include/ur_api.h @@ -475,6 +475,8 @@ typedef enum ur_function_t { UR_FUNCTION_MEMORY_EXPORT_EXPORT_MEMORY_HANDLE_EXP = 287, /// Enumerator for ::urBindlessImagesSupportsImportingHandleTypeExp UR_FUNCTION_BINDLESS_IMAGES_SUPPORTS_IMPORTING_HANDLE_TYPE_EXP = 288, + /// Enumerator for ::urProgramDynamicLinkExp + UR_FUNCTION_PROGRAM_DYNAMIC_LINK_EXP = 289, /// @cond UR_FUNCTION_FORCE_UINT32 = 0x7fffffff /// @endcond @@ -2444,6 +2446,9 @@ typedef enum ur_device_info_t { /// [::ur_bool_t] Returns true if the device supports the multi device /// compile experimental feature. UR_DEVICE_INFO_MULTI_DEVICE_COMPILE_SUPPORT_EXP = 0x6000, + /// [::ur_bool_t] Returns true if the device supports the dynamic linking + /// experimental feature. + UR_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP = 0x6001, /// [::ur_bool_t] returns true if the device supports /// ::urUSMContextMemcpyExp UR_DEVICE_INFO_USM_CONTEXT_MEMCPY_SUPPORT_EXP = 0x7000, @@ -12343,6 +12348,47 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetNativeHandleExp( /// [out] A pointer to the native handle of the command-buffer. ur_native_handle_t *phNativeCommandBuffer); +#if !defined(__GNUC__) +#pragma endregion +#endif +// Intel 'oneAPI' Unified Runtime Experimental APIs for dynamic linking +#if !defined(__GNUC__) +#pragma region dynamic_link_(experimental) +#endif +/////////////////////////////////////////////////////////////////////////////// +/// @brief Creates dynamic links between exported and imported symbols in one or +/// more programs. +/// +/// @details +/// - The application may call this function from simultaneous threads. +/// - Following a successful call to this entry point the programs in +/// `phPrograms` will have all external symbols resolved and kernels +/// inside these programs would be ready for use. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hContext` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == phPrograms` +/// - ::UR_RESULT_ERROR_INVALID_PROGRAM +/// + If one of the programs in `phPrograms` isn't a valid program +/// object. +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// + `count == 0` +/// - ::UR_RESULT_ERROR_PROGRAM_LINK_FAILURE +/// + If an error occurred while linking `phPrograms`. +UR_APIEXPORT ur_result_t UR_APICALL urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms); + #if !defined(__GNUC__) #pragma endregion #endif @@ -12520,6 +12566,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemoryExportExportMemoryHandleExp( #if !defined(__GNUC__) #pragma region multi_device_compile_(experimental) #endif +/////////////////////////////////////////////////////////////////////////////// +/// @brief Program operation behavior control flags +typedef uint32_t ur_exp_program_flags_t; +typedef enum ur_exp_program_flag_t { + /// Allow unresolved symbols in the program resulting from the + /// corresponding operation + UR_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS = UR_BIT(0), + /// @cond + UR_EXP_PROGRAM_FLAG_FORCE_UINT32 = 0x7fffffff + /// @endcond + +} ur_exp_program_flag_t; +/// @brief Bit Mask for validating ur_exp_program_flags_t +#define UR_EXP_PROGRAM_FLAGS_MASK 0xfffffffe + /////////////////////////////////////////////////////////////////////////////// /// @brief Produces an executable program from one program, negates need for the /// linking step. @@ -12543,6 +12604,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemoryExportExportMemoryHandleExp( /// + `NULL == hProgram` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == phDevices` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_EXP_PROGRAM_FLAGS_MASK & flags` /// - ::UR_RESULT_ERROR_INVALID_PROGRAM /// + If `hProgram` isn't a valid program object. /// - ::UR_RESULT_ERROR_PROGRAM_BUILD_FAILURE @@ -12554,6 +12617,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions); @@ -12579,6 +12644,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp( /// + `NULL == hProgram` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == phDevices` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_EXP_PROGRAM_FLAGS_MASK & flags` /// - ::UR_RESULT_ERROR_INVALID_PROGRAM /// + If `hProgram` isn't a valid program object. /// - ::UR_RESULT_ERROR_PROGRAM_BUILD_FAILURE @@ -12590,6 +12657,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions); @@ -12623,6 +12692,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp( /// + `NULL == phDevices` /// + `NULL == phPrograms` /// + `NULL == phProgram` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_EXP_PROGRAM_FLAGS_MASK & flags` /// - ::UR_RESULT_ERROR_INVALID_PROGRAM /// + If one of the programs in `phPrograms` isn't a valid program /// object. @@ -12637,6 +12708,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. @@ -13464,6 +13537,16 @@ typedef struct ur_program_build_params_t { const char **ppOptions; } ur_program_build_params_t; +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for urProgramDynamicLinkExp +/// @details Each entry is a pointer to the parameter passed to the function; +/// allowing the callback the ability to modify the parameter's value +typedef struct ur_program_dynamic_link_exp_params_t { + ur_context_handle_t *phContext; + uint32_t *pcount; + const ur_program_handle_t **pphPrograms; +} ur_program_dynamic_link_exp_params_t; + /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for urProgramBuildExp /// @details Each entry is a pointer to the parameter passed to the function; @@ -13472,6 +13555,7 @@ typedef struct ur_program_build_exp_params_t { ur_program_handle_t *phProgram; uint32_t *pnumDevices; ur_device_handle_t **pphDevices; + ur_exp_program_flags_t *pflags; const char **ppOptions; } ur_program_build_exp_params_t; @@ -13493,6 +13577,7 @@ typedef struct ur_program_compile_exp_params_t { ur_program_handle_t *phProgram; uint32_t *pnumDevices; ur_device_handle_t **pphDevices; + ur_exp_program_flags_t *pflags; const char **ppOptions; } ur_program_compile_exp_params_t; @@ -13516,6 +13601,7 @@ typedef struct ur_program_link_exp_params_t { ur_context_handle_t *phContext; uint32_t *pnumDevices; ur_device_handle_t **pphDevices; + ur_exp_program_flags_t *pflags; uint32_t *pcount; const ur_program_handle_t **pphPrograms; const char **ppOptions; diff --git a/unified-runtime/include/ur_api_funcs.def b/unified-runtime/include/ur_api_funcs.def index f0c92445b9238..c7d8cf3f39d45 100644 --- a/unified-runtime/include/ur_api_funcs.def +++ b/unified-runtime/include/ur_api_funcs.def @@ -59,6 +59,7 @@ _UR_API(urProgramGetBuildInfo) _UR_API(urProgramSetSpecializationConstants) _UR_API(urProgramGetNativeHandle) _UR_API(urProgramCreateWithNativeHandle) +_UR_API(urProgramDynamicLinkExp) _UR_API(urProgramBuildExp) _UR_API(urProgramCompileExp) _UR_API(urProgramLinkExp) diff --git a/unified-runtime/include/ur_ddi.h b/unified-runtime/include/ur_ddi.h index f59e15a9eb3cd..591229f0590af 100644 --- a/unified-runtime/include/ur_ddi.h +++ b/unified-runtime/include/ur_ddi.h @@ -418,29 +418,35 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramProcAddrTable( typedef ur_result_t(UR_APICALL *ur_pfnGetProgramProcAddrTable_t)( ur_api_version_t, ur_program_dditable_t *); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function-pointer for urProgramDynamicLinkExp +typedef ur_result_t(UR_APICALL *ur_pfnProgramDynamicLinkExp_t)( + ur_context_handle_t, uint32_t, const ur_program_handle_t *); + /////////////////////////////////////////////////////////////////////////////// /// @brief Function-pointer for urProgramBuildExp typedef ur_result_t(UR_APICALL *ur_pfnProgramBuildExp_t)(ur_program_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, const char *); /////////////////////////////////////////////////////////////////////////////// /// @brief Function-pointer for urProgramCompileExp -typedef ur_result_t(UR_APICALL *ur_pfnProgramCompileExp_t)(ur_program_handle_t, - uint32_t, - ur_device_handle_t *, - const char *); +typedef ur_result_t(UR_APICALL *ur_pfnProgramCompileExp_t)( + ur_program_handle_t, uint32_t, ur_device_handle_t *, ur_exp_program_flags_t, + const char *); /////////////////////////////////////////////////////////////////////////////// /// @brief Function-pointer for urProgramLinkExp typedef ur_result_t(UR_APICALL *ur_pfnProgramLinkExp_t)( - ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t, - const ur_program_handle_t *, const char *, ur_program_handle_t *); + ur_context_handle_t, uint32_t, ur_device_handle_t *, ur_exp_program_flags_t, + uint32_t, const ur_program_handle_t *, const char *, ur_program_handle_t *); /////////////////////////////////////////////////////////////////////////////// /// @brief Table of ProgramExp functions pointers typedef struct ur_program_exp_dditable_t { + ur_pfnProgramDynamicLinkExp_t pfnDynamicLinkExp; ur_pfnProgramBuildExp_t pfnBuildExp; ur_pfnProgramCompileExp_t pfnCompileExp; ur_pfnProgramLinkExp_t pfnLinkExp; diff --git a/unified-runtime/include/ur_print.h b/unified-runtime/include/ur_print.h index 7c4528ec9ea81..9ce9e9a124945 100644 --- a/unified-runtime/include/ur_print.h +++ b/unified-runtime/include/ur_print.h @@ -1415,6 +1415,16 @@ urPrintExpCommandBufferUpdateKernelLaunchDesc( const struct ur_exp_command_buffer_update_kernel_launch_desc_t params, char *buffer, const size_t buff_size, size_t *out_size); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print ur_exp_program_flag_t enum +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// - `buff_size < out_size` +UR_APIEXPORT ur_result_t UR_APICALL +urPrintExpProgramFlags(enum ur_exp_program_flag_t value, char *buffer, + const size_t buff_size, size_t *out_size); + /////////////////////////////////////////////////////////////////////////////// /// @brief Print ur_exp_peer_info_t enum /// @returns @@ -1846,6 +1856,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urPrintProgramBuildParams( const struct ur_program_build_params_t *params, char *buffer, const size_t buff_size, size_t *out_size); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print ur_program_dynamic_link_exp_params_t struct +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// - `buff_size < out_size` +UR_APIEXPORT ur_result_t UR_APICALL urPrintProgramDynamicLinkExpParams( + const struct ur_program_dynamic_link_exp_params_t *params, char *buffer, + const size_t buff_size, size_t *out_size); + /////////////////////////////////////////////////////////////////////////////// /// @brief Print ur_program_build_exp_params_t struct /// @returns diff --git a/unified-runtime/include/ur_print.hpp b/unified-runtime/include/ur_print.hpp index 9d7bef35061fd..7c2440920eb36 100644 --- a/unified-runtime/include/ur_print.hpp +++ b/unified-runtime/include/ur_print.hpp @@ -259,6 +259,10 @@ inline ur_result_t printTagged(std::ostream &os, const void *ptr, ur_exp_command_buffer_command_info_t value, size_t size); +template <> +inline ur_result_t printFlag(std::ostream &os, + uint32_t flag); + template <> inline ur_result_t printTagged(std::ostream &os, const void *ptr, ur_exp_peer_info_t value, size_t size); @@ -592,6 +596,8 @@ inline std::ostream &operator<<( inline std::ostream & operator<<(std::ostream &os, [[maybe_unused]] const struct ur_exp_command_buffer_update_kernel_launch_desc_t params); +inline std::ostream &operator<<(std::ostream &os, + enum ur_exp_program_flag_t value); inline std::ostream &operator<<(std::ostream &os, enum ur_exp_peer_info_t value); inline std::ostream &operator<<(std::ostream &os, @@ -1276,6 +1282,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_function_t value) { case UR_FUNCTION_BINDLESS_IMAGES_SUPPORTS_IMPORTING_HANDLE_TYPE_EXP: os << "UR_FUNCTION_BINDLESS_IMAGES_SUPPORTS_IMPORTING_HANDLE_TYPE_EXP"; break; + case UR_FUNCTION_PROGRAM_DYNAMIC_LINK_EXP: + os << "UR_FUNCTION_PROGRAM_DYNAMIC_LINK_EXP"; + break; default: os << "unknown enumerator"; break; @@ -3136,6 +3145,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_device_info_t value) { case UR_DEVICE_INFO_MULTI_DEVICE_COMPILE_SUPPORT_EXP: os << "UR_DEVICE_INFO_MULTI_DEVICE_COMPILE_SUPPORT_EXP"; break; + case UR_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP: + os << "UR_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP"; + break; case UR_DEVICE_INFO_USM_CONTEXT_MEMCPY_SUPPORT_EXP: os << "UR_DEVICE_INFO_USM_CONTEXT_MEMCPY_SUPPORT_EXP"; break; @@ -5333,6 +5345,19 @@ inline ur_result_t printTagged(std::ostream &os, const void *ptr, os << ")"; } break; + case UR_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP: { + const ur_bool_t *tptr = (const ur_bool_t *)ptr; + if (sizeof(ur_bool_t) > size) { + os << "invalid size (is: " << size + << ", expected: >=" << sizeof(ur_bool_t) << ")"; + return UR_RESULT_ERROR_INVALID_SIZE; + } + os << (const void *)(tptr) << " ("; + + os << *tptr; + + os << ")"; + } break; case UR_DEVICE_INFO_USM_CONTEXT_MEMCPY_SUPPORT_EXP: { const ur_bool_t *tptr = (const ur_bool_t *)ptr; if (sizeof(ur_bool_t) > size) { @@ -12299,6 +12324,54 @@ inline std::ostream &operator<<( return os; } /////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ur_exp_program_flag_t type +/// @returns +/// std::ostream & +inline std::ostream &operator<<(std::ostream &os, + enum ur_exp_program_flag_t value) { + switch (value) { + case UR_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS: + os << "UR_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS"; + break; + default: + os << "unknown enumerator"; + break; + } + return os; +} + +namespace ur::details { +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print ur_exp_program_flag_t flag +template <> +inline ur_result_t printFlag(std::ostream &os, + uint32_t flag) { + uint32_t val = flag; + bool first = true; + + if ((val & UR_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS) == + (uint32_t)UR_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS) { + val ^= (uint32_t)UR_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS; + if (!first) { + os << " | "; + } else { + first = false; + } + os << UR_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS; + } + if (val != 0) { + std::bitset<32> bits(val); + if (!first) { + os << " | "; + } + os << "unknown bit flags " << bits; + } else if (first) { + os << "0"; + } + return UR_RESULT_SUCCESS; +} +} // namespace ur::details +/////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ur_exp_peer_info_t type /// @returns /// std::ostream & @@ -13579,6 +13652,43 @@ operator<<(std::ostream &os, return os; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ur_program_dynamic_link_exp_params_t type +/// @returns +/// std::ostream & +inline std::ostream & +operator<<(std::ostream &os, + [[maybe_unused]] const struct ur_program_dynamic_link_exp_params_t + *params) { + + os << ".hContext = "; + + ur::details::printPtr(os, *(params->phContext)); + + os << ", "; + os << ".count = "; + + os << *(params->pcount); + + os << ", "; + os << ".phPrograms = "; + ur::details::printPtr(os, + reinterpret_cast(*(params->pphPrograms))); + if (*(params->pphPrograms) != NULL) { + os << " {"; + for (size_t i = 0; i < *params->pcount; ++i) { + if (i != 0) { + os << ", "; + } + + ur::details::printPtr(os, (*(params->pphPrograms))[i]); + } + os << "}"; + } + + return os; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ur_program_build_exp_params_t type /// @returns @@ -13612,6 +13722,11 @@ inline std::ostream &operator<<( os << "}"; } + os << ", "; + os << ".flags = "; + + ur::details::printFlag(os, *(params->pflags)); + os << ", "; os << ".pOptions = "; @@ -13678,6 +13793,11 @@ inline std::ostream &operator<<( os << "}"; } + os << ", "; + os << ".flags = "; + + ur::details::printFlag(os, *(params->pflags)); + os << ", "; os << ".pOptions = "; @@ -13765,6 +13885,11 @@ operator<<(std::ostream &os, os << "}"; } + os << ", "; + os << ".flags = "; + + ur::details::printFlag(os, *(params->pflags)); + os << ", "; os << ".count = "; @@ -21277,6 +21402,9 @@ inline ur_result_t UR_APICALL printFunctionParams(std::ostream &os, case UR_FUNCTION_PROGRAM_BUILD: { os << (const struct ur_program_build_params_t *)params; } break; + case UR_FUNCTION_PROGRAM_DYNAMIC_LINK_EXP: { + os << (const struct ur_program_dynamic_link_exp_params_t *)params; + } break; case UR_FUNCTION_PROGRAM_BUILD_EXP: { os << (const struct ur_program_build_exp_params_t *)params; } break; diff --git a/unified-runtime/scripts/core/EXP-DYNAMIC-LINK.rst b/unified-runtime/scripts/core/EXP-DYNAMIC-LINK.rst new file mode 100644 index 0000000000000..71f11ae33dc2b --- /dev/null +++ b/unified-runtime/scripts/core/EXP-DYNAMIC-LINK.rst @@ -0,0 +1,65 @@ +<% + OneApi=tags['$OneApi'] + x=tags['$x'] + X=x.upper() +%> + +.. _experimental-dynamic-link: + +================================================================================ +Multi Device Compile +================================================================================ + +.. warning:: + + Experimental features: + + * May be replaced, updated, or removed at any time. + * Do not require maintaining API/ABI stability of their own additions over + time. + * Do not require conformance testing of their own additions. + + + +Motivation +-------------------------------------------------------------------------------- + +Some adapters support the ability to do dynamic linking between programs, +resolving external symbols through that. This may allow AOT compiled binaries to +be linked, despite already having been built. + +API +-------------------------------------------------------------------------------- + +Enums +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* ${x}_device_info_t + * ${X}_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP + +Functions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* ${x}ProgramDynamicLinkExp + +Changelog +-------------------------------------------------------------------------------- + ++-----------+---------------------------------------------+ +| Revision | Changes | ++===========+=============================================+ +| 1.0 | Initial Draft | ++-----------+---------------------------------------------+ + +Support +-------------------------------------------------------------------------------- + +Adapters which support this experimental feature *must* return ``true`` when +queried for ${X}_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP via +${x}DeviceGetInfo. Conversely, before using any of the functionality defined +in this experimental feature the user *must* use the device query to determine +if the adapter supports this feature. + +Contributors +-------------------------------------------------------------------------------- + +* Steffen Larsen `steffen.larsen@intel.com `_ diff --git a/unified-runtime/scripts/core/EXP-MULTI-DEVICE-COMPILE.rst b/unified-runtime/scripts/core/EXP-MULTI-DEVICE-COMPILE.rst index 19dc26268b318..9f1bb7e912fbb 100644 --- a/unified-runtime/scripts/core/EXP-MULTI-DEVICE-COMPILE.rst +++ b/unified-runtime/scripts/core/EXP-MULTI-DEVICE-COMPILE.rst @@ -36,6 +36,8 @@ Enums ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ * ${x}_device_info_t * ${X}_DEVICE_INFO_MULTI_DEVICE_COMPILE_SUPPORT_EXP +* ${x}_exp_program_flags_t + * ${X}_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/unified-runtime/scripts/core/exp-dynamic-link.yml b/unified-runtime/scripts/core/exp-dynamic-link.yml new file mode 100644 index 0000000000000..5c5cdba6686e3 --- /dev/null +++ b/unified-runtime/scripts/core/exp-dynamic-link.yml @@ -0,0 +1,50 @@ +# +# Copyright (C) 2023 Intel Corporation +# +# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +# See LICENSE.TXT +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# See YaML.md for syntax definition +# +--- #-------------------------------------------------------------------------- +type: header +desc: "Intel $OneApi Unified Runtime Experimental APIs for dynamic linking" +ordinal: "99" +--- #-------------------------------------------------------------------------- +type: enum +extend: true +typed_etors: true +desc: "Extension enums for $x_device_info_t to support dynamic linking." +name: $x_device_info_t +etors: + - name: DYNAMIC_LINK_SUPPORT_EXP + value: "0x6001" + desc: "[$x_bool_t] Returns true if the device supports the dynamic linking experimental feature." +--- #-------------------------------------------------------------------------- +type: function +desc: "Creates dynamic links between exported and imported symbols in one or more programs." +class: $xProgram +name: DynamicLinkExp +decl: static +ordinal: "2" +details: + - "The application may call this function from simultaneous threads." + - "Following a successful call to this entry point the programs in `phPrograms` will have all external symbols resolved and kernels inside these programs would be ready for use." +params: + - type: $x_context_handle_t + name: hContext + desc: "[in] handle of the context instance." + - type: uint32_t + name: count + desc: "[in] number of program handles in `phPrograms`." + - type: const $x_program_handle_t* + name: phPrograms + desc: "[in][range(0, count)] pointer to array of program handles." +returns: + - $X_RESULT_ERROR_INVALID_PROGRAM: + - "If one of the programs in `phPrograms` isn't a valid program object." + - $X_RESULT_ERROR_INVALID_SIZE: + - "`count == 0`" + - $X_RESULT_ERROR_PROGRAM_LINK_FAILURE: + - "If an error occurred while linking `phPrograms`." diff --git a/unified-runtime/scripts/core/exp-multi-device-compile.yml b/unified-runtime/scripts/core/exp-multi-device-compile.yml index 06309cb019dfc..39bdaa407cfa8 100644 --- a/unified-runtime/scripts/core/exp-multi-device-compile.yml +++ b/unified-runtime/scripts/core/exp-multi-device-compile.yml @@ -22,6 +22,15 @@ etors: value: "0x6000" desc: "[$x_bool_t] Returns true if the device supports the multi device compile experimental feature." --- #-------------------------------------------------------------------------- +type: enum +desc: "Program operation behavior control flags" +class: $xProgram +name: $x_exp_program_flags_t +etors: + - name: ALLOW_UNRESOLVED_SYMBOLS + desc: "Allow unresolved symbols in the program resulting from the corresponding operation" + value: "$X_BIT(0)" +--- #-------------------------------------------------------------------------- type: function desc: "Produces an executable program from one program, negates need for the linking step." class: $xProgram @@ -43,6 +52,9 @@ params: - type: $x_device_handle_t* name: phDevices desc: "[in][range(0, numDevices)] pointer to array of device handles" + - type: $x_exp_program_flags_t + name: flags + desc: "[in] program information flags" - type: const char* name: pOptions desc: "[in][optional] pointer to build options null-terminated string." @@ -74,6 +86,9 @@ params: - type: $x_device_handle_t* name: phDevices desc: "[in][range(0, numDevices)] pointer to array of device handles" + - type: $x_exp_program_flags_t + name: flags + desc: "[in] program information flags" - type: const char* name: pOptions desc: "[in][optional] pointer to build options null-terminated string." @@ -106,6 +121,9 @@ params: - type: $x_device_handle_t* name: phDevices desc: "[in][range(0, numDevices)] pointer to array of device handles" + - type: $x_exp_program_flags_t + name: flags + desc: "[in] program information flags" - type: uint32_t name: count desc: "[in] number of program handles in `phPrograms`." diff --git a/unified-runtime/scripts/core/registry.yml b/unified-runtime/scripts/core/registry.yml index a6237d93bf5ce..13d89f2652359 100644 --- a/unified-runtime/scripts/core/registry.yml +++ b/unified-runtime/scripts/core/registry.yml @@ -670,6 +670,9 @@ etors: - name: BINDLESS_IMAGES_SUPPORTS_IMPORTING_HANDLE_TYPE_EXP desc: Enumerator for $xBindlessImagesSupportsImportingHandleTypeExp value: '288' +- name: PROGRAM_DYNAMIC_LINK_EXP + desc: Enumerator for $xProgramDynamicLinkExp + value: '289' --- type: enum desc: Defines structure types diff --git a/unified-runtime/source/adapters/cuda/device.cpp b/unified-runtime/source/adapters/cuda/device.cpp index 03d9a13999f84..e500ad3546c06 100644 --- a/unified-runtime/source/adapters/cuda/device.cpp +++ b/unified-runtime/source/adapters/cuda/device.cpp @@ -1171,6 +1171,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return ReturnValue(true); case UR_DEVICE_INFO_MULTI_DEVICE_COMPILE_SUPPORT_EXP: return ReturnValue(false); + case UR_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP: + return ReturnValue(false); case UR_DEVICE_INFO_ASYNC_USM_ALLOCATIONS_SUPPORT_EXP: return ReturnValue(true); case UR_DEVICE_INFO_KERNEL_LAUNCH_CAPABILITIES: { diff --git a/unified-runtime/source/adapters/cuda/program.cpp b/unified-runtime/source/adapters/cuda/program.cpp index 15afa125a80a2..9cea38ab02e0c 100644 --- a/unified-runtime/source/adapters/cuda/program.cpp +++ b/unified-runtime/source/adapters/cuda/program.cpp @@ -201,6 +201,7 @@ urProgramCompile(ur_context_handle_t hContext, ur_program_handle_t hProgram, UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, const char *) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } @@ -208,6 +209,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t, UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, const char *) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } @@ -229,15 +231,21 @@ urProgramBuild(ur_context_handle_t /*hContext*/, ur_program_handle_t hProgram, return UR_RESULT_SUCCESS; } -UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp( - ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t, - const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) { +UR_APIEXPORT ur_result_t UR_APICALL +urProgramLinkExp(ur_context_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, uint32_t, const ur_program_handle_t *, + const char *, ur_program_handle_t *phProgram) { if (nullptr != phProgram) { *phProgram = nullptr; } return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } +UR_APIEXPORT ur_result_t UR_APICALL urProgramDynamicLinkExp( + ur_context_handle_t, uint32_t, const ur_program_handle_t *) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +} + /// Creates a new UR program object that is the outcome of linking all input /// programs. /// \TODO Implement linker options, requires mapping of OpenCL to CUDA diff --git a/unified-runtime/source/adapters/cuda/ur_interface_loader.cpp b/unified-runtime/source/adapters/cuda/ur_interface_loader.cpp index 8430df0ab0678..3fcc2ce47f41b 100644 --- a/unified-runtime/source/adapters/cuda/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/cuda/ur_interface_loader.cpp @@ -468,6 +468,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( pDdiTable->pfnBuildExp = urProgramBuildExp; pDdiTable->pfnCompileExp = urProgramCompileExp; pDdiTable->pfnLinkExp = urProgramLinkExp; + pDdiTable->pfnDynamicLinkExp = urProgramDynamicLinkExp; return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/hip/device.cpp b/unified-runtime/source/adapters/hip/device.cpp index c48033ec88826..fa9c81a819134 100644 --- a/unified-runtime/source/adapters/hip/device.cpp +++ b/unified-runtime/source/adapters/hip/device.cpp @@ -1041,6 +1041,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return ReturnValue(true); case UR_DEVICE_INFO_MULTI_DEVICE_COMPILE_SUPPORT_EXP: return ReturnValue(false); + case UR_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP: + return ReturnValue(false); case UR_DEVICE_INFO_KERNEL_LAUNCH_CAPABILITIES: return ReturnValue(0); case UR_DEVICE_INFO_MEMORY_EXPORT_EXPORTABLE_DEVICE_MEM_EXP: diff --git a/unified-runtime/source/adapters/hip/program.cpp b/unified-runtime/source/adapters/hip/program.cpp index 94451730c66ca..cd45448fa7ceb 100644 --- a/unified-runtime/source/adapters/hip/program.cpp +++ b/unified-runtime/source/adapters/hip/program.cpp @@ -284,6 +284,7 @@ urProgramCompile(ur_context_handle_t hContext, ur_program_handle_t hProgram, UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, const char *) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } @@ -291,6 +292,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t, UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, const char *) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } @@ -313,15 +315,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t, return UR_RESULT_SUCCESS; } -UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp( - ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t, - const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) { +UR_APIEXPORT ur_result_t UR_APICALL +urProgramLinkExp(ur_context_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, uint32_t, const ur_program_handle_t *, + const char *, ur_program_handle_t *phProgram) { if (nullptr != phProgram) { *phProgram = nullptr; } return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } +UR_APIEXPORT ur_result_t UR_APICALL urProgramDynamicLinkExp( + ur_context_handle_t, uint32_t, const ur_program_handle_t *) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +} + UR_APIEXPORT ur_result_t UR_APICALL urProgramLink(ur_context_handle_t, uint32_t, const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) { diff --git a/unified-runtime/source/adapters/hip/ur_interface_loader.cpp b/unified-runtime/source/adapters/hip/ur_interface_loader.cpp index dfb4382cad828..2a4fdd3c54270 100644 --- a/unified-runtime/source/adapters/hip/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/hip/ur_interface_loader.cpp @@ -461,6 +461,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( pDdiTable->pfnBuildExp = urProgramBuildExp; pDdiTable->pfnCompileExp = urProgramCompileExp; pDdiTable->pfnLinkExp = urProgramLinkExp; + pDdiTable->pfnDynamicLinkExp = urProgramDynamicLinkExp; return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/level_zero/device.cpp b/unified-runtime/source/adapters/level_zero/device.cpp index 40ab701312292..b450b0bc632c4 100644 --- a/unified-runtime/source/adapters/level_zero/device.cpp +++ b/unified-runtime/source/adapters/level_zero/device.cpp @@ -1288,6 +1288,8 @@ ur_result_t urDeviceGetInfo( return ReturnValue(true); case UR_DEVICE_INFO_MULTI_DEVICE_COMPILE_SUPPORT_EXP: return ReturnValue(true); + case UR_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP: + return ReturnValue(true); case UR_DEVICE_INFO_ASYNC_USM_ALLOCATIONS_SUPPORT_EXP: return ReturnValue(true); case UR_DEVICE_INFO_CURRENT_CLOCK_THROTTLE_REASONS: { diff --git a/unified-runtime/source/adapters/level_zero/program.cpp b/unified-runtime/source/adapters/level_zero/program.cpp index 6b8fa9fff2db2..ef4f061f7aac6 100644 --- a/unified-runtime/source/adapters/level_zero/program.cpp +++ b/unified-runtime/source/adapters/level_zero/program.cpp @@ -161,7 +161,8 @@ ur_result_t urProgramBuild( const char *Options) { std::vector Devices = Context->getDevices(); return ur::level_zero::urProgramBuildExp(Program, Devices.size(), - Devices.data(), Options); + Devices.data(), + ur_exp_program_flags_t{}, Options); } ur_result_t urProgramBuildExp( @@ -171,6 +172,8 @@ ur_result_t urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { // TODO @@ -251,10 +254,11 @@ ur_result_t urProgramBuildExp( } else { // The call to zeModuleCreate does not report an error if there are // unresolved symbols because it thinks these could be resolved later via - // a call to zeModuleDynamicLink. However, modules created with - // urProgramBuild are supposed to be fully linked and ready to use. - // Therefore, do an extra check now for unresolved symbols. - ZeResult = checkUnresolvedSymbols(ZeModuleHandle, &ZeBuildLog); + // a call to zeModuleDynamicLink. However, unless explicitly allowed, + // modules created with urProgramBuild are supposed to be fully linked and + // ready to use. Therefore, do an extra check now for unresolved symbols. + if (!(flags & UR_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS)) + ZeResult = checkUnresolvedSymbols(ZeModuleHandle, &ZeBuildLog); if (ZeResult != ZE_RESULT_SUCCESS) { hProgram->setState(ZeDevice, ur_program_handle_t_::Invalid); Result = (ZeResult == ZE_RESULT_ERROR_MODULE_LINK_FAILURE) @@ -280,6 +284,8 @@ ur_result_t urProgramCompileExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + [[maybe_unused]] ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { std::scoped_lock Guard(hProgram->Mutex); @@ -325,7 +331,8 @@ ur_result_t urProgramCompile( const char *Options) { auto devices = Context->getDevices(); return ur::level_zero::urProgramCompileExp(Program, devices.size(), - devices.data(), Options); + devices.data(), + ur_exp_program_flags_t{}, Options); } ur_result_t urProgramLink( @@ -340,9 +347,9 @@ ur_result_t urProgramLink( /// [out] pointer to handle of program object created. ur_program_handle_t *Program) { std::vector Devices = Context->getDevices(); - return ur::level_zero::urProgramLinkExp(Context, Devices.size(), - Devices.data(), Count, Programs, - Options, Program); + return ur::level_zero::urProgramLinkExp( + Context, Devices.size(), Devices.data(), ur_exp_program_flags_t{}, Count, + Programs, Options, Program); } ur_result_t urProgramLinkExp( @@ -352,6 +359,8 @@ ur_result_t urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. @@ -383,17 +392,18 @@ ur_result_t urProgramLinkExp( ur_result_t UrResult = UR_RESULT_SUCCESS; try { - // Acquire a "shared" lock on each of the input programs, and also validate - // that they are all in Object state for each device in the input list. + // Acquire a "shared" lock on each of the input programs, and also + // validate that they are all in Object state for each device in the input + // list. // // There is no danger of deadlock here even if two threads call - // urProgramLink simultaneously with the same input programs in a different - // order. If we were acquiring these with "exclusive" access, this could - // lead to a classic lock ordering deadlock. However, there is no such - // deadlock potential with "shared" access. There could also be a deadlock - // potential if there was some other code that holds more than one of these - // locks simultaneously with "exclusive" access. However, there is no such - // code like that, so this is also not a danger. + // urProgramLink simultaneously with the same input programs in a + // different order. If we were acquiring these with "exclusive" access, + // this could lead to a classic lock ordering deadlock. However, there is + // no such deadlock potential with "shared" access. There could also be a + // deadlock potential if there was some other code that holds more than + // one of these locks simultaneously with "exclusive" access. However, + // there is no such code like that, so this is also not a danger. std::vector> Guards(count); const ur_program_handle_t_::CodeFormat CommonCodeFormat = phPrograms[0]->getCodeFormat(); @@ -418,11 +428,11 @@ ur_result_t urProgramLinkExp( // Previous calls to urProgramCompile did not actually compile the SPIR-V. // Instead, we postpone compilation until this point, when all the modules - // are linked together. By doing compilation and linking together, the JIT - // compiler is able see all modules and do cross-module optimizations. + // are linked together. By doing compilation and linking together, the + // JIT compiler is able see all modules and do cross-module optimizations. // - // Construct a ze_module_program_exp_desc_t which contains information about - // all of the modules that will be linked together. + // Construct a ze_module_program_exp_desc_t which contains information + // about all of the modules that will be linked together. ZeStruct ZeExtModuleDesc; std::vector CodeSizes(count); std::vector CodeBufs(count); @@ -466,15 +476,15 @@ ur_result_t urProgramLinkExp( ZeModuleDesc.pInputModule = reinterpret_cast(1); ZeModuleDesc.inputSize = 1; - // We need a Level Zero extension to compile multiple programs together into - // a single Level Zero module. However, we don't need that extension if - // there happens to be only one input program. + // We need a Level Zero extension to compile multiple programs together + // into a single Level Zero module. However, we don't need that extension + // if there happens to be only one input program. // // The "|| (NumInputPrograms == 1)" term is a workaround for a bug in the // Level Zero driver. The driver's "ze_module_program_exp_desc_t" // extension should work even in the case when there is just one input - // module. However, there is currently a bug in the driver that leads to a - // crash. As a workaround, do not use the extension when there is one + // module. However, there is currently a bug in the driver that leads to + // a crash. As a workaround, do not use the extension when there is one // input module. // // TODO: Remove this workaround when the driver is fixed. @@ -519,8 +529,9 @@ ur_result_t urProgramLinkExp( &ZeModule, &ZeBuildLog)); // We still create a ur_program_handle_t_ object even if there is a - // BUILD_FAILURE because we need the object to hold the ZeBuildLog. There - // is no build log created for other errors, so we don't create an object. + // BUILD_FAILURE because we need the object to hold the ZeBuildLog. + // There is no build log created for other errors, so we don't create an + // object. UrResult = ze2urResult(ZeResult); if (ZeResult != ZE_RESULT_SUCCESS && ZeResult != ZE_RESULT_ERROR_MODULE_BUILD_FAILURE) { @@ -528,13 +539,14 @@ ur_result_t urProgramLinkExp( } // The call to zeModuleCreate does not report an error if there are - // unresolved symbols because it thinks these could be resolved later via - // a call to zeModuleDynamicLink. However, modules created with + // unresolved symbols because it thinks these could be resolved later + // via a call to zeModuleDynamicLink. However, modules created with // piProgramLink are supposed to be fully linked and ready to use. - // Therefore, do an extra check now for unresolved symbols. Note that we - // still create a ur_program_handle_t_ if there are unresolved symbols - // because the ZeBuildLog tells which symbols are unresolved. - if (ZeResult == ZE_RESULT_SUCCESS) { + // Therefore, do an extra check now for unresolved symbols. Note that + // we still create a ur_program_handle_t_ if there are unresolved + // symbols because the ZeBuildLog tells which symbols are unresolved. + if (ZeResult == ZE_RESULT_SUCCESS && + !(flags & UR_EXP_PROGRAM_FLAG_ALLOW_UNRESOLVED_SYMBOLS)) { ZeResult = checkUnresolvedSymbols(ZeModule, &ZeBuildLog); UrResult = ze2urResult(ZeResult); } @@ -555,6 +567,49 @@ ur_result_t urProgramLinkExp( return UrResult; } +ur_result_t urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms) { + ur_result_t UrResult = UR_RESULT_SUCCESS; + + try { + // Reserve room for all modules. It may be too much on some devices, which + // is why we do not resize. + std::vector ZeModules; + ZeModules.reserve(count); + + for (ur_device_handle_t Device : hContext->getDevices()) { + for (uint32_t I = 0; I < count; ++I) + if (phPrograms[I]->hasZeModuleForDevice(Device->ZeDevice)) + ZeModules.push_back( + phPrograms[I]->getZeModuleHandle(Device->ZeDevice)); + + if (ZeModules.empty()) + continue; + + // TODO: What should be done with the log? Since there is no result + // program, what can it be attached to? + ze_result_t ZeResult = ZE_CALL_NOCHECK( + zeModuleDynamicLink, (ZeModules.size(), ZeModules.data(), nullptr)); + + if (ZeResult != ZE_RESULT_SUCCESS) + return ze2urResult(ZeResult); + + // Clear so the storage stays allocated, but the size is reset to 0. + ZeModules.clear(); + } + } catch (const std::bad_alloc &) { + return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } catch (...) { + return UR_RESULT_ERROR_UNKNOWN; + } + return UrResult; +} + ur_result_t urProgramRetain( /// [in] handle for the Program to retain ur_program_handle_t Program) { @@ -608,7 +663,8 @@ ur_result_t urProgramGetFunctionPointer( ur_program_handle_t Program, /// [in] A null-terminates string denoting the mangled function name. const char *FunctionName, - /// [out] Returns the pointer to the function if it is found in the program. + /// [out] Returns the pointer to the function if it is found in the + /// program. void **FunctionPointerRet) { std::shared_lock Guard(Program->Mutex); if (Program->getState(Device->ZeDevice) != ur_program_handle_t_::Exe) { @@ -696,10 +752,11 @@ ur_result_t urProgramGetInfo( ur_program_info_t PropName, /// [in] the size of the Program property. size_t PropSize, - /// [in,out][optional] array of bytes of holding the program info property. - /// If propSize is not equal to or greater than the real number of bytes - /// needed to return the info then the ::UR_RESULT_ERROR_INVALID_SIZE error - /// is returned and pProgramInfo is not used. + /// [in,out][optional] array of bytes of holding the program info + /// property. If propSize is not equal to or greater than the real number + /// of bytes needed to return the info then the + /// ::UR_RESULT_ERROR_INVALID_SIZE error is returned and pProgramInfo is + /// not used. void *ProgramInfo, /// [out][optional] pointer to the actual size in bytes of data copied to /// propName. @@ -749,12 +806,13 @@ ur_result_t urProgramGetInfo( std::shared_lock Guard(Program->Mutex); size_t NumDevices = Program->AssociatedDevices.size(); if (PropSizeRet) { - // Return the size of the array of pointers to binaries (for each device). + // Return the size of the array of pointers to binaries (for each + // device). *PropSizeRet = NumDevices * sizeof(uint8_t *); } - // If the caller did not provide an array of pointers to copy binaries into, - // return early. + // If the caller did not provide an array of pointers to copy binaries + // into, return early. if (!ProgramInfo) break; @@ -1005,8 +1063,9 @@ ur_result_t urProgramSetSpecializationConstants( std::scoped_lock Guard(Program->Mutex); // Remember the value of this specialization constant until the program is - // built. Note that we only save the pointer to the buffer that contains the - // value. The caller is responsible for maintaining storage for this buffer. + // built. Note that we only save the pointer to the buffer that contains + // the value. The caller is responsible for maintaining storage for this + // buffer. // // NOTE: SpecSize is unused in Level Zero, the size is known from SPIR-V by // SpecID. diff --git a/unified-runtime/source/adapters/level_zero/program.hpp b/unified-runtime/source/adapters/level_zero/program.hpp index fb2cc1f12ee5e..91bdd8c64d19a 100644 --- a/unified-runtime/source/adapters/level_zero/program.hpp +++ b/unified-runtime/source/adapters/level_zero/program.hpp @@ -113,11 +113,15 @@ struct ur_program_handle_t_ : ur_object { return DeviceDataMap[ZeDevice].State; } - ze_module_handle_t getZeModuleHandle(ze_device_handle_t ZeDevice) { + bool hasZeModuleForDevice(ze_device_handle_t ZeDevice) const { + return DeviceDataMap.find(ZeDevice) != DeviceDataMap.end(); + } + + ze_module_handle_t getZeModuleHandle(ze_device_handle_t ZeDevice) const { if (DeviceDataMap.find(ZeDevice) == DeviceDataMap.end()) return InteropZeModule; - return DeviceDataMap[ZeDevice].ZeModule; + return DeviceDataMap.at(ZeDevice).ZeModule; } CodeFormat getCodeFormat(ze_device_handle_t ZeDevice = nullptr) const { diff --git a/unified-runtime/source/adapters/level_zero/ur_interface_loader.cpp b/unified-runtime/source/adapters/level_zero/ur_interface_loader.cpp index 13d7274e7aebf..38195bca41699 100644 --- a/unified-runtime/source/adapters/level_zero/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/level_zero/ur_interface_loader.cpp @@ -398,6 +398,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( return result; } + pDdiTable->pfnDynamicLinkExp = ur::level_zero::urProgramDynamicLinkExp; pDdiTable->pfnBuildExp = ur::level_zero::urProgramBuildExp; pDdiTable->pfnCompileExp = ur::level_zero::urProgramCompileExp; pDdiTable->pfnLinkExp = ur::level_zero::urProgramLinkExp; diff --git a/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp b/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp index 77bc0b7d5b737..7a291011a36fc 100644 --- a/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp +++ b/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp @@ -767,6 +767,9 @@ urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer, ur_result_t urCommandBufferGetNativeHandleExp(ur_exp_command_buffer_handle_t hCommandBuffer, ur_native_handle_t *phNativeCommandBuffer); +ur_result_t urProgramDynamicLinkExp(ur_context_handle_t hContext, + uint32_t count, + const ur_program_handle_t *phPrograms); ur_result_t urEnqueueTimestampRecordingExp( ur_queue_handle_t hQueue, bool blocking, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent); @@ -782,13 +785,16 @@ ur_result_t urMemoryExportExportMemoryHandleExp( void *pMemHandleRet); ur_result_t urProgramBuildExp(ur_program_handle_t hProgram, uint32_t numDevices, ur_device_handle_t *phDevices, + ur_exp_program_flags_t flags, const char *pOptions); ur_result_t urProgramCompileExp(ur_program_handle_t hProgram, uint32_t numDevices, ur_device_handle_t *phDevices, + ur_exp_program_flags_t flags, const char *pOptions); ur_result_t urProgramLinkExp(ur_context_handle_t hContext, uint32_t numDevices, - ur_device_handle_t *phDevices, uint32_t count, + ur_device_handle_t *phDevices, + ur_exp_program_flags_t flags, uint32_t count, const ur_program_handle_t *phPrograms, const char *pOptions, ur_program_handle_t *phProgram); diff --git a/unified-runtime/source/adapters/mock/ur_mockddi.cpp b/unified-runtime/source/adapters/mock/ur_mockddi.cpp index c7ecf0979b8f5..3f73a906ab170 100644 --- a/unified-runtime/source/adapters/mock/ur_mockddi.cpp +++ b/unified-runtime/source/adapters/mock/ur_mockddi.cpp @@ -11239,6 +11239,53 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferGetNativeHandleExp( return exceptionToResult(std::current_exception()); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramDynamicLinkExp +__urdlllocal ur_result_t UR_APICALL urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms) try { + ur_result_t result = UR_RESULT_SUCCESS; + + ur_program_dynamic_link_exp_params_t params = {&hContext, &count, + &phPrograms}; + + auto beforeCallback = reinterpret_cast( + mock::getCallbacks().get_before_callback("urProgramDynamicLinkExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + mock::getCallbacks().get_replace_callback("urProgramDynamicLinkExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + mock::getCallbacks().get_after_callback("urProgramDynamicLinkExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEnqueueTimestampRecordingExp __urdlllocal ur_result_t UR_APICALL urEnqueueTimestampRecordingExp( @@ -11479,12 +11526,14 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) try { ur_result_t result = UR_RESULT_SUCCESS; ur_program_build_exp_params_t params = {&hProgram, &numDevices, &phDevices, - &pOptions}; + &flags, &pOptions}; auto beforeCallback = reinterpret_cast( mock::getCallbacks().get_before_callback("urProgramBuildExp")); @@ -11528,12 +11577,14 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) try { ur_result_t result = UR_RESULT_SUCCESS; ur_program_compile_exp_params_t params = {&hProgram, &numDevices, &phDevices, - &pOptions}; + &flags, &pOptions}; auto beforeCallback = reinterpret_cast( mock::getCallbacks().get_before_callback("urProgramCompileExp")); @@ -11577,6 +11628,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. @@ -11591,8 +11644,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( } ur_program_link_exp_params_t params = {&hContext, &numDevices, &phDevices, - &count, &phPrograms, &pOptions, - &phProgram}; + &flags, &count, &phPrograms, + &pOptions, &phProgram}; auto beforeCallback = reinterpret_cast( mock::getCallbacks().get_before_callback("urProgramLinkExp")); @@ -12782,6 +12835,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( ur_result_t result = UR_RESULT_SUCCESS; + pDdiTable->pfnDynamicLinkExp = driver::urProgramDynamicLinkExp; + pDdiTable->pfnBuildExp = driver::urProgramBuildExp; pDdiTable->pfnCompileExp = driver::urProgramCompileExp; diff --git a/unified-runtime/source/adapters/native_cpu/device.cpp b/unified-runtime/source/adapters/native_cpu/device.cpp index 369b4cd7ed013..7c4dba411182e 100644 --- a/unified-runtime/source/adapters/native_cpu/device.cpp +++ b/unified-runtime/source/adapters/native_cpu/device.cpp @@ -446,6 +446,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, case UR_DEVICE_INFO_MULTI_DEVICE_COMPILE_SUPPORT_EXP: return ReturnValue(true); + case UR_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP: + return ReturnValue(true); + case UR_DEVICE_INFO_GLOBAL_VARIABLE_SUPPORT: return ReturnValue(false); diff --git a/unified-runtime/source/adapters/native_cpu/program.cpp b/unified-runtime/source/adapters/native_cpu/program.cpp index fee72f8a6bc3c..be7dc0764fb6d 100644 --- a/unified-runtime/source/adapters/native_cpu/program.cpp +++ b/unified-runtime/source/adapters/native_cpu/program.cpp @@ -143,6 +143,7 @@ urProgramLink(ur_context_handle_t /*hContext*/, uint32_t /*count*/, UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, const char *) { // Currently for Native CPU the program is offline compiled, so // urProgramCompile is a no-op. @@ -152,15 +153,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t, UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, const char *) { // Currently for Native CPU the program is offline compiled and linked, // so urProgramBuild is a no-op. return UR_RESULT_SUCCESS; } -UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp( - ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t, - const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) { +UR_APIEXPORT ur_result_t UR_APICALL +urProgramLinkExp(ur_context_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, uint32_t, const ur_program_handle_t *, + const char *, ur_program_handle_t *phProgram) { if (nullptr != phProgram) { *phProgram = nullptr; } @@ -169,6 +172,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp( return UR_RESULT_SUCCESS; } +UR_APIEXPORT ur_result_t UR_APICALL urProgramDynamicLinkExp( + ur_context_handle_t, uint32_t, const ur_program_handle_t *) { + // Currently for Native CPU the program is already linked and all its + // symbols are resolved, so this is a no-op. + return UR_RESULT_SUCCESS; +} + UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain(ur_program_handle_t hProgram) { hProgram->incrementReferenceCount(); diff --git a/unified-runtime/source/adapters/native_cpu/ur_interface_loader.cpp b/unified-runtime/source/adapters/native_cpu/ur_interface_loader.cpp index 3f6fe061b4917..6135420022a0e 100644 --- a/unified-runtime/source/adapters/native_cpu/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/native_cpu/ur_interface_loader.cpp @@ -445,6 +445,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( pDdiTable->pfnBuildExp = urProgramBuildExp; pDdiTable->pfnCompileExp = urProgramCompileExp; pDdiTable->pfnLinkExp = urProgramLinkExp; + pDdiTable->pfnDynamicLinkExp = urProgramDynamicLinkExp; return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp index fa3adfdaf92bd..3b575e7f4af25 100644 --- a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp @@ -398,6 +398,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( pDdiTable->pfnBuildExp = urProgramBuildExp; pDdiTable->pfnCompileExp = nullptr; pDdiTable->pfnLinkExp = nullptr; + pDdiTable->pfnDynamicLinkExp = nullptr; return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/opencl/device.cpp b/unified-runtime/source/adapters/opencl/device.cpp index eac2c9fe0bf2c..9a6be2a56a8ed 100644 --- a/unified-runtime/source/adapters/opencl/device.cpp +++ b/unified-runtime/source/adapters/opencl/device.cpp @@ -1425,6 +1425,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return ReturnValue(false); case UR_DEVICE_INFO_MULTI_DEVICE_COMPILE_SUPPORT_EXP: return ReturnValue(true); + case UR_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP: + return ReturnValue(false); case UR_DEVICE_INFO_KERNEL_LAUNCH_CAPABILITIES: return ReturnValue(0); case UR_DEVICE_INFO_LUID: { diff --git a/unified-runtime/source/adapters/opencl/program.cpp b/unified-runtime/source/adapters/opencl/program.cpp index dffd9aed9074b..dbd963debfd27 100644 --- a/unified-runtime/source/adapters/opencl/program.cpp +++ b/unified-runtime/source/adapters/opencl/program.cpp @@ -294,6 +294,7 @@ urProgramLink(ur_context_handle_t hContext, uint32_t count, UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, const char *) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } @@ -301,19 +302,26 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t, UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, const char *) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } -UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp( - ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t, - const ur_program_handle_t *, const char *, ur_program_handle_t *phProgram) { +UR_APIEXPORT ur_result_t UR_APICALL +urProgramLinkExp(ur_context_handle_t, uint32_t, ur_device_handle_t *, + ur_exp_program_flags_t, uint32_t, const ur_program_handle_t *, + const char *, ur_program_handle_t *phProgram) { if (nullptr != phProgram) { *phProgram = nullptr; } return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } +UR_APIEXPORT ur_result_t UR_APICALL urProgramDynamicLinkExp( + ur_context_handle_t, uint32_t, const ur_program_handle_t *) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +} + static cl_int mapURProgramBuildInfoToCL(ur_program_build_info_t URPropName) { switch (static_cast(URPropName)) { diff --git a/unified-runtime/source/adapters/opencl/ur_interface_loader.cpp b/unified-runtime/source/adapters/opencl/ur_interface_loader.cpp index c619fa36b1ab0..5647a5582a1f2 100644 --- a/unified-runtime/source/adapters/opencl/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/opencl/ur_interface_loader.cpp @@ -448,6 +448,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( pDdiTable->pfnBuildExp = urProgramBuildExp; pDdiTable->pfnCompileExp = urProgramCompileExp; pDdiTable->pfnLinkExp = urProgramLinkExp; + pDdiTable->pfnDynamicLinkExp = urProgramDynamicLinkExp; return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/loader/layers/sanitizer/asan/asan_ddi.cpp b/unified-runtime/source/loader/layers/sanitizer/asan/asan_ddi.cpp index 012aa1422cfa1..90e26e5566aa7 100644 --- a/unified-runtime/source/loader/layers/sanitizer/asan/asan_ddi.cpp +++ b/unified-runtime/source/loader/layers/sanitizer/asan/asan_ddi.cpp @@ -345,6 +345,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { auto pfnBuildExp = getContext()->urDdiTable.ProgramExp.pfnBuildExp; @@ -355,7 +357,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( UR_LOG_L(getContext()->logger, DEBUG, "==== urProgramBuildExp"); - auto UrRes = pfnBuildExp(hProgram, numDevices, phDevices, pOptions); + auto UrRes = pfnBuildExp(hProgram, numDevices, phDevices, flags, pOptions); if (UrRes != UR_RESULT_SUCCESS) { PrintUrBuildLogIfError(UrRes, hProgram, phDevices, numDevices); return UrRes; @@ -409,6 +411,8 @@ ur_result_t UR_APICALL urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. @@ -425,7 +429,7 @@ ur_result_t UR_APICALL urProgramLinkExp( UR_LOG_L(getContext()->logger, DEBUG, "==== urProgramLinkExp"); - auto UrRes = pfnProgramLinkExp(hContext, numDevices, phDevices, count, + auto UrRes = pfnProgramLinkExp(hContext, numDevices, phDevices, flags, count, phPrograms, pOptions, phProgram); if (UrRes != UR_RESULT_SUCCESS) { PrintUrBuildLogIfError(UrRes, *phProgram, phDevices, numDevices); @@ -438,6 +442,27 @@ ur_result_t UR_APICALL urProgramLinkExp( return UR_RESULT_SUCCESS; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramDynamicLinkExp +ur_result_t UR_APICALL urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms) { + auto pfnProgramDynamicLinkExp = + getContext()->urDdiTable.ProgramExp.pfnDynamicLinkExp; + + if (nullptr == pfnProgramDynamicLinkExp) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + UR_LOG_L(getContext()->logger, DEBUG, "==== urProgramDynamicLinkExp"); + + return pfnProgramDynamicLinkExp(hContext, count, phPrograms); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRelease ur_result_t UR_APICALL urProgramRelease( @@ -1848,6 +1873,8 @@ __urdlllocal ur_result_t UR_APICALL urGetProgramExpProcAddrTable( pDdiTable->pfnBuildExp = ur_sanitizer_layer::asan::urProgramBuildExp; pDdiTable->pfnLinkExp = ur_sanitizer_layer::asan::urProgramLinkExp; + pDdiTable->pfnDynamicLinkExp = + ur_sanitizer_layer::asan::urProgramDynamicLinkExp; return result; } diff --git a/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp b/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp index 810fd76120b2d..ebd1b197988b3 100644 --- a/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp +++ b/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp @@ -329,13 +329,15 @@ ur_result_t urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { auto pfnBuildExp = getContext()->urDdiTable.ProgramExp.pfnBuildExp; UR_LOG_L(getContext()->logger, DEBUG, "==== urProgramBuildExp"); - auto UrRes = pfnBuildExp(hProgram, numDevices, phDevices, pOptions); + auto UrRes = pfnBuildExp(hProgram, numDevices, phDevices, flags, pOptions); if (UrRes != UR_RESULT_SUCCESS) { PrintUrBuildLogIfError(UrRes, hProgram, phDevices, numDevices); return UrRes; @@ -385,6 +387,8 @@ ur_result_t urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. @@ -397,7 +401,7 @@ ur_result_t urProgramLinkExp( UR_LOG_L(getContext()->logger, DEBUG, "==== urProgramLinkExp"); - auto UrRes = pfnProgramLinkExp(hContext, numDevices, phDevices, count, + auto UrRes = pfnProgramLinkExp(hContext, numDevices, phDevices, flags, count, phPrograms, pOptions, phProgram); if (UrRes != UR_RESULT_SUCCESS) { PrintUrBuildLogIfError(UrRes, *phProgram, phDevices, numDevices); @@ -410,6 +414,23 @@ ur_result_t urProgramLinkExp( return UR_RESULT_SUCCESS; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramDynamicLinkExp +ur_result_t urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms) { + auto pfnProgramDynamicLinkExp = + getContext()->urDdiTable.ProgramExp.pfnDynamicLinkExp; + + UR_LOG_L(getContext()->logger, DEBUG, "==== urProgramDynamicLinkExp"); + + return pfnProgramDynamicLinkExp(hContext, count, phPrograms); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRelease ur_result_t urProgramRelease( @@ -1931,6 +1952,8 @@ ur_result_t urGetProgramExpProcAddrTable( pDdiTable->pfnBuildExp = ur_sanitizer_layer::msan::urProgramBuildExp; pDdiTable->pfnLinkExp = ur_sanitizer_layer::msan::urProgramLinkExp; + pDdiTable->pfnDynamicLinkExp = + ur_sanitizer_layer::msan::urProgramDynamicLinkExp; return result; } diff --git a/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_ddi.cpp b/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_ddi.cpp index 61849ac0b363a..48ab96effd63a 100644 --- a/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_ddi.cpp +++ b/unified-runtime/source/loader/layers/sanitizer/tsan/tsan_ddi.cpp @@ -322,12 +322,14 @@ ur_result_t urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { UR_LOG_L(getContext()->logger, DEBUG, "==== urProgramBuildExp"); auto UrRes = getContext()->urDdiTable.ProgramExp.pfnBuildExp( - hProgram, numDevices, phDevices, pOptions); + hProgram, numDevices, phDevices, flags, pOptions); if (UrRes != UR_RESULT_SUCCESS) { PrintUrBuildLogIfError(UrRes, hProgram, phDevices, numDevices); return UrRes; @@ -345,6 +347,8 @@ ur_result_t urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. @@ -356,7 +360,8 @@ ur_result_t urProgramLinkExp( UR_LOG_L(getContext()->logger, DEBUG, "==== urProgramLinkExp"); auto UrRes = getContext()->urDdiTable.ProgramExp.pfnLinkExp( - hContext, numDevices, phDevices, count, phPrograms, pOptions, phProgram); + hContext, numDevices, phDevices, flags, count, phPrograms, pOptions, + phProgram); if (UrRes != UR_RESULT_SUCCESS) { PrintUrBuildLogIfError(UrRes, *phProgram, phDevices, numDevices); return UrRes; @@ -368,6 +373,20 @@ ur_result_t urProgramLinkExp( return UR_RESULT_SUCCESS; } +/// @brief Intercept function for urProgramDynamicLinkExp +ur_result_t urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms) { + UR_LOG_L(getContext()->logger, DEBUG, "==== urProgramDynamicLinkExp"); + + return getContext()->urDdiTable.ProgramExp.pfnDynamicLinkExp(hContext, count, + phPrograms); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemBufferCreate ur_result_t urMemBufferCreate( @@ -1434,6 +1453,8 @@ ur_result_t urGetProgramExpProcAddrTable( pDdiTable->pfnBuildExp = ur_sanitizer_layer::tsan::urProgramBuildExp; pDdiTable->pfnLinkExp = ur_sanitizer_layer::tsan::urProgramLinkExp; + pDdiTable->pfnDynamicLinkExp = + ur_sanitizer_layer::tsan::urProgramDynamicLinkExp; return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp b/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp index 0bca15321849d..23702e4924df1 100644 --- a/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp +++ b/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp @@ -9517,6 +9517,46 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferGetNativeHandleExp( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramDynamicLinkExp +__urdlllocal ur_result_t UR_APICALL urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms) { + auto pfnDynamicLinkExp = + getContext()->urDdiTable.ProgramExp.pfnDynamicLinkExp; + + if (nullptr == pfnDynamicLinkExp) + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + + ur_program_dynamic_link_exp_params_t params = {&hContext, &count, + &phPrograms}; + uint64_t instance = getContext()->notify_begin( + UR_FUNCTION_PROGRAM_DYNAMIC_LINK_EXP, "urProgramDynamicLinkExp", ¶ms); + + auto &logger = getContext()->logger; + UR_LOG_L(logger, INFO, " ---> urProgramDynamicLinkExp\n"); + + ur_result_t result = pfnDynamicLinkExp(hContext, count, phPrograms); + + getContext()->notify_end(UR_FUNCTION_PROGRAM_DYNAMIC_LINK_EXP, + "urProgramDynamicLinkExp", ¶ms, &result, + instance); + + if (logger.getLevel() <= UR_LOGGER_LEVEL_INFO) { + std::ostringstream args_str; + ur::extras::printFunctionParams( + args_str, UR_FUNCTION_PROGRAM_DYNAMIC_LINK_EXP, ¶ms); + UR_LOG_L(logger, INFO, " <--- urProgramDynamicLinkExp({}) -> {};\n", + args_str.str(), result); + } + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEnqueueTimestampRecordingExp __urdlllocal ur_result_t UR_APICALL urEnqueueTimestampRecordingExp( @@ -9729,6 +9769,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { auto pfnBuildExp = getContext()->urDdiTable.ProgramExp.pfnBuildExp; @@ -9737,14 +9779,15 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; ur_program_build_exp_params_t params = {&hProgram, &numDevices, &phDevices, - &pOptions}; + &flags, &pOptions}; uint64_t instance = getContext()->notify_begin(UR_FUNCTION_PROGRAM_BUILD_EXP, "urProgramBuildExp", ¶ms); auto &logger = getContext()->logger; UR_LOG_L(logger, INFO, " ---> urProgramBuildExp\n"); - ur_result_t result = pfnBuildExp(hProgram, numDevices, phDevices, pOptions); + ur_result_t result = + pfnBuildExp(hProgram, numDevices, phDevices, flags, pOptions); getContext()->notify_end(UR_FUNCTION_PROGRAM_BUILD_EXP, "urProgramBuildExp", ¶ms, &result, instance); @@ -9769,6 +9812,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { auto pfnCompileExp = getContext()->urDdiTable.ProgramExp.pfnCompileExp; @@ -9777,14 +9822,15 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; ur_program_compile_exp_params_t params = {&hProgram, &numDevices, &phDevices, - &pOptions}; + &flags, &pOptions}; uint64_t instance = getContext()->notify_begin( UR_FUNCTION_PROGRAM_COMPILE_EXP, "urProgramCompileExp", ¶ms); auto &logger = getContext()->logger; UR_LOG_L(logger, INFO, " ---> urProgramCompileExp\n"); - ur_result_t result = pfnCompileExp(hProgram, numDevices, phDevices, pOptions); + ur_result_t result = + pfnCompileExp(hProgram, numDevices, phDevices, flags, pOptions); getContext()->notify_end(UR_FUNCTION_PROGRAM_COMPILE_EXP, "urProgramCompileExp", ¶ms, &result, instance); @@ -9809,6 +9855,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. @@ -9826,15 +9874,15 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; ur_program_link_exp_params_t params = {&hContext, &numDevices, &phDevices, - &count, &phPrograms, &pOptions, - &phProgram}; + &flags, &count, &phPrograms, + &pOptions, &phProgram}; uint64_t instance = getContext()->notify_begin(UR_FUNCTION_PROGRAM_LINK_EXP, "urProgramLinkExp", ¶ms); auto &logger = getContext()->logger; UR_LOG_L(logger, INFO, " ---> urProgramLinkExp\n"); - ur_result_t result = pfnLinkExp(hContext, numDevices, phDevices, count, + ur_result_t result = pfnLinkExp(hContext, numDevices, phDevices, flags, count, phPrograms, pOptions, phProgram); getContext()->notify_end(UR_FUNCTION_PROGRAM_LINK_EXP, "urProgramLinkExp", @@ -11162,6 +11210,9 @@ __urdlllocal ur_result_t UR_APICALL urGetProgramExpProcAddrTable( ur_result_t result = UR_RESULT_SUCCESS; + dditable.pfnDynamicLinkExp = pDdiTable->pfnDynamicLinkExp; + pDdiTable->pfnDynamicLinkExp = ur_tracing_layer::urProgramDynamicLinkExp; + dditable.pfnBuildExp = pDdiTable->pfnBuildExp; pDdiTable->pfnBuildExp = ur_tracing_layer::urProgramBuildExp; diff --git a/unified-runtime/source/loader/layers/validation/ur_valddi.cpp b/unified-runtime/source/loader/layers/validation/ur_valddi.cpp index 81151d4124b10..b1656f88ee788 100644 --- a/unified-runtime/source/loader/layers/validation/ur_valddi.cpp +++ b/unified-runtime/source/loader/layers/validation/ur_valddi.cpp @@ -10271,6 +10271,43 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferGetNativeHandleExp( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramDynamicLinkExp +__urdlllocal ur_result_t UR_APICALL urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms) { + auto pfnDynamicLinkExp = + getContext()->urDdiTable.ProgramExp.pfnDynamicLinkExp; + + if (nullptr == pfnDynamicLinkExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + if (getContext()->enableParameterValidation) { + if (NULL == phPrograms) + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + + if (NULL == hContext) + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + + if (count == 0) + return UR_RESULT_ERROR_INVALID_SIZE; + } + + if (getContext()->enableLifetimeValidation && + !getContext()->refCountContext->isReferenceValid(hContext)) { + URLOG_CTX_INVALID_REFERENCE(hContext); + } + + ur_result_t result = pfnDynamicLinkExp(hContext, count, phPrograms); + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEnqueueTimestampRecordingExp __urdlllocal ur_result_t UR_APICALL urEnqueueTimestampRecordingExp( @@ -10513,6 +10550,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { auto pfnBuildExp = getContext()->urDdiTable.ProgramExp.pfnBuildExp; @@ -10527,6 +10566,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( if (NULL == hProgram) return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + + if (UR_EXP_PROGRAM_FLAGS_MASK & flags) + return UR_RESULT_ERROR_INVALID_ENUMERATION; } if (getContext()->enableLifetimeValidation && @@ -10534,7 +10576,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( URLOG_CTX_INVALID_REFERENCE(hProgram); } - ur_result_t result = pfnBuildExp(hProgram, numDevices, phDevices, pOptions); + ur_result_t result = + pfnBuildExp(hProgram, numDevices, phDevices, flags, pOptions); return result; } @@ -10548,6 +10591,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { auto pfnCompileExp = getContext()->urDdiTable.ProgramExp.pfnCompileExp; @@ -10562,6 +10607,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( if (NULL == hProgram) return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + + if (UR_EXP_PROGRAM_FLAGS_MASK & flags) + return UR_RESULT_ERROR_INVALID_ENUMERATION; } if (getContext()->enableLifetimeValidation && @@ -10569,7 +10617,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( URLOG_CTX_INVALID_REFERENCE(hProgram); } - ur_result_t result = pfnCompileExp(hProgram, numDevices, phDevices, pOptions); + ur_result_t result = + pfnCompileExp(hProgram, numDevices, phDevices, flags, pOptions); return result; } @@ -10583,6 +10632,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. @@ -10613,6 +10664,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( if (NULL == hContext) return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + if (UR_EXP_PROGRAM_FLAGS_MASK & flags) + return UR_RESULT_ERROR_INVALID_ENUMERATION; + if (count == 0) return UR_RESULT_ERROR_INVALID_SIZE; } @@ -10622,7 +10676,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( URLOG_CTX_INVALID_REFERENCE(hContext); } - ur_result_t result = pfnLinkExp(hContext, numDevices, phDevices, count, + ur_result_t result = pfnLinkExp(hContext, numDevices, phDevices, flags, count, phPrograms, pOptions, phProgram); return result; @@ -11977,6 +12031,9 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( ur_result_t result = UR_RESULT_SUCCESS; + dditable.pfnDynamicLinkExp = pDdiTable->pfnDynamicLinkExp; + pDdiTable->pfnDynamicLinkExp = ur_validation_layer::urProgramDynamicLinkExp; + dditable.pfnBuildExp = pDdiTable->pfnBuildExp; pDdiTable->pfnBuildExp = ur_validation_layer::urProgramBuildExp; diff --git a/unified-runtime/source/loader/loader.def.in b/unified-runtime/source/loader/loader.def.in index e86a6c65a7957..f3d8b12c02a95 100644 --- a/unified-runtime/source/loader/loader.def.in +++ b/unified-runtime/source/loader/loader.def.in @@ -350,6 +350,7 @@ EXPORTS urPrintExpImageCopyRegion urPrintExpImageMemType urPrintExpPeerInfo + urPrintExpProgramFlags urPrintExpSamplerAddrModes urPrintExpSamplerCubemapFilterMode urPrintExpSamplerCubemapProperties @@ -451,6 +452,7 @@ EXPORTS urPrintProgramCreateWithBinaryParams urPrintProgramCreateWithIlParams urPrintProgramCreateWithNativeHandleParams + urPrintProgramDynamicLinkExpParams urPrintProgramGetBuildInfoParams urPrintProgramGetFunctionPointerParams urPrintProgramGetGlobalVariablePointerParams @@ -552,6 +554,7 @@ EXPORTS urProgramCreateWithBinary urProgramCreateWithIL urProgramCreateWithNativeHandle + urProgramDynamicLinkExp urProgramGetBuildInfo urProgramGetFunctionPointer urProgramGetGlobalVariablePointer diff --git a/unified-runtime/source/loader/loader.map.in b/unified-runtime/source/loader/loader.map.in index 6a30c9186f674..6eac3f2aa92a9 100644 --- a/unified-runtime/source/loader/loader.map.in +++ b/unified-runtime/source/loader/loader.map.in @@ -350,6 +350,7 @@ urPrintExpImageCopyRegion; urPrintExpImageMemType; urPrintExpPeerInfo; + urPrintExpProgramFlags; urPrintExpSamplerAddrModes; urPrintExpSamplerCubemapFilterMode; urPrintExpSamplerCubemapProperties; @@ -451,6 +452,7 @@ urPrintProgramCreateWithBinaryParams; urPrintProgramCreateWithIlParams; urPrintProgramCreateWithNativeHandleParams; + urPrintProgramDynamicLinkExpParams; urPrintProgramGetBuildInfoParams; urPrintProgramGetFunctionPointerParams; urPrintProgramGetGlobalVariablePointerParams; @@ -552,6 +554,7 @@ urProgramCreateWithBinary; urProgramCreateWithIL; urProgramCreateWithNativeHandle; + urProgramDynamicLinkExp; urProgramGetBuildInfo; urProgramGetFunctionPointer; urProgramGetGlobalVariablePointer; diff --git a/unified-runtime/source/loader/ur_ldrddi.cpp b/unified-runtime/source/loader/ur_ldrddi.cpp index d943fe99afdb6..bac156daede65 100644 --- a/unified-runtime/source/loader/ur_ldrddi.cpp +++ b/unified-runtime/source/loader/ur_ldrddi.cpp @@ -5415,6 +5415,26 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferGetNativeHandleExp( return pfnGetNativeHandleExp(hCommandBuffer, phNativeCommandBuffer); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramDynamicLinkExp +__urdlllocal ur_result_t UR_APICALL urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms) { + + auto *dditable = *reinterpret_cast(hContext); + + auto *pfnDynamicLinkExp = dditable->ProgramExp.pfnDynamicLinkExp; + if (nullptr == pfnDynamicLinkExp) + return UR_RESULT_ERROR_UNINITIALIZED; + + // forward to device-platform + return pfnDynamicLinkExp(hContext, count, phPrograms); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEnqueueTimestampRecordingExp __urdlllocal ur_result_t UR_APICALL urEnqueueTimestampRecordingExp( @@ -5541,6 +5561,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { @@ -5551,7 +5573,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramBuildExp( return UR_RESULT_ERROR_UNINITIALIZED; // forward to device-platform - return pfnBuildExp(hProgram, numDevices, phDevices, pOptions); + return pfnBuildExp(hProgram, numDevices, phDevices, flags, pOptions); } /////////////////////////////////////////////////////////////////////////////// @@ -5563,6 +5585,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { @@ -5573,7 +5597,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramCompileExp( return UR_RESULT_ERROR_UNINITIALIZED; // forward to device-platform - return pfnCompileExp(hProgram, numDevices, phDevices, pOptions); + return pfnCompileExp(hProgram, numDevices, phDevices, flags, pOptions); } /////////////////////////////////////////////////////////////////////////////// @@ -5585,6 +5609,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. @@ -5603,7 +5629,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramLinkExp( return UR_RESULT_ERROR_UNINITIALIZED; // forward to device-platform - return pfnLinkExp(hContext, numDevices, phDevices, count, phPrograms, + return pfnLinkExp(hContext, numDevices, phDevices, flags, count, phPrograms, pOptions, phProgram); } @@ -6742,6 +6768,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( if (ur_loader::getContext()->platforms.size() != 1 || ur_loader::getContext()->forceIntercept) { // return pointers to loader's DDIs + pDdiTable->pfnDynamicLinkExp = ur_loader::urProgramDynamicLinkExp; pDdiTable->pfnBuildExp = ur_loader::urProgramBuildExp; pDdiTable->pfnCompileExp = ur_loader::urProgramCompileExp; pDdiTable->pfnLinkExp = ur_loader::urProgramLinkExp; diff --git a/unified-runtime/source/loader/ur_libapi.cpp b/unified-runtime/source/loader/ur_libapi.cpp index 3fa51dc105547..061683a888a3d 100644 --- a/unified-runtime/source/loader/ur_libapi.cpp +++ b/unified-runtime/source/loader/ur_libapi.cpp @@ -9965,6 +9965,49 @@ ur_result_t UR_APICALL urCommandBufferGetNativeHandleExp( return exceptionToResult(std::current_exception()); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Creates dynamic links between exported and imported symbols in one or +/// more programs. +/// +/// @details +/// - The application may call this function from simultaneous threads. +/// - Following a successful call to this entry point the programs in +/// `phPrograms` will have all external symbols resolved and kernels +/// inside these programs would be ready for use. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hContext` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == phPrograms` +/// - ::UR_RESULT_ERROR_INVALID_PROGRAM +/// + If one of the programs in `phPrograms` isn't a valid program +/// object. +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// + `count == 0` +/// - ::UR_RESULT_ERROR_PROGRAM_LINK_FAILURE +/// + If an error occurred while linking `phPrograms`. +ur_result_t UR_APICALL urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms) try { + auto pfnDynamicLinkExp = + ur_lib::getContext()->urDdiTable.ProgramExp.pfnDynamicLinkExp; + if (nullptr == pfnDynamicLinkExp) + return UR_RESULT_ERROR_UNINITIALIZED; + + return pfnDynamicLinkExp(hContext, count, phPrograms); +} catch (...) { + return exceptionToResult(std::current_exception()); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Enqueue a command for recording the device timestamp /// @@ -10184,6 +10227,8 @@ ur_result_t UR_APICALL urMemoryExportExportMemoryHandleExp( /// + `NULL == hProgram` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == phDevices` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_EXP_PROGRAM_FLAGS_MASK & flags` /// - ::UR_RESULT_ERROR_INVALID_PROGRAM /// + If `hProgram` isn't a valid program object. /// - ::UR_RESULT_ERROR_PROGRAM_BUILD_FAILURE @@ -10195,13 +10240,15 @@ ur_result_t UR_APICALL urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) try { auto pfnBuildExp = ur_lib::getContext()->urDdiTable.ProgramExp.pfnBuildExp; if (nullptr == pfnBuildExp) return UR_RESULT_ERROR_UNINITIALIZED; - return pfnBuildExp(hProgram, numDevices, phDevices, pOptions); + return pfnBuildExp(hProgram, numDevices, phDevices, flags, pOptions); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -10228,6 +10275,8 @@ ur_result_t UR_APICALL urProgramBuildExp( /// + `NULL == hProgram` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == phDevices` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_EXP_PROGRAM_FLAGS_MASK & flags` /// - ::UR_RESULT_ERROR_INVALID_PROGRAM /// + If `hProgram` isn't a valid program object. /// - ::UR_RESULT_ERROR_PROGRAM_BUILD_FAILURE @@ -10239,6 +10288,8 @@ ur_result_t UR_APICALL urProgramCompileExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) try { auto pfnCompileExp = @@ -10246,7 +10297,7 @@ ur_result_t UR_APICALL urProgramCompileExp( if (nullptr == pfnCompileExp) return UR_RESULT_ERROR_UNINITIALIZED; - return pfnCompileExp(hProgram, numDevices, phDevices, pOptions); + return pfnCompileExp(hProgram, numDevices, phDevices, flags, pOptions); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -10281,6 +10332,8 @@ ur_result_t UR_APICALL urProgramCompileExp( /// + `NULL == phDevices` /// + `NULL == phPrograms` /// + `NULL == phProgram` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_EXP_PROGRAM_FLAGS_MASK & flags` /// - ::UR_RESULT_ERROR_INVALID_PROGRAM /// + If one of the programs in `phPrograms` isn't a valid program /// object. @@ -10295,6 +10348,8 @@ ur_result_t UR_APICALL urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. @@ -10310,7 +10365,7 @@ ur_result_t UR_APICALL urProgramLinkExp( if (nullptr == pfnLinkExp) return UR_RESULT_ERROR_UNINITIALIZED; - return pfnLinkExp(hContext, numDevices, phDevices, count, phPrograms, + return pfnLinkExp(hContext, numDevices, phDevices, flags, count, phPrograms, pOptions, phProgram); } catch (...) { return exceptionToResult(std::current_exception()); diff --git a/unified-runtime/source/loader/ur_print.cpp b/unified-runtime/source/loader/ur_print.cpp index 06619c8f7f625..24f01b9fb0db2 100644 --- a/unified-runtime/source/loader/ur_print.cpp +++ b/unified-runtime/source/loader/ur_print.cpp @@ -1140,6 +1140,14 @@ ur_result_t urPrintExpCommandBufferUpdateKernelLaunchDesc( return str_copy(&ss, buffer, buff_size, out_size); } +ur_result_t urPrintExpProgramFlags(enum ur_exp_program_flag_t value, + char *buffer, const size_t buff_size, + size_t *out_size) { + std::stringstream ss; + ss << value; + return str_copy(&ss, buffer, buff_size, out_size); +} + ur_result_t urPrintExpPeerInfo(enum ur_exp_peer_info_t value, char *buffer, const size_t buff_size, size_t *out_size) { std::stringstream ss; @@ -2438,6 +2446,14 @@ urPrintProgramBuildParams(const struct ur_program_build_params_t *params, return str_copy(&ss, buffer, buff_size, out_size); } +ur_result_t urPrintProgramDynamicLinkExpParams( + const struct ur_program_dynamic_link_exp_params_t *params, char *buffer, + const size_t buff_size, size_t *out_size) { + std::stringstream ss; + ss << params; + return str_copy(&ss, buffer, buff_size, out_size); +} + ur_result_t urPrintProgramBuildExpParams(const struct ur_program_build_exp_params_t *params, char *buffer, const size_t buff_size, diff --git a/unified-runtime/source/ur_api.cpp b/unified-runtime/source/ur_api.cpp index 9c6670964b958..1c38f9c8e28b2 100644 --- a/unified-runtime/source/ur_api.cpp +++ b/unified-runtime/source/ur_api.cpp @@ -8679,6 +8679,43 @@ ur_result_t UR_APICALL urCommandBufferGetNativeHandleExp( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Creates dynamic links between exported and imported symbols in one or +/// more programs. +/// +/// @details +/// - The application may call this function from simultaneous threads. +/// - Following a successful call to this entry point the programs in +/// `phPrograms` will have all external symbols resolved and kernels +/// inside these programs would be ready for use. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hContext` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == phPrograms` +/// - ::UR_RESULT_ERROR_INVALID_PROGRAM +/// + If one of the programs in `phPrograms` isn't a valid program +/// object. +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// + `count == 0` +/// - ::UR_RESULT_ERROR_PROGRAM_LINK_FAILURE +/// + If an error occurred while linking `phPrograms`. +ur_result_t UR_APICALL urProgramDynamicLinkExp( + /// [in] handle of the context instance. + ur_context_handle_t hContext, + /// [in] number of program handles in `phPrograms`. + uint32_t count, + /// [in][range(0, count)] pointer to array of program handles. + const ur_program_handle_t *phPrograms) { + ur_result_t result = UR_RESULT_SUCCESS; + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Enqueue a command for recording the device timestamp /// @@ -8869,6 +8906,8 @@ ur_result_t UR_APICALL urMemoryExportExportMemoryHandleExp( /// + `NULL == hProgram` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == phDevices` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_EXP_PROGRAM_FLAGS_MASK & flags` /// - ::UR_RESULT_ERROR_INVALID_PROGRAM /// + If `hProgram` isn't a valid program object. /// - ::UR_RESULT_ERROR_PROGRAM_BUILD_FAILURE @@ -8880,6 +8919,8 @@ ur_result_t UR_APICALL urProgramBuildExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { ur_result_t result = UR_RESULT_SUCCESS; @@ -8908,6 +8949,8 @@ ur_result_t UR_APICALL urProgramBuildExp( /// + `NULL == hProgram` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == phDevices` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_EXP_PROGRAM_FLAGS_MASK & flags` /// - ::UR_RESULT_ERROR_INVALID_PROGRAM /// + If `hProgram` isn't a valid program object. /// - ::UR_RESULT_ERROR_PROGRAM_BUILD_FAILURE @@ -8919,6 +8962,8 @@ ur_result_t UR_APICALL urProgramCompileExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in][optional] pointer to build options null-terminated string. const char *pOptions) { ur_result_t result = UR_RESULT_SUCCESS; @@ -8955,6 +9000,8 @@ ur_result_t UR_APICALL urProgramCompileExp( /// + `NULL == phDevices` /// + `NULL == phPrograms` /// + `NULL == phProgram` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_EXP_PROGRAM_FLAGS_MASK & flags` /// - ::UR_RESULT_ERROR_INVALID_PROGRAM /// + If one of the programs in `phPrograms` isn't a valid program /// object. @@ -8969,6 +9016,8 @@ ur_result_t UR_APICALL urProgramLinkExp( uint32_t numDevices, /// [in][range(0, numDevices)] pointer to array of device handles ur_device_handle_t *phDevices, + /// [in] program information flags + ur_exp_program_flags_t flags, /// [in] number of program handles in `phPrograms`. uint32_t count, /// [in][range(0, count)] pointer to array of program handles. diff --git a/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp b/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp index 5f541fc9bc98f..a6631b8a89b4c 100644 --- a/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp +++ b/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp @@ -135,20 +135,23 @@ TEST_P(urMultiDeviceProgramCreateWithBinaryTest, MultipleBuildCalls) { auto second_subset = std::vector( devices.begin() + devices.size() / 2, devices.end()); ASSERT_SUCCESS(urProgramBuildExp(binary_program, first_subset.size(), - first_subset.data(), nullptr)); + first_subset.data(), + ur_exp_program_flags_t{}, nullptr)); auto kernelName = uur::KernelsEnvironment::instance->GetEntryPointNames("foo")[0]; uur::raii::Kernel kernel; ASSERT_SUCCESS( urKernelCreate(binary_program, kernelName.data(), kernel.ptr())); ASSERT_SUCCESS(urProgramBuildExp(binary_program, second_subset.size(), - second_subset.data(), nullptr)); + second_subset.data(), + ur_exp_program_flags_t{}, nullptr)); ASSERT_SUCCESS( urKernelCreate(binary_program, kernelName.data(), kernel.ptr())); // Building for the same subset of devices should not fail. ASSERT_SUCCESS(urProgramBuildExp(binary_program, first_subset.size(), - first_subset.data(), nullptr)); + first_subset.data(), + ur_exp_program_flags_t{}, nullptr)); } // Test the case we get native binaries from program created with multiple diff --git a/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithIL.cpp b/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithIL.cpp index 071a3ba901d7c..8630cb355405b 100644 --- a/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithIL.cpp +++ b/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithIL.cpp @@ -32,8 +32,8 @@ TEST_P(urMultiDeviceProgramTest, urMultiDeviceProgramGetInfo) { auto subset = std::vector( associated_devices.begin(), associated_devices.begin() + associated_devices.size() / 2); - ASSERT_SUCCESS( - urProgramBuildExp(program, subset.size(), subset.data(), nullptr)); + ASSERT_SUCCESS(urProgramBuildExp(program, subset.size(), subset.data(), + ur_exp_program_flags_t{}, nullptr)); std::vector binary_sizes(associated_devices.size()); ASSERT_SUCCESS(urProgramGetInfo(program, UR_PROGRAM_INFO_BINARY_SIZES, @@ -84,8 +84,8 @@ TEST_P(urMultiDeviceProgramTest, urMultiDeviceProgramGetInfoBinaries) { } // Build program for the second device only. - ASSERT_SUCCESS( - urProgramBuildExp(program, 1, associated_devices.data() + 1, nullptr)); + ASSERT_SUCCESS(urProgramBuildExp(program, 1, associated_devices.data() + 1, + ur_exp_program_flags_t{}, nullptr)); std::vector binary_sizes(associated_devices.size()); ASSERT_SUCCESS(urProgramGetInfo(program, UR_PROGRAM_INFO_BINARY_SIZES, binary_sizes.size() * sizeof(size_t), @@ -108,6 +108,7 @@ TEST_P(urMultiDeviceProgramTest, urMultiDeviceProgramGetInfoBinaries) { pointers.data() + 1, nullptr, &program_from_binary)); ASSERT_NE(program_from_binary, nullptr); ASSERT_SUCCESS(urProgramBuildExp(program_from_binary, 1, - associated_devices.data() + 1, nullptr)); + associated_devices.data() + 1, + ur_exp_program_flags_t{}, nullptr)); ASSERT_SUCCESS(urProgramRelease(program_from_binary)); } diff --git a/unified-runtime/tools/urinfo/urinfo.hpp b/unified-runtime/tools/urinfo/urinfo.hpp index 3407c57f847d7..4a29a4f38a8dd 100644 --- a/unified-runtime/tools/urinfo/urinfo.hpp +++ b/unified-runtime/tools/urinfo/urinfo.hpp @@ -461,6 +461,8 @@ inline void printDeviceInfos(ur_device_handle_t hDevice, printDeviceInfo(hDevice, UR_DEVICE_INFO_MULTI_DEVICE_COMPILE_SUPPORT_EXP); std::cout << prefix; + printDeviceInfo(hDevice, UR_DEVICE_INFO_DYNAMIC_LINK_SUPPORT_EXP); + std::cout << prefix; printDeviceInfo(hDevice, UR_DEVICE_INFO_USM_CONTEXT_MEMCPY_SUPPORT_EXP); std::cout << prefix;