Skip to content

Conversation

fabianmcg
Copy link
Contributor

Introduces a dataflow analysis for tracking offset, size, and stride ranges of operations.
Inference of the metadata is accomplished through the implementation of the interface
InferStridedMetadataOpInterface.

To keep the size of the patch small, this patch only implements the interface for the
memref.subview operation. It's future work to add more operations.

Example:

func.func @memref_subview(%arg0: memref<8x16x4xf32, strided<[64, 4, 1]>>) {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %0 = test.with_bounds {smax = 13 : index, smin = 11 : index, umax = 13 : index, umin = 11 : index} : index
  %1 = test.with_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index} : index
  %subview = memref.subview %arg0[%c0, %c0, %c1] [%1, %0, %c2] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[64, 4, 1]>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
  return
}

Applying mlir-opt --test-strided-metadata-range-analysis prints:

Op: %subview = memref.subview %arg0[%c0, %c0, %c1] [%1, %0, %c2] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[64, 4, 1]>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
  result[0]: strided_metadata<offset = [{unsigned : [1, 1] signed : [1, 1]}], sizes = [{unsigned : [5, 7] signed : [5, 7]}, {unsigned : [11, 13] signed : [11, 13]}, {unsigned : [2, 2] signed : [2, 2]}], strides = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [4, 4] signed : [4, 4]}, {unsigned : [1, 1] signed : [1, 1]}]>

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:memref labels Sep 29, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2025

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Fabian Mora (fabianmcg)

Changes

Introduces a dataflow analysis for tracking offset, size, and stride ranges of operations.
Inference of the metadata is accomplished through the implementation of the interface
InferStridedMetadataOpInterface.

To keep the size of the patch small, this patch only implements the interface for the
memref.subview operation. It's future work to add more operations.

Example:

func.func @<!-- -->memref_subview(%arg0: memref&lt;8x16x4xf32, strided&lt;[64, 4, 1]&gt;&gt;) {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %0 = test.with_bounds {smax = 13 : index, smin = 11 : index, umax = 13 : index, umin = 11 : index} : index
  %1 = test.with_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index} : index
  %subview = memref.subview %arg0[%c0, %c0, %c1] [%1, %0, %c2] [%c1, %c1, %c1] : memref&lt;8x16x4xf32, strided&lt;[64, 4, 1]&gt;&gt; to memref&lt;?x?x?xf32, strided&lt;[?, ?, ?], offset: ?&gt;&gt;
  return
}

Applying mlir-opt --test-strided-metadata-range-analysis prints:

Op: %subview = memref.subview %arg0[%c0, %c0, %c1] [%1, %0, %c2] [%c1, %c1, %c1] : memref&lt;8x16x4xf32, strided&lt;[64, 4, 1]&gt;&gt; to memref&lt;?x?x?xf32, strided&lt;[?, ?, ?], offset: ?&gt;&gt;
  result[0]: strided_metadata&lt;offset = [{unsigned : [1, 1] signed : [1, 1]}], sizes = [{unsigned : [5, 7] signed : [5, 7]}, {unsigned : [11, 13] signed : [11, 13]}, {unsigned : [2, 2] signed : [2, 2]}], strides = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [4, 4] signed : [4, 4]}, {unsigned : [1, 1] signed : [1, 1]}]&gt;

Patch is 38.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161280.diff

17 Files Affected:

  • (added) mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h (+54)
  • (added) mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h (+25)
  • (modified) mlir/include/mlir/Interfaces/CMakeLists.txt (+1)
  • (modified) mlir/include/mlir/Interfaces/InferIntRangeInterface.h (+2-1)
  • (added) mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h (+148)
  • (added) mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td (+43)
  • (modified) mlir/lib/Analysis/CMakeLists.txt (+1)
  • (added) mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp (+127)
  • (modified) mlir/lib/Dialect/MemRef/IR/CMakeLists.txt (+3-1)
  • (added) mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp (+118)
  • (modified) mlir/lib/Interfaces/CMakeLists.txt (+2)
  • (added) mlir/lib/Interfaces/InferStridedMetadataInterface.cpp (+36)
  • (modified) mlir/lib/RegisterAllDialects.cpp (+2)
  • (added) mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir (+67)
  • (modified) mlir/test/lib/Analysis/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp (+86)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
