Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,17 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
}
}
// Add the dimension properties if any to the observationProto
var dimensionPropertySensor = sensor as IDimensionPropertiesSensor;
if (dimensionPropertySensor != null)
{
var dimensionProperties = dimensionPropertySensor.GetDimensionProperties();
int[] intDimensionProperties = new int[dimensionProperties.Length];
for (int i = 0; i < dimensionProperties.Length; i++)
{
observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
}
}
observationProto.Shape.AddRange(shape);
return observationProto;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ static ObservationReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKdAgoQT2JzZXJ2YXRp",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyK7AgoQT2JzZXJ2YXRp",
"b25Qcm90bxINCgVzaGFwZRgBIAMoBRJEChBjb21wcmVzc2lvbl90eXBlGAIg",
"ASgOMiouY29tbXVuaWNhdG9yX29iamVjdHMuQ29tcHJlc3Npb25UeXBlUHJv",
"dG8SGQoPY29tcHJlc3NlZF9kYXRhGAMgASgMSAASRgoKZmxvYXRfZGF0YRgE",
"IAEoCzIwLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8u",
"RmxvYXREYXRhSAASIgoaY29tcHJlc3NlZF9jaGFubmVsX21hcHBpbmcYBSAD",
"KAUaGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25f",
"ZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5H",
"EAFCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy",
"b3RvMw=="));
"KAUSHAoUZGltZW5zaW9uX3Byb3BlcnRpZXMYBiADKAUaGQoJRmxvYXREYXRh",
"EgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVz",
"c2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5HEAFCJaoCIlVuaXR5Lk1M",
"QWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
}));
}
#endregion
Expand Down Expand Up @@ -81,6 +81,7 @@ public ObservationProto(ObservationProto other) : this() {
shape_ = other.shape_.Clone();
compressionType_ = other.compressionType_;
compressedChannelMapping_ = other.compressedChannelMapping_.Clone();
dimensionProperties_ = other.dimensionProperties_.Clone();
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;
Expand Down Expand Up @@ -151,6 +152,16 @@ public ObservationProto Clone() {
get { return compressedChannelMapping_; }
}

/// <summary>Field number for the "dimension_properties" field.</summary>
public const int DimensionPropertiesFieldNumber = 6;
private static readonly pb::FieldCodec<int> _repeated_dimensionProperties_codec
= pb::FieldCodec.ForInt32(50);
private readonly pbc::RepeatedField<int> dimensionProperties_ = new pbc::RepeatedField<int>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<int> DimensionProperties {
get { return dimensionProperties_; }
}

private object observationData_;
/// <summary>Enum of possible cases for the "observation_data" oneof.</summary>
public enum ObservationDataOneofCase {
Expand Down Expand Up @@ -188,6 +199,7 @@ public bool Equals(ObservationProto other) {
if (CompressedData != other.CompressedData) return false;
if (!object.Equals(FloatData, other.FloatData)) return false;
if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false;
if(!dimensionProperties_.Equals(other.dimensionProperties_)) return false;
if (ObservationDataCase != other.ObservationDataCase) return false;
return Equals(_unknownFields, other._unknownFields);
}
Expand All @@ -200,6 +212,7 @@ public override int GetHashCode() {
if (observationDataCase_ == ObservationDataOneofCase.CompressedData) hash ^= CompressedData.GetHashCode();
if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode();
hash ^= compressedChannelMapping_.GetHashCode();
hash ^= dimensionProperties_.GetHashCode();
hash ^= (int) observationDataCase_;
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
Expand Down Expand Up @@ -228,6 +241,7 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteMessage(FloatData);
}
compressedChannelMapping_.WriteTo(output, _repeated_compressedChannelMapping_codec);
dimensionProperties_.WriteTo(output, _repeated_dimensionProperties_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand All @@ -247,6 +261,7 @@ public int CalculateSize() {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(FloatData);
}
size += compressedChannelMapping_.CalculateSize(_repeated_compressedChannelMapping_codec);
size += dimensionProperties_.CalculateSize(_repeated_dimensionProperties_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand All @@ -263,6 +278,7 @@ public void MergeFrom(ObservationProto other) {
CompressionType = other.CompressionType;
}
compressedChannelMapping_.Add(other.compressedChannelMapping_);
dimensionProperties_.Add(other.dimensionProperties_);
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;
Expand Down Expand Up @@ -313,6 +329,11 @@ public void MergeFrom(pb::CodedInputStream input) {
compressedChannelMapping_.AddEntriesFrom(input, _repeated_compressedChannelMapping_codec);
break;
}
case 50:
case 48: {
dimensionProperties_.AddEntriesFrom(input, _repeated_dimensionProperties_codec);
break;
}
}
}
}
Expand Down
95 changes: 95 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using System;

namespace Unity.MLAgents.Sensors
{
public class BufferSensor : ISensor, IDimensionPropertiesSensor
{
private int m_MaxNumObs;
private int m_ObsSize;
float[] m_ObservationBuffer;
int m_CurrentNumObservables;
public BufferSensor(int maxNumberObs, int obsSize)
{
m_MaxNumObs = maxNumberObs;
m_ObsSize = obsSize;
m_ObservationBuffer = new float[m_ObsSize * m_MaxNumObs];
m_CurrentNumObservables = 0;
}

/// <inheritdoc/>
public int[] GetObservationShape()
{
return new int[] { m_MaxNumObs, m_ObsSize };
}

/// <inheritdoc/>
public DimensionProperty[] GetDimensionProperties()
{
return new DimensionProperty[]{
DimensionProperty.VariableSize,
DimensionProperty.None
};
}

/// <summary>
/// Appends an observation to the buffer. If the buffer is full (maximum number
/// of observation is reached) the observation will be ignored. the length of
/// the provided observation array must be equal to the observation size of
/// the buffer sensor.
/// </summary>
/// <param name="obs"> The float array observation</param>
public void AppendObservation(float[] obs)
{
if (m_CurrentNumObservables >= m_MaxNumObs)
{
return;
}
for (int i = 0; i < obs.Length; i++)
{
m_ObservationBuffer[m_CurrentNumObservables * m_ObsSize + i] = obs[i];
}
m_CurrentNumObservables++;
}

/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
for (int i = 0; i < m_ObsSize * m_MaxNumObs; i++)
{
writer[i] = m_ObservationBuffer[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is implicitly zero padded, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the assumption is that if all the floats in an obs are 0, then the observation is padded. This padding will be ignored by the trainers, but still fed as input during inference / training.
We can pad with another value, but I chose 0.

}
return m_ObsSize * m_MaxNumObs;
}

/// <inheritdoc/>
public virtual byte[] GetCompressedObservation()
{
return null;
}

/// <inheritdoc/>
public void Update()
{
Reset();
}

/// <inheritdoc/>
public void Reset()
{
m_CurrentNumObservables = 0;
Array.Clear(m_ObservationBuffer, 0, m_ObservationBuffer.Length);
}

public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}

public string GetName()
{
return "BufferSensor";
}

}

}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 41 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using UnityEngine;

