Skip to content

Commit 14513de

Browse files
committed
refactor: move MocketSocket from mocket.mocket to mocket.socket
1 parent 012df13 commit 14513de

File tree

3 files changed

+348
-322
lines changed

3 files changed

+348
-322
lines changed

mocket/mocket.py

Lines changed: 1 addition & 321 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
import collections
2-
import contextlib
3-
import errno
4-
import hashlib
52
import itertools
6-
import json
73
import os
8-
import select
94
import socket
105
import ssl
11-
from datetime import datetime, timedelta
12-
from json.decoder import JSONDecodeError
136
from typing import Optional, Tuple
147

158
import urllib3
@@ -22,19 +15,8 @@
2215
urllib3_wrap_socket = None
2316

2417

25-
from mocket.compat import decode_from_bytes, encode_to_bytes
26-
from mocket.io import MocketSocketCore
27-
from mocket.mode import MocketMode
18+
from mocket.socket import MocketSocket, create_connection, socketpair
2819
from mocket.ssl import FakeSSLContext
29-
from mocket.utils import hexdump, hexload
30-
31-
xxh32 = None
32-
try:
33-
from xxhash import xxh32
34-
except ImportError: # pragma: no cover
35-
with contextlib.suppress(ImportError):
36-
from xxhash_cffi import xxh32
37-
hasher = xxh32 or hashlib.md5
3820

3921
try: # pragma: no cover
4022
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
@@ -60,308 +42,6 @@
6042
true_urllib3_match_hostname = urllib3_match_hostname
6143

6244

