Skip to content

Commit 125b05c

Browse files
[SYCL] Handle KernelNameType templated using template template type with enum template argument (#1806)
Add support for KernelNameType with is a template specialization class with template template arguments which in turn is templated using enum. Signed-off-by: Elizabeth Andrews <[email protected]>
1 parent e3da4ef commit 125b05c

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,6 +1857,38 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
18571857
// template class Foo specialized by class Baz<Bar>, not a template
18581858
// class template <template <typename> class> class T as it should.
18591859
TemplateDecl *TD = Arg.getAsTemplate().getAsTemplateDecl();
1860+
TemplateParameterList *TemplateParams = TD->getTemplateParameters();
1861+
for (NamedDecl *P : *TemplateParams) {
1862+
// If template template paramter type has an enum value template
1863+
// parameter, forward declaration of enum type is required. Only enum
1864+
// values (not types) need to be handled. For example, consider the
1865+
// following kernel name type:
1866+
//
1867+
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
1868+
// typename TypeIn> class T> class Foo;
1869+
//
1870+
// The correct specialization for Foo (with enum type) is:
1871+
// Foo<EnumTypeOut, Baz>, where Baz is a template class.
1872+
//
1873+
// Therefore the forward class declarations generated in the
1874+
// integration header are:
1875+
// template <EnumValueIn EnumValue, typename TypeIn> class Baz;
1876+
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
1877+
// typename EnumTypeIn> class T> class Foo;
1878+
//
1879+
// This requires the following enum forward declarations:
1880+
// enum class EnumTypeOut : int; (Used to template Foo)
1881+
// enum class EnumValueIn : int; (Used to template Baz)
1882+
if (NonTypeTemplateParmDecl *TemplateParam =
1883+
dyn_cast<NonTypeTemplateParmDecl>(P)) {
1884+
QualType T = TemplateParam->getType();
1885+
if (const auto *ET = T->getAs<EnumType>()) {
1886+
const EnumDecl *ED = ET->getDecl();
1887+
if (!checkEnumTemplateParameter(ED, Diag, KernelLocation))
1888+
emitFwdDecl(O, ED, KernelLocation);
1889+
}
1890+
}
1891+
}
18601892
if (Printed.insert(TD).second) {
18611893
emitFwdDecl(O, TD, KernelLocation);
18621894
}
@@ -1923,6 +1955,11 @@ static void printArgument(ASTContext &Ctx, raw_ostream &ArgOS,
19231955
ArgOS << getKernelNameTypeString(T, Ctx, TypePolicy);
19241956
break;
19251957
}
1958+
case TemplateArgument::ArgKind::Template: {
1959+
TemplateDecl *TD = Arg.getAsTemplate().getAsTemplateDecl();
1960+
ArgOS << TD->getQualifiedNameAsString();
1961+
break;
1962+
}
19261963
default:
19271964
Arg.print(P, ArgOS);
19281965
}

clang/test/CodeGenSYCL/kernelname-enum.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,20 @@ class T2 {};
9494
template <typename EnumType>
9595
class T3 {};
9696

97+
enum class EnumTypeOut : int { A,
98+
B,
99+
};
100+
enum class EnumValueIn : int { A,
101+
B,
102+
};
103+
template <EnumValueIn EnumValue, typename EnumTypeIn>
104+
class Baz;
105+
template <typename EnumTypeOut, template <EnumValueIn EnumValue, typename EnumTypeIn> class T>
106+
class dummy_functor_8 {
107+
public:
108+
void operator()() {}
109+
};
110+
97111
int main() {
98112

99113
dummy_functor_1<no_namespace_int::val_1> f1;
@@ -104,6 +118,7 @@ int main() {
104118
dummy_functor_6<unscoped_enum::val_1> f6;
105119
dummy_functor_7<no_namespace_int> f7;
106120
dummy_functor_7<internal::namespace_short> f8;
121+
dummy_functor_8<EnumTypeOut, Baz> f9;
107122

108123
cl::sycl::queue q;
109124

@@ -147,6 +162,10 @@ int main() {
147162
cgh.single_task<T1<T3<type_argument_template_enum::E>>>([=]() {});
148163
});
149164

165+
q.submit([&](cl::sycl::handler &cgh) {
166+
cgh.single_task(f9);
167+
});
168+
150169
return 0;
151170
}
152171

@@ -173,6 +192,10 @@ int main() {
173192
// CHECK-NEXT: }
174193
// CHECK: template <type_argument_template_enum::E EnumValue> class T2;
175194
// CHECK: template <typename T> class T1;
195+
// CHECK: enum class EnumTypeOut : int;
196+
// CHECK: enum class EnumValueIn : int;
197+
// CHECK: template <EnumValueIn EnumValue, typename EnumTypeIn> class Baz;
198+
// CHECK: template <typename EnumTypeOut, template <EnumValueIn EnumValue, typename EnumTypeIn> class T> class dummy_functor_8;
176199
// CHECK: Specializations of KernelInfo for kernel function types:
177200
// CHECK: template <> struct KernelInfo<::dummy_functor_1<(no_namespace_int)0>>
178201
// CHECK: template <> struct KernelInfo<::dummy_functor_2<(no_namespace_short)1>>
@@ -184,3 +207,4 @@ int main() {
184207
// CHECK: template <> struct KernelInfo<::dummy_functor_7<::internal::namespace_short>>
185208
// CHECK: template <> struct KernelInfo<::T1<::T2<(type_argument_template_enum::E)0>>>
186209
// CHECK: template <> struct KernelInfo<::T1<::T3<::type_argument_template_enum::E>>>
210+
// CHECK: template <> struct KernelInfo<::dummy_functor_8<::EnumTypeOut, Baz>>

0 commit comments

Comments
 (0)