Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ public class HallwayCollabAgent : HallwayAgent

[HideInInspector]
public int selection = 0;

public override void Initialize()
{
base.Initialize();
if (isSpotter)
{
var teamManager = new HallwayTeamManager();
SetTeamManager(teamManager);
teammate.SetTeamManager(teamManager);
}
}
public override void OnEpisodeBegin()
{
m_Message = -1;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Extensions.Teams;
using Unity.MLAgents.Sensors;

public class HallwayTeamManager : BaseTeamManager
{
List<Agent> m_AgentList = new List<Agent> { };


public override void RegisterAgent(Agent agent)
{
m_AgentList.Add(agent);
}

public override void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors)
{
agent.SendDoneToTrainer();
}

public override void AddTeamReward(float reward)
{

}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions com.unity.ml-agents.extensions/Runtime/Teams.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Extensions.Teams
{
public class BaseTeamManager : ITeamManager
{
readonly string m_Id = System.Guid.NewGuid().ToString();
Copy link
Contributor

@chriselion chriselion Jan 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also make this an integer - have a global static ID and increment with Interlocked.Increment (which is threadsafe)


public virtual void RegisterAgent(Agent agent)
{
throw new System.NotImplementedException();
}

public virtual void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors)
{
// Possible implementation - save reference to Agent's IPolicy so that we can repeatedly
// call IPolicy.RequestDecision on behalf of the Agent after it's dead
// If so, we'll need dummy sensor impls with the same shape as the originals.
throw new System.NotImplementedException();
}

public virtual void AddTeamReward(float reward)
{

}

public string GetId()
{
return m_Id;
}

}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ internal class BehaviorParametersEditor : UnityEditor.Editor
const string k_InferenceDeviceName = "m_InferenceDevice";
const string k_BehaviorTypeName = "m_BehaviorType";
const string k_TeamIdName = "TeamId";
const string k_GroupIdName = "GroupId";
const string k_UseChildSensorsName = "m_UseChildSensors";
const string k_ObservableAttributeHandlingName = "m_ObservableAttributeHandling";

Expand Down Expand Up @@ -68,7 +67,6 @@ public override void OnInspectorGUI()
}
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck();

EditorGUILayout.PropertyField(so.FindProperty(k_GroupIdName));
EditorGUILayout.PropertyField(so.FindProperty(k_TeamIdName));
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
Expand Down
42 changes: 38 additions & 4 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ internal struct AgentInfo
/// </summary>
public int episodeId;

/// <summary>
/// Team Manager identifier.
/// </summary>
public string teamManagerId;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be public. I think a better approrach is to add an accessor for m_TeamManager, and add a method to the interface for an ID.


public void ClearActions()
{
storedActions.Clear();
Expand Down Expand Up @@ -312,6 +317,8 @@ internal struct AgentParameters
/// </summary>
float[] m_LegacyActionCache;

private ITeamManager m_TeamManager;

/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
Expand Down Expand Up @@ -443,6 +450,11 @@ public void LazyInitialize()
new int[m_ActuatorManager.NumDiscreteActions]
);

if (m_TeamManager != null)
{
m_Info.teamManagerId = m_TeamManager.GetId();
}

// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.
// To avoid the Agent resetting twice, the Agents will not begin their
Expand All @@ -459,7 +471,7 @@ public void LazyInitialize()
/// <summary>
/// The reason that the Agent has been set to "done".
/// </summary>
enum DoneReason
public enum DoneReason
{
/// <summary>
/// The episode was ended manually by calling <see cref="EndEpisode"/>.
Expand Down Expand Up @@ -535,9 +547,17 @@ void NotifyAgentDone(DoneReason doneReason)
}
}
// Request the last decision with no callbacks
// We request a decision so Python knows the Agent is done immediately
m_Brain?.RequestDecision(m_Info, sensors);
ResetSensors();
if (m_TeamManager != null)
{
// Send final observations to TeamManager if it exists.
// The TeamManager is responsible to keeping track of the Agent after it's
// done, including propagating any "posthumous" rewards.
m_TeamManager.OnAgentDone(this, doneReason, sensors);
}
else
{
SendDoneToTrainer();
}

// We also have to write any to any DemonstationStores so that they get the "done" flag.
foreach (var demoWriter in DemonstrationWriters)
Expand All @@ -560,6 +580,13 @@ void NotifyAgentDone(DoneReason doneReason)
m_Info.storedActions.Clear();
}

public void SendDoneToTrainer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal

{
// We request a decision so Python knows the Agent is done immediately
m_Brain?.RequestDecision(m_Info, sensors);
ResetSensors();
}

/// <summary>
/// Updates the Model assigned to this Agent instance.
/// </summary>
Expand Down Expand Up @@ -1344,5 +1371,12 @@ void DecideAction()
m_Info.CopyActions(actions);
m_ActuatorManager.UpdateActions(actions);
}

