Skip to content

Commit c94b5f0

Browse files
authored
Reland: [OpenMP][clang] 6.0: num_threads strict (part 3: codegen) (#155839)
OpenMP 6.0 12.1.2 specifies the behavior of the strict modifier for the num_threads clause on parallel directives, along with the message and severity clauses. This commit implements necessary codegen changes.
1 parent 6af2c18 commit c94b5f0

29 files changed

+17060
-465
lines changed

clang/include/clang/AST/OpenMPClause.h

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,62 +1865,43 @@ class OMPSeverityClause final : public OMPClause {
18651865
/// \endcode
18661866
/// In this example directive '#pragma omp error' has simple
18671867
/// 'message' clause with user error message of "GNU compiler required.".
1868-
class OMPMessageClause final : public OMPClause {
1868+
class OMPMessageClause final
1869+
: public OMPOneStmtClause<llvm::omp::OMPC_message, OMPClause>,
1870+
public OMPClauseWithPreInit {
18691871
friend class OMPClauseReader;
18701872

1871-
/// Location of '('
1872-
SourceLocation LParenLoc;
1873-
1874-
// Expression of the 'message' clause.
1875-
Stmt *MessageString = nullptr;
1876-
18771873
/// Set message string of the clause.
1878-
void setMessageString(Expr *MS) { MessageString = MS; }
1879-
1880-
/// Sets the location of '('.
1881-
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
1874+
void setMessageString(Expr *MS) { setStmt(MS); }
18821875

18831876
public:
18841877
/// Build 'message' clause with message string argument
18851878
///
18861879
/// \param MS Argument of the clause (message string).
1880+
/// \param HelperMS Helper statement for the construct.
1881+
/// \param CaptureRegion Innermost OpenMP region where expressions in this
1882+
/// clause must be captured.
18871883
/// \param StartLoc Starting location of the clause.
18881884
/// \param LParenLoc Location of '('.
18891885
/// \param EndLoc Ending location of the clause.
1890-
OMPMessageClause(Expr *MS, SourceLocation StartLoc, SourceLocation LParenLoc,
1886+
OMPMessageClause(Expr *MS, Stmt *HelperMS, OpenMPDirectiveKind CaptureRegion,
1887+
SourceLocation StartLoc, SourceLocation LParenLoc,
18911888
SourceLocation EndLoc)
1892-
: OMPClause(llvm::omp::OMPC_message, StartLoc, EndLoc),
1893-
LParenLoc(LParenLoc), MessageString(MS) {}
1894-
1895-
/// Build an empty clause.
1896-
OMPMessageClause()
1897-
: OMPClause(llvm::omp::OMPC_message, SourceLocation(), SourceLocation()) {
1889+
: OMPOneStmtClause(MS, StartLoc, LParenLoc, EndLoc),
1890+
OMPClauseWithPreInit(this) {
1891+
setPreInitStmt(HelperMS, CaptureRegion);
18981892
}
18991893

1900-
/// Returns the locaiton of '('.
1901-
SourceLocation getLParenLoc() const { return LParenLoc; }
1894+
/// Build an empty clause.
1895+
OMPMessageClause() : OMPOneStmtClause(), OMPClauseWithPreInit(this) {}
19021896

19031897
/// Returns message string of the clause.
1904-
Expr *getMessageString() const { return cast_or_null<Expr>(MessageString); }
1905-
1906-
child_range children() {
1907-
return child_range(&MessageString, &MessageString + 1);
1908-
}
1909-
1910-
const_child_range children() const {
1911-
return const_child_range(&MessageString, &MessageString + 1);
1912-
}
1913-
1914-
child_range used_children() {
1915-
return child_range(child_iterator(), child_iterator());
1916-
}
1917-
1918-
const_child_range used_children() const {
1919-
return const_child_range(const_child_iterator(), const_child_iterator());
1920-
}
1898+
Expr *getMessageString() const { return getStmtAs<Expr>(); }
19211899

1922-
static bool classof(const OMPClause *T) {
1923-
return T->getClauseKind() == llvm::omp::OMPC_message;
1900+
/// Try to evaluate the message string at compile time.
1901+
std::optional<std::string> tryEvaluateString(ASTContext &Ctx) const {
1902+
if (Expr *MessageExpr = getMessageString())
1903+
return MessageExpr->tryEvaluateString(Ctx);
1904+
return std::nullopt;
19241905
}
19251906
};
19261907

clang/include/clang/Basic/DiagnosticParseKinds.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,8 +1506,8 @@ def err_omp_unexpected_directive : Error<
15061506
"unexpected OpenMP directive %select{|'#pragma omp %1'}0">;
15071507
def err_omp_expected_punc : Error<
15081508
"expected ',' or ')' in '%0' %select{clause|directive}1">;
1509-
def warn_clause_expected_string : Warning<
1510-
"expected string literal in 'clause %0' - ignoring">, InGroup<IgnoredPragmas>;
1509+
def warn_clause_expected_string: Warning<
1510+
"expected string %select{|literal }1in 'clause %0' - ignoring">, InGroup<IgnoredPragmas>;
15111511
def err_omp_unexpected_clause : Error<
15121512
"unexpected OpenMP clause '%0' in directive '#pragma omp %1'">;
15131513
def err_omp_unexpected_clause_extension_only : Error<

clang/lib/AST/OpenMPClause.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) {
104104
return static_cast<const OMPFilterClause *>(C);
105105
case OMPC_ompx_dyn_cgroup_mem:
106106
return static_cast<const OMPXDynCGroupMemClause *>(C);
107+
case OMPC_message:
108+
return static_cast<const OMPMessageClause *>(C);
107109
case OMPC_default:
108110
case OMPC_proc_bind:
109111
case OMPC_safelen:
@@ -158,7 +160,6 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) {
158160
case OMPC_self_maps:
159161
case OMPC_at:
160162
case OMPC_severity:
161-
case OMPC_message:
162163
case OMPC_device_type:
163164
case OMPC_match:
164165
case OMPC_nontemporal:
@@ -1963,8 +1964,10 @@ void OMPClausePrinter::VisitOMPSeverityClause(OMPSeverityClause *Node) {
19631964
}
19641965

19651966
void OMPClausePrinter::VisitOMPMessageClause(OMPMessageClause *Node) {
1966-
OS << "message(\""
1967-
<< cast<StringLiteral>(Node->getMessageString())->getString() << "\")";
1967+
OS << "message(";
1968+
if (Expr *E = Node->getMessageString())
1969+
E->printPretty(OS, nullptr, Policy);
1970+
OS << ")";
19681971
}
19691972

19701973
void OMPClausePrinter::VisitOMPScheduleClause(OMPScheduleClause *Node) {

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,11 +1845,11 @@ void CGOpenMPRuntime::emitIfClause(CodeGenFunction &CGF, const Expr *Cond,
18451845
CGF.EmitBlock(ContBlock, /*IsFinished=*/true);
18461846
}
18471847

1848-
void CGOpenMPRuntime::emitParallelCall(CodeGenFunction &CGF, SourceLocation Loc,
1849-
llvm::Function *OutlinedFn,
1850-
ArrayRef<llvm::Value *> CapturedVars,
1851-
const Expr *IfCond,
1852-
llvm::Value *NumThreads) {
1848+
void CGOpenMPRuntime::emitParallelCall(
1849+
CodeGenFunction &CGF, SourceLocation Loc, llvm::Function *OutlinedFn,
1850+
ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond,
1851+
llvm::Value *NumThreads, OpenMPNumThreadsClauseModifier NumThreadsModifier,
1852+
OpenMPSeverityClauseKind Severity, const Expr *Message) {
18531853
if (!CGF.HaveInsertPoint())
18541854
return;
18551855
llvm::Value *RTLoc = emitUpdateLocation(CGF, Loc);
@@ -2372,9 +2372,8 @@ void CGOpenMPRuntime::emitBarrierCall(CodeGenFunction &CGF, SourceLocation Loc,
23722372

23732373
void CGOpenMPRuntime::emitErrorCall(CodeGenFunction &CGF, SourceLocation Loc,
23742374
Expr *ME, bool IsFatal) {
2375-
llvm::Value *MVL =
2376-
ME ? CGF.EmitStringLiteralLValue(cast<StringLiteral>(ME)).getPointer(CGF)
2377-
: llvm::ConstantPointerNull::get(CGF.VoidPtrTy);
2375+
llvm::Value *MVL = ME ? CGF.EmitScalarExpr(ME)
2376+
: llvm::ConstantPointerNull::get(CGF.VoidPtrTy);
23782377
// Build call void __kmpc_error(ident_t *loc, int severity, const char
23792378
// *message)
23802379
llvm::Value *Args[] = {
@@ -2699,18 +2698,54 @@ llvm::Value *CGOpenMPRuntime::emitForNext(CodeGenFunction &CGF,
26992698
CGF.getContext().BoolTy, Loc);
27002699
}
27012700

2702-
void CGOpenMPRuntime::emitNumThreadsClause(CodeGenFunction &CGF,
2703-
llvm::Value *NumThreads,
2704-
SourceLocation Loc) {
2701+
llvm::Value *CGOpenMPRuntime::emitMessageClause(CodeGenFunction &CGF,
2702+
const Expr *Message) {
2703+
if (!Message)
2704+
return llvm::ConstantPointerNull::get(CGF.VoidPtrTy);
2705+
return CGF.EmitScalarExpr(Message);
2706+
}
2707+
2708+
llvm::Value *
2709+
CGOpenMPRuntime::emitMessageClause(CodeGenFunction &CGF,
2710+
const OMPMessageClause *MessageClause) {
2711+
return emitMessageClause(
2712+
CGF, MessageClause ? MessageClause->getMessageString() : nullptr);
2713+
}
2714+
2715+
llvm::Value *
2716+
CGOpenMPRuntime::emitSeverityClause(OpenMPSeverityClauseKind Severity) {
2717+
// OpenMP 6.0, 10.4: "If no severity clause is specified then the effect is
2718+
// as if sev-level is fatal."
2719+
return llvm::ConstantInt::get(CGM.Int32Ty,
2720+
Severity == OMPC_SEVERITY_warning ? 1 : 2);
2721+
}
2722+
2723+
llvm::Value *
2724+
CGOpenMPRuntime::emitSeverityClause(const OMPSeverityClause *SeverityClause) {
2725+
return emitSeverityClause(SeverityClause ? SeverityClause->getSeverityKind()
2726+
: OMPC_SEVERITY_unknown);
2727+
}
2728+
2729+
void CGOpenMPRuntime::emitNumThreadsClause(
2730+
CodeGenFunction &CGF, llvm::Value *NumThreads, SourceLocation Loc,
2731+
OpenMPNumThreadsClauseModifier Modifier, OpenMPSeverityClauseKind Severity,
2732+
const Expr *Message) {
27052733
if (!CGF.HaveInsertPoint())
27062734
return;
2735+
llvm::SmallVector<llvm::Value *, 4> Args(
2736+
{emitUpdateLocation(CGF, Loc), getThreadID(CGF, Loc),
2737+
CGF.Builder.CreateIntCast(NumThreads, CGF.Int32Ty, /*isSigned*/ true)});
27072738
// Build call __kmpc_push_num_threads(&loc, global_tid, num_threads)
2708-
llvm::Value *Args[] = {
2709-
emitUpdateLocation(CGF, Loc), getThreadID(CGF, Loc),
2710-
CGF.Builder.CreateIntCast(NumThreads, CGF.Int32Ty, /*isSigned*/ true)};
2711-
CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
2712-
CGM.getModule(), OMPRTL___kmpc_push_num_threads),
2713-
Args);
2739+
// or __kmpc_push_num_threads_strict(&loc, global_tid, num_threads, severity,
2740+
// messsage) if strict modifier is used.
2741+
RuntimeFunction FnID = OMPRTL___kmpc_push_num_threads;
2742+
if (Modifier == OMPC_NUMTHREADS_strict) {
2743+
FnID = OMPRTL___kmpc_push_num_threads_strict;
2744+
Args.push_back(emitSeverityClause(Severity));
2745+
Args.push_back(emitMessageClause(CGF, Message));
2746+
}
2747+
CGF.EmitRuntimeCall(
2748+
OMPBuilder.getOrCreateRuntimeFunction(CGM.getModule(), FnID), Args);
27142749
}
27152750

27162751
void CGOpenMPRuntime::emitProcBindClause(CodeGenFunction &CGF,
@@ -12114,12 +12149,11 @@ llvm::Function *CGOpenMPSIMDRuntime::emitTaskOutlinedFunction(
1211412149
llvm_unreachable("Not supported in SIMD-only mode");
1211512150
}
1211612151

12117-
void CGOpenMPSIMDRuntime::emitParallelCall(CodeGenFunction &CGF,
12118-
SourceLocation Loc,
12119-
llvm::Function *OutlinedFn,
12120-
ArrayRef<llvm::Value *> CapturedVars,
12121-
const Expr *IfCond,
12122-
llvm::Value *NumThreads) {
12152+
void CGOpenMPSIMDRuntime::emitParallelCall(
12153+
CodeGenFunction &CGF, SourceLocation Loc, llvm::Function *OutlinedFn,
12154+
ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond,
12155+
llvm::Value *NumThreads, OpenMPNumThreadsClauseModifier NumThreadsModifier,
12156+
OpenMPSeverityClauseKind Severity, const Expr *Message) {
1212312157
llvm_unreachable("Not supported in SIMD-only mode");
1212412158
}
1212512159

@@ -12222,9 +12256,10 @@ llvm::Value *CGOpenMPSIMDRuntime::emitForNext(CodeGenFunction &CGF,
1222212256
llvm_unreachable("Not supported in SIMD-only mode");
1222312257
}
1222412258

12225-
void CGOpenMPSIMDRuntime::emitNumThreadsClause(CodeGenFunction &CGF,
12226-
llvm::Value *NumThreads,
12227-
SourceLocation Loc) {
12259+
void CGOpenMPSIMDRuntime::emitNumThreadsClause(
12260+
CodeGenFunction &CGF, llvm::Value *NumThreads, SourceLocation Loc,
12261+
OpenMPNumThreadsClauseModifier Modifier, OpenMPSeverityClauseKind Severity,
12262+
const Expr *Message) {
1222812263
llvm_unreachable("Not supported in SIMD-only mode");
1222912264
}
1223012265

clang/lib/CodeGen/CGOpenMPRuntime.h

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -777,11 +777,22 @@ class CGOpenMPRuntime {
777777
/// specified, nullptr otherwise.
778778
/// \param NumThreads The value corresponding to the num_threads clause, if
779779
/// any, or nullptr.
780+
/// \param NumThreadsModifier The modifier of the num_threads clause, if
781+
/// any, ignored otherwise.
782+
/// \param Severity The severity corresponding to the num_threads clause, if
783+
/// any, ignored otherwise.
784+
/// \param Message The message string corresponding to the num_threads clause,
785+
/// if any, or nullptr.
780786
///
781-
virtual void emitParallelCall(CodeGenFunction &CGF, SourceLocation Loc,
782-
llvm::Function *OutlinedFn,
783-
ArrayRef<llvm::Value *> CapturedVars,
784-
const Expr *IfCond, llvm::Value *NumThreads);
787+
virtual void
788+
emitParallelCall(CodeGenFunction &CGF, SourceLocation Loc,
789+
llvm::Function *OutlinedFn,
790+
ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond,
791+
llvm::Value *NumThreads,
792+
OpenMPNumThreadsClauseModifier NumThreadsModifier =
793+
OMPC_NUMTHREADS_unknown,
794+
OpenMPSeverityClauseKind Severity = OMPC_SEVERITY_fatal,
795+
const Expr *Message = nullptr);
785796

786797
/// Emits a critical region.
787798
/// \param CriticalName Name of the critical region.
@@ -1037,13 +1048,28 @@ class CGOpenMPRuntime {
10371048
Address IL, Address LB,
10381049
Address UB, Address ST);
10391050

1051+
virtual llvm::Value *emitMessageClause(CodeGenFunction &CGF,
1052+
const Expr *Message);
1053+
virtual llvm::Value *emitMessageClause(CodeGenFunction &CGF,
1054+
const OMPMessageClause *MessageClause);
1055+
1056+
virtual llvm::Value *emitSeverityClause(OpenMPSeverityClauseKind Severity);
1057+
virtual llvm::Value *
1058+
emitSeverityClause(const OMPSeverityClause *SeverityClause);
1059+
10401060
/// Emits call to void __kmpc_push_num_threads(ident_t *loc, kmp_int32
10411061
/// global_tid, kmp_int32 num_threads) to generate code for 'num_threads'
10421062
/// clause.
1063+
/// If the modifier 'strict' is given:
1064+
/// Emits call to void __kmpc_push_num_threads_strict(ident_t *loc, kmp_int32
1065+
/// global_tid, kmp_int32 num_threads, int severity, const char *message) to
1066+
/// generate code for 'num_threads' clause with 'strict' modifier.
10431067
/// \param NumThreads An integer value of threads.
1044-
virtual void emitNumThreadsClause(CodeGenFunction &CGF,
1045-
llvm::Value *NumThreads,
1046-
SourceLocation Loc);
1068+
virtual void emitNumThreadsClause(
1069+
CodeGenFunction &CGF, llvm::Value *NumThreads, SourceLocation Loc,
1070+
OpenMPNumThreadsClauseModifier Modifier = OMPC_NUMTHREADS_unknown,
1071+
OpenMPSeverityClauseKind Severity = OMPC_SEVERITY_fatal,
1072+
const Expr *Message = nullptr);
10471073

10481074
/// Emit call to void __kmpc_push_proc_bind(ident_t *loc, kmp_int32
10491075
/// global_tid, int proc_bind) to generate code for 'proc_bind' clause.
@@ -1737,11 +1763,21 @@ class CGOpenMPSIMDRuntime final : public CGOpenMPRuntime {
17371763
/// specified, nullptr otherwise.
17381764
/// \param NumThreads The value corresponding to the num_threads clause, if
17391765
/// any, or nullptr.
1766+
/// \param NumThreadsModifier The modifier of the num_threads clause, if
1767+
/// any, ignored otherwise.
1768+
/// \param Severity The severity corresponding to the num_threads clause, if
1769+
/// any, ignored otherwise.
1770+
/// \param Message The message string corresponding to the num_threads clause,
1771+
/// if any, or nullptr.
17401772
///
17411773
void emitParallelCall(CodeGenFunction &CGF, SourceLocation Loc,
17421774
llvm::Function *OutlinedFn,
17431775
ArrayRef<llvm::Value *> CapturedVars,
1744-
const Expr *IfCond, llvm::Value *NumThreads) override;
1776+
const Expr *IfCond, llvm::Value *NumThreads,
1777+
OpenMPNumThreadsClauseModifier NumThreadsModifier =
1778+
OMPC_NUMTHREADS_unknown,
1779+
OpenMPSeverityClauseKind Severity = OMPC_SEVERITY_fatal,
1780+
const Expr *Message = nullptr) override;
17451781

17461782
/// Emits a critical region.
17471783
/// \param CriticalName Name of the critical region.
@@ -1911,9 +1947,16 @@ class CGOpenMPSIMDRuntime final : public CGOpenMPRuntime {
19111947
/// Emits call to void __kmpc_push_num_threads(ident_t *loc, kmp_int32
19121948
/// global_tid, kmp_int32 num_threads) to generate code for 'num_threads'
19131949
/// clause.
1950+
/// If the modifier 'strict' is given:
1951+
/// Emits call to void __kmpc_push_num_threads_strict(ident_t *loc, kmp_int32
1952+
/// global_tid, kmp_int32 num_threads, int severity, const char *message) to
1953+
/// generate code for 'num_threads' clause with 'strict' modifier.
19141954
/// \param NumThreads An integer value of threads.
1915-
void emitNumThreadsClause(CodeGenFunction &CGF, llvm::Value *NumThreads,
1916-
SourceLocation Loc) override;
1955+
void emitNumThreadsClause(
1956+
CodeGenFunction &CGF, llvm::Value *NumThreads, SourceLocation Loc,
1957+
OpenMPNumThreadsClauseModifier Modifier = OMPC_NUMTHREADS_unknown,
1958+
OpenMPSeverityClauseKind Severity = OMPC_SEVERITY_fatal,
1959+
const Expr *Message = nullptr) override;
19171960

19181961
/// Emit call to void __kmpc_push_proc_bind(ident_t *loc, kmp_int32
19191962
/// global_tid, int proc_bind) to generate code for 'proc_bind' clause.

0 commit comments

Comments
 (0)