/// /// JWT Handling /// use crate::util::read_file; use chrono::Duration; use jsonwebtoken::{self, Algorithm, Header}; use serde::ser::Serialize; use crate::CONFIG; const JWT_ALGORITHM: Algorithm = Algorithm::RS256; lazy_static! { pub static ref DEFAULT_VALIDITY: Duration = Duration::hours(2); pub static ref JWT_ISSUER: String = CONFIG.domain.clone(); static ref JWT_HEADER: Header = Header::new(JWT_ALGORITHM); static ref PRIVATE_RSA_KEY: Vec = match read_file(&CONFIG.private_rsa_key) { Ok(key) => key, Err(e) => panic!("Error loading private RSA Key from {}\n Error: {}", CONFIG.private_rsa_key, e) }; static ref PUBLIC_RSA_KEY: Vec = match read_file(&CONFIG.public_rsa_key) { Ok(key) => key, Err(e) => panic!("Error loading public RSA Key from {}\n Error: {}", CONFIG.public_rsa_key, e) }; } pub fn encode_jwt(claims: &T) -> String { match jsonwebtoken::encode(&JWT_HEADER, claims, &PRIVATE_RSA_KEY) { Ok(token) => token, Err(e) => panic!("Error encoding jwt {}", e) } } pub fn decode_jwt(token: &str) -> Result { let validation = jsonwebtoken::Validation { leeway: 30, // 30 seconds validate_exp: true, validate_iat: false, // IssuedAt is the same as NotBefore validate_nbf: true, aud: None, iss: Some(JWT_ISSUER.clone()), sub: None, algorithms: vec![JWT_ALGORITHM], }; match jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) { Ok(decoded) => Ok(decoded.claims), Err(msg) => { error!("Error validating jwt - {:#?}", msg); Err(msg.to_string()) } } } pub fn decode_invite_jwt(token: &str) -> Result { let validation = jsonwebtoken::Validation { leeway: 30, // 30 seconds validate_exp: true, validate_iat: false, // IssuedAt is the same as NotBefore validate_nbf: true, aud: None, iss: Some(JWT_ISSUER.clone()), sub: None, algorithms: vec![JWT_ALGORITHM], }; match jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) { Ok(decoded) => Ok(decoded.claims), Err(msg) => { error!("Error validating jwt - {:#?}", msg); Err(msg.to_string()) } } } #[derive(Debug, Serialize, Deserialize)] pub struct JWTClaims { // Not before pub nbf: i64, // Expiration time pub exp: i64, // Issuer pub iss: String, // Subject pub sub: String, pub premium: bool, pub name: String, pub email: String, pub email_verified: bool, pub orgowner: Vec, pub orgadmin: Vec, pub orguser: Vec, pub orgmanager: Vec, // user security_stamp pub sstamp: String, // device uuid pub device: String, // [ "api", "offline_access" ] pub scope: Vec, // [ "Application" ] pub amr: Vec, } #[derive(Debug, Serialize, Deserialize)] pub struct InviteJWTClaims { // Not before pub nbf: i64, // Expiration time pub exp: i64, // Issuer pub iss: String, // Subject pub sub: String, pub email: String, pub org_id: String, pub user_org_id: Option, } /// /// Bearer token authentication /// use rocket::Outcome; use rocket::request::{self, Request, FromRequest}; use crate::db::DbConn; use crate::db::models::{User, UserOrganization, UserOrgType, UserOrgStatus, Device}; pub struct Headers { pub host: String, pub device: Device, pub user: User, } impl<'a, 'r> FromRequest<'a, 'r> for Headers { type Error = &'static str; fn from_request(request: &'a Request<'r>) -> request::Outcome { let headers = request.headers(); // Get host let host = if CONFIG.domain_set { CONFIG.domain.clone() } else if let Some(referer) = headers.get_one("Referer") { referer.to_string() } else { // Try to guess from the headers use std::env; let protocol = if let Some(proto) = headers.get_one("X-Forwarded-Proto") { proto } else if env::var("ROCKET_TLS").is_ok() { "https" } else { "http" }; let host = if let Some(host) = headers.get_one("X-Forwarded-Host") { host } else if let Some(host) = headers.get_one("Host") { host } else { "" }; format!("{}://{}", protocol, host) }; // Get access_token let access_token: &str = match headers.get_one("Authorization") { Some(a) => match a.rsplit("Bearer ").next() { Some(split) => split, None => err_handler!("No access token provided"), }, None => err_handler!("No access token provided"), }; // Check JWT token is valid and get device and user from it let claims: JWTClaims = match decode_jwt(access_token) { Ok(claims) => claims, Err(_) => err_handler!("Invalid claim") }; let device_uuid = claims.device; let user_uuid = claims.sub; let conn = match request.guard::() { Outcome::Success(conn) => conn, _ => err_handler!("Error getting DB") }; let device = match Device::find_by_uuid(&device_uuid, &conn) { Some(device) => device, None => err_handler!("Invalid device id") }; let user = match User::find_by_uuid(&user_uuid, &conn) { Some(user) => user, None => err_handler!("Device has no user associated") }; if user.security_stamp != claims.sstamp { err_handler!("Invalid security stamp") } Outcome::Success(Headers { host, device, user }) } } pub struct OrgHeaders { pub host: String, pub device: Device, pub user: User, pub org_user_type: UserOrgType, } impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders { type Error = &'static str; fn from_request(request: &'a Request<'r>) -> request::Outcome { match request.guard::() { Outcome::Forward(_) => Outcome::Forward(()), Outcome::Failure(f) => Outcome::Failure(f), Outcome::Success(headers) => { // org_id is expected to be the second param ("/organizations/") match request.get_param::(1) { Some(Ok(org_id)) => { let conn = match request.guard::() { Outcome::Success(conn) => conn, _ => err_handler!("Error getting DB") }; let org_user = match UserOrganization::find_by_user_and_org(&headers.user.uuid, &org_id, &conn) { Some(user) => { if user.status == UserOrgStatus::Confirmed as i32 { user } else { err_handler!("The current user isn't confirmed member of the organization") } } None => err_handler!("The current user isn't member of the organization") }; Outcome::Success(Self{ host: headers.host, device: headers.device, user: headers.user, org_user_type: { if let Some(org_usr_type) = UserOrgType::from_i32(org_user.type_) { org_usr_type } else { // This should only happen if the DB is corrupted err_handler!("Unknown user type in the database") } }, }) }, _ => err_handler!("Error getting the organization id"), } } } } } pub struct AdminHeaders { pub host: String, pub device: Device, pub user: User, pub org_user_type: UserOrgType, } impl<'a, 'r> FromRequest<'a, 'r> for AdminHeaders { type Error = &'static str; fn from_request(request: &'a Request<'r>) -> request::Outcome { match request.guard::() { Outcome::Forward(_) => Outcome::Forward(()), Outcome::Failure(f) => Outcome::Failure(f), Outcome::Success(headers) => { if headers.org_user_type >= UserOrgType::Admin { Outcome::Success(Self { host: headers.host, device: headers.device, user: headers.user, org_user_type: headers.org_user_type, }) } else { err_handler!("You need to be Admin or Owner to call this endpoint") } } } } } pub struct OwnerHeaders { pub host: String, pub device: Device, pub user: User, } impl<'a, 'r> FromRequest<'a, 'r> for OwnerHeaders { type Error = &'static str; fn from_request(request: &'a Request<'r>) -> request::Outcome { match request.guard::() { Outcome::Forward(_) => Outcome::Forward(()), Outcome::Failure(f) => Outcome::Failure(f), Outcome::Success(headers) => { if headers.org_user_type == UserOrgType::Owner { Outcome::Success(Self { host: headers.host, device: headers.device, user: headers.user, }) } else { err_handler!("You need to be Owner to call this endpoint") } } } } } /// /// Client IP address detection /// use std::net::IpAddr; pub struct ClientIp { pub ip: IpAddr, } impl<'a, 'r> FromRequest<'a, 'r> for ClientIp { type Error = (); fn from_request(request: &'a Request<'r>) -> request::Outcome { let ip = match request.client_ip() { Some(addr) => addr, None => "0.0.0.0".parse().unwrap(), }; Outcome::Success(ClientIp { ip }) } }