|
1 | 1 | import collections
|
2 |
| -import contextlib |
3 |
| -import errno |
4 |
| -import hashlib |
5 | 2 | import itertools
|
6 |
| -import json |
7 | 3 | import os
|
8 |
| -import select |
9 | 4 | import socket
|
10 | 5 | import ssl
|
11 |
| -from datetime import datetime, timedelta |
12 |
| -from json.decoder import JSONDecodeError |
13 | 6 | from typing import Optional, Tuple
|
14 | 7 |
|
15 | 8 | import urllib3
|
|
22 | 15 | urllib3_wrap_socket = None
|
23 | 16 |
|
24 | 17 |
|
25 |
| -from mocket.compat import decode_from_bytes, encode_to_bytes |
26 |
| -from mocket.io import MocketSocketCore |
27 |
| -from mocket.mode import MocketMode |
| 18 | +from mocket.socket import MocketSocket, create_connection, socketpair |
28 | 19 | from mocket.ssl import FakeSSLContext
|
29 |
| -from mocket.utils import hexdump, hexload |
30 |
| - |
31 |
| -xxh32 = None |
32 |
| -try: |
33 |
| - from xxhash import xxh32 |
34 |
| -except ImportError: # pragma: no cover |
35 |
| - with contextlib.suppress(ImportError): |
36 |
| - from xxhash_cffi import xxh32 |
37 |
| -hasher = xxh32 or hashlib.md5 |
38 | 20 |
|
39 | 21 | try: # pragma: no cover
|
40 | 22 | from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
|
|
60 | 42 | true_urllib3_match_hostname = urllib3_match_hostname
|
61 | 43 |
|
62 | 44 |
|
63 |
| -def create_connection(address, timeout=None, source_address=None): |
64 |
| - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) |
65 |
| - if timeout: |
66 |
| - s.settimeout(timeout) |
67 |
| - s.connect(address) |
68 |
| - return s |
69 |
| - |
70 |
| - |
71 |
| -def socketpair(*args, **kwargs): |
72 |
| - """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" |
73 |
| - import _socket |
74 |
| - |
75 |
| - return _socket.socketpair(*args, **kwargs) |
76 |
| - |
77 |
| - |
78 |
| -def _hash_request(h, req): |
79 |
| - return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() |
80 |
| - |
81 |
| - |
82 |
| -class MocketSocket: |
83 |
| - timeout = None |
84 |
| - _fd = None |
85 |
| - family = None |
86 |
| - type = None |
87 |
| - proto = None |
88 |
| - _host = None |
89 |
| - _port = None |
90 |
| - _address = None |
91 |
| - cipher = lambda s: ("ADH", "AES256", "SHA") |
92 |
| - compression = lambda s: ssl.OP_NO_COMPRESSION |
93 |
| - _mode = None |
94 |
| - _bufsize = None |
95 |
| - _secure_socket = False |
96 |
| - _did_handshake = False |
97 |
| - _sent_non_empty_bytes = False |
98 |
| - _io = None |
99 |
| - |
100 |
| - def __init__( |
101 |
| - self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs |
102 |
| - ): |
103 |
| - self.true_socket = true_socket(family, type, proto) |
104 |
| - self._buflen = 65536 |
105 |
| - self._entry = None |
106 |
| - self.family = int(family) |
107 |
| - self.type = int(type) |
108 |
| - self.proto = int(proto) |
109 |
| - self._truesocket_recording_dir = None |
110 |
| - self.kwargs = kwargs |
111 |
| - |
112 |
| - def __str__(self): |
113 |
| - return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" |
114 |
| - |
115 |
| - def __enter__(self): |
116 |
| - return self |
117 |
| - |
118 |
| - def __exit__(self, exc_type, exc_val, exc_tb): |
119 |
| - self.close() |
120 |
| - |
121 |
| - @property |
122 |
| - def io(self): |
123 |
| - if self._io is None: |
124 |
| - self._io = MocketSocketCore((self._host, self._port)) |
125 |
| - return self._io |
126 |
| - |
127 |
| - def fileno(self): |
128 |
| - address = (self._host, self._port) |
129 |
| - r_fd, _ = Mocket.get_pair(address) |
130 |
| - if not r_fd: |
131 |
| - r_fd, w_fd = os.pipe() |
132 |
| - Mocket.set_pair(address, (r_fd, w_fd)) |
133 |
| - return r_fd |
134 |
| - |
135 |
| - def gettimeout(self): |
136 |
| - return self.timeout |
137 |
| - |
138 |
| - def setsockopt(self, family, type, proto): |
139 |
| - self.family = family |
140 |
| - self.type = type |
141 |
| - self.proto = proto |
142 |
| - |
143 |
| - if self.true_socket: |
144 |
| - self.true_socket.setsockopt(family, type, proto) |
145 |
| - |
146 |
| - def settimeout(self, timeout): |
147 |
| - self.timeout = timeout |
148 |
| - |
149 |
| - @staticmethod |
150 |
| - def getsockopt(level, optname, buflen=None): |
151 |
| - return socket.SOCK_STREAM |
152 |
| - |
153 |
| - def do_handshake(self): |
154 |
| - self._did_handshake = True |
155 |
| - |
156 |
| - def getpeername(self): |
157 |
| - return self._address |
158 |
| - |
159 |
| - def setblocking(self, block): |
160 |
| - self.settimeout(None) if block else self.settimeout(0.0) |
161 |
| - |
162 |
| - def getblocking(self): |
163 |
| - return self.gettimeout() is None |
164 |
| - |
165 |
| - def getsockname(self): |
166 |
| - return socket.gethostbyname(self._address[0]), self._address[1] |
167 |
| - |
168 |
| - def getpeercert(self, *args, **kwargs): |
169 |
| - if not (self._host and self._port): |
170 |
| - self._address = self._host, self._port = Mocket._address |
171 |
| - |
172 |
| - now = datetime.now() |
173 |
| - shift = now + timedelta(days=30 * 12) |
174 |
| - return { |
175 |
| - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), |
176 |
| - "subjectAltName": ( |
177 |
| - ("DNS", f"*.{self._host}"), |
178 |
| - ("DNS", self._host), |
179 |
| - ("DNS", "*"), |
180 |
| - ), |
181 |
| - "subject": ( |
182 |
| - (("organizationName", f"*.{self._host}"),), |
183 |
| - (("organizationalUnitName", "Domain Control Validated"),), |
184 |
| - (("commonName", f"*.{self._host}"),), |
185 |
| - ), |
186 |
| - } |
187 |
| - |
188 |
| - def unwrap(self): |
189 |
| - return self |
190 |
| - |
191 |
| - def write(self, data): |
192 |
| - return self.send(encode_to_bytes(data)) |
193 |
| - |
194 |
| - def connect(self, address): |
195 |
| - self._address = self._host, self._port = address |
196 |
| - Mocket._address = address |
197 |
| - |
198 |
| - def makefile(self, mode="r", bufsize=-1): |
199 |
| - self._mode = mode |
200 |
| - self._bufsize = bufsize |
201 |
| - return self.io |
202 |
| - |
203 |
| - def get_entry(self, data): |
204 |
| - return Mocket.get_entry(self._host, self._port, data) |
205 |
| - |
206 |
| - def sendall(self, data, entry=None, *args, **kwargs): |
207 |
| - if entry is None: |
208 |
| - entry = self.get_entry(data) |
209 |
| - |
210 |
| - if entry: |
211 |
| - consume_response = entry.collect(data) |
212 |
| - response = entry.get_response() if consume_response is not False else None |
213 |
| - else: |
214 |
| - response = self.true_sendall(data, *args, **kwargs) |
215 |
| - |
216 |
| - if response is not None: |
217 |
| - self.io.seek(0) |
218 |
| - self.io.write(response) |
219 |
| - self.io.truncate() |
220 |
| - self.io.seek(0) |
221 |
| - |
222 |
| - def read(self, buffersize): |
223 |
| - rv = self.io.read(buffersize) |
224 |
| - if rv: |
225 |
| - self._sent_non_empty_bytes = True |
226 |
| - if self._did_handshake and not self._sent_non_empty_bytes: |
227 |
| - raise ssl.SSLWantReadError("The operation did not complete (read)") |
228 |
| - return rv |
229 |
| - |
230 |
| - def recv_into(self, buffer, buffersize=None, flags=None): |
231 |
| - if hasattr(buffer, "write"): |
232 |
| - return buffer.write(self.read(buffersize)) |
233 |
| - # buffer is a memoryview |
234 |
| - data = self.read(buffersize) |
235 |
| - if data: |
236 |
| - buffer[: len(data)] = data |
237 |
| - return len(data) |
238 |
| - |
239 |
| - def recv(self, buffersize, flags=None): |
240 |
| - r_fd, _ = Mocket.get_pair((self._host, self._port)) |
241 |
| - if r_fd: |
242 |
| - return os.read(r_fd, buffersize) |
243 |
| - data = self.read(buffersize) |
244 |
| - if data: |
245 |
| - return data |
246 |
| - # used by Redis mock |
247 |
| - exc = BlockingIOError() |
248 |
| - exc.errno = errno.EWOULDBLOCK |
249 |
| - exc.args = (0,) |
250 |
| - raise exc |
251 |
| - |
252 |
| - def true_sendall(self, data, *args, **kwargs): |
253 |
| - if not MocketMode().is_allowed((self._host, self._port)): |
254 |
| - MocketMode.raise_not_allowed() |
255 |
| - |
256 |
| - req = decode_from_bytes(data) |
257 |
| - # make request unique again |
258 |
| - req_signature = _hash_request(hasher, req) |
259 |
| - # port should be always a string |
260 |
| - port = str(self._port) |
261 |
| - |
262 |
| - # prepare responses dictionary |
263 |
| - responses = {} |
264 |
| - |
265 |
| - if Mocket.get_truesocket_recording_dir(): |
266 |
| - path = os.path.join( |
267 |
| - Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json" |
268 |
| - ) |
269 |
| - # check if there's already a recorded session dumped to a JSON file |
270 |
| - try: |
271 |
| - with open(path) as f: |
272 |
| - responses = json.load(f) |
273 |
| - # if not, create a new dictionary |
274 |
| - except (FileNotFoundError, JSONDecodeError): |
275 |
| - pass |
276 |
| - |
277 |
| - try: |
278 |
| - try: |
279 |
| - response_dict = responses[self._host][port][req_signature] |
280 |
| - except KeyError: |
281 |
| - if hasher is not hashlib.md5: |
282 |
| - # Fallback for backwards compatibility |
283 |
| - req_signature = _hash_request(hashlib.md5, req) |
284 |
| - response_dict = responses[self._host][port][req_signature] |
285 |
| - else: |
286 |
| - raise |
287 |
| - except KeyError: |
288 |
| - # preventing next KeyError exceptions |
289 |
| - responses.setdefault(self._host, {}) |
290 |
| - responses[self._host].setdefault(port, {}) |
291 |
| - responses[self._host][port].setdefault(req_signature, {}) |
292 |
| - response_dict = responses[self._host][port][req_signature] |
293 |
| - |
294 |
| - # try to get the response from the dictionary |
295 |
| - try: |
296 |
| - encoded_response = hexload(response_dict["response"]) |
297 |
| - # if not available, call the real sendall |
298 |
| - except KeyError: |
299 |
| - host, port = self._host, self._port |
300 |
| - host = true_gethostbyname(host) |
301 |
| - |
302 |
| - if isinstance(self.true_socket, true_socket) and self._secure_socket: |
303 |
| - self.true_socket = true_urllib3_ssl_wrap_socket( |
304 |
| - self.true_socket, |
305 |
| - **self.kwargs, |
306 |
| - ) |
307 |
| - |
308 |
| - with contextlib.suppress(OSError, ValueError): |
309 |
| - # already connected |
310 |
| - self.true_socket.connect((host, port)) |
311 |
| - self.true_socket.sendall(data, *args, **kwargs) |
312 |
| - encoded_response = b"" |
313 |
| - # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 |
314 |
| - while True: |
315 |
| - more_to_read = select.select([self.true_socket], [], [], 0.1)[0] |
316 |
| - if not more_to_read and encoded_response: |
317 |
| - break |
318 |
| - new_content = self.true_socket.recv(self._buflen) |
319 |
| - if not new_content: |
320 |
| - break |
321 |
| - encoded_response += new_content |
322 |
| - |
323 |
| - # dump the resulting dictionary to a JSON file |
324 |
| - if Mocket.get_truesocket_recording_dir(): |
325 |
| - # update the dictionary with request and response lines |
326 |
| - response_dict["request"] = req |
327 |
| - response_dict["response"] = hexdump(encoded_response) |
328 |
| - |
329 |
| - with open(path, mode="w") as f: |
330 |
| - f.write( |
331 |
| - decode_from_bytes( |
332 |
| - json.dumps(responses, indent=4, sort_keys=True) |
333 |
| - ) |
334 |
| - ) |
335 |
| - |
336 |
| - # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO |
337 |
| - return encoded_response |
338 |
| - |
339 |
| - def send(self, data, *args, **kwargs): # pragma: no cover |
340 |
| - entry = self.get_entry(data) |
341 |
| - if not entry or (entry and self._entry != entry): |
342 |
| - kwargs["entry"] = entry |
343 |
| - self.sendall(data, *args, **kwargs) |
344 |
| - else: |
345 |
| - req = Mocket.last_request() |
346 |
| - if hasattr(req, "add_data"): |
347 |
| - req.add_data(data) |
348 |
| - self._entry = entry |
349 |
| - return len(data) |
350 |
| - |
351 |
| - def close(self): |
352 |
| - if self.true_socket and not self.true_socket._closed: |
353 |
| - self.true_socket.close() |
354 |
| - self._fd = None |
355 |
| - |
356 |
| - def __getattr__(self, name): |
357 |
| - """Do nothing catchall function, for methods like shutdown()""" |
358 |
| - |
359 |
| - def do_nothing(*args, **kwargs): |
360 |
| - pass |
361 |
| - |
362 |
| - return do_nothing |
363 |
| - |
364 |
| - |
365 | 45 | class Mocket:
|
366 | 46 | _socket_pairs = {}
|
367 | 47 | _address = (None, None)
|
|
0 commit comments