new file mode 100644
index 0000000000000..72ac2477435db
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
@@ -0,0 +1,54 @@
+//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
+#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
+
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
+
+namespace mlir {
+namespace dataflow {
+
+/// This lattice element represents the strided metadata of an SSA value.
+class StridedMetadataRangeLattice : public Lattice<StridedMetadataRange> {
+public:
+  using Lattice::Lattice;
+};
+
+/// Strided metadata range analysis determines the strided metadata ranges of
+/// SSA values using operations that define `InferStridedMetadataInterface`.
+///
+/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and
+/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not
+/// loaded in the same solver context.
+class StridedMetadataRangeAnalysis
+    : public SparseForwardDataFlowAnalysis<StridedMetadataRangeLattice> {
+public:
+  StridedMetadataRangeAnalysis(DataFlowSolver &solver,
+                               int32_t indexBitwidth = 64);
+
+  /// At an entry point, we cannot reason about strided metadata ranges unless
+  /// the type also encodes the data. For example, a memref with static layout.
+  void setToEntryState(StridedMetadataRangeLattice *lattice) override;
+
+  /// Visit an operation. Invoke the transfer function on each operation that
+  /// implements `InferStridedMetadataInterface`.
+  LogicalResult
+  visitOperation(Operation *op,
+                 ArrayRef<const StridedMetadataRangeLattice *> operands,
+                 ArrayRef<StridedMetadataRangeLattice *> results) override;
+
+private:
+  /// Index bitwidth to use when operating with the int-ranges.
+  int32_t indexBitwidth = 64;
+};
+} // namespace dataflow
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h
new file mode 100644
index 0000000000000..ca3bc78648ab2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.h
@@ -0,0 +1,25 @@
+//===- InferStridedMetadataOpInterfaceImpl.h - Impl. of infer strided md --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_IR_INFERSTRIDEDMETADATAOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_IR_INFERSTRIDEDMETADATAOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+/// Register the external models for the infer strided metadata op interface,
+/// for the `memref` dialect. This implementation assumes that the strided
+/// metadata of a ranked memref consists of one offset, and zero or more sizes
+/// and strides.
+void registerInferStridedMetadataOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_IR_INFERSTRIDEDMETADATAOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index a5feb592045c0..72ed046a1ba5d 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_interface(DestinationStyleOpInterface)
 add_mlir_interface(FunctionInterfaces)
 add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
+add_mlir_interface(InferStridedMetadataInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
 add_mlir_interface(MemOpInterfaces)
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e88f5232..a9e3e82acdc4f 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -117,7 +117,8 @@ class IntegerValueRange {
   IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
 
   /// Create an integer value range lattice value.
-  IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+  explicit IntegerValueRange(
+      std::optional<ConstantIntRanges> value = std::nullopt)
       : value(std::move(value)) {}
 
   /// Whether the range is uninitialized. This happens when the state hasn't
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
new file mode 100644
index 0000000000000..8d633daba6f1b
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
@@ -0,0 +1,148 @@
+//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the strided metadata inference interface
+// defined in `InferStridedMetadataInterface.td`
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
+#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+namespace mlir {
+/// A class that represents the strided metadata range information, including
+/// offsets, sizes, and strides as integer ranges.
+class StridedMetadataRange {
+public:
+  /// Default constructor creates uninitialized ranges.
+  StridedMetadataRange() = default;
+
+  /// Returns a ranked strided metadata range.
+  static StridedMetadataRange
+  getRanked(SmallVectorImpl<ConstantIntRanges> &&offsets,
+            SmallVectorImpl<ConstantIntRanges> &&sizes,
+            SmallVectorImpl<ConstantIntRanges> &&strides) {
+    return StridedMetadataRange(std::move(offsets), std::move(sizes),
+                                std::move(strides));
+  }
+
+  /// Returns a strided metadata range with maximum ranges.
+  static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
+                                           int32_t offsetsRank,
+                                           int32_t sizeRank,
+                                           int32_t stridedRank) {
+    return StridedMetadataRange(
+        SmallVector<ConstantIntRanges>(
+            offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)),
+        SmallVector<ConstantIntRanges>(
+            sizeRank, ConstantIntRanges::maxRange(indexBitwidth)),
+        SmallVector<ConstantIntRanges>(
+            stridedRank, ConstantIntRanges::maxRange(indexBitwidth)));
+  }
+
+  static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
+                                           int32_t rank) {
+    return getMaxRanges(indexBitwidth, 1, rank, rank);
+  }
+
+  /// Returns whether the metadata is uninitialized.
+  bool isUninitialized() const { return !offsets.has_value(); }
+
+  /// Get the offsets range.
+  ArrayRef<ConstantIntRanges> getOffsets() const {
+    return offsets ? *offsets : ArrayRef<ConstantIntRanges>();
+  }
+  MutableArrayRef<ConstantIntRanges> getOffsets() {
+    return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>();
+  }
+
+  /// Get the sizes ranges.
+  ArrayRef<ConstantIntRanges> getSizes() const { return sizes; }
+  MutableArrayRef<ConstantIntRanges> getSizes() { return sizes; }
+
+  /// Get the strides ranges.
+  ArrayRef<ConstantIntRanges> getStrides() const { return strides; }
+  MutableArrayRef<ConstantIntRanges> getStrides() { return strides; }
+
+  /// Compare two strided metadata ranges.
+  bool operator==(const StridedMetadataRange &other) const {
+    return offsets == other.offsets && sizes == other.sizes &&
+           strides == other.strides;
+  }
+
+  /// Print the strided metadata range.
+  void print(raw_ostream &os) const;
+
+  /// Join two strided metadata ranges, by taking the element-wise union of the
+  /// metadata.
+  static StridedMetadataRange join(const StridedMetadataRange &lhs,
+                                   const StridedMetadataRange &rhs) {
+    if (lhs.isUninitialized())
+      return rhs;
+    if (rhs.isUninitialized())
+      return lhs;
+
+    // Helper fuction to compute the range union of constant ranges.
+    auto rangeUnion =
+        +[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs)
+        -> ConstantIntRanges {
+      return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs));
+    };
+
+    // Get the elementwise range union. Note, that `zip_equal` will assert if
+    // sizes are not equal.
+    SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector(
+        llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion);
+    SmallVector<ConstantIntRanges> sizes =
+        llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion);
+    SmallVector<ConstantIntRanges> strides = llvm::map_to_vector(
+        llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion);
+
+    // Return the joined metadata.
+    return StridedMetadataRange(std::move(offsets), std::move(sizes),
+                                std::move(strides));
+  }
+
+private:
+  /// Create a strided metadata range with the given offset, sizes, and strides.
+  StridedMetadataRange(SmallVectorImpl<ConstantIntRanges> &&offsets,
+                       SmallVectorImpl<ConstantIntRanges> &&sizes,
+                       SmallVectorImpl<ConstantIntRanges> &&strides)
+      : offsets(std::move(offsets)), sizes(std::move(sizes)),
+        strides(std::move(strides)) {}
+
+  /// The offsets range.
+  std::optional<SmallVector<ConstantIntRanges>> offsets;
+
+  /// The sizes ranges.
+  SmallVector<ConstantIntRanges> sizes;
+
+  /// The strides ranges.
+  SmallVector<ConstantIntRanges> strides;
+};
+
+/// Print the strided metadata to `os`.
+inline raw_ostream &operator<<(raw_ostream &os,
+                               const StridedMetadataRange &range) {
+  range.print(os);
+  return os;
+}
+
+/// Callback function type to get the integer range of a value.
+using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
+
+/// Callback function type for setting the strided metadata of a value.
+using SetStridedMetadataRangeFn =
+    function_ref<void(Value, const StridedMetadataRange &)>;
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
new file mode 100644
index 0000000000000..892b44d0aaa65
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
@@ -0,0 +1,43 @@
+//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for strided metadata range analysis
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
+#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferStridedMetadataOpInterface :
+    OpInterface<"InferStridedMetadataOpInterface"> {
+  let description = [{
+    Allows operations to participate in strided metadata analysis by providing
+    methods that allow them to specify bounds on offsets, sizes, and strides
+    of their result(s) given bounds on their input(s) if known.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+      Infer the strided metadata bounds on the results of this op given
+      the bounds on its operands.
+      For each result value or block argument, the method should call
+      `setMetadata` with that `Value` as an argument.
+      The `getIntRange` callback is provided for obtaining the int-range
+      analysis result for a given value.
+    }],
+    "void", "inferStridedMetadataRanges",
+    (ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands,
+         "::mlir::GetIntRangeFn":$getIntRange,
+         "::mlir::SetStridedMetadataRangeFn":$setMetadata,
+         "int32_t":$indexBitwidth)>
+  ];
+}
+#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
\ No newline at end of file
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 609cb34309829..bef189600d8e7 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis
   DataFlow/IntegerRangeAnalysis.cpp
   DataFlow/LivenessAnalysis.cpp
   DataFlow/SparseAnalysis.cpp
