Skip to content
14 changes: 14 additions & 0 deletions docs/api-guide/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,20 @@ operationIds.
In order to work around this, you can override `get_operation_id_base()` to
provide a different base for name part of the ID.

#### `get_serializer()`

If the view has implemented `get_serializer()`, returns the result.

#### `get_request_serializer()`

By default returns `get_serializer()` but can be overridden to
differentiate between request and response objects.

#### `get_response_serializer()`

By default returns `get_serializer()` but can be overridden to
differentiate between request and response objects.

### `AutoSchema.__init__()` kwargs

`AutoSchema` provides a number of `__init__()` kwargs that can be used for
Expand Down
37 changes: 29 additions & 8 deletions rest_framework/schemas/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,22 @@ def get_components(self, path, method):
if method.lower() == 'delete':
return {}

serializer = self.get_serializer(path, method)
request_serializer = self.get_request_serializer(path, method)
response_serializer = self.get_response_serializer(path, method)

if not isinstance(serializer, serializers.Serializer):
return {}
components = {}

if isinstance(request_serializer, serializers.Serializer):
component_name = self.get_component_name(request_serializer)
content = self.map_serializer(request_serializer)
components.setdefault(component_name, content)

component_name = self.get_component_name(serializer)
if isinstance(response_serializer, serializers.Serializer):
component_name = self.get_component_name(response_serializer)
content = self.map_serializer(response_serializer)
components.setdefault(component_name, content)

content = self.map_serializer(serializer)
return {component_name: content}
return components

def _to_camel_case(self, snake_str):
components = snake_str.split('_')
Expand Down Expand Up @@ -615,6 +622,20 @@ def get_serializer(self, path, method):
.format(view.__class__.__name__, method, path))
return None

def get_request_serializer(self, path, method):
"""
Override this method if your view uses a different serializer for
handling request body.
"""
return self.get_serializer(path, method)

def get_response_serializer(self, path, method):
"""
Override this method if your view uses a different serializer for
populating response data.
"""
return self.get_serializer(path, method)

def _get_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}

Expand All @@ -624,7 +645,7 @@ def get_request_body(self, path, method):

self.request_media_types = self.map_parsers(path, method)

serializer = self.get_serializer(path, method)
serializer = self.get_request_serializer(path, method)

if not isinstance(serializer, serializers.Serializer):
item_schema = {}
Expand All @@ -648,7 +669,7 @@ def get_responses(self, path, method):

self.response_media_types = self.map_renderers(path, method)

serializer = self.get_serializer(path, method)
serializer = self.get_response_serializer(path, method)

if not isinstance(serializer, serializers.Serializer):
item_schema = {}
Expand Down
85 changes: 85 additions & 0 deletions tests/schemas/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,91 @@ def get_operation_id_base(self, path, method, action):
operationId = inspector.get_operation_id(path, method)
assert operationId == 'listItem'

def test_different_request_response_objects(self):
class RequestSerializer(serializers.Serializer):
text = serializers.CharField()

class ResponseSerializer(serializers.Serializer):
text = serializers.BooleanField()

class CustomSchema(AutoSchema):
def get_request_serializer(self, path, method):
return RequestSerializer()

def get_response_serializer(self, path, method):
return ResponseSerializer()

path = '/'
method = 'POST'
view = create_view(
views.ExampleGenericAPIView,
method,
create_request(path),
)
inspector = CustomSchema()
inspector.view = view

components = inspector.get_components(path, method)
assert components == {
'Request': {
'properties': {
'text': {
'type': 'string'
}
},
'required': ['text'],
'type': 'object'
},
'Response': {
'properties': {
'text': {
'type': 'boolean'
}
},
'required': ['text'],
'type': 'object'
}
}

operation = inspector.get_operation(path, method)
assert operation == {
'operationId': 'createExample',
'description': '',
'parameters': [],
'requestBody': {
'content': {
'application/json': {
'schema': {
'$ref': '#/components/schemas/Request'
}
},
'application/x-www-form-urlencoded': {
'schema': {
'$ref': '#/components/schemas/Request'
}
},
'multipart/form-data': {
'schema': {
'$ref': '#/components/schemas/Request'
}
}
}
},
'responses': {
'201': {
'content': {
'application/json': {
'schema': {
'$ref': '#/components/schemas/Response'
}
}
},
'description': ''
}
},
'tags': ['']
}

def test_repeat_operation_ids(self):
router = routers.SimpleRouter()
router.register('account', views.ExampleGenericViewSet, basename="account")
Expand Down