Skip to content
2 changes: 1 addition & 1 deletion com.unity.ml-agents.extensions/LICENSE.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
com.unity.ml-agents.extensions copyright © 2020 Unity Technologies
com.unity.ml-agents.extensions copyright © 2020 Unity Technologies ApS
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Package validation was complaining about this.


Licensed under the Unity Companion License for Unity-dependent projects -- see
[Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@

namespace Unity.MLAgents.Extensions.Sensors
{

/// <summary>
/// Utility class to track a hierarchy of ArticulationBodies.
/// </summary>
public class ArticulationBodyPoseExtractor : PoseExtractor
{
ArticulationBody[] m_Bodies;

public ArticulationBodyPoseExtractor(ArticulationBody rootBody)
{
if (rootBody == null)
{
return;
}

if (!rootBody.isRoot)
{
Debug.Log("Must pass ArticulationBody.isRoot");
Expand All @@ -38,23 +45,32 @@ public ArticulationBodyPoseExtractor(ArticulationBody rootBody)

for (var i = 1; i < numBodies; i++)
{
var body = m_Bodies[i];
var parent = body.GetComponentInParent<ArticulationBody>();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would always return itself. Fixed and added checks in the unit test.

parentIndices[i] = bodyToIndex[parent];
var currentArticBody = m_Bodies[i];
// Component.GetComponentInParent will consider the provided object as well.
// So start looking from the parent.
var currentGameObject = currentArticBody.gameObject;
var parentGameObject = currentGameObject.transform.parent;
var parentArticBody = parentGameObject.GetComponentInParent<ArticulationBody>();
parentIndices[i] = bodyToIndex[parentArticBody];
}

SetParentIndices(parentIndices);
}

/// <inheritdoc/>
protected override Vector3 GetLinearVelocityAt(int index)
{
return m_Bodies[index].velocity;
}

/// <inheritdoc/>
protected override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];
var go = body.gameObject;
var t = go.transform;
return new Pose { rotation = t.rotation, position = t.position };
}


}
}
#endif // UNITY_2020_1_OR_NEWER
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#if UNITY_2020_1_OR_NEWER
using UnityEngine;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{
public class ArticulationBodySensorComponent : SensorComponent
{
public ArticulationBody RootBody;

[SerializeField]
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();
public string sensorName;

/// <summary>
/// Creates a PhysicsBodySensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
return new PhysicsBodySensor(RootBody, Settings, sensorName);
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}

// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit heavy handed, but it's not in a critical loop at the moment so maybe it's ok. looks like the TODOs would address this.

var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
return new[] { numTransformObservations };
}
}

}
#endif // UNITY_2020_1_OR_NEWER

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using UnityEngine;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
/// </summary>
public class PhysicsBodySensor : ISensor
{
int[] m_Shape;
string m_SensorName;

PoseExtractor m_PoseExtractor;
PhysicsSensorSettings m_Settings;

/// <summary>
/// Construct a new PhysicsBodySensor
/// </summary>
/// <param name="rootBody"></param>
/// <param name="settings"></param>
/// <param name="sensorName"></param>
public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
{
m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;

var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
m_Shape = new[] { numTransformObservations };
}

#if UNITY_2020_1_OR_NEWER
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
{
m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody);
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;

var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
m_Shape = new[] { numTransformObservations };
}
#endif

/// <inheritdoc/>
public int[] GetObservationShape()
{
return m_Shape;
}

/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
return numWritten;
}

/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
return null;
}

/// <inheritdoc/>
public void Update()
{
if (m_Settings.UseModelSpace)
{
m_PoseExtractor.UpdateModelSpacePoses();
}

if (m_Settings.UseLocalSpace)
{
m_PoseExtractor.UpdateLocalSpacePoses();
}
}

/// <inheritdoc/>
public void Reset() {}

/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}

/// <inheritdoc/>
public string GetName()
{
return m_SensorName;
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Settings that define the observations generated for physics-based sensors.
/// </summary>
[Serializable]
public struct PhysicsSensorSettings
{
Expand All @@ -13,7 +16,7 @@ public struct PhysicsSensorSettings
public bool UseModelSpaceTranslations;

/// <summary>
/// Whether to use model space (relative to the root body) rotatoins as observations.
/// Whether to use model space (relative to the root body) rotations as observations.
/// </summary>
public bool UseModelSpaceRotations;

Expand All @@ -27,6 +30,16 @@ public struct PhysicsSensorSettings
/// </summary>
public bool UseLocalSpaceRotations;

/// <summary>
/// Whether to use model space (relative to the root body) linear velocities as observations.
/// </summary>
public bool UseModelSpaceLinearVelocity;

/// <summary>
/// Whether to use local space (relative to the parent body) linear velocities as observations.
/// </summary>
public bool UseLocalSpaceLinearVelocity;

/// <summary>
/// Creates a PhysicsSensorSettings with reasonable default values.
/// </summary>
Expand All @@ -45,15 +58,15 @@ public static PhysicsSensorSettings Default()
/// </summary>
public bool UseModelSpace
{
get { return UseModelSpaceTranslations || UseModelSpaceRotations; }
get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; }
}

/// <summary>
/// Whether any local space observations are being used.
/// </summary>
public bool UseLocalSpace
{
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations; }
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
}


Expand All @@ -70,6 +83,9 @@ public int TransformSize(int numTransforms)
obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
obsPerTransform += UseLocalSpaceRotations ? 4 : 0;

obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0;
obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0;

return numTransforms * obsPerTransform;
}
}
Expand All @@ -89,8 +105,12 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
var offset = baseOffset;
if (settings.UseModelSpace)
{
foreach (var pose in poseExtractor.ModelSpacePoses)
var poses = poseExtractor.ModelSpacePoses;
var vels = poseExtractor.ModelSpaceVelocities;

for(var i=0; i<poseExtractor.NumPoses; i++)
{
var pose = poses[i];
if(settings.UseModelSpaceTranslations)
{
writer.Add(pose.position, offset);
Expand All @@ -101,13 +121,22 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
writer.Add(pose.rotation, offset);
offset += 4;
}
if (settings.UseModelSpaceLinearVelocity)
{
writer.Add(vels[i], offset);
offset += 3;
}
}
}

if (settings.UseLocalSpace)
{
foreach (var pose in poseExtractor.LocalSpacePoses)
var poses = poseExtractor.LocalSpacePoses;
var vels = poseExtractor.LocalSpaceVelocities;

for(var i=0; i<poseExtractor.NumPoses; i++)
{
var pose = poses[i];
if(settings.UseLocalSpaceTranslations)
{
writer.Add(pose.position, offset);
Expand All @@ -118,6 +147,11 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
writer.Add(pose.rotation, offset);
offset += 4;
}
if (settings.UseLocalSpaceLinearVelocity)
{
writer.Add(vels[i], offset);
offset += 3;
}
}
}

Expand Down
Loading