Skip to content

Commit 3d3a8ed

Browse files
[release/9.0] Fixing SetSlice, Reshape, TryCopyTo. (#108282)
* working * comments from PR * can always reshape to self * fixed tests * comments from PR * fixing tests --------- Co-authored-by: Michael Sharp <[email protected]>
1 parent dc412cb commit 3d3a8ed

File tree

9 files changed

+387
-172
lines changed

9 files changed

+387
-172
lines changed

src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,4 +228,7 @@
228228
<data name="ThrowArgument_StackShapesNotSame" xml:space="preserve">
229229
<value>All tensors must have the same shape.</value>
230230
</data>
231+
<data name="Argument_CannotReshapeNonContiguousOrDense" xml:space="preserve">
232+
<value>The Tensor provided is either non-contiguous or non-dense. Reshape only works with contigous and dense memory. You may need to Broadcast or Copy the data to be contigous.</value>
233+
</data>
231234
</root>

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ReadOnlyTensorSpan.cs

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -518,31 +518,39 @@ public void CopyTo(scoped TensorSpan<T> destination)
518518
// Using "if (!TryCopyTo(...))" results in two branches: one for the length
519519
// check, and one for the result of TryCopyTo. Since these checks are equivalent,
520520
// we can optimize by performing the check once ourselves then calling Memmove directly.
521-
if (_shape.FlattenedLength <= destination.FlattenedLength)
521+
if (TensorHelpers.IsBroadcastableTo(Lengths, destination.Lengths))
522522
{
523523
scoped Span<nint> curIndexes;
524524
nint[]? curIndexesArray;
525+
525526
if (Rank > TensorShape.MaxInlineRank)
526527
{
527-
curIndexesArray = ArrayPool<nint>.Shared.Rent(Rank);
528-
curIndexes = curIndexesArray.AsSpan(0, Rank);
528+
curIndexesArray = ArrayPool<nint>.Shared.Rent(destination.Rank);
529+
curIndexes = curIndexesArray.AsSpan(0, destination.Rank);
530+
529531
}
530532
else
531533
{
532534
curIndexesArray = null;
533-
curIndexes = stackalloc nint[Rank];
535+
curIndexes = stackalloc nint[destination.Rank];
534536
}
535537
curIndexes.Clear();
536538

537539
nint copiedValues = 0;
538-
TensorSpan<T> slice = destination.Slice(_shape.Lengths);
539-
while (copiedValues < _shape.FlattenedLength)
540+
nint[] tempLengths = Tensor.GetSmallestBroadcastableLengths(Lengths, destination.Lengths);
541+
542+
TensorSpan<T> destinationSlice = destination.Slice(tempLengths);
543+
ReadOnlyTensorSpan<T> srcSlice = Tensor.LazyBroadcast(this, tempLengths);
544+
nint copyLength = srcSlice.Strides[^1] == 1 && TensorHelpers.IsContiguousAndDense(srcSlice) ? srcSlice.Lengths[^1] : 1;
545+
int indexToAdjust = srcSlice.Strides[^1] == 1 && TensorHelpers.IsContiguousAndDense(srcSlice) ? srcSlice.Rank - 2 : srcSlice.Rank - 1;
546+
547+
while (copiedValues < destination.FlattenedLength)
540548
{
541-
TensorSpanHelpers.Memmove(ref Unsafe.Add(ref slice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, Strides, Lengths)), ref Unsafe.Add(ref _reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, Strides, Lengths)), Lengths[Rank - 1]);
542-
TensorSpanHelpers.AdjustIndexes(Rank - 2, 1, curIndexes, _shape.Lengths);
543-
copiedValues += Lengths[Rank - 1];
549+
TensorSpanHelpers.Memmove(ref Unsafe.Add(ref destinationSlice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, destinationSlice.Strides, destinationSlice.Lengths)), ref Unsafe.Add(ref srcSlice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, srcSlice.Strides, srcSlice.Lengths)), copyLength);
550+
TensorSpanHelpers.AdjustIndexes(indexToAdjust, 1, curIndexes, tempLengths);
551+
copiedValues += copyLength;
544552
}
545-
Debug.Assert(copiedValues == _shape.FlattenedLength, "Didn't copy the right amount to the array.");
553+
Debug.Assert(copiedValues == destination.FlattenedLength, "Didn't copy the right amount to the array.");
546554