namespace Unity.MLAgents.Sensors
{

/// <summary>
/// A component for BufferSensor.
/// </summary>
[AddComponentMenu("ML Agents/Buffer Sensor", (int)MenuGroup.Sensors)]
public class BufferSensorComponent : SensorComponent
{
public int ObservableSize;
public int MaxNumObservables;
private BufferSensor m_Sensor;

/// <inheritdoc/>
public override ISensor CreateSensor()
{
m_Sensor = new BufferSensor(MaxNumObservables, ObservableSize);
return m_Sensor;
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { MaxNumObservables, ObservableSize };
}

/// <summary>
/// Appends an observation to the buffer. If the buffer is full (maximum number
/// of observation is reached) the observation will be ignored. the length of
/// the provided observation array must be equal to the observation size of
/// the buffer sensor.
/// </summary>
/// <param name="obs"> The float array observation</param>
public void AppendObservation(float[] obs)
{
m_Sensor.AppendObservation(obs);
}
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 47 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
namespace Unity.MLAgents.Sensors
{

/// <summary>
/// The Dimension property flags of the observations
/// </summary>
[System.Flags]
public enum DimensionProperty
{
/// <summary>
/// No properties specified.
/// </summary>
Unspecified = 0,

/// <summary>
/// No Property of the observation in that dimension. Observation can be processed with
/// fully connected networks.
/// </summary>
None = 1,

/// <summary>
/// Means it is suitable to do a convolution in this dimension.
/// </summary>
TranslationalEquivariance = 2,

/// <summary>
/// Means that there can be a variable number of observations in this dimension.
/// The observations are unordered.
/// </summary>
VariableSize = 4,
}


/// <summary>
/// Sensor interface for sensors with special dimension properties.
/// </summary>
public interface IDimensionPropertiesSensor
{
/// <summary>
/// Returns the array containing the properties of each dimensions of the
/// observation. The length of the array must be equal to the rank of the
/// observation tensor.
/// </summary>
/// <returns>The array of DimensionProperty</returns>
DimensionProperty[] GetDimensionProperties();
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions docs/Python-API.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,12 @@ A `TerminalStep` has the following fields:

A `BehaviorSpec` has the following fields :

- `observation_shapes` is a List of Tuples of int : Each Tuple corresponds to an
observation's dimensions (without the number of agents dimension). The shape
tuples have the same ordering as the ordering of the DecisionSteps,
- `sensor_specs` is a List of `SensorSpec` objects : Each `SensorSpec`
corresponds to an observation's properties: `shape` is a tuple of ints that
corresponds to the shape of the observation (without the number of agents dimension).
`dimension_property` is a tuple of flags containing extra information about how the
data should be processed in the corresponding dimension. Note that the `SensorSpec`
have the same ordering as the ordering of observations in the DecisionSteps,
DecisionStep, TerminalSteps and TerminalStep.
- `action_spec` is an `ActionSpec` namedtuple that defines the number and types
of actions for the Agent.
Expand Down
16 changes: 8 additions & 8 deletions gym-unity/gym_unity/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,16 @@ def _preprocess_single(self, single_visual_obs: np.ndarray) -> np.ndarray:

def _get_n_vis_obs(self) -> int:
result = 0
for shape in self.group_spec.observation_shapes:
if len(shape) == 3:
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 3:
result += 1
return result

def _get_vis_obs_shape(self) -> List[Tuple]:
result: List[Tuple] = []
for shape in self.group_spec.observation_shapes:
if len(shape) == 3:
result.append(shape)
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 3:
result.append(sen_spec.shape)
return result

def _get_vis_obs_list(
Expand All @@ -261,9 +261,9 @@ def _get_vector_obs(

def _get_vec_obs_size(self) -> int:
result = 0
for shape in self.group_spec.observation_shapes:
if len(shape) == 1:
result += shape[0]
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 1:
result += sen_spec.shape[0]
return result

def render(self, mode="rgb_array"):
Expand Down
Loading