@@ -474,6 +474,151 @@ TVM_REGISTER_OP("relax.flatten")
474
474
.set_attr<TMixedPrecisionPolicy>(" TMixedPrecisionPolicy" , MixedPrecisionPolicyKind::kFollow )
475
475
.set_attr<Bool>(" FPurity" , Bool(true ));
476
476
477
+ /* relax.index_tensor */
478
+
479
+ Expr index_tensor (Expr first, Expr tensors) {
480
+ static const Op& op = Op::Get (" relax.index_tensor" );
481
+ return Call (op, {std::move (first), std::move (tensors)}, Attrs (), {});
482
+ }
483
+
484
+ TVM_REGISTER_GLOBAL (" relax.op.index_tensor" ).set_body_typed(index_tensor);
485
+
486
+ StructInfo InferStructInfoIndexTensor (const Call& call, const BlockBuilder& ctx) {
487
+ if (call->args .size () != 2 ) {
488
+ ctx->ReportFatal (Diagnostic::Error (call) << " Index.Tensor op should have 2 arguments" );
489
+ }
490
+
491
+ TensorStructInfo data_sinfo = GetInputTensorStructInfo (call, 0 , ctx);
492
+ Array<TensorStructInfo> indices_sinfo = GetTensorStructInfoFromTuple (call, ctx, call->args [1 ]);
493
+
494
+ if (indices_sinfo.empty ()) {
495
+ ctx->ReportFatal (Diagnostic::Error (call)
496
+ << " index_tensor expects a non‑empty tuple of index tensors" );
497
+ }
498
+
499
+ DataType output_dtype = data_sinfo->dtype ;
500
+ int n_indices = static_cast <int >(indices_sinfo.size ());
501
+ Optional<VDevice> vdev = data_sinfo->vdevice ;
502
+
503
+ // Indices must be integers
504
+ for (int i = 0 ; i < n_indices; ++i) {
505
+ const auto & s = indices_sinfo[i];
506
+ if (!s->IsUnknownDtype () && !s->dtype .is_int ()) {
507
+ ctx->ReportFatal (Diagnostic::Error (call)
508
+ << " index_tensor requires every index tensor to have an integer dtype; "
509
+ << " index " << i << " has dtype " << s->dtype );
510
+ }
511
+ }
512
+
513
+ // Count of indices must be less than or equal to data.ndim
514
+ if (!data_sinfo->IsUnknownNdim () && n_indices > data_sinfo->ndim ) {
515
+ ctx->ReportFatal (Diagnostic::Error (call)
516
+ << " index_tensor received " << n_indices
517
+ << " index tensors, but data has only " << data_sinfo->ndim << " dimensions" );
518
+ }
519
+
520
+ arith::Analyzer* analyzer = ctx->GetAnalyzer ();
521
+ bool all_index_have_shape_value = true ;
522
+ std::vector<Array<PrimExpr>> index_shapes;
523
+ int max_index_ndim = 0 ;
524
+
525
+ for (const auto & s : indices_sinfo) {
526
+ const auto * shp = s->shape .as <ShapeExprNode>();
527
+ if (!shp) {
528
+ all_index_have_shape_value = false ;
529
+ } else {
530
+ index_shapes.push_back (shp->values );
531
+ max_index_ndim = std::max (max_index_ndim, static_cast <int >(shp->values .size ()));
532
+ }
533
+ if (!s->IsUnknownNdim ()) {
534
+ max_index_ndim = std::max (max_index_ndim, s->ndim );
535
+ }
536
+ }
537
+
538
+ Optional<Array<PrimExpr>> broadcast_shape;
539
+ bool shape_unknown = !all_index_have_shape_value;
540
+
541
+ if (all_index_have_shape_value) {
542
+ // initialise broadcast result with 1’s
543
+ Array<PrimExpr> out_shape;
544
+ for (int i = 0 ; i < max_index_ndim; ++i) {
545
+ out_shape.push_back (IntImm (DataType::Int (64 ), 1 ));
546
+ }
547
+
548
+ for (const auto & ishape : index_shapes) {
549
+ int cur_ndim = ishape.size ();
550
+ for (int axis = 0 ; axis < max_index_ndim; ++axis) {
551
+ int lhs_axis = max_index_ndim - 1 - axis; // aligned from right
552
+ int rhs_axis = cur_ndim - 1 - axis;
553
+ if (rhs_axis < 0 ) break ; // shorter rank – done
554
+
555
+ PrimExpr lhs_dim = out_shape[lhs_axis];
556
+ PrimExpr rhs_dim = ishape[rhs_axis];
557
+
558
+ const auto * lhs_int = lhs_dim.as <IntImmNode>();
559
+ const auto * rhs_int = rhs_dim.as <IntImmNode>();
560
+
561
+ // Case 1: current broadcast slot is 1 -> always replace
562
+ if (lhs_int && lhs_int->value == 1 ) {
563
+ out_shape.Set (lhs_axis, rhs_dim);
564
+ continue ;
565
+ }
566
+ // Case 2: rhs is 1 -> keep lhs_dim unchanged
567
+ if (rhs_int && rhs_int->value == 1 ) {
568
+ continue ;
569
+ }
570
+ // Both are non‑one constants: must equal
571
+ if (lhs_int && rhs_int && lhs_int->value != rhs_int->value ) {
572
+ ctx->ReportFatal (Diagnostic::Error (call)
573
+ << " index_tensor: cannot broadcast index shapes. Mismatch at axis "
574
+ << lhs_axis << " : " << lhs_dim << " vs " << rhs_dim);
575
+ }
576
+ // Give up if not provablt equal
577
+ if (!analyzer->CanProveEqual (lhs_dim, rhs_dim)) {
578
+ shape_unknown = true ;
579
+ break ;
580
+ }
581
+ }
582
+ if (shape_unknown) break ;
583
+ }
584
+
585
+ if (!shape_unknown) broadcast_shape = out_shape;
586
+ }
587
+
588
+ // Count of dimensions in output
589
+ int out_ndim = kUnknownNDim ;
590
+ if (!data_sinfo->IsUnknownNdim ()) {
591
+ int tail_ndim = data_sinfo->ndim - n_indices;
592
+ if (broadcast_shape.defined ()) {
593
+ out_ndim = static_cast <int >(broadcast_shape.value ().size ()) + tail_ndim;
594
+ } else if (!shape_unknown) {
595
+ out_ndim = max_index_ndim + tail_ndim;
596
+ }
597
+ }
598
+
599
+ // Derive output shape
600
+ if (broadcast_shape.defined ()) {
601
+ const auto * data_shape_expr = data_sinfo->shape .as <ShapeExprNode>();
602
+ if (data_shape_expr) {
603
+ Array<PrimExpr> result_shape = broadcast_shape.value ();
604
+ for (int i = n_indices; i < data_sinfo->ndim ; ++i) {
605
+ result_shape.push_back (data_shape_expr->values [i]);
606
+ }
607
+ return TensorStructInfo (ShapeExpr (result_shape), output_dtype, vdev);
608
+ }
609
+ }
610
+
611
+ // Unknown output shape
612
+ return TensorStructInfo (output_dtype, out_ndim, vdev);
613
+ }
614
+
615
+ TVM_REGISTER_OP (" relax.index_tensor" )
616
+ .set_num_inputs(2 )
617
+ .add_argument(" data" , " Tensor" , " The input data." )
618
+ .add_argument(" indices" , " List of Tensors" , " The indices used to index." )
619
+ .set_attr<FInferStructInfo>(" FInferStructInfo" , InferStructInfoIndexTensor)
620
+ .set_attr<Bool>(" FPurity" , Bool(true ));
621
+
477
622
/* relax.layout_transform */
478
623
TVM_REGISTER_NODE_TYPE (LayoutTransformAttrs);
479
624
0 commit comments