Browse Source

Authentication middleware

master
Julio Biason 2 years ago
parent
commit
9c0f3a4836
  1. 2
      axumtest/Cargo.toml
  2. 43
      axumtest/src/auth/mod.rs
  3. 19
      axumtest/src/main.rs

2
axumtest/Cargo.toml

@ -3,8 +3,6 @@ name = "axumtest"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = { version = "0.5.15", features = ["headers"] }
clap = { version = "3.2.17", features = ["derive", "env"] }

43
axumtest/src/auth/mod.rs

@ -0,0 +1,43 @@
//! "Authentication" middleware.
use axum::headers::HeaderName;
use axum::http::Request;
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::Response;
static CIUSR: HeaderName = HeaderName::from_static("x-ciusr");
static CIPWD: HeaderName = HeaderName::from_static("x-cipwd");
static CIROLE: HeaderName = HeaderName::from_static("x-cirole");
pub async fn ci_auth<B>(
req: Request<B>,
next: Next<B>,
expected_usr: &str,
expected_pwd: &str,
expected_role: &str,
) -> Result<Response, StatusCode> {
let usr = req
.headers()
.get(&CIUSR)
.and_then(|header| header.to_str().ok());
let pwd = req
.headers()
.get(&CIPWD)
.and_then(|header| header.to_str().ok());
let role = req
.headers()
.get(&CIROLE)
.and_then(|header| header.to_str().ok());
tracing::debug!(usr, pwd, role);
match (usr, pwd, role) {
(Some(inc_usr), Some(inc_pwd), Some(inc_role))
if inc_usr == expected_usr && inc_pwd == expected_pwd && inc_role == expected_role =>
{
Ok(next.run(req).await)
}
(_, _, _) => Err(StatusCode::UNAUTHORIZED),
}
}

19
axumtest/src/main.rs

@ -1,12 +1,14 @@
mod auth;
mod headers;
use std::sync::Arc;
use axum::middleware;
use axum::routing::get;
use headers::cipwd::CiPwd;
use headers::cirole::CiRole;
use headers::ciusr::CiUsr;
use axum::routing::get;
use axum::routing::Router;
use axum::TypedHeader;
use clap::Parser;
@ -14,7 +16,6 @@ use dotenv::dotenv;
use mongodb::options::ClientOptions;
use mongodb::Client;
use mongodb::Database;
use tracing_subscriber;
#[derive(Parser)]
struct Params {
@ -50,8 +51,11 @@ async fn main() {
"/collections",
get({
let shared_state = Arc::clone(&state);
move |usr, pwd, role| get_collections(usr, pwd, role, Arc::clone(&shared_state))
}),
move || get_collections(Arc::clone(&shared_state))
})
.route_layer(middleware::from_fn(move |req, next| {
auth::ci_auth(req, next, "A", "A", "A")
})),
);
tracing::info!(args.addr, "Server listening in");
@ -69,11 +73,6 @@ async fn index(
format!("Hellow {}", usr)
}
async fn get_collections(
TypedHeader(usr): TypedHeader<CiUsr>,
TypedHeader(pwd): TypedHeader<CiPwd>,
TypedHeader(role): TypedHeader<CiRole>,
state: Arc<State>,
) -> String {
async fn get_collections(state: Arc<State>) -> String {
format!("Collections")
}

Loading…
Cancel
Save