547555
if (curIndexesArray != null)
548556
ArrayPool<nint>.Shared.Return(curIndexesArray);
@@ -565,32 +573,40 @@ public bool TryCopyTo(scoped TensorSpan<T> destination)
565573
{
566574
bool retVal = false;
567575

568-
if (_shape.FlattenedLength <= destination.FlattenedLength)
576+
if (TensorHelpers.IsBroadcastableTo(Lengths, destination.Lengths))
569577
{
570578
scoped Span<nint> curIndexes;
571579
nint[]? curIndexesArray;
580+
572581
if (Rank > TensorShape.MaxInlineRank)
573582
{
574-
curIndexesArray = ArrayPool<nint>.Shared.Rent(Rank);
575-
curIndexes = curIndexesArray.AsSpan(0, Rank);
583+
curIndexesArray = ArrayPool<nint>.Shared.Rent(destination.Rank);
584+
curIndexes = curIndexesArray.AsSpan(0, destination.Rank);
585+
576586
}
577587
else
578588
{
579589
curIndexesArray = null;
580-
curIndexes = stackalloc nint[Rank];
590+
curIndexes = stackalloc nint[destination.Rank];
581591
}
582592
curIndexes.Clear();
583593

584594
nint copiedValues = 0;
585-
TensorSpan<T> slice = destination.Slice(_shape.Lengths);
586-
while (copiedValues < _shape.FlattenedLength)
595+
nint[] tempLengths = Tensor.GetSmallestBroadcastableLengths(Lengths, destination.Lengths);
596+
597+
TensorSpan<T> destinationSlice = destination.Slice(tempLengths);
598+
ReadOnlyTensorSpan<T> srcSlice = Tensor.LazyBroadcast(this, tempLengths);
599+
nint copyLength = srcSlice.Strides[^1] == 1 && TensorHelpers.IsContiguousAndDense(srcSlice) ? srcSlice.Lengths[^1] : 1;
600+
int indexToAdjust = srcSlice.Strides[^1] == 1 && TensorHelpers.IsContiguousAndDense(srcSlice) ? srcSlice.Rank - 2 : srcSlice.Rank - 1;
601+
602+
while (copiedValues < destination.FlattenedLength)
587603
{
588-
TensorSpanHelpers.Memmove(ref Unsafe.Add(ref slice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, Strides, Lengths)), ref Unsafe.Add(ref _reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, Strides, Lengths)), Lengths[Rank - 1]);
589-
TensorSpanHelpers.AdjustIndexes(Rank - 2, 1, curIndexes, _shape.Lengths);
590-
copiedValues += Lengths[Rank - 1];
604+
TensorSpanHelpers.Memmove(ref Unsafe.Add(ref destinationSlice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, destinationSlice.Strides, destinationSlice.Lengths)), ref Unsafe.Add(ref srcSlice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, srcSlice.Strides, srcSlice.Lengths)), copyLength);
605+
TensorSpanHelpers.AdjustIndexes(indexToAdjust, 1, curIndexes, tempLengths);
606+
copiedValues += copyLength;
591607
}
608+
Debug.Assert(copiedValues == destination.FlattenedLength, "Didn't copy the right amount to the array.");
592609
retVal = true;
593-
Debug.Assert(copiedValues == _shape.FlattenedLength, "Didn't copy the right amount to the array.");
594610

595611
if (curIndexesArray != null)
596612
ArrayPool<nint>.Shared.Return(curIndexesArray);

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs

