Skip to content

Commit 7d00ca8

Browse files
jcai19Google-ML-Automation
authored andcommitted
[XLA][Numerics][HLO Value Tracking] Propagte original values in HLO parameters through the StableHLO round trip
This saves HLO original values of parameters in MLIR function argument attributes. PiperOrigin-RevId: 814947158
1 parent 968aea7 commit 7d00ca8

File tree

7 files changed

+100
-22
lines changed

7 files changed

+100
-22
lines changed

xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,9 +567,18 @@ absl::StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
567567
// Setup the input parameters.
568568
const int num_parameters = computation.num_parameters();
569569

570+
FuncOp func = llvm::dyn_cast<FuncOp>(builder->getBlock()->getParentOp());
570571
for (int i = 0; i < num_parameters; i++) {
571572
auto* hlo_parameter = computation.parameter_instruction(i);
572573
instruction_value_map_[hlo_parameter] = arguments[i];
574+
// Only add original value attributes to parameters in functions. Skip
575+
// regions.
576+
if (hlo_parameter->original_value() && func) {
577+
func.setArgAttr(
578+
i, kMhloOriginalValueAttr,
579+
builder_->getStringAttr(
580+
"{" + hlo_parameter->original_value()->ToString() + "}"));
581+
}
573582
}
574583

