Skip to content

Commit 41bbd45

Browse files
author
Chris Elion
authored
[MLA-1762] reduce memory allocations from DiscreteActionOutputApplier (#4922)
1 parent 09c4e56 commit 41bbd45

File tree

4 files changed

+74
-281
lines changed

4 files changed

+74
-281
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ and this project adheres to
5353
- Removed unnecessary memory allocations in `SideChannelManager.GetSideChannelMessage()` (#4886)
5454
- Removed several memory allocations that happened during inference. On a test scene, this
5555
reduced the amount of memory allocated by approximately 25%. (#4887)
56+
- Removed several memory allocations that happened during inference with discrete actions. (#4922)
5657
- Properly catch permission errors when writing timer files. (#4921)
5758

5859
#### ml-agents / ml-agents-envs / gym-unity (Python)

com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs

Lines changed: 28 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using System;
21
using System.Collections.Generic;
32
using System.Linq;
43
using Unity.MLAgents.Inference.Utils;
@@ -55,62 +54,26 @@ internal class DiscreteActionOutputApplier : TensorApplier.IApplier
5554
{
5655
readonly int[] m_ActionSize;
5756
readonly Multinomial m_Multinomial;
58-
readonly ITensorAllocator m_Allocator;
5957
readonly ActionSpec m_ActionSpec;
58+
readonly int[] m_StartActionIndices;
59+
readonly float[] m_CdfBuffer;
60+
6061

6162
public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
6263
{
6364
m_ActionSize = actionSpec.BranchSizes;
6465
m_Multinomial = new Multinomial(seed);
65-
m_Allocator = allocator;
6666
m_ActionSpec = actionSpec;
67+
m_StartActionIndices = Utilities.CumSum(m_ActionSize);
68+
69+
// Scratch space for computing the cumulative distribution function.
70+
// In order to reuse it, make it the size of the largest branch.
71+
var largestBranch = Mathf.Max(m_ActionSize);
72+
m_CdfBuffer = new float[largestBranch];
6773
}
6874

6975
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
7076
{
71-
//var tensorDataProbabilities = tensorProxy.Data as float[,];
72-
var idActionPairList = actionIds as List<int> ?? actionIds.ToList();
73-
var batchSize = idActionPairList.Count;
74-
var actionValues = new float[batchSize, m_ActionSize.Length];
75-
var startActionIndices = Utilities.CumSum(m_ActionSize);
76-
for (var actionIndex = 0; actionIndex < m_ActionSize.Length; actionIndex++)
77-
{
78-
var nBranchAction = m_ActionSize[actionIndex];
79-
var actionProbs = new TensorProxy()
80-
{
81-
valueType = TensorProxy.TensorType.FloatingPoint,
82-
shape = new long[] { batchSize, nBranchAction },
83-
data = m_Allocator.Alloc(new TensorShape(batchSize, nBranchAction))
84-
};
85-
86-
for (var batchIndex = 0; batchIndex < batchSize; batchIndex++)
87-
{
88-
for (var branchActionIndex = 0;
89-
branchActionIndex < nBranchAction;
90-
branchActionIndex++)
91-
{
92-
actionProbs.data[batchIndex, branchActionIndex] =
93-
tensorProxy.data[batchIndex, startActionIndices[actionIndex] + branchActionIndex];
94-
}
95-
}
96-
97-
var outputTensor = new TensorProxy()
98-
{
99-
valueType = TensorProxy.TensorType.FloatingPoint,
100-
shape = new long[] { batchSize, 1 },
101-
data = m_Allocator.Alloc(new TensorShape(batchSize, 1))
102-
};
103-
104-
Eval(actionProbs, outputTensor, m_Multinomial);
105-
106-
for (var ii = 0; ii < batchSize; ii++)
107-
{
108-
actionValues[ii, actionIndex] = outputTensor.data[ii, 0];
109-
}
110-
actionProbs.data.Dispose();
111-
outputTensor.data.Dispose();
112-
}
113-
11477
var agentIndex = 0;
11578
for (var i = 0; i < actionIds.Count; i++)
11679
{
@@ -126,74 +89,38 @@ public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int,
12689
var discreteBuffer = actionBuffer.DiscreteActions;
12790
for (var j = 0; j < m_ActionSize.Length; j++)
12891
{
129-
discreteBuffer[j] = (int)actionValues[agentIndex, j];
92+
ComputeCdf(tensorProxy, agentIndex, m_StartActionIndices[j], m_ActionSize[j]);
93+
discreteBuffer[j] = m_Multinomial.Sample(m_CdfBuffer, m_ActionSize[j]);
13094
}
13195
}
13296
agentIndex++;
13397
}
13498
}
13599

136100
/// <summary>
137-
/// Draw samples from a multinomial distribution based on log-probabilities specified
138-
/// in tensor src. The samples will be saved in the dst tensor.
101+
/// Compute the cumulative distribution function for a given agent's action
102+
/// given the log-probabilities.
103+
/// The results are stored in m_CdfBuffer, which is the size of the largest action's number of branches.
139104
/// </summary>
140-
/// <param name="src">2-D tensor with shape batch_size x num_classes</param>
141-
/// <param name="dst">Allocated tensor with size batch_size x num_samples</param>
142-
/// <param name="multinomial">Multinomial object used to sample values</param>
143-
/// <exception cref="NotImplementedException">
144-
/// Multinomial doesn't support integer tensors
145-
/// </exception>
146-
/// <exception cref="ArgumentException">Issue with tensor shape or type</exception>
147-
/// <exception cref="ArgumentNullException">
148-
/// At least one of the tensors is not allocated
149-
/// </exception>
150-
public static void Eval(TensorProxy src, TensorProxy dst, Multinomial multinomial)
105+
/// <param name="logProbs"></param>
106+
/// <param name="batch">Index of the agent being considered</param>
107+
/// <param name="channelOffset">Offset into the tensor's channel.</param>
108+
/// <param name="branchSize"></param>
109+
internal void ComputeCdf(TensorProxy logProbs, int batch, int channelOffset, int branchSize)
151110
{
152-
if (src.DataType != typeof(float))
111+
// Find the class maximum
112+
var maxProb = float.NegativeInfinity;
113+
for (var cls = 0; cls < branchSize; ++cls)
153114
{
154-
throw new NotImplementedException("Only float tensors are currently supported");
115+
maxProb = Mathf.Max(logProbs.data[batch, cls + channelOffset], maxProb);
155116
}
156117

157-
if (src.valueType != dst.valueType)
118+
// Sum the log probabilities and compute CDF
119+
var sumProb = 0.0f;
120+
for (var cls = 0; cls < branchSize; ++cls)
158121
{
159-
throw new ArgumentException(
160-
"Source and destination tensors have different types!");
161-
}
162-
163-
if (src.data == null || dst.data == null)
164-
{
165-
throw new ArgumentNullException();
166-
}
167-
168-
if (src.data.batch != dst.data.batch)
169-
{
170-
throw new ArgumentException("Batch size for input and output data is different!");
171-
}
172-
173-
var cdf = new float[src.data.channels];
174-
175-
for (var batch = 0; batch < src.data.batch; ++batch)
176-
{
177-
// Find the class maximum
178-
var maxProb = float.NegativeInfinity;
179-
for (var cls = 0; cls < src.data.channels; ++cls)
180-
{
181-
maxProb = Mathf.Max(src.data[batch, cls], maxProb);
182-
}
183-
184-
// Sum the log probabilities and compute CDF
185-
var sumProb = 0.0f;
186-
for (var cls = 0; cls < src.data.channels; ++cls)
187-
{
188-
sumProb += Mathf.Exp(src.data[batch, cls] - maxProb);
189-
cdf[cls] = sumProb;
190-
}
191-
192-
// Generate the samples
193-
for (var sample = 0; sample < dst.data.channels; ++sample)
194-
{
195-
dst.data[batch, sample] = multinomial.Sample(cdf);
196-
}
122+
sumProb += Mathf.Exp(logProbs.data[batch, cls + channelOffset] - maxProb);
123+
m_CdfBuffer[cls] = sumProb;
197124
}
198125
}
199126
}

com.unity.ml-agents/Runtime/Inference/Utils/Multinomial.cs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ public Multinomial(int seed)
3232
/// to be monotonic (always increasing). If the CMF is scaled, then the last entry in
3333
/// the array will be 1.0.
3434
/// </param>
35-
/// <returns>A sampled index from the CMF ranging from 0 to cmf.Length-1.</returns>
36-
public int Sample(float[] cmf)
35+
/// <param name="branchSize">The number of possible branches, i.e. the effective size of the cmf array.</param>
36+
/// <returns>A sampled index from the CMF ranging from 0 to branchSize-1.</returns>
37+
public int Sample(float[] cmf, int branchSize)
3738
{
38-
var p = (float)m_Random.NextDouble() * cmf[cmf.Length - 1];
39+
var p = (float)m_Random.NextDouble() * cmf[branchSize - 1];
3940
var cls = 0;
4041
while (cmf[cls] < p)
4142
{
@@ -44,5 +45,15 @@ public int Sample(float[] cmf)
4445

4546
return cls;
4647
}
48+
49+
/// <summary>
50+
/// Samples from the Multinomial distribution defined by the provided cumulative
51+
/// mass function.
52+
/// </summary>
53+
/// <returns>A sampled index from the CMF ranging from 0 to cmf.Length-1.</returns>
54+
public int Sample(float[] cmf)
55+
{
56+
return Sample(cmf, cmf.Length);
57+
}
4758
}
4859
}

0 commit comments

Comments
 (0)