public void SetTeamManager(ITeamManager teamManager)
{
m_TeamManager = teamManager;
m_Info.teamManagerId = teamManager?.GetId();
teamManager?.RegisterAgent(this);
}
}
}
5 changes: 5 additions & 0 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
agentInfoProto.ActionMask.AddRange(ai.discreteActionMasks);
}

if (ai.teamManagerId != null)
{
agentInfoProto.TeamManagerId = ai.teamManagerId;
}

return agentInfoProto;
}

Expand Down
39 changes: 34 additions & 5 deletions com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@ static AgentInfoReflection() {
string.Concat(
"CjNtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50c19lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvItEBCg5B",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIuoBCg5B",
"Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY",
"ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv",
"bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj",
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG9KBAgBEAJKBAgCEANKBAgD",
"EARKBAgEEAVKBAgFEAZKBAgGEAdKBAgMEA1CJaoCIlVuaXR5Lk1MQWdlbnRz",
"LkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy",
"X2lkGA4gASgJSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH",
"SgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3Rz",
"YgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::Unity.MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto), global::Unity.MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations", "TeamManagerId" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -74,6 +75,7 @@ public AgentInfoProto(AgentInfoProto other) : this() {
id_ = other.id_;
actionMask_ = other.actionMask_.Clone();
observations_ = other.observations_.Clone();
teamManagerId_ = other.teamManagerId_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand Down Expand Up @@ -146,6 +148,17 @@ public int Id {
get { return observations_; }
}

/// <summary>Field number for the "team_manager_id" field.</summary>
public const int TeamManagerIdFieldNumber = 14;
private string teamManagerId_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string TeamManagerId {
get { return teamManagerId_; }
set {
teamManagerId_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as AgentInfoProto);
Expand All @@ -165,6 +178,7 @@ public bool Equals(AgentInfoProto other) {
if (Id != other.Id) return false;
if(!actionMask_.Equals(other.actionMask_)) return false;
if(!observations_.Equals(other.observations_)) return false;
if (TeamManagerId != other.TeamManagerId) return false;
return Equals(_unknownFields, other._unknownFields);
}

Expand All @@ -177,6 +191,7 @@ public override int GetHashCode() {
if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
hash ^= observations_.GetHashCode();
if (TeamManagerId.Length != 0) hash ^= TeamManagerId.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand Down Expand Up @@ -208,6 +223,10 @@ public void WriteTo(pb::CodedOutputStream output) {
}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
observations_.WriteTo(output, _repeated_observations_codec);
if (TeamManagerId.Length != 0) {
output.WriteRawTag(114);
output.WriteString(TeamManagerId);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand All @@ -230,6 +249,9 @@ public int CalculateSize() {
}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
size += observations_.CalculateSize(_repeated_observations_codec);
if (TeamManagerId.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(TeamManagerId);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand All @@ -255,6 +277,9 @@ public void MergeFrom(AgentInfoProto other) {
}
actionMask_.Add(other.actionMask_);
observations_.Add(other.observations_);
if (other.TeamManagerId.Length != 0) {
TeamManagerId = other.TeamManagerId;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand Down Expand Up @@ -291,6 +316,10 @@ public void MergeFrom(pb::CodedInputStream input) {
observations_.AddEntriesFrom(input, _repeated_observations_codec);
break;
}
case 114: {
TeamManagerId = input.ReadString();
break;
}
}
}
}
Expand Down
14 changes: 14 additions & 0 deletions com.unity.ml-agents/Runtime/ITeamManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System.Collections.Generic;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents
{
public interface ITeamManager
{
string GetId();

void RegisterAgent(Agent agent);
// TODO not sure this is all the info we need, maybe pass a class/struct instead.
void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors);
}
}
3 changes: 3 additions & 0 deletions com.unity.ml-agents/Runtime/ITeamManager.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 1 addition & 7 deletions com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,6 @@ public string BehaviorName
[HideInInspector, SerializeField, FormerlySerializedAs("m_TeamID")]
public int TeamId;

/// <summary>
/// The group ID for this behavior.
/// </summary>
[HideInInspector, SerializeField]
[Tooltip("Assign the same Group ID to all Agents in the same Area.")]
public int GroupId;
// TODO properties here instead of Agent

[FormerlySerializedAs("m_useChildSensors")]
Expand Down Expand Up @@ -200,7 +194,7 @@ public ObservableAttributeOptions ObservableAttributeHandling
/// </summary>
public string FullyQualifiedBehaviorName
{
get { return m_BehaviorName + "?team=" + TeamId + "&group=" + GroupId; }
get { return m_BehaviorName + "?team=" + TeamId; }
}

internal IPolicy GeneratePolicy(ActionSpec actionSpec, HeuristicPolicy.ActionGenerator heuristic)
Expand Down
Loading