Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions logfire/_internal/integrations/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def instrument_fastapi(
excluded_urls: str | Iterable[str] | None = None,
record_send_receive: bool = False,
extra_spans: bool = False,
record_handled_exceptions: bool = True,
**opentelemetry_kwargs: Any,
) -> AbstractContextManager[None]:
"""Instrument a FastAPI app so that spans and logs are automatically created for each request.
Expand Down Expand Up @@ -100,6 +101,7 @@ def instrument_fastapi(
logfire_instance,
request_attributes_mapper or _default_request_attributes_mapper,
extra_spans=extra_spans,
record_handled_exceptions=record_handled_exceptions,
)

@contextmanager
Expand Down Expand Up @@ -163,14 +165,16 @@ def __init__(
dict[str, Any] | None,
],
extra_spans: bool,
record_handled_exceptions: bool = True,
):
self.logfire_instance = logfire_instance.with_settings(custom_scope_suffix='fastapi')
self.timestamp_generator = self.logfire_instance.config.advanced.ns_timestamp_generator
self.request_attributes_mapper = request_attributes_mapper
self.extra_spans = extra_spans
self.record_handled_exceptions = record_handled_exceptions

@contextmanager
def pseudo_span(self, namespace: str, root_span: Span):
def pseudo_span(self, namespace: str, root_span: Span, request: Request | WebSocket | None = None):
"""Record start and end timestamps in the root span, and possibly exceptions."""

def set_timestamp(attribute_name: str):
Expand All @@ -179,15 +183,48 @@ def set_timestamp(attribute_name: str):
root_span.set_attribute(f'fastapi.{namespace}.{attribute_name}', value)

set_timestamp('start_timestamp')
exception_to_check = None
try:
try:
yield
finally:
# Record the end timestamp before recording exceptions.
set_timestamp('end_timestamp')
except Exception as exc:
root_span.record_exception(exc)
exception_to_check = exc
raise
finally:
if exception_to_check and self._should_record_exception(exception_to_check, request):
root_span.record_exception(exception_to_check)

def _is_exception_handled(self, exc: Exception, request: Request | WebSocket) -> bool:
"""Check if the exception will be handled by a FastAPI exception handler."""
app = request.app
exc_type = type(exc)

# Check direct type handlers
if exc_type in app.exception_handlers:
return True

# Check inheritance chain handlers
for handler_type in app.exception_handlers:
if isinstance(exc, handler_type):
return True

return False

def _should_record_exception(
self,
exc: Exception,
request: Request | WebSocket | None = None,
) -> bool:
if self.record_handled_exceptions:
return True

if request and self._is_exception_handled(exc, request):
return False

return True

async def solve_dependencies(self, request: Request | WebSocket, original: Awaitable[Any]) -> Any:
root_span = request.scope.get(LOGFIRE_SPAN_SCOPE_KEY)
Expand All @@ -207,7 +244,7 @@ async def solve_dependencies(self, request: Request | WebSocket, original: Await
set_user_attributes_on_raw_span(root_span, fastapi_route_attributes)
span.set_attributes(fastapi_route_attributes)

with self.pseudo_span('arguments', root_span):
with self.pseudo_span('arguments', root_span, request):
result: Any = await original

with handle_internal_errors:
Expand Down Expand Up @@ -297,7 +334,7 @@ async def run_endpoint_function(
)
else:
extra_span = NoopSpan()
with extra_span, self.pseudo_span('endpoint_function', root_span):
with extra_span, self.pseudo_span('endpoint_function', root_span, request):
return await original


Expand Down
5 changes: 5 additions & 0 deletions logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,7 @@ def instrument_fastapi(
excluded_urls: str | Iterable[str] | None = None,
record_send_receive: bool = False,
extra_spans: bool = False,
record_handled_exceptions: bool = True,
**opentelemetry_kwargs: Any,
) -> AbstractContextManager[None]:
"""Instrument a FastAPI app so that spans and logs are automatically created for each request.
Expand Down Expand Up @@ -1084,6 +1085,9 @@ def instrument_fastapi(
since many can be created for a single request, and they are not often useful.
If enabled, they will be set to debug level, meaning they will usually still be hidden in the UI.
extra_spans: Whether to include the extra 'FastAPI arguments' and 'endpoint function' spans.
record_handled_exceptions: Set to `False` to prevent recording handled exceptions.

When `False`, exceptions that have registered exception handlers in FastAPI are not recorded as errors.
opentelemetry_kwargs: Additional keyword arguments to pass to the OpenTelemetry FastAPI instrumentation.

Returns:
Expand All @@ -1104,6 +1108,7 @@ def instrument_fastapi(
excluded_urls=excluded_urls,
record_send_receive=record_send_receive,
extra_spans=extra_spans,
record_handled_exceptions=record_handled_exceptions,
**opentelemetry_kwargs,
)

Expand Down
175 changes: 174 additions & 1 deletion tests/otel_integrations/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from inline_snapshot import snapshot
from opentelemetry.propagate import inject
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.testclient import TestClient

import logfire
Expand All @@ -27,6 +27,20 @@
from logfire.testing import TestExporter


class CustomError(Exception):
"""Custom exception for testing."""

pass


def add_custom_error_handler(app: FastAPI) -> None:
"""Add custom error handler to FastAPI app."""

@app.exception_handler(CustomError)
async def custom_error_handler(request: Request, exc: CustomError): # pyright: ignore[reportUnusedFunction]
return JSONResponse(status_code=400, content={'detail': str(exc)})


def test_missing_opentelemetry_dependency() -> None:
with mock.patch.dict('sys.modules', {'opentelemetry.instrumentation.fastapi': None}):
with pytest.raises(RuntimeError) as exc_info:
Expand Down Expand Up @@ -78,6 +92,19 @@ async def bad_request_error():
raise HTTPException(400)


async def custom_error_manual():
"""Endpoint that manually handles custom exceptions."""
try:
raise CustomError('User not found')
except CustomError as exc:
return JSONResponse(status_code=404, content={'detail': str(exc)})


async def custom_error_unhandled():
"""Endpoint that raises custom exceptions without handling."""
raise CustomError('Unhandled custom error')


async def websocket_endpoint(websocket: WebSocket, name: str):
logfire.info('websocket_endpoint: {name}', name=name)
await websocket.accept()
Expand Down Expand Up @@ -113,7 +140,12 @@ def app():
app.get('/with_path_param/{param}')(with_path_param)
app.get('/secret/{path_param}', name='secret')(get_secret)
app.get('/bad_dependency_route/{good}')(bad_dependency_route)
app.get('/custom_error_manual')(custom_error_manual)
app.get('/custom_error_unhandled')(custom_error_unhandled)
app.websocket('/ws/{name}')(websocket_endpoint)

add_custom_error_handler(app)

first_lvl_app.get('/other', name='other_route_name', operation_id='other_route_operation_id')(other_route)
second_lvl_app.get('/other', name='other_route_name', operation_id='other_route_operation_id')(other_route)
return app
Expand Down Expand Up @@ -2334,3 +2366,144 @@ def test_sampled_out(client: TestClient, exporter: TestExporter, config_kwargs:
make_request_hook_spans(record_send_receive=False)

assert exporter.exported_spans_as_dict() == []


def test_custom_error_with_exception_handler_default_behavior(exporter: TestExporter):
"""Test that custom exceptions handled by FastAPI handlers are recorded by default (backward compatibility)."""
app = FastAPI()
app.get('/custom_error_unhandled')(custom_error_unhandled)

add_custom_error_handler(app)

logfire.instrument_fastapi(app)
client = TestClient(app)

response = client.get('/custom_error_unhandled')
assert response.status_code == 400

spans = exporter.exported_spans_as_dict()

# Should have recorded the exception (default behavior)
exception_spans = [s for s in spans if 'events' in s and s['events']]
assert len(exception_spans) > 0

exception_event = exception_spans[0]['events'][0]
assert exception_event['name'] == 'exception'
assert 'CustomError' in exception_event['attributes']['exception.type']


def test_http_exception_default_behavior(exporter: TestExporter):
"""Test that HTTPExceptions are recorded by default."""
app = FastAPI()
app.get('/bad_request_error')(bad_request_error)

logfire.instrument_fastapi(app)
client = TestClient(app)

response = client.get('/bad_request_error')
assert response.status_code == 400

spans = exporter.exported_spans_as_dict()

# Should have recorded the exception (default behavior)
exception_spans = [s for s in spans if 'events' in s and s['events']]
assert len(exception_spans) > 0

exception_event = exception_spans[0]['events'][0]
assert exception_event['name'] == 'exception'
assert 'HTTPException' in exception_event['attributes']['exception.type']


def test_custom_error_with_exception_handler_record_handled_false(exporter: TestExporter):
"""Test that custom exceptions with handlers are NOT recorded when record_handled_exceptions=False.

With the new handler detection logic, exceptions that have registered handlers
are considered "handled" and are filtered out to avoid noise.
"""
app = FastAPI()
app.get('/custom_error_unhandled')(custom_error_unhandled)

add_custom_error_handler(app)

logfire.instrument_fastapi(app, record_handled_exceptions=False)
client = TestClient(app)

response = client.get('/custom_error_unhandled')
assert response.status_code == 400

spans = exporter.exported_spans_as_dict()

# Custom exceptions with handlers are now filtered out (handled exceptions)
exception_spans = [s for s in spans if 'events' in s and s['events']]
assert len(exception_spans) == 0


def test_custom_error_without_handler_record_handled_false(exporter: TestExporter):
"""Test that custom exceptions WITHOUT handlers are still recorded when record_handled_exceptions=False.

Exceptions without handlers are truly unhandled and should always be logged.
"""
app = FastAPI()

async def unhandled_custom_error():
raise CustomError('This exception has no handler')

app.get('/unhandled_custom_error')(unhandled_custom_error)
# Note: NOT adding the custom error handler

logfire.instrument_fastapi(app, record_handled_exceptions=False)
client = TestClient(app)

# This will raise the exception because there's no handler for CustomError
with pytest.raises(CustomError):
client.get('/unhandled_custom_error')

spans = exporter.exported_spans_as_dict()

# Unhandled custom exceptions should still be recorded
exception_spans = [s for s in spans if 'events' in s and s['events']]
assert len(exception_spans) > 0

exception_event = exception_spans[0]['events'][0]
assert exception_event['name'] == 'exception'
assert 'CustomError' in exception_event['attributes']['exception.type']


def test_http_exception_record_handled_false(exporter: TestExporter):
"""Test that HTTPExceptions are not recorded when record_handled_exceptions=False."""
app = FastAPI()
app.get('/bad_request_error')(bad_request_error)

logfire.instrument_fastapi(app, record_handled_exceptions=False)
client = TestClient(app)

response = client.get('/bad_request_error')
assert response.status_code == 400

spans = exporter.exported_spans_as_dict()

# Should not have recorded the exception
exception_spans = [s for s in spans if 'events' in s and s['events']]
assert len(exception_spans) == 0


def test_unhandled_exception_always_recorded(exporter: TestExporter):
"""Test that truly unhandled exceptions are always recorded, regardless of record_handled_exceptions setting."""
app = FastAPI()
app.get('/exception')(exception)

logfire.instrument_fastapi(app, record_handled_exceptions=False)
client = TestClient(app)

with pytest.raises(ValueError):
client.get('/exception')

spans = exporter.exported_spans_as_dict()

# Should have recorded the exception (it's unhandled)
exception_spans = [s for s in spans if 'events' in s and s['events']]
assert len(exception_spans) > 0

exception_event = exception_spans[0]['events'][0]
assert exception_event['name'] == 'exception'
assert 'ValueError' in exception_event['attributes']['exception.type']
Loading