Skip to content

Commit a5b5e34

Browse files
authored
improve injection code, make backwards compat explicit, make ssl-api explicit (#268)
* refactor: make injection code more readable and make backwards-compat explicit * refactor: move ssl socket-wrapping code to ssl/socket.py * refactor: convert MocketSSLContext.wrap_socket and wrap_bio to instance-methods * refactor: MocketSSLSocket use proper ssl-context instead of urllib3
1 parent 0da2722 commit a5b5e34

File tree

7 files changed

+132
-146
lines changed

7 files changed

+132
-146
lines changed

mocket/inject.py

Lines changed: 60 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import os
45
import socket
56
import ssl
7+
from types import ModuleType
8+
from typing import Any
69

710
import urllib3
811

9-
try: # pragma: no cover
10-
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
12+
_patches_restore: dict[tuple[ModuleType, str], Any] = {}
1113

12-
pyopenssl_override = True
13-
except ImportError:
14-
pyopenssl_override = False
14+
15+
def _patch(module: ModuleType, name: str, patched_value: Any) -> None:
16+
with contextlib.suppress(KeyError):
17+
original_value, module.__dict__[name] = module.__dict__[name], patched_value
18+
_patches_restore[(module, name)] = original_value
19+
20+
21+
def _restore(module: ModuleType, name: str) -> None:
22+
if original_value := _patches_restore.pop((module, name)):
23+
module.__dict__[name] = original_value
1524

1625

1726
def enable(
1827
namespace: str | None = None,
1928
truesocket_recording_dir: str | None = None,
2029
) -> None:
21-
from mocket.mocket import Mocket
2230
from mocket.socket import (
2331
MocketSocket,
2432
mock_create_connection,
@@ -27,99 +35,62 @@ def enable(
2735
mock_gethostname,
2836
mock_inet_pton,
2937
mock_socketpair,
30-
mock_urllib3_match_hostname,
3138
)
32-
from mocket.ssl.context import MocketSSLContext
39+
from mocket.ssl.context import MocketSSLContext, mock_wrap_socket
40+
from mocket.urllib3 import (
41+
mock_match_hostname as mock_urllib3_match_hostname,
42+
)
43+
from mocket.urllib3 import (
44+
mock_ssl_wrap_socket as mock_urllib3_ssl_wrap_socket,
45+
)
46+
47+
patches = {
48+
# stdlib: socket
49+
(socket, "socket"): MocketSocket,
50+
(socket, "create_connection"): mock_create_connection,
51+
(socket, "getaddrinfo"): mock_getaddrinfo,
52+
(socket, "gethostbyname"): mock_gethostbyname,
53+
(socket, "gethostname"): mock_gethostname,
54+
(socket, "inet_pton"): mock_inet_pton,
55+
(socket, "SocketType"): MocketSocket,
56+
(socket, "socketpair"): mock_socketpair,
57+
# stdlib: ssl
58+
(ssl, "SSLContext"): MocketSSLContext,
59+
(ssl, "wrap_socket"): mock_wrap_socket, # python < 3.12.0
60+
# urllib3
61+
(urllib3.connection, "match_hostname"): mock_urllib3_match_hostname,
62+
(urllib3.connection, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket,
63+
(urllib3.util, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket,
64+
(urllib3.util.ssl_, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket,
65+
(urllib3.util.ssl_, "wrap_socket"): mock_urllib3_ssl_wrap_socket, # urllib3 < 2
66+
}
67+
68+
for (module, name), new_value in patches.items():
69+
_patch(module, name, new_value)
70+
71+
with contextlib.suppress(ImportError):
72+
from urllib3.contrib.pyopenssl import extract_from_urllib3
73+
74+
extract_from_urllib3()
75+
76+
from mocket.mocket import Mocket
3377

3478
Mocket._namespace = namespace
3579
Mocket._truesocket_recording_dir = truesocket_recording_dir
36-
3780
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
3881
# JSON dumps will be saved here
3982
raise AssertionError
4083

41-
socket.socket = socket.__dict__["socket"] = MocketSocket
42-
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
43-
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
44-
socket.create_connection = socket.__dict__["create_connection"] = (
45-
mock_create_connection
46-
)
47-
socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname
48-
socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname
49-
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo
50-
socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair
51-
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket
52-
ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext
53-
socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton
54-
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
55-
MocketSSLContext.wrap_socket
56-
)
57-
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
58-
"ssl_wrap_socket"
59-
] = MocketSSLContext.wrap_socket
60-
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
61-
MocketSSLContext.wrap_socket
62-
)
63-
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
64-
"ssl_wrap_socket"
65-
] = MocketSSLContext.wrap_socket
66-
urllib3.connection.match_hostname = urllib3.connection.__dict__[
67-
"match_hostname"
68-
] = mock_urllib3_match_hostname
69-
if pyopenssl_override: # pragma: no cover
70-
# Take out the pyopenssl version - use the default implementation
71-
extract_from_urllib3()
72-
7384

7485
def disable() -> None:
86+
for module, name in list(_patches_restore.keys()):
87+
_restore(module, name)
88+
89+
with contextlib.suppress(ImportError):
90+
from urllib3.contrib.pyopenssl import inject_into_urllib3
91+
92+
inject_into_urllib3()
93+
7594
from mocket.mocket import Mocket
76-
from mocket.socket import (
77-
true_create_connection,
78-
true_getaddrinfo,
79-
true_gethostbyname,
80-
true_gethostname,
81-
true_inet_pton,
82-
true_socket,
83-
true_socketpair,
84-
true_urllib3_match_hostname,
85-
)
86-
from mocket.ssl.context import (
87-
true_ssl_context,
88-
true_ssl_wrap_socket,
89-
true_urllib3_ssl_wrap_socket,
90-
true_urllib3_wrap_socket,
91-
)
9295

93-
socket.socket = socket.__dict__["socket"] = true_socket
94-
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
95-
socket.SocketType = socket.__dict__["SocketType"] = true_socket
96-
socket.create_connection = socket.__dict__["create_connection"] = (
97-
true_create_connection
98-
)
99-
socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
100-
socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
101-
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
102-
socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
103-
if true_ssl_wrap_socket:
104-
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
105-
ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
106-
socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
107-
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
108-
true_urllib3_wrap_socket
109-
)
110-
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
111-
"ssl_wrap_socket"
112-
] = true_urllib3_ssl_wrap_socket
113-
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
114-
true_urllib3_ssl_wrap_socket
115-
)
116-
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
117-
"ssl_wrap_socket"
118-
] = true_urllib3_ssl_wrap_socket
119-
urllib3.connection.match_hostname = urllib3.connection.__dict__[
120-
"match_hostname"
121-
] = true_urllib3_match_hostname
12296
Mocket.reset()
123-
if pyopenssl_override: # pragma: no cover
124-
# Put the pyopenssl version back in place
125-
inject_into_urllib3()