63-
def create_connection(address, timeout=None, source_address=None):
64-
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
65-
if timeout:
66-
s.settimeout(timeout)
67-
s.connect(address)
68-
return s
69-
70-
71-
def socketpair(*args, **kwargs):
72-
"""Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
73-
import _socket
74-
75-
return _socket.socketpair(*args, **kwargs)
76-
77-
78-
def _hash_request(h, req):
79-
return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
80-
81-
82-
class MocketSocket:
83-
timeout = None
84-
_fd = None
85-
family = None
86-
type = None
87-
proto = None
88-
_host = None
89-
_port = None
90-
_address = None
91-
cipher = lambda s: ("ADH", "AES256", "SHA")
92-
compression = lambda s: ssl.OP_NO_COMPRESSION
93-
_mode = None
94-
_bufsize = None
95-
_secure_socket = False
96-
_did_handshake = False
97-
_sent_non_empty_bytes = False
98-
_io = None
99-
100-
def __init__(
101-
self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
102-
):
103-
self.true_socket = true_socket(family, type, proto)
104-
self._buflen = 65536
105-
self._entry = None
106-
self.family = int(family)
107-
self.type = int(type)
108-
self.proto = int(proto)
109-
self._truesocket_recording_dir = None
110-
self.kwargs = kwargs
111-
112-
def __str__(self):
113-
return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
114-
115-
def __enter__(self):
116-
return self
117-
118-
def __exit__(self, exc_type, exc_val, exc_tb):
119-
self.close()
120-
121-
@property
122-
def io(self):
123-
if self._io is None:
124-
self._io = MocketSocketCore((self._host, self._port))
125-
return self._io
126-
127-
def fileno(self):
128-
address = (self._host, self._port)
129-
r_fd, _ = Mocket.get_pair(address)
130-
if not r_fd:
131-
r_fd, w_fd = os.pipe()
132-
Mocket.set_pair(address, (r_fd, w_fd))
133-
return r_fd
134-
135-
def gettimeout(self):
136-
return self.timeout
137-
138-
def setsockopt(self, family, type, proto):
139-
self.family = family
140-
self.type = type
141-
self.proto = proto
142-
143-
if self.true_socket:
144-
self.true_socket.setsockopt(family, type, proto)
145-
146-
def settimeout(self, timeout):
147-
self.timeout = timeout
148-
149-
@staticmethod
150-
def getsockopt(level, optname, buflen=None):
151-
return socket.SOCK_STREAM
152-
153-
def do_handshake(self):
154-
self._did_handshake = True
155-
156-
def getpeername(self):
157-
return self._address
158-
159-
def setblocking(self, block):
160-
self.settimeout(None) if block else self.settimeout(0.0)
161-
162-
def getblocking(self):
163-
return self.gettimeout() is None
164-
165-
def getsockname(self):
166-
return socket.gethostbyname(self._address[0]), self._address[1]
167-
168-
def getpeercert(self, *args, **kwargs):
169-
if not (self._host and self._port):
170-
self._address = self._host, self._port = Mocket._address
171-
172-
now = datetime.now()
173-
shift = now + timedelta(days=30 * 12)
174-
return {
175-
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
176-
"subjectAltName": (
177-
("DNS", f"*.{self._host}"),
178-
("DNS", self._host),
179-
("DNS", "*"),
180-
),
181-
"subject": (
182-
(("organizationName", f"*.{self._host}"),),
183-
(("organizationalUnitName", "Domain Control Validated"),),
184-
(("commonName", f"*.{self._host}"),),
185-
),
186-
}
187-
188-
def unwrap(self):
189-
return self
190-
191-
def write(self, data):
192-
return self.send(encode_to_bytes(data))
193-
194-
def connect(self, address):
195-
self._address = self._host, self._port = address
196-
Mocket._address = address
197-
198-
def makefile(self, mode="r", bufsize=-1):
199-
self._mode = mode
200-
self._bufsize = bufsize
201-
return self.io
202-
203-
def get_entry(self, data):
204-
return Mocket.get_entry(self._host, self._port, data)
205-
206-
def sendall(self, data, entry=None, *args, **kwargs):
207-
if entry is None:
208-
entry = self.get_entry(data)
209-
210-
if entry:
211-
consume_response = entry.collect(data)
212-
response = entry.get_response() if consume_response is not False else None
213-
else:
214-
response = self.true_sendall(data, *args, **kwargs)
215-
216-
if response is not None:
217-
self.io.seek(0)
218-
self.io.write(response)
219-
self.io.truncate()
220-
self.io.seek(0)
221-
222-
def read(self, buffersize):
223-
rv = self.io.read(buffersize)
224-
if rv:
225-
self._sent_non_empty_bytes = True
226-
if self._did_handshake and not self._sent_non_empty_bytes:
227-
raise ssl.SSLWantReadError("The operation did not complete (read)")
228-
return rv
229-
230-
def recv_into(self, buffer, buffersize=None, flags=None):
231-
if hasattr(buffer, "write"):
232-
return buffer.write(self.read(buffersize))
233-
# buffer is a memoryview
234-
data = self.read(buffersize)
235-
if data:
236-
buffer[: len(data)] = data
237-
return len(data)
238-
239-
def recv(self, buffersize, flags=None):
240-
r_fd, _ = Mocket.get_pair((self._host, self._port))
241-
if r_fd:
242-
return os.read(r_fd, buffersize)
243-
data = self.read(buffersize)
244-
if data:
245-
return data
246-
# used by Redis mock
247-
exc = BlockingIOError()
248-
exc.errno = errno.EWOULDBLOCK
249-
exc.args = (0,)
250-
raise exc
251-
252-
def true_sendall(self, data, *args, **kwargs):
253-
if not MocketMode().is_allowed((self._host, self._port)):
254-
MocketMode.raise_not_allowed()
255-
256-
req = decode_from_bytes(data)
257-
# make request unique again
258-
req_signature = _hash_request(hasher, req)
259-
# port should be always a string
260-
port = str(self._port)
261-
262-
# prepare responses dictionary
263-
responses = {}
264-
265-
if Mocket.get_truesocket_recording_dir():
266-
path = os.path.join(
267-
Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json"
268-
)
269-
# check if there's already a recorded session dumped to a JSON file
270-
try:
271-
with open(path) as f:
272-
responses = json.load(f)
273-
# if not, create a new dictionary
274-
except (FileNotFoundError, JSONDecodeError):
275-
pass
276-
277-
try:
278-
try:
279-
response_dict = responses[self._host][port][req_signature]
280-
except KeyError:
281-
if hasher is not hashlib.md5:
282-
# Fallback for backwards compatibility
283-
req_signature = _hash_request(hashlib.md5, req)
284-
response_dict = responses[self._host][port][req_signature]
285-
else:
286-
raise
287-
except KeyError:
288-
# preventing next KeyError exceptions
289-
responses.setdefault(self._host, {})
290-
responses[self._host].setdefault(port, {})
291-
responses[self._host][port].setdefault(req_signature, {})
292-
response_dict = responses[self._host][port][req_signature]
293-
294-
# try to get the response from the dictionary
295-
try:
296-
encoded_response = hexload(response_dict["response"])
297-
# if not available, call the real sendall
298-
except KeyError:
299-
host, port = self._host, self._port
300-
host = true_gethostbyname(host)
301-
302-
if isinstance(self.true_socket, true_socket) and self._secure_socket:
303-
self.true_socket = true_urllib3_ssl_wrap_socket(
304-
self.true_socket,
305-
**self.kwargs,
306-
)
307-
308-
with contextlib.suppress(OSError, ValueError):
309-
# already connected
310-
self.true_socket.connect((host, port))
311-
self.true_socket.sendall(data, *args, **kwargs)
312-
encoded_response = b""
313-
# https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
314-
while True:
315-
more_to_read = select.select([self.true_socket], [], [], 0.1)[0]
316-
if not more_to_read and encoded_response:
317-
break
318-
new_content = self.true_socket.recv(self._buflen)
319-
if not new_content:
320-
break
321-
encoded_response += new_content
322-
323-
# dump the resulting dictionary to a JSON file
324-
if Mocket.get_truesocket_recording_dir():
325-
# update the dictionary with request and response lines
326-
response_dict["request"] = req
327-
response_dict["response"] = hexdump(encoded_response)
328-
329-
with open(path, mode="w") as f:
330-
f.write(
331-
decode_from_bytes(
332-
json.dumps(responses, indent=4, sort_keys=True)
333-
)
334-
)
335-
336-
# response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
337-
return encoded_response
338-
339-
def send(self, data, *args, **kwargs): # pragma: no cover
340-
entry = self.get_entry(data)
341-
if not entry or (entry and self._entry != entry):
342-
kwargs["entry"] = entry
343-
self.sendall(data, *args, **kwargs)
344-
else:
345-
req = Mocket.last_request()
346-
if hasattr(req, "add_data"):
347-
req.add_data(data)
348-
self._entry = entry
349-
return len(data)
350-
351-
def close(self):
352-
if self.true_socket and not self.true_socket._closed:
353-
self.true_socket.close()
354-
self._fd = None
355-
356-
def __getattr__(self, name):
357-
"""Do nothing catchall function, for methods like shutdown()"""
358-
359-
def do_nothing(*args, **kwargs):
360-
pass
361-
362-
return do_nothing
363-
364-
36545
class Mocket:
36646
_socket_pairs = {}
36747
_address = (None, None)

0 commit comments

Comments
 (0)