Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
e9b9af6
make actionSpec not read-only
Nov 30, 2020
4840d09
add actionSpec in BrainParameters and update BrainParameters Drawer
Nov 30, 2020
d7080be
Integrate BrainParameters with ActionSpec, deprecate some fields
Nov 30, 2020
a1c81e9
add serialization callbacks
Nov 30, 2020
178e095
add function to set ActionSpec
Dec 1, 2020
accf0d5
update BrainParameters when changed in editor
Dec 1, 2020
05a3c6b
fix tests
Dec 1, 2020
3e3ebaa
enable hybrid ParameterLoaderTest
Dec 1, 2020
cc227fd
fix tests
Dec 1, 2020
351a479
fix demonstration drawer
Dec 1, 2020
6864686
add comments
Dec 1, 2020
397fcf1
rename method
Dec 3, 2020
0bf75e3
check both continuous/discrete output shape regardless of model outpu…
Dec 8, 2020
91d4125
simplify actionspec. put property names in const string
Dec 9, 2020
def4b0c
fix merge
Dec 9, 2020
8ed6ac6
make SetContinuous/Discrete() internal
Dec 9, 2020
2ae9473
fix prev action check
Dec 9, 2020
d53a9fd
fix tests
Dec 9, 2020
ddde653
rename VectorActionSpec to ActionSpec
Dec 9, 2020
b42a740
remove unused import
Dec 9, 2020
f9336eb
fix bad merge
Dec 9, 2020
2d0f9da
pass multiple args instead of new array to makeDiscrete()
Dec 9, 2020
ca778e3
change back public fields' name
Dec 9, 2020
8e9977b
change back public fields' name
Dec 9, 2020
a5b7302
fix clone brainParameter
Dec 10, 2020
ede61b2
remove setContinuous and setDiscrete
Dec 10, 2020
4f32b56
Merge branch 'master' into develop-hybrid-brainparameters2
Dec 10, 2020
59285bd
update changelog
Dec 10, 2020
b24c013
fix internal field serialization
Dec 10, 2020
51587de
revert breaking change
Dec 11, 2020
b6c62d8
sync deprecated fields
Dec 11, 2020
cd28e02
fix BrainParameterToProto
Dec 11, 2020
b1d04b5
change back VectorActuator suffix
Dec 11, 2020
4219e5c
add back ActionSpec
Dec 11, 2020
aa3d799
set deprecated to null
Dec 11, 2020
b64d85a
replace with SyncDeprecatedActionFields
Dec 11, 2020
d7c8298
update changelog
Dec 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@ and this project adheres to
### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- PyTorch trainers now support training agents with both continuous and discrete action spaces.
Currently, this can only be done with Actuators. Please see
[here](../Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs) for an
example of how to use Actuators. (#4702)

- PyTorch trainers now support training agents with both continuous and discrete action spaces. (#4702)
### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- Agent with both continuous and discrete actions is now supported. You can specify
continuous and discrete action sizes repectively in Behavior Parameters. (#4702, #4718)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- `ActionSpec.validate_action()` now enforces that `UnityEnvironment.set_action_for_agent()` receives a 1D `np.array`.

Expand Down
29 changes: 20 additions & 9 deletions com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,60 @@ internal class BehaviorParametersEditor : UnityEditor.Editor
float m_TimeSinceModelReload;
// Whether or not the model needs to be reloaded
bool m_RequireReload;
const string k_BehaviorName = "m_BehaviorName";
const string k_BrainParametersName = "m_BrainParameters";
const string k_ModelName = "m_Model";
const string k_InferenceDeviceName = "m_InferenceDevice";
const string k_BehaviorTypeName = "m_BehaviorType";
const string k_TeamIdName = "TeamId";
const string k_UseChildSensorsName = "m_UseChildSensors";
const string k_ObservableAttributeHandlingName = "m_ObservableAttributeHandling";

public override void OnInspectorGUI()
{
var so = serializedObject;
so.Update();
bool needPolicyUpdate; // Whether the name, model, inference device, or BehaviorType changed.
bool needBrainParametersUpdate; // Whether the brain parameters changed

// Drawing the Behavior Parameters
EditorGUI.indentLevel++;
EditorGUI.BeginChangeCheck(); // global

EditorGUI.BeginChangeCheck();
{
EditorGUILayout.PropertyField(so.FindProperty("m_BehaviorName"));
EditorGUILayout.PropertyField(so.FindProperty(k_BehaviorName));
}
needPolicyUpdate = EditorGUI.EndChangeCheck();

EditorGUI.BeginChangeCheck();
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
EditorGUILayout.PropertyField(so.FindProperty("m_BrainParameters"), true);
EditorGUILayout.PropertyField(so.FindProperty(k_BrainParametersName), true);
}
EditorGUI.EndDisabledGroup();
needBrainParametersUpdate = EditorGUI.EndChangeCheck();

EditorGUI.BeginChangeCheck();
{
EditorGUILayout.PropertyField(so.FindProperty("m_Model"), true);
EditorGUILayout.PropertyField(so.FindProperty(k_ModelName), true);
EditorGUI.indentLevel++;
EditorGUILayout.PropertyField(so.FindProperty("m_InferenceDevice"), true);
EditorGUILayout.PropertyField(so.FindProperty(k_InferenceDeviceName), true);
EditorGUI.indentLevel--;
}
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck();

EditorGUI.BeginChangeCheck();
{
EditorGUILayout.PropertyField(so.FindProperty("m_BehaviorType"));
EditorGUILayout.PropertyField(so.FindProperty(k_BehaviorTypeName));
}
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck();

EditorGUILayout.PropertyField(so.FindProperty("TeamId"));
EditorGUILayout.PropertyField(so.FindProperty(k_TeamIdName));
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
EditorGUILayout.PropertyField(so.FindProperty("m_UseChildSensors"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservableAttributeHandling"), true);
EditorGUILayout.PropertyField(so.FindProperty(k_UseChildSensorsName), true);
EditorGUILayout.PropertyField(so.FindProperty(k_ObservableAttributeHandlingName), true);
}
EditorGUI.EndDisabledGroup();

Expand Down Expand Up @@ -91,7 +102,7 @@ void DisplayFailedModelChecks()
// Display all failed checks
D.logEnabled = false;
Model barracudaModel = null;
var model = (NNModel)serializedObject.FindProperty("m_Model").objectReferenceValue;
var model = (NNModel)serializedObject.FindProperty(k_ModelName).objectReferenceValue;
var behaviorParameters = (BehaviorParameters)target;

// Grab the sensor components, since we need them to determine the observation sizes.
Expand Down
58 changes: 17 additions & 41 deletions com.unity.ml-agents/Editor/BrainParametersDrawer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ internal class BrainParametersDrawer : PropertyDrawer
// The height of a line in the Unity Inspectors
const float k_LineHeight = 17f;
const int k_VecObsNumLine = 3;
const string k_ActionSizePropName = "VectorActionSize";
const string k_ActionTypePropName = "VectorActionSpaceType";
const string k_ActionSpecName = "m_ActionSpec";
const string k_ContinuousActionSizeName = "m_NumContinuousActions";
const string k_DiscreteBranchSizeName = "BranchSizes";
const string k_ActionDescriptionPropName = "VectorActionDescriptions";
const string k_VecObsPropName = "VectorObservationSize";
const string k_NumVecObsPropName = "NumStackedVectorObservations";
Expand Down Expand Up @@ -97,22 +98,10 @@ static void DrawVectorAction(Rect position, SerializedProperty property)
EditorGUI.LabelField(position, "Vector Action");
position.y += k_LineHeight;
EditorGUI.indentLevel++;
var bpVectorActionType = property.FindPropertyRelative(k_ActionTypePropName);
EditorGUI.PropertyField(
position,
bpVectorActionType,
new GUIContent("Space Type",
"Corresponds to whether state vector contains a single integer (Discrete) " +
"or a series of real-valued floats (Continuous)."));
var actionSpecProperty = property.FindPropertyRelative(k_ActionSpecName);
DrawContinuousVectorAction(position, actionSpecProperty);
position.y += k_LineHeight;
if (bpVectorActionType.enumValueIndex == 1)
{
DrawContinuousVectorAction(position, property);
}
else
{
DrawDiscreteVectorAction(position, property);
}
DrawDiscreteVectorAction(position, actionSpecProperty);
}

/// <summary>
Expand All @@ -123,21 +112,11 @@ static void DrawVectorAction(Rect position, SerializedProperty property)
/// to make the custom GUI for.</param>
static void DrawContinuousVectorAction(Rect position, SerializedProperty property)
{
var vecActionSize = property.FindPropertyRelative(k_ActionSizePropName);

// This check is here due to:
// https://fogbugz.unity3d.com/f/cases/1246524/
// If this case has been resolved, please remove this if condition.
if (vecActionSize.arraySize != 1)
{
vecActionSize.arraySize = 1;
}
var continuousActionSize =
vecActionSize.GetArrayElementAtIndex(0);
var continuousActionSize = property.FindPropertyRelative(k_ContinuousActionSizeName);
EditorGUI.PropertyField(
position,
continuousActionSize,
new GUIContent("Space Size", "Length of continuous action vector."));
new GUIContent("Continuous Action Size", "Length of continuous action vector."));
}

/// <summary>
Expand All @@ -148,27 +127,27 @@ static void DrawContinuousVectorAction(Rect position, SerializedProperty propert
/// to make the custom GUI for.</param>
static void DrawDiscreteVectorAction(Rect position, SerializedProperty property)
{
var vecActionSize = property.FindPropertyRelative(k_ActionSizePropName);
var branchSizes = property.FindPropertyRelative(k_DiscreteBranchSizeName);
var newSize = EditorGUI.IntField(
position, "Branches Size", vecActionSize.arraySize);
position, "Discrete Branch Size", branchSizes.arraySize);

// This check is here due to:
// https://fogbugz.unity3d.com/f/cases/1246524/
// If this case has been resolved, please remove this if condition.
if (newSize != vecActionSize.arraySize)
if (newSize != branchSizes.arraySize)
{
vecActionSize.arraySize = newSize;
branchSizes.arraySize = newSize;
}

position.y += k_LineHeight;
position.x += 20;
position.width -= 20;
for (var branchIndex = 0;
branchIndex < vecActionSize.arraySize;
branchIndex < branchSizes.arraySize;
branchIndex++)
{
var branchActionSize =
vecActionSize.GetArrayElementAtIndex(branchIndex);
branchSizes.GetArrayElementAtIndex(branchIndex);

EditorGUI.PropertyField(
position,
Expand All @@ -185,12 +164,9 @@ static void DrawDiscreteVectorAction(Rect position, SerializedProperty property)
/// <returns>The height of the drawer of the Vector Action.</returns>
static float GetHeightDrawVectorAction(SerializedProperty property)
{
var actionSize = 2 + property.FindPropertyRelative(k_ActionSizePropName).arraySize;
if (property.FindPropertyRelative(k_ActionTypePropName).enumValueIndex == 0)
{
actionSize += 1;
}
return actionSize * k_LineHeight;
var actionSpecProperty = property.FindPropertyRelative(k_ActionSpecName);
var numActionLines = 3 + actionSpecProperty.FindPropertyRelative(k_DiscreteBranchSizeName).arraySize;
return numActionLines * k_LineHeight;
}
}
}
46 changes: 29 additions & 17 deletions com.unity.ml-agents/Editor/DemonstrationDrawer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Text;
using UnityEditor;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Policies;


namespace Unity.MLAgents.Editor
Expand All @@ -17,23 +16,35 @@ internal class DemonstrationEditor : UnityEditor.Editor
SerializedProperty m_BrainParameters;
SerializedProperty m_DemoMetaData;
SerializedProperty m_ObservationShapes;
const string k_BrainParametersName = "brainParameters";
const string k_MetaDataName = "metaData";
const string k_ObservationSummariesName = "observationSummaries";
const string k_DemonstrationName = "demonstrationName";
const string k_NumberStepsName = "numberSteps";
const string k_NumberEpisodesName = "numberEpisodes";
const string k_MeanRewardName = "meanReward";
const string k_ActionSpecName = "ActionSpec";
const string k_NumContinuousActionsName = "m_NumContinuousActions";
const string k_NumDiscreteActionsName = "m_NumDiscreteActions";
const string k_ShapeName = "shape";


void OnEnable()
{
m_BrainParameters = serializedObject.FindProperty("brainParameters");
m_DemoMetaData = serializedObject.FindProperty("metaData");
m_ObservationShapes = serializedObject.FindProperty("observationSummaries");
m_BrainParameters = serializedObject.FindProperty(k_BrainParametersName);
m_DemoMetaData = serializedObject.FindProperty(k_MetaDataName);
m_ObservationShapes = serializedObject.FindProperty(k_ObservationSummariesName);
}

/// <summary>
/// Renders Inspector UI for Demonstration metadata.
/// </summary>
void MakeMetaDataProperty(SerializedProperty property)
{
var nameProp = property.FindPropertyRelative("demonstrationName");
var experiencesProp = property.FindPropertyRelative("numberSteps");
var episodesProp = property.FindPropertyRelative("numberEpisodes");
var rewardsProp = property.FindPropertyRelative("meanReward");
var nameProp = property.FindPropertyRelative(k_DemonstrationName);
var experiencesProp = property.FindPropertyRelative(k_NumberStepsName);
var episodesProp = property.FindPropertyRelative(k_NumberEpisodesName);
var rewardsProp = property.FindPropertyRelative(k_MeanRewardName);

var nameLabel = nameProp.displayName + ": " + nameProp.stringValue;
var experiencesLabel = experiencesProp.displayName + ": " + experiencesProp.intValue;
Expand Down Expand Up @@ -72,16 +83,17 @@ static string BuildIntArrayLabel(SerializedProperty actionSizeProperty)
/// </summary>
void MakeActionsProperty(SerializedProperty property)
{
var actSizeProperty = property.FindPropertyRelative("VectorActionSize");
var actSpaceTypeProp = property.FindPropertyRelative("VectorActionSpaceType");
var actSpecProperty = property.FindPropertyRelative(k_ActionSpecName);
var continuousSizeProperty = actSpecProperty.FindPropertyRelative(k_NumContinuousActionsName);
var discreteSizeProperty = actSpecProperty.FindPropertyRelative(k_NumDiscreteActionsName);

var vecActSizeLabel =
actSizeProperty.displayName + ": " + BuildIntArrayLabel(actSizeProperty);
var actSpaceTypeLabel = actSpaceTypeProp.displayName + ": " +
(SpaceType)actSpaceTypeProp.enumValueIndex;
var continuousSizeLabel =
continuousSizeProperty.displayName + ": " + continuousSizeProperty.intValue;
var discreteSizeLabel = discreteSizeProperty.displayName + ": " +
discreteSizeProperty.intValue;

EditorGUILayout.LabelField(vecActSizeLabel);
EditorGUILayout.LabelField(actSpaceTypeLabel);
EditorGUILayout.LabelField(continuousSizeLabel);
EditorGUILayout.LabelField(discreteSizeLabel);
}

/// <summary>
Expand All @@ -95,7 +107,7 @@ void MakeObservationsProperty(SerializedProperty obsSummariesProperty)
for (var i = 0; i < numObservations; i++)
{
var summary = obsSummariesProperty.GetArrayElementAtIndex(i);
var shapeProperty = summary.FindPropertyRelative("shape");
var shapeProperty = summary.FindPropertyRelative(k_ShapeName);
shapesLabels.Add(BuildIntArrayLabel(shapeProperty));
}

Expand Down
26 changes: 14 additions & 12 deletions com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents.Policies;
using UnityEngine;

namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Defines the structure of an Action Space to be used by the Actuator system.
/// </summary>
public readonly struct ActionSpec
[Serializable]
public struct ActionSpec
{
[SerializeField]
int m_NumContinuousActions;

/// <summary>
/// An array of branch sizes for our action space.
///
Expand All @@ -20,23 +24,23 @@ public readonly struct ActionSpec
///
/// For an IActuator with a Continuous it will be null.
/// </summary>
public readonly int[] BranchSizes;
public int[] BranchSizes;

/// <summary>
/// The number of actions for a Continuous <see cref="SpaceType"/>.
/// </summary>
public int NumContinuousActions { get; }
public int NumContinuousActions { get { return m_NumContinuousActions; } set { m_NumContinuousActions = value; } }

/// <summary>
/// The number of branches for a Discrete <see cref="SpaceType"/>.
/// </summary>
public int NumDiscreteActions { get; }
public int NumDiscreteActions { get { return BranchSizes == null ? 0 : BranchSizes.Length; } }

/// <summary>
/// Get the total number of Discrete Actions that can be taken by calculating the Sum
/// of all of the Discrete Action branch sizes.
/// </summary>
public int SumOfDiscreteBranchSizes { get; }
public int SumOfDiscreteBranchSizes { get { return BranchSizes == null ? 0 : BranchSizes.Sum(); } }

/// <summary>
/// Creates a Continuous <see cref="ActionSpec"/> with the number of actions available.
Expand All @@ -45,7 +49,7 @@ public readonly struct ActionSpec
/// <returns>An Continuous ActionSpec initialized with the number of actions available.</returns>
public static ActionSpec MakeContinuous(int numActions)
{
var actuatorSpace = new ActionSpec(numActions, 0);
var actuatorSpace = new ActionSpec(numActions, null);
return actuatorSpace;
}

Expand All @@ -59,16 +63,14 @@ public static ActionSpec MakeContinuous(int numActions)
public static ActionSpec MakeDiscrete(params int[] branchSizes)
{
var numActions = branchSizes.Length;
var actuatorSpace = new ActionSpec(0, numActions, branchSizes);
var actuatorSpace = new ActionSpec(0, branchSizes);
return actuatorSpace;
}

internal ActionSpec(int numContinuousActions, int numDiscreteActions, int[] branchSizes = null)
internal ActionSpec(int numContinuousActions, int[] branchSizes = null)
{
NumContinuousActions = numContinuousActions;
NumDiscreteActions = numDiscreteActions;
m_NumContinuousActions = numContinuousActions;
BranchSizes = branchSizes;
SumOfDiscreteBranchSizes = branchSizes?.Sum() ?? 0;
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ internal static ActionSpec CombineActionSpecs(IList<IActuator> actuators)
}
}

return new ActionSpec(numContinuousActions, numDiscreteActions, combinedBranchSizes);
return new ActionSpec(numContinuousActions, combinedBranchSizes);
}

/// <summary>
Expand Down
Loading