From 3e886e670e7e026c188af1d856ca8ae2cc3113b5 Mon Sep 17 00:00:00 2001 From: Chase Douglas Date: Mon, 3 Feb 2025 13:19:37 -0800 Subject: [PATCH] AWS Aurora DSQL support --- Cargo.lock | 446 +++++++++++++++++- Cargo.toml | 5 + build.rs | 5 +- .../metadata.toml | 1 + .../2024-12-30-100000_create_tables/up.sql | 281 +++++++++++ src/db/dsql.rs | 184 ++++++++ src/db/mod.rs | 44 +- 7 files changed, 953 insertions(+), 13 deletions(-) create mode 100644 migrations/dsql/2024-12-30-100000_create_tables/metadata.toml create mode 100644 migrations/dsql/2024-12-30-100000_create_tables/up.sql create mode 100644 src/db/dsql.rs diff --git a/Cargo.lock b/Cargo.lock index c925f664..71088b56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -311,6 +311,331 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-config" +version = "1.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "649316840239f4e58df0b7f620c428f5fababbbca2d504488c641534050bd141" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sdk-sso", + "aws-sdk-ssooidc", + "aws-sdk-sts", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "hex", + "http 0.2.12", + "ring", + "time", + "tokio", + "tracing", + "url", + "zeroize", +] + +[[package]] +name = "aws-credential-types" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60e8f6b615cb5fc60a98132268508ad104310f0cfb25a1c22eee76efdf9154da" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "zeroize", +] + +[[package]] +name = "aws-runtime" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44f6f1124d6e19ab6daf7f2e615644305dc6cb2d706892a8a8c0b98db35de020" +dependencies = [ + "aws-credential-types", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http-body 0.4.6", + "once_cell", + "percent-encoding", + "pin-project-lite", + "tracing", + "uuid", +] + +[[package]] +name = "aws-sdk-dsql" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4ab97b30c3c494278a3386dd83335c4f50de3f963305a4642b1becf8d2dcfd2" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", + "url", +] + +[[package]] +name = "aws-sdk-sso" +version = "1.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb25f7129c74d36afe33405af4517524df8f74b635af8c2c8e91c1552b8397b2" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-ssooidc" +version = "1.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d03a3d5ef14851625eafd89660a751776f938bf32f309308b20dcca41c44b568" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-sts" +version = "1.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf3a9f073ae3a53b54421503063dfb87ff1ea83b876f567d92e8b8d9942ba91b" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sigv4" +version = "1.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d3820e0c08d0737872ff3c7c1f21ebbb6693d832312d6152bf18ef50a5471c2" +dependencies = [ + "aws-credential-types", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "form_urlencoded", + "hex", + "hmac", + "http 0.2.12", + "http 1.2.0", + "once_cell", + "percent-encoding", + "sha2", + "time", + "tracing", +] + +[[package]] +name = "aws-smithy-async" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "427cb637d15d63d6f9aae26358e1c9a9c09d5aa490d64b09354c8217cfef0f28" +dependencies = [ + "futures-util", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "aws-smithy-http" +version = "0.60.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c8bc3e8fdc6b8d07d976e301c02fe553f72a39b7a9fea820e023268467d7ab6" +dependencies = [ + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http-body 0.4.6", + "once_cell", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-json" +version = "0.61.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4e69cc50921eb913c6b662f8d909131bb3e6ad6cb6090d3a39b66fc5c52095" +dependencies = [ + "aws-smithy-types", +] + +[[package]] +name = "aws-smithy-query" +version = "0.60.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" +dependencies = [ + "aws-smithy-types", + "urlencoding", +] + +[[package]] +name = "aws-smithy-runtime" +version = "1.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a05dd41a70fc74051758ee75b5c4db2c0ca070ed9229c3df50e9475cda1cb985" +dependencies = [ + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "fastrand", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "http-body 1.0.1", + "httparse", + "hyper 0.14.32", + "hyper-rustls 0.24.2", + "once_cell", + "pin-project-lite", + "pin-utils", + "rustls 0.21.12", + "tokio", + "tracing", +] + +[[package]] +name = "aws-smithy-runtime-api" +version = "1.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92165296a47a812b267b4f41032ff8069ab7ff783696d217f0994a0d7ab585cd" +dependencies = [ + "aws-smithy-async", + "aws-smithy-types", + "bytes", + "http 0.2.12", + "http 1.2.0", + "pin-project-lite", + "tokio", + "tracing", + "zeroize", +] + +[[package]] +name = "aws-smithy-types" +version = "1.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ddc9bd6c28aeb303477170ddd183760a956a03e083b3902a990238a7e3792d" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http 1.2.0", + "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", + "tokio", + "tokio-util", +] + +[[package]] +name = "aws-smithy-xml" +version = "0.60.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" +dependencies = [ + "xmlparser", +] + +[[package]] +name = "aws-types" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5221b91b3e441e6675310829fd8984801b772cb1546ef6c0e54dec9f1ac13fef" +dependencies = [ + "aws-credential-types", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "rustc_version", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -344,6 +669,16 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + [[package]] name = "base64ct" version = "1.6.0" @@ -451,6 +786,16 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + [[package]] name = "cached" version = "0.54.0" @@ -1316,6 +1661,25 @@ dependencies = [ "phf", ] +[[package]] +name = "h2" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "h2" version = "0.4.7" @@ -1392,6 +1756,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "hickory-proto" version = "0.24.3" @@ -1555,6 +1925,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", + "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "httparse", @@ -1577,7 +1948,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2", + "h2 0.4.7", "http 1.2.0", "http-body 1.0.1", "httparse", @@ -1588,6 +1959,22 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper 0.14.32", + "log", + "rustls 0.21.12", + "rustls-native-certs", + "tokio", + "tokio-rustls 0.24.1", +] + [[package]] name = "hyper-rustls" version = "0.27.5" @@ -2374,6 +2761,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "outref" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a" + [[package]] name = "overload" version = "0.1.1" @@ -2877,6 +3270,12 @@ dependencies = [ "regex-syntax 0.8.5", ] +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + [[package]] name = "regex-syntax" version = "0.6.29" @@ -2915,12 +3314,12 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", + "h2 0.4.7", "http 1.2.0", "http-body 1.0.1", "http-body-util", "hyper 1.6.0", - "hyper-rustls", + "hyper-rustls 0.27.5", "hyper-tls", "hyper-util", "ipnet", @@ -3119,6 +3518,15 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.44" @@ -3157,6 +3565,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile 1.0.4", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -4041,6 +4461,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf-8" version = "0.7.6" @@ -4085,6 +4511,8 @@ name = "vaultwarden" version = "1.0.0" dependencies = [ "argon2", + "aws-config", + "aws-sdk-dsql", "bigdecimal", "bytes", "cached", @@ -4158,6 +4586,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "walkdir" version = "2.5.0" @@ -4616,6 +5050,12 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +[[package]] +name = "xmlparser" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" + [[package]] name = "yansi" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index 85bd1c16..72647720 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ build = "build.rs" enable_syslog = [] mysql = ["diesel/mysql", "diesel_migrations/mysql"] postgresql = ["diesel/postgres", "diesel_migrations/postgres"] +dsql = ["postgresql", "dep:aws-config", "dep:aws-sdk-dsql"] sqlite = ["diesel/sqlite", "diesel_migrations/sqlite", "dep:libsqlite3-sys"] # Enable to use a vendored and statically linked openssl vendored_openssl = ["openssl/vendored"] @@ -88,6 +89,10 @@ diesel-derive-newtype = "2.1.2" # Bundled/Static SQLite libsqlite3-sys = { version = "0.31.0", features = ["bundled"], optional = true } +# Amazon Aurora DSQL +aws-config = { version = "1.5.12", features = ["behavior-version-latest"], optional = true } +aws-sdk-dsql = { version = "1.2.0", features = ["behavior-version-latest"], optional = true } + # Crypto-related libraries rand = "0.9.0" ring = "0.17.8" diff --git a/build.rs b/build.rs index 07bd99a7..6ad6846d 100644 --- a/build.rs +++ b/build.rs @@ -9,10 +9,12 @@ fn main() { println!("cargo:rustc-cfg=mysql"); #[cfg(feature = "postgresql")] println!("cargo:rustc-cfg=postgresql"); + #[cfg(feature = "dsql")] + println!("cargo:rustc-cfg=dsql"); #[cfg(feature = "query_logger")] println!("cargo:rustc-cfg=query_logger"); - #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgresql")))] + #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgresql", feature = "dsql")))] compile_error!( "You need to enable one DB backend. To build with previous defaults do: cargo build --features sqlite" ); @@ -22,6 +24,7 @@ fn main() { println!("cargo::rustc-check-cfg=cfg(sqlite)"); println!("cargo::rustc-check-cfg=cfg(mysql)"); println!("cargo::rustc-check-cfg=cfg(postgresql)"); + println!("cargo::rustc-check-cfg=cfg(dsql)"); println!("cargo::rustc-check-cfg=cfg(query_logger)"); // Rerun when these paths are changed. diff --git a/migrations/dsql/2024-12-30-100000_create_tables/metadata.toml b/migrations/dsql/2024-12-30-100000_create_tables/metadata.toml new file mode 100644 index 00000000..16153bc0 --- /dev/null +++ b/migrations/dsql/2024-12-30-100000_create_tables/metadata.toml @@ -0,0 +1 @@ +run_in_transaction = false \ No newline at end of file diff --git a/migrations/dsql/2024-12-30-100000_create_tables/up.sql b/migrations/dsql/2024-12-30-100000_create_tables/up.sql new file mode 100644 index 00000000..3eefed1e --- /dev/null +++ b/migrations/dsql/2024-12-30-100000_create_tables/up.sql @@ -0,0 +1,281 @@ +CREATE TABLE attachments ( + id text NOT NULL PRIMARY KEY, + cipher_uuid character varying(40) NOT NULL, + file_name text NOT NULL, + file_size bigint NOT NULL, + akey text +); + +CREATE TABLE auth_requests ( + uuid character(36) NOT NULL PRIMARY KEY, + user_uuid character(36) NOT NULL, + organization_uuid character(36), + request_device_identifier character(36) NOT NULL, + device_type integer NOT NULL, + request_ip text NOT NULL, + response_device_id character(36), + access_code text NOT NULL, + public_key text NOT NULL, + enc_key text, + master_password_hash text, + approved boolean, + creation_date timestamp without time zone NOT NULL, + response_date timestamp without time zone, + authentication_date timestamp without time zone +); + +CREATE TABLE ciphers ( + uuid character varying(40) NOT NULL PRIMARY KEY, + created_at timestamp without time zone NOT NULL, + updated_at timestamp without time zone NOT NULL, + user_uuid character varying(40), + organization_uuid character varying(40), + atype integer NOT NULL, + name text NOT NULL, + notes text, + fields text, + data text NOT NULL, + password_history text, + deleted_at timestamp without time zone, + reprompt integer, + key text +); + +CREATE TABLE ciphers_collections ( + cipher_uuid character varying(40) NOT NULL, + collection_uuid character varying(40) NOT NULL, + PRIMARY KEY (cipher_uuid, collection_uuid) +); + +CREATE TABLE collections ( + uuid character varying(40) NOT NULL PRIMARY KEY, + org_uuid character varying(40) NOT NULL, + name text NOT NULL, + external_id text +); + +CREATE TABLE collections_groups ( + collections_uuid character varying(40) NOT NULL, + groups_uuid character(36) NOT NULL, + read_only boolean NOT NULL, + hide_passwords boolean NOT NULL, + PRIMARY KEY (collections_uuid, groups_uuid) +); + +CREATE TABLE devices ( + uuid character varying(40) NOT NULL, + created_at timestamp without time zone NOT NULL, + updated_at timestamp without time zone NOT NULL, + user_uuid character varying(40) NOT NULL, + name text NOT NULL, + atype integer NOT NULL, + push_token text, + refresh_token text NOT NULL, + twofactor_remember text, + push_uuid text, + PRIMARY KEY (uuid, user_uuid) +); + +CREATE TABLE emergency_access ( + uuid character(36) NOT NULL PRIMARY KEY, + grantor_uuid character(36), + grantee_uuid character(36), + email character varying(255), + key_encrypted text, + atype integer NOT NULL, + status integer NOT NULL, + wait_time_days integer NOT NULL, + recovery_initiated_at timestamp without time zone, + last_notification_at timestamp without time zone, + updated_at timestamp without time zone NOT NULL, + created_at timestamp without time zone NOT NULL +); + +CREATE TABLE event ( + uuid character(36) NOT NULL PRIMARY KEY, + event_type integer NOT NULL, + user_uuid character(36), + org_uuid character(36), + cipher_uuid character(36), + collection_uuid character(36), + group_uuid character(36), + org_user_uuid character(36), + act_user_uuid character(36), + device_type integer, + ip_address text, + event_date timestamp without time zone NOT NULL, + policy_uuid character(36), + provider_uuid character(36), + provider_user_uuid character(36), + provider_org_uuid character(36) +); + +CREATE TABLE favorites ( + user_uuid character varying(40) NOT NULL, + cipher_uuid character varying(40) NOT NULL, + PRIMARY KEY (user_uuid, cipher_uuid) +); + +CREATE TABLE folders ( + uuid character varying(40) NOT NULL PRIMARY KEY, + created_at timestamp without time zone NOT NULL, + updated_at timestamp without time zone NOT NULL, + user_uuid character varying(40) NOT NULL, + name text NOT NULL +); + +CREATE TABLE folders_ciphers ( + cipher_uuid character varying(40) NOT NULL, + folder_uuid character varying(40) NOT NULL, + PRIMARY KEY (cipher_uuid, folder_uuid) +); + +CREATE TABLE groups ( + uuid character(36) NOT NULL PRIMARY KEY, + organizations_uuid character varying(40) NOT NULL, + name character varying(100) NOT NULL, + access_all boolean NOT NULL, + external_id character varying(300), + creation_date timestamp without time zone NOT NULL, + revision_date timestamp without time zone NOT NULL +); + +CREATE TABLE groups_users ( + groups_uuid character(36) NOT NULL, + users_organizations_uuid character varying(36) NOT NULL, + PRIMARY KEY (groups_uuid, users_organizations_uuid) +); + +CREATE TABLE invitations ( + email text NOT NULL PRIMARY KEY +); + +CREATE TABLE org_policies ( + uuid character(36) NOT NULL PRIMARY KEY, + org_uuid character(36) NOT NULL, + atype integer NOT NULL, + enabled boolean NOT NULL, + data text NOT NULL, + UNIQUE (org_uuid, atype) +); + +CREATE TABLE organization_api_key ( + uuid character(36) NOT NULL, + org_uuid character(36) NOT NULL, + atype integer NOT NULL, + api_key character varying(255), + revision_date timestamp without time zone NOT NULL, + PRIMARY KEY (uuid, org_uuid) +); + +CREATE TABLE organizations ( + uuid character varying(40) NOT NULL PRIMARY KEY, + name text NOT NULL, + billing_email text NOT NULL, + private_key text, + public_key text +); + +CREATE TABLE sends ( + uuid character(36) NOT NULL PRIMARY KEY, + user_uuid character(36), + organization_uuid character(36), + name text NOT NULL, + notes text, + atype integer NOT NULL, + data text NOT NULL, + akey text NOT NULL, + password_hash bytea, + password_salt bytea, + password_iter integer, + max_access_count integer, + access_count integer NOT NULL, + creation_date timestamp without time zone NOT NULL, + revision_date timestamp without time zone NOT NULL, + expiration_date timestamp without time zone, + deletion_date timestamp without time zone NOT NULL, + disabled boolean NOT NULL, + hide_email boolean +); + +CREATE TABLE twofactor ( + uuid character varying(40) NOT NULL PRIMARY KEY, + user_uuid character varying(40) NOT NULL, + atype integer NOT NULL, + enabled boolean NOT NULL, + data text NOT NULL, + last_used bigint DEFAULT 0 NOT NULL, + UNIQUE (user_uuid, atype) +); + +CREATE TABLE twofactor_duo_ctx ( + state character varying(64) NOT NULL PRIMARY KEY, + user_email character varying(255) NOT NULL, + nonce character varying(64) NOT NULL, + exp bigint NOT NULL +); + +CREATE TABLE twofactor_incomplete ( + user_uuid character varying(40) NOT NULL, + device_uuid character varying(40) NOT NULL, + device_name text NOT NULL, + login_time timestamp without time zone NOT NULL, + ip_address text NOT NULL, + device_type integer DEFAULT 14 NOT NULL, + PRIMARY KEY (user_uuid, device_uuid) +); + +CREATE TABLE users ( + uuid character varying(40) NOT NULL PRIMARY KEY, + created_at timestamp without time zone NOT NULL, + updated_at timestamp without time zone NOT NULL, + email text NOT NULL UNIQUE, + name text NOT NULL, + password_hash bytea NOT NULL, + salt bytea NOT NULL, + password_iterations integer NOT NULL, + password_hint text, + akey text NOT NULL, + private_key text, + public_key text, + totp_secret text, + totp_recover text, + security_stamp text NOT NULL, + equivalent_domains text NOT NULL, + excluded_globals text NOT NULL, + client_kdf_type integer DEFAULT 0 NOT NULL, + client_kdf_iter integer DEFAULT 100000 NOT NULL, + verified_at timestamp without time zone, + last_verifying_at timestamp without time zone, + login_verify_count integer DEFAULT 0 NOT NULL, + email_new character varying(255) DEFAULT NULL::character varying, + email_new_token character varying(16) DEFAULT NULL::character varying, + enabled boolean DEFAULT true NOT NULL, + stamp_exception text, + api_key text, + avatar_color text, + client_kdf_memory integer, + client_kdf_parallelism integer, + external_id text +); + +CREATE TABLE users_collections ( + user_uuid character varying(40) NOT NULL, + collection_uuid character varying(40) NOT NULL, + read_only boolean DEFAULT false NOT NULL, + hide_passwords boolean DEFAULT false NOT NULL, + PRIMARY KEY (user_uuid, collection_uuid) +); + +CREATE TABLE users_organizations ( + uuid character varying(40) NOT NULL PRIMARY KEY, + user_uuid character varying(40) NOT NULL, + org_uuid character varying(40) NOT NULL, + access_all boolean NOT NULL, + akey text NOT NULL, + status integer NOT NULL, + atype integer NOT NULL, + reset_password_key text, + external_id text, + UNIQUE (user_uuid, org_uuid) +); \ No newline at end of file diff --git a/src/db/dsql.rs b/src/db/dsql.rs new file mode 100644 index 00000000..803c376b --- /dev/null +++ b/src/db/dsql.rs @@ -0,0 +1,184 @@ +use std::sync::RwLock; + +use diesel::{ + r2d2::{ManageConnection, R2D2Connection}, + ConnectionError, + ConnectionResult, +}; +use url::Url; + +#[derive(Debug)] +pub struct ConnectionManager { + inner: RwLock>, + #[cfg(dsql)] + dsql_url: Option, +} + +impl ConnectionManager { + /// Returns a new connection manager, + /// which establishes connections to the given database URL. + pub fn new>(database_url: S) -> Self { + let database_url = database_url.into(); + + Self { + inner: RwLock::new(diesel::r2d2::ConnectionManager::new(&database_url)), + #[cfg(dsql)] + dsql_url: if database_url.starts_with("dsql:") { + Some(database_url) + } else { + None + }, + } + } +} + +impl ManageConnection for ConnectionManager +where + T: R2D2Connection + Send + 'static, +{ + type Connection = T; + type Error = diesel::r2d2::Error; + + fn connect(&self) -> Result { + #[cfg(dsql)] + if let Some(dsql_url) = &self.dsql_url { + let url = psql_url(dsql_url).map_err(|e| Self::Error::ConnectionError(e))?; + self.inner.write().expect("Failed to lock inner connection manager to set DSQL connection URL").update_database_url(&url); + } + + self.inner.read().expect("Failed to lock inner connection manager to connect").connect() + } + + fn is_valid(&self, conn: &mut T) -> Result<(), Self::Error> { + self.inner.read().expect("Failed to lock inner connection manager to check validity").is_valid(conn) + } + + fn has_broken(&self, conn: &mut T) -> bool { + self.inner.read().expect("Failed to lock inner connection manager to check if has broken").has_broken(conn) + } +} + +// Cache the AWS SDK config, as recommended by the AWS SDK documentation. The +// initial load is async, so we spawn a thread to load it and then join it to +// get the result in a blocking fashion. +static AWS_SDK_CONFIG: std::sync::LazyLock> = std::sync::LazyLock::new(|| { + std::thread::spawn(|| { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + std::io::Result::Ok(rt.block_on(aws_config::load_defaults(aws_config::BehaviorVersion::latest()))) + }) + .join() + .map_err(|e| ConnectionError::BadConnection(format!("Failed to load AWS config for DSQL connection: {e:#?}")))? + .map_err(|e| ConnectionError::BadConnection(format!("Failed to load AWS config for DSQL connection: {e}"))) +}); + +// Generate a Postgres libpq connection string. The input connection string has +// the following format: +// +// dsql://.dsql..on.aws +// +// The generated connection string will have the form: +// +// postgresql://.dsql..on.aws/postgres?sslmode=require&user=admin&password= +// +// The auth token is a temporary token generated by the AWS SDK for DSQL. It is +// valid for up to 15 minutes. We cache the last-generated token for each unique +// DSQL connection URL, and reuse it if it is less than 14 minutes old. +pub(crate) fn psql_url(url: &str) -> Result { + use std::{ + collections::HashMap, + sync::{Arc, LazyLock, Mutex}, + time::Duration, + }; + + struct PsqlUrl { + timestamp: std::time::Instant, + url: String, + } + + static PSQL_URLS: LazyLock>>>>> = LazyLock::new(|| Mutex::new(HashMap::new())); + + let mut psql_urls = PSQL_URLS.lock().map_err(|e| ConnectionError::BadConnection(format!("Failed to lock PSQL URLs: {e}")))?; + + let psql_url_lock = if let Some(existing_psql_url_lock) = psql_urls.get(url) { + existing_psql_url_lock.clone() + } else { + let psql_url_lock = Arc::new(Mutex::new(None)); + psql_urls.insert(url.to_string(), psql_url_lock.clone()); + psql_url_lock + }; + + let mut psql_url_lock_guard = psql_url_lock.lock().map_err(|e| ConnectionError::BadConnection(format!("Failed to lock PSQL url: {e}")))?; + + drop(psql_urls); + + if let Some(ref psql_url) = *psql_url_lock_guard { + if psql_url.timestamp.elapsed() < Duration::from_secs(14 * 60) { + debug!("Reusing DSQL auth token for connection '{url}'"); + return Ok(psql_url.url.clone()); + } + + info!("Refreshing DSQL auth token for connection '{url}'"); + } else { + info!("Generating new DSQL auth token for connection '{url}'"); + } + + // This would be so much easier if ConnectionError implemented Clone. + let sdk_config = match *AWS_SDK_CONFIG { + Ok(ref sdk_config) => sdk_config.clone(), + Err(ConnectionError::BadConnection(ref e)) => return Err(ConnectionError::BadConnection(e.to_owned())), + Err(ref e) => unreachable!("Unexpected error loading AWS SDK config: {e}"), + }; + + let mut psql_url = Url::parse(url).map_err(|e| { + ConnectionError::InvalidConnectionUrl(e.to_string()) + })?; + + let host = psql_url.host_str().ok_or(ConnectionError::InvalidConnectionUrl("Missing hostname in connection URL".to_string()))?.to_string(); + + static DSQL_REGION_FROM_HOST_RE: LazyLock = LazyLock::new(|| { + regex::Regex::new(r"^[a-z0-9]+\.dsql\.(?P[a-z0-9-]+)\.on\.aws$").expect("Failed to compile DSQL region regex") + }); + + let region = (*DSQL_REGION_FROM_HOST_RE).captures(&host).ok_or(ConnectionError::InvalidConnectionUrl("Failed to find AWS region in DSQL hostname".to_string()))? + .name("region") + .ok_or(ConnectionError::InvalidConnectionUrl("Failed to find AWS region in DSQL hostname".to_string()))? + .as_str() + .to_string(); + + let region = aws_config::Region::new(region); + + let auth_config = aws_sdk_dsql::auth_token::Config::builder() + .hostname(host) + .region(region) + .build() + .map_err(|e| ConnectionError::BadConnection(format!("Failed to build AWS auth token signer config: {e}")))?; + + let signer = aws_sdk_dsql::auth_token::AuthTokenGenerator::new(auth_config); + + let now = std::time::Instant::now(); + + let auth_token = std::thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + rt.block_on(signer.db_connect_admin_auth_token(&sdk_config)) + }) + .join() + .map_err(|e| ConnectionError::BadConnection(format!("Failed to generate DSQL auth token: {e:#?}")))? + .map_err(|e| ConnectionError::BadConnection(format!("Failed to generate DSQL auth token: {e}")))?; + + psql_url.set_scheme("postgresql").expect("Failed to set 'postgresql' as scheme for DSQL connection URL"); + psql_url.set_path("postgres"); + psql_url.query_pairs_mut() + .append_pair("sslmode", "require") + .append_pair("user", "admin") + .append_pair("password", auth_token.as_str()); + + psql_url_lock_guard.replace(PsqlUrl { timestamp: now, url: psql_url.to_string() }); + + Ok(psql_url.to_string()) +} \ No newline at end of file diff --git a/src/db/mod.rs b/src/db/mod.rs index 464be561..c3fc755b 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,8 +1,11 @@ +#[cfg(dsql)] +mod dsql; + use std::{sync::Arc, time::Duration}; use diesel::{ connection::SimpleConnection, - r2d2::{ConnectionManager, CustomizeConnection, Pool, PooledConnection}, + r2d2::{CustomizeConnection, Pool, PooledConnection}, }; use rocket::{ @@ -21,6 +24,11 @@ use crate::{ CONFIG, }; +#[cfg(dsql)] +type ConnectionManager = dsql::ConnectionManager; +#[cfg(not(dsql))] +type ConnectionManager = diesel::r2d2::ConnectionManager; + #[cfg(sqlite)] #[path = "schemas/sqlite/schema.rs"] pub mod __sqlite_schema; @@ -130,7 +138,7 @@ macro_rules! generate_connections { DbConnType::$name => { #[cfg($name)] { - paste::paste!{ [< $name _migrations >]::run_migrations()?; } + paste::paste!{ [< $name _migrations >]::run_migrations(&url)?; } let manager = ConnectionManager::new(&url); let pool = Pool::builder() .max_size(CONFIG.database_max_conns()) @@ -209,6 +217,14 @@ impl DbConnType { #[cfg(not(postgresql))] err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled") + // Amazon Aurora DSQL + } else if url.starts_with("dsql:") { + #[cfg(dsql)] + return Ok(DbConnType::postgresql); + + #[cfg(not(dsql))] + err!("`DATABASE_URL` is a DSQL URL, but the 'dsql' feature is not enabled") + //Sqlite } else { #[cfg(sqlite)] @@ -429,13 +445,12 @@ mod sqlite_migrations { use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite"); - pub fn run_migrations() -> Result<(), super::Error> { + pub fn run_migrations(url: &str) -> Result<(), super::Error> { use diesel::{Connection, RunQueryDsl}; - let url = crate::CONFIG.database_url(); // Establish a connection to the sqlite database (this will create a new one, if it does // not exist, and exit if there is an error). - let mut connection = diesel::sqlite::SqliteConnection::establish(&url)?; + let mut connection = diesel::sqlite::SqliteConnection::establish(url)?; // Run the migrations after successfully establishing a connection // Disable Foreign Key Checks during migration @@ -459,10 +474,10 @@ mod mysql_migrations { use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/mysql"); - pub fn run_migrations() -> Result<(), super::Error> { + pub fn run_migrations(url: &str) -> Result<(), super::Error> { use diesel::{Connection, RunQueryDsl}; // Make sure the database is up to date (create if it doesn't exist, or run the migrations) - let mut connection = diesel::mysql::MysqlConnection::establish(&crate::CONFIG.database_url())?; + let mut connection = diesel::mysql::MysqlConnection::establish(url)?; // Disable Foreign Key Checks during migration // Scoped to a connection/session. @@ -480,10 +495,21 @@ mod postgresql_migrations { use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgresql"); - pub fn run_migrations() -> Result<(), super::Error> { + pub fn run_migrations(url: &str) -> Result<(), super::Error> { use diesel::Connection; + + #[cfg(dsql)] + if url.starts_with("dsql:") { + pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/dsql"); + + let psql_url = crate::db::dsql::psql_url(url)?; + let mut connection = diesel::pg::PgConnection::establish(&psql_url)?; + connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations"); + return Ok(()) + } + // Make sure the database is up to date (create if it doesn't exist, or run the migrations) - let mut connection = diesel::pg::PgConnection::establish(&crate::CONFIG.database_url())?; + let mut connection = diesel::pg::PgConnection::establish(url)?; connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations"); Ok(()) }