Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions com.unity.ml-agents.extensions/Runtime/RuntimeExample.cs

This file was deleted.

8 changes: 8 additions & 0 deletions com.unity.ml-agents.extensions/Runtime/Sensors.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#if UNITY_2020_1_OR_NEWER

using System.Collections.Generic;
using UnityEngine;

namespace Unity.MLAgents.Extensions.Sensors
{

public class ArticulationBodyPoseExtractor : PoseExtractor
{
ArticulationBody[] m_Bodies;

public ArticulationBodyPoseExtractor(ArticulationBody rootBody)
{
if (!rootBody.isRoot)
{
Debug.Log("Must pass ArticulationBody.isRoot");
return;
}

var bodies = rootBody.GetComponentsInChildren <ArticulationBody>();
if (bodies[0] != rootBody)
{
Debug.Log("Expected root body at index 0");
return;
}

var numBodies = bodies.Length;
m_Bodies = bodies;
int[] parentIndices = new int[numBodies];
parentIndices[0] = -1;

var bodyToIndex = new Dictionary<ArticulationBody, int>();
for (var i = 0; i < numBodies; i++)
{
bodyToIndex[m_Bodies[i]] = i;
}

for (var i = 1; i < numBodies; i++)
{
var body = m_Bodies[i];
var parent = body.GetComponentInParent<ArticulationBody>();
parentIndices[i] = bodyToIndex[parent];
}

SetParentIndices(parentIndices);
}

protected override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];
var go = body.gameObject;
var t = go.transform;
return new Pose { rotation = t.rotation, position = t.position };
}


}
}
#endif // UNITY_2020_1_OR_NEWER

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
using System;

using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Sensors
{
[Serializable]
public struct PhysicsSensorSettings
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might get extended later to have different representation types for rotations, e.g. 3x3 rotation matrices, Euler angles.

{
/// <summary>
/// Whether to use model space (relative to the root body) translations as observations.
/// </summary>
public bool UseModelSpaceTranslations;

/// <summary>
/// Whether to use model space (relative to the root body) rotatoins as observations.
/// </summary>
public bool UseModelSpaceRotations;

/// <summary>
/// Whether to use local space (relative to the parent body) translations as observations.
/// </summary>
public bool UseLocalSpaceTranslations;

/// <summary>
/// Whether to use local space (relative to the parent body) translations as observations.
/// </summary>
public bool UseLocalSpaceRotations;

/// <summary>
/// Creates a PhysicsSensorSettings with reasonable default values.
/// </summary>
/// <returns></returns>
public static PhysicsSensorSettings Default()
{
return new PhysicsSensorSettings
{
UseModelSpaceTranslations = true,
UseModelSpaceRotations = true,
};
}

/// <summary>
/// Whether any model space observations are being used.
/// </summary>
public bool UseModelSpace
{
get { return UseModelSpaceTranslations || UseModelSpaceRotations; }
}

/// <summary>
/// Whether any local space observations are being used.
/// </summary>
public bool UseLocalSpace
{
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations; }
}


/// <summary>
/// The number of floats needed to represent a given number of transforms.
/// </summary>
/// <param name="numTransforms"></param>
/// <returns></returns>
public int TransformSize(int numTransforms)
{
int obsPerTransform = 0;
obsPerTransform += UseModelSpaceTranslations ? 3 : 0;
obsPerTransform += UseModelSpaceRotations ? 4 : 0;
obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
obsPerTransform += UseLocalSpaceRotations ? 4 : 0;

return numTransforms * obsPerTransform;
}
}

