mirror of
https://github.com/dani-garcia/vaultwarden.git
synced 2025-06-15 08:17:00 +00:00
Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking
This commit is contained in:
874
Cargo.lock
generated
874
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
27
Cargo.toml
27
Cargo.toml
@ -3,7 +3,7 @@ name = "vaultwarden"
|
||||
version = "1.0.0"
|
||||
authors = ["Daniel García <dani-garcia@users.noreply.github.com>"]
|
||||
edition = "2021"
|
||||
rust-version = "1.60"
|
||||
rust-version = "1.56"
|
||||
resolver = "2"
|
||||
|
||||
repository = "https://github.com/dani-garcia/vaultwarden"
|
||||
@ -13,6 +13,7 @@ publish = false
|
||||
build = "build.rs"
|
||||
|
||||
[features]
|
||||
# default = ["sqlite"]
|
||||
# Empty to keep compatibility, prefer to set USE_SYSLOG=true
|
||||
enable_syslog = []
|
||||
mysql = ["diesel/mysql", "diesel_migrations/mysql"]
|
||||
@ -29,22 +30,22 @@ unstable = []
|
||||
syslog = "4.0.1"
|
||||
|
||||
[dependencies]
|
||||
# Web framework for nightly with a focus on ease-of-use, expressibility, and speed.
|
||||
rocket = { version = "=0.5.0-dev", features = ["tls"], default-features = false }
|
||||
rocket_contrib = "=0.5.0-dev"
|
||||
# Web framework
|
||||
rocket = { version = "0.5.0-rc.1", features = ["tls", "json"], default-features = false }
|
||||
|
||||
# HTTP client
|
||||
reqwest = { version = "0.11.9", features = ["blocking", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
|
||||
# Async futures
|
||||
futures = "0.3.19"
|
||||
tokio = { version = "1.16.1", features = ["rt-multi-thread", "fs", "io-util", "parking_lot"] }
|
||||
|
||||
# HTTP client
|
||||
reqwest = { version = "0.11.9", features = ["stream", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
|
||||
bytes = "1.1.0"
|
||||
|
||||
# Used for custom short lived cookie jar
|
||||
cookie = "0.15.1"
|
||||
cookie_store = "0.15.1"
|
||||
bytes = "1.1.0"
|
||||
url = "2.2.2"
|
||||
|
||||
# multipart/form-data support
|
||||
multipart = { version = "0.18.0", features = ["server"], default-features = false }
|
||||
|
||||
# WebSockets library
|
||||
ws = { version = "0.11.1", package = "parity-ws" }
|
||||
|
||||
@ -141,10 +142,10 @@ backtrace = "0.3.64"
|
||||
paste = "1.0.6"
|
||||
governor = "0.4.1"
|
||||
|
||||
ctrlc = { version = "3.2.1", features = ["termination"] }
|
||||
|
||||
[patch.crates-io]
|
||||
# Use newest ring
|
||||
rocket = { git = 'https://github.com/SergioBenitez/Rocket', rev = '263e39b5b429de1913ce7e3036575a7b4d88b6d7' }
|
||||
rocket_contrib = { git = 'https://github.com/SergioBenitez/Rocket', rev = '263e39b5b429de1913ce7e3036575a7b4d88b6d7' }
|
||||
rocket = { git = 'https://github.com/SergioBenitez/Rocket', rev = '8cae077ba1d54b92cdef3e171a730b819d5eeb8e' }
|
||||
|
||||
# The maintainer of the `job_scheduler` crate doesn't seem to have responded
|
||||
# to any issues or PRs for almost a year (as of April 2021). This hopefully
|
||||
|
@ -1,2 +0,0 @@
|
||||
[global.limits]
|
||||
json = 10485760 # 10 MiB
|
@ -1 +1 @@
|
||||
nightly-2022-01-23
|
||||
stable
|
||||
|
@ -3,13 +3,14 @@ use serde::de::DeserializeOwned;
|
||||
use serde_json::Value;
|
||||
use std::env;
|
||||
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::{
|
||||
http::{Cookie, Cookies, SameSite, Status},
|
||||
request::{self, FlashMessage, Form, FromRequest, Outcome, Request},
|
||||
response::{content::Html, Flash, Redirect},
|
||||
form::Form,
|
||||
http::{Cookie, CookieJar, SameSite, Status},
|
||||
request::{self, FlashMessage, FromRequest, Outcome, Request},
|
||||
response::{content::RawHtml as Html, Flash, Redirect},
|
||||
Route,
|
||||
};
|
||||
use rocket_contrib::json::Json;
|
||||
|
||||
use crate::{
|
||||
api::{ApiResult, EmptyResult, JsonResult, NumberOrString},
|
||||
@ -85,10 +86,11 @@ fn admin_path() -> String {
|
||||
|
||||
struct Referer(Option<String>);
|
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for Referer {
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for Referer {
|
||||
type Error = ();
|
||||
|
||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
|
||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||
Outcome::Success(Referer(request.headers().get_one("Referer").map(str::to_string)))
|
||||
}
|
||||
}
|
||||
@ -96,10 +98,11 @@ impl<'a, 'r> FromRequest<'a, 'r> for Referer {
|
||||
#[derive(Debug)]
|
||||
struct IpHeader(Option<String>);
|
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for IpHeader {
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for IpHeader {
|
||||
type Error = ();
|
||||
|
||||
fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
||||
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
if req.headers().get_one(&CONFIG.ip_header()).is_some() {
|
||||
Outcome::Success(IpHeader(Some(CONFIG.ip_header())))
|
||||
} else if req.headers().get_one("X-Client-IP").is_some() {
|
||||
@ -138,7 +141,7 @@ fn admin_url(referer: Referer) -> String {
|
||||
#[get("/", rank = 2)]
|
||||
fn admin_login(flash: Option<FlashMessage>) -> ApiResult<Html<String>> {
|
||||
// If there is an error, show it
|
||||
let msg = flash.map(|msg| format!("{}: {}", msg.name(), msg.msg()));
|
||||
let msg = flash.map(|msg| format!("{}: {}", msg.kind(), msg.message()));
|
||||
let json = json!({
|
||||
"page_content": "admin/login",
|
||||
"version": VERSION,
|
||||
@ -159,7 +162,7 @@ struct LoginForm {
|
||||
#[post("/", data = "<data>")]
|
||||
fn post_admin_login(
|
||||
data: Form<LoginForm>,
|
||||
mut cookies: Cookies,
|
||||
cookies: &CookieJar,
|
||||
ip: ClientIp,
|
||||
referer: Referer,
|
||||
) -> Result<Redirect, Flash<Redirect>> {
|
||||
@ -180,7 +183,7 @@ fn post_admin_login(
|
||||
|
||||
let cookie = Cookie::build(COOKIE_NAME, jwt)
|
||||
.path(admin_path())
|
||||
.max_age(time::Duration::minutes(20))
|
||||
.max_age(rocket::time::Duration::minutes(20))
|
||||
.same_site(SameSite::Strict)
|
||||
.http_only(true)
|
||||
.finish();
|
||||
@ -297,7 +300,7 @@ fn test_smtp(data: Json<InviteData>, _token: AdminToken) -> EmptyResult {
|
||||
}
|
||||
|
||||
#[get("/logout")]
|
||||
fn logout(mut cookies: Cookies, referer: Referer) -> Redirect {
|
||||
fn logout(cookies: &CookieJar, referer: Referer) -> Redirect {
|
||||
cookies.remove(Cookie::named(COOKIE_NAME));
|
||||
Redirect::to(admin_url(referer))
|
||||
}
|
||||
@ -462,23 +465,23 @@ struct GitCommit {
|
||||
sha: String,
|
||||
}
|
||||
|
||||
fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
|
||||
async fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
|
||||
let github_api = get_reqwest_client();
|
||||
|
||||
Ok(github_api.get(url).send()?.error_for_status()?.json::<T>()?)
|
||||
Ok(github_api.get(url).send().await?.error_for_status()?.json::<T>().await?)
|
||||
}
|
||||
|
||||
fn has_http_access() -> bool {
|
||||
async fn has_http_access() -> bool {
|
||||
let http_access = get_reqwest_client();
|
||||
|
||||
match http_access.head("https://github.com/dani-garcia/vaultwarden").send() {
|
||||
match http_access.head("https://github.com/dani-garcia/vaultwarden").send().await {
|
||||
Ok(r) => r.status().is_success(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/diagnostics")]
|
||||
fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
|
||||
async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
|
||||
use crate::util::read_file_string;
|
||||
use chrono::prelude::*;
|
||||
use std::net::ToSocketAddrs;
|
||||
@ -497,7 +500,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
||||
|
||||
// Execute some environment checks
|
||||
let running_within_docker = is_running_in_docker();
|
||||
let has_http_access = has_http_access();
|
||||
let has_http_access = has_http_access().await;
|
||||
let uses_proxy = env::var_os("HTTP_PROXY").is_some()
|
||||
|| env::var_os("http_proxy").is_some()
|
||||
|| env::var_os("HTTPS_PROXY").is_some()
|
||||
@ -513,11 +516,14 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
||||
// TODO: Maybe we need to cache this using a LazyStatic or something. Github only allows 60 requests per hour, and we use 3 here already.
|
||||
let (latest_release, latest_commit, latest_web_build) = if has_http_access {
|
||||
(
|
||||
match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest") {
|
||||
match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest")
|
||||
.await
|
||||
{
|
||||
Ok(r) => r.tag_name,
|
||||
_ => "-".to_string(),
|
||||
},
|
||||
match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main") {
|
||||
match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await
|
||||
{
|
||||
Ok(mut c) => {
|
||||
c.sha.truncate(8);
|
||||
c.sha
|
||||
@ -531,7 +537,9 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
||||
} else {
|
||||
match get_github_api::<GitRelease>(
|
||||
"https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest",
|
||||
) {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
|
||||
_ => "-".to_string(),
|
||||
}
|
||||
@ -562,7 +570,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
||||
"ip_header_config": &CONFIG.ip_header(),
|
||||
"uses_proxy": uses_proxy,
|
||||
"db_type": *DB_TYPE,
|
||||
"db_version": get_sql_server_version(&conn),
|
||||
"db_version": get_sql_server_version(&conn).await,
|
||||
"admin_url": format!("{}/diagnostics", admin_url(Referer(None))),
|
||||
"overrides": &CONFIG.get_overrides().join(", "),
|
||||
"server_time_local": Local::now().format("%Y-%m-%d %H:%M:%S %Z").to_string(),
|
||||
@ -591,9 +599,9 @@ fn delete_config(_token: AdminToken) -> EmptyResult {
|
||||
}
|
||||
|
||||
#[post("/config/backup_db")]
|
||||
fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
|
||||
async fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
|
||||
if *CAN_BACKUP {
|
||||
backup_database(&conn)
|
||||
backup_database(&conn).await
|
||||
} else {
|
||||
err!("Can't back up current DB (Only SQLite supports this feature)");
|
||||
}
|
||||
@ -601,21 +609,22 @@ fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
|
||||
|
||||
pub struct AdminToken {}
|
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for AdminToken {
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for AdminToken {
|
||||
type Error = &'static str;
|
||||
|
||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
|
||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||
if CONFIG.disable_admin_token() {
|
||||
Outcome::Success(AdminToken {})
|
||||
} else {
|
||||
let mut cookies = request.cookies();
|
||||
let cookies = request.cookies();
|
||||
|
||||
let access_token = match cookies.get(COOKIE_NAME) {
|
||||
Some(cookie) => cookie.value(),
|
||||
None => return Outcome::Forward(()), // If there is no cookie, redirect to login
|
||||
};
|
||||
|
||||
let ip = match request.guard::<ClientIp>() {
|
||||
let ip = match ClientIp::from_request(request).await {
|
||||
Outcome::Success(ip) => ip.ip,
|
||||
_ => err_handler!("Error getting Client IP"),
|
||||
};
|
||||
|
@ -1,5 +1,5 @@
|
||||
use chrono::Utc;
|
||||
use rocket_contrib::json::Json;
|
||||
use rocket::serde::json::Json;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
use chrono::{Duration, Utc};
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use rocket_contrib::json::Json;
|
||||
use serde_json::Value;
|
||||
use std::borrow::Borrow;
|
||||
|
||||
@ -709,13 +709,13 @@ fn check_emergency_access_allowed() -> EmptyResult {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn emergency_request_timeout_job(pool: DbPool) {
|
||||
pub async fn emergency_request_timeout_job(pool: DbPool) {
|
||||
debug!("Start emergency_request_timeout_job");
|
||||
if !CONFIG.emergency_access_allowed() {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Ok(conn) = pool.get() {
|
||||
if let Ok(conn) = pool.get().await {
|
||||
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
|
||||
|
||||
if emergency_access_list.is_empty() {
|
||||
@ -756,13 +756,13 @@ pub fn emergency_request_timeout_job(pool: DbPool) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn emergency_notification_reminder_job(pool: DbPool) {
|
||||
pub async fn emergency_notification_reminder_job(pool: DbPool) {
|
||||
debug!("Start emergency_notification_reminder_job");
|
||||
if !CONFIG.emergency_access_allowed() {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Ok(conn) = pool.get() {
|
||||
if let Ok(conn) = pool.get().await {
|
||||
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
|
||||
|
||||
if emergency_access_list.is_empty() {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use rocket_contrib::json::Json;
|
||||
use rocket::serde::json::Json;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
|
@ -31,8 +31,8 @@ pub fn routes() -> Vec<Route> {
|
||||
//
|
||||
// Move this somewhere else
|
||||
//
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use rocket_contrib::json::Json;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
@ -144,7 +144,7 @@ fn put_eq_domains(data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbC
|
||||
}
|
||||
|
||||
#[get("/hibp/breach?<username>")]
|
||||
fn hibp_breach(username: String) -> JsonResult {
|
||||
async fn hibp_breach(username: String) -> JsonResult {
|
||||
let url = format!(
|
||||
"https://haveibeenpwned.com/api/v3/breachedaccount/{}?truncateResponse=false&includeUnverified=false",
|
||||
username
|
||||
@ -153,14 +153,14 @@ fn hibp_breach(username: String) -> JsonResult {
|
||||
if let Some(api_key) = crate::CONFIG.hibp_api_key() {
|
||||
let hibp_client = get_reqwest_client();
|
||||
|
||||
let res = hibp_client.get(&url).header("hibp-api-key", api_key).send()?;
|
||||
let res = hibp_client.get(&url).header("hibp-api-key", api_key).send().await?;
|
||||
|
||||
// If we get a 404, return a 404, it means no breached accounts
|
||||
if res.status() == 404 {
|
||||
return Err(Error::empty().with_code(404));
|
||||
}
|
||||
|
||||
let value: Value = res.error_for_status()?.json()?;
|
||||
let value: Value = res.error_for_status()?.json().await?;
|
||||
Ok(Json(value))
|
||||
} else {
|
||||
Ok(Json(json!([{
|
||||
|
@ -1,6 +1,6 @@
|
||||
use num_traits::FromPrimitive;
|
||||
use rocket::{request::Form, Route};
|
||||
use rocket_contrib::json::Json;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
@ -469,12 +469,12 @@ fn put_collection_users(
|
||||
|
||||
#[derive(FromForm)]
|
||||
struct OrgIdData {
|
||||
#[form(field = "organizationId")]
|
||||
#[field(name = "organizationId")]
|
||||
organization_id: String,
|
||||
}
|
||||
|
||||
#[get("/ciphers/organization-details?<data..>")]
|
||||
fn get_org_details(data: Form<OrgIdData>, headers: Headers, conn: DbConn) -> Json<Value> {
|
||||
fn get_org_details(data: OrgIdData, headers: Headers, conn: DbConn) -> Json<Value> {
|
||||
let ciphers = Cipher::find_by_org(&data.organization_id, &conn);
|
||||
let ciphers_json: Vec<Value> =
|
||||
ciphers.iter().map(|c| c.to_json(&headers.host, &headers.user.uuid, &conn)).collect();
|
||||
@ -1097,14 +1097,14 @@ struct RelationsData {
|
||||
|
||||
#[post("/ciphers/import-organization?<query..>", data = "<data>")]
|
||||
fn post_org_import(
|
||||
query: Form<OrgIdData>,
|
||||
query: OrgIdData,
|
||||
data: JsonUpcase<ImportData>,
|
||||
headers: AdminHeaders,
|
||||
conn: DbConn,
|
||||
nt: Notify,
|
||||
) -> EmptyResult {
|
||||
let data: ImportData = data.into_inner().data;
|
||||
let org_id = query.into_inner().organization_id;
|
||||
let org_id = query.organization_id;
|
||||
|
||||
// Read and create the collections
|
||||
let collections: Vec<_> = data
|
||||
|
@ -1,9 +1,10 @@
|
||||
use std::{io::Read, path::Path};
|
||||
use std::path::Path;
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use multipart::server::{save::SavedData, Multipart, SaveResult};
|
||||
use rocket::{http::ContentType, response::NamedFile, Data};
|
||||
use rocket_contrib::json::Json;
|
||||
use rocket::form::Form;
|
||||
use rocket::fs::NamedFile;
|
||||
use rocket::fs::TempFile;
|
||||
use rocket::serde::json::Json;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
@ -31,9 +32,9 @@ pub fn routes() -> Vec<rocket::Route> {
|
||||
]
|
||||
}
|
||||
|
||||
pub fn purge_sends(pool: DbPool) {
|
||||
pub async fn purge_sends(pool: DbPool) {
|
||||
debug!("Purging sends");
|
||||
if let Ok(conn) = pool.get() {
|
||||
if let Ok(conn) = pool.get().await {
|
||||
Send::purge(&conn);
|
||||
} else {
|
||||
error!("Failed to get DB connection while purging sends")
|
||||
@ -177,25 +178,23 @@ fn post_send(data: JsonUpcase<SendData>, headers: Headers, conn: DbConn, nt: Not
|
||||
Ok(Json(send.to_json()))
|
||||
}
|
||||
|
||||
#[derive(FromForm)]
|
||||
struct UploadData<'f> {
|
||||
model: Json<crate::util::UpCase<SendData>>,
|
||||
data: TempFile<'f>,
|
||||
}
|
||||
|
||||
#[post("/sends/file", format = "multipart/form-data", data = "<data>")]
|
||||
fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult {
|
||||
async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
|
||||
enforce_disable_send_policy(&headers, &conn)?;
|
||||
|
||||
let boundary = content_type.params().next().expect("No boundary provided").1;
|
||||
let UploadData {
|
||||
model,
|
||||
mut data,
|
||||
} = data.into_inner();
|
||||
let model = model.into_inner().data;
|
||||
|
||||
let mut mpart = Multipart::with_body(data.open(), boundary);
|
||||
|
||||
// First entry is the SendData JSON
|
||||
let mut model_entry = match mpart.read_entry()? {
|
||||
Some(e) if &*e.headers.name == "model" => e,
|
||||
Some(_) => err!("Invalid entry name"),
|
||||
None => err!("No model entry present"),
|
||||
};
|
||||
|
||||
let mut buf = String::new();
|
||||
model_entry.data.read_to_string(&mut buf)?;
|
||||
let data = serde_json::from_str::<crate::util::UpCase<SendData>>(&buf)?;
|
||||
enforce_disable_hide_email_policy(&data.data, &headers, &conn)?;
|
||||
enforce_disable_hide_email_policy(&model, &headers, &conn)?;
|
||||
|
||||
// Get the file length and add an extra 5% to avoid issues
|
||||
const SIZE_525_MB: u64 = 550_502_400;
|
||||
@ -212,45 +211,27 @@ fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn
|
||||
None => SIZE_525_MB,
|
||||
};
|
||||
|
||||
// Create the Send
|
||||
let mut send = create_send(data.data, headers.user.uuid)?;
|
||||
let file_id = crate::crypto::generate_send_id();
|
||||
|
||||
let mut send = create_send(model, headers.user.uuid)?;
|
||||
if send.atype != SendType::File as i32 {
|
||||
err!("Send content is not a file");
|
||||
}
|
||||
|
||||
let file_path = Path::new(&CONFIG.sends_folder()).join(&send.uuid).join(&file_id);
|
||||
let size = data.len();
|
||||
if size > size_limit {
|
||||
err!("Attachment storage limit exceeded with this file");
|
||||
}
|
||||
|
||||
// Read the data entry and save the file
|
||||
let mut data_entry = match mpart.read_entry()? {
|
||||
Some(e) if &*e.headers.name == "data" => e,
|
||||
Some(_) => err!("Invalid entry name"),
|
||||
None => err!("No model entry present"),
|
||||
};
|
||||
let file_id = crate::crypto::generate_send_id();
|
||||
let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(&send.uuid);
|
||||
let file_path = folder_path.join(&file_id);
|
||||
tokio::fs::create_dir_all(&folder_path).await?;
|
||||
data.persist_to(&file_path).await?;
|
||||
|
||||
let size = match data_entry.data.save().memory_threshold(0).size_limit(size_limit).with_path(&file_path) {
|
||||
SaveResult::Full(SavedData::File(_, size)) => size as i32,
|
||||
SaveResult::Full(other) => {
|
||||
std::fs::remove_file(&file_path).ok();
|
||||
err!(format!("Attachment is not a file: {:?}", other));
|
||||
}
|
||||
SaveResult::Partial(_, reason) => {
|
||||
std::fs::remove_file(&file_path).ok();
|
||||
err!(format!("Attachment storage limit exceeded with this file: {:?}", reason));
|
||||
}
|
||||
SaveResult::Error(e) => {
|
||||
std::fs::remove_file(&file_path).ok();
|
||||
err!(format!("Error: {:?}", e));
|
||||
}
|
||||
};
|
||||
|
||||
// Set ID and sizes
|
||||
let mut data_value: Value = serde_json::from_str(&send.data)?;
|
||||
if let Some(o) = data_value.as_object_mut() {
|
||||
o.insert(String::from("Id"), Value::String(file_id));
|
||||
o.insert(String::from("Size"), Value::Number(size.into()));
|
||||
o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size)));
|
||||
o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size as i32)));
|
||||
}
|
||||
send.data = serde_json::to_string(&data_value)?;
|
||||
|
||||
@ -367,10 +348,10 @@ fn post_access_file(
|
||||
}
|
||||
|
||||
#[get("/sends/<send_id>/<file_id>?<t>")]
|
||||
fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
|
||||
async fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
|
||||
if let Ok(claims) = crate::auth::decode_send(&t) {
|
||||
if claims.sub == format!("{}/{}", send_id, file_id) {
|
||||
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).ok();
|
||||
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
|
||||
}
|
||||
}
|
||||
None
|
||||
|
@ -1,6 +1,6 @@
|
||||
use data_encoding::BASE32;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use rocket_contrib::json::Json;
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
|
@ -1,7 +1,7 @@
|
||||
use chrono::Utc;
|
||||
use data_encoding::BASE64;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use rocket_contrib::json::Json;
|
||||
|
||||
use crate::{
|
||||
api::{core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, JsonUpcase, PasswordData},
|
||||
@ -152,7 +152,7 @@ fn check_duo_fields_custom(data: &EnableDuoData) -> bool {
|
||||
}
|
||||
|
||||
#[post("/two-factor/duo", data = "<data>")]
|
||||
fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
||||
async fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
||||
let data: EnableDuoData = data.into_inner().data;
|
||||
let mut user = headers.user;
|
||||
|
||||
@ -163,7 +163,7 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
|
||||
let (data, data_str) = if check_duo_fields_custom(&data) {
|
||||
let data_req: DuoData = data.into();
|
||||
let data_str = serde_json::to_string(&data_req)?;
|
||||
duo_api_request("GET", "/auth/v2/check", "", &data_req).map_res("Failed to validate Duo credentials")?;
|
||||
duo_api_request("GET", "/auth/v2/check", "", &data_req).await.map_res("Failed to validate Duo credentials")?;
|
||||
(data_req.obscure(), data_str)
|
||||
} else {
|
||||
(DuoData::secret(), String::new())
|
||||
@ -185,11 +185,11 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
|
||||
}
|
||||
|
||||
#[put("/two-factor/duo", data = "<data>")]
|
||||
fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
||||
activate_duo(data, headers, conn)
|
||||
async fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
||||
activate_duo(data, headers, conn).await
|
||||
}
|
||||
|
||||
fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
|
||||
async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
|
||||
use reqwest::{header, Method};
|
||||
use std::str::FromStr;
|
||||
|
||||
@ -209,7 +209,8 @@ fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> Em
|
||||
.basic_auth(username, Some(password))
|
||||
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
|
||||
.header(header::DATE, date)
|
||||
.send()?
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
|
||||
Ok(())
|
||||
|
@ -1,6 +1,6 @@
|
||||
use chrono::{Duration, NaiveDateTime, Utc};
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use rocket_contrib::json::Json;
|
||||
|
||||
use crate::{
|
||||
api::{core::two_factor::_generate_recover_code, EmptyResult, JsonResult, JsonUpcase, PasswordData},
|
||||
|
@ -1,7 +1,7 @@
|
||||
use chrono::{Duration, Utc};
|
||||
use data_encoding::BASE32;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use rocket_contrib::json::Json;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
@ -158,14 +158,14 @@ fn disable_twofactor_put(data: JsonUpcase<DisableTwoFactorData>, headers: Header
|
||||
disable_twofactor(data, headers, conn)
|
||||
}
|
||||
|
||||
pub fn send_incomplete_2fa_notifications(pool: DbPool) {
|
||||
pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
|
||||
debug!("Sending notifications for incomplete 2FA logins");
|
||||
|
||||
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
|
||||
return;
|
||||
}
|
||||
|
||||
let conn = match pool.get() {
|
||||
let conn = match pool.get().await {
|
||||
Ok(conn) => conn,
|
||||
_ => {
|
||||
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
|
||||
|
@ -1,6 +1,6 @@
|
||||
use once_cell::sync::Lazy;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use rocket_contrib::json::Json;
|
||||
use serde_json::Value;
|
||||
use u2f::{
|
||||
messages::{RegisterResponse, SignResponse, U2fSignRequest},
|
||||
|
@ -1,5 +1,5 @@
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use rocket_contrib::json::Json;
|
||||
use serde_json::Value;
|
||||
use url::Url;
|
||||
use webauthn_rs::{base64_data::Base64UrlSafeData, proto::*, AuthenticationState, RegistrationState, Webauthn};
|
||||
|
@ -1,5 +1,5 @@
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use rocket_contrib::json::Json;
|
||||
use serde_json::Value;
|
||||
use yubico::{config::Config, verify};
|
||||
|
||||
|
125
src/api/icons.rs
125
src/api/icons.rs
File diff suppressed because it is too large
Load Diff
@ -1,10 +1,10 @@
|
||||
use chrono::Utc;
|
||||
use num_traits::FromPrimitive;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::{
|
||||
request::{Form, FormItems, FromForm},
|
||||
form::{Form, FromForm},
|
||||
Route,
|
||||
};
|
||||
use rocket_contrib::json::Json;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
@ -455,66 +455,57 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
|
||||
|
||||
// https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
|
||||
// https://github.com/bitwarden/mobile/blob/master/src/Core/Models/Request/TokenRequest.cs
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[derive(Debug, Clone, Default, FromForm)]
|
||||
#[allow(non_snake_case)]
|
||||
struct ConnectData {
|
||||
// refresh_token, password, client_credentials (API key)
|
||||
grant_type: String,
|
||||
#[field(name = uncased("grant_type"))]
|
||||
#[field(name = uncased("granttype"))]
|
||||
grant_type: String, // refresh_token, password, client_credentials (API key)
|
||||
|
||||
// Needed for grant_type="refresh_token"
|
||||
#[field(name = uncased("refresh_token"))]
|
||||
#[field(name = uncased("refreshtoken"))]
|
||||
refresh_token: Option<String>,
|
||||
|
||||
// Needed for grant_type = "password" | "client_credentials"
|
||||
client_id: Option<String>, // web, cli, desktop, browser, mobile
|
||||
client_secret: Option<String>, // API key login (cli only)
|
||||
#[field(name = uncased("client_id"))]
|
||||
#[field(name = uncased("clientid"))]
|
||||
client_id: Option<String>, // web, cli, desktop, browser, mobile
|
||||
#[field(name = uncased("client_secret"))]
|
||||
#[field(name = uncased("clientsecret"))]
|
||||
client_secret: Option<String>,
|
||||
#[field(name = uncased("password"))]
|
||||
password: Option<String>,
|
||||
#[field(name = uncased("scope"))]
|
||||
scope: Option<String>,
|
||||
#[field(name = uncased("username"))]
|
||||
username: Option<String>,
|
||||
|
||||
#[field(name = uncased("device_identifier"))]
|
||||
#[field(name = uncased("deviceidentifier"))]
|
||||
device_identifier: Option<String>,
|
||||
#[field(name = uncased("device_name"))]
|
||||
#[field(name = uncased("devicename"))]
|
||||
device_name: Option<String>,
|
||||
#[field(name = uncased("device_type"))]
|
||||
#[field(name = uncased("devicetype"))]
|
||||
device_type: Option<String>,
|
||||
#[field(name = uncased("device_push_token"))]
|
||||
#[field(name = uncased("devicepushtoken"))]
|
||||
device_push_token: Option<String>, // Unused; mobile device push not yet supported.
|
||||
|
||||
// Needed for two-factor auth
|
||||
#[field(name = uncased("two_factor_provider"))]
|
||||
#[field(name = uncased("twofactorprovider"))]
|
||||
two_factor_provider: Option<i32>,
|
||||
#[field(name = uncased("two_factor_token"))]
|
||||
#[field(name = uncased("twofactortoken"))]
|
||||
two_factor_token: Option<String>,
|
||||
#[field(name = uncased("two_factor_remember"))]
|
||||
#[field(name = uncased("twofactorremember"))]
|
||||
two_factor_remember: Option<i32>,
|
||||
}
|
||||
|
||||
impl<'f> FromForm<'f> for ConnectData {
|
||||
type Error = String;
|
||||
|
||||
fn from_form(items: &mut FormItems<'f>, _strict: bool) -> Result<Self, Self::Error> {
|
||||
let mut form = Self::default();
|
||||
for item in items {
|
||||
let (key, value) = item.key_value_decoded();
|
||||
let mut normalized_key = key.to_lowercase();
|
||||
normalized_key.retain(|c| c != '_'); // Remove '_'
|
||||
|
||||
match normalized_key.as_ref() {
|
||||
"granttype" => form.grant_type = value,
|
||||
"refreshtoken" => form.refresh_token = Some(value),
|
||||
"clientid" => form.client_id = Some(value),
|
||||
"clientsecret" => form.client_secret = Some(value),
|
||||
"password" => form.password = Some(value),
|
||||
"scope" => form.scope = Some(value),
|
||||
"username" => form.username = Some(value),
|
||||
"deviceidentifier" => form.device_identifier = Some(value),
|
||||
"devicename" => form.device_name = Some(value),
|
||||
"devicetype" => form.device_type = Some(value),
|
||||
"devicepushtoken" => form.device_push_token = Some(value),
|
||||
"twofactorprovider" => form.two_factor_provider = value.parse().ok(),
|
||||
"twofactortoken" => form.two_factor_token = Some(value),
|
||||
"twofactorremember" => form.two_factor_remember = value.parse().ok(),
|
||||
key => warn!("Detected unexpected parameter during login: {}", key),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(form)
|
||||
}
|
||||
}
|
||||
|
||||
fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult {
|
||||
if value.is_none() {
|
||||
err!(msg)
|
||||
|
@ -5,7 +5,7 @@ mod identity;
|
||||
mod notifications;
|
||||
mod web;
|
||||
|
||||
use rocket_contrib::json::Json;
|
||||
use rocket::serde::json::Json;
|
||||
use serde_json::Value;
|
||||
|
||||
pub use crate::api::{
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::Route;
|
||||
use rocket_contrib::json::Json;
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
use crate::{api::EmptyResult, auth::Headers, Error, CONFIG};
|
||||
@ -417,7 +417,7 @@ pub enum UpdateType {
|
||||
}
|
||||
|
||||
use rocket::State;
|
||||
pub type Notify<'a> = State<'a, WebSocketUsers>;
|
||||
pub type Notify<'a> = &'a State<WebSocketUsers>;
|
||||
|
||||
pub fn start_notification_server() -> WebSocketUsers {
|
||||
let factory = WsFactory::init();
|
||||
@ -430,12 +430,11 @@ pub fn start_notification_server() -> WebSocketUsers {
|
||||
settings.queue_size = 2;
|
||||
settings.panic_on_internal = false;
|
||||
|
||||
ws::Builder::new()
|
||||
.with_settings(settings)
|
||||
.build(factory)
|
||||
.unwrap()
|
||||
.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port()))
|
||||
.unwrap();
|
||||
let ws = ws::Builder::new().with_settings(settings).build(factory).unwrap();
|
||||
CONFIG.set_ws_shutdown_handle(ws.broadcaster());
|
||||
ws.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port())).unwrap();
|
||||
|
||||
warn!("WS Server stopped!");
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use rocket::{http::ContentType, response::content::Content, response::NamedFile, Route};
|
||||
use rocket_contrib::json::Json;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::{fs::NamedFile, http::ContentType, Route};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
@ -21,16 +21,16 @@ pub fn routes() -> Vec<Route> {
|
||||
}
|
||||
|
||||
#[get("/")]
|
||||
fn web_index() -> Cached<Option<NamedFile>> {
|
||||
Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).ok(), false)
|
||||
async fn web_index() -> Cached<Option<NamedFile>> {
|
||||
Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).await.ok(), false)
|
||||
}
|
||||
|
||||
#[get("/app-id.json")]
|
||||
fn app_id() -> Cached<Content<Json<Value>>> {
|
||||
fn app_id() -> Cached<(ContentType, Json<Value>)> {
|
||||
let content_type = ContentType::new("application", "fido.trusted-apps+json");
|
||||
|
||||
Cached::long(
|
||||
Content(
|
||||
(
|
||||
content_type,
|
||||
Json(json!({
|
||||
"trustedFacets": [
|
||||
@ -58,13 +58,13 @@ fn app_id() -> Cached<Content<Json<Value>>> {
|
||||
}
|
||||
|
||||
#[get("/<p..>", rank = 10)] // Only match this if the other routes don't match
|
||||
fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
|
||||
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).ok(), true)
|
||||
async fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
|
||||
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).await.ok(), true)
|
||||
}
|
||||
|
||||
#[get("/attachments/<uuid>/<file_id>")]
|
||||
fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
|
||||
NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).ok()
|
||||
async fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
|
||||
NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).await.ok()
|
||||
}
|
||||
|
||||
// We use DbConn here to let the alive healthcheck also verify the database connection.
|
||||
@ -78,25 +78,20 @@ fn alive(_conn: DbConn) -> Json<String> {
|
||||
}
|
||||
|
||||
#[get("/vw_static/<filename>")]
|
||||
fn static_files(filename: String) -> Result<Content<&'static [u8]>, Error> {
|
||||
fn static_files(filename: String) -> Result<(ContentType, &'static [u8]), Error> {
|
||||
match filename.as_ref() {
|
||||
"mail-github.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
|
||||
"logo-gray.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
|
||||
"error-x.svg" => Ok(Content(ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
|
||||
"hibp.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
|
||||
"vaultwarden-icon.png" => {
|
||||
Ok(Content(ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png")))
|
||||
}
|
||||
|
||||
"bootstrap.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
|
||||
"bootstrap-native.js" => {
|
||||
Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js")))
|
||||
}
|
||||
"identicon.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
|
||||
"datatables.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
|
||||
"datatables.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
|
||||
"mail-github.png" => Ok((ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
|
||||
"logo-gray.png" => Ok((ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
|
||||
"error-x.svg" => Ok((ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
|
||||
"hibp.png" => Ok((ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
|
||||
"vaultwarden-icon.png" => Ok((ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png"))),
|
||||
"bootstrap.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
|
||||
"bootstrap-native.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js"))),
|
||||
"identicon.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
|
||||
"datatables.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
|
||||
"datatables.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
|
||||
"jquery-3.6.0.slim.js" => {
|
||||
Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
|
||||
Ok((ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
|
||||
}
|
||||
_ => err!(format!("Static file not found: {}", filename)),
|
||||
}
|
||||
|
262
src/auth.rs
262
src/auth.rs
File diff suppressed because it is too large
Load Diff
@ -36,6 +36,9 @@ macro_rules! make_config {
|
||||
pub struct Config { inner: RwLock<Inner> }
|
||||
|
||||
struct Inner {
|
||||
rocket_shutdown_handle: Option<rocket::Shutdown>,
|
||||
ws_shutdown_handle: Option<ws::Sender>,
|
||||
|
||||
templates: Handlebars<'static>,
|
||||
config: ConfigItems,
|
||||
|
||||
@ -332,6 +335,8 @@ make_config! {
|
||||
attachments_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "attachments");
|
||||
/// Sends folder
|
||||
sends_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "sends");
|
||||
/// Temp folder |> Used for storing temporary file uploads
|
||||
tmp_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "tmp");
|
||||
/// Templates folder
|
||||
templates_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "templates");
|
||||
/// Session JWT key
|
||||
@ -509,6 +514,9 @@ make_config! {
|
||||
/// Max database connection retries |> Number of times to retry the database connection during startup, with 1 second between each retry, set to 0 to retry indefinitely
|
||||
db_connection_retries: u32, false, def, 15;
|
||||
|
||||
/// Timeout when aquiring database connection
|
||||
database_timeout: u64, false, def, 30;
|
||||
|
||||
/// Database connection pool size
|
||||
database_max_conns: u32, false, def, 10;
|
||||
|
||||
@ -743,6 +751,8 @@ impl Config {
|
||||
|
||||
Ok(Config {
|
||||
inner: RwLock::new(Inner {
|
||||
rocket_shutdown_handle: None,
|
||||
ws_shutdown_handle: None,
|
||||
templates: load_templates(&config.templates_folder),
|
||||
config,
|
||||
_env,
|
||||
@ -907,6 +917,27 @@ impl Config {
|
||||
hb.render(name, data).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_rocket_shutdown_handle(&self, handle: rocket::Shutdown) {
|
||||
self.inner.write().unwrap().rocket_shutdown_handle = Some(handle);
|
||||
}
|
||||
|
||||
pub fn set_ws_shutdown_handle(&self, handle: ws::Sender) {
|
||||
self.inner.write().unwrap().ws_shutdown_handle = Some(handle);
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
if let Ok(c) = self.inner.read() {
|
||||
if let Some(handle) = c.ws_shutdown_handle.clone() {
|
||||
handle.shutdown().ok();
|
||||
}
|
||||
// Wait a bit before stopping the web server
|
||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
if let Some(handle) = c.rocket_shutdown_handle.clone() {
|
||||
handle.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext, RenderError, Renderable};
|
||||
|
231
src/db/mod.rs
231
src/db/mod.rs
File diff suppressed because it is too large
Load Diff
10
src/error.rs
10
src/error.rs
@ -45,6 +45,7 @@ use lettre::transport::smtp::Error as SmtpErr;
|
||||
use openssl::error::ErrorStack as SSLErr;
|
||||
use regex::Error as RegexErr;
|
||||
use reqwest::Error as ReqErr;
|
||||
use rocket::error::Error as RocketErr;
|
||||
use serde_json::{Error as SerdeErr, Value};
|
||||
use std::io::Error as IoErr;
|
||||
use std::time::SystemTimeError as TimeErr;
|
||||
@ -84,6 +85,7 @@ make_error! {
|
||||
Address(AddrErr): _has_source, _api_error,
|
||||
Smtp(SmtpErr): _has_source, _api_error,
|
||||
OpenSSL(SSLErr): _has_source, _api_error,
|
||||
Rocket(RocketErr): _has_source, _api_error,
|
||||
|
||||
DieselCon(DieselConErr): _has_source, _api_error,
|
||||
DieselMig(DieselMigErr): _has_source, _api_error,
|
||||
@ -193,8 +195,8 @@ use rocket::http::{ContentType, Status};
|
||||
use rocket::request::Request;
|
||||
use rocket::response::{self, Responder, Response};
|
||||
|
||||
impl<'r> Responder<'r> for Error {
|
||||
fn respond_to(self, _: &Request) -> response::Result<'r> {
|
||||
impl<'r> Responder<'r, 'static> for Error {
|
||||
fn respond_to(self, _: &Request) -> response::Result<'static> {
|
||||
match self.error {
|
||||
ErrorKind::Empty(_) => {} // Don't print the error in this situation
|
||||
ErrorKind::Simple(_) => {} // Don't print the error in this situation
|
||||
@ -202,8 +204,8 @@ impl<'r> Responder<'r> for Error {
|
||||
};
|
||||
|
||||
let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest);
|
||||
|
||||
Response::build().status(code).header(ContentType::JSON).sized_body(Cursor::new(format!("{}", self))).ok()
|
||||
let body = self.to_string();
|
||||
Response::build().status(code).header(ContentType::JSON).sized_body(Some(body.len()), Cursor::new(body)).ok()
|
||||
}
|
||||
}
|
||||
|
||||
|
98
src/main.rs
98
src/main.rs
@ -20,8 +20,15 @@ extern crate diesel;
|
||||
#[macro_use]
|
||||
extern crate diesel_migrations;
|
||||
|
||||
use job_scheduler::{Job, JobScheduler};
|
||||
use std::{fs::create_dir_all, panic, path::Path, process::exit, str::FromStr, thread, time::Duration};
|
||||
use std::{
|
||||
fs::{canonicalize, create_dir_all},
|
||||
panic,
|
||||
path::Path,
|
||||
process::exit,
|
||||
str::FromStr,
|
||||
thread,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
#[macro_use]
|
||||
mod error;
|
||||
@ -37,9 +44,11 @@ mod util;
|
||||
|
||||
pub use config::CONFIG;
|
||||
pub use error::{Error, MapResult};
|
||||
use rocket::data::{Limits, ToByteUnit};
|
||||
pub use util::is_running_in_docker;
|
||||
|
||||
fn main() {
|
||||
#[rocket::main]
|
||||
async fn main() -> Result<(), Error> {
|
||||
parse_args();
|
||||
launch_info();
|
||||
|
||||
@ -56,13 +65,16 @@ fn main() {
|
||||
});
|
||||
check_web_vault();
|
||||
|
||||
create_icon_cache_folder();
|
||||
create_dir(&CONFIG.icon_cache_folder(), "icon cache");
|
||||
create_dir(&CONFIG.tmp_folder(), "tmp folder");
|
||||
create_dir(&CONFIG.sends_folder(), "sends folder");
|
||||
create_dir(&CONFIG.attachments_folder(), "attachments folder");
|
||||
|
||||
let pool = create_db_pool();
|
||||
schedule_jobs(pool.clone());
|
||||
crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().unwrap()).unwrap();
|
||||
schedule_jobs(pool.clone()).await;
|
||||
crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().await.unwrap()).unwrap();
|
||||
|
||||
launch_rocket(pool, extra_debug); // Blocks until program termination.
|
||||
launch_rocket(pool, extra_debug).await // Blocks until program termination.
|
||||
}
|
||||
|
||||
const HELP: &str = "\
|
||||
@ -127,10 +139,12 @@ fn init_logging(level: log::LevelFilter) -> Result<(), fern::InitError> {
|
||||
.level_for("hyper::server", log::LevelFilter::Warn)
|
||||
// Silence rocket logs
|
||||
.level_for("_", log::LevelFilter::Off)
|
||||
.level_for("launch", log::LevelFilter::Off)
|
||||
.level_for("launch_", log::LevelFilter::Off)
|
||||
.level_for("rocket::rocket", log::LevelFilter::Off)
|
||||
.level_for("rocket::fairing", log::LevelFilter::Off)
|
||||
.level_for("rocket::launch", log::LevelFilter::Error)
|
||||
.level_for("rocket::launch_", log::LevelFilter::Error)
|
||||
.level_for("rocket::rocket", log::LevelFilter::Warn)
|
||||
.level_for("rocket::server", log::LevelFilter::Warn)
|
||||
.level_for("rocket::fairing::fairings", log::LevelFilter::Warn)
|
||||
.level_for("rocket::shield::shield", log::LevelFilter::Warn)
|
||||
// Never show html5ever and hyper::proto logs, too noisy
|
||||
.level_for("html5ever", log::LevelFilter::Off)
|
||||
.level_for("hyper::proto", log::LevelFilter::Off)
|
||||
@ -243,10 +257,6 @@ fn create_dir(path: &str, description: &str) {
|
||||
create_dir_all(path).expect(&err_msg);
|
||||
}
|
||||
|
||||
fn create_icon_cache_folder() {
|
||||
create_dir(&CONFIG.icon_cache_folder(), "icon cache");
|
||||
}
|
||||
|
||||
fn check_data_folder() {
|
||||
let data_folder = &CONFIG.data_folder();
|
||||
let path = Path::new(data_folder);
|
||||
@ -314,51 +324,73 @@ fn create_db_pool() -> db::DbPool {
|
||||
}
|
||||
}
|
||||
|
||||
fn launch_rocket(pool: db::DbPool, extra_debug: bool) {
|
||||
async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> {
|
||||
let basepath = &CONFIG.domain_path();
|
||||
|
||||
let mut config = rocket::Config::from(rocket::Config::figment());
|
||||
config.address = std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED); // TODO: Allow this to be changed, keep ROCKET_ADDRESS for compat
|
||||
config.temp_dir = canonicalize(CONFIG.tmp_folder()).unwrap().into();
|
||||
config.limits = Limits::new() //
|
||||
.limit("json", 10.megabytes())
|
||||
.limit("data-form", 150.megabytes())
|
||||
.limit("file", 150.megabytes());
|
||||
|
||||
// If adding more paths here, consider also adding them to
|
||||
// crate::utils::LOGGED_ROUTES to make sure they appear in the log
|
||||
let result = rocket::ignite()
|
||||
.mount(&[basepath, "/"].concat(), api::web_routes())
|
||||
.mount(&[basepath, "/api"].concat(), api::core_routes())
|
||||
.mount(&[basepath, "/admin"].concat(), api::admin_routes())
|
||||
.mount(&[basepath, "/identity"].concat(), api::identity_routes())
|
||||
.mount(&[basepath, "/icons"].concat(), api::icons_routes())
|
||||
.mount(&[basepath, "/notifications"].concat(), api::notifications_routes())
|
||||
let instance = rocket::custom(config)
|
||||
.mount([basepath, "/"].concat(), api::web_routes())
|
||||
.mount([basepath, "/api"].concat(), api::core_routes())
|
||||
.mount([basepath, "/admin"].concat(), api::admin_routes())
|
||||
.mount([basepath, "/identity"].concat(), api::identity_routes())
|
||||
.mount([basepath, "/icons"].concat(), api::icons_routes())
|
||||
.mount([basepath, "/notifications"].concat(), api::notifications_routes())
|
||||
.manage(pool)
|
||||
.manage(api::start_notification_server())
|
||||
.attach(util::AppHeaders())
|
||||
.attach(util::Cors())
|
||||
.attach(util::BetterLogging(extra_debug))
|
||||
.launch();
|
||||
.ignite()
|
||||
.await?;
|
||||
|
||||
// Launch and print error if there is one
|
||||
// The launch will restore the original logging level
|
||||
error!("Launch error {:#?}", result);
|
||||
CONFIG.set_rocket_shutdown_handle(instance.shutdown());
|
||||
ctrlc::set_handler(move || {
|
||||
info!("Exiting vaultwarden!");
|
||||
CONFIG.shutdown();
|
||||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
instance.launch().await?;
|
||||
|
||||
info!("Vaultwarden process exited!");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn schedule_jobs(pool: db::DbPool) {
|
||||
async fn schedule_jobs(pool: db::DbPool) {
|
||||
if CONFIG.job_poll_interval_ms() == 0 {
|
||||
info!("Job scheduler disabled.");
|
||||
return;
|
||||
}
|
||||
|
||||
let runtime = tokio::runtime::Handle::current();
|
||||
|
||||
thread::Builder::new()
|
||||
.name("job-scheduler".to_string())
|
||||
.spawn(move || {
|
||||
use job_scheduler::{Job, JobScheduler};
|
||||
|
||||
let mut sched = JobScheduler::new();
|
||||
|
||||
// Purge sends that are past their deletion date.
|
||||
if !CONFIG.send_purge_schedule().is_empty() {
|
||||
sched.add(Job::new(CONFIG.send_purge_schedule().parse().unwrap(), || {
|
||||
api::purge_sends(pool.clone());
|
||||
runtime.spawn(api::purge_sends(pool.clone()));
|
||||
}));
|
||||
}
|
||||
|
||||
// Purge trashed items that are old enough to be auto-deleted.
|
||||
if !CONFIG.trash_purge_schedule().is_empty() {
|
||||
sched.add(Job::new(CONFIG.trash_purge_schedule().parse().unwrap(), || {
|
||||
api::purge_trashed_ciphers(pool.clone());
|
||||
runtime.spawn(api::purge_trashed_ciphers(pool.clone()));
|
||||
}));
|
||||
}
|
||||
|
||||
@ -366,7 +398,7 @@ fn schedule_jobs(pool: db::DbPool) {
|
||||
// indicates that a user's master password has been compromised.
|
||||
if !CONFIG.incomplete_2fa_schedule().is_empty() {
|
||||
sched.add(Job::new(CONFIG.incomplete_2fa_schedule().parse().unwrap(), || {
|
||||
api::send_incomplete_2fa_notifications(pool.clone());
|
||||
runtime.spawn(api::send_incomplete_2fa_notifications(pool.clone()));
|
||||
}));
|
||||
}
|
||||
|
||||
@ -375,7 +407,7 @@ fn schedule_jobs(pool: db::DbPool) {
|
||||
// sending reminders for requests that are about to be granted anyway.
|
||||
if !CONFIG.emergency_request_timeout_schedule().is_empty() {
|
||||
sched.add(Job::new(CONFIG.emergency_request_timeout_schedule().parse().unwrap(), || {
|
||||
api::emergency_request_timeout_job(pool.clone());
|
||||
runtime.spawn(api::emergency_request_timeout_job(pool.clone()));
|
||||
}));
|
||||
}
|
||||
|
||||
@ -383,7 +415,7 @@ fn schedule_jobs(pool: db::DbPool) {
|
||||
// emergency access requests.
|
||||
if !CONFIG.emergency_notification_reminder_schedule().is_empty() {
|
||||
sched.add(Job::new(CONFIG.emergency_notification_reminder_schedule().parse().unwrap(), || {
|
||||
api::emergency_notification_reminder_job(pool.clone());
|
||||
runtime.spawn(api::emergency_notification_reminder_job(pool.clone()));
|
||||
}));
|
||||
}
|
||||
|
||||
|
70
src/util.rs
70
src/util.rs
@ -5,10 +5,10 @@ use std::io::Cursor;
|
||||
|
||||
use rocket::{
|
||||
fairing::{Fairing, Info, Kind},
|
||||
http::{ContentType, Header, HeaderMap, Method, RawStr, Status},
|
||||
http::{ContentType, Header, HeaderMap, Method, Status},
|
||||
request::FromParam,
|
||||
response::{self, Responder},
|
||||
Data, Request, Response, Rocket,
|
||||
Data, Orbit, Request, Response, Rocket,
|
||||
};
|
||||
|
||||
use std::thread::sleep;
|
||||
@ -18,6 +18,7 @@ use crate::CONFIG;
|
||||
|
||||
pub struct AppHeaders();
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl Fairing for AppHeaders {
|
||||
fn info(&self) -> Info {
|
||||
Info {
|
||||
@ -26,7 +27,7 @@ impl Fairing for AppHeaders {
|
||||
}
|
||||
}
|
||||
|
||||
fn on_response(&self, _req: &Request, res: &mut Response) {
|
||||
async fn on_response<'r>(&self, _req: &'r Request<'_>, res: &mut Response<'r>) {
|
||||
res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), camera=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), sync-xhr=(self \"https://haveibeenpwned.com\" \"https://2fa.directory\"), usb=(), vr=()");
|
||||
res.set_raw_header("Referrer-Policy", "same-origin");
|
||||
res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
|
||||
@ -72,6 +73,7 @@ impl Cors {
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl Fairing for Cors {
|
||||
fn info(&self) -> Info {
|
||||
Info {
|
||||
@ -80,7 +82,7 @@ impl Fairing for Cors {
|
||||
}
|
||||
}
|
||||
|
||||
fn on_response(&self, request: &Request, response: &mut Response) {
|
||||
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
|
||||
let req_headers = request.headers();
|
||||
|
||||
if let Some(origin) = Cors::get_allowed_origin(req_headers) {
|
||||
@ -97,7 +99,7 @@ impl Fairing for Cors {
|
||||
response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
|
||||
response.set_status(Status::Ok);
|
||||
response.set_header(ContentType::Plain);
|
||||
response.set_sized_body(Cursor::new(""));
|
||||
response.set_sized_body(Some(0), Cursor::new(""));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -134,25 +136,21 @@ impl<R> Cached<R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, R: Responder<'r>> Responder<'r> for Cached<R> {
|
||||
fn respond_to(self, req: &Request) -> response::Result<'r> {
|
||||
impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> {
|
||||
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
|
||||
let mut res = self.response.respond_to(request)?;
|
||||
|
||||
let cache_control_header = if self.is_immutable {
|
||||
format!("public, immutable, max-age={}", self.ttl)
|
||||
} else {
|
||||
format!("public, max-age={}", self.ttl)
|
||||
};
|
||||
res.set_raw_header("Cache-Control", cache_control_header);
|
||||
|
||||
let time_now = chrono::Local::now();
|
||||
|
||||
match self.response.respond_to(req) {
|
||||
Ok(mut res) => {
|
||||
res.set_raw_header("Cache-Control", cache_control_header);
|
||||
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
|
||||
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
|
||||
Ok(res)
|
||||
}
|
||||
e @ Err(_) => e,
|
||||
}
|
||||
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
|
||||
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,11 +173,9 @@ impl<'r> FromParam<'r> for SafeString {
|
||||
type Error = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> {
|
||||
let s = param.percent_decode().map(|cow| cow.into_owned()).map_err(|_| ())?;
|
||||
|
||||
if s.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
|
||||
Ok(SafeString(s))
|
||||
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
|
||||
if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
|
||||
Ok(SafeString(param.to_string()))
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
@ -193,15 +189,16 @@ const LOGGED_ROUTES: [&str; 6] =
|
||||
|
||||
// Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
|
||||
pub struct BetterLogging(pub bool);
|
||||
#[rocket::async_trait]
|
||||
impl Fairing for BetterLogging {
|
||||
fn info(&self) -> Info {
|
||||
Info {
|
||||
name: "Better Logging",
|
||||
kind: Kind::Launch | Kind::Request | Kind::Response,
|
||||
kind: Kind::Liftoff | Kind::Request | Kind::Response,
|
||||
}
|
||||
}
|
||||
|
||||
fn on_launch(&self, rocket: &Rocket) {
|
||||
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
|
||||
if self.0 {
|
||||
info!(target: "routes", "Routes loaded:");
|
||||
let mut routes: Vec<_> = rocket.routes().collect();
|
||||
@ -225,34 +222,36 @@ impl Fairing for BetterLogging {
|
||||
info!(target: "start", "Rocket has launched from {}", addr);
|
||||
}
|
||||
|
||||
fn on_request(&self, request: &mut Request<'_>, _data: &Data) {
|
||||
async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
|
||||
let method = request.method();
|
||||
if !self.0 && method == Method::Options {
|
||||
return;
|
||||
}
|
||||
let uri = request.uri();
|
||||
let uri_path = uri.path();
|
||||
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
|
||||
let uri_path_str = uri_path.url_decode_lossy();
|
||||
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
|
||||
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
||||
match uri.query() {
|
||||
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path, &q[..q.len().min(30)]),
|
||||
None => info!(target: "request", "{} {}", method, uri_path),
|
||||
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]),
|
||||
None => info!(target: "request", "{} {}", method, uri_path_str),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn on_response(&self, request: &Request, response: &mut Response) {
|
||||
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
|
||||
if !self.0 && request.method() == Method::Options {
|
||||
return;
|
||||
}
|
||||
let uri_path = request.uri().path();
|
||||
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
|
||||
let uri_path_str = uri_path.url_decode_lossy();
|
||||
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
|
||||
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
||||
let status = response.status();
|
||||
if let Some(route) = request.route() {
|
||||
info!(target: "response", "{} => {} {}", route, status.code, status.reason)
|
||||
if let Some(ref route) = request.route() {
|
||||
info!(target: "response", "{} => {}", route, status)
|
||||
} else {
|
||||
info!(target: "response", "{} {}", status.code, status.reason)
|
||||
info!(target: "response", "{}", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -614,10 +613,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
use reqwest::{
|
||||
blocking::{Client, ClientBuilder},
|
||||
header,
|
||||
};
|
||||
use reqwest::{header, Client, ClientBuilder};
|
||||
|
||||
pub fn get_reqwest_client() -> Client {
|
||||
get_reqwest_client_builder().build().expect("Failed to build client")
|
||||
|
Reference in New Issue
Block a user