From 7d0e234b34c830eae63a713177f4bea310a8ae2d Mon Sep 17 00:00:00 2001 From: Jeremy Lin Date: Sun, 7 Mar 2021 00:35:08 -0800 Subject: [PATCH] CORS fixes * The Safari extension apparently now uses the origin `file://` and expects that to be returned (see bitwarden/browser#1311, bitwarden/server#800). * The `Access-Control-Allow-Origin` header was reflecting the value of the `Origin` header without checking whether the origin was actually allowed. This effectively allows any origin to interact with the server, which defeats the purpose of CORS. --- src/util.rs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/util.rs b/src/util.rs index 82343bcf..de663583 100644 --- a/src/util.rs +++ b/src/util.rs @@ -48,10 +48,16 @@ impl CORS { } } - fn valid_url(url: String) -> String { - match url.as_ref() { - "file://" => "*".to_string(), - _ => url, + // Check a request's `Origin` header against the list of allowed origins. + // If a match exists, return it. Otherwise, return None. + fn get_allowed_origin(headers: &HeaderMap) -> Option { + let origin = CORS::get_header(headers, "Origin"); + let domain_origin = CONFIG.domain_origin(); + let safari_extension_origin = "file://"; + if origin == domain_origin || origin == safari_extension_origin { + Some(origin) + } else { + None } } } @@ -67,11 +73,11 @@ impl Fairing for CORS { fn on_response(&self, request: &Request, response: &mut Response) { let req_headers = request.headers(); - // We need to explicitly get the Origin header for Access-Control-Allow-Origin - let req_allow_origin = CORS::valid_url(CORS::get_header(req_headers, "Origin")); - - response.set_header(Header::new("Access-Control-Allow-Origin", req_allow_origin)); + if let Some(origin) = CORS::get_allowed_origin(req_headers) { + response.set_header(Header::new("Access-Control-Allow-Origin", origin)); + } + // Preflight request if request.method() == Method::Options { let req_allow_headers = CORS::get_header(req_headers, "Access-Control-Request-Headers"); let req_allow_method = CORS::get_header(req_headers, "Access-Control-Request-Method");