internal static class ObservationWriterPhysicsExtensions
{
/// <summary>
/// Utility method for writing a PoseExtractor to an ObservationWriter.
/// </summary>
/// <param name="writer"></param>
/// <param name="settings"></param>
/// <param name="poseExtractor"></param>
/// <param name="baseOffset">The offset into the ObservationWriter to start writing at.</param>
/// <returns>The number of observations written.</returns>
public static int WritePoses(this ObservationWriter writer, PhysicsSensorSettings settings, PoseExtractor poseExtractor, int baseOffset = 0)
{
var offset = baseOffset;
if (settings.UseModelSpace)
{
foreach (var pose in poseExtractor.ModelSpacePoses)
{
if(settings.UseModelSpaceTranslations)
{
writer.Add(pose.position, offset);
offset += 3;
}
if (settings.UseModelSpaceRotations)
{
writer.Add(pose.rotation, offset);
offset += 4;
}
}
}

if (settings.UseLocalSpace)
{
foreach (var pose in poseExtractor.LocalSpacePoses)
{
if(settings.UseLocalSpaceTranslations)
{
writer.Add(pose.position, offset);
offset += 3;
}
if (settings.UseLocalSpaceRotations)
{
writer.Add(pose.rotation, offset);
offset += 4;
}
}
}

return offset - baseOffset;
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

170 changes: 170 additions & 0 deletions com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
using System.Collections.Generic;
using UnityEngine;

namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Abstract class for managing the transforms of a hierarchy of objects.
/// This could be GameObjects or Monobehaviours in the scene graph, but this is
/// not a requirement; for example, the objects could be rigid bodies whose hierarchy
/// is defined by Joint configurations.
///
/// Poses are either considered in model space, which is relative to a root body,
/// or in local space, which is relative to their parent.
/// </summary>
public abstract class PoseExtractor
{
int[] m_ParentIndices;
Pose[] m_ModelSpacePoses;
Pose[] m_LocalSpacePoses;

/// <summary>
/// Read access to the model space transforms.
/// </summary>
public IList<Pose> ModelSpacePoses
{
get { return m_ModelSpacePoses; }
}

/// <summary>
/// Read access to the local space transforms.
/// </summary>
public IList<Pose> LocalSpacePoses
{
get { return m_LocalSpacePoses; }
}

/// <summary>
/// Number of transforms in the hierarchy (read-only).
/// </summary>
public int NumPoses
{
get { return m_ModelSpacePoses?.Length ?? 0; }
}

/// <summary>
/// Initialize with the mapping of parent indices.
/// The 0th element is assumed to be -1, indicating that it's the root.
/// </summary>
/// <param name="parentIndices"></param>
protected void SetParentIndices(int[] parentIndices)
{
m_ParentIndices = parentIndices;
var numTransforms = parentIndices.Length;
m_ModelSpacePoses = new Pose[numTransforms];
m_LocalSpacePoses = new Pose[numTransforms];
}

/// <summary>
/// Return the world space Pose of the i'th object.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
protected abstract Pose GetPoseAt(int index);

/// <summary>
/// Update the internal model space transform storage based on the underlying system.
/// </summary>
public void UpdateModelSpacePoses()
{
if (m_ModelSpacePoses == null)
{
return;
}

var worldTransform = GetPoseAt(0);
var worldToModel = worldTransform.Inverse();

for (var i = 0; i < m_ModelSpacePoses.Length; i++)
{
var currentTransform = GetPoseAt(i);
m_ModelSpacePoses[i] = worldToModel.Multiply(currentTransform);
}
}

/// <summary>
/// Update the internal model space transform storage based on the underlying system.
/// </summary>
public void UpdateLocalSpacePoses()
{
if (m_LocalSpacePoses == null)
{
return;
}

for (var i = 0; i < m_LocalSpacePoses.Length; i++)
{
if (m_ParentIndices[i] != -1)
{
var parentTransform = GetPoseAt(m_ParentIndices[i]);
// This is slightly inefficient, since for a body with multiple children, we'll end up inverting
// the transform multiple times. Might be able to trade space for perf here.
var invParent = parentTransform.Inverse();
var currentTransform = GetPoseAt(i);
m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
}
else
{
m_LocalSpacePoses[i] = Pose.identity;
}
}
}


public void DrawModelSpace(Vector3 offset)
{
UpdateLocalSpacePoses();
UpdateModelSpacePoses();

var pose = m_ModelSpacePoses;
var localPose = m_LocalSpacePoses;
for (var i = 0; i < pose.Length; i++)
{
var current = pose[i];
if (m_ParentIndices[i] == -1)
{
continue;
}

var parent = pose[m_ParentIndices[i]];
Debug.DrawLine(current.position + offset, parent.position + offset, Color.cyan);
var localUp = localPose[i].rotation * Vector3.up;
var localFwd = localPose[i].rotation * Vector3.forward;
var localRight = localPose[i].rotation * Vector3.right;
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localUp, Color.red);
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localFwd, Color.green);
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localRight, Color.blue);
}
}
}

public static class PoseExtensions
{
/// <summary>
/// Compute the inverse of a Pose. For any Pose P,
/// P.Inverse() * P
/// will equal the identity pose (within tolerance).
/// </summary>
/// <param name="pose"></param>
/// <returns></returns>
public static Pose Inverse(this Pose pose)
{
var rotationInverse = Quaternion.Inverse(pose.rotation);
var translationInverse = -(rotationInverse * pose.position);
return new Pose { rotation = rotationInverse, position = translationInverse };
}

/// <summary>
/// This is equivalent to Pose.GetTransformedBy(), but keeps the order more intuitive.
/// </summary>
/// <param name="pose"></param>
/// <param name="rhs"></param>
/// <returns></returns>
public static Pose Multiply(this Pose pose, Pose rhs)
{
return rhs.GetTransformedBy(pose);
}

// TODO optimize inv(A)*B?
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading