Skip to content

Commit f540d32

Browse files
committed
refactor: add class that handles request records
1 parent 7f00e02 commit f540d32

File tree

4 files changed

+227
-125
lines changed

4 files changed

+227
-125
lines changed

mocket/mocket.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import collections
44
import itertools
55
import os
6+
from pathlib import Path
67
from typing import TYPE_CHECKING, ClassVar
78

89
import mocket.inject
10+
from mocket.recording import MocketRecordStorage
911

1012
# NOTE this is here for backwards-compat to keep old import-paths working
1113
# from mocket.socket import MocketSocket as MocketSocket
@@ -20,21 +22,28 @@ class Mocket:
2022
_address: ClassVar[Address] = (None, None)
2123
_entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list)
2224
_requests: ClassVar[list] = []
23-
_namespace: ClassVar[str] = str(id(_entries))
24-
_truesocket_recording_dir: ClassVar[str | None] = None
25+
_record_storage: ClassVar[MocketRecordStorage | None] = None
2526

2627
@classmethod
2728
def enable(
2829
cls,
2930
namespace: str | None = None,
3031
truesocket_recording_dir: str | None = None,
3132
) -> None:
32-
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
33-
# JSON dumps will be saved here
34-
raise AssertionError
33+
if namespace is None:
34+
namespace = str(id(cls._entries))
3535

36-
cls._namespace = namespace
37-
cls._truesocket_recording_dir = truesocket_recording_dir
36+
if truesocket_recording_dir is not None:
37+
recording_dir = Path(truesocket_recording_dir)
38+
39+
if not recording_dir.is_dir():
40+
# JSON dumps will be saved here
41+
raise AssertionError
42+
43+
cls._record_storage = MocketRecordStorage(
44+
directory=recording_dir,
45+
namespace=namespace,
46+
)
3847

3948
mocket.inject.enable()
4049

@@ -87,6 +96,7 @@ def reset(cls) -> None:
8796
cls._socket_pairs = {}
8897
cls._entries = collections.defaultdict(list)
8998
cls._requests = []
99+
cls._record_storage = None
90100

91101
@classmethod
92102
def last_request(cls):
@@ -107,12 +117,16 @@ def has_requests(cls) -> bool:
107117
return bool(cls.request_list())
108118

109119
@classmethod
110-
def get_namespace(cls) -> str:
111-
return cls._namespace
120+
def get_namespace(cls) -> str | None:
121+
if not cls._record_storage:
122+
return None
123+
return cls._record_storage.namespace
112124

113125
@classmethod
114126
def get_truesocket_recording_dir(cls) -> str | None:
115-
return cls._truesocket_recording_dir
127+
if not cls._record_storage:
128+
return None
129+
return str(cls._record_storage.directory)
116130

117131
@classmethod
118132
def assert_fail_if_entries_not_served(cls) -> None:

mocket/recording.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from __future__ import annotations
2+
3+
import binascii
4+
import contextlib
5+
import hashlib
6+
import json
7+
from collections import defaultdict
8+
from dataclasses import dataclass
9+
from pathlib import Path
10+
11+
from mocket.compat import decode_from_bytes, encode_to_bytes
12+
from mocket.types import Address
13+
14+
hash_function = hashlib.md5
15+
16+
with contextlib.suppress(ImportError):
17+
from xxhash_cffi import xxh32 as xxhash_cffi_xxh32
18+
19+
hash_function = xxhash_cffi_xxh32
20+
21+
with contextlib.suppress(ImportError):
22+
from xxhash import xxh32 as xxhash_xxh32
23+
24+
hash_function = xxhash_xxh32
25+
26+
27+
def _hash_prepare_request(data: bytes) -> bytes:
28+
_data = decode_from_bytes(data)
29+
return encode_to_bytes("".join(sorted(_data.split("\r\n"))))
30+
31+
32+
def _hash_request(data: bytes) -> str:
33+
_data = _hash_prepare_request(data)
34+
return hash_function(_data).hexdigest()
35+
36+
37+
def _hash_request_fallback(data: bytes) -> str:
38+
_data = _hash_prepare_request(data)
39+
return hashlib.md5(_data).hexdigest()
40+
41+
42+
def hexdump(binary_string: bytes) -> str:
43+
r"""
44+
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
45+
True
46+
"""
47+
bs = decode_from_bytes(binascii.hexlify(binary_string).upper())
48+
return " ".join(a + b for a, b in zip(bs[::2], bs[1::2]))
49+
50+
51+
def hexload(string: str) -> bytes:
52+
r"""
53+
>>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo")
54+
True
55+
"""
56+
string_no_spaces = "".join(string.split())
57+
return encode_to_bytes(binascii.unhexlify(string_no_spaces))
58+
59+
60+
@dataclass
61+
class MocketRecord:
62+
host: str
63+
port: int
64+
request: bytes
65+
response: bytes
66+
67+
68+
class MocketRecordStorage:
69+
def __init__(self, directory: Path, namespace: str) -> None:
70+
self._directory = directory
71+
self._namespace = namespace
72+
self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = (
73+
defaultdict(defaultdict)
74+
)
75+
76+
self._load()
77+
78+
@property
79+
def directory(self) -> Path:
80+
return self._directory
81+
82+
@property
83+
def namespace(self) -> str:
84+
return self._namespace
85+
86+
@property
87+
def file(self) -> Path:
88+
return self._directory / f"{self._namespace}.json"
89+
90+
def _load(self) -> None:
91+
if not self.file.exists():
92+
return
93+
94+
json_data = self.file.read_text()
95+
records = json.loads(json_data)
96+
for host, port_signature_record in records.items():
97+
for port, signature_record in port_signature_record.items():
98+
for signature, record in signature_record.items():
99+
# NOTE backward-compat
100+
try:
101+
request_data = hexload(record["request"])
102+
except binascii.Error:
103+
request_data = record["request"]
104+
105+
self._records[(host, int(port))][signature] = MocketRecord(
106+
host=host,
107+
port=port,
108+
request=request_data,
109+
response=hexload(record["response"]),
110+
)
111+
112+
def _save(self) -> None:
113+
data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict(
114+
lambda: defaultdict(defaultdict)
115+
)
116+
for address, signature_record in self._records.items():
117+
host, port = address
118+
for signature, record in signature_record.items():
119+
data[host][str(port)][signature] = dict(
120+
request=decode_from_bytes(record.request),
121+
response=hexdump(record.response),
122+
)
123+
124+
json_data = json.dumps(data, indent=4, sort_keys=True)
125+
self.file.parent.mkdir(exist_ok=True)
126+
self.file.write_text(json_data)
127+
128+
def get_records(self, address: Address) -> list[MocketRecord]:
129+
return list(self._records[address].values())
130+
131+
def get_record(self, address: Address, request: bytes) -> MocketRecord | None:
132+
# NOTE for backward-compat
133+
request_signature_fallback = _hash_request_fallback(request)
134+
if request_signature_fallback in self._records[address]:
135+
return self._records[address].get(request_signature_fallback)
136+
137+
request_signature = _hash_request(request)
138+
if request_signature in self._records[address]:
139+
return self._records[address][request_signature]
140+
141+
return None
142+
143+
def put_record(
144+
self,
145+
address: Address,
146+
request: bytes,
147+
response: bytes,
148+
) -> None:
149+
host, port = address
150+
record = MocketRecord(
151+
host=host,
152+
port=port,
153+
request=request,
154+
response=response,
155+
)
156+
157+
# NOTE for backward-compat
158+
request_signature_fallback = _hash_request_fallback(request)
159+
if request_signature_fallback in self._records[address]:
160+
self._records[address][request_signature_fallback] = record
161+
return
162+
163+
request_signature = _hash_request(request)
164+
self._records[address][request_signature] = record
165+
self._save()

0 commit comments

Comments
 (0)