mocket/socket.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from types import TracebackType
1212
from typing import Any, Type
1313

14-
import urllib3.connection
1514
from typing_extensions import Self
1615

1716
from mocket.compat import decode_from_bytes, encode_to_bytes
@@ -27,14 +26,8 @@
2726
)
2827
from mocket.utils import hexdump, hexload
2928

30-
true_create_connection = socket.create_connection
31-
true_getaddrinfo = socket.getaddrinfo
3229
true_gethostbyname = socket.gethostbyname
33-
true_gethostname = socket.gethostname
34-
true_inet_pton = socket.inet_pton
3530
true_socket = socket.socket
36-
true_socketpair = socket.socketpair
37-
true_urllib3_match_hostname = urllib3.connection.match_hostname
3831

3932

4033
xxh32 = None
@@ -84,10 +77,6 @@ def mock_socketpair(*args, **kwargs):
8477
return _socket.socketpair(*args, **kwargs)
8578

8679

87-
def mock_urllib3_match_hostname(*args: Any) -> None:
88-
return None
89-
90-
9180
def _hash_request(h, req):
9281
return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
9382

mocket/ssl/context.py

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,10 @@
11
from __future__ import annotations
22

3-
import contextlib
4-
import ssl
53
from typing import Any
64

7-
import urllib3.util.ssl_
8-
95
from mocket.socket import MocketSocket
106
from mocket.ssl.socket import MocketSSLSocket
117

12-
true_ssl_context = ssl.SSLContext
13-
14-
true_ssl_wrap_socket = None
15-
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
16-
true_urllib3_wrap_socket = None
17-
18-
with contextlib.suppress(ImportError):
19-
# from Py3.12 it's only under SSLContext
20-
from ssl import wrap_socket as ssl_wrap_socket
21-
22-
true_ssl_wrap_socket = ssl_wrap_socket
23-
24-
with contextlib.suppress(ImportError):
25-
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
26-
27-
true_urllib3_wrap_socket = urllib3_wrap_socket
28-
298

