Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to

### Bug Fixes
#### com.unity.ml-agents (C#)
- `Agent.CollectObservations()` and `Agent.EndEpisode()` will now throw an exception
if they are called recursively (for example, if they call `Agent.EndEpisode()`).
Previously, this would result in an infinite loop and cause the editor to hang. (#4573)
#### ml-agents / ml-agents-envs / gym-unity (Python)


Expand Down
27 changes: 3 additions & 24 deletions com.unity.ml-agents/Runtime/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,8 @@ public bool IsCommunicatorOn
// Flag used to keep track of the first time the Academy is reset.
bool m_HadFirstReset;

// Whether the Academy is in the middle of a step. This is used to detect and Academy
// step called by user code that is also called by the Academy.
bool m_IsStepping;
// Detect an Academy step called by user code that is also called by the Academy.
private RecursionChecker m_StepRecursionChecker = new RecursionChecker("EnvironmentStep");

// Random seed used for inference.
int m_InferenceSeed;
Expand Down Expand Up @@ -535,22 +534,7 @@ void ForcedFullReset()
/// </summary>
public void EnvironmentStep()
{
// Check whether we're already in the middle of a step.
// This shouldn't happen generally, but could happen if user code (e.g. CollectObservations)
// that is called by EnvironmentStep() also calls EnvironmentStep(). This would result
// in an infinite loop and/or stack overflow, so stop it before it happens.
if (m_IsStepping)
{
throw new UnityAgentsException(
"Academy.EnvironmentStep() called recursively. " +
"This might happen if you call EnvironmentStep() from custom code such as " +
"CollectObservations() or OnActionReceived()."
);
}

m_IsStepping = true;

try
using (m_StepRecursionChecker.Start())
{
if (!m_HadFirstReset)
{
Expand Down Expand Up @@ -584,11 +568,6 @@ public void EnvironmentStep()
AgentAct?.Invoke();
}
}
finally
{
// Reset m_IsStepping when we're done (or if an exception occurred).
m_IsStepping = false;
}
}

/// <summary>
Expand Down
24 changes: 20 additions & 4 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ internal struct AgentParameters
/// </summary>
internal VectorSensor collectObservationsSensor;

private RecursionChecker m_CollectObservationsChecker = new RecursionChecker("CollectObservations");
private RecursionChecker m_OnEpisodeBeginChecker = new RecursionChecker("OnEpisodeBegin");

/// <summary>
/// List of IActuators that this Agent will delegate actions to if any exist.
/// </summary>
Expand Down Expand Up @@ -435,7 +438,10 @@ public void LazyInitialize()
// episode when initializing until after the Academy had its first reset.
if (Academy.Instance.TotalStepCount != 0)
{
OnEpisodeBegin();
using (m_OnEpisodeBeginChecker.Start())
{
OnEpisodeBegin();
}
}
}