Lines changed: 134 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2724,6 +2724,14 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, params ReadO
27242724
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
27252725
public static Tensor<T> Reshape<T>(this Tensor<T> tensor, params ReadOnlySpan<nint> lengths)
27262726
{
2727+
if (tensor.Lengths.SequenceEqual(lengths))
2728+
return tensor;
2729+
2730+
if (!TensorHelpers.IsContiguousAndDense<T>(tensor) && !tensor.Strides.Contains(0))
2731+
{
2732+
ThrowHelper.ThrowArgument_CannotReshapeNonContiguousOrDense();
2733+
}
2734+
27272735
nint[] arrLengths = lengths.ToArray();
27282736
// Calculate wildcard info.
27292737
if (lengths.Contains(-1))
@@ -2745,7 +2753,33 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, params ReadOnlySpan<ni
27452753
nint tempLinear = TensorSpanHelpers.CalculateTotalLength(arrLengths);
27462754
if (tempLinear != tensor.FlattenedLength)
27472755
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
2748-
nint[] strides = TensorSpanHelpers.CalculateStrides(arrLengths);
2756+
2757+
nint[] strides;
2758+
2759+
// If we contain a 0 stride we can only add dimensions of length 1.
2760+
if (tensor.Strides.Contains(0))
2761+
{
2762+
List<nint> origStrides = new List<nint>(tensor.Strides.ToArray());
2763+
int lengthOffset = 0;
2764+
for (int i = 0; i < arrLengths.Length; i++)
2765+
{
2766+
if (lengthOffset < tensor.Rank && arrLengths[i] == tensor.Lengths[lengthOffset])
2767+
lengthOffset++;
2768+
else if (arrLengths[i] == 1)
2769+
{
2770+
if (lengthOffset == tensor.Rank)
2771+
origStrides.Add(tensor.Strides[lengthOffset - 1]);
2772+
else
2773+
origStrides.Insert(i, tensor.Strides[i] * tensor.Lengths[i]);
2774+
}
2775+
else
2776+
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
2777+
}
2778+
strides = origStrides.ToArray();
2779+
}
2780+
else
2781+
strides = TensorSpanHelpers.CalculateStrides(arrLengths);
2782+
27492783
return new Tensor<T>(tensor._values, arrLengths, strides);
27502784
}
27512785