+  DataFlow/StridedMetadataRangeAnalysis.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis
diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
new file mode 100644
index 0000000000000..01c9dafaddf10
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
@@ -0,0 +1,127 @@
+//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer range inference
+// which is used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/DebugStringHelper.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "strided-metadata-range-analysis"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+/// Get the entry state for a value. For any value that is not a ranked memref,
+/// this function sets the metadata to a top state with no offsets, sizes, or
+/// strides. For `memref` types, this function will use the metadata in the type
+/// to try to deduce as much informaiton as possible.
+static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) {
+  // TODO: generalize this method with a type interface.
+  auto mTy = dyn_cast<BaseMemRefType>(v.getType());
+
+  // If not a memref or it's un-ranked, don't infer any metadata.
+  if (!mTy || !mTy.hasRank())
+    return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0);
+
+  // Get the top state.
+  auto metadata =
+      StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank());
+
+  // Compute the offset and strides.
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset)))
+    return metadata;
+
+  // Refine the metadata if we know it from the type.
+  if (!ShapedType::isDynamic(offset)) {
+    metadata.getOffsets()[0] =
+        ConstantIntRanges::constant(APInt(indexBitwidth, offset));
+  }
+  for (auto &&[size, range] :
+       llvm::zip_equal(mTy.getShape(), metadata.getSizes())) {
+    if (ShapedType::isDynamic(size))
+      continue;
+    range = ConstantIntRanges::constant(APInt(indexBitwidth, size));
+  }
+  for (auto &&[stride, range] :
+       llvm::zip_equal(strides, metadata.getStrides())) {
+    if (ShapedType::isDynamic(stride))
+      continue;
+    range = ConstantIntRanges::constant(APInt(indexBitwidth, stride));
+  }
+
+  return metadata;
+}
+
+StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis(
+    DataFlowSolver &solver, int32_t indexBitwidth)
+    : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) {
+  assert(indexBitwidth > 0 && "invalid bitwidth");
+}
+
+void StridedMetadataRangeAnalysis::setToEntryState(
+    StridedMetadataRangeLattice *lattice) {
+  propagateIfChanged(lattice, lattice->join(getEntryStateImpl(
+                                  lattice->getAnchor(), indexBitwidth)));
+}
+
+LogicalResult StridedMetadataRangeAnalysis::visitOperation(
+    Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands,
+    ArrayRef<StridedMetadataRangeLattice *> results) {
+  auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op);
+
+  // Bail if we cannot reason about the op.
+  if (!inferrable) {
+    setAllToEntryStates(results);
+    return success();
+  }
+
+  LDBG() << "Inferring metadata for: "
+         << OpWithFlags(op, OpPrintingFlags().skipRegions());
+
+  // Helper function to retrieve int range values.
+  auto getIntRange = [&](Value value) -> IntegerValueRange {
+    auto lattice = getOrCreateFor<IntegerValueRangeLattice>(
+        getProgramPointAfter(op), value);
+    return lattice ? lattice->getValue() : IntegerValueRange();
+  };
+
+  // Convert the arguments lattices to a vector.
+  SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
+      operands, [](const StridedMetadataRangeLattice *lattice) {
+        return lattice->getValue();
+      });
+
+  // Callback to set metadata on a result.
+  auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
+    auto result = cast<OpResult>(v);
+    assert(llvm::is_contained(op->getResults(), result));
+    LDBG() << "- Inferred metadata: " << md;
+    StridedMetadataRangeLattice *lattice = results[result.getResultNumber()];
+    ChangeResult changed = lattice->join(md);
+    LDBG() << "- Joined metadata: " << lattice->getValue();
+    propagateIfChanged(lattice, changed);
+  };
+
+  // Infer the metadata.
+  inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback,
+                                        indexBitwidth);
+  return success();
+}
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index e25a0121a3359..9707dc0cc64e9 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -2,10 +2,11 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MemRefDialect.cpp
   MemRefMemorySlot.cpp
   MemRefOps.cpp
