// // Web Headers and caching // use std::{ io::{Cursor, ErrorKind}, ops::Deref, }; use rocket::{ fairing::{Fairing, Info, Kind}, http::{ContentType, Header, HeaderMap, Method, Status}, request::FromParam, response::{self, Responder}, Data, Orbit, Request, Response, Rocket, }; use tokio::{ runtime::Handle, time::{sleep, Duration}, }; use crate::CONFIG; pub struct AppHeaders(); #[rocket::async_trait] impl Fairing for AppHeaders { fn info(&self) -> Info { Info { name: "Application Headers", kind: Kind::Response, } } async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { let req_uri_path = req.uri().path(); let req_headers = req.headers(); // Check if this connection is an Upgrade/WebSocket connection and return early // We do not want add any extra headers, this could cause issues with reverse proxies or CloudFlare if req_uri_path.ends_with("notifications/hub") || req_uri_path.ends_with("notifications/anonymous-hub") { match (req_headers.get_one("connection"), req_headers.get_one("upgrade")) { (Some(c), Some(u)) if c.to_lowercase().contains("upgrade") && u.to_lowercase().contains("websocket") => { // Remove headers which could cause websocket connection issues res.remove_header("X-Frame-Options"); res.remove_header("X-Content-Type-Options"); return; } (_, _) => (), } } res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()"); res.set_raw_header("Referrer-Policy", "same-origin"); res.set_raw_header("X-Content-Type-Options", "nosniff"); // Obsolete in modern browsers, unsafe (XS-Leak), and largely replaced by CSP res.set_raw_header("X-XSS-Protection", "0"); // Do not send the Content-Security-Policy (CSP) Header and X-Frame-Options for the *-connector.html files. // This can cause issues when some MFA requests needs to open a popup or page within the clients like WebAuthn, or Duo. // This is the same behavior as upstream Bitwarden. if !req_uri_path.ends_with("connector.html") { // # Frame Ancestors: // Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb // Edge Add-ons: https://microsoftedge.microsoft.com/addons/detail/bitwarden-free-password/jbkfoedolllekgbhcbcoahefnbanhhlh?hl=en-US // Firefox Browser Add-ons: https://addons.mozilla.org/en-US/firefox/addon/bitwarden-password-manager/ // # img/child/frame src: // Have I Been Pwned to allow those calls to work. // # Connect src: // Leaked Passwords check: api.pwnedpasswords.com // 2FA/MFA Site check: api.2fa.directory // # Mail Relay: https://bitwarden.com/blog/add-privacy-and-security-using-email-aliases-with-bitwarden/ // app.simplelogin.io, app.addy.io, api.fastmail.com, quack.duckduckgo.com let csp = format!( "default-src 'self'; \ base-uri 'self'; \ form-action 'self'; \ object-src 'self' blob:; \ script-src 'self' 'wasm-unsafe-eval'; \ style-src 'self' 'unsafe-inline'; \ child-src 'self' https://*.duosecurity.com https://*.duofederal.com; \ frame-src 'self' https://*.duosecurity.com https://*.duofederal.com; \ frame-ancestors 'self' \ chrome-extension://nngceckbapebfimnlniiiahkandclblb \ chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh \ moz-extension://* \ {allowed_iframe_ancestors}; \ img-src 'self' data: \ https://haveibeenpwned.com \ {icon_service_csp}; \ connect-src 'self' \ https://api.pwnedpasswords.com \ https://api.2fa.directory \ https://app.simplelogin.io/api/ \ https://app.addy.io/api/ \ https://api.fastmail.com/ \ https://api.forwardemail.net \ ;\ ", icon_service_csp = CONFIG._icon_service_csp(), allowed_iframe_ancestors = CONFIG.allowed_iframe_ancestors() ); res.set_raw_header("Content-Security-Policy", csp); res.set_raw_header("X-Frame-Options", "SAMEORIGIN"); } else { // It looks like this header get's set somewhere else also, make sure this is not sent for these files, it will cause MFA issues. res.remove_header("X-Frame-Options"); } // Disable cache unless otherwise specified if !res.headers().contains("cache-control") { res.set_raw_header("Cache-Control", "no-cache, no-store, max-age=0"); } } } pub struct Cors(); impl Cors { fn get_header(headers: &HeaderMap<'_>, name: &str) -> String { match headers.get_one(name) { Some(h) => h.to_string(), _ => String::new(), } } // Check a request's `Origin` header against the list of allowed origins. // If a match exists, return it. Otherwise, return None. fn get_allowed_origin(headers: &HeaderMap<'_>) -> Option { let origin = Cors::get_header(headers, "Origin"); let domain_origin = CONFIG.domain_origin(); let safari_extension_origin = "file://"; if origin == domain_origin || origin == safari_extension_origin { Some(origin) } else { None } } } #[rocket::async_trait] impl Fairing for Cors { fn info(&self) -> Info { Info { name: "Cors", kind: Kind::Response, } } async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) { let req_headers = request.headers(); if let Some(origin) = Cors::get_allowed_origin(req_headers) { response.set_header(Header::new("Access-Control-Allow-Origin", origin)); } // Preflight request if request.method() == Method::Options { let req_allow_headers = Cors::get_header(req_headers, "Access-Control-Request-Headers"); let req_allow_method = Cors::get_header(req_headers, "Access-Control-Request-Method"); response.set_header(Header::new("Access-Control-Allow-Methods", req_allow_method)); response.set_header(Header::new("Access-Control-Allow-Headers", req_allow_headers)); response.set_header(Header::new("Access-Control-Allow-Credentials", "true")); response.set_status(Status::Ok); response.set_header(ContentType::Plain); response.set_sized_body(Some(0), Cursor::new("")); } } } pub struct Cached { response: R, is_immutable: bool, ttl: u64, } impl Cached { pub fn long(response: R, is_immutable: bool) -> Cached { Self { response, is_immutable, ttl: 604800, // 7 days } } pub fn short(response: R, is_immutable: bool) -> Cached { Self { response, is_immutable, ttl: 600, // 10 minutes } } pub fn ttl(response: R, ttl: u64, is_immutable: bool) -> Cached { Self { response, is_immutable, ttl, } } } impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached { fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> { let mut res = self.response.respond_to(request)?; let cache_control_header = if self.is_immutable { format!("public, immutable, max-age={}", self.ttl) } else { format!("public, max-age={}", self.ttl) }; res.set_raw_header("Cache-Control", cache_control_header); let time_now = chrono::Local::now(); let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap()); res.set_raw_header("Expires", format_datetime_http(&expiry_time)); Ok(res) } } pub struct SafeString(String); impl std::fmt::Display for SafeString { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } impl Deref for SafeString { type Target = String; fn deref(&self) -> &Self::Target { &self.0 } } impl AsRef for SafeString { #[inline] fn as_ref(&self) -> &Path { Path::new(&self.0) } } impl<'r> FromParam<'r> for SafeString { type Error = (); #[inline(always)] fn from_param(param: &'r str) -> Result { if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { Ok(SafeString(param.to_string())) } else { Err(()) } } } // Log all the routes from the main paths list, and the attachments endpoint // Effectively ignores, any static file route, and the alive endpoint const LOGGED_ROUTES: [&str; 7] = ["/api", "/admin", "/identity", "/icons", "/attachments", "/events", "/notifications"]; // Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts pub struct BetterLogging(pub bool); #[rocket::async_trait] impl Fairing for BetterLogging { fn info(&self) -> Info { Info { name: "Better Logging", kind: Kind::Liftoff | Kind::Request | Kind::Response, } } async fn on_liftoff(&self, rocket: &Rocket) { if self.0 { info!(target: "routes", "Routes loaded:"); let mut routes: Vec<_> = rocket.routes().collect(); routes.sort_by_key(|r| r.uri.path()); for route in routes { if route.rank < 0 { info!(target: "routes", "{:<6} {}", route.method, route.uri); } else { info!(target: "routes", "{:<6} {} [{}]", route.method, route.uri, route.rank); } } } let config = rocket.config(); let scheme = if config.tls_enabled() { "https" } else { "http" }; let addr = format!("{}://{}:{}", &scheme, &config.address, &config.port); info!(target: "start", "Rocket has launched from {}", addr); } async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) { let method = request.method(); if !self.0 && method == Method::Options { return; } let uri = request.uri(); let uri_path = uri.path(); let uri_path_str = uri_path.url_decode_lossy(); let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str); if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) { match uri.query() { Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]), None => info!(target: "request", "{} {}", method, uri_path_str), }; } } async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) { if !self.0 && request.method() == Method::Options { return; } let uri_path = request.uri().path(); let uri_path_str = uri_path.url_decode_lossy(); let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str); if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) { let status = response.status(); if let Some(ref route) = request.route() { info!(target: "response", "{} => {}", route, status) } else { info!(target: "response", "{}", status) } } } } // // File handling // use std::{ fs::{self, File}, io::Result as IOResult, path::Path, }; pub fn file_exists(path: &str) -> bool { Path::new(path).exists() } pub fn write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error> { use std::io::Write; let mut f = match File::create(path) { Ok(file) => file, Err(e) => { if e.kind() == ErrorKind::PermissionDenied { error!("Can't create '{}': Permission denied", path); } return Err(From::from(e)); } }; f.write_all(content)?; f.flush()?; Ok(()) } pub fn delete_file(path: &str) -> IOResult<()> { let res = fs::remove_file(path); if let Some(parent) = Path::new(path).parent() { // If the directory isn't empty, this returns an error, which we ignore // We only want to delete the folder if it's empty fs::remove_dir(parent).ok(); } res } pub fn get_display_size(size: i32) -> String { const UNITS: [&str; 6] = ["bytes", "KB", "MB", "GB", "TB", "PB"]; let mut size: f64 = size.into(); let mut unit_counter = 0; loop { if size > 1024. { size /= 1024.; unit_counter += 1; } else { break; } } format!("{:.2} {}", size, UNITS[unit_counter]) } pub fn get_uuid() -> String { uuid::Uuid::new_v4().to_string() } // // String util methods // use std::str::FromStr; #[inline] pub fn upcase_first(s: &str) -> String { let mut c = s.chars(); match c.next() { None => String::new(), Some(f) => f.to_uppercase().collect::() + c.as_str(), } } #[inline] pub fn lcase_first(s: &str) -> String { let mut c = s.chars(); match c.next() { None => String::new(), Some(f) => f.to_lowercase().collect::() + c.as_str(), } } pub fn try_parse_string(string: Option) -> Option where S: AsRef, T: FromStr, { if let Some(Ok(value)) = string.map(|s| s.as_ref().parse::()) { Some(value) } else { None } } // // Env methods // use std::env; pub fn get_env_str_value(key: &str) -> Option { let key_file = format!("{key}_FILE"); let value_from_env = env::var(key); let value_file = env::var(&key_file); match (value_from_env, value_file) { (Ok(_), Ok(_)) => panic!("You should not define both {key} and {key_file}!"), (Ok(v_env), Err(_)) => Some(v_env), (Err(_), Ok(v_file)) => match fs::read_to_string(v_file) { Ok(content) => Some(content.trim().to_string()), Err(e) => panic!("Failed to load {key}: {e:?}"), }, _ => None, } } pub fn get_env(key: &str) -> Option where V: FromStr, { try_parse_string(get_env_str_value(key)) } pub fn get_env_bool(key: &str) -> Option { const TRUE_VALUES: &[&str] = &["true", "t", "yes", "y", "1"]; const FALSE_VALUES: &[&str] = &["false", "f", "no", "n", "0"]; match get_env_str_value(key) { Some(val) if TRUE_VALUES.contains(&val.to_lowercase().as_ref()) => Some(true), Some(val) if FALSE_VALUES.contains(&val.to_lowercase().as_ref()) => Some(false), _ => None, } } // // Date util methods // use chrono::{DateTime, Local, NaiveDateTime, TimeZone}; // Format used by Bitwarden API const DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.6fZ"; /// Formats a UTC-offset `NaiveDateTime` in the format used by Bitwarden API /// responses with "date" fields (`CreationDate`, `RevisionDate`, etc.). pub fn format_date(dt: &NaiveDateTime) -> String { dt.format(DATETIME_FORMAT).to_string() } /// Formats a `DateTime` using the specified format string. /// /// For a `DateTime`, the `%Z` specifier normally formats as the /// time zone's UTC offset (e.g., `+00:00`). In this function, if the /// `TZ` environment variable is set, then `%Z` instead formats as the /// abbreviation for that time zone (e.g., `UTC`). pub fn format_datetime_local(dt: &DateTime, fmt: &str) -> String { // Try parsing the `TZ` environment variable to enable formatting `%Z` as // a time zone abbreviation. if let Ok(tz) = env::var("TZ") { if let Ok(tz) = tz.parse::() { return dt.with_timezone(&tz).format(fmt).to_string(); } } // Otherwise, fall back to formatting `%Z` as a UTC offset. dt.format(fmt).to_string() } /// Formats a UTC-offset `NaiveDateTime` as a datetime in the local time zone. /// /// This function basically converts the `NaiveDateTime` to a `DateTime`, /// and then calls [format_datetime_local](crate::util::format_datetime_local). pub fn format_naive_datetime_local(dt: &NaiveDateTime, fmt: &str) -> String { format_datetime_local(&Local.from_utc_datetime(dt), fmt) } /// Formats a `DateTime` as required for HTTP /// /// https://httpwg.org/specs/rfc7231.html#http.date pub fn format_datetime_http(dt: &DateTime) -> String { let expiry_time = DateTime::::from_naive_utc_and_offset(dt.naive_utc(), chrono::Utc); // HACK: HTTP expects the date to always be GMT (UTC) rather than giving an // offset (which would always be 0 in UTC anyway) expiry_time.to_rfc2822().replace("+0000", "GMT") } pub fn parse_date(date: &str) -> NaiveDateTime { NaiveDateTime::parse_from_str(date, DATETIME_FORMAT).unwrap() } // // Deployment environment methods // /// Returns true if the program is running in Docker or Podman. pub fn is_running_in_docker() -> bool { Path::new("/.dockerenv").exists() || Path::new("/run/.containerenv").exists() } /// Simple check to determine on which docker base image vaultwarden is running. /// We build images based upon Debian or Alpine, so these we check here. pub fn docker_base_image() -> &'static str { if Path::new("/etc/debian_version").exists() { "Debian" } else if Path::new("/etc/alpine-release").exists() { "Alpine" } else { "Unknown" } } // // Deserialization methods // use std::fmt; use serde::de::{self, DeserializeOwned, Deserializer, MapAccess, SeqAccess, Visitor}; use serde_json::{self, Value}; pub type JsonMap = serde_json::Map; #[derive(Serialize, Deserialize)] pub struct UpCase { #[serde(deserialize_with = "upcase_deserialize")] #[serde(flatten)] pub data: T, } // https://github.com/serde-rs/serde/issues/586 pub fn upcase_deserialize<'de, T, D>(deserializer: D) -> Result where T: DeserializeOwned, D: Deserializer<'de>, { let d = deserializer.deserialize_any(UpCaseVisitor)?; T::deserialize(d).map_err(de::Error::custom) } struct UpCaseVisitor; impl<'de> Visitor<'de> for UpCaseVisitor { type Value = Value; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter.write_str("an object or an array") } fn visit_map(self, mut map: A) -> Result where A: MapAccess<'de>, { let mut result_map = JsonMap::new(); while let Some((key, value)) = map.next_entry()? { result_map.insert(upcase_first(key), upcase_value(value)); } Ok(Value::Object(result_map)) } fn visit_seq(self, mut seq: A) -> Result where A: SeqAccess<'de>, { let mut result_seq = Vec::::new(); while let Some(value) = seq.next_element()? { result_seq.push(upcase_value(value)); } Ok(Value::Array(result_seq)) } } fn upcase_value(value: Value) -> Value { if let Value::Object(map) = value { let mut new_value = Value::Object(serde_json::Map::new()); for (key, val) in map.into_iter() { let processed_key = _process_key(&key); new_value[processed_key] = upcase_value(val); } new_value } else if let Value::Array(array) = value { // Initialize array with null values let mut new_value = Value::Array(vec![Value::Null; array.len()]); for (index, val) in array.into_iter().enumerate() { new_value[index] = upcase_value(val); } new_value } else { value } } // Inner function to handle a special case for the 'ssn' key. // This key is part of the Identity Cipher (Social Security Number) fn _process_key(key: &str) -> String { match key.to_lowercase().as_ref() { "ssn" => "SSN".into(), _ => self::upcase_first(key), } } // // Retry methods // pub fn retry(mut func: F, max_tries: u32) -> Result where F: FnMut() -> Result, { let mut tries = 0; loop { match func() { ok @ Ok(_) => return ok, err @ Err(_) => { tries += 1; if tries >= max_tries { return err; } Handle::current().block_on(sleep(Duration::from_millis(500))); } } } } pub async fn retry_db(mut func: F, max_tries: u32) -> Result where F: FnMut() -> Result, E: std::error::Error, { let mut tries = 0; loop { match func() { ok @ Ok(_) => return ok, Err(e) => { tries += 1; if tries >= max_tries && max_tries > 0 { return Err(e); } warn!("Can't connect to database, retrying: {:?}", e); sleep(Duration::from_millis(1_000)).await; } } } } use reqwest::{header, Client, ClientBuilder}; pub fn get_reqwest_client() -> Client { match get_reqwest_client_builder().build() { Ok(client) => client, Err(e) => { error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'"); get_reqwest_client_builder().trust_dns(false).build().expect("Failed to build client") } } } pub fn get_reqwest_client_builder() -> ClientBuilder { let mut headers = header::HeaderMap::new(); headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden")); Client::builder().default_headers(headers).timeout(Duration::from_secs(10)) } pub fn convert_json_key_lcase_first(src_json: Value) -> Value { match src_json { Value::Array(elm) => { let mut new_array: Vec = Vec::with_capacity(elm.len()); for obj in elm { new_array.push(convert_json_key_lcase_first(obj)); } Value::Array(new_array) } Value::Object(obj) => { let mut json_map = JsonMap::new(); for (key, value) in obj.iter() { match (key, value) { (key, Value::Object(elm)) => { let inner_value = convert_json_key_lcase_first(Value::Object(elm.clone())); json_map.insert(lcase_first(key), inner_value); } (key, Value::Array(elm)) => { let mut inner_array: Vec = Vec::with_capacity(elm.len()); for inner_obj in elm { inner_array.push(convert_json_key_lcase_first(inner_obj.clone())); } json_map.insert(lcase_first(key), Value::Array(inner_array)); } (key, value) => { json_map.insert(lcase_first(key), value.clone()); } } } Value::Object(json_map) } value => value, } }