@@ -2758,6 +2792,14 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, params ReadOnlySpan<ni
27582792
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
27592793
public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, params scoped ReadOnlySpan<nint> lengths)
27602794
{
2795+
if (tensor.Lengths.SequenceEqual(lengths))
2796+
return tensor;
2797+
2798+
if (!TensorHelpers.IsContiguousAndDense<T>(tensor) && !tensor.Strides.Contains(0))
2799+
{
2800+
ThrowHelper.ThrowArgument_CannotReshapeNonContiguousOrDense();
2801+
}
2802+
27612803
nint[] arrLengths = lengths.ToArray();
27622804
// Calculate wildcard info.
27632805
if (lengths.Contains(-1))
@@ -2779,7 +2821,35 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, params scop
27792821
nint tempLinear = TensorSpanHelpers.CalculateTotalLength(arrLengths);
27802822
if (tempLinear != tensor.FlattenedLength)
27812823
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
2782-
nint[] strides = TensorSpanHelpers.CalculateStrides(arrLengths);
2824+
2825+
nint[] strides;
2826+
2827+
// If we contain a 0 stride we can only add dimensions of length 1.
2828+
if (tensor.Strides.Contains(0))
2829+
{
2830+
List<nint> origStrides = new List<nint>(tensor.Strides.ToArray());
2831+
int lengthOffset = 0;
2832+
for (int i = 0; i < arrLengths.Length; i++)
2833+
{
2834+
if (lengthOffset < tensor.Rank && arrLengths[i] == tensor.Lengths[lengthOffset])
2835+
{
2836+
lengthOffset++;
2837+
}
2838+
else if (arrLengths[i] == 1)
2839+
{
2840+
if (lengthOffset == tensor.Rank)
2841+
origStrides.Add(tensor.Strides[lengthOffset - 1]);
2842+
else
2843+
origStrides.Insert(i, tensor.Strides[i] * tensor.Lengths[i]);
2844+
}
2845+
else
2846+
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
2847+
}
2848+
strides = origStrides.ToArray();
2849+
}
2850+
else
2851+
strides = TensorSpanHelpers.CalculateStrides(arrLengths);
2852+
27832853
TensorSpan<T> output = new TensorSpan<T>(ref tensor._reference, arrLengths, strides, tensor._shape._memoryLength);
27842854
return output;
27852855
}
@@ -2793,6 +2863,14 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, params scop
27932863
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
27942864
public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> tensor, params scoped ReadOnlySpan<nint> lengths)
27952865
{
2866+
if (tensor.Lengths.SequenceEqual(lengths))
2867+
return tensor;
2868+
2869+
if (!TensorHelpers.IsContiguousAndDense<T>(tensor) && !tensor.Strides.Contains(0))
2870+
{
2871+
ThrowHelper.ThrowArgument_CannotReshapeNonContiguousOrDense();
2872+
}
2873+
27962874
nint[] arrLengths = lengths.ToArray();
27972875
// Calculate wildcard info.
27982876
if (lengths.Contains(-1))
@@ -2814,7 +2892,33 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
28142892
nint tempLinear = TensorSpanHelpers.CalculateTotalLength(arrLengths);
28152893
if (tempLinear != tensor.FlattenedLength)
28162894
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
2817-
nint[] strides = TensorSpanHelpers.CalculateStrides(arrLengths);
2895+
2896+
nint[] strides;
2897+
2898+
// If we contain a 0 stride we can only add dimensions of length 1.
2899+
if (tensor.Strides.Contains(0))
2900+
{
2901+
List<nint> origStrides = new List<nint>(tensor.Strides.ToArray());
2902+
int lengthOffset = 0;
2903+
for (int i = 0; i < arrLengths.Length; i++)
2904+
{
2905+
if (lengthOffset < tensor.Rank && arrLengths[i] == tensor.Lengths[lengthOffset])
2906+
lengthOffset++;
2907+
else if (arrLengths[i] == 1)
2908+
{
2909+
if (lengthOffset == tensor.Rank)
2910+
origStrides.Add(tensor.Strides[lengthOffset - 1]);
2911+
else
2912+
origStrides.Insert(i, tensor.Strides[i] * tensor.Lengths[i]);
2913+
}
2914+
else
2915+
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
2916+
}
2917+
strides = origStrides.ToArray();
2918+
}
2919+
else
2920+
strides = TensorSpanHelpers.CalculateStrides(arrLengths);
2921+
28182922
ReadOnlyTensorSpan<T> output = new ReadOnlyTensorSpan<T>(ref tensor._reference, arrLengths, strides, tensor._shape._memoryLength);
28192923
return output;
28202924
}
@@ -3053,14 +3157,17 @@ public static ref readonly TensorSpan<T> SetSlice<T>(this in TensorSpan<T> tenso
30533157
TensorSpan<T> srcSpan;
30543158
if (ranges == ReadOnlySpan<NRange>.Empty)
30553159
{
3056-
if (!tensor.Lengths.SequenceEqual(values.Lengths))
3160+
if (!TensorHelpers.IsBroadcastableTo(values.Lengths, tensor.Lengths))
30573161
ThrowHelper.ThrowArgument_SetSliceNoRange(nameof(values));
3058-
srcSpan = tensor.Slice(tensor.Lengths);
3162+
srcSpan = tensor;
30593163
}
30603164
else
30613165
srcSpan = tensor.Slice(ranges);
30623166

3063-
if (!srcSpan.Lengths.SequenceEqual(values.Lengths))
3167+
if (!TensorHelpers.IsContiguousAndDense<T>(srcSpan))
3168+
ThrowHelper.ThrowArgument_SetSliceInvalidShapes(nameof(values));
3169+
3170+
if (!TensorHelpers.IsBroadcastableTo(values.Lengths, srcSpan.Lengths))
30643171
ThrowHelper.ThrowArgument_SetSliceInvalidShapes(nameof(values));
30653172

30663173
values.CopyTo(srcSpan);
@@ -3555,8 +3662,13 @@ public static Tensor<T> Unsqueeze<T>(this Tensor<T> tensor, int dimension)
35553662

35563663
List<nint> tempLengths = tensor._lengths.ToList();
35573664
tempLengths.Insert(dimension, 1);
3558-
nint[] lengths = tempLengths.ToArray();
3559-
nint[] strides = TensorSpanHelpers.CalculateStrides(lengths);
3665+
nint[] lengths = [.. tempLengths];
3666+
List<nint> tempStrides = tensor.Strides.ToArray().ToList();
3667+
if (dimension == tensor.Rank)
3668+
tempStrides.Add(tensor.Strides[dimension - 1]);
3669+
else
3670+
tempStrides.Insert(dimension, tensor.Strides[dimension] * tensor.Lengths[dimension]);
3671+
nint[] strides = [.. tempStrides];
35603672
return new Tensor<T>(tensor._values, lengths, strides);
35613673
}
35623674

