@@ -916,6 +916,20 @@ static void ExtractFrontendAttributesFromFunction(
916
916
}
917
917
}
918
918
919
+ static void ExtractOriginalValuesFromFunction (
920
+ mlir::func::FuncOp function,
921
+ llvm::SmallVectorImpl<std::optional<xla::OriginalValueProto>>*
922
+ original_value_protos) {
923
+ original_value_protos->resize (function.getNumArguments (), std::nullopt );
924
+ for (int i = 0 , end = function.getNumArguments (); i < end; ++i) {
925
+ if (auto original_value_attr = function.getArgAttrOfType <mlir::StringAttr>(
926
+ i, xla::kMhloOriginalValueAttr )) {
927
+ (*original_value_protos)[i] =
928
+ xla::ConvertOriginalValue (original_value_attr.getValue ());
929
+ }
930
+ }
931
+ }
932
+
919
933
static bool SomeOptionalShardingsAreSet (
920
934
llvm::ArrayRef<std::optional<xla::OpSharding>> shardings) {
921
935
return llvm::any_of (shardings,
@@ -1114,8 +1128,10 @@ class ConvertToHloModule {
1114
1128
bool ensure_single_arg,
1115
1129
const std::vector<bool >& entry_args_same_across_replicas,
1116
1130
llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
1131
+ llvm::ArrayRef<std::optional<xla::FrontendAttributes>> arg_fe_attrs,
1132
+ llvm::ArrayRef<std::optional<xla::OriginalValueProto>>
1133
+ arg_original_value_protos,
1117
1134
llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
1118
- llvm::ArrayRef<std::optional<xla::FrontendAttributes>> fe_attrs,
1119
1135
xla::XlaComputationId& computation,
1120
1136
llvm::ArrayRef<mlir::Value> implicit_operands = {},
1121
1137
llvm::ArrayRef<mlir::Value> implicit_results = {});
@@ -5493,8 +5509,9 @@ LogicalResult ConvertToHloModule::LowerStablehloCompositeCall(
5493
5509
/* is_entry_function=*/ false ,
5494
5510
/* ensure_single_arg=*/ false ,
5495
5511
/* entry_args_same_across_replicas=*/ {},
5496
- /* arg_shardings=*/ {}, /* ret_shardings=*/ {},
5497
- /* fe_attrs=*/ {}, /* computation=*/ computation,
5512
+ /* arg_shardings=*/ {}, /* arg_fe_attrs=*/ {},
5513
+ /* arg_original_value_protos=*/ {}, /* ret_shardings=*/ {},
5514
+ /* computation=*/ computation,
5498
5515
/* implicit_operands=*/ {}))) {
5499
5516
return failure ();
5500
5517
}
@@ -5552,8 +5569,9 @@ LogicalResult ConvertToHloModule::LowerCompositeCall(
5552
5569
/* is_entry_function=*/ false ,
5553
5570
/* ensure_single_arg=*/ false ,
5554
5571
/* entry_args_same_across_replicas=*/ {},
5555
- /* arg_shardings=*/ {}, /* ret_shardings=*/ {},
5556
- /* fe_attrs=*/ {}, /* computation=*/ computation,
5572
+ /* arg_shardings=*/ {}, /* arg_fe_attrs=*/ {},
5573
+ /* arg_original_value_protos=*/ {}, /* ret_shardings=*/ {},
5574
+ /* computation=*/ computation,
5557
5575
/* implicit_operands=*/ {}))) {
5558
5576
return failure ();
5559
5577
}
@@ -5908,6 +5926,8 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) {
5908
5926
llvm::SmallVector<std::optional<xla::OpSharding>, 4 > arg_shardings;
5909
5927
llvm::SmallVector<std::optional<xla::OpSharding>, 4 > ret_shardings;
5910
5928
llvm::SmallVector<std::optional<xla::FrontendAttributes>, 4 > arg_fe_attrs;
5929
+ llvm::SmallVector<std::optional<xla::OriginalValueProto>, 4 >
5930
+ arg_original_value_protos;
5911
5931
if (entry_function) {
5912
5932
bool any_arg_replicated = false ;
5913
5933
entry_args_same_across_replicas.reserve (f.getNumArguments ());
@@ -5954,13 +5974,14 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) {
5954
5974
if (!any_arg_replicated) entry_args_same_across_replicas.clear ();
5955
5975
}
5956
5976
ExtractFrontendAttributesFromFunction (f, &arg_fe_attrs);
5977
+ ExtractOriginalValuesFromFunction (f, &arg_original_value_protos);
5957
5978
ExtractShardingsFromFunction (f, &arg_shardings, &ret_shardings,
5958
5979
entry_function);
5959
5980
xla::XlaComputationId computation;
5960
5981
if (failed (LowerBasicBlockAsFunction (
5961
5982
&f.front (), builder.get (), entry_function, false ,
5962
- entry_args_same_across_replicas, arg_shardings, ret_shardings ,
5963
- arg_fe_attrs , computation))) {
5983
+ entry_args_same_across_replicas, arg_shardings, arg_fe_attrs ,
5984
+ arg_original_value_protos, ret_shardings , computation))) {
5964
5985
return failure ();
5965
5986
}
5966
5987
if (auto execution_thread =
@@ -6102,12 +6123,14 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
6102
6123
bool ensure_single_arg,
6103
6124
const std::vector<bool >& entry_args_same_across_replicas,
6104
6125
llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
6126
+ llvm::ArrayRef<std::optional<xla::FrontendAttributes>> arg_fe_attrs,
6127
+ llvm::ArrayRef<std::optional<xla::OriginalValueProto>>
6128
+ arg_original_value_protos,
6105
6129
llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
6106
- llvm::ArrayRef<std::optional<xla::FrontendAttributes>> fe_attrs,
6107
6130
xla::XlaComputationId& computation,
6108
6131
llvm::ArrayRef<mlir::Value> implicit_operands,
6109
6132
llvm::ArrayRef<mlir::Value> implicit_results) {
6110
- // Mapping from the Value to lowered XlaOp.
6133
+ // Mapping from the Value to lowered XlaOp.
6111
6134
ValueLoweringMap lowering;
6112
6135
6113
6136
// If using tuples as input, then there is only one input parameter that is a
@@ -6137,6 +6160,10 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
6137
6160
xla::XlaScopedShardingAssignment scoped_sharding (
6138
6161
builder, arg_shardings.empty () ? std::nullopt
6139
6162
: arg_shardings[arg.getArgNumber ()]);
6163
+ xla::XlaScopedOriginalValueAssignment original_value (
6164
+ builder, arg_original_value_protos.empty ()
6165
+ ? std::nullopt
6166
+ : arg_original_value_protos[arg.getArgNumber ()]);
6140
6167
lowering[arg] = xla::GetTupleElement (tuple, arg.getArgNumber ());
6141
6168
}
6142
6169
} else {
@@ -6172,6 +6199,10 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
6172
6199
xla::XlaScopedShardingAssignment scoped_sharding (
6173
6200
builder,
6174
6201
arg_shardings.empty () ? std::nullopt : arg_shardings[num]);
6202
+ xla::XlaScopedOriginalValueAssignment original_value (
6203
+ builder, arg_original_value_protos.empty ()
6204
+ ? std::nullopt
6205
+ : arg_original_value_protos[num]);
6175
6206
lowering[arg] = xla::GetTupleElement (tuple, num);
6176
6207
}
6177
6208
for (auto [implicit_index, implicit_operand] :
@@ -6188,6 +6219,10 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
6188
6219
xla::XlaScopedShardingAssignment scoped_sharding (
6189
6220
builder,
6190
6221
arg_shardings.empty () ? std::nullopt : arg_shardings.front ());
6222
+ xla::XlaScopedOriginalValueAssignment original_value (
6223
+ builder, arg_original_value_protos.empty ()
6224
+ ? std::nullopt
6225
+ : arg_original_value_protos.front ());
6191
6226
mlir::Value arg = implicit_operands.empty () ? block->getArgument (0 )
6192
6227
: implicit_operands.front ();
6193
6228
xla::XlaScopedOpMetadataAssignment op_metadata (
@@ -6214,10 +6249,14 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
6214
6249
xla::Shape shape = xla::TypeToShape (arg.getType ());
6215
6250
xla::XlaScopedShardingAssignment scoped_sharding (
6216
6251
builder, arg_shardings.empty () ? std::nullopt : arg_shardings[num]);
6217
- if (!fe_attrs.empty () && fe_attrs[num]) {
6252
+ xla::XlaScopedOriginalValueAssignment original_value (
6253
+ builder, arg_original_value_protos.empty ()
6254
+ ? std::nullopt
6255
+ : arg_original_value_protos[num]);
6256
+ if (!arg_fe_attrs.empty () && arg_fe_attrs[num]) {
6218
6257
// Populates frontend attributes for parameters only for the entry
6219
6258
// functions with no tuple args.
6220
- builder->SetFrontendAttributes (*fe_attrs [num]);
6259
+ builder->SetFrontendAttributes (*arg_fe_attrs [num]);
6221
6260
}
6222
6261
// Save the location information as a name. For example JAX will set the
6223
6262
// name of the function argument of these. Want to preserve these for
@@ -6277,8 +6316,10 @@ LogicalResult ConvertToHloModule::LowerRegionAsComputation(
6277
6316
®ion->front (), builder.get (),
6278
6317
/* is_entry_function=*/ false ,
6279
6318
/* ensure_single_arg*/ ensure_single_arg,
6280
- /* entry_args_same_across_replicas=*/ {}, arg_shardings, ret_shardings,
6281
- /* fe_attrs=*/ {}, func, implicit_operands, implicit_results))) {
6319
+ /* entry_args_same_across_replicas=*/ {}, arg_shardings,
6320
+ /* arg_fe_attrs=*/ {},
6321
+ /* arg_original_value_protos=*/ {}, ret_shardings, func,
6322
+ implicit_operands, implicit_results))) {
6282
6323
return failure ();
6283
6324
}
6284
6325
return success ();
0 commit comments