575584
for (auto instruction : computation.MakeInstructionPostOrder()) {

xla/hlo/translate/hlo_to_mhlo/tests/import.hlo

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,9 +2130,23 @@ add {
21302130
// FLATTEN-CHECK-SAME: ([[ARG:%.*]]: tensor<4x4xf32>) -> (tensor<4x2xf32>, tensor<4x2xi32>)
21312131

21322132
// Test HLO original value
2133-
// CHECK-LABEL: func private @test_original_value
2134-
%test_original_value (Arg_0: f32[192]) -> f32[1,17,17,192] {
2135-
%Arg_0 = f32[192]{0} parameter(0)
2136-
// CHECK: "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> {mhlo.original_value = "{{[{][{]}}\22broadcast.2342\22{{[}][}]}}"} : (tensor<192xf32>) -> tensor<1x17x17x192xf32>
2137-
ROOT %broadcast.2342 = f32[1,17,17,192]{3,2,1,0} broadcast(f32[192]{0} %Arg_0), dimensions={3}, origin={{"broadcast.2342"}}
2133+
// CHECK-LABEL: func private @add_original_value
2134+
// CHECK-SAME: {mhlo.original_value = "{{[{][{]}}\22a\22{{[}][}]}}"}
2135+
add_original_value {
2136+
lhs = f32[] parameter(0), origin={{"a"}}
2137+
rhs = f32[] parameter(1)
2138+
ROOT add = f32[] add(lhs, rhs)
2139+
}
2140+
2141+
// CHECK-LABEL: func private @test_orignal_value
2142+
// CHECK-SAME: ([[INPUT:%.*]]: tensor<8xf32>
2143+
// CHECK-SAME: {mhlo.original_value = "{{[{][{]}}\22b\22{{[}][}]}}"})
2144+
%test_orignal_value {
2145+
input = f32[8] parameter(0), origin={{"b"}}
2146+
// CHECK-NEXT: "mhlo.all_reduce"([[INPUT]]) <{
2147+
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
2148+
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
2149+
// CHECK: mhlo.return [[ADD]] : tensor<f32>
2150+
// CHECK: }) {mhlo.original_value = "{{[{][{]}}\22c\22{{[}][}]}}"}
2151+
ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add_original_value, origin={{"c"}}
21382152
}

xla/hlo/translate/mhlo_to_hlo/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ cc_library(
291291
srcs = ["translate.cc"],
292292
hdrs = ["translate.h"],
293293
deps = [
294+
":attribute_exporter",
294295
":mlir_hlo_to_hlo",
295296
":type_to_shape",
296297
"//xla:debug_options_flags",

xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,20 @@ static void ExtractFrontendAttributesFromFunction(
916916
}
917917
}
918918

919+
static void ExtractOriginalValuesFromFunction(
920+
mlir::func::FuncOp function,
921+
llvm::SmallVectorImpl<std::optional<xla::OriginalValueProto>>*
922+
original_value_protos) {
923+
original_value_protos->resize(function.getNumArguments(), std::nullopt);
924+
for (int i = 0, end = function.getNumArguments(); i < end; ++i) {
925+
if (auto original_value_attr = function.getArgAttrOfType<mlir::StringAttr>(
926+
i, xla::kMhloOriginalValueAttr)) {
927+
(*original_value_protos)[i] =
928+
xla::ConvertOriginalValue(original_value_attr.getValue());
929+
}
930+
}
931+
}
932+
919933
static bool SomeOptionalShardingsAreSet(
920934
llvm::ArrayRef<std::optional<xla::OpSharding>> shardings) {
921935
return llvm::any_of(shardings,
@@ -1114,8 +1128,10 @@ class ConvertToHloModule {
11141128
bool ensure_single_arg,
11151129
const std::vector<bool>& entry_args_same_across_replicas,
11161130
llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
1131+
llvm::ArrayRef<std::optional<xla::FrontendAttributes>> arg_fe_attrs,
1132+
llvm::ArrayRef<std::optional<xla::OriginalValueProto>>
1133+
arg_original_value_protos,
11171134
llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
1118-
llvm::ArrayRef<std::optional<xla::FrontendAttributes>> fe_attrs,
11191135
xla::XlaComputationId& computation,
11201136
llvm::ArrayRef<mlir::Value> implicit_operands = {},
11211137
llvm::ArrayRef<mlir::Value> implicit_results = {});
@@ -5493,8 +5509,9 @@ LogicalResult ConvertToHloModule::LowerStablehloCompositeCall(
54935509
/*is_entry_function=*/false,
54945510
/*ensure_single_arg=*/false,
54955511
/*entry_args_same_across_replicas=*/{},
5496-
/*arg_shardings=*/{}, /*ret_shardings=*/{},
5497-
/*fe_attrs=*/{}, /*computation=*/computation,
5512+
/*arg_shardings=*/{}, /*arg_fe_attrs=*/{},
5513+
/*arg_original_value_protos=*/{}, /*ret_shardings=*/{},
5514+
/*computation=*/computation,
54985515
/*implicit_operands=*/{}))) {
54995516
return failure();
55005517
}
@@ -5552,8 +5569,9 @@ LogicalResult ConvertToHloModule::LowerCompositeCall(
55525569
/*is_entry_function=*/false,
55535570
/*ensure_single_arg=*/false,
55545571
/*entry_args_same_across_replicas=*/{},
5555-
/*arg_shardings=*/{}, /*ret_shardings=*/{},
5556-
/*fe_attrs=*/{}, /*computation=*/computation,
5572+
/*arg_shardings=*/{}, /*arg_fe_attrs=*/{},
5573+
/*arg_original_value_protos=*/{}, /*ret_shardings=*/{},
5574+
/*computation=*/computation,
55575575
/*implicit_operands=*/{}))) {
55585576
return failure();
55595577
}
@@ -5908,6 +5926,8 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) {
59085926
llvm::SmallVector<std::optional<xla::OpSharding>, 4> arg_shardings;
59095927
llvm::SmallVector<std::optional<xla::OpSharding>, 4> ret_shardings;
59105928
llvm::SmallVector<std::optional<xla::FrontendAttributes>, 4> arg_fe_attrs;
5929+
llvm::SmallVector<std::optional<xla::OriginalValueProto>, 4>
5930+
arg_original_value_protos;
59115931
if (entry_function) {
59125932
bool any_arg_replicated = false;
59135933
entry_args_same_across_replicas.reserve(f.getNumArguments());
@@ -5954,13 +5974,14 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) {
59545974
if (!any_arg_replicated) entry_args_same_across_replicas.clear();
59555975
}
59565976
ExtractFrontendAttributesFromFunction(f, &arg_fe_attrs);
5977+
ExtractOriginalValuesFromFunction(f, &arg_original_value_protos);
59575978
ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings,
59585979
entry_function);
59595980
xla::XlaComputationId computation;
59605981
if (failed(LowerBasicBlockAsFunction(
59615982
&f.front(), builder.get(), entry_function, false,
5962-
entry_args_same_across_replicas, arg_shardings, ret_shardings,
5963-
arg_fe_attrs, computation))) {
5983+
entry_args_same_across_replicas, arg_shardings, arg_fe_attrs,
5984+
arg_original_value_protos, ret_shardings, computation))) {
59645985
return failure();
59655986
}
59665987
if (auto execution_thread =
@@ -6102,12 +6123,14 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
61026123
bool ensure_single_arg,
61036124
const std::vector<bool>& entry_args_same_across_replicas,
61046125
llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
6126+
llvm::ArrayRef<std::optional<xla::FrontendAttributes>> arg_fe_attrs,
6127+
llvm::ArrayRef<std::optional<xla::OriginalValueProto>>
6128+
arg_original_value_protos,
61056129
llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
6106-
llvm::ArrayRef<std::optional<xla::FrontendAttributes>> fe_attrs,
61076130
xla::XlaComputationId& computation,
61086131
llvm::ArrayRef<mlir::Value> implicit_operands,
61096132
llvm::ArrayRef<mlir::Value> implicit_results) {
6110-
// Mapping from the Value to lowered XlaOp.
6133+
// Mapping from the Value to lowered XlaOp.
61116134
ValueLoweringMap lowering;
61126135

61136136
// If using tuples as input, then there is only one input parameter that is a
@@ -6137,6 +6160,10 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
61376160
xla::XlaScopedShardingAssignment scoped_sharding(
61386161
builder, arg_shardings.empty() ? std::nullopt
61396162
: arg_shardings[arg.getArgNumber()]);
6163+
xla::XlaScopedOriginalValueAssignment original_value(
6164+
builder, arg_original_value_protos.empty()
6165+
? std::nullopt
6166+
: arg_original_value_protos[arg.getArgNumber()]);
61406167
lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
61416168
}
61426169
} else {
@@ -6172,6 +6199,10 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
61726199
xla::XlaScopedShardingAssignment scoped_sharding(
61736200
builder,
61746201
arg_shardings.empty() ? std::nullopt : arg_shardings[num]);
6202+
xla::XlaScopedOriginalValueAssignment original_value(
6203+
builder, arg_original_value_protos.empty()
6204+
? std::nullopt
6205+
: arg_original_value_protos[num]);
61756206
lowering[arg] = xla::GetTupleElement(tuple, num);
61766207
}
61776208
for (auto [implicit_index, implicit_operand] :
@@ -6188,6 +6219,10 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
61886219
xla::XlaScopedShardingAssignment scoped_sharding(
61896220
builder,
61906221
arg_shardings.empty() ? std::nullopt : arg_shardings.front());
6222+
xla::XlaScopedOriginalValueAssignment original_value(
6223+
builder, arg_original_value_protos.empty()
6224+
? std::nullopt
6225+
: arg_original_value_protos.front());
61916226
mlir::Value arg = implicit_operands.empty() ? block->getArgument(0)
61926227
: implicit_operands.front();
61936228
xla::XlaScopedOpMetadataAssignment op_metadata(
@@ -6214,10 +6249,14 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
62146249
xla::Shape shape = xla::TypeToShape(arg.getType());
62156250
xla::XlaScopedShardingAssignment scoped_sharding(
62166251
builder, arg_shardings.empty() ? std::nullopt : arg_shardings[num]);
6217-
if (!fe_attrs.empty() && fe_attrs[num]) {
6252+
xla::XlaScopedOriginalValueAssignment original_value(
6253+
builder, arg_original_value_protos.empty()
6254+
? std::nullopt
6255+
: arg_original_value_protos[num]);
6256+
if (!arg_fe_attrs.empty() && arg_fe_attrs[num]) {
62186257
// Populates frontend attributes for parameters only for the entry
62196258
// functions with no tuple args.
6220-
builder->SetFrontendAttributes(*fe_attrs[num]);
6259+
builder->SetFrontendAttributes(*arg_fe_attrs[num]);
62216260
}
62226261
// Save the location information as a name. For example JAX will set the
62236262
// name of the function argument of these. Want to preserve these for
@@ -6277,8 +6316,10 @@ LogicalResult ConvertToHloModule::LowerRegionAsComputation(
62776316
&region->front(), builder.get(),
62786317
/*is_entry_function=*/false,
62796318
/*ensure_single_arg*/ ensure_single_arg,
6280-
/*entry_args_same_across_replicas=*/{}, arg_shardings, ret_shardings,
6281-
/*fe_attrs=*/{}, func, implicit_operands, implicit_results))) {
6319+
/*entry_args_same_across_replicas=*/{}, arg_shardings,
6320+
/*arg_fe_attrs=*/{},
6321+
/*arg_original_value_protos=*/{}, ret_shardings, func,
6322+
implicit_operands, implicit_results))) {
62826323
return failure();
62836324
}
62846325
return success();

