diff --git a/axumtest/Cargo.toml b/axumtest/Cargo.toml index cead273..e94fb27 100644 --- a/axumtest/Cargo.toml +++ b/axumtest/Cargo.toml @@ -3,8 +3,6 @@ name = "axumtest" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] axum = { version = "0.5.15", features = ["headers"] } clap = { version = "3.2.17", features = ["derive", "env"] } diff --git a/axumtest/src/auth/mod.rs b/axumtest/src/auth/mod.rs new file mode 100644 index 0000000..510347d --- /dev/null +++ b/axumtest/src/auth/mod.rs @@ -0,0 +1,43 @@ +//! "Authentication" middleware. + +use axum::headers::HeaderName; +use axum::http::Request; +use axum::http::StatusCode; +use axum::middleware::Next; +use axum::response::Response; + +static CIUSR: HeaderName = HeaderName::from_static("x-ciusr"); +static CIPWD: HeaderName = HeaderName::from_static("x-cipwd"); +static CIROLE: HeaderName = HeaderName::from_static("x-cirole"); + +pub async fn ci_auth( + req: Request, + next: Next, + expected_usr: &str, + expected_pwd: &str, + expected_role: &str, +) -> Result { + let usr = req + .headers() + .get(&CIUSR) + .and_then(|header| header.to_str().ok()); + let pwd = req + .headers() + .get(&CIPWD) + .and_then(|header| header.to_str().ok()); + let role = req + .headers() + .get(&CIROLE) + .and_then(|header| header.to_str().ok()); + + tracing::debug!(usr, pwd, role); + + match (usr, pwd, role) { + (Some(inc_usr), Some(inc_pwd), Some(inc_role)) + if inc_usr == expected_usr && inc_pwd == expected_pwd && inc_role == expected_role => + { + Ok(next.run(req).await) + } + (_, _, _) => Err(StatusCode::UNAUTHORIZED), + } +} diff --git a/axumtest/src/main.rs b/axumtest/src/main.rs index fca73bb..b6ec2f2 100644 --- a/axumtest/src/main.rs +++ b/axumtest/src/main.rs @@ -1,12 +1,14 @@ +mod auth; mod headers; use std::sync::Arc; +use axum::middleware; +use axum::routing::get; use headers::cipwd::CiPwd; use headers::cirole::CiRole; use headers::ciusr::CiUsr; -use axum::routing::get; use axum::routing::Router; use axum::TypedHeader; use clap::Parser; @@ -14,7 +16,6 @@ use dotenv::dotenv; use mongodb::options::ClientOptions; use mongodb::Client; use mongodb::Database; -use tracing_subscriber; #[derive(Parser)] struct Params { @@ -50,8 +51,11 @@ async fn main() { "/collections", get({ let shared_state = Arc::clone(&state); - move |usr, pwd, role| get_collections(usr, pwd, role, Arc::clone(&shared_state)) - }), + move || get_collections(Arc::clone(&shared_state)) + }) + .route_layer(middleware::from_fn(move |req, next| { + auth::ci_auth(req, next, "A", "A", "A") + })), ); tracing::info!(args.addr, "Server listening in"); @@ -69,11 +73,6 @@ async fn index( format!("Hellow {}", usr) } -async fn get_collections( - TypedHeader(usr): TypedHeader, - TypedHeader(pwd): TypedHeader, - TypedHeader(role): TypedHeader, - state: Arc, -) -> String { +async fn get_collections(state: Arc) -> String { format!("Collections") }