+  InferStridedMetadataInterfaceImpl.cpp
   ValueBoundsOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
-  ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
+  ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR
 
   DEPENDS
   MLIRMemRefOpsIncGen
@@ -18,6 +19,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRDialectUtils
   MLIRInferIntRangeCommon
   MLIRInferIntRangeInterface
+  MLIRInferStridedMetadataInterface
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemOpInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp
new file mode 100644
index 0000000000000..4bc4edc0357e8
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/IR/InferStridedMetadataIn...
[truncated]

@fabianmcg fabianmcg force-pushed the users/fabianmcg/strided-metadata branch from a3dfb47 to 37ec2ca Compare September 29, 2025 21:28
@fabianmcg fabianmcg requested a review from Copilot September 29, 2025 21:29
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a new dataflow analysis for tracking strided metadata (offset, size, stride) ranges in MLIR operations. The analysis uses the integer range analysis to infer bounds on memref metadata parameters. Currently implements the interface only for memref.subview operations as a starting point, with future work planned for additional operations.

Key changes:

  • New InferStridedMetadataOpInterface interface for operations to participate in strided metadata analysis
  • StridedMetadataRangeAnalysis dataflow analysis that tracks offset, size, and stride ranges
  • Implementation of the interface for memref.subview operation in the MemRef dialect

