From df2b381263c01c9732fd9cc6b8e04a82f1524a3f Mon Sep 17 00:00:00 2001 From: betaboon Date: Sat, 23 Nov 2024 18:39:20 +0100 Subject: [PATCH 1/4] refactor: make injection code more readable and make backwards-compat explicit --- mocket/inject.py | 147 +++++++++++++++++------------------------- mocket/socket.py | 11 ---- mocket/ssl/context.py | 25 +------ mocket/urllib3.py | 23 +++++++ mocket/utils.py | 4 +- tests/test_mode.py | 2 +- 6 files changed, 88 insertions(+), 124 deletions(-) create mode 100644 mocket/urllib3.py diff --git a/mocket/inject.py b/mocket/inject.py index 35e9da01..4a9bb5ee 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -1,24 +1,32 @@ from __future__ import annotations +import contextlib import os import socket import ssl +from types import ModuleType +from typing import Any import urllib3 -try: # pragma: no cover - from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 +_patches_restore: dict[tuple[ModuleType, str], Any] = {} - pyopenssl_override = True -except ImportError: - pyopenssl_override = False + +def _patch(module: ModuleType, name: str, patched_value: Any) -> None: + with contextlib.suppress(KeyError): + original_value, module.__dict__[name] = module.__dict__[name], patched_value + _patches_restore[(module, name)] = original_value + + +def _restore(module: ModuleType, name: str) -> None: + if original_value := _patches_restore.pop((module, name)): + module.__dict__[name] = original_value def enable( namespace: str | None = None, truesocket_recording_dir: str | None = None, ) -> None: - from mocket.mocket import Mocket from mocket.socket import ( MocketSocket, mock_create_connection, @@ -27,99 +35,62 @@ def enable( mock_gethostname, mock_inet_pton, mock_socketpair, - mock_urllib3_match_hostname, ) from mocket.ssl.context import MocketSSLContext + from mocket.urllib3 import ( + mock_match_hostname as mock_urllib3_match_hostname, + ) + from mocket.urllib3 import ( + mock_ssl_wrap_socket as mock_urllib3_ssl_wrap_socket, + ) + + patches = { + # stdlib: socket + (socket, "socket"): MocketSocket, + (socket, "create_connection"): mock_create_connection, + (socket, "getaddrinfo"): mock_getaddrinfo, + (socket, "gethostbyname"): mock_gethostbyname, + (socket, "gethostname"): mock_gethostname, + (socket, "inet_pton"): mock_inet_pton, + (socket, "SocketType"): MocketSocket, + (socket, "socketpair"): mock_socketpair, + # stdlib: ssl + (ssl, "SSLContext"): MocketSSLContext, + (ssl, "wrap_socket"): MocketSSLContext.wrap_socket, # python < 3.12.0 + # urllib3 + (urllib3.connection, "match_hostname"): mock_urllib3_match_hostname, + (urllib3.connection, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util.ssl_, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util.ssl_, "wrap_socket"): mock_urllib3_ssl_wrap_socket, # urllib3 < 2 + } + + for (module, name), new_value in patches.items(): + _patch(module, name, new_value) + + with contextlib.suppress(ImportError): + from urllib3.contrib.pyopenssl import extract_from_urllib3 + + extract_from_urllib3() + + from mocket.mocket import Mocket Mocket._namespace = namespace Mocket._truesocket_recording_dir = truesocket_recording_dir - if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): # JSON dumps will be saved here raise AssertionError - socket.socket = socket.__dict__["socket"] = MocketSocket - socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket - socket.SocketType = socket.__dict__["SocketType"] = MocketSocket - socket.create_connection = socket.__dict__["create_connection"] = ( - mock_create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname - socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo - socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext - socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - MocketSSLContext.wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = MocketSSLContext.wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - MocketSSLContext.wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = MocketSSLContext.wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = mock_urllib3_match_hostname - if pyopenssl_override: # pragma: no cover - # Take out the pyopenssl version - use the default implementation - extract_from_urllib3() - def disable() -> None: + for module, name in list(_patches_restore.keys()): + _restore(module, name) + + with contextlib.suppress(ImportError): + from urllib3.contrib.pyopenssl import inject_into_urllib3 + + inject_into_urllib3() + from mocket.mocket import Mocket - from mocket.socket import ( - true_create_connection, - true_getaddrinfo, - true_gethostbyname, - true_gethostname, - true_inet_pton, - true_socket, - true_socketpair, - true_urllib3_match_hostname, - ) - from mocket.ssl.context import ( - true_ssl_context, - true_ssl_wrap_socket, - true_urllib3_ssl_wrap_socket, - true_urllib3_wrap_socket, - ) - socket.socket = socket.__dict__["socket"] = true_socket - socket._socketobject = socket.__dict__["_socketobject"] = true_socket - socket.SocketType = socket.__dict__["SocketType"] = true_socket - socket.create_connection = socket.__dict__["create_connection"] = ( - true_create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = true_gethostname - socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo - socket.socketpair = socket.__dict__["socketpair"] = true_socketpair - if true_ssl_wrap_socket: - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context - socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - true_urllib3_wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - true_urllib3_ssl_wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = true_urllib3_match_hostname Mocket.reset() - if pyopenssl_override: # pragma: no cover - # Put the pyopenssl version back in place - inject_into_urllib3() diff --git a/mocket/socket.py b/mocket/socket.py index 03c6f7e5..9480d365 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -11,7 +11,6 @@ from types import TracebackType from typing import Any, Type -import urllib3.connection from typing_extensions import Self from mocket.compat import decode_from_bytes, encode_to_bytes @@ -27,14 +26,8 @@ ) from mocket.utils import hexdump, hexload -true_create_connection = socket.create_connection -true_getaddrinfo = socket.getaddrinfo true_gethostbyname = socket.gethostbyname -true_gethostname = socket.gethostname -true_inet_pton = socket.inet_pton true_socket = socket.socket -true_socketpair = socket.socketpair -true_urllib3_match_hostname = urllib3.connection.match_hostname xxh32 = None @@ -84,10 +77,6 @@ def mock_socketpair(*args, **kwargs): return _socket.socketpair(*args, **kwargs) -def mock_urllib3_match_hostname(*args: Any) -> None: - return None - - def _hash_request(h, req): return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index 438faa10..12d2eb2d 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -1,31 +1,10 @@ from __future__ import annotations -import contextlib -import ssl from typing import Any -import urllib3.util.ssl_ - from mocket.socket import MocketSocket from mocket.ssl.socket import MocketSSLSocket -true_ssl_context = ssl.SSLContext - -true_ssl_wrap_socket = None -true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket -true_urllib3_wrap_socket = None - -with contextlib.suppress(ImportError): - # from Py3.12 it's only under SSLContext - from ssl import wrap_socket as ssl_wrap_socket - - true_ssl_wrap_socket = ssl_wrap_socket - -with contextlib.suppress(ImportError): - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket - - true_urllib3_wrap_socket = urllib3_wrap_socket - class _MocketSSLContext: """For Python 3.6 and newer.""" @@ -75,7 +54,9 @@ def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocke ssl_socket = MocketSSLSocket() ssl_socket._original_socket = sock - ssl_socket._true_socket = true_urllib3_ssl_wrap_socket( + from mocket.urllib3 import true_ssl_wrap_socket + + ssl_socket._true_socket = true_ssl_wrap_socket( sock._true_socket, **kwargs, ) diff --git a/mocket/urllib3.py b/mocket/urllib3.py new file mode 100644 index 00000000..9a8a6569 --- /dev/null +++ b/mocket/urllib3.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Any + +import urllib3 +from mocket.socket import MocketSocket +from mocket.ssl.context import MocketSSLContext +from mocket.ssl.socket import MocketSSLSocket + +true_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket + + +def mock_match_hostname(*args: Any) -> None: + return None + + +def mock_ssl_wrap_socket( + sock: MocketSocket, + *args: Any, + **kwargs: Any, +) -> MocketSSLSocket: + context = MocketSSLContext() + return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/utils.py b/mocket/utils.py index b9e2c259..59403954 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -45,10 +45,10 @@ def get_mocketize(wrapper_: Callable) -> Callable: __all__ = ( - "MocketSocketCore", "MocketMode", + "MocketSocketCore", "SSL_PROTOCOL", + "get_mocketize", "hexdump", "hexload", - "get_mocketize", ) diff --git a/tests/test_mode.py b/tests/test_mode.py index 2a764949..ea5905b0 100644 --- a/tests/test_mode.py +++ b/tests/test_mode.py @@ -4,7 +4,7 @@ from mocket import Mocketizer, mocketize from mocket.exceptions import StrictMocketException from mocket.mockhttp import Entry, Response -from mocket.utils import MocketMode +from mocket.mode import MocketMode @mocketize(strict_mode=True) From 5d99f07f3426dad67b530d42af42dba1f2c572b3 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sat, 23 Nov 2024 20:47:38 +0100 Subject: [PATCH 2/4] refactor: move ssl socket-wrapping code to ssl/socket.py --- mocket/ssl/context.py | 22 +--------------------- mocket/ssl/socket.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index 12d2eb2d..84b848d5 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -51,27 +51,7 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any: @staticmethod def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket: - ssl_socket = MocketSSLSocket() - ssl_socket._original_socket = sock - - from mocket.urllib3 import true_ssl_wrap_socket - - ssl_socket._true_socket = true_ssl_wrap_socket( - sock._true_socket, - **kwargs, - ) - ssl_socket._kwargs = kwargs - - ssl_socket._timeout = sock._timeout - - ssl_socket._host = sock._host - ssl_socket._port = sock._port - ssl_socket._address = sock._address - - ssl_socket._io = sock._io - ssl_socket._entry = sock._entry - - return ssl_socket + return MocketSSLSocket._create(sock, *args, **kwargs) @staticmethod def wrap_bio( diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py index e50b7320..d0dda0ce 100644 --- a/mocket/ssl/socket.py +++ b/mocket/ssl/socket.py @@ -60,3 +60,28 @@ def compression(self) -> str | None: def unwrap(self) -> MocketSocket: return self._original_socket + + @classmethod + def _create(cls, sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket: + ssl_socket = MocketSSLSocket() + ssl_socket._original_socket = sock + + from mocket.urllib3 import true_ssl_wrap_socket + + ssl_socket._true_socket = true_ssl_wrap_socket( + sock._true_socket, + **kwargs, + ) + + ssl_socket._kwargs = kwargs + + ssl_socket._timeout = sock._timeout + + ssl_socket._host = sock._host + ssl_socket._port = sock._port + ssl_socket._address = sock._address + + ssl_socket._io = sock._io + ssl_socket._entry = sock._entry + + return ssl_socket From 815e86fcd48b0f3d32a3f5dbad5aaaac492802e6 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sat, 23 Nov 2024 20:51:10 +0100 Subject: [PATCH 3/4] refactor: convert MocketSSLContext.wrap_socket and wrap_bio to instance-methods --- mocket/inject.py | 4 ++-- mocket/ssl/context.py | 19 ++++++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/mocket/inject.py b/mocket/inject.py index 4a9bb5ee..469ab30b 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -36,7 +36,7 @@ def enable( mock_inet_pton, mock_socketpair, ) - from mocket.ssl.context import MocketSSLContext + from mocket.ssl.context import MocketSSLContext, mock_wrap_socket from mocket.urllib3 import ( mock_match_hostname as mock_urllib3_match_hostname, ) @@ -56,7 +56,7 @@ def enable( (socket, "socketpair"): mock_socketpair, # stdlib: ssl (ssl, "SSLContext"): MocketSSLContext, - (ssl, "wrap_socket"): MocketSSLContext.wrap_socket, # python < 3.12.0 + (ssl, "wrap_socket"): mock_wrap_socket, # python < 3.12.0 # urllib3 (urllib3.connection, "match_hostname"): mock_urllib3_match_hostname, (urllib3.connection, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index 84b848d5..6d5e7307 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -49,12 +49,16 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any: for m in self.DUMMY_METHODS: setattr(self, m, dummy_method) - @staticmethod - def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket: + def wrap_socket( + self, + sock: MocketSocket, + *args: Any, + **kwargs: Any, + ) -> MocketSSLSocket: return MocketSSLSocket._create(sock, *args, **kwargs) - @staticmethod def wrap_bio( + self, incoming: Any, # _ssl.MemoryBIO outgoing: Any, # _ssl.MemoryBIO server_side: bool = False, @@ -63,3 +67,12 @@ def wrap_bio( ssl_obj = MocketSSLSocket() ssl_obj._host = server_hostname return ssl_obj + + +def mock_wrap_socket( + sock: MocketSocket, + *args: Any, + **kwargs: Any, +) -> MocketSSLSocket: + context = MocketSSLContext() + return context.wrap_socket(sock, *args, **kwargs) From a1551905857ae5792f43e7969fa6631e5751bb7c Mon Sep 17 00:00:00 2001 From: betaboon Date: Sat, 23 Nov 2024 20:54:50 +0100 Subject: [PATCH 4/4] refactor: MocketSSLSocket use proper ssl-context instead of urllib3 --- mocket/ssl/socket.py | 21 ++++++++++++++------- mocket/urllib3.py | 3 --- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py index d0dda0ce..f7f41761 100644 --- a/mocket/ssl/socket.py +++ b/mocket/ssl/socket.py @@ -62,16 +62,23 @@ def unwrap(self) -> MocketSocket: return self._original_socket @classmethod - def _create(cls, sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket: + def _create( + cls, + sock: MocketSocket, + ssl_context: ssl.SSLContext | None = None, + server_hostname: str | None = None, + *args: Any, + **kwargs: Any, + ) -> MocketSSLSocket: ssl_socket = MocketSSLSocket() ssl_socket._original_socket = sock + ssl_socket._true_socket = sock._true_socket - from mocket.urllib3 import true_ssl_wrap_socket - - ssl_socket._true_socket = true_ssl_wrap_socket( - sock._true_socket, - **kwargs, - ) + if ssl_context: + ssl_socket._true_socket = ssl_context.wrap_socket( + sock=ssl_socket._true_socket, + server_hostname=server_hostname, + ) ssl_socket._kwargs = kwargs diff --git a/mocket/urllib3.py b/mocket/urllib3.py index 9a8a6569..e89bc7b5 100644 --- a/mocket/urllib3.py +++ b/mocket/urllib3.py @@ -2,13 +2,10 @@ from typing import Any -import urllib3 from mocket.socket import MocketSocket from mocket.ssl.context import MocketSSLContext from mocket.ssl.socket import MocketSSLSocket -true_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket - def mock_match_hostname(*args: Any) -> None: return None