Skip to content

Commit fad4c25

Browse files
authored
feat: Persist DefaultRenderingTypePredictor state (#1340)
### Description - Persist `DefaultRenderingTypePredictor` state ### Issues - Closes: #1272
1 parent 2f24600 commit fad4c25

File tree

6 files changed

+289
-81
lines changed

6 files changed

+289
-81
lines changed

docs/guides/code_examples/playwright_crawler_adaptive/init_prediction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
class CustomRenderingTypePredictor(RenderingTypePredictor):
1414
def __init__(self) -> None:
15+
super().__init__()
16+
1517
self._learning_data = list[tuple[Request, RenderingType]]()
1618

1719
def predict(self, request: Request) -> RenderingTypePrediction:

src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ async def adaptive_pre_navigation_hook_pw(context: PlaywrightPreNavCrawlingConte
205205

206206
self._additional_context_managers = [
207207
*self._additional_context_managers,
208+
self.rendering_type_predictor,
208209
static_crawler.statistics,
209210
playwright_crawler.statistics,
210211
playwright_crawler._browser_pool, # noqa: SLF001 # Intentional access to private member.

src/crawlee/crawlers/_adaptive_playwright/_rendering_type_predictor.py

Lines changed: 105 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,49 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
24
from collections import defaultdict
35
from dataclasses import dataclass
46
from itertools import zip_longest
7+
from logging import getLogger
58
from statistics import mean
6-
from typing import Literal
9+
from typing import TYPE_CHECKING, Annotated, Literal
710
from urllib.parse import urlparse
811

912
from jaro import jaro_winkler_metric
13+
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, PlainValidator
1014
from sklearn.linear_model import LogisticRegression
1115
from typing_extensions import override
1216

13-
from crawlee import Request
1417
from crawlee._utils.docs import docs_group
18+
from crawlee._utils.recoverable_state import RecoverableState
19+
20+
from ._utils import sklearn_model_serializer, sklearn_model_validator
21+
22+
if TYPE_CHECKING:
23+
from types import TracebackType
24+
25+
from crawlee import Request
26+
27+
logger = getLogger(__name__)
1528

1629
UrlComponents = list[str]
1730
RenderingType = Literal['static', 'client only']
1831
FeatureVector = tuple[float, float]
1932

2033

34+
class RenderingTypePredictorState(BaseModel):
35+
model_config = ConfigDict(populate_by_name=True)
36+
37+
model: Annotated[
38+
LogisticRegression,
39+
Field(LogisticRegression),
40+
PlainValidator(sklearn_model_validator),
41+
PlainSerializer(sklearn_model_serializer),
42+
]
43+
44+
labels_coefficients: Annotated[defaultdict[str, float], Field(alias='labelsCoefficients')]
45+
46+
2147
@docs_group('Other')
2248
@dataclass(frozen=True)
2349
class RenderingTypePrediction:
@@ -36,6 +62,11 @@ class RenderingTypePrediction:
3662
class RenderingTypePredictor(ABC):
3763
"""Stores rendering type for previously crawled URLs and predicts the rendering type for unvisited urls."""
3864

65+
def __init__(self) -> None:
66+
"""Initialize a new instance."""
67+
# Flag to indicate the state.
68+
self._active = False
69+
3970
@abstractmethod
4071
def predict(self, request: Request) -> RenderingTypePrediction:
4172
"""Get `RenderingTypePrediction` based on the input request.
@@ -53,6 +84,32 @@ def store_result(self, request: Request, rendering_type: RenderingType) -> None:
5384
rendering_type: Known suitable `RenderingType`.
5485
"""
5586

87+
async def initialize(self) -> None:
88+
"""Initialize additional resources required for the predictor operation."""
89+
if self._active:
90+
raise RuntimeError(f'The {self.__class__.__name__} is already active.')
91+
self._active = True
92+
93+
async def clear(self) -> None:
94+
"""Clear and release additional resources used by the predictor."""
95+
if not self._active:
96+
raise RuntimeError(f'The {self.__class__.__name__} is not active.')
97+
self._active = False
98+
99+
async def __aenter__(self) -> RenderingTypePredictor:
100+
"""Initialize the predictor upon entering the context manager."""
101+
await self.initialize()
102+
return self
103+
104+
async def __aexit__(
105+
self,
106+
exc_type: type[BaseException] | None,
107+
exc_value: BaseException | None,
108+
exc_traceback: TracebackType | None,
109+
) -> None:
110+
"""Clear the predictor upon exiting the context manager."""
111+
await self.clear()
112+
56113

57114
@docs_group('Other')
58115
class DefaultRenderingTypePredictor(RenderingTypePredictor):
@@ -62,24 +119,59 @@ class DefaultRenderingTypePredictor(RenderingTypePredictor):
62119
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
63120
"""
64121

65-
def __init__(self, detection_ratio: float = 0.1) -> None:
122+
def __init__(
123+
self,
124+
detection_ratio: float = 0.1,
125+
*,
126+
persistence_enabled: bool = False,
127+
persist_state_key: str = 'rendering-type-predictor-state',
128+
) -> None:
66129
"""Initialize a new instance.
67130
68131
Args:
69132
detection_ratio: A number between 0 and 1 that determines the desired ratio of rendering type detections.
133+
persist_state_key: Key in the key-value storage where the trained model parameters will be saved.
134+
If None, defaults to 'rendering-type-predictor-state'.
135+
persistence_enabled: Whether to enable persistence of the trained model parameters for reuse.
136+
70137
"""
138+
super().__init__()
139+
71140
self._rendering_type_detection_results: dict[RenderingType, dict[str, list[UrlComponents]]] = {
72141
'static': defaultdict(list),
73142
'client only': defaultdict(list),
74143
}
75-
self._model = LogisticRegression(max_iter=1000)
76144
self._detection_ratio = max(0, min(1, detection_ratio))
77145

78146
# Used to increase detection probability recommendation for initial recommendations of each label.
79147
# Reaches 1 (no additional increase) after n samples of specific label is already present in
80148
# `self._rendering_type_detection_results`.
81149
n = 3
82-
self._labels_coefficients: dict[str, float] = defaultdict(lambda: n + 2)
150+
151+
self._state = RecoverableState(
152+
default_state=RenderingTypePredictorState(
153+
model=LogisticRegression(max_iter=1000), labels_coefficients=defaultdict(lambda: n + 2)
154+
),
155+
persist_state_key=persist_state_key,
156+
persistence_enabled=persistence_enabled,
157+
logger=logger,
158+
)
159+
160+
@override
161+
async def initialize(self) -> None:
162+
"""Get current state of the predictor."""
163+
await super().initialize()
164+
165+
if not self._state.is_initialized:
166+
await self._state.initialize()
167+
168+
@override
169+
async def clear(self) -> None:
170+
"""Clear the predictor state."""
171+
await super().clear()
172+
173+
if self._state.is_initialized:
174+
await self._state.teardown()
83175

84176
@override
85177
def predict(self, request: Request) -> RenderingTypePrediction:
@@ -91,19 +183,20 @@ def predict(self, request: Request) -> RenderingTypePrediction:
91183
similarity_threshold = 0.1 # Prediction probability difference threshold to consider prediction unreliable.
92184
label = request.label or ''
93185

94-
if self._rendering_type_detection_results['static'] or self._rendering_type_detection_results['client only']:
186+
# Check that the model has already been fitted.
187+
if hasattr(self._state.current_value.model, 'coef_'):
95188
url_feature = self._calculate_feature_vector(get_url_components(request.url), label)
96189
# Are both calls expensive?
97-
prediction = self._model.predict([url_feature])[0]
98-
probability = self._model.predict_proba([url_feature])[0]
190+
prediction = self._state.current_value.model.predict([url_feature])[0]
191+
probability = self._state.current_value.model.predict_proba([url_feature])[0]
99192

100193
if abs(probability[0] - probability[1]) < similarity_threshold:
101194
# Prediction not reliable.
102195
detection_probability_recommendation = 1.0
103196
else:
104197
detection_probability_recommendation = self._detection_ratio
105198
# Increase recommendation for uncommon labels.
106-
detection_probability_recommendation *= self._labels_coefficients[label]
199+
detection_probability_recommendation *= self._state.current_value.labels_coefficients[label]
107200

108201
return RenderingTypePrediction(
109202
rendering_type=('client only', 'static')[int(prediction)],
@@ -122,8 +215,8 @@ def store_result(self, request: Request, rendering_type: RenderingType) -> None:
122215
"""
123216
label = request.label or ''
124217
self._rendering_type_detection_results[rendering_type][label].append(get_url_components(request.url))
125-
if self._labels_coefficients[label] > 1:
126-
self._labels_coefficients[label] -= 1
218+
if self._state.current_value.labels_coefficients[label] > 1:
219+
self._state.current_value.labels_coefficients[label] -= 1
127220
self._retrain()
128221

129222
def _retrain(self) -> None:
@@ -137,7 +230,7 @@ def _retrain(self) -> None:
137230
x.append(self._calculate_feature_vector(url_components, label))
138231
y.append(encoded_rendering_type)
139232

140-
self._model.fit(x, y)
233+
self._state.current_value.model.fit(x, y)
141234

142235
def _calculate_mean_similarity(self, url: UrlComponents, label: str, rendering_type: RenderingType) -> float:
143236
if not self._rendering_type_detection_results[rendering_type][label]:
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
from sklearn.linear_model import LogisticRegression
5+
6+
7+
def sklearn_model_validator(v: LogisticRegression | dict[str, Any]) -> LogisticRegression:
8+
if isinstance(v, LogisticRegression):
9+
return v
10+
11+
model = LogisticRegression(max_iter=1000)
12+
if v.get('is_fitted', False):
13+
model.coef_ = np.array(v['coef'])
14+
model.intercept_ = np.array(v['intercept'])
15+
model.classes_ = np.array(v['classes'])
16+
model.n_iter_ = np.array(v.get('n_iter', [1000]))
17+
18+
return model
19+
20+
21+
def sklearn_model_serializer(model: LogisticRegression) -> dict[str, Any]:
22+
if hasattr(model, 'coef_'):
23+
return {
24+
'coef': model.coef_.tolist(),
25+
'intercept': model.intercept_.tolist(),
26+
'classes': model.classes_.tolist(),
27+
'n_iter': model.n_iter_.tolist() if hasattr(model, 'n_iter_') else [1000],
28+
'is_fitted': True,
29+
'max_iter': model.max_iter,
30+
'solver': model.solver,
31+
}
32+
return {'is_fitted': False, 'max_iter': model.max_iter, 'solver': model.solver}

tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def __init__(
8181
rendering_types: Iterator[RenderingType] | None = None,
8282
detection_probability_recommendation: None | Iterator[float] = None,
8383
) -> None:
84+
super().__init__()
85+
8486
self._rendering_types = rendering_types or cycle(['static'])
8587
self._detection_probability_recommendation = detection_probability_recommendation or cycle([1])
8688

0 commit comments

Comments
 (0)