xla/hlo/translate/mhlo_to_hlo/tests/export.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,10 +3172,10 @@ func.func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
31723172
// -----
31733173
// CHECK: HloModule
31743174
// CHECK: ENTRY
3175-
// CHECK: %[[ARG0:.*]] = f32[192] parameter(0)
3175+
// CHECK: %[[ARG0:.*]] = f32[192] parameter(0), origin={{[{][{]}}"a"{{[}][}]}}
31763176
// CHECK: ROOT %[[RESULT:.*]] = f32[1,17,17,192] broadcast(%[[ARG0]]), dimensions={3}, origin={{[{][{]}}"broadcast.2342"{{[}][}]}}
31773177

3178-
func.func @main(%arg0: tensor<192xf32>) -> tensor<1x17x17x192xf32> {
3178+
func.func @main(%arg0: tensor<192xf32> {mhlo.original_value = "{{\22a\22}}"}) -> tensor<1x17x17x192xf32> {
31793179
%0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<3> : tensor<1xi64>}> {mhlo.original_value = "{{\22broadcast.2342\22}}"} : (tensor<192xf32>) -> tensor<1x17x17x192xf32>
31803180
return %0 : tensor<1x17x17x192xf32>
31813181
}

xla/hlo/translate/mhlo_to_hlo/translate.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ limitations under the License.
4242
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
4343
#include "xla/hlo/ir/hlo_instruction.h"
4444
#include "xla/hlo/ir/hlo_module.h"
45+
#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h"
4546
#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
4647
#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h"
4748
#include "xla/hlo/translate/register.h"
@@ -139,6 +140,18 @@ absl::Status ConvertMlirHloToHloViaBuilder(
139140
}
140141
}
141142
}
143+
144+
for (int i = 0; i < main.getNumArguments(); ++i) {
145+
if (auto original_value_attr = main.getArgAttrOfType<mlir::StringAttr>(
146+
i, xla::kMhloOriginalValueAttr)) {
147+
*computation.mutable_proto()
148+
->mutable_computations(0)
149+
->mutable_instructions(i)
150+
->mutable_original_value() =
151+
*ConvertOriginalValue(original_value_attr);
152+
}
153+
}
154+
142155
auto hlo_module = computation.proto();
143156
mlir::StringRef module_name = module.getName() ? *module.getName() : "main";
144157
hlo_module.set_name(module_name.str());

