Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,38 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
// template class Foo specialized by class Baz<Bar>, not a template
// class template <template <typename> class> class T as it should.
TemplateDecl *TD = Arg.getAsTemplate().getAsTemplateDecl();
TemplateParameterList *TemplateParams = TD->getTemplateParameters();
for (NamedDecl *P : *TemplateParams) {
// If template template paramter type has an enum value template
// parameter, forward declaration of enum type is required. Only enum
// values (not types) need to be handled. For example, consider the
// following kernel name type:
//
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
// typename TypeIn> class T> class Foo;
//
// The correct specialization for Foo (with enum type) is:
// Foo<EnumTypeOut, Baz>, where Baz is a template class.
//
// Therefore the forward class declarations generated in the
// integration header are:
// template <EnumValueIn EnumValue, typename TypeIn> class Baz;
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
// typename EnumTypeIn> class T> class Foo;
//
// This requires the following enum forward declarations:
// enum class EnumTypeOut : int; (Used to template Foo)
// enum class EnumValueIn : int; (Used to template Baz)
if (NonTypeTemplateParmDecl *TemplateParam =
dyn_cast<NonTypeTemplateParmDecl>(P)) {
QualType T = TemplateParam->getType();
if (const auto *ET = T->getAs<EnumType>()) {
const EnumDecl *ED = ET->getDecl();
if (!checkEnumTemplateParameter(ED, Diag, KernelLocation))
emitFwdDecl(O, ED, KernelLocation);
}
}
}
if (Printed.insert(TD).second) {
emitFwdDecl(O, TD, KernelLocation);
}
Expand Down Expand Up @@ -1897,6 +1929,11 @@ static void printArgument(ASTContext &Ctx, raw_ostream &ArgOS,
ArgOS << getKernelNameTypeString(T, Ctx, TypePolicy);
break;
}
case TemplateArgument::ArgKind::Template: {
TemplateDecl *TD = Arg.getAsTemplate().getAsTemplateDecl();
ArgOS << TD->getQualifiedNameAsString();
break;
}
default:
Arg.print(P, ArgOS);
}
Expand Down
24 changes: 24 additions & 0 deletions clang/test/CodeGenSYCL/kernelname-enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ class T2 {};
template <typename EnumType>
class T3 {};

enum class EnumTypeOut : int { A,
B,
};
enum class EnumValueIn : int { A,
B,
};
template <EnumValueIn EnumValue, typename EnumTypeIn>
class Baz;
template <typename EnumTypeOut, template <EnumValueIn EnumValue, typename EnumTypeIn> class T>
class dummy_functor_8 {
public:
void operator()() {}
};

int main() {

dummy_functor_1<no_namespace_int::val_1> f1;
Expand All @@ -104,6 +118,7 @@ int main() {
dummy_functor_6<unscoped_enum::val_1> f6;
dummy_functor_7<no_namespace_int> f7;
dummy_functor_7<internal::namespace_short> f8;
dummy_functor_8<EnumTypeOut, Baz> f9;

cl::sycl::queue q;

Expand Down Expand Up @@ -147,6 +162,10 @@ int main() {
cgh.single_task<T1<T3<type_argument_template_enum::E>>>([=]() {});
});

q.submit([&](cl::sycl::handler &cgh) {
cgh.single_task(f9);
});

return 0;
}

Expand All @@ -173,6 +192,10 @@ int main() {
// CHECK-NEXT: }
// CHECK: template <type_argument_template_enum::E EnumValue> class T2;
// CHECK: template <typename T> class T1;
// CHECK: enum class EnumTypeOut : int;
// CHECK: enum class EnumValueIn : int;
// CHECK: template <EnumValueIn EnumValue, typename EnumTypeIn> class Baz;
// CHECK: template <typename EnumTypeOut, template <EnumValueIn EnumValue, typename EnumTypeIn> class T> class dummy_functor_8;
// CHECK: Specializations of KernelInfo for kernel function types:
// CHECK: template <> struct KernelInfo<::dummy_functor_1<(no_namespace_int)0>>
// CHECK: template <> struct KernelInfo<::dummy_functor_2<(no_namespace_short)1>>
Expand All @@ -184,3 +207,4 @@ int main() {
// CHECK: template <> struct KernelInfo<::dummy_functor_7<::internal::namespace_short>>
// CHECK: template <> struct KernelInfo<::T1<::T2<(type_argument_template_enum::E)0>>>
// CHECK: template <> struct KernelInfo<::T1<::T3<::type_argument_template_enum::E>>>
// CHECK: template <> struct KernelInfo<::dummy_functor_8<::EnumTypeOut, Baz>>