Expand Down Expand Up @@ -512,7 +518,10 @@ void NotifyAgentDone(DoneReason doneReason)
{
// Make sure the latest observations are being passed to training.
collectObservationsSensor.Reset();
CollectObservations(collectObservationsSensor);
using (m_CollectObservationsChecker.Start())
{
CollectObservations(collectObservationsSensor);
}
}
// Request the last decision with no callbacks
// We request a decision so Python knows the Agent is done immediately
Expand Down Expand Up @@ -1006,7 +1015,10 @@ void SendInfoToBrain()
UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
CollectObservations(collectObservationsSensor);
using (m_CollectObservationsChecker.Start())
{
CollectObservations(collectObservationsSensor);
}
}
using (TimerStack.Instance.Scoped("CollectDiscreteActionMasks"))
{
Expand Down Expand Up @@ -1229,7 +1241,11 @@ void _AgentReset()
{
ResetData();
m_StepCount = 0;
OnEpisodeBegin();
using (m_OnEpisodeBeginChecker.Start())
{
OnEpisodeBegin();
}

}

/// <summary>
Expand Down
35 changes: 35 additions & 0 deletions com.unity.ml-agents/Runtime/RecursionChecker.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System;

namespace Unity.MLAgents
{
internal class RecursionChecker : IDisposable
{
private bool m_IsRunning;
private string m_MethodName;

public RecursionChecker(string methodName)
{
m_MethodName = methodName;
}

public IDisposable Start()
{
if (m_IsRunning)
{
throw new UnityAgentsException(
$"{m_MethodName} called recursively. " +
"This might happen if you call EnvironmentStep() or EndEpisode() from custom " +
"code such as CollectObservations() or OnActionReceived()."
Comment on lines +20 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more general if instead of passing methodName as argument, we passed the whole error instead. The constructor would look like :

new RecursionChecker("CollectObservations() was called recursively. Make sure there are ...");

This would make the error messages more clear, and the code future proofer, but I do not have a strong opinion on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really know how to make the errors messages more clear and still catch all possible permutations that users could set up. I'm afraid if we say something like "don't call EndEpisode from CollectObservations" and someone manages to hit CollectObservations recursively some other way, it will be more confusing.

);
}
m_IsRunning = true;
return this;
}

public void Dispose()
{
// Reset the flag when we're done (or if an exception occurred).
m_IsRunning = false;
}
}
}
3 changes: 3 additions & 0 deletions com.unity.ml-agents/Runtime/RecursionChecker.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 56 additions & 0 deletions com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -830,4 +830,60 @@ public void TestObservableAttributeBehaviorIgnore()
}
}
}

[TestFixture]
public class AgentRecursionTests
{
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}

class CollectObsEndEpisodeAgent : Agent
{
public override void CollectObservations(VectorSensor sensor)
{
// NEVER DO THIS IN REAL CODE!
EndEpisode();
}
}

class OnEpisodeBeginEndEpisodeAgent : Agent
{
public override void OnEpisodeBegin()
{
// NEVER DO THIS IN REAL CODE!
EndEpisode();
}
}

void TestRecursiveThrows<T>() where T : Agent
{
var gameObj = new GameObject();
var agent = gameObj.AddComponent<T>();
agent.LazyInitialize();
agent.RequestDecision();

Assert.Throws<UnityAgentsException>(() =>
{
Academy.Instance.EnvironmentStep();
});
}

[Test]
public void TestRecursiveCollectObsEndEpisodeThrows()
{
TestRecursiveThrows<CollectObsEndEpisodeAgent>();
}

[Test]
public void TestRecursiveOnEpisodeBeginEndEpisodeThrows()
{
TestRecursiveThrows<OnEpisodeBeginEndEpisodeAgent>();
}
}
}
72 changes: 72 additions & 0 deletions com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using System;
using NUnit.Framework;

namespace Unity.MLAgents.Tests
{
[TestFixture]
public class RecursionCheckerTests
{
class InfiniteRecurser
{
RecursionChecker m_checker = new RecursionChecker("InfiniteRecurser");
public int NumCalls = 0;

public void Implode()
{
NumCalls++;
using (m_checker.Start())
{
Implode();
}
}
}

[Test]
public void TestRecursionCheck()
{
var rc = new InfiniteRecurser();
Assert.Throws<UnityAgentsException>(() =>
{
rc.Implode();
});

// Should increment twice before bailing out.
Assert.AreEqual(2, rc.NumCalls);
}

class OneTimeThrower
{
RecursionChecker m_checker = new RecursionChecker("OneTimeThrower");
public int NumCalls;

public void DoStuff()
{
// This method throws from inside the checker the first time.
// Later calls do nothing.
NumCalls++;
using (m_checker.Start())
{
if (NumCalls == 1)
{
throw new ArgumentException("oops");
}
}
}
}

[Test]
public void TestThrowResetsFlag()
{
var ott = new OneTimeThrower();
Assert.Throws<ArgumentException>(() =>
{
ott.DoStuff();
});

// Make sure the flag is cleared if we throw in the "using". Should be able to step subsequently.
ott.DoStuff();
ott.DoStuff();
Assert.AreEqual(3, ott.NumCalls);
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.