309
class _MocketSSLContext:
3110
"""For Python 3.6 and newer."""
@@ -70,30 +49,16 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any:
7049
for m in self.DUMMY_METHODS:
7150
setattr(self, m, dummy_method)
7251

73-
@staticmethod
74-
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket:
75-
ssl_socket = MocketSSLSocket()
76-
ssl_socket._original_socket = sock
77-
78-
ssl_socket._true_socket = true_urllib3_ssl_wrap_socket(
79-
sock._true_socket,
80-
**kwargs,
81-
)
82-
ssl_socket._kwargs = kwargs
83-
84-
ssl_socket._timeout = sock._timeout
85-
86-
ssl_socket._host = sock._host
87-
ssl_socket._port = sock._port
88-
ssl_socket._address = sock._address
89-
90-
ssl_socket._io = sock._io
91-
ssl_socket._entry = sock._entry
92-
93-
return ssl_socket
52+
def wrap_socket(
53+
self,
54+
sock: MocketSocket,
55+
*args: Any,
56+
**kwargs: Any,
57+
) -> MocketSSLSocket:
58+
return MocketSSLSocket._create(sock, *args, **kwargs)
9459

95-
@staticmethod
9660
def wrap_bio(
61+
self,
9762
incoming: Any, # _ssl.MemoryBIO
9863
outgoing: Any, # _ssl.MemoryBIO
9964
server_side: bool = False,
@@ -102,3 +67,12 @@ def wrap_bio(
10267
ssl_obj = MocketSSLSocket()
10368
ssl_obj._host = server_hostname
10469
return ssl_obj
70+
71+
72+
def mock_wrap_socket(
73+
sock: MocketSocket,
74+
*args: Any,
75+
**kwargs: Any,
76+
) -> MocketSSLSocket:
77+
context = MocketSSLContext()
78+
return context.wrap_socket(sock, *args, **kwargs)

mocket/ssl/socket.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,35 @@ def compression(self) -> str | None:
6060

6161
def unwrap(self) -> MocketSocket:
6262
return self._original_socket
63+
64+
@classmethod
65+
def _create(
66+
cls,
67+
sock: MocketSocket,
68+
ssl_context: ssl.SSLContext | None = None,
69+
server_hostname: str | None = None,
70+
*args: Any,
71+
**kwargs: Any,
72+
) -> MocketSSLSocket:
73+
ssl_socket = MocketSSLSocket()
74+
ssl_socket._original_socket = sock
75+
ssl_socket._true_socket = sock._true_socket
76+
77+
if ssl_context:
78+
ssl_socket._true_socket = ssl_context.wrap_socket(
79+
sock=ssl_socket._true_socket,
80+
server_hostname=server_hostname,
81+
)
82+
83+
ssl_socket._kwargs = kwargs
84+
85+
ssl_socket._timeout = sock._timeout
86+
87+
ssl_socket._host = sock._host
88+
ssl_socket._port = sock._port
89+
ssl_socket._address = sock._address
90+
91+
ssl_socket._io = sock._io
92+
ssl_socket._entry = sock._entry
93+
94+
return ssl_socket

mocket/urllib3.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from mocket.socket import MocketSocket
6+
from mocket.ssl.context import MocketSSLContext
7+
from mocket.ssl.socket import MocketSSLSocket
8+
9+
10+
def mock_match_hostname(*args: Any) -> None:
11+
return None
12+
13+
14+
def mock_ssl_wrap_socket(
15+
sock: MocketSocket,
16+
*args: Any,
17+
**kwargs: Any,
18+
) -> MocketSSLSocket:
19+
context = MocketSSLContext()
20+
return context.wrap_socket(sock, *args, **kwargs)

mocket/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def get_mocketize(wrapper_: Callable) -> Callable:
4545

4646

4747
__all__ = (
48-
"MocketSocketCore",
4948
"MocketMode",
49+
"MocketSocketCore",
5050
"SSL_PROTOCOL",
51+
"get_mocketize",
5152
"hexdump",
5253
"hexload",
53-
"get_mocketize",
5454
)

tests/test_mode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mocket import Mocketizer, mocketize
55
from mocket.exceptions import StrictMocketException
66
from mocket.mockhttp import Entry, Response
7-
from mocket.utils import MocketMode
7+
from mocket.mode import MocketMode
88

99

1010
@mocketize(strict_mode=True)

0 commit comments

Comments
 (0)