Skip to content

Commit b7f815f

Browse files
committed
axum-extra: implement FromRequest for Either*
1 parent 58f63bb commit b7f815f

File tree

1 file changed

+162
-2
lines changed

1 file changed

+162
-2
lines changed

axum-extra/src/either.rs

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! `Either*` types for combining extractors or responses into a single type.
22
//!
3-
//! # As an extractor
3+
//! # As an `FromRequestParts` extractor
44
//!
55
//! ```
66
//! use axum_extra::either::Either3;
@@ -54,6 +54,42 @@
5454
//! Note that if all the inner extractors reject the request, the rejection from the last
5555
//! extractor will be returned. For the example above that would be [`BytesRejection`].
5656
//!
57+
//! # As an `FromRequest` extractor
58+
//!
59+
//! In the following example, we can first try to deserialize the payload as JSON, if that fails try
60+
//! to interpret it as a UTF-8 string, and lastly just take the raw bytes.
61+
//!
62+
//! It might be preferable to instead extract `Bytes` directly and then fallibly convert them to
63+
//! `String` and then deserialize the data inside the handler.
64+
//!
65+
//! ```
66+
//! use axum_extra::either::Either3;
67+
//! use axum::{
68+
//! body::Bytes,
69+
//! Json,
70+
//! Router,
71+
//! routing::get,
72+
//! extract::FromRequestParts,
73+
//! };
74+
//!
75+
//! #[derive(serde::Deserialize)]
76+
//! struct Payload {
77+
//! user: String,
78+
//! request_id: u32,
79+
//! }
80+
//!
81+
//! async fn handler(
82+
//! body: Either3<Json<Payload>, String, Bytes>,
83+
//! ) {
84+
//! match body {
85+
//! Either3::E1(json) => { /* ... */ }
86+
//! Either3::E2(string) => { /* ... */ }
87+
//! Either3::E3(bytes) => { /* ... */ }
88+
//! }
89+
//! }
90+
//! #
91+
//! # let _: axum::routing::MethodRouter = axum::routing::get(handler);
92+
//! ```
5793
//! # As a response
5894
//!
5995
//! ```
@@ -93,9 +129,10 @@
93129
use std::task::{Context, Poll};
94130

95131
use axum::{
96-
extract::FromRequestParts,
132+
extract::{rejection::BytesRejection, FromRequest, FromRequestParts, Request},
97133
response::{IntoResponse, Response},
98134
};
135+
use bytes::Bytes;
99136
use http::request::Parts;
100137
use tower_layer::Layer;
101138
use tower_service::Service;
@@ -226,6 +263,28 @@ pub enum Either8<E1, E2, E3, E4, E5, E6, E7, E8> {
226263
E8(E8),
227264
}
228265

266+
/// Rejection used for [`Either`], [`Either3`], etc.
267+
///
268+
/// Contains one variant for a case when the whole request could not be loaded and one variant
269+
/// containing the rejection of the last variant if all extractors failed..
270+
#[derive(Debug)]
271+
pub enum EitherRejection<E> {
272+
/// Buffering of the request body failed.
273+
Bytes(BytesRejection),
274+
275+
/// All extractors failed. This contains the error returned by the last extractor.
276+
LastRejection(E),
277+
}
278+
279+
impl<E: IntoResponse> IntoResponse for EitherRejection<E> {
280+
fn into_response(self) -> Response {
281+
match self {
282+
EitherRejection::Bytes(rejection) => rejection.into_response(),
283+
EitherRejection::LastRejection(rejection) => rejection.into_response(),
284+
}
285+
}
286+
}
287+
229288
macro_rules! impl_traits_for_either {
230289
(
231290
$either:ident =>
@@ -251,6 +310,43 @@ macro_rules! impl_traits_for_either {
251310
}
252311
}
253312

313+
impl<S, $($ident),*, $last> FromRequest<S> for $either<$($ident),*, $last>
314+
where
315+
S: Send + Sync,
316+
$($ident: FromRequest<S>),*,
317+
$last: FromRequest<S>,
318+
$($ident::Rejection: Send),*,
319+
$last::Rejection: IntoResponse + Send,
320+
{
321+
type Rejection = EitherRejection<$last::Rejection>;
322+
323+
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
324+
let (parts, body) = req.into_parts();
325+
let bytes = Bytes::from_request(Request::from_parts(parts.clone(), body), state)
326+
.await
327+
.map_err(EitherRejection::Bytes)?;
328+
329+
$(
330+
let req = Request::from_parts(
331+
parts.clone(),
332+
axum::body::Body::new(http_body_util::Full::new(bytes.clone())),
333+
);
334+
if let Ok(extracted) = $ident::from_request(req, state).await {
335+
return Ok(Self::$ident(extracted));
336+
}
337+
)*
338+
339+
let req = Request::from_parts(
340+
parts.clone(),
341+
axum::body::Body::new(http_body_util::Full::new(bytes.clone())),
342+
);
343+
match $last::from_request(req, state).await {
344+
Ok(extracted) => Ok(Self::$last(extracted)),
345+
Err(error) => Err(EitherRejection::LastRejection(error)),
346+
}
347+
}
348+
}
349+
254350
impl<$($ident),*, $last> IntoResponse for $either<$($ident),*, $last>
255351
where
256352
$($ident: IntoResponse),*,
@@ -312,3 +408,67 @@ where
312408
}
313409
}
314410
}
411+
412+
#[cfg(test)]
413+
mod tests {
414+
use std::future::Future;
415+
416+
use axum::body::Body;
417+
use axum::extract::rejection::StringRejection;
418+
use axum::extract::{FromRequest, Request, State};
419+
use bytes::Bytes;
420+
use http_body_util::Full;
421+
422+
use super::*;
423+
424+
struct False;
425+
426+
impl<S> FromRequestParts<S> for False {
427+
type Rejection = ();
428+
429+
fn from_request_parts(
430+
_parts: &mut Parts,
431+
_state: &S,
432+
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
433+
std::future::ready(Err(()))
434+
}
435+
}
436+
437+
#[tokio::test]
438+
async fn either_from_request() {
439+
// The body is by design not valid UTF-8.
440+
let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255]))));
441+
442+
let either = Either4::<String, String, Request, Bytes>::from_request(request, &())
443+
.await
444+
.unwrap();
445+
446+
assert!(matches!(either, Either4::E3(_)));
447+
}
448+
449+
#[tokio::test]
450+
async fn either_from_request_rejection() {
451+
// The body is by design not valid UTF-8.
452+
let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255]))));
453+
454+
let either = Either::<String, String>::from_request(request, &())
455+
.await
456+
.unwrap_err();
457+
458+
assert!(matches!(
459+
either,
460+
EitherRejection::LastRejection(StringRejection::InvalidUtf8(_))
461+
));
462+
}
463+
464+
#[tokio::test]
465+
async fn either_from_request_parts() {
466+
let (mut parts, _) = Request::new(Body::empty()).into_parts();
467+
468+
let either = Either3::<False, False, State<()>>::from_request_parts(&mut parts, &())
469+
.await
470+
.unwrap();
471+
472+
assert!(matches!(either, Either3::E3(State(()))));
473+
}
474+
}

0 commit comments

Comments
 (0)