From 189fd7780619cfe14ce6250c23363cc6f56df7d3 Mon Sep 17 00:00:00 2001 From: Stefan Melmuk Date: Sat, 21 Dec 2024 13:31:47 +0100 Subject: [PATCH] introduce group_id newtype pattern --- src/api/core/organizations.rs | 97 ++++++++++++++++++----------------- src/db/models/group.rs | 82 ++++++++++++++++++++++++----- src/db/models/mod.rs | 2 +- src/db/models/organization.rs | 6 +-- 4 files changed, 123 insertions(+), 64 deletions(-) diff --git a/src/api/core/organizations.rs b/src/api/core/organizations.rs index b7ac7dcb..6b5048bc 100644 --- a/src/api/core/organizations.rs +++ b/src/api/core/organizations.rs @@ -124,7 +124,7 @@ struct OrganizationUpdateData { #[serde(rename_all = "camelCase")] struct NewCollectionData { name: String, - groups: Vec, + groups: Vec, users: Vec, id: Option, external_id: Option, @@ -132,9 +132,9 @@ struct NewCollectionData { #[derive(Deserialize)] #[serde(rename_all = "camelCase")] -struct NewCollectionObjectData { +struct NewCollectionGroupData { hide_passwords: bool, - id: String, + id: GroupId, read_only: bool, } @@ -155,8 +155,8 @@ struct OrgKeyData { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] -struct OrgBulkIds { - ids: Vec, +struct BulkGroupIds { + ids: Vec, } #[derive(Deserialize, Debug)] @@ -367,7 +367,7 @@ async fn get_org_collections_details( .await .iter() .map(|collection_group| { - SelectionReadOnly::to_collection_group_details_read_only(collection_group).to_json() + GroupSelection::to_collection_group_details_read_only(collection_group).to_json() }) .collect() } else { @@ -651,7 +651,7 @@ async fn get_org_collection_detail( .await .iter() .map(|collection_group| { - SelectionReadOnly::to_collection_group_details_read_only(collection_group).to_json() + GroupSelection::to_collection_group_details_read_only(collection_group).to_json() }) .collect() } else { @@ -856,7 +856,7 @@ struct MembershipData { #[serde(rename_all = "camelCase")] struct InviteData { emails: Vec, - groups: Vec, + groups: Vec, r#type: NumberOrString, collections: Option>, #[serde(default)] @@ -942,8 +942,8 @@ async fn send_invite( new_member.save(&mut conn).await?; - for group in data.groups.iter() { - let mut group_entry = GroupUser::new(String::from(group), new_member.uuid.clone()); + for group_id in data.groups.iter() { + let mut group_entry = GroupUser::new(group_id.clone(), new_member.uuid.clone()); group_entry.save(&mut conn).await?; } @@ -1330,7 +1330,7 @@ async fn get_user( struct EditUserData { r#type: NumberOrString, collections: Option>, - groups: Option>, + groups: Option>, #[serde(default)] access_all: bool, } @@ -1432,8 +1432,8 @@ async fn edit_user( GroupUser::delete_all_by_member(&member_to_edit.uuid, &mut conn).await?; - for group in data.groups.iter().flatten() { - let mut group_entry = GroupUser::new(String::from(group), member_to_edit.uuid.clone()); + for group_id in data.groups.iter().flatten() { + let mut group_entry = GroupUser::new(group_id.clone(), member_to_edit.uuid.clone()); group_entry.save(&mut conn).await?; } @@ -2379,13 +2379,13 @@ impl GroupRequest { #[derive(Deserialize, Serialize)] #[serde(rename_all = "camelCase")] -struct SelectionReadOnly { - id: String, +struct GroupSelection { + id: GroupId, read_only: bool, hide_passwords: bool, } -impl SelectionReadOnly { +impl GroupSelection { pub fn to_collection_group_details_read_only(collection_group: &CollectionGroup) -> Self { Self { id: collection_group.groups_uuid.clone(), @@ -2408,7 +2408,7 @@ struct CollectionSelection { } impl CollectionSelection { - pub fn to_collection_group(&self, groups_uuid: String) -> CollectionGroup { + pub fn to_collection_group(&self, groups_uuid: GroupId) -> CollectionGroup { CollectionGroup::new(self.id.clone(), groups_uuid, self.read_only, self.hide_passwords) } } @@ -2438,7 +2438,7 @@ impl UserSelection { #[post("/organizations//groups/", data = "")] async fn post_group( org_id: OrganizationId, - group_id: &str, + group_id: GroupId, data: Json, headers: AdminHeaders, conn: DbConn, @@ -2477,7 +2477,7 @@ async fn post_groups( #[put("/organizations//groups/", data = "")] async fn put_group( org_id: OrganizationId, - group_id: &str, + group_id: GroupId, data: Json, headers: AdminHeaders, mut conn: DbConn, @@ -2486,15 +2486,15 @@ async fn put_group( err!("Group support is disabled"); } - let Some(group) = Group::find_by_uuid_and_org(group_id, &org_id, &mut conn).await else { + let Some(group) = Group::find_by_uuid_and_org(&group_id, &org_id, &mut conn).await else { err!("Group not found", "Group uuid is invalid or does not belong to the organization") }; let group_request = data.into_inner(); let updated_group = group_request.update_group(group); - CollectionGroup::delete_all_by_group(group_id, &mut conn).await?; - GroupUser::delete_all_by_group(group_id, &mut conn).await?; + CollectionGroup::delete_all_by_group(&group_id, &mut conn).await?; + GroupUser::delete_all_by_group(&group_id, &mut conn).await?; log_event( EventType::GroupUpdated as i32, @@ -2553,7 +2553,7 @@ async fn add_update_group( #[get("/organizations//groups//details")] async fn get_group_details( org_id: OrganizationId, - group_id: &str, + group_id: GroupId, _headers: AdminHeaders, mut conn: DbConn, ) -> JsonResult { @@ -2561,7 +2561,7 @@ async fn get_group_details( err!("Group support is disabled"); } - let Some(group) = Group::find_by_uuid_and_org(group_id, &org_id, &mut conn).await else { + let Some(group) = Group::find_by_uuid_and_org(&group_id, &org_id, &mut conn).await else { err!("Group not found", "Group uuid is invalid or does not belong to the organization") }; @@ -2571,21 +2571,26 @@ async fn get_group_details( #[post("/organizations//groups//delete")] async fn post_delete_group( org_id: OrganizationId, - group_id: &str, + group_id: GroupId, headers: AdminHeaders, mut conn: DbConn, ) -> EmptyResult { - _delete_group(&org_id, group_id, &headers, &mut conn).await + _delete_group(&org_id, &group_id, &headers, &mut conn).await } #[delete("/organizations//groups/")] -async fn delete_group(org_id: OrganizationId, group_id: &str, headers: AdminHeaders, mut conn: DbConn) -> EmptyResult { - _delete_group(&org_id, group_id, &headers, &mut conn).await +async fn delete_group( + org_id: OrganizationId, + group_id: GroupId, + headers: AdminHeaders, + mut conn: DbConn, +) -> EmptyResult { + _delete_group(&org_id, &group_id, &headers, &mut conn).await } async fn _delete_group( org_id: &OrganizationId, - group_id: &str, + group_id: &GroupId, headers: &AdminHeaders, conn: &mut DbConn, ) -> EmptyResult { @@ -2614,7 +2619,7 @@ async fn _delete_group( #[delete("/organizations//groups", data = "")] async fn bulk_delete_groups( org_id: OrganizationId, - data: Json, + data: Json, headers: AdminHeaders, mut conn: DbConn, ) -> EmptyResult { @@ -2622,7 +2627,7 @@ async fn bulk_delete_groups( err!("Group support is disabled"); } - let data: OrgBulkIds = data.into_inner(); + let data: BulkGroupIds = data.into_inner(); for group_id in data.ids { _delete_group(&org_id, &group_id, &headers, &mut conn).await? @@ -2631,12 +2636,12 @@ async fn bulk_delete_groups( } #[get("/organizations//groups/")] -async fn get_group(org_id: OrganizationId, group_id: &str, _headers: AdminHeaders, mut conn: DbConn) -> JsonResult { +async fn get_group(org_id: OrganizationId, group_id: GroupId, _headers: AdminHeaders, mut conn: DbConn) -> JsonResult { if !CONFIG.org_groups_enabled() { err!("Group support is disabled"); } - let Some(group) = Group::find_by_uuid_and_org(group_id, &org_id, &mut conn).await else { + let Some(group) = Group::find_by_uuid_and_org(&group_id, &org_id, &mut conn).await else { err!("Group not found", "Group uuid is invalid or does not belong to the organization") }; @@ -2646,7 +2651,7 @@ async fn get_group(org_id: OrganizationId, group_id: &str, _headers: AdminHeader #[get("/organizations//groups//users")] async fn get_group_users( org_id: OrganizationId, - group_id: &str, + group_id: GroupId, _headers: AdminHeaders, mut conn: DbConn, ) -> JsonResult { @@ -2654,11 +2659,11 @@ async fn get_group_users( err!("Group support is disabled"); } - if Group::find_by_uuid_and_org(group_id, &org_id, &mut conn).await.is_none() { + if Group::find_by_uuid_and_org(&&group_id, &org_id, &mut conn).await.is_none() { err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization") }; - let group_users: Vec = GroupUser::find_by_group(group_id, &mut conn) + let group_users: Vec = GroupUser::find_by_group(&group_id, &mut conn) .await .iter() .map(|entry| entry.users_organizations_uuid.clone()) @@ -2670,7 +2675,7 @@ async fn get_group_users( #[put("/organizations//groups//users", data = "")] async fn put_group_users( org_id: OrganizationId, - group_id: &str, + group_id: GroupId, headers: AdminHeaders, data: Json>, mut conn: DbConn, @@ -2679,15 +2684,15 @@ async fn put_group_users( err!("Group support is disabled"); } - if Group::find_by_uuid_and_org(group_id, &org_id, &mut conn).await.is_none() { + if Group::find_by_uuid_and_org(&group_id, &org_id, &mut conn).await.is_none() { err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization") }; - GroupUser::delete_all_by_group(group_id, &mut conn).await?; + GroupUser::delete_all_by_group(&group_id, &mut conn).await?; let assigned_members = data.into_inner(); for assigned_member in assigned_members { - let mut user_entry = GroupUser::new(String::from(group_id), assigned_member.clone()); + let mut user_entry = GroupUser::new(group_id.clone(), assigned_member.clone()); user_entry.save(&mut conn).await?; log_event( @@ -2720,7 +2725,7 @@ async fn get_user_groups( err!("User could not be found!") }; - let user_groups: Vec = + let user_groups: Vec = GroupUser::find_by_member(&member_id, &mut conn).await.iter().map(|entry| entry.groups_uuid.clone()).collect(); Ok(Json(json!(user_groups))) @@ -2729,7 +2734,7 @@ async fn get_user_groups( #[derive(Deserialize)] #[serde(rename_all = "camelCase")] struct OrganizationUserUpdateGroupsRequest { - group_ids: Vec, + group_ids: Vec, } #[post("/organizations//users//groups", data = "")] @@ -2784,7 +2789,7 @@ async fn put_user_groups( #[post("/organizations//groups//delete-user/")] async fn post_delete_group_user( org_id: OrganizationId, - group_id: &str, + group_id: GroupId, member_id: MembershipId, headers: AdminHeaders, conn: DbConn, @@ -2795,7 +2800,7 @@ async fn post_delete_group_user( #[delete("/organizations//groups//users/")] async fn delete_group_user( org_id: OrganizationId, - group_id: &str, + group_id: GroupId, member_id: MembershipId, headers: AdminHeaders, mut conn: DbConn, @@ -2808,7 +2813,7 @@ async fn delete_group_user( err!("User could not be found or does not belong to the organization."); } - if Group::find_by_uuid_and_org(group_id, &org_id, &mut conn).await.is_none() { + if Group::find_by_uuid_and_org(&group_id, &org_id, &mut conn).await.is_none() { err!("Group could not be found or does not belong to the organization."); } @@ -2823,7 +2828,7 @@ async fn delete_group_user( ) .await; - GroupUser::delete_by_group_and_member(group_id, &member_id, &mut conn).await + GroupUser::delete_by_group_and_member(&group_id, &member_id, &mut conn).await } #[derive(Deserialize)] diff --git a/src/db/models/group.rs b/src/db/models/group.rs index f742b64b..1ed23f04 100644 --- a/src/db/models/group.rs +++ b/src/db/models/group.rs @@ -3,14 +3,20 @@ use crate::api::EmptyResult; use crate::db::DbConn; use crate::error::MapResult; use chrono::{NaiveDateTime, Utc}; +use rocket::request::FromParam; use serde_json::Value; +use std::{ + borrow::Borrow, + fmt::{Display, Formatter}, + ops::Deref, +}; db_object! { #[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[diesel(table_name = groups)] #[diesel(primary_key(uuid))] pub struct Group { - pub uuid: String, + pub uuid: GroupId, pub organizations_uuid: OrganizationId, pub name: String, pub access_all: bool, @@ -24,7 +30,7 @@ db_object! { #[diesel(primary_key(collections_uuid, groups_uuid))] pub struct CollectionGroup { pub collections_uuid: CollectionId, - pub groups_uuid: String, + pub groups_uuid: GroupId, pub read_only: bool, pub hide_passwords: bool, } @@ -33,7 +39,7 @@ db_object! { #[diesel(table_name = groups_users)] #[diesel(primary_key(groups_uuid, users_organizations_uuid))] pub struct GroupUser { - pub groups_uuid: String, + pub groups_uuid: GroupId, pub users_organizations_uuid: MembershipId } } @@ -49,7 +55,7 @@ impl Group { let now = Utc::now().naive_utc(); let mut new_model = Self { - uuid: crate::util::get_uuid(), + uuid: GroupId(crate::util::get_uuid()), organizations_uuid, name, access_all, @@ -113,7 +119,7 @@ impl Group { } impl CollectionGroup { - pub fn new(collections_uuid: CollectionId, groups_uuid: String, read_only: bool, hide_passwords: bool) -> Self { + pub fn new(collections_uuid: CollectionId, groups_uuid: GroupId, read_only: bool, hide_passwords: bool) -> Self { Self { collections_uuid, groups_uuid, @@ -124,7 +130,7 @@ impl CollectionGroup { } impl GroupUser { - pub fn new(groups_uuid: String, users_organizations_uuid: MembershipId) -> Self { + pub fn new(groups_uuid: GroupId, users_organizations_uuid: MembershipId) -> Self { Self { groups_uuid, users_organizations_uuid, @@ -196,7 +202,7 @@ impl Group { }} } - pub async fn find_by_uuid_and_org(uuid: &str, org_uuid: &OrganizationId, conn: &mut DbConn) -> Option { + pub async fn find_by_uuid_and_org(uuid: &GroupId, org_uuid: &OrganizationId, conn: &mut DbConn) -> Option { db_run! { conn: { groups::table .filter(groups::uuid.eq(uuid)) @@ -269,13 +275,13 @@ impl Group { }} } - pub async fn update_revision(uuid: &str, conn: &mut DbConn) { + pub async fn update_revision(uuid: &GroupId, conn: &mut DbConn) { if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await { warn!("Failed to update revision for {}: {:#?}", uuid, e); } } - async fn _update_revision(uuid: &str, date: &NaiveDateTime, conn: &mut DbConn) -> EmptyResult { + async fn _update_revision(uuid: &GroupId, date: &NaiveDateTime, conn: &mut DbConn) -> EmptyResult { db_run! {conn: { crate::util::retry(|| { diesel::update(groups::table.filter(groups::uuid.eq(uuid))) @@ -343,7 +349,7 @@ impl CollectionGroup { } } - pub async fn find_by_group(group_uuid: &str, conn: &mut DbConn) -> Vec { + pub async fn find_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> Vec { db_run! { conn: { collections_groups::table .filter(collections_groups::groups_uuid.eq(group_uuid)) @@ -396,7 +402,7 @@ impl CollectionGroup { }} } - pub async fn delete_all_by_group(group_uuid: &str, conn: &mut DbConn) -> EmptyResult { + pub async fn delete_all_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> EmptyResult { let group_users = GroupUser::find_by_group(group_uuid, conn).await; for group_user in group_users { group_user.update_user_revision(conn).await; @@ -475,7 +481,7 @@ impl GroupUser { } } - pub async fn find_by_group(group_uuid: &str, conn: &mut DbConn) -> Vec { + pub async fn find_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> Vec { db_run! { conn: { groups_users::table .filter(groups_users::groups_uuid.eq(group_uuid)) @@ -540,7 +546,7 @@ impl GroupUser { } pub async fn delete_by_group_and_member( - group_uuid: &str, + group_uuid: &GroupId, member_uuid: &MembershipId, conn: &mut DbConn, ) -> EmptyResult { @@ -558,7 +564,7 @@ impl GroupUser { }} } - pub async fn delete_all_by_group(group_uuid: &str, conn: &mut DbConn) -> EmptyResult { + pub async fn delete_all_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> EmptyResult { let group_users = GroupUser::find_by_group(group_uuid, conn).await; for group_user in group_users { group_user.update_user_revision(conn).await; @@ -586,3 +592,51 @@ impl GroupUser { }} } } + +#[derive(DieselNewType, FromForm, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct GroupId(String); + +impl AsRef for GroupId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl Deref for GroupId { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Borrow for GroupId { + fn borrow(&self) -> &str { + &self.0 + } +} + +impl Display for GroupId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for GroupId { + fn from(raw: String) -> Self { + Self(raw) + } +} + +impl<'r> FromParam<'r> for GroupId { + type Error = (); + + #[inline(always)] + fn from_param(param: &'r str) -> Result { + if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { + Ok(Self(param.to_string())) + } else { + Err(()) + } + } +} diff --git a/src/db/models/mod.rs b/src/db/models/mod.rs index a96c5bb9..e10951e5 100644 --- a/src/db/models/mod.rs +++ b/src/db/models/mod.rs @@ -25,7 +25,7 @@ pub use self::emergency_access::{EmergencyAccess, EmergencyAccessStatus, Emergen pub use self::event::{Event, EventType}; pub use self::favorite::Favorite; pub use self::folder::{Folder, FolderCipher}; -pub use self::group::{CollectionGroup, Group, GroupUser}; +pub use self::group::{CollectionGroup, Group, GroupId, GroupUser}; pub use self::org_policy::{OrgPolicy, OrgPolicyErr, OrgPolicyType}; pub use self::organization::{ Membership, MembershipId, MembershipStatus, MembershipType, Organization, OrganizationApiKey, OrganizationId, diff --git a/src/db/models/organization.rs b/src/db/models/organization.rs index 9e39161a..016bd3da 100644 --- a/src/db/models/organization.rs +++ b/src/db/models/organization.rs @@ -11,8 +11,8 @@ use std::{ }; use super::{ - Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupUser, OrgPolicy, OrgPolicyType, TwoFactor, - User, UserId, + Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy, OrgPolicyType, + TwoFactor, User, UserId, }; use crate::CONFIG; @@ -460,7 +460,7 @@ impl Membership { let twofactor_enabled = !TwoFactor::find_by_user(&user.uuid, conn).await.is_empty(); - let groups: Vec = if include_groups && CONFIG.org_groups_enabled() { + let groups: Vec = if include_groups && CONFIG.org_groups_enabled() { GroupUser::find_by_member(&self.uuid, conn).await.iter().map(|gu| gu.groups_uuid.clone()).collect() } else { // The Bitwarden clients seem to call this API regardless of whether groups are enabled,