geforkt von mirrored/vaultwarden
364 Zeilen
9,5 KiB
Rust
364 Zeilen
9,5 KiB
Rust
use rocket::Route;
|
|
use rocket_contrib::json::Json;
|
|
use serde_json::Value as JsonValue;
|
|
|
|
use crate::api::JsonResult;
|
|
use crate::auth::Headers;
|
|
use crate::db::DbConn;
|
|
|
|
use crate::CONFIG;
|
|
|
|
pub fn routes() -> Vec<Route> {
|
|
routes![negotiate, websockets_err]
|
|
}
|
|
|
|
#[get("/hub")]
|
|
fn websockets_err() -> JsonResult {
|
|
err!("'/notifications/hub' should be proxied to the websocket server or notifications won't work. Go to the README for more info.")
|
|
}
|
|
|
|
#[post("/hub/negotiate")]
|
|
fn negotiate(_headers: Headers, _conn: DbConn) -> JsonResult {
|
|
use crate::crypto;
|
|
use data_encoding::BASE64URL;
|
|
|
|
let conn_id = BASE64URL.encode(&crypto::get_random(vec![0u8; 16]));
|
|
let mut available_transports: Vec<JsonValue> = Vec::new();
|
|
|
|
if CONFIG.websocket_enabled() {
|
|
available_transports.push(json!({"transport":"WebSockets", "transferFormats":["Text","Binary"]}));
|
|
}
|
|
|
|
// TODO: Implement transports
|
|
// Rocket WS support: https://github.com/SergioBenitez/Rocket/issues/90
|
|
// Rocket SSE support: https://github.com/SergioBenitez/Rocket/issues/33
|
|
// {"transport":"ServerSentEvents", "transferFormats":["Text"]},
|
|
// {"transport":"LongPolling", "transferFormats":["Text","Binary"]}
|
|
Ok(Json(json!({
|
|
"connectionId": conn_id,
|
|
"availableTransports": available_transports
|
|
})))
|
|
}
|
|
|
|
//
|
|
// Websockets server
|
|
//
|
|
use std::sync::Arc;
|
|
use std::thread;
|
|
|
|
use ws::{self, util::Token, Factory, Handler, Handshake, Message, Sender, WebSocket};
|
|
|
|
use chashmap::CHashMap;
|
|
use chrono::NaiveDateTime;
|
|
use serde_json::from_str;
|
|
|
|
use crate::db::models::{Cipher, Folder, User};
|
|
|
|
use rmpv::Value;
|
|
|
|
fn serialize(val: Value) -> Vec<u8> {
|
|
use rmpv::encode::write_value;
|
|
|
|
let mut buf = Vec::new();
|
|
write_value(&mut buf, &val).expect("Error encoding MsgPack");
|
|
|
|
// Add size bytes at the start
|
|
// Extracted from BinaryMessageFormat.js
|
|
let mut size: usize = buf.len();
|
|
let mut len_buf: Vec<u8> = Vec::new();
|
|
|
|
loop {
|
|
let mut size_part = size & 0x7f;
|
|
size >>= 7;
|
|
|
|
if size > 0 {
|
|
size_part |= 0x80;
|
|
}
|
|
|
|
len_buf.push(size_part as u8);
|
|
|
|
if size == 0 {
|
|
break;
|
|
}
|
|
}
|
|
|
|
len_buf.append(&mut buf);
|
|
len_buf
|
|
}
|
|
|
|
fn serialize_date(date: NaiveDateTime) -> Value {
|
|
let seconds: i64 = date.timestamp();
|
|
let nanos: i64 = date.timestamp_subsec_nanos().into();
|
|
let timestamp = nanos << 34 | seconds;
|
|
|
|
let bs = timestamp.to_be_bytes();
|
|
|
|
// -1 is Timestamp
|
|
// https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type
|
|
Value::Ext(-1, bs.to_vec())
|
|
}
|
|
|
|
fn convert_option<T: Into<Value>>(option: Option<T>) -> Value {
|
|
match option {
|
|
Some(a) => a.into(),
|
|
None => Value::Nil,
|
|
}
|
|
}
|
|
|
|
// Server WebSocket handler
|
|
pub struct WSHandler {
|
|
out: Sender,
|
|
user_uuid: Option<String>,
|
|
users: WebSocketUsers,
|
|
}
|
|
|
|
const RECORD_SEPARATOR: u8 = 0x1e;
|
|
const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, <RS>
|
|
|
|
#[derive(Deserialize)]
|
|
struct InitialMessage {
|
|
protocol: String,
|
|
version: i32,
|
|
}
|
|
|
|
const PING_MS: u64 = 15_000;
|
|
const PING: Token = Token(1);
|
|
|
|
impl Handler for WSHandler {
|
|
fn on_open(&mut self, hs: Handshake) -> ws::Result<()> {
|
|
// TODO: Improve this split
|
|
let path = hs.request.resource();
|
|
let mut query_split: Vec<_> = path.split('?').nth(1).unwrap().split('&').collect();
|
|
query_split.sort();
|
|
let access_token = &query_split[0][13..];
|
|
let _id = &query_split[1][3..];
|
|
|
|
// Validate the user
|
|
use crate::auth;
|
|
let claims = match auth::decode_login(access_token) {
|
|
Ok(claims) => claims,
|
|
Err(_) => return Err(ws::Error::new(ws::ErrorKind::Internal, "Invalid access token provided")),
|
|
};
|
|
|
|
// Assign the user to the handler
|
|
let user_uuid = claims.sub;
|
|
self.user_uuid = Some(user_uuid.clone());
|
|
|
|
// Add the current Sender to the user list
|
|
let handler_insert = self.out.clone();
|
|
let handler_update = self.out.clone();
|
|
|
|
self.users
|
|
.map
|
|
.upsert(user_uuid, || vec![handler_insert], |ref mut v| v.push(handler_update));
|
|
|
|
// Schedule a ping to keep the connection alive
|
|
self.out.timeout(PING_MS, PING)
|
|
}
|
|
|
|
fn on_message(&mut self, msg: Message) -> ws::Result<()> {
|
|
info!("Server got message '{}'. ", msg);
|
|
|
|
if let Message::Text(text) = msg.clone() {
|
|
let json = &text[..text.len() - 1]; // Remove last char
|
|
|
|
if let Ok(InitialMessage { protocol, version }) = from_str::<InitialMessage>(json) {
|
|
if &protocol == "messagepack" && version == 1 {
|
|
return self.out.send(&INITIAL_RESPONSE[..]); // Respond to initial message
|
|
}
|
|
}
|
|
}
|
|
|
|
// If it's not the initial message, just echo the message
|
|
self.out.send(msg)
|
|
}
|
|
|
|
fn on_timeout(&mut self, event: Token) -> ws::Result<()> {
|
|
if event == PING {
|
|
// send ping
|
|
self.out.send(create_ping())?;
|
|
|
|
// reschedule the timeout
|
|
self.out.timeout(PING_MS, PING)
|
|
} else {
|
|
Err(ws::Error::new(
|
|
ws::ErrorKind::Internal,
|
|
"Invalid timeout token provided",
|
|
))
|
|
}
|
|
}
|
|
}
|
|
|
|
struct WSFactory {
|
|
pub users: WebSocketUsers,
|
|
}
|
|
|
|
impl WSFactory {
|
|
pub fn init() -> Self {
|
|
WSFactory {
|
|
users: WebSocketUsers {
|
|
map: Arc::new(CHashMap::new()),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Factory for WSFactory {
|
|
type Handler = WSHandler;
|
|
|
|
fn connection_made(&mut self, out: Sender) -> Self::Handler {
|
|
WSHandler {
|
|
out,
|
|
user_uuid: None,
|
|
users: self.users.clone(),
|
|
}
|
|
}
|
|
|
|
fn connection_lost(&mut self, handler: Self::Handler) {
|
|
// Remove handler
|
|
if let Some(user_uuid) = &handler.user_uuid {
|
|
if let Some(mut user_conn) = self.users.map.get_mut(user_uuid) {
|
|
user_conn.remove_item(&handler.out);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct WebSocketUsers {
|
|
map: Arc<CHashMap<String, Vec<Sender>>>,
|
|
}
|
|
|
|
impl WebSocketUsers {
|
|
fn send_update(&self, user_uuid: &str, data: &[u8]) -> ws::Result<()> {
|
|
if let Some(user) = self.map.get(user_uuid) {
|
|
for sender in user.iter() {
|
|
sender.send(data)?;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// NOTE: The last modified date needs to be updated before calling these methods
|
|
pub fn send_user_update(&self, ut: UpdateType, user: &User) {
|
|
let data = create_update(
|
|
vec![
|
|
("UserId".into(), user.uuid.clone().into()),
|
|
("Date".into(), serialize_date(user.updated_at)),
|
|
],
|
|
ut,
|
|
);
|
|
|
|
self.send_update(&user.uuid, &data).ok();
|
|
}
|
|
|
|
pub fn send_folder_update(&self, ut: UpdateType, folder: &Folder) {
|
|
let data = create_update(
|
|
vec![
|
|
("Id".into(), folder.uuid.clone().into()),
|
|
("UserId".into(), folder.user_uuid.clone().into()),
|
|
("RevisionDate".into(), serialize_date(folder.updated_at)),
|
|
],
|
|
ut,
|
|
);
|
|
|
|
self.send_update(&folder.user_uuid, &data).ok();
|
|
}
|
|
|
|
pub fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[String]) {
|
|
let user_uuid = convert_option(cipher.user_uuid.clone());
|
|
let org_uuid = convert_option(cipher.organization_uuid.clone());
|
|
|
|
let data = create_update(
|
|
vec![
|
|
("Id".into(), cipher.uuid.clone().into()),
|
|
("UserId".into(), user_uuid),
|
|
("OrganizationId".into(), org_uuid),
|
|
("CollectionIds".into(), Value::Nil),
|
|
("RevisionDate".into(), serialize_date(cipher.updated_at)),
|
|
],
|
|
ut,
|
|
);
|
|
|
|
for uuid in user_uuids {
|
|
self.send_update(&uuid, &data).ok();
|
|
}
|
|
}
|
|
}
|
|
|
|
/* Message Structure
|
|
[
|
|
1, // MessageType.Invocation
|
|
{}, // Headers
|
|
null, // InvocationId
|
|
"ReceiveMessage", // Target
|
|
[ // Arguments
|
|
{
|
|
"ContextId": "app_id",
|
|
"Type": ut as i32,
|
|
"Payload": {}
|
|
}
|
|
]
|
|
]
|
|
*/
|
|
fn create_update(payload: Vec<(Value, Value)>, ut: UpdateType) -> Vec<u8> {
|
|
use rmpv::Value as V;
|
|
|
|
let value = V::Array(vec![
|
|
1.into(),
|
|
V::Array(vec![]),
|
|
V::Nil,
|
|
"ReceiveMessage".into(),
|
|
V::Array(vec![V::Map(vec![
|
|
("ContextId".into(), "app_id".into()),
|
|
("Type".into(), (ut as i32).into()),
|
|
("Payload".into(), payload.into()),
|
|
])]),
|
|
]);
|
|
|
|
serialize(value)
|
|
}
|
|
|
|
fn create_ping() -> Vec<u8> {
|
|
serialize(Value::Array(vec![6.into()]))
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
#[derive(PartialEq)]
|
|
pub enum UpdateType {
|
|
CipherUpdate = 0,
|
|
CipherCreate = 1,
|
|
LoginDelete = 2,
|
|
FolderDelete = 3,
|
|
Ciphers = 4,
|
|
|
|
Vault = 5,
|
|
OrgKeys = 6,
|
|
FolderCreate = 7,
|
|
FolderUpdate = 8,
|
|
CipherDelete = 9,
|
|
SyncSettings = 10,
|
|
|
|
LogOut = 11,
|
|
|
|
None = 100,
|
|
}
|
|
|
|
use rocket::State;
|
|
pub type Notify<'a> = State<'a, WebSocketUsers>;
|
|
|
|
pub fn start_notification_server() -> WebSocketUsers {
|
|
let factory = WSFactory::init();
|
|
let users = factory.users.clone();
|
|
|
|
if CONFIG.websocket_enabled() {
|
|
thread::spawn(move || {
|
|
WebSocket::new(factory)
|
|
.unwrap()
|
|
.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port()))
|
|
.unwrap();
|
|
});
|
|
}
|
|
|
|
users
|
|
}
|