1
- using System ;
2
1
using System . Collections . Generic ;
3
2
using System . Linq ;
4
3
using Unity . MLAgents . Inference . Utils ;
@@ -55,62 +54,26 @@ internal class DiscreteActionOutputApplier : TensorApplier.IApplier
55
54
{
56
55
readonly int [ ] m_ActionSize ;
57
56
readonly Multinomial m_Multinomial ;
58
- readonly ITensorAllocator m_Allocator ;
59
57
readonly ActionSpec m_ActionSpec ;
58
+ readonly int [ ] m_StartActionIndices ;
59
+ readonly float [ ] m_CdfBuffer ;
60
+
60
61
61
62
public DiscreteActionOutputApplier ( ActionSpec actionSpec , int seed , ITensorAllocator allocator )
62
63
{
63
64
m_ActionSize = actionSpec . BranchSizes ;
64
65
m_Multinomial = new Multinomial ( seed ) ;
65
- m_Allocator = allocator ;
66
66
m_ActionSpec = actionSpec ;
67
+ m_StartActionIndices = Utilities . CumSum ( m_ActionSize ) ;
68
+
69
+ // Scratch space for computing the cumulative distribution function.
70
+ // In order to reuse it, make it the size of the largest branch.
71
+ var largestBranch = Mathf . Max ( m_ActionSize ) ;
72
+ m_CdfBuffer = new float [ largestBranch ] ;
67
73
}
68
74
69
75
public void Apply ( TensorProxy tensorProxy , IList < int > actionIds , Dictionary < int , ActionBuffers > lastActions )
70
76
{
71
- //var tensorDataProbabilities = tensorProxy.Data as float[,];
72
- var idActionPairList = actionIds as List < int > ?? actionIds . ToList ( ) ;
73
- var batchSize = idActionPairList . Count ;
74
- var actionValues = new float [ batchSize , m_ActionSize . Length ] ;
75
- var startActionIndices = Utilities . CumSum ( m_ActionSize ) ;
76
- for ( var actionIndex = 0 ; actionIndex < m_ActionSize . Length ; actionIndex ++ )
77
- {
78
- var nBranchAction = m_ActionSize [ actionIndex ] ;
79
- var actionProbs = new TensorProxy ( )
80
- {
81
- valueType = TensorProxy . TensorType . FloatingPoint ,
82
- shape = new long [ ] { batchSize , nBranchAction } ,
83
- data = m_Allocator . Alloc ( new TensorShape ( batchSize , nBranchAction ) )
84
- } ;
85
-
86
- for ( var batchIndex = 0 ; batchIndex < batchSize ; batchIndex ++ )
87
- {
88
- for ( var branchActionIndex = 0 ;
89
- branchActionIndex < nBranchAction ;
90
- branchActionIndex ++ )
91
- {
92
- actionProbs . data [ batchIndex , branchActionIndex ] =
93
- tensorProxy . data [ batchIndex , startActionIndices [ actionIndex ] + branchActionIndex ] ;
94
- }
95
- }
96
-
97
- var outputTensor = new TensorProxy ( )
98
- {
99
- valueType = TensorProxy . TensorType . FloatingPoint ,
100
- shape = new long [ ] { batchSize , 1 } ,
101
- data = m_Allocator . Alloc ( new TensorShape ( batchSize , 1 ) )
102
- } ;
103
-
104
- Eval ( actionProbs , outputTensor , m_Multinomial ) ;
105
-
106
- for ( var ii = 0 ; ii < batchSize ; ii ++ )
107
- {
108
- actionValues [ ii , actionIndex ] = outputTensor . data [ ii , 0 ] ;
109
- }
110
- actionProbs . data . Dispose ( ) ;
111
- outputTensor . data . Dispose ( ) ;
112
- }
113
-
114
77
var agentIndex = 0 ;
115
78
for ( var i = 0 ; i < actionIds . Count ; i ++ )
116
79
{
@@ -126,74 +89,38 @@ public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int,
126
89
var discreteBuffer = actionBuffer . DiscreteActions ;
127
90
for ( var j = 0 ; j < m_ActionSize . Length ; j ++ )
128
91
{
129
- discreteBuffer [ j ] = ( int ) actionValues [ agentIndex , j ] ;
92
+ ComputeCdf ( tensorProxy , agentIndex , m_StartActionIndices [ j ] , m_ActionSize [ j ] ) ;
93
+ discreteBuffer [ j ] = m_Multinomial . Sample ( m_CdfBuffer , m_ActionSize [ j ] ) ;
130
94
}
131
95
}
132
96
agentIndex ++ ;
133
97
}
134
98
}
135
99
136
100
/// <summary>
137
- /// Draw samples from a multinomial distribution based on log-probabilities specified
138
- /// in tensor src. The samples will be saved in the dst tensor.
101
+ /// Compute the cumulative distribution function for a given agent's action
102
+ /// given the log-probabilities.
103
+ /// The results are stored in m_CdfBuffer, which is the size of the largest action's number of branches.
139
104
/// </summary>
140
- /// <param name="src">2-D tensor with shape batch_size x num_classes</param>
141
- /// <param name="dst">Allocated tensor with size batch_size x num_samples</param>
142
- /// <param name="multinomial">Multinomial object used to sample values</param>
143
- /// <exception cref="NotImplementedException">
144
- /// Multinomial doesn't support integer tensors
145
- /// </exception>
146
- /// <exception cref="ArgumentException">Issue with tensor shape or type</exception>
147
- /// <exception cref="ArgumentNullException">
148
- /// At least one of the tensors is not allocated
149
- /// </exception>
150
- public static void Eval ( TensorProxy src , TensorProxy dst , Multinomial multinomial )
105
+ /// <param name="logProbs"></param>
106
+ /// <param name="batch">Index of the agent being considered</param>
107
+ /// <param name="channelOffset">Offset into the tensor's channel.</param>
108
+ /// <param name="branchSize"></param>
109
+ internal void ComputeCdf ( TensorProxy logProbs , int batch , int channelOffset , int branchSize )
151
110
{
152
- if ( src . DataType != typeof ( float ) )
111
+ // Find the class maximum
112
+ var maxProb = float . NegativeInfinity ;
113
+ for ( var cls = 0 ; cls < branchSize ; ++ cls )
153
114
{
154
- throw new NotImplementedException ( "Only float tensors are currently supported" ) ;
115
+ maxProb = Mathf . Max ( logProbs . data [ batch , cls + channelOffset ] , maxProb ) ;
155
116
}
156
117
157
- if ( src . valueType != dst . valueType )
118
+ // Sum the log probabilities and compute CDF
119
+ var sumProb = 0.0f ;
120
+ for ( var cls = 0 ; cls < branchSize ; ++ cls )
158
121
{
159
- throw new ArgumentException (
160
- "Source and destination tensors have different types!" ) ;
161
- }
162
-
163
- if ( src . data == null || dst . data == null )
164
- {
165
- throw new ArgumentNullException ( ) ;
166
- }
167
-
168
- if ( src . data . batch != dst . data . batch )
169
- {
170
- throw new ArgumentException ( "Batch size for input and output data is different!" ) ;
171
- }
172
-
173
- var cdf = new float [ src . data . channels ] ;
174
-
175
- for ( var batch = 0 ; batch < src . data . batch ; ++ batch )
176
- {
177
- // Find the class maximum
178
- var maxProb = float . NegativeInfinity ;
179
- for ( var cls = 0 ; cls < src . data . channels ; ++ cls )
180
- {
181
- maxProb = Mathf . Max ( src . data [ batch , cls ] , maxProb ) ;
182
- }
183
-
184
- // Sum the log probabilities and compute CDF
185
- var sumProb = 0.0f ;
186
- for ( var cls = 0 ; cls < src . data . channels ; ++ cls )
187
- {
188
- sumProb += Mathf . Exp ( src . data [ batch , cls ] - maxProb ) ;
189
- cdf [ cls ] = sumProb ;
190
- }
191
-
192
- // Generate the samples
193
- for ( var sample = 0 ; sample < dst . data . channels ; ++ sample )
194
- {
195
- dst . data [ batch , sample ] = multinomial . Sample ( cdf ) ;
196
- }
122
+ sumProb += Mathf . Exp ( logProbs . data [ batch , cls + channelOffset ] - maxProb ) ;
123
+ m_CdfBuffer [ cls ] = sumProb ;
197
124
}
198
125
}
199
126
}
0 commit comments