Skip to content

Commit d2fdf24

Browse files
authored
fix stack 2x2 tensor along dimension 1 (#108620)
1 parent c929990 commit d2fdf24

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3462,7 +3462,7 @@ public static Tensor<T> StackAlongDimension<T>(int dimension, params ReadOnlySpa
34623462
Tensor<T>[] outputs = new Tensor<T>[tensors.Length];
34633463
for (int i = 0; i < tensors.Length; i++)
34643464
{
3465-
outputs[i] = Tensor.Unsqueeze(tensors[0], dimension);
3465+
outputs[i] = Tensor.Unsqueeze(tensors[i], dimension);
34663466
}
34673467
return Tensor.ConcatenateOnDimension<T>(dimension, outputs);
34683468
}
@@ -3500,7 +3500,7 @@ public static ref readonly TensorSpan<T> StackAlongDimension<T>(scoped ReadOnlyS
35003500
Tensor<T>[] outputs = new Tensor<T>[tensors.Length];
35013501
for (int i = 0; i < tensors.Length; i++)
35023502
{
3503-
outputs[i] = Tensor.Unsqueeze(tensors[0], dimension);
3503+
outputs[i] = Tensor.Unsqueeze(tensors[i], dimension);
35043504
}
35053505
return ref Tensor.ConcatenateOnDimension<T>(dimension, tensors, destination);
35063506
}

src/libraries/System.Numerics.Tensors/tests/TensorTests.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,27 @@ public static void TensorStackTests()
10741074
Assert.Equal(8, resultTensor[1, 3, 1]);
10751075
Assert.Equal(9, resultTensor[1, 4, 0]);
10761076
Assert.Equal(9, resultTensor[1, 4, 1]);
1077+
1078+
// stacking 2x2 tensors along dimention 1
1079+
Tensor<int> v1 = Tensor.Create([1, 2, 3, 4], [2, 2]);
1080+
Tensor<int> v2 = Tensor.Create([10, 20, 30, 40], [2, 2]);
1081+
1082+
resultTensor = Tensor.StackAlongDimension(1, [v1, v2]);
1083+
1084+
Assert.Equal(3, resultTensor.Rank);
1085+
Assert.Equal(2, resultTensor.Lengths[0]);
1086+
Assert.Equal(2, resultTensor.Lengths[1]);
1087+
Assert.Equal(2, resultTensor.Lengths[2]);
1088+
1089+
Assert.Equal(1, resultTensor[0, 0, 0]);
1090+
Assert.Equal(2, resultTensor[0, 0, 1]);
1091+
Assert.Equal(10, resultTensor[0, 1, 0]);
1092+
Assert.Equal(20, resultTensor[0, 1, 1]);
1093+
1094+
Assert.Equal(3, resultTensor[1, 0, 0]);
1095+
Assert.Equal(4, resultTensor[1, 0, 1]);
1096+
Assert.Equal(30, resultTensor[1, 1, 0]);
1097+
Assert.Equal(40, resultTensor[1, 1, 1]);
10771098
}
10781099

10791100
[Fact]

0 commit comments

Comments
 (0)