1
1
from __future__ import annotations
2
2
3
+ import contextlib
3
4
import os
4
5
import socket
5
6
import ssl
7
+ from types import ModuleType
8
+ from typing import Any
6
9
7
10
import urllib3
8
11
9
- try : # pragma: no cover
10
- from urllib3 .contrib .pyopenssl import extract_from_urllib3 , inject_into_urllib3
12
+ _patches_restore : dict [tuple [ModuleType , str ], Any ] = {}
11
13
12
- pyopenssl_override = True
13
- except ImportError :
14
- pyopenssl_override = False
14
+
15
+ def _patch (module : ModuleType , name : str , patched_value : Any ) -> None :
16
+ with contextlib .suppress (KeyError ):
17
+ original_value , module .__dict__ [name ] = module .__dict__ [name ], patched_value
18
+ _patches_restore [(module , name )] = original_value
19
+
20
+
21
+ def _restore (module : ModuleType , name : str ) -> None :
22
+ if original_value := _patches_restore .pop ((module , name )):
23
+ module .__dict__ [name ] = original_value
15
24
16
25
17
26
def enable (
18
27
namespace : str | None = None ,
19
28
truesocket_recording_dir : str | None = None ,
20
29
) -> None :
21
- from mocket .mocket import Mocket
22
30
from mocket .socket import (
23
31
MocketSocket ,
24
32
mock_create_connection ,
@@ -27,99 +35,62 @@ def enable(
27
35
mock_gethostname ,
28
36
mock_inet_pton ,
29
37
mock_socketpair ,
30
- mock_urllib3_match_hostname ,
31
38
)
32
- from mocket .ssl .context import MocketSSLContext
39
+ from mocket .ssl .context import MocketSSLContext , mock_wrap_socket
40
+ from mocket .urllib3 import (
41
+ mock_match_hostname as mock_urllib3_match_hostname ,
42
+ )
43
+ from mocket .urllib3 import (
44
+ mock_ssl_wrap_socket as mock_urllib3_ssl_wrap_socket ,
45
+ )
46
+
47
+ patches = {
48
+ # stdlib: socket
49
+ (socket , "socket" ): MocketSocket ,
50
+ (socket , "create_connection" ): mock_create_connection ,
51
+ (socket , "getaddrinfo" ): mock_getaddrinfo ,
52
+ (socket , "gethostbyname" ): mock_gethostbyname ,
53
+ (socket , "gethostname" ): mock_gethostname ,
54
+ (socket , "inet_pton" ): mock_inet_pton ,
55
+ (socket , "SocketType" ): MocketSocket ,
56
+ (socket , "socketpair" ): mock_socketpair ,
57
+ # stdlib: ssl
58
+ (ssl , "SSLContext" ): MocketSSLContext ,
59
+ (ssl , "wrap_socket" ): mock_wrap_socket , # python < 3.12.0
60
+ # urllib3
61
+ (urllib3 .connection , "match_hostname" ): mock_urllib3_match_hostname ,
62
+ (urllib3 .connection , "ssl_wrap_socket" ): mock_urllib3_ssl_wrap_socket ,
63
+ (urllib3 .util , "ssl_wrap_socket" ): mock_urllib3_ssl_wrap_socket ,
64
+ (urllib3 .util .ssl_ , "ssl_wrap_socket" ): mock_urllib3_ssl_wrap_socket ,
65
+ (urllib3 .util .ssl_ , "wrap_socket" ): mock_urllib3_ssl_wrap_socket , # urllib3 < 2
66
+ }
67
+
68
+ for (module , name ), new_value in patches .items ():
69
+ _patch (module , name , new_value )
70
+
71
+ with contextlib .suppress (ImportError ):
72
+ from urllib3 .contrib .pyopenssl import extract_from_urllib3
73
+
74
+ extract_from_urllib3 ()
75
+
76
+ from mocket .mocket import Mocket
33
77
34
78
Mocket ._namespace = namespace
35
79
Mocket ._truesocket_recording_dir = truesocket_recording_dir
36
-
37
80
if truesocket_recording_dir and not os .path .isdir (truesocket_recording_dir ):
38
81
# JSON dumps will be saved here
39
82
raise AssertionError
40
83
41
- socket .socket = socket .__dict__ ["socket" ] = MocketSocket
42
- socket ._socketobject = socket .__dict__ ["_socketobject" ] = MocketSocket
43
- socket .SocketType = socket .__dict__ ["SocketType" ] = MocketSocket
44
- socket .create_connection = socket .__dict__ ["create_connection" ] = (
45
- mock_create_connection
46
- )
47
- socket .gethostname = socket .__dict__ ["gethostname" ] = mock_gethostname
48
- socket .gethostbyname = socket .__dict__ ["gethostbyname" ] = mock_gethostbyname
49
- socket .getaddrinfo = socket .__dict__ ["getaddrinfo" ] = mock_getaddrinfo
50
- socket .socketpair = socket .__dict__ ["socketpair" ] = mock_socketpair
51
- ssl .wrap_socket = ssl .__dict__ ["wrap_socket" ] = MocketSSLContext .wrap_socket
52
- ssl .SSLContext = ssl .__dict__ ["SSLContext" ] = MocketSSLContext
53
- socket .inet_pton = socket .__dict__ ["inet_pton" ] = mock_inet_pton
54
- urllib3 .util .ssl_ .wrap_socket = urllib3 .util .ssl_ .__dict__ ["wrap_socket" ] = (
55
- MocketSSLContext .wrap_socket
56
- )
57
- urllib3 .util .ssl_ .ssl_wrap_socket = urllib3 .util .ssl_ .__dict__ [
58
- "ssl_wrap_socket"
59
- ] = MocketSSLContext .wrap_socket
60
- urllib3 .util .ssl_wrap_socket = urllib3 .util .__dict__ ["ssl_wrap_socket" ] = (
61
- MocketSSLContext .wrap_socket
62
- )
63
- urllib3 .connection .ssl_wrap_socket = urllib3 .connection .__dict__ [
64
- "ssl_wrap_socket"
65
- ] = MocketSSLContext .wrap_socket
66
- urllib3 .connection .match_hostname = urllib3 .connection .__dict__ [
67
- "match_hostname"
68
- ] = mock_urllib3_match_hostname
69
- if pyopenssl_override : # pragma: no cover
70
- # Take out the pyopenssl version - use the default implementation
71
- extract_from_urllib3 ()
72
-
73
84
74
85
def disable () -> None :
86
+ for module , name in list (_patches_restore .keys ()):
87
+ _restore (module , name )
88
+
89
+ with contextlib .suppress (ImportError ):
90
+ from urllib3 .contrib .pyopenssl import inject_into_urllib3
91
+
92
+ inject_into_urllib3 ()
93
+
75
94
from mocket .mocket import Mocket
76
- from mocket .socket import (
77
- true_create_connection ,
78
- true_getaddrinfo ,
79
- true_gethostbyname ,
80
- true_gethostname ,
81
- true_inet_pton ,
82
- true_socket ,
83
- true_socketpair ,
84
- true_urllib3_match_hostname ,
85
- )
86
- from mocket .ssl .context import (
87
- true_ssl_context ,
88
- true_ssl_wrap_socket ,
89
- true_urllib3_ssl_wrap_socket ,
90
- true_urllib3_wrap_socket ,
91
- )
92
95
93
- socket .socket = socket .__dict__ ["socket" ] = true_socket
94
- socket ._socketobject = socket .__dict__ ["_socketobject" ] = true_socket
95
- socket .SocketType = socket .__dict__ ["SocketType" ] = true_socket
96
- socket .create_connection = socket .__dict__ ["create_connection" ] = (
97
- true_create_connection
98
- )
99
- socket .gethostname = socket .__dict__ ["gethostname" ] = true_gethostname
100
- socket .gethostbyname = socket .__dict__ ["gethostbyname" ] = true_gethostbyname
101
- socket .getaddrinfo = socket .__dict__ ["getaddrinfo" ] = true_getaddrinfo
102
- socket .socketpair = socket .__dict__ ["socketpair" ] = true_socketpair
103
- if true_ssl_wrap_socket :
104
- ssl .wrap_socket = ssl .__dict__ ["wrap_socket" ] = true_ssl_wrap_socket
105
- ssl .SSLContext = ssl .__dict__ ["SSLContext" ] = true_ssl_context
106
- socket .inet_pton = socket .__dict__ ["inet_pton" ] = true_inet_pton
107
- urllib3 .util .ssl_ .wrap_socket = urllib3 .util .ssl_ .__dict__ ["wrap_socket" ] = (
108
- true_urllib3_wrap_socket
109
- )
110
- urllib3 .util .ssl_ .ssl_wrap_socket = urllib3 .util .ssl_ .__dict__ [
111
- "ssl_wrap_socket"
112
- ] = true_urllib3_ssl_wrap_socket
113
- urllib3 .util .ssl_wrap_socket = urllib3 .util .__dict__ ["ssl_wrap_socket" ] = (
114
- true_urllib3_ssl_wrap_socket
115
- )
116
- urllib3 .connection .ssl_wrap_socket = urllib3 .connection .__dict__ [
117
- "ssl_wrap_socket"
118
- ] = true_urllib3_ssl_wrap_socket
119
- urllib3 .connection .match_hostname = urllib3 .connection .__dict__ [
120
- "match_hostname"
121
- ] = true_urllib3_match_hostname
122
96
Mocket .reset ()
123
- if pyopenssl_override : # pragma: no cover
124
- # Put the pyopenssl version back in place
125
- inject_into_urllib3 ()
0 commit comments