Skip to content

Commit 60dc68d

Browse files
JDPailleuxbonachea
andauthored
[flang][Lower] Add Lowering for CO_{BROADCAST, MAX, MIN, SUM} to PRIF (#154770)
In relation to the approval and merge of the #76088 specification about multi-image features in Flang. Here is a PR on adding support of the collectives CO_BROADCAST, CO_SUM, CO_MIN and CO_MAX in conformance with the PRIF specification. --------- Co-authored-by: Dan Bonachea <[email protected]>
1 parent 39c8df3 commit 60dc68d

File tree

8 files changed

+661
-9
lines changed

8 files changed

+661
-9
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ struct IntrinsicLibrary {
246246
template <mlir::arith::CmpIPredicate pred>
247247
fir::ExtendedValue genCPtrCompare(mlir::Type,
248248
llvm::ArrayRef<fir::ExtendedValue>);
249+
void genCoBroadcast(llvm::ArrayRef<fir::ExtendedValue>);
250+
void genCoMax(llvm::ArrayRef<fir::ExtendedValue>);
251+
void genCoMin(llvm::ArrayRef<fir::ExtendedValue>);
252+
void genCoSum(llvm::ArrayRef<fir::ExtendedValue>);
249253
mlir::Value genCosd(mlir::Type, llvm::ArrayRef<mlir::Value>);
250254
mlir::Value genCospi(mlir::Type, llvm::ArrayRef<mlir::Value>);
251255
void genDateAndTime(llvm::ArrayRef<fir::ExtendedValue>);

flang/include/flang/Optimizer/Builder/Runtime/Coarray.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ namespace fir::runtime {
3434
return fir::NameUniquer::doProcedure({"prif"}, {}, oss.str()); \
3535
}()
3636

37+
#define PRIF_STAT_TYPE builder.getRefType(builder.getI32Type())
38+
#define PRIF_ERRMSG_TYPE \
39+
fir::BoxType::get(fir::CharacterType::get(builder.getContext(), 1, \
40+
fir::CharacterType::unknownLen()))
41+
3742
/// Generate Call to runtime prif_init
3843
mlir::Value genInitCoarray(fir::FirOpBuilder &builder, mlir::Location loc);
3944

@@ -49,5 +54,22 @@ mlir::Value getNumImagesWithTeam(fir::FirOpBuilder &builder, mlir::Location loc,
4954
mlir::Value getThisImage(fir::FirOpBuilder &builder, mlir::Location loc,
5055
mlir::Value team = {});
5156

57+
/// Generate call to runtime subroutine prif_co_broadcast
58+
void genCoBroadcast(fir::FirOpBuilder &builder, mlir::Location loc,
59+
mlir::Value A, mlir::Value sourceImage, mlir::Value stat,
60+
mlir::Value errmsg);
61+
62+
/// Generate call to runtime subroutine prif_co_max and prif_co_max_character
63+
void genCoMax(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value A,
64+
mlir::Value resultImage, mlir::Value stat, mlir::Value errmsg);
65+
66+
/// Generate call to runtime subroutine prif_co_min or prif_co_min_character
67+
void genCoMin(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value A,
68+
mlir::Value resultImage, mlir::Value stat, mlir::Value errmsg);
69+
70+
/// Generate call to runtime subroutine prif_co_sum
71+
void genCoSum(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value A,
72+
mlir::Value resultImage, mlir::Value stat, mlir::Value errmsg);
73+
5274
} // namespace fir::runtime
5375
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_COARRAY_H

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,34 @@ static constexpr IntrinsicHandler handlers[]{
397397
{"cmplx",
398398
&I::genCmplx,
399399
{{{"x", asValue}, {"y", asValue, handleDynamicOptional}}}},
400+
{"co_broadcast",
401+
&I::genCoBroadcast,
402+
{{{"a", asBox},
403+
{"source_image", asAddr},
404+
{"stat", asAddr, handleDynamicOptional},
405+
{"errmsg", asBox, handleDynamicOptional}}},
406+
/*isElemental*/ false},
407+
{"co_max",
408+
&I::genCoMax,
409+
{{{"a", asBox},
410+
{"result_image", asAddr, handleDynamicOptional},
411+
{"stat", asAddr, handleDynamicOptional},
412+
{"errmsg", asBox, handleDynamicOptional}}},
413+
/*isElemental*/ false},
414+
{"co_min",
415+
&I::genCoMin,
416+
{{{"a", asBox},
417+
{"result_image", asAddr, handleDynamicOptional},
418+
{"stat", asAddr, handleDynamicOptional},
419+
{"errmsg", asBox, handleDynamicOptional}}},
420+
/*isElemental*/ false},
421+
{"co_sum",
422+
&I::genCoSum,
423+
{{{"a", asBox},
424+
{"result_image", asAddr, handleDynamicOptional},
425+
{"stat", asAddr, handleDynamicOptional},
426+
{"errmsg", asBox, handleDynamicOptional}}},
427+
/*isElemental*/ false},
400428
{"command_argument_count", &I::genCommandArgumentCount},
401429
{"conjg", &I::genConjg},
402430
{"cosd", &I::genCosd},
@@ -3686,6 +3714,85 @@ mlir::Value IntrinsicLibrary::genCmplx(mlir::Type resultType,
36863714
imag);
36873715
}
36883716

3717+
// CO_BROADCAST
3718+
void IntrinsicLibrary::genCoBroadcast(llvm::ArrayRef<fir::ExtendedValue> args) {
3719+
checkCoarrayEnabled();
3720+
assert(args.size() == 4);
3721+
mlir::Value sourceImage = fir::getBase(args[1]);
3722+
mlir::Value status =
3723+
isStaticallyAbsent(args[2])
3724+
? fir::AbsentOp::create(builder, loc,
3725+
builder.getRefType(builder.getI32Type()))
3726+
.getResult()
3727+
: fir::getBase(args[2]);
3728+
mlir::Value errmsg =
3729+
isStaticallyAbsent(args[3])
3730+
? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult()
3731+
: fir::getBase(args[3]);
3732+
fir::runtime::genCoBroadcast(builder, loc, fir::getBase(args[0]), sourceImage,
3733+
status, errmsg);
3734+
}
3735+
3736+
// CO_MAX
3737+
void IntrinsicLibrary::genCoMax(llvm::ArrayRef<fir::ExtendedValue> args) {
3738+
checkCoarrayEnabled();
3739+
assert(args.size() == 4);
3740+
mlir::Value refNone =
3741+
fir::AbsentOp::create(builder, loc,
3742+
builder.getRefType(builder.getI32Type()))
3743+
.getResult();
3744+
mlir::Value resultImage =
3745+
isStaticallyAbsent(args[1]) ? refNone : fir::getBase(args[1]);
3746+
mlir::Value status =
3747+
isStaticallyAbsent(args[2]) ? refNone : fir::getBase(args[2]);
3748+
mlir::Value errmsg =
3749+
isStaticallyAbsent(args[3])
3750+
? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult()
3751+
: fir::getBase(args[3]);
3752+
fir::runtime::genCoMax(builder, loc, fir::getBase(args[0]), resultImage,
3753+
status, errmsg);
3754+
}
3755+
3756+
// CO_MIN
3757+
void IntrinsicLibrary::genCoMin(llvm::ArrayRef<fir::ExtendedValue> args) {
3758+
checkCoarrayEnabled();
3759+
assert(args.size() == 4);
3760+
mlir::Value refNone =
3761+
fir::AbsentOp::create(builder, loc,
3762+
builder.getRefType(builder.getI32Type()))
3763+
.getResult();
3764+
mlir::Value resultImage =
3765+
isStaticallyAbsent(args[1]) ? refNone : fir::getBase(args[1]);
3766+
mlir::Value status =
3767+
isStaticallyAbsent(args[2]) ? refNone : fir::getBase(args[2]);
3768+
mlir::Value errmsg =
3769+
isStaticallyAbsent(args[3])
3770+
? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult()
3771+
: fir::getBase(args[3]);
3772+
fir::runtime::genCoMin(builder, loc, fir::getBase(args[0]), resultImage,
3773+
status, errmsg);
3774+
}
3775+
3776+
// CO_SUM
3777+
void IntrinsicLibrary::genCoSum(llvm::ArrayRef<fir::ExtendedValue> args) {
3778+
checkCoarrayEnabled();
3779+
assert(args.size() == 4);
3780+
mlir::Value absentInt =
3781+
fir::AbsentOp::create(builder, loc,
3782+
builder.getRefType(builder.getI32Type()))
3783+
.getResult();
3784+
mlir::Value resultImage =
3785+
isStaticallyAbsent(args[1]) ? absentInt : fir::getBase(args[1]);
3786+
mlir::Value status =
3787+
isStaticallyAbsent(args[2]) ? absentInt : fir::getBase(args[2]);
3788+
mlir::Value errmsg =
3789+
isStaticallyAbsent(args[3])
3790+
? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult()
3791+
: fir::getBase(args[3]);
3792+
fir::runtime::genCoSum(builder, loc, fir::getBase(args[0]), resultImage,
3793+
status, errmsg);
3794+
}
3795+
36893796
// COMMAND_ARGUMENT_COUNT
36903797
fir::ExtendedValue IntrinsicLibrary::genCommandArgumentCount(
36913798
mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args) {

flang/lib/Optimizer/Builder/Runtime/Coarray.cpp

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,24 @@
1414
using namespace Fortran::runtime;
1515
using namespace Fortran::semantics;
1616

17+
// Most PRIF functions take `errmsg` and `errmsg_alloc` as two optional
18+
// arguments of intent (out). One is allocatable, the other is not.
19+
// It is the responsibility of the compiler to ensure that the appropriate
20+
// optional argument is passed, and at most one must be provided in a given
21+
// call.
22+
// Depending on the type of `errmsg`, this function will return the pair
23+
// corresponding to (`errmsg`, `errmsg_alloc`).
24+
static std::pair<mlir::Value, mlir::Value>
25+
genErrmsgPRIF(fir::FirOpBuilder &builder, mlir::Location loc,
26+
mlir::Value errmsg) {
27+
bool isAllocatableErrmsg = fir::isAllocatableType(errmsg.getType());
28+
29+
mlir::Value absent = fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE);
30+
mlir::Value errMsg = isAllocatableErrmsg ? absent : errmsg;
31+
mlir::Value errMsgAlloc = isAllocatableErrmsg ? errmsg : absent;
32+
return {errMsg, errMsgAlloc};
33+
}
34+
1735
/// Generate Call to runtime prif_init
1836
mlir::Value fir::runtime::genInitCoarray(fir::FirOpBuilder &builder,
1937
mlir::Location loc) {
@@ -24,8 +42,8 @@ mlir::Value fir::runtime::genInitCoarray(fir::FirOpBuilder &builder,
2442
builder.createFunction(loc, PRIFNAME_SUB("init"), ftype);
2543
llvm::SmallVector<mlir::Value> args =
2644
fir::runtime::createArguments(builder, loc, ftype, result);
27-
builder.create<fir::CallOp>(loc, funcOp, args);
28-
return builder.create<fir::LoadOp>(loc, result);
45+
fir::CallOp::create(builder, loc, funcOp, args);
46+
return fir::LoadOp::create(builder, loc, result);
2947
}
3048

3149
/// Generate Call to runtime prif_num_images
@@ -38,8 +56,8 @@ mlir::Value fir::runtime::getNumImages(fir::FirOpBuilder &builder,
3856
builder.createFunction(loc, PRIFNAME_SUB("num_images"), ftype);
3957
llvm::SmallVector<mlir::Value> args =
4058
fir::runtime::createArguments(builder, loc, ftype, result);
41-
builder.create<fir::CallOp>(loc, funcOp, args);
42-
return builder.create<fir::LoadOp>(loc, result);
59+
fir::CallOp::create(builder, loc, funcOp, args);
60+
return fir::LoadOp::create(builder, loc, result);
4361
}
4462

4563
/// Generate Call to runtime prif_num_images_with_{team|team_number}
@@ -63,8 +81,8 @@ mlir::Value fir::runtime::getNumImagesWithTeam(fir::FirOpBuilder &builder,
6381
team = builder.createBox(loc, team);
6482
llvm::SmallVector<mlir::Value> args =
6583
fir::runtime::createArguments(builder, loc, ftype, team, result);
66-
builder.create<fir::CallOp>(loc, funcOp, args);
67-
return builder.create<fir::LoadOp>(loc, result);
84+
fir::CallOp::create(builder, loc, funcOp, args);
85+
return fir::LoadOp::create(builder, loc, result);
6886
}
6987

7088
/// Generate Call to runtime prif_this_image_no_coarray
@@ -78,9 +96,72 @@ mlir::Value fir::runtime::getThisImage(fir::FirOpBuilder &builder,
7896

7997
mlir::Value result = builder.createTemporary(loc, builder.getI32Type());
8098
mlir::Value teamArg =
81-
!team ? builder.create<fir::AbsentOp>(loc, boxTy) : team;
99+
!team ? fir::AbsentOp::create(builder, loc, boxTy) : team;
82100
llvm::SmallVector<mlir::Value> args =
83101
fir::runtime::createArguments(builder, loc, ftype, teamArg, result);
84-
builder.create<fir::CallOp>(loc, funcOp, args);
85-
return builder.create<fir::LoadOp>(loc, result);
102+
fir::CallOp::create(builder, loc, funcOp, args);
103+
return fir::LoadOp::create(builder, loc, result);
104+
}
105+
106+
/// Generate call to collective subroutines except co_reduce
107+
/// A must be lowered as a box
108+
void genCollectiveSubroutine(fir::FirOpBuilder &builder, mlir::Location loc,
109+
mlir::Value A, mlir::Value rootImage,
110+
mlir::Value stat, mlir::Value errmsg,
111+
std::string coName) {
112+
mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
113+
mlir::FunctionType ftype =
114+
PRIF_FUNCTYPE(boxTy, builder.getRefType(builder.getI32Type()),
115+
PRIF_STAT_TYPE, PRIF_ERRMSG_TYPE, PRIF_ERRMSG_TYPE);
116+
mlir::func::FuncOp funcOp = builder.createFunction(loc, coName, ftype);
117+
118+
auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, errmsg);
119+
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
120+
builder, loc, ftype, A, rootImage, stat, errmsgArg, errmsgAllocArg);
121+
fir::CallOp::create(builder, loc, funcOp, args);
122+
}
123+
124+
/// Generate call to runtime subroutine prif_co_broadcast
125+
void fir::runtime::genCoBroadcast(fir::FirOpBuilder &builder,
126+
mlir::Location loc, mlir::Value A,
127+
mlir::Value sourceImage, mlir::Value stat,
128+
mlir::Value errmsg) {
129+
genCollectiveSubroutine(builder, loc, A, sourceImage, stat, errmsg,
130+
PRIFNAME_SUB("co_broadcast"));
131+
}
132+
133+
/// Generate call to runtime subroutine prif_co_max or prif_co_max_character
134+
void fir::runtime::genCoMax(fir::FirOpBuilder &builder, mlir::Location loc,
135+
mlir::Value A, mlir::Value resultImage,
136+
mlir::Value stat, mlir::Value errmsg) {
137+
mlir::Type argTy =
138+
fir::unwrapSequenceType(fir::unwrapPassByRefType(A.getType()));
139+
if (mlir::isa<fir::CharacterType>(argTy))
140+
genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg,
141+
PRIFNAME_SUB("co_max_character"));
142+
else
143+
genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg,
144+
PRIFNAME_SUB("co_max"));
145+
}
146+
147+
/// Generate call to runtime subroutine prif_co_min or prif_co_min_character
148+
void fir::runtime::genCoMin(fir::FirOpBuilder &builder, mlir::Location loc,
149+
mlir::Value A, mlir::Value resultImage,
150+
mlir::Value stat, mlir::Value errmsg) {
151+
mlir::Type argTy =
152+
fir::unwrapSequenceType(fir::unwrapPassByRefType(A.getType()));
153+
if (mlir::isa<fir::CharacterType>(argTy))
154+
genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg,
155+
PRIFNAME_SUB("co_min_character"));
156+
else
157+
genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg,
158+
PRIFNAME_SUB("co_min"));
159+
}
160+
161+
/// Generate call to runtime subroutine prif_co_sum
162+
void fir::runtime::genCoSum(fir::FirOpBuilder &builder, mlir::Location loc,
163+
mlir::Value A, mlir::Value resultImage,
164+
mlir::Value stat, mlir::Value errmsg) {
165+
genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg,
166+
PRIFNAME_SUB("co_sum"));
86167
}

0 commit comments

Comments
 (0)