From f9d42dcc3b6499dc9615dbd340c8e33b72c354d7 Mon Sep 17 00:00:00 2001 From: David Montague Date: Tue, 28 Jan 2020 16:33:36 -0800 Subject: [PATCH 1/2] Add support for more generic CBVs by way of a generic router --- fastapi_utils/cbv.py | 57 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/fastapi_utils/cbv.py b/fastapi_utils/cbv.py index 32c014a0..cfdad464 100644 --- a/fastapi_utils/cbv.py +++ b/fastapi_utils/cbv.py @@ -1,5 +1,8 @@ +import functools import inspect -from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints +from copy import deepcopy +from types import FunctionType +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, get_type_hints from fastapi import APIRouter, Depends from pydantic.typing import is_classvar @@ -8,6 +11,7 @@ T = TypeVar("T") CBV_CLASS_KEY = "__cbv_class__" +GENERIC_CBV_ROUTERS_KEY = "__generic_cbv_routers__" def cbv(router: APIRouter) -> Callable[[Type[T]], Type[T]]: @@ -17,24 +21,50 @@ def decorator(cls: Type[T]) -> Type[T]: return decorator +def generic_cbv(router: APIRouter) -> Callable[[Type[T]], Type[T]]: + def decorator(cls: Type[T]) -> Type[T]: + generic_routers = getattr(cls, GENERIC_CBV_ROUTERS_KEY, None) + if generic_routers is None: + generic_routers = [] + setattr(cls, GENERIC_CBV_ROUTERS_KEY, generic_routers) + generic_routers.append(router) + return cls + + return decorator + + def _cbv(router: APIRouter, cls: Type[T]) -> Type[T]: _init_cbv(cls) cbv_router = APIRouter() functions = inspect.getmembers(cls, inspect.isfunction) - routes_by_endpoint = { - route.endpoint: route for route in router.routes if isinstance(route, (Route, WebSocketRoute)) - } + routes_by_endpoint = _routes_by_endpoint(router) + generic_routes_by_endpoint = {} + for generic_router in getattr(cls, GENERIC_CBV_ROUTERS_KEY, []): + generic_routes_by_endpoint.update(_routes_by_endpoint(generic_router)) for _, func in functions: route = routes_by_endpoint.get(func) if route is None: - continue - router.routes.remove(route) + route = generic_routes_by_endpoint.get(func) + if route is None: + continue + else: + router.routes.remove(route) + route = deepcopy(route) + route.endpoint = replace_method_with_copy(cls, func) _update_cbv_route_endpoint_signature(cls, route) cbv_router.routes.append(route) router.include_router(cbv_router) return cls +def _routes_by_endpoint(router: Optional[APIRouter]) -> Dict[Callable[..., Any], Union[Route, WebSocketRoute]]: + return ( + {} + if router is None + else {route.endpoint: route for route in router.routes if isinstance(route, (Route, WebSocketRoute))} + ) + + def _update_cbv_route_endpoint_signature(cls: Type[Any], route: Union[Route, WebSocketRoute]) -> None: old_endpoint = route.endpoint old_signature = inspect.signature(old_endpoint) @@ -78,3 +108,18 @@ def new_init(self: Any, *args: Any, **kwargs: Any) -> None: setattr(cls, "__signature__", new_signature) setattr(cls, "__init__", new_init) setattr(cls, CBV_CLASS_KEY, True) + + +def replace_method_with_copy(cls: Type[Any], function: FunctionType) -> FunctionType: + copied = FunctionType( + function.__code__, + function.__globals__, + name=function.__name__, + argdefs=function.__defaults__, + closure=function.__closure__, + ) + functools.update_wrapper(copied, function) + copied.__qualname__ = f"{cls.__name__}.{function.__name__}" + copied.__kwdefaults__ = function.__kwdefaults__ + setattr(cls, function.__name__, copied) + return copied From 6948667a57e134d128164655b44c66d45a9a0bbc Mon Sep 17 00:00:00 2001 From: David Montague Date: Thu, 30 Jan 2020 02:46:03 -0800 Subject: [PATCH 2/2] Add test coverage; hopefully make compatible with 3.6 --- fastapi_utils/cbv.py | 4 +-- tests/test_generic_cbv.py | 71 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 tests/test_generic_cbv.py diff --git a/fastapi_utils/cbv.py b/fastapi_utils/cbv.py index cfdad464..bb37d0b1 100644 --- a/fastapi_utils/cbv.py +++ b/fastapi_utils/cbv.py @@ -1,6 +1,6 @@ import functools import inspect -from copy import deepcopy +from copy import copy from types import FunctionType from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, get_type_hints @@ -49,7 +49,7 @@ def _cbv(router: APIRouter, cls: Type[T]) -> Type[T]: continue else: router.routes.remove(route) - route = deepcopy(route) + route = copy(route) route.endpoint = replace_method_with_copy(cls, func) _update_cbv_route_endpoint_signature(cls, route) cbv_router.routes.append(route) diff --git a/tests/test_generic_cbv.py b/tests/test_generic_cbv.py new file mode 100644 index 00000000..f98263de --- /dev/null +++ b/tests/test_generic_cbv.py @@ -0,0 +1,71 @@ +from typing import Optional + +from fastapi import APIRouter, Depends, FastAPI +from starlette.testclient import TestClient + +from fastapi_utils.cbv import cbv, generic_cbv + + +def get_a(a: int) -> int: + return a + + +def get_double_b(b: int) -> int: + return 2 * b + + +def get_string(c: Optional[str] = None) -> Optional[str]: + return c + + +router = APIRouter() + + +@generic_cbv(router) +class BaseGenericCBV: + number: int = Depends(None) + + @router.get("/") + async def echo_number(self) -> int: + return self.number + + +other_router = APIRouter() + + +@generic_cbv(other_router) +class GenericCBV(BaseGenericCBV): + string: Optional[str] = Depends(None) + + @router.get("/string") + async def echo_string(self) -> Optional[str]: + return self.string + + +router_a = APIRouter() +router_b = APIRouter() + + +@cbv(router_a) +class CBVA(GenericCBV): + number = Depends(get_a) + string = Depends(get_string) + + +@cbv(router_b) +class CBVB(GenericCBV): + number = Depends(get_double_b) + string = Depends(get_string) + + +app = FastAPI() +app.include_router(router_a, prefix="/a") +app.include_router(router_b, prefix="/b") + + +def test_generic_cbv() -> None: + assert TestClient(app).get("/a/", params={"a": 1}).json() == 1 + assert TestClient(app).get("/b/", params={"b": 1}).json() == 2 + + assert TestClient(app).get("/a/string", params={"a": 1, "c": "hello"}).json() == "hello" + assert TestClient(app).get("/b/string", params={"b": 1, "c": "world"}).json() == "world"