diff --git a/axumtest/Cargo.lock b/axumtest/Cargo.lock index 3164176..cb3d47e 100644 --- a/axumtest/Cargo.lock +++ b/axumtest/Cargo.lock @@ -113,6 +113,8 @@ dependencies = [ "clap", "dotenv", "mongodb", + "serde", + "serde_json", "tokio", "tracing", "tracing-subscriber", diff --git a/axumtest/Cargo.toml b/axumtest/Cargo.toml index e94fb27..faeef37 100644 --- a/axumtest/Cargo.toml +++ b/axumtest/Cargo.toml @@ -8,6 +8,8 @@ axum = { version = "0.5.15", features = ["headers"] } clap = { version = "3.2.17", features = ["derive", "env"] } dotenv = "0.15.0" mongodb = "2.3.0" +serde = { version = "1.0.143", features = ["derive"] } +serde_json = "1.0.83" tokio = { version = "1.20.1", features = ["macros", "rt-multi-thread"] } tracing = "0.1.36" tracing-subscriber = "0.3.15" diff --git a/axumtest/src/auth/mod.rs b/axumtest/src/auth/mod.rs index 510347d..879aaa9 100644 --- a/axumtest/src/auth/mod.rs +++ b/axumtest/src/auth/mod.rs @@ -13,9 +13,9 @@ static CIROLE: HeaderName = HeaderName::from_static("x-cirole"); pub async fn ci_auth( req: Request, next: Next, - expected_usr: &str, - expected_pwd: &str, - expected_role: &str, + expected_usr: String, + expected_pwd: String, + expected_role: String, ) -> Result { let usr = req .headers() diff --git a/axumtest/src/main.rs b/axumtest/src/main.rs index b6ec2f2..4d112c7 100644 --- a/axumtest/src/main.rs +++ b/axumtest/src/main.rs @@ -1,16 +1,12 @@ mod auth; -mod headers; use std::sync::Arc; use axum::middleware; use axum::routing::get; -use headers::cipwd::CiPwd; -use headers::cirole::CiRole; -use headers::ciusr::CiUsr; +use axum::Json; use axum::routing::Router; -use axum::TypedHeader; use clap::Parser; use dotenv::dotenv; use mongodb::options::ClientOptions; @@ -24,11 +20,18 @@ struct Params { #[clap(short, long, env = "MONGO_URI")] mongo_addr: String, -} + #[clap(long, env = "CIUSR")] + ci_usr: String, + + #[clap(long, env = "CIPWD")] + ci_pwd: String, + + #[clap(long, env = "CIROLE")] + ci_role: String, +} struct State { db: Database, - collections: Vec, } #[tokio::main] @@ -45,7 +48,9 @@ async fn main() { let collections = db.list_collection_names(None).await.unwrap(); tracing::debug!("Collections = {:?}", &collections); - let state = Arc::new(State { db, collections }); + let state = Arc::new(State { + db, + }); let app = Router::new().route("/", get(index)).route( "/collections", @@ -54,7 +59,11 @@ async fn main() { move || get_collections(Arc::clone(&shared_state)) }) .route_layer(middleware::from_fn(move |req, next| { - auth::ci_auth(req, next, "A", "A", "A") + let ci_usr = args.ci_usr.clone(); + let ci_pwd = args.ci_pwd.clone(); + let ci_role = args.ci_role.clone(); + + auth::ci_auth(req, next, ci_usr, ci_pwd, ci_role) })), ); @@ -65,14 +74,11 @@ async fn main() { .unwrap(); } -async fn index( - TypedHeader(usr): TypedHeader, - TypedHeader(_pwd): TypedHeader, - TypedHeader(_role): TypedHeader, -) -> String { - format!("Hellow {}", usr) +async fn index() -> String { + format!("Hellow") } -async fn get_collections(state: Arc) -> String { - format!("Collections") +async fn get_collections(state: Arc) -> Json> { + let collections = state.db.list_collection_names(None).await.unwrap(); + Json(collections) }