Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', ' 3.13', 'pypy3.10']
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.10']

steps:
- uses: actions/checkout@v4
Expand Down
18 changes: 1 addition & 17 deletions mocket/inject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import contextlib
import os
import socket
import ssl
from types import ModuleType
Expand All @@ -23,10 +22,7 @@ def _restore(module: ModuleType, name: str) -> None:
module.__dict__[name] = original_value


def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
def enable() -> None:
from mocket.socket import (
MocketSocket,
mock_create_connection,
Expand Down Expand Up @@ -73,14 +69,6 @@ def enable(

extract_from_urllib3()

from mocket.mocket import Mocket

Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
raise AssertionError


def disable() -> None:
for module, name in list(_patches_restore.keys()):
Expand All @@ -90,7 +78,3 @@ def disable() -> None:
from urllib3.contrib.pyopenssl import inject_into_urllib3

inject_into_urllib3()

from mocket.mocket import Mocket

Mocket.reset()
46 changes: 39 additions & 7 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import collections
import itertools
import os
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar

import mocket.inject
from mocket.recording import MocketRecordStorage

# NOTE this is here for backwards-compat to keep old import-paths working
# from mocket.socket import MocketSocket as MocketSocket
Expand All @@ -20,11 +22,36 @@ class Mocket:
_address: ClassVar[Address] = (None, None)
_entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list)
_requests: ClassVar[list] = []
_namespace: ClassVar[str] = str(id(_entries))
_truesocket_recording_dir: ClassVar[str | None] = None
_record_storage: ClassVar[MocketRecordStorage | None] = None

enable = mocket.inject.enable
disable = mocket.inject.disable
@classmethod
def enable(
cls,
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
if namespace is None:
namespace = str(id(cls._entries))

if truesocket_recording_dir is not None:
recording_dir = Path(truesocket_recording_dir)

if not recording_dir.is_dir():
# JSON dumps will be saved here
raise AssertionError

cls._record_storage = MocketRecordStorage(
directory=recording_dir,
namespace=namespace,
)

mocket.inject.enable()

@classmethod
def disable(cls) -> None:
cls.reset()

mocket.inject.disable()

@classmethod
def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]:
Expand Down Expand Up @@ -69,6 +96,7 @@ def reset(cls) -> None:
cls._socket_pairs = {}
cls._entries = collections.defaultdict(list)
cls._requests = []
cls._record_storage = None

@classmethod
def last_request(cls):
Expand All @@ -89,12 +117,16 @@ def has_requests(cls) -> bool:
return bool(cls.request_list())

@classmethod
def get_namespace(cls) -> str:
return cls._namespace
def get_namespace(cls) -> str | None:
if not cls._record_storage:
return None
return cls._record_storage.namespace

@classmethod
def get_truesocket_recording_dir(cls) -> str | None:
return cls._truesocket_recording_dir
if not cls._record_storage:
return None
return str(cls._record_storage.directory)

@classmethod
def assert_fail_if_entries_not_served(cls) -> None:
Expand Down
147 changes: 147 additions & 0 deletions mocket/recording.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

import contextlib
import hashlib
import json
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path

from mocket.compat import decode_from_bytes, encode_to_bytes
from mocket.types import Address
from mocket.utils import hexdump, hexload

hash_function = hashlib.md5

with contextlib.suppress(ImportError):
from xxhash_cffi import xxh32 as xxhash_cffi_xxh32

hash_function = xxhash_cffi_xxh32

with contextlib.suppress(ImportError):
from xxhash import xxh32 as xxhash_xxh32

hash_function = xxhash_xxh32


def _hash_prepare_request(data: bytes) -> bytes:
_data = decode_from_bytes(data)
return encode_to_bytes("".join(sorted(_data.split("\r\n"))))


def _hash_request(data: bytes) -> str:
_data = _hash_prepare_request(data)
return hash_function(_data).hexdigest()


def _hash_request_fallback(data: bytes) -> str:
_data = _hash_prepare_request(data)
return hashlib.md5(_data).hexdigest()


@dataclass
class MocketRecord:
host: str
port: int
request: bytes
response: bytes


class MocketRecordStorage:
def __init__(self, directory: Path, namespace: str) -> None:
self._directory = directory
self._namespace = namespace
self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = (
defaultdict(defaultdict)
)

self._load()

@property
def directory(self) -> Path:
return self._directory

@property
def namespace(self) -> str:
return self._namespace

@property
def file(self) -> Path:
return self._directory / f"{self._namespace}.json"

def _load(self) -> None:
if not self.file.exists():
return

json_data = self.file.read_text()
records = json.loads(json_data)
for host, port_signature_record in records.items():
for port, signature_record in port_signature_record.items():
for signature, record in signature_record.items():
# NOTE backward-compat
try:
request_data = hexload(record["request"])
except ValueError:
request_data = record["request"]

self._records[(host, int(port))][signature] = MocketRecord(
host=host,
port=port,
request=request_data,
response=hexload(record["response"]),
)

def _save(self) -> None:
data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict(
lambda: defaultdict(defaultdict)
)
for address, signature_record in self._records.items():
host, port = address
for signature, record in signature_record.items():
data[host][str(port)][signature] = dict(
request=decode_from_bytes(record.request),
response=hexdump(record.response),
)

json_data = json.dumps(data, indent=4, sort_keys=True)
self.file.parent.mkdir(exist_ok=True)
self.file.write_text(json_data)

def get_records(self, address: Address) -> list[MocketRecord]:
return list(self._records[address].values())

def get_record(self, address: Address, request: bytes) -> MocketRecord | None:
# NOTE for backward-compat
request_signature_fallback = _hash_request_fallback(request)
if request_signature_fallback in self._records[address]:
return self._records[address].get(request_signature_fallback)

request_signature = _hash_request(request)
if request_signature in self._records[address]:
return self._records[address][request_signature]

return None

def put_record(
self,
address: Address,
request: bytes,
response: bytes,
) -> None:
host, port = address
record = MocketRecord(
host=host,
port=port,
request=request,
response=response,
)

# NOTE for backward-compat
request_signature_fallback = _hash_request_fallback(request)
if request_signature_fallback in self._records[address]:
self._records[address][request_signature_fallback] = record
return

request_signature = _hash_request(request)
self._records[address][request_signature] = record
self._save()
Loading