diff --git a/src/auth.rs b/src/auth.rs index 7eabbc1e..637dec8a 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -360,20 +360,38 @@ use crate::db::{ DbConn, }; -pub struct Host { - pub host: String, +pub struct Domain { + pub domain: String, } #[rocket::async_trait] -impl<'r> FromRequest<'r> for Host { +impl<'r> FromRequest<'r> for Domain { type Error = &'static str; async fn from_request(request: &'r Request<'_>) -> Outcome { let headers = request.headers(); // Get host - let host = if CONFIG.domain_set() { - CONFIG.domain() + // TODO: UPDATE THIS SECTION + let domain = if CONFIG.domain_set() { + let host = if let Some(host) = headers.get_one("X-Forwarded-Host") { + host + } else if let Some(host) = headers.get_one("Host") { + host + } else { + // TODO fix error handling + // This is probably a 400 bad request, + // because http requests require the host header + todo!() + }; + + let Some(domain) = CONFIG.host_to_domain(host) else { + // TODO fix error handling + // This is probably a 421 misdirected request. + todo!() + }; + + domain } else if let Some(referer) = headers.get_one("Referer") { referer.to_string() } else { @@ -399,14 +417,14 @@ impl<'r> FromRequest<'r> for Host { format!("{protocol}://{host}") }; - Outcome::Success(Host { - host, + Outcome::Success(Domain { + domain, }) } } pub struct ClientHeaders { - pub host: String, + pub domain: String, pub device_type: i32, pub ip: ClientIp, } @@ -416,7 +434,7 @@ impl<'r> FromRequest<'r> for ClientHeaders { type Error = &'static str; async fn from_request(request: &'r Request<'_>) -> Outcome { - let host = try_outcome!(Host::from_request(request).await).host; + let domain = try_outcome!(Domain::from_request(request).await).domain; let ip = match ClientIp::from_request(request).await { Outcome::Success(ip) => ip, _ => err_handler!("Error getting Client IP"), @@ -426,7 +444,7 @@ impl<'r> FromRequest<'r> for ClientHeaders { request.headers().get_one("device-type").map(|d| d.parse().unwrap_or(14)).unwrap_or_else(|| 14); Outcome::Success(ClientHeaders { - host, + domain, device_type, ip, }) @@ -434,7 +452,7 @@ impl<'r> FromRequest<'r> for ClientHeaders { } pub struct Headers { - pub host: String, + pub domain: String, pub device: Device, pub user: User, pub ip: ClientIp, @@ -447,7 +465,7 @@ impl<'r> FromRequest<'r> for Headers { async fn from_request(request: &'r Request<'_>) -> Outcome { let headers = request.headers(); - let host = try_outcome!(Host::from_request(request).await).host; + let domain = try_outcome!(Domain::from_request(request).await).domain; let ip = match ClientIp::from_request(request).await { Outcome::Success(ip) => ip, _ => err_handler!("Error getting Client IP"), @@ -518,7 +536,7 @@ impl<'r> FromRequest<'r> for Headers { } Outcome::Success(Headers { - host, + domain, device, user, ip, @@ -527,7 +545,7 @@ impl<'r> FromRequest<'r> for Headers { } pub struct OrgHeaders { - pub host: String, + pub domain: String, pub device: Device, pub user: User, pub org_user_type: UserOrgType, @@ -583,7 +601,7 @@ impl<'r> FromRequest<'r> for OrgHeaders { }; Outcome::Success(Self { - host: headers.host, + domain: headers.domain, device: headers.device, user, org_user_type: { @@ -605,7 +623,7 @@ impl<'r> FromRequest<'r> for OrgHeaders { } pub struct AdminHeaders { - pub host: String, + pub domain: String, pub device: Device, pub user: User, pub org_user_type: UserOrgType, @@ -622,7 +640,7 @@ impl<'r> FromRequest<'r> for AdminHeaders { let client_version = request.headers().get_one("Bitwarden-Client-Version").map(String::from); if headers.org_user_type >= UserOrgType::Admin { Outcome::Success(Self { - host: headers.host, + domain: headers.domain, device: headers.device, user: headers.user, org_user_type: headers.org_user_type, @@ -638,7 +656,7 @@ impl<'r> FromRequest<'r> for AdminHeaders { impl From for Headers { fn from(h: AdminHeaders) -> Headers { Headers { - host: h.host, + domain: h.domain, device: h.device, user: h.user, ip: h.ip, @@ -669,7 +687,7 @@ fn get_col_id(request: &Request<'_>) -> Option { /// and have access to the specific collection provided via the /collections/collectionId. /// This does strict checking on the collection_id, ManagerHeadersLoose does not. pub struct ManagerHeaders { - pub host: String, + pub domain: String, pub device: Device, pub user: User, pub org_user_type: UserOrgType, @@ -698,7 +716,7 @@ impl<'r> FromRequest<'r> for ManagerHeaders { } Outcome::Success(Self { - host: headers.host, + domain: headers.domain, device: headers.device, user: headers.user, org_user_type: headers.org_user_type, @@ -713,7 +731,7 @@ impl<'r> FromRequest<'r> for ManagerHeaders { impl From for Headers { fn from(h: ManagerHeaders) -> Headers { Headers { - host: h.host, + domain: h.domain, device: h.device, user: h.user, ip: h.ip, @@ -724,7 +742,7 @@ impl From for Headers { /// The ManagerHeadersLoose is used when you at least need to be a Manager, /// but there is no collection_id sent with the request (either in the path or as form data). pub struct ManagerHeadersLoose { - pub host: String, + pub domain: String, pub device: Device, pub user: User, pub org_user: UserOrganization, @@ -740,7 +758,7 @@ impl<'r> FromRequest<'r> for ManagerHeadersLoose { let headers = try_outcome!(OrgHeaders::from_request(request).await); if headers.org_user_type >= UserOrgType::Manager { Outcome::Success(Self { - host: headers.host, + domain: headers.domain, device: headers.device, user: headers.user, org_user: headers.org_user, @@ -756,7 +774,7 @@ impl<'r> FromRequest<'r> for ManagerHeadersLoose { impl From for Headers { fn from(h: ManagerHeadersLoose) -> Headers { Headers { - host: h.host, + domain: h.domain, device: h.device, user: h.user, ip: h.ip, @@ -784,7 +802,7 @@ impl ManagerHeaders { } Ok(ManagerHeaders { - host: h.host, + domain: h.domain, device: h.device, user: h.user, org_user_type: h.org_user_type,