@@ -701,175 +701,105 @@ class MapInfoFinalizationPass
701
701
702
702
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
703
703
llvm::SmallVector<mlir::Value> newMapOpsForFields;
704
- llvm::SmallVector<llvm::SmallVector< int64_t >> newMemberIndexPaths ;
704
+ llvm::SmallVector<int64_t > fieldIndicies ;
705
705
706
- auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
707
- mlir::Type memTy,
708
- llvm::ArrayRef<int64_t > indexPath,
709
- llvm::StringRef memberName) {
710
- // Check if already mapped (index path equality).
706
+ for (auto fieldMemTyPair : recordType.getTypeList ()) {
707
+ auto &field = fieldMemTyPair.first ;
708
+ auto memTy = fieldMemTyPair.second ;
709
+
710
+ bool shouldMapField =
711
+ llvm::find_if (mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
712
+ if (!fir::isAllocatableType (memTy))
713
+ return false ;
714
+
715
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
716
+ if (!designateOp)
717
+ return false ;
718
+
719
+ return designateOp.getComponent () &&
720
+ designateOp.getComponent ()->strref () == field;
721
+ }) != mapVarForwardSlice.end ();
722
+
723
+ // TODO Handle recursive record types. Adapting
724
+ // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
725
+ // entities might be helpful here.
726
+
727
+ if (!shouldMapField)
728
+ continue ;
729
+
730
+ int32_t fieldIdx = recordType.getFieldIndex (field);
711
731
bool alreadyMapped = [&]() {
712
732
if (op.getMembersIndexAttr ())
713
733
for (auto indexList : op.getMembersIndexAttr ()) {
714
734
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
715
- if (indexListAttr.size () != indexPath.size ())
716
- continue ;
717
- bool allEq = true ;
718
- for (auto [i, attr] : llvm::enumerate (indexListAttr)) {
719
- if (mlir::cast<mlir::IntegerAttr>(attr).getInt () !=
720
- indexPath[i]) {
721
- allEq = false ;
722
- break ;
723
- }
724
- }
725
- if (allEq)
735
+ if (indexListAttr.size () == 1 &&
736
+ mlir::cast<mlir::IntegerAttr>(indexListAttr[0 ]).getInt () ==
737
+ fieldIdx)
726
738
return true ;
727
739
}
728
740
729
741
return false ;
730
742
}();
731
743
732
744
if (alreadyMapped)
733
- return ;
745
+ continue ;
734
746
735
747
builder.setInsertionPoint (op);
748
+ fir::IntOrValue idxConst =
749
+ mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
750
+ auto fieldCoord = fir::CoordinateOp::create (
751
+ builder, op.getLoc (), builder.getRefType (memTy), op.getVarPtr (),
752
+ llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
736
753
fir::factory::AddrAndBoundsInfo info =
737
- fir::factory::getDataOperandBaseAddr (builder, coordRef,
738
- /* isOptional=*/ false , loc );
754
+ fir::factory::getDataOperandBaseAddr (
755
+ builder, fieldCoord, /* isOptional=*/ false , op. getLoc () );
739
756
llvm::SmallVector<mlir::Value> bounds =
740
757
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
741
758
mlir::omp::MapBoundsType>(
742
759
builder, info,
743
- hlfir::translateToExtendedValue (loc , builder,
744
- hlfir::Entity{coordRef })
760
+ hlfir::translateToExtendedValue (op. getLoc () , builder,
761
+ hlfir::Entity{fieldCoord })
745
762
.first ,
746
- /* dataExvIsAssumedSize=*/ false , loc );
763
+ /* dataExvIsAssumedSize=*/ false , op. getLoc () );
747
764
748
765
mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create (
749
- builder, loc, coordRef.getType (), coordRef,
750
- mlir::TypeAttr::get (fir::unwrapRefType (coordRef.getType ())),
766
+ builder, op.getLoc (), fieldCoord.getResult ().getType (),
767
+ fieldCoord.getResult (),
768
+ mlir::TypeAttr::get (
769
+ fir::unwrapRefType (fieldCoord.getResult ().getType ())),
751
770
op.getMapTypeAttr (),
752
771
builder.getAttr <mlir::omp::VariableCaptureKindAttr>(
753
772
mlir::omp::VariableCaptureKind::ByRef),
754
773
/* varPtrPtr=*/ mlir::Value{}, /* members=*/ mlir::ValueRange{},
755
774
/* members_index=*/ mlir::ArrayAttr{}, bounds,
756
775
/* mapperId=*/ mlir::FlatSymbolRefAttr (),
757
- builder.getStringAttr (op.getNameAttr ().strref () + " ." +
758
- memberName + " .implicit_map" ),
776
+ builder.getStringAttr (op.getNameAttr ().strref () + " ." + field +
777
+ " .implicit_map" ),
759
778
/* partial_map=*/ builder.getBoolAttr (false ));
760
779
newMapOpsForFields.emplace_back (fieldMapOp);
761
- newMemberIndexPaths.emplace_back (indexPath.begin (), indexPath.end ());
762
- };
763
-
764
- // 1) Handle direct top-level allocatable fields (existing behavior).
765
- for (auto fieldMemTyPair : recordType.getTypeList ()) {
766
- auto &field = fieldMemTyPair.first ;
767
- auto memTy = fieldMemTyPair.second ;
768
-
769
- if (!fir::isAllocatableType (memTy))
770
- continue ;
771
-
772
- bool referenced = llvm::any_of (mapVarForwardSlice, [&](auto *opv) {
773
- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
774
- return designateOp && designateOp.getComponent () &&
775
- designateOp.getComponent ()->strref () == field;
776
- });
777
- if (!referenced)
778
- continue ;
779
-
780
- int32_t fieldIdx = recordType.getFieldIndex (field);
781
- builder.setInsertionPoint (op);
782
- fir::IntOrValue idxConst =
783
- mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
784
- auto fieldCoord = fir::CoordinateOp::create (
785
- builder, op.getLoc (), builder.getRefType (memTy), op.getVarPtr (),
786
- llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
787
- appendMemberMap (op.getLoc (), fieldCoord, memTy, {fieldIdx}, field);
788
- }
789
-
790
- // Handle nested allocatable fields along any component chain
791
- // referenced in the region via HLFIR designates.
792
- for (mlir::Operation *sliceOp : mapVarForwardSlice) {
793
- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
794
- if (!designateOp || !designateOp.getComponent ())
795
- continue ;
796
- llvm::SmallVector<llvm::StringRef> compPathReversed;
797
- compPathReversed.push_back (designateOp.getComponent ()->strref ());
798
- mlir::Value curBase = designateOp.getMemref ();
799
- bool rootedAtMapArg = false ;
800
- while (true ) {
801
- if (auto parentDes = curBase.getDefiningOp <hlfir::DesignateOp>()) {
802
- if (!parentDes.getComponent ())
803
- break ;
804
- compPathReversed.push_back (parentDes.getComponent ()->strref ());
805
- curBase = parentDes.getMemref ();
806
- continue ;
807
- }
808
- if (auto decl = curBase.getDefiningOp <hlfir::DeclareOp>()) {
809
- if (auto barg =
810
- mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref ()))
811
- rootedAtMapArg = (barg == opBlockArg);
812
- } else if (auto blockArg =
813
- mlir::dyn_cast_or_null<mlir::BlockArgument>(
814
- curBase)) {
815
- rootedAtMapArg = (blockArg == opBlockArg);
816
- }
817
- break ;
818
- }
819
- if (!rootedAtMapArg || compPathReversed.size () < 2 )
820
- continue ;
821
- builder.setInsertionPoint (op);
822
- llvm::SmallVector<int64_t > indexPath;
823
- mlir::Type curTy = underlyingType;
824
- mlir::Value coordRef = op.getVarPtr ();
825
- bool validPath = true ;
826
- for (llvm::StringRef compName : llvm::reverse (compPathReversed)) {
827
- auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
828
- if (!recTy) {
829
- validPath = false ;
830
- break ;
831
- }
832
- int32_t idx = recTy.getFieldIndex (compName);
833
- if (idx < 0 ) {
834
- validPath = false ;
835
- break ;
836
- }
837
- indexPath.push_back (idx);
838
- mlir::Type memTy = recTy.getType (idx);
839
- fir::IntOrValue idxConst =
840
- mlir::IntegerAttr::get (builder.getI32Type (), idx);
841
- coordRef = fir::CoordinateOp::create (
842
- builder, op.getLoc (), builder.getRefType (memTy), coordRef,
843
- llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
844
- curTy = memTy;
845
- }
846
- if (!validPath)
847
- continue ;
848
- if (auto finalRefTy =
849
- mlir::dyn_cast<fir::ReferenceType>(coordRef.getType ())) {
850
- mlir::Type eleTy = finalRefTy.getElementType ();
851
- if (fir::isAllocatableType (eleTy))
852
- appendMemberMap (op.getLoc (), coordRef, eleTy, indexPath,
853
- compPathReversed.front ());
854
- }
780
+ fieldIndicies.emplace_back (fieldIdx);
855
781
}
856
782
857
783
if (newMapOpsForFields.empty ())
858
784
return mlir::WalkResult::advance ();
859
785
860
786
op.getMembersMutable ().append (newMapOpsForFields);
861
787
llvm::SmallVector<llvm::SmallVector<int64_t >> newMemberIndices;
862
- if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr ())
863
- for (mlir::Attribute indexList : oldAttr) {
788
+ mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr ();
789
+
790
+ if (oldMembersIdxAttr)
791
+ for (mlir::Attribute indexList : oldMembersIdxAttr) {
864
792
llvm::SmallVector<int64_t > listVec;
865
793
866
794
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
867
795
listVec.push_back (mlir::cast<mlir::IntegerAttr>(index).getInt ());
868
796
869
797
newMemberIndices.emplace_back (std::move (listVec));
870
798
}
871
- for (auto &path : newMemberIndexPaths)
872
- newMemberIndices.emplace_back (path);
799
+
800
+ for (int64_t newFieldIdx : fieldIndicies)
801
+ newMemberIndices.emplace_back (
802
+ llvm::SmallVector<int64_t >(1 , newFieldIdx));
873
803
874
804
op.setMembersIndexAttr (builder.create2DI64ArrayAttr (newMemberIndices));
875
805
op.setPartialMap (true );
0 commit comments