@@ -3574,8 +3686,13 @@ public static TensorSpan<T> Unsqueeze<T>(in this TensorSpan<T> tensor, int dimen
35743686

35753687
List<nint> tempLengths = tensor.Lengths.ToArray().ToList();
35763688
tempLengths.Insert(dimension, 1);
3577-
nint[] lengths = tempLengths.ToArray();
3578-
nint[] strides = TensorSpanHelpers.CalculateStrides(lengths);
3689+
nint[] lengths = [.. tempLengths];
3690+
List<nint> tempStrides = tensor.Strides.ToArray().ToList();
3691+
if (dimension == tensor.Rank)
3692+
tempStrides.Add(tensor.Strides[dimension - 1]);
3693+
else
3694+
tempStrides.Insert(dimension, tensor.Strides[dimension] * tensor.Lengths[dimension]);
3695+
nint[] strides = [.. tempStrides];
35793696
return new TensorSpan<T>(ref tensor._reference, lengths, strides, tensor._shape._memoryLength);
35803697
}
35813698

@@ -3593,8 +3710,13 @@ public static ReadOnlyTensorSpan<T> Unsqueeze<T>(in this ReadOnlyTensorSpan<T> t
35933710

35943711
List<nint> tempLengths = tensor.Lengths.ToArray().ToList();
35953712
tempLengths.Insert(dimension, 1);
3596-
nint[] lengths = tempLengths.ToArray();
3597-
nint[] strides = TensorSpanHelpers.CalculateStrides(lengths);
3713+
nint[] lengths = [.. tempLengths];
3714+
List<nint> tempStrides = tensor.Strides.ToArray().ToList();
3715+
if (dimension == tensor.Rank)
3716+
tempStrides.Add(tensor.Strides[dimension - 1]);
3717+
else
3718+
tempStrides.Insert(dimension, tensor.Strides[dimension] * tensor.Lengths[dimension]);
3719+
nint[] strides = [.. tempStrides];
35983720
return new ReadOnlyTensorSpan<T>(ref tensor._reference, lengths, strides, tensor._shape._memoryLength);
35993721
}
36003722
#endregion

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorHelpers.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ internal static bool IsBroadcastableTo(ReadOnlySpan<nint> lengths1, ReadOnlySpan
4343
nint s1;
4444
nint s2;
4545

46+
if (lengths1.Length == 0 || lengths2.Length == 0)
47+
return false;
48+
4649
while (lengths1Index >= 0 || lengths2Index >= 0)
4750
{
4851
// if a dimension is missing in one of the shapes, it is considered to be 1
@@ -56,7 +59,7 @@ internal static bool IsBroadcastableTo(ReadOnlySpan<nint> lengths1, ReadOnlySpan
5659
else
5760
s2 = lengths2[lengths2Index--];
5861

59-
if (s1 == s2 || (s1 == 1 && s2 != 1) || (s2 == 1 && s1 != 1)) { }
62+
if (s1 == s2 || (s1 == 1 && s2 > 1) || (s2 == 1 && s1 > 1)) { }
6063
else
6164
{
6265
areCompatible = false;

0 commit comments

Comments
 (0)