|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | +import socket |
| 5 | +import ssl |
| 6 | + |
| 7 | +import urllib3 |
| 8 | +from urllib3.connection import match_hostname as urllib3_match_hostname |
| 9 | +from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket |
| 10 | + |
| 11 | +try: |
| 12 | + from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket |
| 13 | +except ImportError: |
| 14 | + urllib3_wrap_socket = None |
| 15 | + |
| 16 | + |
| 17 | +try: # pragma: no cover |
| 18 | + from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 |
| 19 | + |
| 20 | + pyopenssl_override = True |
| 21 | +except ImportError: |
| 22 | + pyopenssl_override = False |
| 23 | + |
| 24 | +true_socket = socket.socket |
| 25 | +true_create_connection = socket.create_connection |
| 26 | +true_gethostbyname = socket.gethostbyname |
| 27 | +true_gethostname = socket.gethostname |
| 28 | +true_getaddrinfo = socket.getaddrinfo |
| 29 | +true_socketpair = socket.socketpair |
| 30 | +true_ssl_wrap_socket = getattr( |
| 31 | + ssl, "wrap_socket", None |
| 32 | +) # from Py3.12 it's only under SSLContext |
| 33 | +true_ssl_socket = ssl.SSLSocket |
| 34 | +true_ssl_context = ssl.SSLContext |
| 35 | +true_inet_pton = socket.inet_pton |
| 36 | +true_urllib3_wrap_socket = urllib3_wrap_socket |
| 37 | +true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket |
| 38 | +true_urllib3_match_hostname = urllib3_match_hostname |
| 39 | + |
| 40 | + |
| 41 | +def enable( |
| 42 | + namespace: str | None = None, |
| 43 | + truesocket_recording_dir: str | None = None, |
| 44 | +) -> None: |
| 45 | + from mocket.mocket import Mocket |
| 46 | + from mocket.socket import MocketSocket, create_connection, socketpair |
| 47 | + from mocket.ssl import FakeSSLContext |
| 48 | + |
| 49 | + Mocket._namespace = namespace |
| 50 | + Mocket._truesocket_recording_dir = truesocket_recording_dir |
| 51 | + |
| 52 | + if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): |
| 53 | + # JSON dumps will be saved here |
| 54 | + raise AssertionError |
| 55 | + |
| 56 | + socket.socket = socket.__dict__["socket"] = MocketSocket |
| 57 | + socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket |
| 58 | + socket.SocketType = socket.__dict__["SocketType"] = MocketSocket |
| 59 | + socket.create_connection = socket.__dict__["create_connection"] = create_connection |
| 60 | + socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost" |
| 61 | + socket.gethostbyname = socket.__dict__["gethostbyname"] = lambda host: "127.0.0.1" |
| 62 | + socket.getaddrinfo = socket.__dict__["getaddrinfo"] = ( |
| 63 | + lambda host, port, family=None, socktype=None, proto=None, flags=None: [ |
| 64 | + (2, 1, 6, "", (host, port)) |
| 65 | + ] |
| 66 | + ) |
| 67 | + socket.socketpair = socket.__dict__["socketpair"] = socketpair |
| 68 | + ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket |
| 69 | + ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext |
| 70 | + socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes( |
| 71 | + "\x7f\x00\x00\x01", "utf-8" |
| 72 | + ) |
| 73 | + urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( |
| 74 | + FakeSSLContext.wrap_socket |
| 75 | + ) |
| 76 | + urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ |
| 77 | + "ssl_wrap_socket" |
| 78 | + ] = FakeSSLContext.wrap_socket |
| 79 | + urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( |
| 80 | + FakeSSLContext.wrap_socket |
| 81 | + ) |
| 82 | + urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ |
| 83 | + "ssl_wrap_socket" |
| 84 | + ] = FakeSSLContext.wrap_socket |
| 85 | + urllib3.connection.match_hostname = urllib3.connection.__dict__[ |
| 86 | + "match_hostname" |
| 87 | + ] = lambda *args: None |
| 88 | + if pyopenssl_override: # pragma: no cover |
| 89 | + # Take out the pyopenssl version - use the default implementation |
| 90 | + extract_from_urllib3() |
| 91 | + |
| 92 | + |
| 93 | +def disable() -> None: |
| 94 | + from mocket.mocket import Mocket |
| 95 | + |
| 96 | + socket.socket = socket.__dict__["socket"] = true_socket |
| 97 | + socket._socketobject = socket.__dict__["_socketobject"] = true_socket |
| 98 | + socket.SocketType = socket.__dict__["SocketType"] = true_socket |
| 99 | + socket.create_connection = socket.__dict__["create_connection"] = ( |
| 100 | + true_create_connection |
| 101 | + ) |
| 102 | + socket.gethostname = socket.__dict__["gethostname"] = true_gethostname |
| 103 | + socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname |
| 104 | + socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo |
| 105 | + socket.socketpair = socket.__dict__["socketpair"] = true_socketpair |
| 106 | + if true_ssl_wrap_socket: |
| 107 | + ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket |
| 108 | + ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context |
| 109 | + socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton |
| 110 | + urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( |
| 111 | + true_urllib3_wrap_socket |
| 112 | + ) |
| 113 | + urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ |
| 114 | + "ssl_wrap_socket" |
| 115 | + ] = true_urllib3_ssl_wrap_socket |
| 116 | + urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( |
| 117 | + true_urllib3_ssl_wrap_socket |
| 118 | + ) |
| 119 | + urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ |
| 120 | + "ssl_wrap_socket" |
| 121 | + ] = true_urllib3_ssl_wrap_socket |
| 122 | + urllib3.connection.match_hostname = urllib3.connection.__dict__[ |
| 123 | + "match_hostname" |
| 124 | + ] = true_urllib3_match_hostname |
| 125 | + Mocket.reset() |
| 126 | + if pyopenssl_override: # pragma: no cover |
| 127 | + # Put the pyopenssl version back in place |
| 128 | + inject_into_urllib3() |
0 commit comments