Skip to content

Commit 89055e8

Browse files
authored
Refactor: introduce state object (#264)
* refactor: move enable- and disable-functions from mocket.mocket to mocket.inject * refactor: Mocket - add typing and get rid of cyclic import
1 parent 2beb8f1 commit 89055e8

File tree

10 files changed

+186
-178
lines changed

10 files changed

+186
-178
lines changed

mocket/entry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections.abc
22

33
from mocket.compat import encode_to_bytes
4+
from mocket.mocket import Mocket
45

56

67
class MocketEntry:
@@ -41,8 +42,6 @@ def can_handle(data):
4142
return True
4243

4344
def collect(self, data):
44-
from mocket import Mocket
45-
4645
req = self.request_cls(data)
4746
Mocket.collect(req)
4847

mocket/inject.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import socket
5+
import ssl
6+
7+
import urllib3
8+
from urllib3.connection import match_hostname as urllib3_match_hostname
9+
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
10+
11+
try:
12+
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
13+
except ImportError:
14+
urllib3_wrap_socket = None
15+
16+
17+
try: # pragma: no cover
18+
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
19+
20+
pyopenssl_override = True
21+
except ImportError:
22+
pyopenssl_override = False
23+
24+
true_socket = socket.socket
25+
true_create_connection = socket.create_connection
26+
true_gethostbyname = socket.gethostbyname
27+
true_gethostname = socket.gethostname
28+
true_getaddrinfo = socket.getaddrinfo
29+
true_socketpair = socket.socketpair
30+
true_ssl_wrap_socket = getattr(
31+
ssl, "wrap_socket", None
32+
) # from Py3.12 it's only under SSLContext
33+
true_ssl_socket = ssl.SSLSocket
34+
true_ssl_context = ssl.SSLContext
35+
true_inet_pton = socket.inet_pton
36+
true_urllib3_wrap_socket = urllib3_wrap_socket
37+
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
38+
true_urllib3_match_hostname = urllib3_match_hostname
39+
40+
41+
def enable(
42+
namespace: str | None = None,
43+
truesocket_recording_dir: str | None = None,
44+
) -> None:
45+
from mocket.mocket import Mocket
46+
from mocket.socket import MocketSocket, create_connection, socketpair
47+
from mocket.ssl import FakeSSLContext
48+
49+
Mocket._namespace = namespace
50+
Mocket._truesocket_recording_dir = truesocket_recording_dir
51+
52+
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
53+
# JSON dumps will be saved here
54+
raise AssertionError
55+
56+
socket.socket = socket.__dict__["socket"] = MocketSocket
57+
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
58+
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
59+
socket.create_connection = socket.__dict__["create_connection"] = create_connection
60+
socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
61+
socket.gethostbyname = socket.__dict__["gethostbyname"] = lambda host: "127.0.0.1"
62+
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = (
63+
lambda host, port, family=None, socktype=None, proto=None, flags=None: [
64+
(2, 1, 6, "", (host, port))
65+
]
66+
)
67+
socket.socketpair = socket.__dict__["socketpair"] = socketpair
68+
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
69+
ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
70+
socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes(
71+
"\x7f\x00\x00\x01", "utf-8"
72+
)
73+
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
74+
FakeSSLContext.wrap_socket
75+
)
76+
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
77+
"ssl_wrap_socket"
78+
] = FakeSSLContext.wrap_socket
79+
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
80+
FakeSSLContext.wrap_socket
81+
)
82+
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
83+
"ssl_wrap_socket"
84+
] = FakeSSLContext.wrap_socket
85+
urllib3.connection.match_hostname = urllib3.connection.__dict__[
86+
"match_hostname"
87+
] = lambda *args: None
88+
if pyopenssl_override: # pragma: no cover
89+
# Take out the pyopenssl version - use the default implementation
90+
extract_from_urllib3()
91+
92+
93+
def disable() -> None:
94+
from mocket.mocket import Mocket
95+
96+
socket.socket = socket.__dict__["socket"] = true_socket
97+
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
98+
socket.SocketType = socket.__dict__["SocketType"] = true_socket
99+
socket.create_connection = socket.__dict__["create_connection"] = (
100+
true_create_connection
101+
)
102+
socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
103+
socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
104+
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
105+
socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
106+
if true_ssl_wrap_socket:
107+
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
108+
ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
109+
socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
110+
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
111+
true_urllib3_wrap_socket
112+
)
113+
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
114+
"ssl_wrap_socket"
115+
] = true_urllib3_ssl_wrap_socket
116+
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
117+
true_urllib3_ssl_wrap_socket
118+
)
119+
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
120+
"ssl_wrap_socket"
121+
] = true_urllib3_ssl_wrap_socket
122+
urllib3.connection.match_hostname = urllib3.connection.__dict__[
123+
"match_hostname"
124+
] = true_urllib3_match_hostname
125+
Mocket.reset()
126+
if pyopenssl_override: # pragma: no cover
127+
# Put the pyopenssl version back in place
128+
inject_into_urllib3()

mocket/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import io
22
import os
33

4+
from mocket.mocket import Mocket
5+
46

57
class MocketSocketCore(io.BytesIO):
68
def __init__(self, address) -> None:
79
self._address = address
810
super().__init__()
911

1012
def write(self, content):
11-
from mocket import Mocket
12-
1313
super().write(content)
1414

1515
_, w_fd = Mocket.get_pair(self._address)

0 commit comments

Comments
 (0)