From d2eac057cac1b230e98b85153a3ed3933e755f52 Mon Sep 17 00:00:00 2001 From: lsy Date: Mon, 25 Nov 2024 03:36:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8E=E7=AB=AF=EF=BC=9A=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E5=92=8C=E6=95=B0=E6=8D=AE=E5=BA=93=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E6=9E=84=E5=BB=BA=EF=BC=8C=E5=AE=9E=E7=8E=B0=E5=BA=94?= =?UTF-8?q?=E7=94=A8=E9=87=8D=E5=90=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/Cargo.toml | 2 + backend/assets/config.toml | 11 - backend/src/auth/bcrypt.rs | 16 + backend/src/auth/jwt.rs | 2 +- backend/src/auth/mod.rs | 1 + backend/src/config.rs | 4 +- backend/src/database/relational/builder.rs | 613 +++++++++++++----- backend/src/database/relational/mod.rs | 6 +- .../src/database/relational/postgresql/mod.rs | 12 +- backend/src/error.rs | 41 ++ backend/src/main.rs | 99 ++- backend/src/manage.rs | 57 -- backend/src/routes/auth/token.rs | 101 ++- backend/src/routes/configure.rs | 10 + backend/src/routes/{intsall.rs => install.rs} | 34 +- backend/src/routes/mod.rs | 50 +- backend/src/routes/person.rs | 47 +- backend/src/routes/theme.rs | 12 - backend/src/utils.rs | 28 - frontend/.env | 1 - 20 files changed, 788 insertions(+), 359 deletions(-) delete mode 100644 backend/assets/config.toml create mode 100644 backend/src/auth/bcrypt.rs create mode 100644 backend/src/error.rs delete mode 100644 backend/src/manage.rs create mode 100644 backend/src/routes/configure.rs rename backend/src/routes/{intsall.rs => install.rs} (70%) delete mode 100644 backend/src/routes/theme.rs delete mode 100644 frontend/.env diff --git a/backend/Cargo.toml b/backend/Cargo.toml index af25587..3433ed7 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -17,3 +17,5 @@ rand = "0.8.5" chrono = "0.4" regex = "1.11.1" bcrypt = "0.16" +uuid = { version = "1.11.0", features = ["v4", "serde"] } +hex = "0.4.3" \ No newline at end of file diff --git a/backend/assets/config.toml b/backend/assets/config.toml deleted file mode 100644 index 5920086..0000000 --- a/backend/assets/config.toml +++ /dev/null @@ -1,11 +0,0 @@ -[info] -install = false -non_relational = false - -[sql_config] -db_type = "postgresql" -address = "localhost" -port = 5432 -user = "postgres" -password = "postgres" -db_name = "echoes" diff --git a/backend/src/auth/bcrypt.rs b/backend/src/auth/bcrypt.rs new file mode 100644 index 0000000..602cfc2 --- /dev/null +++ b/backend/src/auth/bcrypt.rs @@ -0,0 +1,16 @@ +use crate::error::CustomErrorInto; +use crate::error::CustomResult; +use bcrypt::{hash, verify, DEFAULT_COST}; + +pub fn generate_hash(s: &str) -> CustomResult { + let hashed = hash(s, DEFAULT_COST)?; + Ok(hashed) +} + +pub fn verify_hash(s: &str, hash: &str) -> CustomResult<()> { + let is_valid = verify(s, hash)?; + if !is_valid { + return Err("密码无效".into_custom_error()); + } + Ok(()) +} diff --git a/backend/src/auth/jwt.rs b/backend/src/auth/jwt.rs index cd74b6c..79c5921 100644 --- a/backend/src/auth/jwt.rs +++ b/backend/src/auth/jwt.rs @@ -1,4 +1,4 @@ -use crate::utils::CustomResult; +use crate::error::CustomResult; use chrono::{Duration, Utc}; use ed25519_dalek::{SigningKey, VerifyingKey}; use jwt_compact::{alg::Ed25519, AlgorithmExt, Header, TimeOptions, Token, UntrustedToken}; diff --git a/backend/src/auth/mod.rs b/backend/src/auth/mod.rs index 417233c..d3d9866 100644 --- a/backend/src/auth/mod.rs +++ b/backend/src/auth/mod.rs @@ -1 +1,2 @@ +pub mod bcrypt; pub mod jwt; diff --git a/backend/src/config.rs b/backend/src/config.rs index 082eba1..0523104 100644 --- a/backend/src/config.rs +++ b/backend/src/config.rs @@ -1,4 +1,4 @@ -use crate::utils::CustomResult; +use crate::error::CustomResult; use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::{env, fs}; @@ -47,6 +47,6 @@ impl Config { } pub fn get_path() -> CustomResult { - Ok(env::current_dir()?.join("assets").join("config.toml")) + Ok(env::current_dir()?.join("config.toml")) } } diff --git a/backend/src/database/relational/builder.rs b/backend/src/database/relational/builder.rs index cc6e12a..805cc52 100644 --- a/backend/src/database/relational/builder.rs +++ b/backend/src/database/relational/builder.rs @@ -1,63 +1,289 @@ -use crate::utils::{CustomError, CustomResult}; +use crate::error::{CustomErrorInto, CustomResult}; +use chrono::{DateTime, Utc}; use regex::Regex; +use serde::Serialize; +use serde_json::Value as JsonValue; use std::collections::HashMap; use std::hash::Hash; +use uuid::Uuid; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum ValidatedValue { - Identifier(String), - RichText(String), - PlainText(String), +#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, Serialize)] +pub enum ValidationLevel { + Strict, + Standard, + Relaxed, } -impl ValidatedValue { - pub fn new_identifier(value: String) -> CustomResult { - let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]{0,63}$").unwrap(); - if !valid_pattern.is_match(&value) { - return Err(CustomError::from_str("Invalid identifier format")); +#[derive(Debug, Clone)] +pub struct TextValidator { + sql_patterns: Vec<&'static str>, + special_chars: Vec, + level_max_lengths: HashMap, + level_allowed_chars: HashMap>, +} + +impl Default for TextValidator { + fn default() -> Self { + let level_max_lengths = HashMap::from([ + (ValidationLevel::Strict, 100), + (ValidationLevel::Standard, 1000), + (ValidationLevel::Relaxed, 100000), + ]); + + let level_allowed_chars = HashMap::from([ + (ValidationLevel::Strict, vec!['_']), + ( + ValidationLevel::Standard, + vec!['_', '-', '.', ',', '!', '?', ':', ' '], + ), + ( + ValidationLevel::Relaxed, + vec![ + '_', '-', '.', ',', '!', '?', ':', ' ', '"', '\'', '(', ')', '[', ']', '{', + '}', '@', '#', '$', '%', '^', '&', '*', '+', '=', '<', '>', '/', '\\', + ], + ), + ]); + + TextValidator { + sql_patterns: vec![ + "DROP", + "TRUNCATE", + "ALTER", + "DELETE", + "UPDATE", + "INSERT", + "MERGE", + "GRANT", + "REVOKE", + "UNION", + "--", + "/*", + "EXEC", + "EXECUTE", + "WAITFOR", + "DELAY", + "BENCHMARK", + ], + special_chars: vec!['\0', '\n', '\r', '\t'], + level_max_lengths, + level_allowed_chars, } - Ok(ValidatedValue::Identifier(value)) + } +} + +impl TextValidator { + pub fn validate(&self, text: &str, level: ValidationLevel) -> CustomResult<()> { + let max_length = self + .level_max_lengths + .get(&level) + .ok_or_else(|| "Invalid validation level".into_custom_error())?; + + if text.len() > *max_length { + return Err("Text exceeds maximum length".into_custom_error()); + } + + // 简化验证逻辑 + if level == ValidationLevel::Relaxed { + return self.validate_sql_patterns(text); + } + + self.validate_chars(text, level)?; + self.validate_special_chars(text) } - pub fn new_rich_text(value: String) -> CustomResult { - let dangerous_patterns = [ - "UNION ALL SELECT", - "UNION SELECT", - "OR 1=1", - "OR '1'='1", - "DROP TABLE", - "DELETE FROM", - "UPDATE ", - "INSERT INTO", - "--", - "/*", - "*/", - "@@", - ]; + fn validate_sql_patterns(&self, text: &str) -> CustomResult<()> { + let upper_text = text.to_uppercase(); + if self + .sql_patterns + .iter() + .any(|&pattern| upper_text.contains(&pattern.to_uppercase())) + { + return Err("Potentially dangerous SQL pattern detected".into_custom_error()); + } + Ok(()) + } - let value_upper = value.to_uppercase(); - for pattern in dangerous_patterns.iter() { - if value_upper.contains(&pattern.to_uppercase()) { - return Err(CustomError::from_str("Invalid identifier format")); + fn validate_chars(&self, text: &str, level: ValidationLevel) -> CustomResult<()> { + let allowed_chars = self + .level_allowed_chars + .get(&level) + .ok_or_else(|| "Invalid validation level".into_custom_error())?; + + if let Some(invalid_char) = text + .chars() + .find(|&c| !c.is_alphanumeric() && !allowed_chars.contains(&c)) + { + return Err( + format!("Invalid character '{}' for {:?} level", invalid_char, level) + .into_custom_error(), + ); + } + Ok(()) + } + + fn validate_special_chars(&self, text: &str) -> CustomResult<()> { + if self.special_chars.iter().any(|&c| text.contains(c)) { + return Err("Invalid special character detected".into_custom_error()); + } + Ok(()) + } + + // 提供便捷方法 + pub fn validate_relaxed(&self, text: &str) -> CustomResult<()> { + self.validate(text, ValidationLevel::Relaxed) + } + + pub fn validate_standard(&self, text: &str) -> CustomResult<()> { + self.validate(text, ValidationLevel::Standard) + } + + pub fn validate_strict(&self, text: &str) -> CustomResult<()> { + self.validate(text, ValidationLevel::Strict) + } + + pub fn sanitize(&self, text: &str) -> CustomResult { + self.validate_relaxed(text)?; + Ok(text.replace('\'', "''").replace('\\', "\\\\")) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum SafeValue { + Null, + Bool(bool), + Integer(i64), + Float(f64), + Text(String, ValidationLevel), + DateTime(DateTime), + Uuid(Uuid), + Binary(Vec), + Array(Vec), + Json(JsonValue), + Enum(String, String, ValidationLevel), +} + +impl SafeValue { + pub fn from_json(value: JsonValue, level: ValidationLevel) -> CustomResult { + match value { + JsonValue::Null => Ok(SafeValue::Null), + JsonValue::Bool(b) => Ok(SafeValue::Bool(b)), + JsonValue::Number(n) => { + if let Some(i) = n.as_i64() { + Ok(SafeValue::Integer(i)) + } else if let Some(f) = n.as_f64() { + Ok(SafeValue::Float(f)) + } else { + Err("Invalid number format".into_custom_error()) + } + } + JsonValue::String(s) => { + TextValidator::default().validate(&s, level)?; + Ok(SafeValue::Text(s, level)) + } + JsonValue::Array(arr) => Ok(SafeValue::Array( + arr.into_iter() + .map(|item| SafeValue::from_json(item, level)) + .collect::>>()?, + )), + JsonValue::Object(_) => { + Self::validate_json_structure(&value, level)?; + Ok(SafeValue::Json(value)) } } - Ok(ValidatedValue::RichText(value)) } - pub fn new_plain_text(value: String) -> CustomResult { - if value.contains(';') || value.contains("--") { - return Err(CustomError::from_str("Invalid characters in text")); + fn validate_json_structure(value: &JsonValue, level: ValidationLevel) -> CustomResult<()> { + let validator = TextValidator::default(); + match value { + JsonValue::Object(map) => { + for (key, val) in map { + validator.validate(key, level)?; + Self::validate_json_structure(val, level)?; + } + } + JsonValue::Array(arr) => { + arr.iter() + .try_for_each(|item| Self::validate_json_structure(item, level))?; + } + JsonValue::String(s) => validator.validate(s, level)?, + _ => {} } - Ok(ValidatedValue::PlainText(value)) + Ok(()) } - pub fn get(&self) -> &str { + fn get_sql_type(&self) -> CustomResult { + let sql_type = match self { + SafeValue::Null => "NULL", + SafeValue::Bool(_) => "boolean", + SafeValue::Integer(_) => "bigint", + SafeValue::Float(_) => "double precision", + SafeValue::Text(_, _) => "text", + SafeValue::DateTime(_) => "timestamp with time zone", + SafeValue::Uuid(_) => "uuid", + SafeValue::Binary(_) => "bytea", + SafeValue::Array(_) | SafeValue::Json(_) => "jsonb", + SafeValue::Enum(_, enum_type, level) => { + TextValidator::default().validate(enum_type, *level)?; + return Ok(enum_type.replace('\'', "''")); + } + }; + Ok(sql_type.to_string()) + } + + pub fn to_sql_string(&self) -> CustomResult { match self { - ValidatedValue::Identifier(s) - | ValidatedValue::RichText(s) - | ValidatedValue::PlainText(s) => s, + SafeValue::Null => Ok("NULL".to_string()), + SafeValue::Bool(b) => Ok(b.to_string()), + SafeValue::Integer(i) => Ok(i.to_string()), + SafeValue::Float(f) => Ok(f.to_string()), + SafeValue::Text(s, level) => { + TextValidator::default().validate(s, *level)?; + Ok(s.replace('\'', "''")) + } + SafeValue::DateTime(dt) => Ok(format!("'{}'", dt.to_rfc3339())), + SafeValue::Uuid(u) => Ok(format!("'{}'", u)), + SafeValue::Binary(b) => Ok(format!("'\\x{}'", hex::encode(b))), + SafeValue::Array(arr) => { + let values: CustomResult> = arr.iter().map(|v| v.to_sql_string()).collect(); + Ok(format!("ARRAY[{}]", values?.join(","))) + } + SafeValue::Json(j) => { + let json_str = serde_json::to_string(j)?; + TextValidator::default().validate(&json_str, ValidationLevel::Relaxed)?; + Ok(json_str.replace('\'', "''")) + } + SafeValue::Enum(s, _, level) => { + TextValidator::default().validate(s, *level)?; + Ok(s.to_string()) + } } } + + fn to_param_sql(&self, param_index: usize) -> CustomResult { + if matches!(self, SafeValue::Null) { + Ok("NULL".to_string()) + } else { + Ok(format!("${}::{}", param_index, self.get_sql_type()?)) + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Identifier(String); + +impl Identifier { + pub fn new(value: String) -> CustomResult { + let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_\.]{0,63}$")?; + if !valid_pattern.is_match(&value) { + return Err("Invalid identifier format".into_custom_error()); + } + Ok(Identifier(value)) + } + + pub fn as_str(&self) -> &str { + &self.0 + } } #[derive(Debug, Clone, PartialEq)] @@ -80,6 +306,8 @@ pub enum Operator { In, IsNull, IsNotNull, + JsonContains, + JsonExists, } impl Operator { @@ -95,31 +323,23 @@ impl Operator { Operator::In => "IN", Operator::IsNull => "IS NULL", Operator::IsNotNull => "IS NOT NULL", + Operator::JsonContains => "@>", + Operator::JsonExists => "?", } } } #[derive(Debug, Clone)] -pub struct WhereCondition { - field: ValidatedValue, +pub struct Condition { + field: Identifier, operator: Operator, - value: Option, + value: Option, } -impl WhereCondition { - pub fn new(field: String, operator: Operator, value: Option) -> CustomResult { - let field = ValidatedValue::new_identifier(field)?; - - let value = match value { - Some(v) => Some(match operator { - Operator::Like => ValidatedValue::new_plain_text(v)?, - _ => ValidatedValue::new_plain_text(v)?, - }), - None => None, - }; - - Ok(WhereCondition { - field, +impl Condition { + pub fn new(field: String, operator: Operator, value: Option) -> CustomResult { + Ok(Condition { + field: Identifier::new(field)?, operator, value, }) @@ -130,170 +350,233 @@ impl WhereCondition { pub enum WhereClause { And(Vec), Or(Vec), - Condition(WhereCondition), + Condition(Condition), } + #[derive(Debug, Clone)] pub struct QueryBuilder { operation: SqlOperation, - table: ValidatedValue, - fields: Vec, - params: HashMap, + table: Identifier, + fields: Vec, + values: HashMap, where_clause: Option, - order_by: Option, + order_by: Option, limit: Option, + offset: Option, } impl QueryBuilder { pub fn new(operation: SqlOperation, table: String) -> CustomResult { Ok(QueryBuilder { operation, - table: ValidatedValue::new_identifier(table)?, + table: Identifier::new(table)?, fields: Vec::new(), - params: HashMap::new(), + values: HashMap::new(), where_clause: None, order_by: None, limit: None, + offset: None, }) } - pub fn build(&self) -> CustomResult<(String, Vec)> { + pub fn add_field(&mut self, field: String) -> CustomResult<&mut Self> { + self.fields.push(Identifier::new(field)?); + Ok(self) + } + + pub fn set_value(&mut self, field: String, value: SafeValue) -> CustomResult<&mut Self> { + self.values.insert(Identifier::new(field)?, value); + Ok(self) + } + + pub fn add_condition(&mut self, condition: WhereClause) -> &mut Self { + self.where_clause = Some(condition); + self + } + + pub fn build(&self) -> CustomResult<(String, Vec)> { let mut query = String::new(); - let mut values = Vec::new(); - let mut param_counter = 1; + let mut params = Vec::new(); match self.operation { - SqlOperation::Select => { - let fields = if self.fields.is_empty() { - "*".to_string() - } else { - self.fields - .iter() - .map(|f| f.get().to_string()) - .collect::>() - .join(", ") - }; - query.push_str(&format!("SELECT {} FROM {}", fields, self.table.get())); - } - SqlOperation::Insert => { - let fields: Vec = self.params.keys().map(|k| k.get().to_string()).collect(); - let placeholders: Vec = - (1..=self.params.len()).map(|i| format!("${}", i)).collect(); - - query.push_str(&format!( - "INSERT INTO {} ({}) VALUES ({})", - self.table.get(), - fields.join(", "), - placeholders.join(", ") - )); - - values.extend(self.params.values().map(|v| v.get().to_string())); - return Ok((query, values)); - } - SqlOperation::Update => { - query.push_str(&format!("UPDATE {} SET ", self.table.get())); - let set_clauses: Vec = self - .params - .iter() - .map(|(key, _)| { - let placeholder = format!("${}", param_counter); - values.push(self.params[key].get().to_string()); - param_counter += 1; - format!("{} = {}", key.get(), placeholder) - }) - .collect(); - query.push_str(&set_clauses.join(", ")); - } - SqlOperation::Delete => { - query.push_str(&format!("DELETE FROM {}", self.table.get())); - } + SqlOperation::Select => self.build_select(&mut query)?, + SqlOperation::Insert => self.build_insert(&mut query, &mut params)?, + SqlOperation::Update => self.build_update(&mut query, &mut params)?, + SqlOperation::Delete => query.push_str(&format!("DELETE FROM {}", self.table.as_str())), } if let Some(where_clause) = &self.where_clause { query.push_str(" WHERE "); - let (where_sql, where_values) = self.build_where_clause(where_clause, param_counter)?; + let (where_sql, where_params) = self.build_where_clause(where_clause)?; query.push_str(&where_sql); - values.extend(where_values); + params.extend(where_params); } - if let Some(order) = &self.order_by { - query.push_str(&format!(" ORDER BY {}", order.get())); - } - - if let Some(limit) = self.limit { - query.push_str(&format!(" LIMIT {}", limit)); - } - - Ok((query, values)) + self.build_pagination(&mut query)?; + Ok((query, params)) } - fn build_where_clause( - &self, - clause: &WhereClause, - mut param_counter: i32, - ) -> CustomResult<(String, Vec)> { - let mut values = Vec::new(); + fn build_select(&self, query: &mut String) -> CustomResult<()> { + let fields = if self.fields.is_empty() { + "*".to_string() + } else { + self.fields + .iter() + .map(|f| f.as_str()) + .collect::>() + .join(", ") + }; + query.push_str(&format!("SELECT {} FROM {}", fields, self.table.as_str())); + Ok(()) + } + + fn build_insert(&self, query: &mut String, params: &mut Vec) -> CustomResult<()> { + let mut fields = Vec::new(); + let mut placeholders = Vec::new(); + + for (field, value) in &self.values { + fields.push(field.as_str()); + if matches!(value, SafeValue::Null) { + placeholders.push("NULL".to_string()); + } else { + placeholders.push(format!("${}::{}", params.len() + 1, value.get_sql_type()?)); + params.push(value.clone()); + } + } + + query.push_str(&format!( + "INSERT INTO {} ({}) VALUES ({})", + self.table.as_str(), + fields.join(", "), + placeholders.join(", ") + )); + + Ok(()) + } + + fn build_update(&self, query: &mut String, params: &mut Vec) -> CustomResult<()> { + query.push_str(&format!("UPDATE {} SET ", self.table.as_str())); + + let mut updates = Vec::new(); + for (field, value) in &self.values { + let set_sql = format!( + "{} = {}", + field.as_str(), + value.to_param_sql(params.len() + 1)? + ); + if !matches!(value, SafeValue::Null) { + params.push(value.clone()); + } + updates.push(set_sql); + } + + query.push_str(&updates.join(", ")); + Ok(()) + } + + fn build_delete(&self, query: &mut String) -> CustomResult<()> { + query.push_str(&format!("DELETE FROM {}", self.table.as_str())); + Ok(()) + } + + fn build_where_clause(&self, clause: &WhereClause) -> CustomResult<(String, Vec)> { + let mut params = Vec::new(); + let mut param_index = 1; // 添加参数索引计数器 let sql = match clause { WhereClause::And(conditions) => { let mut parts = Vec::new(); for condition in conditions { - let (sql, mut vals) = self.build_where_clause(condition, param_counter)?; - param_counter += vals.len() as i32; + let (sql, mut condition_params) = + self.build_where_clause_with_index(condition, param_index)?; + param_index += condition_params.len(); // 更新参数索引 parts.push(sql); - values.append(&mut vals); + params.append(&mut condition_params); } format!("({})", parts.join(" AND ")) } WhereClause::Or(conditions) => { let mut parts = Vec::new(); for condition in conditions { - let (sql, mut vals) = self.build_where_clause(condition, param_counter)?; - param_counter += vals.len() as i32; + let (sql, mut condition_params) = + self.build_where_clause_with_index(condition, param_index)?; + param_index += condition_params.len(); // 更新参数索引 parts.push(sql); - values.append(&mut vals); + params.append(&mut condition_params); } format!("({})", parts.join(" OR ")) } - WhereClause::Condition(cond) => { - if let Some(value) = &cond.value { - let placeholder = format!("${}", param_counter); - values.push(value.get().to_string()); - format!( - "{} {} {}", - cond.field.get(), - cond.operator.as_str(), - placeholder - ) - } else { - format!("{} {}", cond.field.get(), cond.operator.as_str()) - } + WhereClause::Condition(condition) => { + self.build_condition(condition, &mut params, param_index)? } }; - Ok((sql, values)) - } - pub fn fields(mut self, fields: Vec) -> Self { - self.fields = fields; - self + Ok((sql, params)) } - pub fn params(mut self, params: HashMap) -> Self { - self.params = params; - self + // 添加新的辅助方法 + fn build_where_clause_with_index( + &self, + clause: &WhereClause, + start_index: usize, + ) -> CustomResult<(String, Vec)> { + let mut params = Vec::new(); + + let sql = match clause { + WhereClause::Condition(condition) => { + self.build_condition(condition, &mut params, start_index)? + } + _ => { + let (sql, params_inner) = self.build_where_clause(clause)?; + params = params_inner; + sql + } + }; + + Ok((sql, params)) } - pub fn where_clause(mut self, clause: WhereClause) -> Self { - self.where_clause = Some(clause); - self + fn build_condition( + &self, + condition: &Condition, + params: &mut Vec, + param_index: usize, + ) -> CustomResult { + match &condition.value { + Some(value) => { + let sql = format!( + "{} {} {}", + condition.field.as_str(), + condition.operator.as_str(), + value.to_param_sql(param_index)? + ); + if !matches!(value, SafeValue::Null) { + params.push(value.clone()); + } + Ok(sql) + } + None => Ok(format!( + "{} {}", + condition.field.as_str(), + condition.operator.as_str() + )), + } } - pub fn order_by(mut self, order: ValidatedValue) -> Self { - self.order_by = Some(order); - self - } + // 构建分页 + fn build_pagination(&self, query: &mut String) -> CustomResult<()> { + if let Some(order) = &self.order_by { + query.push_str(&format!(" ORDER BY {}", order.as_str())); + } - pub fn limit(mut self, limit: i32) -> Self { - self.limit = Some(limit); - self + if let Some(limit) = self.limit { + query.push_str(&format!(" LIMIT {}", limit)); + } + + if let Some(offset) = self.offset { + query.push_str(&format!(" OFFSET {}", offset)); + } + + Ok(()) } } diff --git a/backend/src/database/relational/mod.rs b/backend/src/database/relational/mod.rs index bb365cd..3ed7237 100644 --- a/backend/src/database/relational/mod.rs +++ b/backend/src/database/relational/mod.rs @@ -1,6 +1,6 @@ mod postgresql; use crate::config; -use crate::utils::{CustomError, CustomResult}; +use crate::error::{CustomErrorInto, CustomResult}; use async_trait::async_trait; use std::collections::HashMap; use std::sync::Arc; @@ -33,7 +33,7 @@ impl Database { pub async fn link(database: &config::SqlConfig) -> CustomResult { let db = match database.db_type.as_str() { "postgresql" => postgresql::Postgresql::connect(database).await?, - _ => return Err(CustomError::from_str("unknown database type")), + _ => return Err("unknown database type".into_custom_error()), }; Ok(Self { @@ -44,7 +44,7 @@ impl Database { pub async fn initial_setup(database: config::SqlConfig) -> CustomResult<()> { match database.db_type.as_str() { "postgresql" => postgresql::Postgresql::initialization(database).await?, - _ => return Err(CustomError::from_str("unknown database type")), + _ => return Err("unknown database type".into_custom_error()), }; Ok(()) } diff --git a/backend/src/database/relational/postgresql/mod.rs b/backend/src/database/relational/postgresql/mod.rs index b761e71..6bea843 100644 --- a/backend/src/database/relational/postgresql/mod.rs +++ b/backend/src/database/relational/postgresql/mod.rs @@ -1,11 +1,11 @@ use super::{builder, DatabaseTrait}; use crate::config; -use crate::utils::CustomResult; +use crate::error::CustomErrorInto; +use crate::error::CustomResult; use async_trait::async_trait; use sqlx::{Column, Executor, PgPool, Row}; use std::collections::HashMap; use std::{env, fs}; - #[derive(Clone)] pub struct Postgresql { pool: PgPool, @@ -70,10 +70,14 @@ impl DatabaseTrait for Postgresql { let mut sqlx_query = sqlx::query(&query); for value in values { - sqlx_query = sqlx_query.bind(value); + sqlx_query = sqlx_query.bind(value.to_sql_string()?); } - let rows = sqlx_query.fetch_all(&self.pool).await?; + let rows = sqlx_query.fetch_all(&self.pool).await.map_err(|e| { + let (sql, params) = builder.build().unwrap(); + format!("Err:{}\n,SQL: {}\nParams: {:?}", e.to_string(), sql, params) + .into_custom_error() + })?; let mut results = Vec::new(); for row in rows { diff --git a/backend/src/error.rs b/backend/src/error.rs new file mode 100644 index 0000000..21c2c9d --- /dev/null +++ b/backend/src/error.rs @@ -0,0 +1,41 @@ +use rocket::http::Status; +use rocket::response::status; + +pub type AppResult = Result>; + +pub trait AppResultInto { + fn into_app_result(self) -> AppResult; +} + +#[derive(Debug)] +pub struct CustomError(String); + +impl std::fmt::Display for CustomError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +pub trait CustomErrorInto { + fn into_custom_error(self) -> CustomError; +} + +impl CustomErrorInto for &str { + fn into_custom_error(self) -> CustomError { + CustomError(self.to_string()) + } +} + +impl From for CustomError { + fn from(error: E) -> Self { + CustomError(error.to_string()) + } +} + +pub type CustomResult = Result; + +impl AppResultInto for CustomResult { + fn into_app_result(self) -> AppResult { + self.map_err(|e| status::Custom(Status::InternalServerError, e.to_string())) + } +} diff --git a/backend/src/main.rs b/backend/src/main.rs index 56571d2..a21adbb 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,59 +1,98 @@ mod auth; mod config; mod database; -mod manage; mod routes; mod utils; use database::relational; -use rocket::launch; +use rocket::Shutdown; use std::sync::Arc; use tokio::sync::Mutex; -use utils::{AppResult, CustomError, CustomResult}; +mod error; +use error::{CustomErrorInto, CustomResult}; -struct AppState { +pub struct AppState { db: Arc>>, configure: Arc>, + shutdown: Arc>>, + restart_progress: Arc>, } impl AppState { - async fn get_sql(&self) -> CustomResult { + pub fn new(config: config::Config) -> Self { + Self { + db: Arc::new(Mutex::new(None)), + configure: Arc::new(Mutex::new(config)), + shutdown: Arc::new(Mutex::new(None)), + restart_progress: Arc::new(Mutex::new(false)), + } + } + + pub async fn sql_get(&self) -> CustomResult { self.db .lock() .await .clone() - .ok_or_else(|| CustomError::from_str("Database not initialized")) + .ok_or("数据库未连接".into_custom_error()) } - async fn link_sql(&self, config: &config::SqlConfig) -> CustomResult<()> { + pub async fn sql_link(&self, config: &config::SqlConfig) -> CustomResult<()> { let database = relational::Database::link(config).await?; *self.db.lock().await = Some(database); Ok(()) } -} -#[launch] -async fn rocket() -> _ { - let config = config::Config::read().expect("Failed to read config"); - - let state = AppState { - db: Arc::new(Mutex::new(None)), - configure: Arc::new(Mutex::new(config.clone())), - }; - - let mut rocket_builder = rocket::build().manage(state); - - if config.info.install { - if let Some(state) = rocket_builder.state::() { - state - .link_sql(&config.sql_config) - .await - .expect("Failed to connect to database"); - } - } else { - rocket_builder = rocket_builder.mount("/", rocket::routes![routes::intsall::install]); + pub async fn set_shutdown(&self, shutdown: Shutdown) { + *self.shutdown.lock().await = Some(shutdown); } - rocket_builder = rocket_builder.mount("/auth/token", routes::jwt_routes()); + pub async fn trigger_restart(&self) -> CustomResult<()> { + *self.restart_progress.lock().await = true; - rocket_builder + self.shutdown + .lock() + .await + .take() + .ok_or("未能获取rocket的shutdown".into_custom_error())? + .notify(); + + Ok(()) + } +} + +#[rocket::main] +async fn main() -> CustomResult<()> { + let config = config::Config::read()?; + + let state = AppState::new(config.clone()); + + if config.info.install { + state.sql_link(&config.sql_config).await?; + } + + let state = Arc::new(state); + + let rocket_builder = rocket::build().manage(state.clone()); + + let rocket_builder = if !config.info.install { + rocket_builder.mount("/", rocket::routes![routes::install::install]) + } else { + rocket_builder.mount("/auth/token", routes::jwt_routes()) + }; + + let rocket = rocket_builder.ignite().await?; + + rocket + .state::>() + .ok_or("未能获取AppState".into_custom_error())? + .set_shutdown(rocket.shutdown()) + .await; + + rocket.launch().await?; + + let restart_progress = *state.restart_progress.lock().await; + if restart_progress { + let current_exe = std::env::current_exe()?; + let _ = std::process::Command::new(current_exe).spawn(); + } + std::process::exit(0); } diff --git a/backend/src/manage.rs b/backend/src/manage.rs deleted file mode 100644 index ba69f4f..0000000 --- a/backend/src/manage.rs +++ /dev/null @@ -1,57 +0,0 @@ -use rocket::shutdown::Shutdown; -use std::env; -use std::path::Path; -use std::process::{exit, Command}; -use tokio::signal; - -// 应用管理器 -pub struct AppManager { - shutdown: Shutdown, - executable_path: String, -} - -impl AppManager { - pub fn new(shutdown: Shutdown) -> Self { - let executable_path = env::current_exe() - .expect("Failed to get executable path") - .to_string_lossy() - .into_owned(); - - Self { - shutdown, - executable_path, - } - } - - // 优雅关闭 - pub async fn graceful_shutdown(&self) { - println!("Initiating graceful shutdown..."); - - // 触发 Rocket 的优雅关闭 - self.shutdown.notify(); - - // 等待一段时间以确保连接正确关闭 - tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; - } - - // 重启应用 - pub async fn restart(&self) -> Result<(), Box> { - println!("Preparing to restart application..."); - - // 执行优雅关闭 - self.graceful_shutdown().await; - - // 在新进程中启动应用 - if cfg!(target_os = "windows") { - Command::new("cmd") - .args(&["/C", &self.executable_path]) - .spawn()?; - } else { - Command::new(&self.executable_path).spawn()?; - } - - // 退出当前进程 - println!("Application restarting..."); - exit(0); - } -} diff --git a/backend/src/routes/auth/token.rs b/backend/src/routes/auth/token.rs index 9ee568d..336e3de 100644 --- a/backend/src/routes/auth/token.rs +++ b/backend/src/routes/auth/token.rs @@ -1,15 +1,102 @@ use crate::auth; -use crate::{AppResult, AppState}; +use crate::database::relational::builder; +use crate::error::{AppResult, AppResultInto}; +use crate::AppState; use chrono::Duration; -use rocket::{get, http::Status, response::status, State}; +use jwt_compact::Token; +use rocket::{ + http::Status, + post, + response::status, + serde::json::{Json, Value}, + State, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::sync::Arc; +#[derive(Deserialize, Serialize)] +pub struct TokenSystemData { + name: String, + password: String, +} +#[post("/system", format = "application/json", data = "")] +pub async fn token_system( + state: &State>, + data: Json, +) -> AppResult { + let name_condition = builder::Condition::new( + "person_name".to_string(), + builder::Operator::Eq, + Some(builder::SafeValue::Text( + data.name.to_string(), + builder::ValidationLevel::Relaxed, + )), + ) + .into_app_result()?; + + let email_condition = builder::Condition::new( + "person_email".to_string(), + builder::Operator::Eq, + Some(builder::SafeValue::Text( + "author@lsy22.com".to_string(), + builder::ValidationLevel::Relaxed, + )), + ) + .into_app_result()?; + + let level_condition = builder::Condition::new( + "person_level".to_string(), + builder::Operator::Eq, + Some(builder::SafeValue::Enum( + "administrators".to_string(), + "privilege_level".to_string(), + builder::ValidationLevel::Standard, + )), + ) + .into_app_result()?; + + let where_clause = builder::WhereClause::And(vec![ + builder::WhereClause::Condition(name_condition), + builder::WhereClause::Condition(email_condition), + builder::WhereClause::Condition(level_condition), + ]); + + let mut builder = + builder::QueryBuilder::new(builder::SqlOperation::Select, String::from("persons")) + .into_app_result()?; + + let builder = builder + .add_field("person_password".to_string()) + .into_app_result()?; + + let sql_builder = builder.add_condition(where_clause); + let values = state + .sql_get() + .await + .into_app_result()? + .get_db() + .execute_query(&sql_builder) + .await + .into_app_result()?; + + let password = values + .first() + .ok_or(status::Custom( + Status::NotFound, + String::from("该用户并非系统用户"), + ))? + .get("person_password") + .ok_or(status::Custom( + Status::NotFound, + String::from("该用户密码丢失"), + ))?; + + auth::bcrypt::verify_hash(&data.password, password).into_app_result()?; -#[get("/system")] -pub async fn token_system(_state: &State) -> AppResult> { let claims = auth::jwt::CustomClaims { name: "system".into(), }; + let token = auth::jwt::generate_jwt(claims, Duration::seconds(1)).into_app_result()?; - auth::jwt::generate_jwt(claims, Duration::seconds(1)) - .map(|token| status::Custom(Status::Ok, token)) - .map_err(|e| status::Custom(Status::InternalServerError, e.to_string())) + Ok(token) } diff --git a/backend/src/routes/configure.rs b/backend/src/routes/configure.rs new file mode 100644 index 0000000..c358d25 --- /dev/null +++ b/backend/src/routes/configure.rs @@ -0,0 +1,10 @@ +use super::SystemToken; +use crate::error::AppResult; +use rocket::{ + get, + http::Status, + post, + response::status, + serde::json::{Json, Value}, + Request, +}; diff --git a/backend/src/routes/intsall.rs b/backend/src/routes/install.rs similarity index 70% rename from backend/src/routes/intsall.rs rename to backend/src/routes/install.rs index 80a2f3f..38f0608 100644 --- a/backend/src/routes/intsall.rs +++ b/backend/src/routes/install.rs @@ -1,12 +1,13 @@ use crate::auth; use crate::database::relational; +use crate::error::{AppResult, AppResultInto}; use crate::routes::person; -use crate::utils::AppResult; use crate::AppState; use crate::{config, utils}; use chrono::Duration; use rocket::{http::Status, post, response::status, serde::json::Json, State}; use serde::{Deserialize, Serialize}; +use std::sync::Arc; #[derive(Deserialize, Serialize)] pub struct InstallData { @@ -25,7 +26,7 @@ pub struct InstallReplyData { #[post("/install", format = "application/json", data = "")] pub async fn install( data: Json, - state: &State, + state: &State>, ) -> AppResult>> { let mut config = state.configure.lock().await; if config.info.install { @@ -39,20 +40,14 @@ pub async fn install( relational::Database::initial_setup(data.sql_config.clone()) .await - .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; + .into_app_result()?; let _ = auth::jwt::generate_key(); config.info.install = true; - state - .link_sql(data.sql_config.clone()) - .await - .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; - let sql = state - .get_sql() - .await - .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; + state.sql_link(&data.sql_config).await.into_app_result()?; + let sql = state.sql_get().await.into_app_result()?; let system_name = utils::generate_random_string(20); let system_password = utils::generate_random_string(20); @@ -63,30 +58,35 @@ pub async fn install( name: data.name.clone(), email: data.email, password: data.password, + level: "administrators".to_string(), }, ) .await - .map_err(|e| status::Custom(Status::InternalServerError, e.to_string())); + .into_app_result()?; + let _ = person::insert( &sql, person::RegisterData { name: system_name.clone(), email: String::from("author@lsy22.com"), - password: system_name.clone(), + password: system_password.clone(), + level: "administrators".to_string(), }, ) .await - .map_err(|e| status::Custom(Status::InternalServerError, e.to_string())); + .into_app_result()?; + let token = auth::jwt::generate_jwt( auth::jwt::CustomClaims { name: data.name.clone(), }, Duration::days(7), ) - .map_err(|e| status::Custom(Status::Unauthorized, e.to_string()))?; + .into_app_result()?; - config::Config::write(config.clone()) - .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; + config::Config::write(config.clone()).into_app_result()?; + + state.trigger_restart().await.into_app_result()?; Ok(status::Custom( Status::Ok, Json(InstallReplyData { diff --git a/backend/src/routes/mod.rs b/backend/src/routes/mod.rs index 15ea6a8..26b62ce 100644 --- a/backend/src/routes/mod.rs +++ b/backend/src/routes/mod.rs @@ -1,9 +1,55 @@ pub mod auth; -pub mod intsall; +pub mod configure; +pub mod install; pub mod person; -pub mod theme; +use rocket::http::Status; +use rocket::request::{FromRequest, Outcome, Request}; use rocket::routes; +pub struct Token(String); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for Token { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let token = request + .headers() + .get_one("Authorization") + .map(|value| value.replace("Bearer ", "")); + + match token { + Some(token) => Outcome::Success(Token(token)), + None => Outcome::Success(Token("".to_string())), + } + } +} + +pub struct SystemToken(String); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for SystemToken { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let token = request + .headers() + .get_one("Authorization") + .map(|value| value.replace("Bearer ", "")); + + match token { + Some(token) => { + if token == "system" { + Outcome::Success(SystemToken(token)) + } else { + Outcome::Error((Status::Unauthorized, ())) + } + } + None => Outcome::Error((Status::Unauthorized, ())), + } + } +} + pub fn jwt_routes() -> Vec { routes![auth::token::token_system] } diff --git a/backend/src/routes/person.rs b/backend/src/routes/person.rs index 3016dec..4237c08 100644 --- a/backend/src/routes/person.rs +++ b/backend/src/routes/person.rs @@ -1,7 +1,8 @@ +use crate::auth; +use crate::auth::bcrypt; use crate::database::{relational, relational::builder}; -use crate::utils::CustomResult; +use crate::error::{CustomErrorInto, CustomResult}; use crate::{config, utils}; -use bcrypt::{hash, DEFAULT_COST}; use rocket::{get, http::Status, post, response::status, serde::json::Json, State}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -16,28 +17,36 @@ pub struct RegisterData { pub name: String, pub email: String, pub password: String, + pub level: String, } pub async fn insert(sql: &relational::Database, data: RegisterData) -> CustomResult<()> { - let hashed_password = hash(data.password, DEFAULT_COST).expect("Failed to hash password"); + let mut builder = + builder::QueryBuilder::new(builder::SqlOperation::Insert, "persons".to_string())?; - let mut user_params = HashMap::new(); - user_params.insert( - builder::ValidatedValue::Identifier(String::from("person_name")), - builder::ValidatedValue::PlainText(data.name), - ); - user_params.insert( - builder::ValidatedValue::Identifier(String::from("person_email")), - builder::ValidatedValue::PlainText(data.email), - ); - user_params.insert( - builder::ValidatedValue::Identifier(String::from("person_password")), - builder::ValidatedValue::PlainText(hashed_password), - ); + let password_hash = auth::bcrypt::generate_hash(&data.password)?; - let builder = - builder::QueryBuilder::new(builder::SqlOperation::Insert, String::from("persons"))? - .params(user_params); + builder + .set_value( + "person_name".to_string(), + builder::SafeValue::Text(data.name.to_string(), builder::ValidationLevel::Relaxed), + )? + .set_value( + "person_email".to_string(), + builder::SafeValue::Text(data.email.to_string(), builder::ValidationLevel::Relaxed), + )? + .set_value( + "person_password".to_string(), + builder::SafeValue::Text(password_hash, builder::ValidationLevel::Relaxed), + )? + .set_value( + "person_level".to_string(), + builder::SafeValue::Enum( + data.level.to_string(), + "privilege_level".to_string(), + builder::ValidationLevel::Standard, + ), + )?; sql.get_db().execute_query(&builder).await?; Ok(()) diff --git a/backend/src/routes/theme.rs b/backend/src/routes/theme.rs deleted file mode 100644 index 1b5a6de..0000000 --- a/backend/src/routes/theme.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::utils::AppResult; -use rocket::{ - http::Status, - post, - response::status, - serde::json::{Json, Value}, -}; - -#[post("/current", format = "application/json", data = "")] -pub fn theme_current(data: Json) -> AppResult>> { - Ok(status::Custom(Status::Ok, Json(Value::Object(())))) -} diff --git a/backend/src/utils.rs b/backend/src/utils.rs index 5a74e5f..9e02d93 100644 --- a/backend/src/utils.rs +++ b/backend/src/utils.rs @@ -1,5 +1,4 @@ use rand::seq::SliceRandom; -use rocket::response::status; pub fn generate_random_string(length: usize) -> String { let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; @@ -8,30 +7,3 @@ pub fn generate_random_string(length: usize) -> String { .map(|_| *charset.choose(&mut rng).unwrap() as char) .collect() } -#[derive(Debug)] -pub struct CustomError(String); - -impl std::fmt::Display for CustomError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for CustomError -where - T: std::error::Error + Send + 'static, -{ - fn from(error: T) -> Self { - CustomError(error.to_string()) - } -} - -impl CustomError { - pub fn from_str(error: &str) -> Self { - CustomError(error.to_string()) - } -} - -pub type CustomResult = Result; - -pub type AppResult = Result>; diff --git a/frontend/.env b/frontend/.env deleted file mode 100644 index 6d3542e..0000000 --- a/frontend/.env +++ /dev/null @@ -1 +0,0 @@ -VITE_API_BASE_URL = 1 \ No newline at end of file