2
2
import time
3
3
from functools import cached_property
4
4
from http .server import BaseHTTPRequestHandler
5
+ from typing import Callable , Optional
5
6
from urllib .parse import parse_qs , unquote , urlsplit
6
7
7
8
from h11 import SERVER , Connection , Data
@@ -82,9 +83,7 @@ def __init__(self, body="", status=200, headers=None):
82
83
self .status = status
83
84
84
85
self .set_base_headers ()
85
-
86
- if headers is not None :
87
- self .set_extra_headers (headers )
86
+ self .set_extra_headers (headers )
88
87
89
88
self .data = self .get_protocol_data () + self .body
90
89
@@ -142,9 +141,19 @@ class Entry(MocketEntry):
142
141
request_cls = Request
143
142
response_cls = Response
144
143
145
- default_config = {"match_querystring" : True }
144
+ default_config = {"match_querystring" : True , "can_handle_fun" : None }
145
+ _can_handle_fun : Optional [Callable ] = None
146
+
147
+ def __init__ (
148
+ self ,
149
+ uri ,
150
+ method ,
151
+ responses ,
152
+ match_querystring : bool = True ,
153
+ can_handle_fun : Optional [Callable ] = None ,
154
+ ):
155
+ self ._can_handle_fun = can_handle_fun if can_handle_fun else self ._can_handle
146
156
147
- def __init__ (self , uri , method , responses , match_querystring : bool = True ):
148
157
uri = urlsplit (uri )
149
158
150
159
port = uri .port
@@ -177,6 +186,18 @@ def collect(self, data):
177
186
178
187
return consume_response
179
188
189
+ def _can_handle (self , path : str , qs_dict : dict ) -> bool :
190
+ """
191
+ The default can_handle function, which checks if the path match,
192
+ and if match_querystring is True, also checks if the querystring matches.
193
+ """
194
+ can_handle = path == self .path
195
+ if self ._match_querystring :
196
+ can_handle = can_handle and qs_dict == parse_qs (
197
+ self .query , keep_blank_values = True
198
+ )
199
+ return can_handle
200
+
180
201
def can_handle (self , data ):
181
202
r"""
182
203
>>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b'<html/>'),))
@@ -192,13 +213,12 @@ def can_handle(self, data):
192
213
except ValueError :
193
214
return self is getattr (Mocket , "_last_entry" , None )
194
215
195
- uri = urlsplit (path )
196
- can_handle = uri .path == self .path and method == self .method
197
- if self ._match_querystring :
198
- kw = dict (keep_blank_values = True )
199
- can_handle = can_handle and parse_qs (uri .query , ** kw ) == parse_qs (
200
- self .query , ** kw
201
- )
216
+ _request = urlsplit (path )
217
+
218
+ can_handle = method == self .method and self ._can_handle_fun (
219
+ _request .path , parse_qs (_request .query , keep_blank_values = True )
220
+ )
221
+
202
222
if can_handle :
203
223
Mocket ._last_entry = self
204
224
return can_handle
@@ -249,8 +269,27 @@ def single_register(
249
269
headers = None ,
250
270
exception = None ,
251
271
match_querystring = True ,
272
+ can_handle_fun = None ,
252
273
** config ,
253
274
):
275
+ """
276
+ A helper method to register a single Response for a given URI and method.
277
+ Instead of passing a list of Response objects, you can just pass the response
278
+ parameters directly.
279
+
280
+ Args:
281
+ method (str): The HTTP method (e.g., 'GET', 'POST').
282
+ uri (str): The URI to register the response for.
283
+ body (str, optional): The body of the response. Defaults to an empty string.
284
+ status (int, optional): The HTTP status code. Defaults to 200.
285
+ headers (dict, optional): A dictionary of headers to include in the response. Defaults to None.
286
+ exception (Exception, optional): An exception to raise instead of returning a response. Defaults to None.
287
+ match_querystring (bool, optional): Whether to match the querystring in the URI. Defaults to True.
288
+ can_handle_fun (Callable, optional): A custom function to determine if the Entry can handle a request.
289
+ Defaults to None. If None, the default matching logic is used. The function should accept two parameters:
290
+ path (str), and querystring params (dict), and return a boolean. Method is matched before the function call.
291
+ **config: Additional configuration options.
292
+ """
254
293
response = (
255
294
exception
256
295
if exception
@@ -262,5 +301,6 @@ def single_register(
262
301
uri ,
263
302
response ,
264
303
match_querystring = match_querystring ,
304
+ can_handle_fun = can_handle_fun ,
265
305
** config ,
266
306
)
0 commit comments