@@ -2724,6 +2724,14 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, params ReadO
2724
2724
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
2725
2725
public static Tensor < T > Reshape < T > ( this Tensor < T > tensor , params ReadOnlySpan < nint > lengths )
2726
2726
{
2727
+ if ( tensor . Lengths . SequenceEqual ( lengths ) )
2728
+ return tensor ;
2729
+
2730
+ if ( ! TensorHelpers . IsContiguousAndDense < T > ( tensor ) & & ! tensor . Strides . Contains ( 0 ) )
2731
+ {
2732
+ ThrowHelper . ThrowArgument_CannotReshapeNonContiguousOrDense ( ) ;
2733
+ }
2734
+
2727
2735
nint [ ] arrLengths = lengths . ToArray ( ) ;
2728
2736
// Calculate wildcard info.
2729
2737
if ( lengths . Contains ( - 1 ) )
@@ -2745,7 +2753,33 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, params ReadOnlySpan<ni
2745
2753
nint tempLinear = TensorSpanHelpers . CalculateTotalLength ( arrLengths ) ;
2746
2754
if ( tempLinear != tensor . FlattenedLength )
2747
2755
ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2748
- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2756
+
2757
+ nint [ ] strides ;
2758
+
2759
+ // If we contain a 0 stride we can only add dimensions of length 1.
2760
+ if ( tensor . Strides . Contains ( 0 ) )
2761
+ {
2762
+ List < nint > origStrides = new List < nint > ( tensor . Strides . ToArray ( ) ) ;
2763
+ int lengthOffset = 0 ;
2764
+ for ( int i = 0 ; i < arrLengths . Length ; i ++ )
2765
+ {
2766
+ if ( lengthOffset < tensor . Rank && arrLengths [ i ] == tensor . Lengths [ lengthOffset ] )
2767
+ lengthOffset ++ ;
2768
+ else if ( arrLengths [ i ] == 1 )
2769
+ {
2770
+ if ( lengthOffset == tensor . Rank )
2771
+ origStrides . Add ( tensor . Strides [ lengthOffset - 1 ] ) ;
2772
+ else
2773
+ origStrides . Insert ( i , tensor . Strides [ i ] * tensor . Lengths [ i ] ) ;
2774
+ }
2775
+ else
2776
+ ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2777
+ }
2778
+ strides = origStrides . ToArray ( ) ;
2779
+ }
2780
+ else
2781
+ strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2782
+
2749
2783
return new Tensor < T > ( tensor . _values , arrLengths , strides ) ;
2750
2784
}
2751
2785
@@ -2758,6 +2792,14 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, params ReadOnlySpan<ni
2758
2792
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
2759
2793
public static TensorSpan < T > Reshape < T > ( in this TensorSpan < T > tensor , params scoped ReadOnlySpan < nint > lengths )
2760
2794
{
2795
+ if ( tensor . Lengths . SequenceEqual ( lengths ) )
2796
+ return tensor ;
2797
+
2798
+ if ( ! TensorHelpers . IsContiguousAndDense < T > ( tensor ) & & ! tensor . Strides . Contains ( 0 ) )
2799
+ {
2800
+ ThrowHelper . ThrowArgument_CannotReshapeNonContiguousOrDense ( ) ;
2801
+ }
2802
+
2761
2803
nint [ ] arrLengths = lengths . ToArray ( ) ;
2762
2804
// Calculate wildcard info.
2763
2805
if ( lengths . Contains ( - 1 ) )
@@ -2779,7 +2821,35 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, params scop
2779
2821
nint tempLinear = TensorSpanHelpers . CalculateTotalLength ( arrLengths ) ;
2780
2822
if ( tempLinear != tensor . FlattenedLength )
2781
2823
ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2782
- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2824
+
2825
+ nint [ ] strides ;
2826
+
2827
+ // If we contain a 0 stride we can only add dimensions of length 1.
2828
+ if ( tensor . Strides . Contains ( 0 ) )
2829
+ {
2830
+ List < nint > origStrides = new List < nint > ( tensor . Strides . ToArray ( ) ) ;
2831
+ int lengthOffset = 0 ;
2832
+ for ( int i = 0 ; i < arrLengths . Length ; i ++ )
2833
+ {
2834
+ if ( lengthOffset < tensor . Rank && arrLengths [ i ] == tensor . Lengths [ lengthOffset ] )
2835
+ {
2836
+ lengthOffset ++ ;
2837
+ }
2838
+ else if ( arrLengths [ i ] == 1 )
2839
+ {
2840
+ if ( lengthOffset == tensor . Rank )
2841
+ origStrides . Add ( tensor . Strides [ lengthOffset - 1 ] ) ;
2842
+ else
2843
+ origStrides . Insert ( i , tensor . Strides [ i ] * tensor . Lengths [ i ] ) ;
2844
+ }
2845
+ else
2846
+ ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2847
+ }
2848
+ strides = origStrides . ToArray ( ) ;
2849
+ }
2850
+ else
2851
+ strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2852
+
2783
2853
TensorSpan < T > output = new TensorSpan < T > ( ref tensor . _reference , arrLengths , strides , tensor . _shape . _memoryLength ) ;
2784
2854
return output ;
2785
2855
}
@@ -2793,6 +2863,14 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, params scop
2793
2863
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
2794
2864
public static ReadOnlyTensorSpan < T > Reshape < T > ( in this ReadOnlyTensorSpan < T > tensor , params scoped ReadOnlySpan < nint > lengths )
2795
2865
{
2866
+ if ( tensor . Lengths . SequenceEqual ( lengths ) )
2867
+ return tensor ;
2868
+
2869
+ if ( ! TensorHelpers . IsContiguousAndDense < T > ( tensor ) & & ! tensor . Strides . Contains ( 0 ) )
2870
+ {
2871
+ ThrowHelper . ThrowArgument_CannotReshapeNonContiguousOrDense ( ) ;
2872
+ }
2873
+
2796
2874
nint [ ] arrLengths = lengths . ToArray ( ) ;
2797
2875
// Calculate wildcard info.
2798
2876
if ( lengths . Contains ( - 1 ) )
@@ -2814,7 +2892,33 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
2814
2892
nint tempLinear = TensorSpanHelpers . CalculateTotalLength ( arrLengths ) ;
2815
2893
if ( tempLinear != tensor . FlattenedLength )
2816
2894
ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2817
- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2895
+
2896
+ nint [ ] strides ;
2897
+
2898
+ // If we contain a 0 stride we can only add dimensions of length 1.
2899
+ if ( tensor . Strides . Contains ( 0 ) )
2900
+ {
2901
+ List < nint > origStrides = new List < nint > ( tensor . Strides . ToArray ( ) ) ;
2902
+ int lengthOffset = 0 ;
2903
+ for ( int i = 0 ; i < arrLengths . Length ; i ++ )
2904
+ {
2905
+ if ( lengthOffset < tensor . Rank && arrLengths [ i ] == tensor . Lengths [ lengthOffset ] )
2906
+ lengthOffset ++ ;
2907
+ else if ( arrLengths [ i ] == 1 )
2908
+ {
2909
+ if ( lengthOffset == tensor . Rank )
2910
+ origStrides . Add ( tensor . Strides [ lengthOffset - 1 ] ) ;
2911
+ else
2912
+ origStrides . Insert ( i , tensor . Strides [ i ] * tensor . Lengths [ i ] ) ;
2913
+ }
2914
+ else
2915
+ ThrowHelper . ThrowArgument_InvalidReshapeDimensions ( ) ;
2916
+ }
2917
+ strides = origStrides . ToArray ( ) ;
2918
+ }
2919
+ else
2920
+ strides = TensorSpanHelpers . CalculateStrides ( arrLengths ) ;
2921
+
2818
2922
ReadOnlyTensorSpan < T > output = new ReadOnlyTensorSpan < T > ( ref tensor . _reference , arrLengths , strides , tensor . _shape . _memoryLength ) ;
2819
2923
return output ;
2820
2924
}
@@ -3053,14 +3157,17 @@ public static ref readonly TensorSpan<T> SetSlice<T>(this in TensorSpan<T> tenso
3053
3157
TensorSpan < T > srcSpan ;
3054
3158
if ( ranges == ReadOnlySpan < NRange > . Empty )
3055
3159
{
3056
- if ( ! tensor . Lengths . SequenceEqual ( values . Lengths ) )
3160
+ if ( ! TensorHelpers . IsBroadcastableTo ( values . Lengths , tensor . Lengths ) )
3057
3161
ThrowHelper . ThrowArgument_SetSliceNoRange ( nameof ( values ) ) ;
3058
- srcSpan = tensor . Slice ( tensor . Lengths ) ;
3162
+ srcSpan = tensor ;
3059
3163
}
3060
3164
else
3061
3165
srcSpan = tensor . Slice ( ranges ) ;
3062
3166
3063
- if ( ! srcSpan . Lengths . SequenceEqual ( values . Lengths ) )
3167
+ if ( ! TensorHelpers . IsContiguousAndDense < T > ( srcSpan ) )
3168
+ ThrowHelper . ThrowArgument_SetSliceInvalidShapes ( nameof ( values ) ) ;
3169
+
3170
+ if ( ! TensorHelpers . IsBroadcastableTo ( values . Lengths , srcSpan . Lengths ) )
3064
3171
ThrowHelper . ThrowArgument_SetSliceInvalidShapes ( nameof ( values ) ) ;
3065
3172
3066
3173
values . CopyTo ( srcSpan ) ;
@@ -3555,8 +3662,13 @@ public static Tensor<T> Unsqueeze<T>(this Tensor<T> tensor, int dimension)
3555
3662
3556
3663
List < nint > tempLengths = tensor . _lengths . ToList ( ) ;
3557
3664
tempLengths . Insert ( dimension , 1 ) ;
3558
- nint [ ] lengths = tempLengths . ToArray ( ) ;
3559
- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( lengths ) ;
3665
+ nint [ ] lengths = [ .. tempLengths ] ;
3666
+ List < nint > tempStrides = tensor . Strides . ToArray ( ) . ToList ( ) ;
3667
+ if ( dimension == tensor . Rank )
3668
+ tempStrides . Add ( tensor . Strides [ dimension - 1 ] ) ;
3669
+ else
3670
+ tempStrides . Insert ( dimension , tensor . Strides [ dimension ] * tensor . Lengths [ dimension ] ) ;
3671
+ nint [ ] strides = [ .. tempStrides ] ;
3560
3672
return new Tensor < T > ( tensor . _values , lengths , strides ) ;
3561
3673
}
3562
3674
@@ -3574,8 +3686,13 @@ public static TensorSpan<T> Unsqueeze<T>(in this TensorSpan<T> tensor, int dimen
3574
3686
3575
3687
List < nint > tempLengths = tensor . Lengths . ToArray ( ) . ToList ( ) ;
3576
3688
tempLengths . Insert ( dimension , 1 ) ;
3577
- nint [ ] lengths = tempLengths . ToArray ( ) ;
3578
- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( lengths ) ;
3689
+ nint [ ] lengths = [ .. tempLengths ] ;
3690
+ List < nint > tempStrides = tensor . Strides . ToArray ( ) . ToList ( ) ;
3691
+ if ( dimension == tensor . Rank )
3692
+ tempStrides . Add ( tensor . Strides [ dimension - 1 ] ) ;
3693
+ else
3694
+ tempStrides . Insert ( dimension , tensor . Strides [ dimension ] * tensor . Lengths [ dimension ] ) ;
3695
+ nint [ ] strides = [ .. tempStrides ] ;
3579
3696
return new TensorSpan < T > ( ref tensor . _reference , lengths , strides , tensor . _shape . _memoryLength ) ;
3580
3697
}
3581
3698
@@ -3593,8 +3710,13 @@ public static ReadOnlyTensorSpan<T> Unsqueeze<T>(in this ReadOnlyTensorSpan<T> t
3593
3710
3594
3711
List < nint > tempLengths = tensor . Lengths . ToArray ( ) . ToList ( ) ;
3595
3712
tempLengths . Insert ( dimension , 1 ) ;
3596
- nint [ ] lengths = tempLengths . ToArray ( ) ;
3597
- nint [ ] strides = TensorSpanHelpers . CalculateStrides ( lengths ) ;
3713
+ nint [ ] lengths = [ .. tempLengths ] ;
3714
+ List < nint > tempStrides = tensor . Strides . ToArray ( ) . ToList ( ) ;
3715
+ if ( dimension == tensor . Rank )
3716
+ tempStrides . Add ( tensor . Strides [ dimension - 1 ] ) ;
3717
+ else
3718
+ tempStrides . Insert ( dimension , tensor . Strides [ dimension ] * tensor . Lengths [ dimension ] ) ;
3719
+ nint [ ] strides = [ .. tempStrides ] ;
3598
3720
return new ReadOnlyTensorSpan < T > ( ref tensor . _reference , lengths , strides , tensor . _shape . _memoryLength ) ;
3599
3721
}
3600
3722
#endregion
0 commit comments