From 3a88c33a6eb2bae396582afedb1369ac3e1dd344 Mon Sep 17 00:00:00 2001 From: lsy Date: Thu, 21 Nov 2024 11:47:41 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8E=E7=AB=AF=EF=BC=9A=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E9=94=99=E8=AF=AF=EF=BC=8C=E6=89=80=E6=9C=89=E8=BF=94=E5=9B=9E?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E9=83=BD=E5=85=B7=E6=9C=89=E5=8F=91=E9=80=81?= =?UTF-8?q?=E5=92=8C=E5=85=A8=E5=B1=80=E7=94=9F=E5=91=BD=E5=85=B7=E6=9C=89?= =?UTF-8?q?=E6=89=80=E6=9C=89=E6=9D=83=EF=BC=8C=E6=89=80=E6=9C=89=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E7=9A=84=E9=94=99=E8=AF=AF=E9=83=BD=E5=8F=AF=E4=BB=A5?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=3F=E6=9D=A5=E8=A7=A3=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/src/auth/jwt.rs | 10 ++-- backend/src/database/relational/builder.rs | 70 +++++++++++----------- backend/src/database/relational/mod.rs | 35 ++--------- backend/src/main.rs | 35 +++-------- backend/src/routes/intsall.rs | 19 ++---- backend/src/routes/person.rs | 14 ++--- backend/src/utils.rs | 27 ++++++++- 7 files changed, 88 insertions(+), 122 deletions(-) diff --git a/backend/src/auth/jwt.rs b/backend/src/auth/jwt.rs index 5138e31..92b58f7 100644 --- a/backend/src/auth/jwt.rs +++ b/backend/src/auth/jwt.rs @@ -5,7 +5,7 @@ use ed25519_dalek::{SigningKey, VerifyingKey}; use std::fs::File; use std::io::Write; use std::{env, fs}; -use std::error::Error; +use crate::utils::CustomError; use rand::{SeedableRng, RngCore}; #[derive(Debug, Serialize, Deserialize, Clone)] @@ -27,7 +27,7 @@ impl SecretKey { } } -pub fn generate_key() -> Result<(), Box> { +pub fn generate_key() -> Result<(),CustomError> { let mut csprng = rand::rngs::StdRng::from_entropy(); let mut private_key_bytes = [0u8; 32]; @@ -49,7 +49,7 @@ pub fn generate_key() -> Result<(), Box> { Ok(()) } -pub fn get_key(key_type: SecretKey) -> Result<[u8; 32], Box> { +pub fn get_key(key_type: SecretKey) -> Result<[u8; 32],CustomError> { let path = env::current_dir()? .join("assets") .join("key") @@ -60,7 +60,7 @@ pub fn get_key(key_type: SecretKey) -> Result<[u8; 32], Box> { Ok(key) } -pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> Result> { +pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> Result { let key_bytes = get_key(SecretKey::Signing)?; let signing_key = SigningKey::from_bytes(&key_bytes); @@ -79,7 +79,7 @@ pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> Result Result> { +pub fn validate_jwt(token: &str) -> Result { let key_bytes = get_key(SecretKey::Verifying)?; let verifying = VerifyingKey::from_bytes(&key_bytes)?; let token = UntrustedToken::new(token)?; diff --git a/backend/src/database/relational/builder.rs b/backend/src/database/relational/builder.rs index 3e0d03d..8b46cf2 100644 --- a/backend/src/database/relational/builder.rs +++ b/backend/src/database/relational/builder.rs @@ -1,8 +1,6 @@ use regex::Regex; +use crate::utils::CustomError; use std::collections::HashMap; -use super::DatabaseError; - - use std::hash::Hash; #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -13,17 +11,15 @@ pub enum ValidatedValue { } impl ValidatedValue { - pub fn new_identifier(value: String) -> Result { + pub fn new_identifier(value: String) -> Result { let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]{0,63}$").unwrap(); if !valid_pattern.is_match(&value) { - return Err(DatabaseError::ValidationError( - "Invalid identifier format".to_string(), - )); + return Err(CustomError::from_str("Invalid identifier format")); } Ok(ValidatedValue::Identifier(value)) } - pub fn new_rich_text(value: String) -> Result { + pub fn new_rich_text(value: String) -> Result { let dangerous_patterns = [ "UNION ALL SELECT", "UNION SELECT", @@ -42,24 +38,24 @@ impl ValidatedValue { let value_upper = value.to_uppercase(); for pattern in dangerous_patterns.iter() { if value_upper.contains(&pattern.to_uppercase()) { - return Err(DatabaseError::SqlInjectionAttempt( - format!("Dangerous SQL pattern detected: {}", pattern) - )); + return Err(CustomError::from_str("Invalid identifier format")); } } Ok(ValidatedValue::RichText(value)) } - pub fn new_plain_text(value: String) -> Result { + pub fn new_plain_text(value: String) -> Result { if value.contains(';') || value.contains("--") { - return Err(DatabaseError::ValidationError("Invalid characters in text".to_string())); + return Err(CustomError::from_str("Invalid characters in text")); } Ok(ValidatedValue::PlainText(value)) } pub fn get(&self) -> &str { match self { - ValidatedValue::Identifier(s) | ValidatedValue::RichText(s) | ValidatedValue::PlainText(s) => s, + ValidatedValue::Identifier(s) + | ValidatedValue::RichText(s) + | ValidatedValue::PlainText(s) => s, } } } @@ -105,7 +101,7 @@ impl Operator { #[derive(Debug, Clone)] pub struct WhereCondition { - field: ValidatedValue, + field: ValidatedValue, operator: Operator, value: Option, } @@ -115,9 +111,9 @@ impl WhereCondition { field: String, operator: Operator, value: Option, - ) -> Result { + ) -> Result { let field = ValidatedValue::new_identifier(field)?; - + let value = match value { Some(v) => Some(match operator { Operator::Like => ValidatedValue::new_plain_text(v)?, @@ -140,19 +136,19 @@ pub enum WhereClause { Or(Vec), Condition(WhereCondition), } - +#[derive(Debug, Clone)] pub struct QueryBuilder { operation: SqlOperation, table: ValidatedValue, - fields: Vec, - params: HashMap, + fields: Vec, + params: HashMap, where_clause: Option, - order_by: Option, + order_by: Option, limit: Option, } impl QueryBuilder { - pub fn new(operation: SqlOperation, table: String) -> Result { + pub fn new(operation: SqlOperation, table: String) -> Result { Ok(QueryBuilder { operation, table: ValidatedValue::new_identifier(table)?, @@ -164,7 +160,7 @@ impl QueryBuilder { }) } - pub fn build(&self) -> Result<(String, Vec), DatabaseError> { + pub fn build(&self) -> Result<(String, Vec), CustomError> { let mut query = String::new(); let mut values = Vec::new(); let mut param_counter = 1; @@ -174,7 +170,8 @@ impl QueryBuilder { let fields = if self.fields.is_empty() { "*".to_string() } else { - self.fields.iter() + self.fields + .iter() .map(|f| f.get().to_string()) .collect::>() .join(", ") @@ -182,12 +179,9 @@ impl QueryBuilder { query.push_str(&format!("SELECT {} FROM {}", fields, self.table.get())); } SqlOperation::Insert => { - let fields: Vec = self.params.keys() - .map(|k| k.get().to_string()) - .collect(); - let placeholders: Vec = (1..=self.params.len()) - .map(|i| format!("${}", i)) - .collect(); + let fields: Vec = self.params.keys().map(|k| k.get().to_string()).collect(); + let placeholders: Vec = + (1..=self.params.len()).map(|i| format!("${}", i)).collect(); query.push_str(&format!( "INSERT INTO {} ({}) VALUES ({})", @@ -201,7 +195,8 @@ impl QueryBuilder { } SqlOperation::Update => { query.push_str(&format!("UPDATE {} SET ", self.table.get())); - let set_clauses: Vec = self.params + let set_clauses: Vec = self + .params .iter() .map(|(key, _)| { let placeholder = format!("${}", param_counter); @@ -239,7 +234,7 @@ impl QueryBuilder { &self, clause: &WhereClause, mut param_counter: i32, - ) -> Result<(String, Vec), DatabaseError> { + ) -> Result<(String, Vec), CustomError> { let mut values = Vec::new(); let sql = match clause { @@ -267,7 +262,12 @@ impl QueryBuilder { if let Some(value) = &cond.value { let placeholder = format!("${}", param_counter); values.push(value.get().to_string()); - format!("{} {} {}", cond.field.get(), cond.operator.as_str(), placeholder) + format!( + "{} {} {}", + cond.field.get(), + cond.operator.as_str(), + placeholder + ) } else { format!("{} {}", cond.field.get(), cond.operator.as_str()) } @@ -281,7 +281,7 @@ impl QueryBuilder { self } - pub fn params(mut self, params: HashMap) -> Self { + pub fn params(mut self, params: HashMap) -> Self { self.params = params; self } @@ -300,4 +300,4 @@ impl QueryBuilder { self.limit = Some(limit); self } -} \ No newline at end of file +} diff --git a/backend/src/database/relational/mod.rs b/backend/src/database/relational/mod.rs index 62dacd6..5d2d213 100644 --- a/backend/src/database/relational/mod.rs +++ b/backend/src/database/relational/mod.rs @@ -2,43 +2,20 @@ mod postgresql; use crate::config; use async_trait::async_trait; use std::collections::HashMap; -use std::error::Error; +use crate::utils::CustomError; use std::sync::Arc; -use std::fmt; pub mod builder; -#[derive(Debug)] -pub enum DatabaseError { - ValidationError(String), - SqlInjectionAttempt(String), - InvalidParameter(String), - ExecutionError(String), -} - -impl fmt::Display for DatabaseError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - DatabaseError::ValidationError(msg) => write!(f, "Validation error: {}", msg), - DatabaseError::SqlInjectionAttempt(msg) => write!(f, "SQL injection attempt: {}", msg), - DatabaseError::InvalidParameter(msg) => write!(f, "Invalid parameter: {}", msg), - DatabaseError::ExecutionError(msg) => write!(f, "Execution error: {}", msg), - } - } -} - -impl Error for DatabaseError {} - - #[async_trait] pub trait DatabaseTrait: Send + Sync { - async fn connect(database: &config::SqlConfig) -> Result> + async fn connect(database: &config::SqlConfig) -> Result where Self: Sized; async fn execute_query<'a>( &'a self, builder: &builder::QueryBuilder, - ) -> Result>, Box>; - async fn initialization(database: config::SqlConfig) -> Result<(), Box> + ) -> Result>, CustomError>; + async fn initialization(database: config::SqlConfig) -> Result<(), CustomError> where Self: Sized; } @@ -53,7 +30,7 @@ impl Database { &self.db } - pub async fn link(database: &config::SqlConfig) -> Result> { + pub async fn link(database: &config::SqlConfig) -> Result { let db = match database.db_type.as_str() { "postgresql" => postgresql::Postgresql::connect(database).await?, _ => return Err("unknown database type".into()), @@ -64,7 +41,7 @@ impl Database { }) } - pub async fn initial_setup(database: config::SqlConfig) -> Result<(), Box> { + pub async fn initial_setup(database: config::SqlConfig) -> Result<(), CustomError> { match database.db_type.as_str() { "postgresql" => postgresql::Postgresql::initialization(database).await?, _ => return Err("unknown database type".into()), diff --git a/backend/src/main.rs b/backend/src/main.rs index c21fa63..62f8d56 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -6,33 +6,13 @@ mod routes; use chrono::Duration; use database::relational; use rocket::{ - get, post, - http::Status, - launch, - response::status, - State, + get, http::Status, launch, outcome::IntoOutcome, post, response::status, State }; use std::sync::Arc; use tokio::sync::Mutex; +use std::error::Error; -#[derive(Debug)] -pub enum AppError { - Database(String), - Config(String), - Auth(String), -} -impl From for status::Custom { - fn from(error: AppError) -> Self { - match error { - AppError::Database(msg) => status::Custom(Status::InternalServerError, format!("Database error: {}", msg)), - AppError::Config(msg) => status::Custom(Status::InternalServerError, format!("Config error: {}", msg)), - AppError::Auth(msg) => status::Custom(Status::InternalServerError, format!("Auth error: {}", msg)), - } - } -} - -type AppResult = Result; struct AppState { db: Arc>>, @@ -40,18 +20,17 @@ struct AppState { } impl AppState { - async fn get_sql(&self) -> AppResult { + async fn get_sql(&self) -> Result> { self.db .lock() .await .clone() - .ok_or_else(|| AppError::Database("Database not initialized".into())) + .ok_or_else(|| "Database not initialized".into()) } - async fn link_sql(&self, config: config::SqlConfig) -> AppResult<()> { + async fn link_sql(&self, config: config::SqlConfig) -> Result<,Box> { let database = relational::Database::link(&config) - .await - .map_err(|e| AppError::Database(e.to_string()))?; + .await?; *self.db.lock().await = Some(database); Ok(()) } @@ -68,7 +47,7 @@ async fn token_system(_state: &State) -> Result auth::jwt::generate_jwt(claims, Duration::seconds(1)) .map(|token| status::Custom(Status::Ok, token)) - .map_err(|e| AppError::Auth(e.to_string()).into()) + .map_err(|e| status::Custom(Status::InternalServerError, e.to_string())) } diff --git a/backend/src/routes/intsall.rs b/backend/src/routes/intsall.rs index 67143a1..3fa354d 100644 --- a/backend/src/routes/intsall.rs +++ b/backend/src/routes/intsall.rs @@ -1,7 +1,7 @@ use serde::{Deserialize,Serialize}; use crate::{config,utils}; use crate::database::relational; -use crate::{AppState,AppError,AppResult}; +use crate::AppState; use rocket::{ post, http::Status, @@ -28,20 +28,6 @@ pub struct InstallReplyData{ password:String, } -#[post("/test", format = "application/json", data = "")] -pub async fn test( - data: Json, - state: &State -) -> Result, status::Custom> { - let data=data.into_inner(); - - - let sql= state.get_sql().await.map_err(|e| e)?; - - let _ = person::insert(&sql,person::RegisterData{ name: data.name.clone(), email: data.email, password:data.password }); - Ok(status::Custom(Status::Ok, "Installation successful".to_string())) - -} #[post("/install", format = "application/json", data = "")] pub async fn install( @@ -58,6 +44,9 @@ pub async fn install( relational::Database::initial_setup(data.sql_config.clone()) .await .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; + + + auth::jwt::generate_key(); config.info.install = true; diff --git a/backend/src/routes/person.rs b/backend/src/routes/person.rs index 42afe5f..b1cc03d 100644 --- a/backend/src/routes/person.rs +++ b/backend/src/routes/person.rs @@ -1,7 +1,6 @@ use serde::{Deserialize,Serialize}; use crate::{config,utils}; use crate::database::{relational,relational::builder}; -use crate::{AppError,AppResult}; use rocket::{ get, post, http::Status, @@ -11,6 +10,8 @@ use rocket::{ }; use std::collections::HashMap; use bcrypt::{hash, verify, DEFAULT_COST}; +use crate::utils::CustomError; + #[derive(Deserialize, Serialize)] @@ -25,7 +26,7 @@ pub struct RegisterData{ pub password:String } -pub async fn insert(sql:&relational::Database,data:RegisterData) -> AppResult<()>{ +pub async fn insert(sql:&relational::Database,data:RegisterData) -> Result<(),CustomError>{ let hashed_password = hash(data.password, DEFAULT_COST).expect("Failed to hash password"); @@ -46,16 +47,11 @@ pub async fn insert(sql:&relational::Database,data:RegisterData) -> AppResult<() builder::ValidatedValue::PlainText(hashed_password) ); - let builder = builder::QueryBuilder::new(builder::SqlOperation::Insert,String::from("persons")) - .map_err(|e|{ - AppError::Database(format!("Error while building query: {}", e.to_string())) - })? + let builder = builder::QueryBuilder::new(builder::SqlOperation::Insert,String::from("persons"))? .params(user_params) ; - let _= sql.get_db().execute_query(&builder).await.map_err(|e|{ - AppError::Database(format!("Travel during execution: {}", e.to_string())) - })?; + sql.get_db().execute_query(&builder).await?; Ok(()) } diff --git a/backend/src/utils.rs b/backend/src/utils.rs index de34532..dc56d9f 100644 --- a/backend/src/utils.rs +++ b/backend/src/utils.rs @@ -1,9 +1,34 @@ use rand::seq::SliceRandom; + pub fn generate_random_string(length: usize) -> String { let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; let mut rng = rand::thread_rng(); (0..length) .map(|_| *charset.choose(&mut rng).unwrap() as char) .collect() -} \ No newline at end of file +} + +pub struct CustomError(String); + +impl std::fmt::Display for CustomError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + + +impl From for CustomError +where + T: std::error::Error + Send + 'static, +{ + fn from(error: T) -> Self { + CustomError(error.to_string()) + } +} + +impl CustomError { + pub fn from_str(error: &str) -> Self { + CustomError(error.to_string()) + } +}