Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 23 additions & 38 deletions UnitySDK/Assets/ML-Agents/Scripts/CoreBrainInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ public enum TensorType
/// Modify only in inspector : Reference to the Graph asset
public TextAsset graphModel;

/// Modify only in inspector : If a scope was used when training the model, specify it here
public string graphScope;

[SerializeField]
[Tooltip(
"If your graph takes additional inputs that are fixed (example: noise level) you can specify them here.")]
Expand Down Expand Up @@ -136,40 +133,35 @@ public void InitializeCoreBrain(MLAgents.Batcher brainBatcher)

// TODO: Make this a loop over a dynamic set of graph inputs

if ((graphScope.Length > 1) && (graphScope[graphScope.Length - 1] != '/'))
{
graphScope = graphScope + '/';
}

if (graph[graphScope + BatchSizePlaceholderName] != null)
if (graph[BatchSizePlaceholderName] != null)
{
hasBatchSize = true;
}

if ((graph[graphScope + RecurrentInPlaceholderName] != null) &&
(graph[graphScope + RecurrentOutPlaceholderName] != null))
if ((graph[RecurrentInPlaceholderName] != null) &&
(graph[RecurrentOutPlaceholderName] != null))
{
hasRecurrent = true;
var runner = session.GetRunner();
runner.Fetch(graph[graphScope + "memory_size"][0]);
runner.Fetch(graph["memory_size"][0]);
var networkOutput = runner.Run()[0].GetValue();
memorySize = (int) networkOutput;
}

if (graph[graphScope + VectorObservationPlacholderName] != null)
if (graph[VectorObservationPlacholderName] != null)
{
hasState = true;
}

if (graph[graphScope + PreviousActionPlaceholderName] != null)
if (graph[PreviousActionPlaceholderName] != null)
{
hasPrevAction = true;
}
if (graph[graphScope + "value_estimate"] != null)
if (graph["value_estimate"] != null)
{
hasValueEstimate = true;
}
if (graph[graphScope + ActionMaskPlaceholderName] != null)
if (graph[ActionMaskPlaceholderName] != null)
{
hasMaskedActions = true;
}
Expand Down Expand Up @@ -304,18 +296,18 @@ public void DecideAction(Dictionary<Agent, AgentInfo> agentInfo)
var runner = session.GetRunner();
try
{
runner.Fetch(graph[graphScope + ActionPlaceholderName][0]);
runner.Fetch(graph[ActionPlaceholderName][0]);
}
catch
{
throw new UnityAgentsException(string.Format(
@"The node {0} could not be found. Please make sure the graphScope {1} is correct",
graphScope + ActionPlaceholderName, graphScope));
@"The node {0} could not be found. Please make sure the node name is correct",
ActionPlaceholderName));
}

if (hasBatchSize)
{
runner.AddInput(graph[graphScope + BatchSizePlaceholderName][0], new int[] {currentBatchSize});
runner.AddInput(graph[BatchSizePlaceholderName][0], new int[] {currentBatchSize});
}

foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders)
Expand All @@ -324,12 +316,12 @@ public void DecideAction(Dictionary<Agent, AgentInfo> agentInfo)
{
if (placeholder.valueType == TensorFlowAgentPlaceholder.TensorType.FloatingPoint)
{
runner.AddInput(graph[graphScope + placeholder.name][0],
runner.AddInput(graph[placeholder.name][0],
new float[] {Random.Range(placeholder.minValue, placeholder.maxValue)});
}
else if (placeholder.valueType == TensorFlowAgentPlaceholder.TensorType.Integer)
{
runner.AddInput(graph[graphScope + placeholder.name][0],
runner.AddInput(graph[placeholder.name][0],
new int[] {Random.Range((int) placeholder.minValue, (int) placeholder.maxValue + 1)});
}
}
Expand All @@ -338,26 +330,26 @@ public void DecideAction(Dictionary<Agent, AgentInfo> agentInfo)
throw new UnityAgentsException(string.Format(
@"One of the Tensorflow placeholder cound nout be found.
In brain {0}, there are no {1} placeholder named {2}.",
brain.gameObject.name, placeholder.valueType.ToString(), graphScope + placeholder.name));
brain.gameObject.name, placeholder.valueType.ToString(), placeholder.name));
}
}

// Create the state tensor
if (hasState)
{
runner.AddInput(graph[graphScope + VectorObservationPlacholderName][0], inputState);
runner.AddInput(graph[VectorObservationPlacholderName][0], inputState);
}

// Create the previous action tensor
if (hasPrevAction)
{
runner.AddInput(graph[graphScope + PreviousActionPlaceholderName][0], inputPrevAction);
runner.AddInput(graph[PreviousActionPlaceholderName][0], inputPrevAction);
}

// Create the mask action tensor
if (hasMaskedActions)
{
runner.AddInput(graph[graphScope + ActionMaskPlaceholderName][0], maskedActions);
runner.AddInput(graph[ActionMaskPlaceholderName][0], maskedActions);
}

// Create the observation tensors
Expand All @@ -366,20 +358,20 @@ public void DecideAction(Dictionary<Agent, AgentInfo> agentInfo)
obsNumber < brain.brainParameters.cameraResolutions.Length;
obsNumber++)
{
runner.AddInput(graph[graphScope + VisualObservationPlaceholderName[obsNumber]][0],
runner.AddInput(graph[VisualObservationPlaceholderName[obsNumber]][0],
observationMatrixList[obsNumber]);
}

if (hasRecurrent)
{
runner.AddInput(graph[graphScope + "sequence_length"][0], 1);
runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories);
runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]);
runner.AddInput(graph["sequence_length"][0], 1);
runner.AddInput(graph[RecurrentInPlaceholderName][0], inputOldMemories);
runner.Fetch(graph[RecurrentOutPlaceholderName][0]);
}

if (hasValueEstimate)
{
runner.Fetch(graph[graphScope + "value_estimate"][0]);
runner.Fetch(graph["value_estimate"][0]);
}

TFTensor[] networkOutput;
Expand Down Expand Up @@ -504,13 +496,6 @@ public void OnInspector()
{
EditorGUILayout.HelpBox("Please provide a tensorflow graph as a bytes file.", MessageType.Error);
}


graphScope =
EditorGUILayout.TextField(new GUIContent("Graph Scope",
"If you set a scope while training your tensorflow model, " +
"all your placeholder name will have a prefix. You must specify that prefix here."), graphScope);

if (BatchSizePlaceholderName == "")
{
BatchSizePlaceholderName = "batch_size";
Expand Down