Skip to content

Commit 47f1b5b

Browse files
committed
Define an alternative can_handle logic by passing a callable.
1 parent d193c96 commit 47f1b5b

File tree

2 files changed

+79
-12
lines changed

2 files changed

+79
-12
lines changed

mocket/mocks/mockhttp.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from functools import cached_property
44
from http.server import BaseHTTPRequestHandler
5+
from typing import Callable, Optional
56
from urllib.parse import parse_qs, unquote, urlsplit
67

78
from h11 import SERVER, Connection, Data
@@ -82,9 +83,7 @@ def __init__(self, body="", status=200, headers=None):
8283
self.status = status
8384

8485
self.set_base_headers()
85-
86-
if headers is not None:
87-
self.set_extra_headers(headers)
86+
self.set_extra_headers(headers)
8887

8988
self.data = self.get_protocol_data() + self.body
9089

@@ -142,9 +141,19 @@ class Entry(MocketEntry):
142141
request_cls = Request
143142
response_cls = Response
144143

145-
default_config = {"match_querystring": True}
144+
default_config = {"match_querystring": True, "can_handle_fun": None}
145+
_can_handle_fun: Optional[Callable] = None
146+
147+
def __init__(
148+
self,
149+
uri,
150+
method,
151+
responses,
152+
match_querystring: bool = True,
153+
can_handle_fun: Optional[Callable] = None,
154+
):
155+
self._can_handle_fun = can_handle_fun if can_handle_fun else self._can_handle
146156

147-
def __init__(self, uri, method, responses, match_querystring: bool = True):
148157
uri = urlsplit(uri)
149158

150159
port = uri.port
@@ -177,6 +186,18 @@ def collect(self, data):
177186

178187
return consume_response
179188

189+
def _can_handle(self, path: str, qs_dict: dict) -> bool:
190+
"""
191+
The default can_handle function, which checks if the path match,
192+
and if match_querystring is True, also checks if the querystring matches.
193+
"""
194+
can_handle = path == self.path
195+
if self._match_querystring:
196+
can_handle = can_handle and qs_dict == parse_qs(
197+
self.query, keep_blank_values=True
198+
)
199+
return can_handle
200+
180201
def can_handle(self, data):
181202
r"""
182203
>>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b'<html/>'),))
@@ -192,13 +213,12 @@ def can_handle(self, data):
192213
except ValueError:
193214
return self is getattr(Mocket, "_last_entry", None)
194215

195-
uri = urlsplit(path)
196-
can_handle = uri.path == self.path and method == self.method
197-
if self._match_querystring:
198-
kw = dict(keep_blank_values=True)
199-
can_handle = can_handle and parse_qs(uri.query, **kw) == parse_qs(
200-
self.query, **kw
201-
)
216+
_request = urlsplit(path)
217+
218+
can_handle = method == self.method and self._can_handle_fun(
219+
_request.path, parse_qs(_request.query, keep_blank_values=True)
220+
)
221+
202222
if can_handle:
203223
Mocket._last_entry = self
204224
return can_handle
@@ -249,8 +269,27 @@ def single_register(
249269
headers=None,
250270
exception=None,
251271
match_querystring=True,
272+
can_handle_fun=None,
252273
**config,
253274
):
275+
"""
276+
A helper method to register a single Response for a given URI and method.
277+
Instead of passing a list of Response objects, you can just pass the response
278+
parameters directly.
279+
280+
Args:
281+
method (str): The HTTP method (e.g., 'GET', 'POST').
282+
uri (str): The URI to register the response for.
283+
body (str, optional): The body of the response. Defaults to an empty string.
284+
status (int, optional): The HTTP status code. Defaults to 200.
285+
headers (dict, optional): A dictionary of headers to include in the response. Defaults to None.
286+
exception (Exception, optional): An exception to raise instead of returning a response. Defaults to None.
287+
match_querystring (bool, optional): Whether to match the querystring in the URI. Defaults to True.
288+
can_handle_fun (Callable, optional): A custom function to determine if the Entry can handle a request.
289+
Defaults to None. If None, the default matching logic is used. The function should accept two parameters:
290+
path (str), and querystring params (dict), and return a boolean. Method is matched before the function call.
291+
**config: Additional configuration options.
292+
"""
254293
response = (
255294
exception
256295
if exception
@@ -262,5 +301,6 @@ def single_register(
262301
uri,
263302
response,
264303
match_querystring=match_querystring,
304+
can_handle_fun=can_handle_fun,
265305
**config,
266306
)

tests/test_http.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,30 @@ def test_mocket_with_no_path(self):
455455
response = urlopen("http://httpbin.local/")
456456
self.assertEqual(response.code, 202)
457457
self.assertEqual(Mocket._entries[("httpbin.local", 80)][0].path, "/")
458+
459+
@mocketize
460+
def test_can_handle(self):
461+
Entry.single_register(
462+
Entry.POST,
463+
"http://testme.org/foobar",
464+
body=json.dumps({"message": "Spooky!"}),
465+
match_querystring=False,
466+
)
467+
Entry.single_register(
468+
Entry.GET,
469+
"http://testme.org/",
470+
body=json.dumps({"message": "Gotcha!"}),
471+
can_handle_fun=lambda p, q: p.endswith("/foobar") and "a" in q,
472+
)
473+
Entry.single_register(
474+
Entry.GET,
475+
"http://testme.org/foobar",
476+
body=json.dumps({"message": "Missed!"}),
477+
match_querystring=False,
478+
)
479+
response = requests.get("http://testme.org/foobar?a=1")
480+
self.assertEqual(response.status_code, 200)
481+
self.assertEqual(response.json(), {"message": "Gotcha!"})
482+
response = requests.get("http://testme.org/foobar?b=2")
483+
self.assertEqual(response.status_code, 200)
484+
self.assertEqual(response.json(), {"message": "Missed!"})

0 commit comments

Comments
 (0)