Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 60 additions & 89 deletions mocket/inject.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
from __future__ import annotations

import contextlib
import os
import socket
import ssl
from types import ModuleType
from typing import Any

import urllib3

try: # pragma: no cover
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
_patches_restore: dict[tuple[ModuleType, str], Any] = {}

pyopenssl_override = True
except ImportError:
pyopenssl_override = False

def _patch(module: ModuleType, name: str, patched_value: Any) -> None:
with contextlib.suppress(KeyError):
original_value, module.__dict__[name] = module.__dict__[name], patched_value
_patches_restore[(module, name)] = original_value


def _restore(module: ModuleType, name: str) -> None:
if original_value := _patches_restore.pop((module, name)):
module.__dict__[name] = original_value


def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
from mocket.mocket import Mocket
from mocket.socket import (
MocketSocket,
mock_create_connection,
Expand All @@ -27,99 +35,62 @@ def enable(
mock_gethostname,
mock_inet_pton,
mock_socketpair,
mock_urllib3_match_hostname,
)
from mocket.ssl.context import MocketSSLContext
from mocket.ssl.context import MocketSSLContext, mock_wrap_socket
from mocket.urllib3 import (
mock_match_hostname as mock_urllib3_match_hostname,
)
from mocket.urllib3 import (
mock_ssl_wrap_socket as mock_urllib3_ssl_wrap_socket,
)

patches = {
# stdlib: socket
(socket, "socket"): MocketSocket,
(socket, "create_connection"): mock_create_connection,
(socket, "getaddrinfo"): mock_getaddrinfo,
(socket, "gethostbyname"): mock_gethostbyname,
(socket, "gethostname"): mock_gethostname,
(socket, "inet_pton"): mock_inet_pton,
(socket, "SocketType"): MocketSocket,
(socket, "socketpair"): mock_socketpair,
# stdlib: ssl
(ssl, "SSLContext"): MocketSSLContext,
(ssl, "wrap_socket"): mock_wrap_socket, # python < 3.12.0
# urllib3
(urllib3.connection, "match_hostname"): mock_urllib3_match_hostname,
(urllib3.connection, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket,
(urllib3.util, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket,
(urllib3.util.ssl_, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket,
(urllib3.util.ssl_, "wrap_socket"): mock_urllib3_ssl_wrap_socket, # urllib3 < 2
}

for (module, name), new_value in patches.items():
_patch(module, name, new_value)

with contextlib.suppress(ImportError):
from urllib3.contrib.pyopenssl import extract_from_urllib3

extract_from_urllib3()

from mocket.mocket import Mocket

Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir

if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
raise AssertionError

socket.socket = socket.__dict__["socket"] = MocketSocket
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
socket.create_connection = socket.__dict__["create_connection"] = (
mock_create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext
socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
MocketSSLContext.wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = MocketSSLContext.wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
MocketSSLContext.wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = MocketSSLContext.wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = mock_urllib3_match_hostname
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()


def disable() -> None:
for module, name in list(_patches_restore.keys()):
_restore(module, name)

with contextlib.suppress(ImportError):
from urllib3.contrib.pyopenssl import inject_into_urllib3

inject_into_urllib3()

from mocket.mocket import Mocket
from mocket.socket import (
true_create_connection,
true_getaddrinfo,
true_gethostbyname,
true_gethostname,
true_inet_pton,
true_socket,
true_socketpair,
true_urllib3_match_hostname,
)
from mocket.ssl.context import (
true_ssl_context,
true_ssl_wrap_socket,
true_urllib3_ssl_wrap_socket,
true_urllib3_wrap_socket,
)

socket.socket = socket.__dict__["socket"] = true_socket
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
socket.SocketType = socket.__dict__["SocketType"] = true_socket
socket.create_connection = socket.__dict__["create_connection"] = (
true_create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
if true_ssl_wrap_socket:
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
true_urllib3_wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
true_urllib3_ssl_wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = true_urllib3_match_hostname
Mocket.reset()
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
inject_into_urllib3()
11 changes: 0 additions & 11 deletions mocket/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from types import TracebackType
from typing import Any, Type

import urllib3.connection
from typing_extensions import Self

from mocket.compat import decode_from_bytes, encode_to_bytes
Expand All @@ -27,14 +26,8 @@
)
from mocket.utils import hexdump, hexload

true_create_connection = socket.create_connection
true_getaddrinfo = socket.getaddrinfo
true_gethostbyname = socket.gethostbyname
true_gethostname = socket.gethostname
true_inet_pton = socket.inet_pton
true_socket = socket.socket
true_socketpair = socket.socketpair
true_urllib3_match_hostname = urllib3.connection.match_hostname


xxh32 = None
Expand Down Expand Up @@ -84,10 +77,6 @@ def mock_socketpair(*args, **kwargs):
return _socket.socketpair(*args, **kwargs)


def mock_urllib3_match_hostname(*args: Any) -> None:
return None


def _hash_request(h, req):
return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()

Expand Down
60 changes: 17 additions & 43 deletions mocket/ssl/context.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
from __future__ import annotations

import contextlib
import ssl
from typing import Any

import urllib3.util.ssl_

from mocket.socket import MocketSocket
from mocket.ssl.socket import MocketSSLSocket

true_ssl_context = ssl.SSLContext

true_ssl_wrap_socket = None
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
true_urllib3_wrap_socket = None

with contextlib.suppress(ImportError):
# from Py3.12 it's only under SSLContext
from ssl import wrap_socket as ssl_wrap_socket

true_ssl_wrap_socket = ssl_wrap_socket

with contextlib.suppress(ImportError):
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket

true_urllib3_wrap_socket = urllib3_wrap_socket


class _MocketSSLContext:
"""For Python 3.6 and newer."""
Expand Down Expand Up @@ -70,30 +49,16 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any:
for m in self.DUMMY_METHODS:
setattr(self, m, dummy_method)

@staticmethod
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket:
ssl_socket = MocketSSLSocket()
ssl_socket._original_socket = sock

ssl_socket._true_socket = true_urllib3_ssl_wrap_socket(
sock._true_socket,
**kwargs,
)
ssl_socket._kwargs = kwargs

ssl_socket._timeout = sock._timeout

ssl_socket._host = sock._host
ssl_socket._port = sock._port
ssl_socket._address = sock._address

ssl_socket._io = sock._io
ssl_socket._entry = sock._entry

return ssl_socket
def wrap_socket(
self,
sock: MocketSocket,
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
return MocketSSLSocket._create(sock, *args, **kwargs)

@staticmethod
def wrap_bio(
self,
incoming: Any, # _ssl.MemoryBIO
outgoing: Any, # _ssl.MemoryBIO
server_side: bool = False,
Expand All @@ -102,3 +67,12 @@ def wrap_bio(
ssl_obj = MocketSSLSocket()
ssl_obj._host = server_hostname
return ssl_obj


def mock_wrap_socket(
sock: MocketSocket,
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
context = MocketSSLContext()
return context.wrap_socket(sock, *args, **kwargs)
32 changes: 32 additions & 0 deletions mocket/ssl/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,35 @@ def compression(self) -> str | None:

def unwrap(self) -> MocketSocket:
return self._original_socket

@classmethod
def _create(
cls,
sock: MocketSocket,
ssl_context: ssl.SSLContext | None = None,
server_hostname: str | None = None,
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
ssl_socket = MocketSSLSocket()
ssl_socket._original_socket = sock
ssl_socket._true_socket = sock._true_socket

if ssl_context:
ssl_socket._true_socket = ssl_context.wrap_socket(
sock=ssl_socket._true_socket,
server_hostname=server_hostname,
)

ssl_socket._kwargs = kwargs

ssl_socket._timeout = sock._timeout

ssl_socket._host = sock._host
ssl_socket._port = sock._port
ssl_socket._address = sock._address

ssl_socket._io = sock._io
ssl_socket._entry = sock._entry

return ssl_socket
20 changes: 20 additions & 0 deletions mocket/urllib3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from typing import Any

from mocket.socket import MocketSocket
from mocket.ssl.context import MocketSSLContext
from mocket.ssl.socket import MocketSSLSocket


def mock_match_hostname(*args: Any) -> None:
return None


def mock_ssl_wrap_socket(
sock: MocketSocket,
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
context = MocketSSLContext()
return context.wrap_socket(sock, *args, **kwargs)
4 changes: 2 additions & 2 deletions mocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def get_mocketize(wrapper_: Callable) -> Callable:


__all__ = (
"MocketSocketCore",
"MocketMode",
"MocketSocketCore",
"SSL_PROTOCOL",
"get_mocketize",
"hexdump",
"hexload",
"get_mocketize",
)
2 changes: 1 addition & 1 deletion tests/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mocket import Mocketizer, mocketize
from mocket.exceptions import StrictMocketException
from mocket.mockhttp import Entry, Response
from mocket.utils import MocketMode
from mocket.mode import MocketMode


@mocketize(strict_mode=True)
Expand Down