Skip to content

Commit 3e7e60a

Browse files
authored
Revert "[Flang][OpenMP] Implicitly map nested allocatable components in derived types" (#160759)
Reverts #160116
1 parent 04258fe commit 3e7e60a

File tree

3 files changed

+55
-206
lines changed

3 files changed

+55
-206
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 55 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -701,175 +701,105 @@ class MapInfoFinalizationPass
701701

702702
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
703703
llvm::SmallVector<mlir::Value> newMapOpsForFields;
704-
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
704+
llvm::SmallVector<int64_t> fieldIndicies;
705705

706-
auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
707-
mlir::Type memTy,
708-
llvm::ArrayRef<int64_t> indexPath,
709-
llvm::StringRef memberName) {
710-
// Check if already mapped (index path equality).
706+
for (auto fieldMemTyPair : recordType.getTypeList()) {
707+
auto &field = fieldMemTyPair.first;
708+
auto memTy = fieldMemTyPair.second;
709+
710+
bool shouldMapField =
711+
llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
712+
if (!fir::isAllocatableType(memTy))
713+
return false;
714+
715+
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
716+
if (!designateOp)
717+
return false;
718+
719+
return designateOp.getComponent() &&
720+
designateOp.getComponent()->strref() == field;
721+
}) != mapVarForwardSlice.end();
722+
723+
// TODO Handle recursive record types. Adapting
724+
// `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
725+
// entities might be helpful here.
726+
727+
if (!shouldMapField)
728+
continue;
729+
730+
int32_t fieldIdx = recordType.getFieldIndex(field);
711731
bool alreadyMapped = [&]() {
712732
if (op.getMembersIndexAttr())
713733
for (auto indexList : op.getMembersIndexAttr()) {
714734
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
715-
if (indexListAttr.size() != indexPath.size())
716-
continue;
717-
bool allEq = true;
718-
for (auto [i, attr] : llvm::enumerate(indexListAttr)) {
719-
if (mlir::cast<mlir::IntegerAttr>(attr).getInt() !=
720-
indexPath[i]) {
721-
allEq = false;
722-
break;
723-
}
724-
}
725-
if (allEq)
735+
if (indexListAttr.size() == 1 &&
736+
mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
737+
fieldIdx)
726738
return true;
727739
}
728740

729741
return false;
730742
}();
731743

732744
if (alreadyMapped)
733-
return;
745+
continue;
734746

735747
builder.setInsertionPoint(op);
748+
fir::IntOrValue idxConst =
749+
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
750+
auto fieldCoord = fir::CoordinateOp::create(
751+
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
752+
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
736753
fir::factory::AddrAndBoundsInfo info =
737-
fir::factory::getDataOperandBaseAddr(builder, coordRef,
738-
/*isOptional=*/false, loc);
754+
fir::factory::getDataOperandBaseAddr(
755+
builder, fieldCoord, /*isOptional=*/false, op.getLoc());
739756
llvm::SmallVector<mlir::Value> bounds =
740757
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
741758
mlir::omp::MapBoundsType>(
742759
builder, info,
743-
hlfir::translateToExtendedValue(loc, builder,
744-
hlfir::Entity{coordRef})
760+
hlfir::translateToExtendedValue(op.getLoc(), builder,
761+
hlfir::Entity{fieldCoord})
745762
.first,
746-
/*dataExvIsAssumedSize=*/false, loc);
763+
/*dataExvIsAssumedSize=*/false, op.getLoc());
747764