xla/hlo/translate/tests/stablehlo.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1953,11 +1953,11 @@ module {
19531953
// CHECK-LABEL: HloModule main
19541954

19551955
// CHECK: ENTRY
1956-
// CHECK: %[[ARG0:.*]] = f32[192] parameter(0)
1956+
// CHECK: %[[ARG0:.*]] = f32[192] parameter(0), origin={{[{][{]}}"a"{{[}][}]}}
19571957
// CHECK: ROOT %[[RESULT:.*]] = f32[1,17,17,192] broadcast(%[[ARG0]]), dimensions={3}, origin={{[{][{]}}"broadcast.2342"{{[}][}]}}
19581958

19591959
module {
1960-
func.func @main(%arg0: tensor<192xf32>) -> tensor<1x17x17x192xf32> {
1960+
func.func @main(%arg0: tensor<192xf32> {mhlo.original_value = "{{\22a\22}}"}) -> tensor<1x17x17x192xf32> {
19611961
%0 = stablehlo.broadcast_in_dim %arg0, dims = [3] {mhlo.original_value = "{{\22broadcast.2342\22}}"} : (tensor<192xf32>) -> tensor<1x17x17x192xf32>
19621962
return %0 : tensor<1x17x17x192xf32>
19631963
}

0 commit comments

Comments
 (0)