diff --git a/backend/src/auth/bcrypt.rs b/backend/src/auth/bcrypt.rs index 602cfc2..c0f1fb8 100644 --- a/backend/src/auth/bcrypt.rs +++ b/backend/src/auth/bcrypt.rs @@ -1,16 +1,12 @@ -use crate::error::CustomErrorInto; -use crate::error::CustomResult; +use crate::error::{CustomErrorInto, CustomResult}; use bcrypt::{hash, verify, DEFAULT_COST}; pub fn generate_hash(s: &str) -> CustomResult { - let hashed = hash(s, DEFAULT_COST)?; - Ok(hashed) + Ok(hash(s, DEFAULT_COST)?) } pub fn verify_hash(s: &str, hash: &str) -> CustomResult<()> { - let is_valid = verify(s, hash)?; - if !is_valid { - return Err("密码无效".into_custom_error()); - } - Ok(()) + verify(s, hash)? + .then_some(()) + .ok_or_else(|| "密码无效".into_custom_error()) } diff --git a/backend/src/auth/jwt.rs b/backend/src/auth/jwt.rs index 79c5921..aa655b4 100644 --- a/backend/src/auth/jwt.rs +++ b/backend/src/auth/jwt.rs @@ -4,9 +4,7 @@ use ed25519_dalek::{SigningKey, VerifyingKey}; use jwt_compact::{alg::Ed25519, AlgorithmExt, Header, TimeOptions, Token, UntrustedToken}; use rand::{RngCore, SeedableRng}; use serde::{Deserialize, Serialize}; -use std::fs::File; -use std::io::Write; -use std::{env, fs}; +use std::{env, fs, path::PathBuf}; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct CustomClaims { @@ -19,73 +17,74 @@ pub enum SecretKey { } impl SecretKey { - fn as_string(&self) -> String { + const fn as_str(&self) -> &'static str { match self { - Self::Signing => String::from("signing"), - Self::Verifying => String::from("verifying"), + Self::Signing => "signing", + Self::Verifying => "verifying", } } } +fn get_key_path(key_type: &SecretKey) -> CustomResult { + Ok(env::current_dir()? + .join("assets") + .join("key") + .join(key_type.as_str())) +} + pub fn generate_key() -> CustomResult<()> { let mut csprng = rand::rngs::StdRng::from_entropy(); - let mut private_key_bytes = [0u8; 32]; csprng.fill_bytes(&mut private_key_bytes); let signing_key = SigningKey::from_bytes(&private_key_bytes); let verifying_key = signing_key.verifying_key(); - let base_path = env::current_dir()?.join("assets").join("key"); - + let base_path = get_key_path(&SecretKey::Signing)? + .parent() + .unwrap() + .to_path_buf(); fs::create_dir_all(&base_path)?; - File::create(base_path.join(SecretKey::Signing.as_string()))? - .write_all(signing_key.as_bytes())?; - File::create(base_path.join(SecretKey::Verifying.as_string()))? - .write_all(verifying_key.as_bytes())?; + + fs::write(get_key_path(&SecretKey::Signing)?, signing_key.as_bytes())?; + fs::write( + get_key_path(&SecretKey::Verifying)?, + verifying_key.as_bytes(), + )?; Ok(()) } pub fn get_key(key_type: SecretKey) -> CustomResult<[u8; 32]> { - let path = env::current_dir()? - .join("assets") - .join("key") - .join(key_type.as_string()); - let key_bytes = fs::read(path)?; + let key_bytes = fs::read(get_key_path(&key_type)?)?; let mut key = [0u8; 32]; key.copy_from_slice(&key_bytes[..32]); Ok(key) } pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> CustomResult { - let key_bytes = get_key(SecretKey::Signing)?; - let signing_key = SigningKey::from_bytes(&key_bytes); - + let signing_key = SigningKey::from_bytes(&get_key(SecretKey::Signing)?); let time_options = TimeOptions::new(Duration::seconds(0), Utc::now); + let claims = jwt_compact::Claims::new(claims) .set_duration_and_issuance(&time_options, duration) .set_not_before(Utc::now()); - let header = Header::empty(); - - let token = Ed25519.token(&header, &claims, &signing_key)?; - - Ok(token) + Ok(Ed25519.token(&Header::empty(), &claims, &signing_key)?) } pub fn validate_jwt(token: &str) -> CustomResult { - let key_bytes = get_key(SecretKey::Verifying)?; - let verifying = VerifyingKey::from_bytes(&key_bytes)?; - let token = UntrustedToken::new(token)?; - let token: Token = Ed25519.validator(&verifying).validate(&token)?; - + let verifying = VerifyingKey::from_bytes(&get_key(SecretKey::Verifying)?)?; let time_options = TimeOptions::new(Duration::seconds(0), Utc::now); + + let token: Token = Ed25519 + .validator(&verifying) + .validate(&UntrustedToken::new(token)?)?; + token .claims() .validate_expiration(&time_options)? .validate_maturity(&time_options)?; - let claims = token.claims().custom.clone(); - Ok(claims) + Ok(token.claims().custom.clone()) } diff --git a/backend/src/config.rs b/backend/src/config.rs index 20e1915..d11a321 100644 --- a/backend/src/config.rs +++ b/backend/src/config.rs @@ -11,7 +11,6 @@ pub struct Config { pub sql_config: SqlConfig, } - impl Default for Config { fn default() -> Self { Self { diff --git a/backend/src/database/relational/builder.rs b/backend/src/database/relational/builder.rs index ba5515e..d0cabab 100644 --- a/backend/src/database/relational/builder.rs +++ b/backend/src/database/relational/builder.rs @@ -77,7 +77,7 @@ impl TextValidator { let max_length = self .level_max_lengths .get(&level) - .ok_or( "Invalid validation level".into_custom_error())?; + .ok_or("Invalid validation level".into_custom_error())?; if text.len() > *max_length { return Err("Text exceeds maximum length".into_custom_error()); @@ -499,7 +499,7 @@ impl QueryBuilder { for condition in conditions { let (sql, mut condition_params) = self.build_where_clause_with_index(condition, param_index)?; - param_index += condition_params.len(); // 更新参数索引 + param_index += condition_params.len(); parts.push(sql); params.append(&mut condition_params); } diff --git a/backend/src/database/relational/mod.rs b/backend/src/database/relational/mod.rs index 3ed7237..0de5052 100644 --- a/backend/src/database/relational/mod.rs +++ b/backend/src/database/relational/mod.rs @@ -2,8 +2,7 @@ mod postgresql; use crate::config; use crate::error::{CustomErrorInto, CustomResult}; use async_trait::async_trait; -use std::collections::HashMap; -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; pub mod builder; #[async_trait] @@ -14,7 +13,7 @@ pub trait DatabaseTrait: Send + Sync { async fn execute_query<'a>( &'a self, builder: &builder::QueryBuilder, - ) -> CustomResult>>; + ) -> CustomResult>>; async fn initialization(database: config::SqlConfig) -> CustomResult<()> where Self: Sized; diff --git a/backend/src/database/relational/postgresql/mod.rs b/backend/src/database/relational/postgresql/mod.rs index 6bea843..2067ebd 100644 --- a/backend/src/database/relational/postgresql/mod.rs +++ b/backend/src/database/relational/postgresql/mod.rs @@ -3,7 +3,8 @@ use crate::config; use crate::error::CustomErrorInto; use crate::error::CustomResult; use async_trait::async_trait; -use sqlx::{Column, Executor, PgPool, Row}; +use serde_json::Value; +use sqlx::{Column, Executor, PgPool, Row, TypeInfo}; use std::collections::HashMap; use std::{env, fs}; #[derive(Clone)] @@ -64,7 +65,7 @@ impl DatabaseTrait for Postgresql { async fn execute_query<'a>( &'a self, builder: &builder::QueryBuilder, - ) -> CustomResult>> { + ) -> CustomResult>> { let (query, values) = builder.build()?; let mut sqlx_query = sqlx::query(&query); @@ -73,22 +74,32 @@ impl DatabaseTrait for Postgresql { sqlx_query = sqlx_query.bind(value.to_sql_string()?); } - let rows = sqlx_query.fetch_all(&self.pool).await.map_err(|e| { - let (sql, params) = builder.build().unwrap(); - format!("Err:{}\n,SQL: {}\nParams: {:?}", e.to_string(), sql, params) - .into_custom_error() - })?; + let rows = sqlx_query.fetch_all(&self.pool).await?; - let mut results = Vec::new(); - for row in rows { - let mut map = HashMap::new(); - for column in row.columns() { - let value: String = row.try_get(column.name()).unwrap_or_default(); - map.insert(column.name().to_string(), value); - } - results.push(map); - } - - Ok(results) + Ok(rows + .into_iter() + .map(|row| { + row.columns() + .iter() + .map(|col| { + let value = match col.type_info().name() { + "INT4" | "INT8" => Value::Number( + row.try_get::(col.name()).unwrap_or_default().into(), + ), + "FLOAT4" | "FLOAT8" => Value::Number( + serde_json::Number::from_f64( + row.try_get::(col.name()).unwrap_or(0.0), + ) + .unwrap_or_else(|| 0.into()), + ), + "BOOL" => Value::Bool(row.try_get(col.name()).unwrap_or_default()), + "JSON" | "JSONB" => row.try_get(col.name()).unwrap_or(Value::Null), + _ => Value::String(row.try_get(col.name()).unwrap_or_default()), + }; + (col.name().to_string(), value) + }) + .collect() + }) + .collect()) } } diff --git a/backend/src/main.rs b/backend/src/main.rs index 7a18b84..a4ab47e 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,14 +1,15 @@ mod auth; mod config; mod database; +mod error; mod routes; mod utils; + use database::relational; +use error::{CustomErrorInto, CustomResult}; use rocket::Shutdown; use std::sync::Arc; use tokio::sync::Mutex; -mod error; -use error::{CustomErrorInto, CustomResult}; pub struct AppState { db: Arc>>, @@ -30,12 +31,11 @@ impl AppState { .lock() .await .clone() - .ok_or("数据库未连接".into_custom_error()) + .ok_or_else(|| "数据库未连接".into_custom_error()) } pub async fn sql_link(&self, config: &config::SqlConfig) -> CustomResult<()> { - let database = relational::Database::link(config).await?; - *self.db.lock().await = Some(database); + *self.db.lock().await = Some(relational::Database::link(config).await?); Ok(()) } @@ -45,14 +45,12 @@ impl AppState { pub async fn trigger_restart(&self) -> CustomResult<()> { *self.restart_progress.lock().await = true; - self.shutdown .lock() .await .take() - .ok_or("未能获取rocket的shutdown".into_custom_error())? + .ok_or_else(|| "未能获取rocket的shutdown".into_custom_error())? .notify(); - Ok(()) } } @@ -60,39 +58,39 @@ impl AppState { #[rocket::main] async fn main() -> CustomResult<()> { let config = config::Config::read().unwrap_or_default(); + let state = Arc::new(AppState::new()); let rocket_config = rocket::Config::figment() - .merge(("address", config.address.clone())) + .merge(("address", config.address)) .merge(("port", config.port)); - let state = AppState::new(); - let state = Arc::new(state); + let mut rocket_builder = rocket::build() + .configure(rocket_config) + .manage(state.clone()); - let rocket_builder = rocket::build().configure(rocket_config).manage(state.clone()); - - let rocket_builder = if !config.info.install { - rocket_builder.mount("/", rocket::routes![routes::install::install]) + if !config.info.install { + rocket_builder = rocket_builder.mount("/", rocket::routes![routes::install::install]); } else { state.sql_link(&config.sql_config).await?; - rocket_builder + rocket_builder = rocket_builder .mount("/auth/token", routes::jwt_routes()) - .mount("/config", routes::configure_routes()) - }; + .mount("/config", routes::configure_routes()); + } let rocket = rocket_builder.ignite().await?; rocket .state::>() - .ok_or("未能获取AppState".into_custom_error())? + .ok_or_else(|| "未能获取AppState".into_custom_error())? .set_shutdown(rocket.shutdown()) .await; rocket.launch().await?; - let restart_progress = *state.restart_progress.lock().await; - if restart_progress { - let current_exe = std::env::current_exe()?; - let _ = std::process::Command::new(current_exe).spawn(); + if *state.restart_progress.lock().await { + if let Ok(current_exe) = std::env::current_exe() { + let _ = std::process::Command::new(current_exe).spawn(); + } } std::process::exit(0); } diff --git a/backend/src/routes/auth/token.rs b/backend/src/routes/auth/token.rs index aa90dd6..6e74998 100644 --- a/backend/src/routes/auth/token.rs +++ b/backend/src/routes/auth/token.rs @@ -23,81 +23,74 @@ pub async fn token_system( state: &State>, data: Json, ) -> AppResult { - let name_condition = builder::Condition::new( - "person_name".to_string(), - builder::Operator::Eq, - Some(builder::SafeValue::Text( - data.name.to_string(), - builder::ValidationLevel::Relaxed, - )), - ) - .into_app_result()?; - - let email_condition = builder::Condition::new( - "person_email".to_string(), - builder::Operator::Eq, - Some(builder::SafeValue::Text( - "author@lsy22.com".to_string(), - builder::ValidationLevel::Relaxed, - )), - ) - .into_app_result()?; - - let level_condition = builder::Condition::new( - "person_level".to_string(), - builder::Operator::Eq, - Some(builder::SafeValue::Enum( - "administrators".to_string(), - "privilege_level".to_string(), - builder::ValidationLevel::Standard, - )), - ) - .into_app_result()?; - - let where_clause = builder::WhereClause::And(vec![ - builder::WhereClause::Condition(name_condition), - builder::WhereClause::Condition(email_condition), - builder::WhereClause::Condition(level_condition), - ]); - let mut builder = - builder::QueryBuilder::new(builder::SqlOperation::Select, String::from("persons")) + builder::QueryBuilder::new(builder::SqlOperation::Select, "persons".to_string()) .into_app_result()?; - - let builder = builder + builder .add_field("person_password".to_string()) - .into_app_result()?; + .into_app_result()? + .add_condition(builder::WhereClause::And(vec![ + builder::WhereClause::Condition( + builder::Condition::new( + "person_name".to_string(), + builder::Operator::Eq, + Some(builder::SafeValue::Text( + data.name.clone(), + builder::ValidationLevel::Relaxed, + )), + ) + .into_app_result()?, + ), + builder::WhereClause::Condition( + builder::Condition::new( + "person_email".to_string(), + builder::Operator::Eq, + Some(builder::SafeValue::Text( + "author@lsy22.com".into(), + builder::ValidationLevel::Relaxed, + )), + ) + .into_app_result()?, + ), + builder::WhereClause::Condition( + builder::Condition::new( + "person_level".to_string(), + builder::Operator::Eq, + Some(builder::SafeValue::Enum( + "administrators".into(), + "privilege_level".into(), + builder::ValidationLevel::Standard, + )), + ) + .into_app_result()?, + ), + ])); - let sql_builder = builder.add_condition(where_clause); let values = state .sql_get() .await .into_app_result()? .get_db() - .execute_query(&sql_builder) + .execute_query(&builder) .await .into_app_result()?; let password = values .first() - .ok_or(status::Custom( - Status::NotFound, - String::from("该用户并非系统用户"), - ))? - .get("person_password") - .ok_or(status::Custom( - Status::NotFound, - String::from("该用户密码丢失"), - ))?; + .and_then(|row| row.get("person_password")) + .and_then(|val| val.as_str()) + .ok_or_else(|| { + status::Custom(Status::NotFound, "Invalid system user or password".into()) + })?; - auth::bcrypt::verify_hash(&data.password, password).map_err(|_| { - status::Custom(Status::Forbidden, String::from("密码错误")) - })?; + auth::bcrypt::verify_hash(&data.password, password) + .map_err(|_| status::Custom(Status::Forbidden, "Invalid password".into()))?; - let claims = auth::jwt::CustomClaims { - name: "system".into(), - }; - let token = auth::jwt::generate_jwt(claims, Duration::minutes(1)).into_app_result()?; - - Ok(token) + Ok(auth::jwt::generate_jwt( + auth::jwt::CustomClaims { + name: "system".into(), + }, + Duration::minutes(1), + ) + .into_app_result()?) } diff --git a/backend/src/routes/configure.rs b/backend/src/routes/configure.rs index 388a5d5..39c7fad 100644 --- a/backend/src/routes/configure.rs +++ b/backend/src/routes/configure.rs @@ -39,15 +39,28 @@ pub async fn get_configure( comfig_type: String, name: String, ) -> CustomResult> { + let name_condition = builder::Condition::new( + "config_name".to_string(), + builder::Operator::Eq, + Some(builder::SafeValue::Text( + format!("{}_{}", comfig_type, name), + builder::ValidationLevel::Strict, + )), + )?; + + println!( + "Searching for config_name: {}", + format!("{}_{}", comfig_type, name) + ); + + let where_clause = builder::WhereClause::Condition(name_condition); + let mut sql_builder = builder::QueryBuilder::new(builder::SqlOperation::Select, "config".to_string())?; - sql_builder.set_value( - "config_name".to_string(), - builder::SafeValue::Text( - format!("{}_{}", comfig_type, name).to_string(), - builder::ValidationLevel::Strict, - ), - )?; + sql_builder + .add_condition(where_clause) + .add_field("config_data".to_string())?; + let result = sql.get_db().execute_query(&sql_builder).await?; Ok(Json(json!(result))) } @@ -76,9 +89,12 @@ pub async fn insert_configure( } #[get("/system")] -pub async fn system_config_get(state: &State>,token: SystemToken) -> AppResult> { +pub async fn system_config_get( + state: &State>, + _token: SystemToken, +) -> AppResult> { let sql = state.sql_get().await.into_app_result()?; - let configure = get_configure(&sql, "system".to_string(), "configure".to_string()) + let configure = get_configure(&sql, "system".to_string(), "config".to_string()) .await .into_app_result()?; Ok(configure) diff --git a/backend/src/routes/install.rs b/backend/src/routes/install.rs index 78629bf..ab9b5b1 100644 --- a/backend/src/routes/install.rs +++ b/backend/src/routes/install.rs @@ -1,14 +1,14 @@ +use super::{configure, person}; use crate::auth; use crate::database::relational; use crate::error::{AppResult, AppResultInto}; -use super::{person, configure}; use crate::AppState; use crate::{config, utils}; use chrono::Duration; use rocket::{http::Status, post, response::status, serde::json::Json, State}; use serde::{Deserialize, Serialize}; -use std::sync::Arc; use serde_json::json; +use std::sync::Arc; #[derive(Deserialize, Serialize)] pub struct InstallData { @@ -37,26 +37,26 @@ pub async fn install( )); } - config.info.install = true; - config.sql_config = data.sql_config.clone(); - let data = data.into_inner(); + let sql = { + config.info.install = true; + config.sql_config = data.sql_config.clone(); - relational::Database::initial_setup(data.sql_config.clone()) - .await - .into_app_result()?; + relational::Database::initial_setup(data.sql_config.clone()) + .await + .into_app_result()?; + auth::jwt::generate_key().into_app_result()?; - let _ = auth::jwt::generate_key(); + state.sql_link(&data.sql_config).await.into_app_result()?; + state.sql_get().await.into_app_result()? + }; - + let system_credentials = ( + utils::generate_random_string(20), + utils::generate_random_string(20), + ); - state.sql_link(&data.sql_config).await.into_app_result()?; - let sql = state.sql_get().await.into_app_result()?; - - let system_name = utils::generate_random_string(20); - let system_password = utils::generate_random_string(20); - - let _ = person::insert( + person::insert( &sql, person::RegisterData { name: data.name.clone(), @@ -68,41 +68,45 @@ pub async fn install( .await .into_app_result()?; - let _ = person::insert( + person::insert( &sql, person::RegisterData { - name: system_name.clone(), - email: String::from("author@lsy22.com"), - password: system_password.clone(), + name: system_credentials.0.clone(), + email: "author@lsy22.com".to_string(), + password: system_credentials.1.clone(), level: "administrators".to_string(), }, ) .await .into_app_result()?; - let mut system_configure = configure::SystemConfigure::default(); - system_configure.author_name = data.name.clone(); - - configure::insert_configure(&sql, "system".to_string(), "configure".to_string(), Json(json!(system_configure))).await.into_app_result()?; - + configure::insert_configure( + &sql, + "system".to_string(), + "config".to_string(), + Json(json!(configure::SystemConfigure { + author_name: data.name.clone(), + ..configure::SystemConfigure::default() + })), + ) + .await + .into_app_result()?; let token = auth::jwt::generate_jwt( - auth::jwt::CustomClaims { - name: data.name.clone(), - }, + auth::jwt::CustomClaims { name: data.name }, Duration::days(7), ) .into_app_result()?; - config::Config::write(config.clone()).into_app_result()?; - + config::Config::write(config).into_app_result()?; state.trigger_restart().await.into_app_result()?; + Ok(status::Custom( Status::Ok, Json(InstallReplyData { - token: token, - name: system_name, - password: system_password, + token, + name: system_credentials.0, + password: system_credentials.1, }), )) } diff --git a/backend/src/routes/person.rs b/backend/src/routes/person.rs index 4237c08..73849ae 100644 --- a/backend/src/routes/person.rs +++ b/backend/src/routes/person.rs @@ -23,26 +23,26 @@ pub struct RegisterData { pub async fn insert(sql: &relational::Database, data: RegisterData) -> CustomResult<()> { let mut builder = builder::QueryBuilder::new(builder::SqlOperation::Insert, "persons".to_string())?; - - let password_hash = auth::bcrypt::generate_hash(&data.password)?; - builder .set_value( "person_name".to_string(), - builder::SafeValue::Text(data.name.to_string(), builder::ValidationLevel::Relaxed), + builder::SafeValue::Text(data.name, builder::ValidationLevel::Relaxed), )? .set_value( "person_email".to_string(), - builder::SafeValue::Text(data.email.to_string(), builder::ValidationLevel::Relaxed), + builder::SafeValue::Text(data.email, builder::ValidationLevel::Relaxed), )? .set_value( "person_password".to_string(), - builder::SafeValue::Text(password_hash, builder::ValidationLevel::Relaxed), + builder::SafeValue::Text( + bcrypt::generate_hash(&data.password)?, + builder::ValidationLevel::Relaxed, + ), )? .set_value( "person_level".to_string(), builder::SafeValue::Enum( - data.level.to_string(), + data.level, "privilege_level".to_string(), builder::ValidationLevel::Standard, ),