748765
mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
749-
builder, loc, coordRef.getType(), coordRef,
750-
mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
766+
builder, op.getLoc(), fieldCoord.getResult().getType(),
767+
fieldCoord.getResult(),
768+
mlir::TypeAttr::get(
769+
fir::unwrapRefType(fieldCoord.getResult().getType())),
751770
op.getMapTypeAttr(),
752771
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
753772
mlir::omp::VariableCaptureKind::ByRef),
754773
/*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
755774
/*members_index=*/mlir::ArrayAttr{}, bounds,
756775
/*mapperId=*/mlir::FlatSymbolRefAttr(),
757-
builder.getStringAttr(op.getNameAttr().strref() + "." +
758-
memberName + ".implicit_map"),
776+
builder.getStringAttr(op.getNameAttr().strref() + "." + field +
777+
".implicit_map"),
759778
/*partial_map=*/builder.getBoolAttr(false));
760779
newMapOpsForFields.emplace_back(fieldMapOp);
761-
newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
762-
};
763-
764-
// 1) Handle direct top-level allocatable fields (existing behavior).
765-
for (auto fieldMemTyPair : recordType.getTypeList()) {
766-
auto &field = fieldMemTyPair.first;
767-
auto memTy = fieldMemTyPair.second;
768-
769-
if (!fir::isAllocatableType(memTy))
770-
continue;
771-
772-
bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
773-
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
774-
return designateOp && designateOp.getComponent() &&
775-
designateOp.getComponent()->strref() == field;
776-
});
777-
if (!referenced)
778-
continue;
779-
780-
int32_t fieldIdx = recordType.getFieldIndex(field);
781-
builder.setInsertionPoint(op);
782-
fir::IntOrValue idxConst =
783-
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
784-
auto fieldCoord = fir::CoordinateOp::create(
785-
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
786-
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
787-
appendMemberMap(op.getLoc(), fieldCoord, memTy, {fieldIdx}, field);
788-
}
789-
790-
// Handle nested allocatable fields along any component chain
791-
// referenced in the region via HLFIR designates.
792-
for (mlir::Operation *sliceOp : mapVarForwardSlice) {
793-
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
794-
if (!designateOp || !designateOp.getComponent())
795-
continue;
796-
llvm::SmallVector<llvm::StringRef> compPathReversed;
797-
compPathReversed.push_back(designateOp.getComponent()->strref());
798-
mlir::Value curBase = designateOp.getMemref();
799-
bool rootedAtMapArg = false;
800-
while (true) {
801-
if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
802-
if (!parentDes.getComponent())
803-
break;
804-
compPathReversed.push_back(parentDes.getComponent()->strref());
805-
curBase = parentDes.getMemref();
806-
continue;
807-
}
808-
if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
809-
if (auto barg =
810-
mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
811-
rootedAtMapArg = (barg == opBlockArg);
812-
} else if (auto blockArg =
813-
mlir::dyn_cast_or_null<mlir::BlockArgument>(
814-
curBase)) {
815-
rootedAtMapArg = (blockArg == opBlockArg);
816-
}
817-
break;
818-
}
819-
if (!rootedAtMapArg || compPathReversed.size() < 2)
820-
continue;
821-
builder.setInsertionPoint(op);
822-
llvm::SmallVector<int64_t> indexPath;
823-
mlir::Type curTy = underlyingType;
824-
mlir::Value coordRef = op.getVarPtr();
825-
bool validPath = true;
826-
for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
827-
auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
828-
if (!recTy) {
829-
validPath = false;
830-
break;
831-
}
832-
int32_t idx = recTy.getFieldIndex(compName);
833-
if (idx < 0) {
834-
validPath = false;
835-
break;
836-
}
837-
indexPath.push_back(idx);
838-
mlir::Type memTy = recTy.getType(idx);
839-
fir::IntOrValue idxConst =
840-
mlir::IntegerAttr::get(builder.getI32Type(), idx);
841-
coordRef = fir::CoordinateOp::create(
842-
builder, op.getLoc(), builder.getRefType(memTy), coordRef,
843-
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
844-
curTy = memTy;
845-
}
846-
if (!validPath)
847-
continue;
848-
if (auto finalRefTy =
849-
mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
850-
mlir::Type eleTy = finalRefTy.getElementType();
851-
if (fir::isAllocatableType(eleTy))
852-
appendMemberMap(op.getLoc(), coordRef, eleTy, indexPath,
853-
compPathReversed.front());
854-
}
780+
fieldIndicies.emplace_back(fieldIdx);
855781
}
856782

857783
if (newMapOpsForFields.empty())
858784
return mlir::WalkResult::advance();
859785

860786
op.getMembersMutable().append(newMapOpsForFields);
861787
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
862-
if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
863-
for (mlir::Attribute indexList : oldAttr) {
788+
mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
789+
790+
if (oldMembersIdxAttr)
791+
for (mlir::Attribute indexList : oldMembersIdxAttr) {
864792
llvm::SmallVector<int64_t> listVec;
865793

866794
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
867795
listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
868796

869797
newMemberIndices.emplace_back(std::move(listVec));
870798
}
871-
for (auto &path : newMemberIndexPaths)
872-
newMemberIndices.emplace_back(path);
799+
800+
for (int64_t newFieldIdx : fieldIndicies)
801+
newMemberIndices.emplace_back(
802+
llvm::SmallVector<int64_t>(1, newFieldIdx));
873803

874804
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
875805
op.setPartialMap(true);

flang/test/Lower/OpenMP/declare-mapper.f90

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
77
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90
88
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-5.f90 -o - | FileCheck %t/omp-declare-mapper-5.f90
9-
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90
109

1110
!--- omp-declare-mapper-1.f90
1211
subroutine declare_mapper_1
@@ -263,40 +262,3 @@ subroutine use_inner()
263262
!$omp end target
264263
end subroutine
265264
end program declare_mapper_5
266-
267-
!--- omp-declare-mapper-6.f90
268-
subroutine declare_mapper_nested_parent
269-
type :: inner_t
270-
real, allocatable :: deep_arr(:)
271-
end type inner_t
272-
273-
type, abstract :: base_t
274-
real, allocatable :: base_arr(:)
275-
type(inner_t) :: inner
276-
end type base_t
277-
278-
type, extends(base_t) :: real_t
279-
real, allocatable :: real_arr(:)
280-
end type real_t
281-
282-
!$omp declare mapper (custommapper : real_t :: t) map(tofrom: t%base_arr, t%real_arr)
283-
284-
type(real_t) :: r
285-
286-
allocate(r%base_arr(10))
287-
allocate(r%inner%deep_arr(10))
288-
allocate(r%real_arr(10))
289-
r%base_arr = 1.0
290-
r%inner%deep_arr = 4.0
291-
r%real_arr = 0.0
292-
293-
! CHECK: omp.target
294-
! Check implicit maps for nested parent and deep nested allocatable payloads
295-
! CHECK-DAG: omp.map.info {{.*}} {name = "r.base_arr.implicit_map"}
296-
! CHECK-DAG: omp.map.info {{.*}} {name = "r.deep_arr.implicit_map"}
297-
! The declared mapper's own allocatable is still mapped implicitly
298-
! CHECK-DAG: omp.map.info {{.*}} {name = "r.real_arr.implicit_map"}
299-
!$omp target map(mapper(custommapper), tofrom: r)
300-
r%real_arr = r%base_arr(1) + r%inner%deep_arr(1)
301-
!$omp end target
302-
end subroutine declare_mapper_nested_parent

offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)