Reviewed Changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td Interface definition for strided metadata inference
mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h Header with StridedMetadataRange class and interface declarations
mlir/lib/Interfaces/InferStridedMetadataInterface.cpp Implementation of StridedMetadataRange print method
mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h Header for the dataflow analysis class
mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp Main analysis implementation
mlir/lib/Dialect/MemRef/IR/InferStridedMetadataInterfaceImpl.cpp Implementation of interface for memref.subview operation
mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp Test pass for the analysis
mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir Test cases demonstrating the analysis
Comments suppressed due to low confidence (1)

mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h:1

  • There's a typo in the word 'informaiton' which should be 'information'.
//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//

@fabianmcg fabianmcg force-pushed the users/fabianmcg/strided-metadata branch from 37ec2ca to d5443d9 Compare September 29, 2025 21:31
Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! I haven't checked the actual numbers in the test but this looks good to me!

namespace mlir {
namespace dataflow {

/// This lattice element represents the strided metadata of an SSA value.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this apply only to memrefs? If so, I would mention it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, if someone wanted to implement for tensor, eg. tensor.extract_slice they could.

void mlir::memref::registerInferStridedMetadataOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
memref::SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This interface should also be promised.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it even registered as an external interface in the first place? Can we just define the interface on the op directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following perceived convention, where ValueBounds and other interfaces like IntegerRangeAnalysis are added as external models presumably to allow overriding by downstreams, or not introducing deps for downstreams. I can promote it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason I know we started using external models is to be able to implement interfaces that are not in include/mlir/interface (that is: it is necessary to work around a layering violation).
If the pattern has been copied beyond its use, I doubt it is really intentional.


include "mlir/IR/OpBase.td"

def InferStridedMetadataOpInterface :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you anticipate this interface being implemented on any ops outside of the memref dialect? If not, this interface+analysis may be better suited for mlir/Dialect/MemRef/IR instead of mlir/Interfaces+mlir/Analaysis.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially yes.

InterfaceMethod<[{
Infer the strided metadata bounds on the results of this op given
the bounds on its operands.
For each result value or block argument, the method should call
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarify documentation: For each value/block arg or only the ones that are memrefs?

analysis result for a given value.
}],
"void", "inferStridedMetadataRanges",
(ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many elements are in operands? One per operand or one per memref operand? In the former case, if a value does not have strided metadata, is it an "invalid"/unspecified StridedMetadataRange object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many elements are in operands?

All, that's an unfortunate limitation of dataflow. It will construct the lattice for every operand, even if they shouldn't participate in the analysis.

strides(std::move(strides)) {}

/// The offsets range.
std::optional<SmallVector<ConstantIntRanges>> offsets;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this optional? Also, why can there be multiple offsets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's to handle the uninitialized state without adding a new field. If nullopt it's uninitialized.

For tensor (or even memref) someone can decide to track the non-linearized offsets producing the value.

Also, we have interfaces like OffsetSizeAndStrideOpInterface that already allow the notion of multiple offsets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:memref mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants