1
+ from __future__ import annotations
2
+
1
3
from abc import ABC , abstractmethod
2
4
from collections import defaultdict
3
5
from dataclasses import dataclass
4
6
from itertools import zip_longest
7
+ from logging import getLogger
5
8
from statistics import mean
6
- from typing import Literal
9
+ from typing import TYPE_CHECKING , Annotated , Literal
7
10
from urllib .parse import urlparse
8
11
9
12
from jaro import jaro_winkler_metric
13
+ from pydantic import BaseModel , ConfigDict , Field , PlainSerializer , PlainValidator
10
14
from sklearn .linear_model import LogisticRegression
11
15
from typing_extensions import override
12
16
13
- from crawlee import Request
14
17
from crawlee ._utils .docs import docs_group
18
+ from crawlee ._utils .recoverable_state import RecoverableState
19
+
20
+ from ._utils import sklearn_model_serializer , sklearn_model_validator
21
+
22
+ if TYPE_CHECKING :
23
+ from types import TracebackType
24
+
25
+ from crawlee import Request
26
+
27
+ logger = getLogger (__name__ )
15
28
16
29
UrlComponents = list [str ]
17
30
RenderingType = Literal ['static' , 'client only' ]
18
31
FeatureVector = tuple [float , float ]
19
32
20
33
34
+ class RenderingTypePredictorState (BaseModel ):
35
+ model_config = ConfigDict (populate_by_name = True )
36
+
37
+ model : Annotated [
38
+ LogisticRegression ,
39
+ Field (LogisticRegression ),
40
+ PlainValidator (sklearn_model_validator ),
41
+ PlainSerializer (sklearn_model_serializer ),
42
+ ]
43
+
44
+ labels_coefficients : Annotated [defaultdict [str , float ], Field (alias = 'labelsCoefficients' )]
45
+
46
+
21
47
@docs_group ('Other' )
22
48
@dataclass (frozen = True )
23
49
class RenderingTypePrediction :
@@ -36,6 +62,11 @@ class RenderingTypePrediction:
36
62
class RenderingTypePredictor (ABC ):
37
63
"""Stores rendering type for previously crawled URLs and predicts the rendering type for unvisited urls."""
38
64
65
+ def __init__ (self ) -> None :
66
+ """Initialize a new instance."""
67
+ # Flag to indicate the state.
68
+ self ._active = False
69
+
39
70
@abstractmethod
40
71
def predict (self , request : Request ) -> RenderingTypePrediction :
41
72
"""Get `RenderingTypePrediction` based on the input request.
@@ -53,6 +84,32 @@ def store_result(self, request: Request, rendering_type: RenderingType) -> None:
53
84
rendering_type: Known suitable `RenderingType`.
54
85
"""
55
86
87
+ async def initialize (self ) -> None :
88
+ """Initialize additional resources required for the predictor operation."""
89
+ if self ._active :
90
+ raise RuntimeError (f'The { self .__class__ .__name__ } is already active.' )
91
+ self ._active = True
92
+
93
+ async def clear (self ) -> None :
94
+ """Clear and release additional resources used by the predictor."""
95
+ if not self ._active :
96
+ raise RuntimeError (f'The { self .__class__ .__name__ } is not active.' )
97
+ self ._active = False
98
+
99
+ async def __aenter__ (self ) -> RenderingTypePredictor :
100
+ """Initialize the predictor upon entering the context manager."""
101
+ await self .initialize ()
102
+ return self
103
+
104
+ async def __aexit__ (
105
+ self ,
106
+ exc_type : type [BaseException ] | None ,
107
+ exc_value : BaseException | None ,
108
+ exc_traceback : TracebackType | None ,
109
+ ) -> None :
110
+ """Clear the predictor upon exiting the context manager."""
111
+ await self .clear ()
112
+
56
113
57
114
@docs_group ('Other' )
58
115
class DefaultRenderingTypePredictor (RenderingTypePredictor ):
@@ -62,24 +119,59 @@ class DefaultRenderingTypePredictor(RenderingTypePredictor):
62
119
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
63
120
"""
64
121
65
- def __init__ (self , detection_ratio : float = 0.1 ) -> None :
122
+ def __init__ (
123
+ self ,
124
+ detection_ratio : float = 0.1 ,
125
+ * ,
126
+ persistence_enabled : bool = False ,
127
+ persist_state_key : str = 'rendering-type-predictor-state' ,
128
+ ) -> None :
66
129
"""Initialize a new instance.
67
130
68
131
Args:
69
132
detection_ratio: A number between 0 and 1 that determines the desired ratio of rendering type detections.
133
+ persist_state_key: Key in the key-value storage where the trained model parameters will be saved.
134
+ If None, defaults to 'rendering-type-predictor-state'.
135
+ persistence_enabled: Whether to enable persistence of the trained model parameters for reuse.
136
+
70
137
"""
138
+ super ().__init__ ()
139
+
71
140
self ._rendering_type_detection_results : dict [RenderingType , dict [str , list [UrlComponents ]]] = {
72
141
'static' : defaultdict (list ),
73
142
'client only' : defaultdict (list ),
74
143
}
75
- self ._model = LogisticRegression (max_iter = 1000 )
76
144
self ._detection_ratio = max (0 , min (1 , detection_ratio ))
77
145
78
146
# Used to increase detection probability recommendation for initial recommendations of each label.
79
147
# Reaches 1 (no additional increase) after n samples of specific label is already present in
80
148
# `self._rendering_type_detection_results`.
81
149
n = 3
82
- self ._labels_coefficients : dict [str , float ] = defaultdict (lambda : n + 2 )
150
+
151
+ self ._state = RecoverableState (
152
+ default_state = RenderingTypePredictorState (
153
+ model = LogisticRegression (max_iter = 1000 ), labels_coefficients = defaultdict (lambda : n + 2 )
154
+ ),
155
+ persist_state_key = persist_state_key ,
156
+ persistence_enabled = persistence_enabled ,
157
+ logger = logger ,
158
+ )
159
+
160
+ @override
161
+ async def initialize (self ) -> None :
162
+ """Get current state of the predictor."""
163
+ await super ().initialize ()
164
+
165
+ if not self ._state .is_initialized :
166
+ await self ._state .initialize ()
167
+
168
+ @override
169
+ async def clear (self ) -> None :
170
+ """Clear the predictor state."""
171
+ await super ().clear ()
172
+
173
+ if self ._state .is_initialized :
174
+ await self ._state .teardown ()
83
175
84
176
@override
85
177
def predict (self , request : Request ) -> RenderingTypePrediction :
@@ -91,19 +183,20 @@ def predict(self, request: Request) -> RenderingTypePrediction:
91
183
similarity_threshold = 0.1 # Prediction probability difference threshold to consider prediction unreliable.
92
184
label = request .label or ''
93
185
94
- if self ._rendering_type_detection_results ['static' ] or self ._rendering_type_detection_results ['client only' ]:
186
+ # Check that the model has already been fitted.
187
+ if hasattr (self ._state .current_value .model , 'coef_' ):
95
188
url_feature = self ._calculate_feature_vector (get_url_components (request .url ), label )
96
189
# Are both calls expensive?
97
- prediction = self ._model .predict ([url_feature ])[0 ]
98
- probability = self ._model .predict_proba ([url_feature ])[0 ]
190
+ prediction = self ._state . current_value . model .predict ([url_feature ])[0 ]
191
+ probability = self ._state . current_value . model .predict_proba ([url_feature ])[0 ]
99
192
100
193
if abs (probability [0 ] - probability [1 ]) < similarity_threshold :
101
194
# Prediction not reliable.
102
195
detection_probability_recommendation = 1.0
103
196
else :
104
197
detection_probability_recommendation = self ._detection_ratio
105
198
# Increase recommendation for uncommon labels.
106
- detection_probability_recommendation *= self ._labels_coefficients [label ]
199
+ detection_probability_recommendation *= self ._state . current_value . labels_coefficients [label ]
107
200
108
201
return RenderingTypePrediction (
109
202
rendering_type = ('client only' , 'static' )[int (prediction )],
@@ -122,8 +215,8 @@ def store_result(self, request: Request, rendering_type: RenderingType) -> None:
122
215
"""
123
216
label = request .label or ''
124
217
self ._rendering_type_detection_results [rendering_type ][label ].append (get_url_components (request .url ))
125
- if self ._labels_coefficients [label ] > 1 :
126
- self ._labels_coefficients [label ] -= 1
218
+ if self ._state . current_value . labels_coefficients [label ] > 1 :
219
+ self ._state . current_value . labels_coefficients [label ] -= 1
127
220
self ._retrain ()
128
221
129
222
def _retrain (self ) -> None :
@@ -137,7 +230,7 @@ def _retrain(self) -> None:
137
230
x .append (self ._calculate_feature_vector (url_components , label ))
138
231
y .append (encoded_rendering_type )
139
232
140
- self ._model .fit (x , y )
233
+ self ._state . current_value . model .fit (x , y )
141
234
142
235
def _calculate_mean_similarity (self , url : UrlComponents , label : str , rendering_type : RenderingType ) -> float :
143
236
if not self ._rendering_type_detection_results [rendering_type ][label ]:
0 commit comments