diff --git a/com.unity.ml-agents.extensions/Runtime/RuntimeExample.cs b/com.unity.ml-agents.extensions/Runtime/RuntimeExample.cs deleted file mode 100644 index d12535d5b0..0000000000 --- a/com.unity.ml-agents.extensions/Runtime/RuntimeExample.cs +++ /dev/null @@ -1,6 +0,0 @@ -using Unity.MLAgents; - -namespace Unity.MLAgents.Extensions -{ - -} \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors.meta b/com.unity.ml-agents.extensions/Runtime/Sensors.meta new file mode 100644 index 0000000000..8a56d01593 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 2a66e31170bb04777b9ade862995a624 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs new file mode 100644 index 0000000000..3f5068ed92 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs @@ -0,0 +1,60 @@ +#if UNITY_2020_1_OR_NEWER + +using System.Collections.Generic; +using UnityEngine; + +namespace Unity.MLAgents.Extensions.Sensors +{ + + public class ArticulationBodyPoseExtractor : PoseExtractor + { + ArticulationBody[] m_Bodies; + + public ArticulationBodyPoseExtractor(ArticulationBody rootBody) + { + if (!rootBody.isRoot) + { + Debug.Log("Must pass ArticulationBody.isRoot"); + return; + } + + var bodies = rootBody.GetComponentsInChildren (); + if (bodies[0] != rootBody) + { + Debug.Log("Expected root body at index 0"); + return; + } + + var numBodies = bodies.Length; + m_Bodies = bodies; + int[] parentIndices = new int[numBodies]; + parentIndices[0] = -1; + + var bodyToIndex = new Dictionary(); + for (var i = 0; i < numBodies; i++) + { + bodyToIndex[m_Bodies[i]] = i; + } + + for (var i = 1; i < numBodies; i++) + { + var body = m_Bodies[i]; + var parent = body.GetComponentInParent(); + parentIndices[i] = bodyToIndex[parent]; + } + + SetParentIndices(parentIndices); + } + + protected override Pose GetPoseAt(int index) + { + var body = m_Bodies[index]; + var go = body.gameObject; + var t = go.transform; + return new Pose { rotation = t.rotation, position = t.position }; + } + + + } +} +#endif // UNITY_2020_1_OR_NEWER \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/RuntimeExample.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs.meta similarity index 83% rename from com.unity.ml-agents.extensions/Runtime/RuntimeExample.cs.meta rename to com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs.meta index a517bd42bc..41f7baad44 100644 --- a/com.unity.ml-agents.extensions/Runtime/RuntimeExample.cs.meta +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 2763a568e3b8541e08b90cb15e442281 +guid: 11fe037a02b4a483cb9342c3454232cd MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs new file mode 100644 index 0000000000..3a920eb8cd --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs @@ -0,0 +1,127 @@ +using System; + +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Extensions.Sensors +{ + [Serializable] + public struct PhysicsSensorSettings + { + /// + /// Whether to use model space (relative to the root body) translations as observations. + /// + public bool UseModelSpaceTranslations; + + /// + /// Whether to use model space (relative to the root body) rotatoins as observations. + /// + public bool UseModelSpaceRotations; + + /// + /// Whether to use local space (relative to the parent body) translations as observations. + /// + public bool UseLocalSpaceTranslations; + + /// + /// Whether to use local space (relative to the parent body) translations as observations. + /// + public bool UseLocalSpaceRotations; + + /// + /// Creates a PhysicsSensorSettings with reasonable default values. + /// + /// + public static PhysicsSensorSettings Default() + { + return new PhysicsSensorSettings + { + UseModelSpaceTranslations = true, + UseModelSpaceRotations = true, + }; + } + + /// + /// Whether any model space observations are being used. + /// + public bool UseModelSpace + { + get { return UseModelSpaceTranslations || UseModelSpaceRotations; } + } + + /// + /// Whether any local space observations are being used. + /// + public bool UseLocalSpace + { + get { return UseLocalSpaceTranslations || UseLocalSpaceRotations; } + } + + + /// + /// The number of floats needed to represent a given number of transforms. + /// + /// + /// + public int TransformSize(int numTransforms) + { + int obsPerTransform = 0; + obsPerTransform += UseModelSpaceTranslations ? 3 : 0; + obsPerTransform += UseModelSpaceRotations ? 4 : 0; + obsPerTransform += UseLocalSpaceTranslations ? 3 : 0; + obsPerTransform += UseLocalSpaceRotations ? 4 : 0; + + return numTransforms * obsPerTransform; + } + } + + internal static class ObservationWriterPhysicsExtensions + { + /// + /// Utility method for writing a PoseExtractor to an ObservationWriter. + /// + /// + /// + /// + /// The offset into the ObservationWriter to start writing at. + /// The number of observations written. + public static int WritePoses(this ObservationWriter writer, PhysicsSensorSettings settings, PoseExtractor poseExtractor, int baseOffset = 0) + { + var offset = baseOffset; + if (settings.UseModelSpace) + { + foreach (var pose in poseExtractor.ModelSpacePoses) + { + if(settings.UseModelSpaceTranslations) + { + writer.Add(pose.position, offset); + offset += 3; + } + if (settings.UseModelSpaceRotations) + { + writer.Add(pose.rotation, offset); + offset += 4; + } + } + } + + if (settings.UseLocalSpace) + { + foreach (var pose in poseExtractor.LocalSpacePoses) + { + if(settings.UseLocalSpaceTranslations) + { + writer.Add(pose.position, offset); + offset += 3; + } + if (settings.UseLocalSpaceRotations) + { + writer.Add(pose.rotation, offset); + offset += 4; + } + } + } + + return offset - baseOffset; + } + } +} diff --git a/com.unity.ml-agents.extensions/Tests/Editor/EditorExampleTest.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs.meta similarity index 83% rename from com.unity.ml-agents.extensions/Tests/Editor/EditorExampleTest.cs.meta rename to com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs.meta index fafd878624..04f26d1a8a 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/EditorExampleTest.cs.meta +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 8dadec7ee45de484bac4cab4efcca17d +guid: fcb7a51f0d5f8404db7b85bd35ecc1fb MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs new file mode 100644 index 0000000000..d2493ffbc3 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs @@ -0,0 +1,170 @@ +using System.Collections.Generic; +using UnityEngine; + +namespace Unity.MLAgents.Extensions.Sensors +{ + /// + /// Abstract class for managing the transforms of a hierarchy of objects. + /// This could be GameObjects or Monobehaviours in the scene graph, but this is + /// not a requirement; for example, the objects could be rigid bodies whose hierarchy + /// is defined by Joint configurations. + /// + /// Poses are either considered in model space, which is relative to a root body, + /// or in local space, which is relative to their parent. + /// + public abstract class PoseExtractor + { + int[] m_ParentIndices; + Pose[] m_ModelSpacePoses; + Pose[] m_LocalSpacePoses; + + /// + /// Read access to the model space transforms. + /// + public IList ModelSpacePoses + { + get { return m_ModelSpacePoses; } + } + + /// + /// Read access to the local space transforms. + /// + public IList LocalSpacePoses + { + get { return m_LocalSpacePoses; } + } + + /// + /// Number of transforms in the hierarchy (read-only). + /// + public int NumPoses + { + get { return m_ModelSpacePoses?.Length ?? 0; } + } + + /// + /// Initialize with the mapping of parent indices. + /// The 0th element is assumed to be -1, indicating that it's the root. + /// + /// + protected void SetParentIndices(int[] parentIndices) + { + m_ParentIndices = parentIndices; + var numTransforms = parentIndices.Length; + m_ModelSpacePoses = new Pose[numTransforms]; + m_LocalSpacePoses = new Pose[numTransforms]; + } + + /// + /// Return the world space Pose of the i'th object. + /// + /// + /// + protected abstract Pose GetPoseAt(int index); + + /// + /// Update the internal model space transform storage based on the underlying system. + /// + public void UpdateModelSpacePoses() + { + if (m_ModelSpacePoses == null) + { + return; + } + + var worldTransform = GetPoseAt(0); + var worldToModel = worldTransform.Inverse(); + + for (var i = 0; i < m_ModelSpacePoses.Length; i++) + { + var currentTransform = GetPoseAt(i); + m_ModelSpacePoses[i] = worldToModel.Multiply(currentTransform); + } + } + + /// + /// Update the internal model space transform storage based on the underlying system. + /// + public void UpdateLocalSpacePoses() + { + if (m_LocalSpacePoses == null) + { + return; + } + + for (var i = 0; i < m_LocalSpacePoses.Length; i++) + { + if (m_ParentIndices[i] != -1) + { + var parentTransform = GetPoseAt(m_ParentIndices[i]); + // This is slightly inefficient, since for a body with multiple children, we'll end up inverting + // the transform multiple times. Might be able to trade space for perf here. + var invParent = parentTransform.Inverse(); + var currentTransform = GetPoseAt(i); + m_LocalSpacePoses[i] = invParent.Multiply(currentTransform); + } + else + { + m_LocalSpacePoses[i] = Pose.identity; + } + } + } + + + public void DrawModelSpace(Vector3 offset) + { + UpdateLocalSpacePoses(); + UpdateModelSpacePoses(); + + var pose = m_ModelSpacePoses; + var localPose = m_LocalSpacePoses; + for (var i = 0; i < pose.Length; i++) + { + var current = pose[i]; + if (m_ParentIndices[i] == -1) + { + continue; + } + + var parent = pose[m_ParentIndices[i]]; + Debug.DrawLine(current.position + offset, parent.position + offset, Color.cyan); + var localUp = localPose[i].rotation * Vector3.up; + var localFwd = localPose[i].rotation * Vector3.forward; + var localRight = localPose[i].rotation * Vector3.right; + Debug.DrawLine(current.position+offset, current.position+offset+.1f*localUp, Color.red); + Debug.DrawLine(current.position+offset, current.position+offset+.1f*localFwd, Color.green); + Debug.DrawLine(current.position+offset, current.position+offset+.1f*localRight, Color.blue); + } + } + } + + public static class PoseExtensions + { + /// + /// Compute the inverse of a Pose. For any Pose P, + /// P.Inverse() * P + /// will equal the identity pose (within tolerance). + /// + /// + /// + public static Pose Inverse(this Pose pose) + { + var rotationInverse = Quaternion.Inverse(pose.rotation); + var translationInverse = -(rotationInverse * pose.position); + return new Pose { rotation = rotationInverse, position = translationInverse }; + } + + /// + /// This is equivalent to Pose.GetTransformedBy(), but keeps the order more intuitive. + /// + /// + /// + /// + public static Pose Multiply(this Pose pose, Pose rhs) + { + return rhs.GetTransformedBy(pose); + } + + // TODO optimize inv(A)*B? + } +} diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs.meta new file mode 100644 index 0000000000..5a16a709be --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: e0f1e5147a394f428f9a6447c4a8a1f4 +timeCreated: 1591919825 \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs new file mode 100644 index 0000000000..176f21b2da --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs @@ -0,0 +1,70 @@ +using System.Collections.Generic; +using UnityEngine; + +namespace Unity.MLAgents.Extensions.Sensors +{ + + /// + /// Utility class to track a hierarchy of RigidBodies. These are assumed to have a root node, + /// and child nodes are connect to their parents via Joints. + /// + public class RigidBodyPoseExtractor : PoseExtractor + { + Rigidbody[] m_Bodies; + + /// + /// Initialize given a root RigidBody. + /// + /// + public RigidBodyPoseExtractor(Rigidbody rootBody) + { + if (rootBody == null) + { + return; + } + var rbs = rootBody.GetComponentsInChildren (); + var bodyToIndex = new Dictionary(rbs.Length); + var parentIndices = new int[rbs.Length]; + + if (rbs[0] != rootBody) + { + Debug.Log("Expected root body at index 0"); + return; + } + + for (var i = 0; i < rbs.Length; i++) + { + bodyToIndex[rbs[i]] = i; + } + + var joints = rootBody.GetComponentsInChildren (); + + + foreach (var j in joints) + { + var parent = j.connectedBody; + var child = j.GetComponent(); + + var parentIndex = bodyToIndex[parent]; + var childIndex = bodyToIndex[child]; + parentIndices[childIndex] = parentIndex; + } + + m_Bodies = rbs; + SetParentIndices(parentIndices); + } + + /// + /// Get the pose of the i'th RigidBody. + /// + /// + /// + protected override Pose GetPoseAt(int index) + { + var body = m_Bodies[index]; + return new Pose { rotation = body.rotation, position = body.position }; + } + + + } +} diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs.meta b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs.meta new file mode 100644 index 0000000000..8418e2f146 --- /dev/null +++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 867cab4a07f244518bae3e6fdda14416 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Tests/Editor/EditorExampleTest.cs b/com.unity.ml-agents.extensions/Tests/Editor/EditorExampleTest.cs deleted file mode 100644 index 118258ce0c..0000000000 --- a/com.unity.ml-agents.extensions/Tests/Editor/EditorExampleTest.cs +++ /dev/null @@ -1,20 +0,0 @@ -using UnityEngine; -using UnityEditor; -using UnityEngine.TestTools; -using NUnit.Framework; -using System.Collections; - -namespace Unity.MLAgents.Extensions.Tests -{ - - internal class EditorExampleTest { - - [Test] - public void EditorTestMath() - { - Assert.AreEqual(2, 1 + 1); - } - - } - -} diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors.meta b/com.unity.ml-agents.extensions/Tests/Editor/Sensors.meta new file mode 100644 index 0000000000..cb33686b6b --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: b7ffdca5cd8064ee6831175d7ffd3f0f +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs new file mode 100644 index 0000000000..b140e1f607 --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs @@ -0,0 +1,122 @@ +using UnityEngine; +using NUnit.Framework; +using Unity.MLAgents.Extensions.Sensors; + +namespace Unity.MLAgents.Extensions.Tests.Sensors +{ + public class PoseExtractorTests + { + class UselessPoseExtractor : PoseExtractor + { + protected override Pose GetPoseAt(int index) + { + return Pose.identity; + } + + public void Init(int[] parentIndices) + { + SetParentIndices(parentIndices); + } + } + + [Test] + public void TestEmptyExtractor() + { + var poseExtractor = new UselessPoseExtractor(); + + // These should be no-ops + poseExtractor.UpdateLocalSpacePoses(); + poseExtractor.UpdateModelSpacePoses(); + + Assert.AreEqual(0, poseExtractor.NumPoses); + } + + [Test] + public void TestSimpleExtractor() + { + var poseExtractor = new UselessPoseExtractor(); + var parentIndices = new[] { -1, 0 }; + poseExtractor.Init(parentIndices); + Assert.AreEqual(2, poseExtractor.NumPoses); + } + + + /// + /// A simple "chain" hierarchy, where each object is parented to the one before it. + /// 0 <- 1 <- 2 <- ... + /// + class ChainPoseExtractor : PoseExtractor + { + public Vector3 offset; + public ChainPoseExtractor(int size) + { + var parents = new int[size]; + for (var i = 0; i < size; i++) + { + parents[i] = i - 1; + } + SetParentIndices(parents); + } + + protected override Pose GetPoseAt(int index) + { + var rotation = Quaternion.identity; + var translation = offset + new Vector3(index, index, index); + return new Pose + { + rotation = rotation, + position = translation + }; + } + } + + [Test] + public void TestChain() + { + var size = 4; + var chain = new ChainPoseExtractor(size); + chain.offset = new Vector3(.5f, .75f, .333f); + + chain.UpdateModelSpacePoses(); + chain.UpdateLocalSpacePoses(); + + // Root transforms are currently always the identity. + Assert.IsTrue(chain.ModelSpacePoses[0] == Pose.identity); + Assert.IsTrue(chain.LocalSpacePoses[0] == Pose.identity); + + // Check the non-root transforms + for (var i = 1; i < size; i++) + { + var modelSpace = chain.ModelSpacePoses[i]; + var expectedModelTranslation = new Vector3(i, i, i); + Assert.IsTrue(expectedModelTranslation == modelSpace.position); + + var localSpace = chain.LocalSpacePoses[i]; + var expectedLocalTranslation = new Vector3(1, 1, 1); + Assert.IsTrue(expectedLocalTranslation == localSpace.position); + } + } + + } + + public class PoseExtensionTests + { + [Test] + public void TestInverse() + { + Pose t = new Pose + { + rotation = Quaternion.AngleAxis(23.0f, new Vector3(1, 1, 1).normalized), + position = new Vector3(-1.0f, 2.0f, 3.0f) + }; + + var inverseT = t.Inverse(); + var product = inverseT.Multiply(t); + Assert.IsTrue(Vector3.zero == product.position); + Assert.IsTrue(Quaternion.identity == product.rotation); + + Assert.IsTrue(Pose.identity == product); + } + + } +} diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs.meta b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs.meta new file mode 100644 index 0000000000..732dda565c --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 9b0deb29b2f5d4b03a7f75a516943c81 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs new file mode 100644 index 0000000000..079e7dac23 --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs @@ -0,0 +1,62 @@ +using System.Collections.Generic; +using UnityEngine; +using NUnit.Framework; +using Unity.MLAgents.Extensions.Sensors; + +namespace Unity.MLAgents.Extensions.Tests.Sensors +{ + public class RigidBodyPoseExtractorTests + { + [TearDown] + public void RemoveGameObjects() + { + var objects = GameObject.FindObjectsOfType(); + foreach (var o in objects) + { + UnityEngine.Object.DestroyImmediate(o); + } + } + + [Test] + public void TestNullRoot() + { + var poseExtractor = new RigidBodyPoseExtractor(null); + // These should be no-ops + poseExtractor.UpdateLocalSpacePoses(); + poseExtractor.UpdateModelSpacePoses(); + + Assert.AreEqual(0, poseExtractor.NumPoses); + } + + [Test] + public void TestSingleBody() + { + var go = new GameObject(); + var rootRb = go.AddComponent(); + var poseExtractor = new RigidBodyPoseExtractor(rootRb); + Assert.AreEqual(1, poseExtractor.NumPoses); + } + + [Test] + public void TestTwoBodies() + { + // * rootObj + // - rb1 + // * go2 + // - rb2 + // - joint + var rootObj = new GameObject(); + var rb1 = rootObj.AddComponent(); + + var go2 = new GameObject(); + var rb2 = go2.AddComponent(); + go2.transform.SetParent(rootObj.transform); + + var joint = go2.AddComponent(); + joint.connectedBody = rb1; + + var poseExtractor = new RigidBodyPoseExtractor(rb1); + Assert.AreEqual(2, poseExtractor.NumPoses); + } + } +} diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs.meta b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs.meta new file mode 100644 index 0000000000..70d23b254c --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 23f57b248aaf940a3962cef68c6d83f5 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: