后端:定义错误,所有返回错误都具有发送和全局生命具有所有权,所有返回的错误都可以使用?来解用

This commit is contained in:
lsy 2024-11-21 11:47:41 +08:00
parent 33b53b3663
commit 3a88c33a6e
7 changed files with 88 additions and 122 deletions

View File

@ -5,7 +5,7 @@ use ed25519_dalek::{SigningKey, VerifyingKey};
use std::fs::File; use std::fs::File;
use std::io::Write; use std::io::Write;
use std::{env, fs}; use std::{env, fs};
use std::error::Error; use crate::utils::CustomError;
use rand::{SeedableRng, RngCore}; use rand::{SeedableRng, RngCore};
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
@ -27,7 +27,7 @@ impl SecretKey {
} }
} }
pub fn generate_key() -> Result<(), Box<dyn Error>> { pub fn generate_key() -> Result<(),CustomError> {
let mut csprng = rand::rngs::StdRng::from_entropy(); let mut csprng = rand::rngs::StdRng::from_entropy();
let mut private_key_bytes = [0u8; 32]; let mut private_key_bytes = [0u8; 32];
@ -49,7 +49,7 @@ pub fn generate_key() -> Result<(), Box<dyn Error>> {
Ok(()) Ok(())
} }
pub fn get_key(key_type: SecretKey) -> Result<[u8; 32], Box<dyn Error>> { pub fn get_key(key_type: SecretKey) -> Result<[u8; 32],CustomError> {
let path = env::current_dir()? let path = env::current_dir()?
.join("assets") .join("assets")
.join("key") .join("key")
@ -60,7 +60,7 @@ pub fn get_key(key_type: SecretKey) -> Result<[u8; 32], Box<dyn Error>> {
Ok(key) Ok(key)
} }
pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> Result<String, Box<dyn Error>> { pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> Result<String,CustomError> {
let key_bytes = get_key(SecretKey::Signing)?; let key_bytes = get_key(SecretKey::Signing)?;
let signing_key = SigningKey::from_bytes(&key_bytes); let signing_key = SigningKey::from_bytes(&key_bytes);
@ -79,7 +79,7 @@ pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> Result<String,
Ok(token) Ok(token)
} }
pub fn validate_jwt(token: &str) -> Result<CustomClaims, Box<dyn Error>> { pub fn validate_jwt(token: &str) -> Result<CustomClaims, CustomError> {
let key_bytes = get_key(SecretKey::Verifying)?; let key_bytes = get_key(SecretKey::Verifying)?;
let verifying = VerifyingKey::from_bytes(&key_bytes)?; let verifying = VerifyingKey::from_bytes(&key_bytes)?;
let token = UntrustedToken::new(token)?; let token = UntrustedToken::new(token)?;

View File

@ -1,8 +1,6 @@
use regex::Regex; use regex::Regex;
use crate::utils::CustomError;
use std::collections::HashMap; use std::collections::HashMap;
use super::DatabaseError;
use std::hash::Hash; use std::hash::Hash;
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
@ -13,17 +11,15 @@ pub enum ValidatedValue {
} }
impl ValidatedValue { impl ValidatedValue {
pub fn new_identifier(value: String) -> Result<Self, DatabaseError> { pub fn new_identifier(value: String) -> Result<Self, CustomError> {
let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]{0,63}$").unwrap(); let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]{0,63}$").unwrap();
if !valid_pattern.is_match(&value) { if !valid_pattern.is_match(&value) {
return Err(DatabaseError::ValidationError( return Err(CustomError::from_str("Invalid identifier format"));
"Invalid identifier format".to_string(),
));
} }
Ok(ValidatedValue::Identifier(value)) Ok(ValidatedValue::Identifier(value))
} }
pub fn new_rich_text(value: String) -> Result<Self, DatabaseError> { pub fn new_rich_text(value: String) -> Result<Self, CustomError> {
let dangerous_patterns = [ let dangerous_patterns = [
"UNION ALL SELECT", "UNION ALL SELECT",
"UNION SELECT", "UNION SELECT",
@ -42,24 +38,24 @@ impl ValidatedValue {
let value_upper = value.to_uppercase(); let value_upper = value.to_uppercase();
for pattern in dangerous_patterns.iter() { for pattern in dangerous_patterns.iter() {
if value_upper.contains(&pattern.to_uppercase()) { if value_upper.contains(&pattern.to_uppercase()) {
return Err(DatabaseError::SqlInjectionAttempt( return Err(CustomError::from_str("Invalid identifier format"));
format!("Dangerous SQL pattern detected: {}", pattern)
));
} }
} }
Ok(ValidatedValue::RichText(value)) Ok(ValidatedValue::RichText(value))
} }
pub fn new_plain_text(value: String) -> Result<Self, DatabaseError> { pub fn new_plain_text(value: String) -> Result<Self, CustomError> {
if value.contains(';') || value.contains("--") { if value.contains(';') || value.contains("--") {
return Err(DatabaseError::ValidationError("Invalid characters in text".to_string())); return Err(CustomError::from_str("Invalid characters in text"));
} }
Ok(ValidatedValue::PlainText(value)) Ok(ValidatedValue::PlainText(value))
} }
pub fn get(&self) -> &str { pub fn get(&self) -> &str {
match self { match self {
ValidatedValue::Identifier(s) | ValidatedValue::RichText(s) | ValidatedValue::PlainText(s) => s, ValidatedValue::Identifier(s)
| ValidatedValue::RichText(s)
| ValidatedValue::PlainText(s) => s,
} }
} }
} }
@ -105,7 +101,7 @@ impl Operator {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct WhereCondition { pub struct WhereCondition {
field: ValidatedValue, field: ValidatedValue,
operator: Operator, operator: Operator,
value: Option<ValidatedValue>, value: Option<ValidatedValue>,
} }
@ -115,9 +111,9 @@ impl WhereCondition {
field: String, field: String,
operator: Operator, operator: Operator,
value: Option<String>, value: Option<String>,
) -> Result<Self, DatabaseError> { ) -> Result<Self, CustomError> {
let field = ValidatedValue::new_identifier(field)?; let field = ValidatedValue::new_identifier(field)?;
let value = match value { let value = match value {
Some(v) => Some(match operator { Some(v) => Some(match operator {
Operator::Like => ValidatedValue::new_plain_text(v)?, Operator::Like => ValidatedValue::new_plain_text(v)?,
@ -140,19 +136,19 @@ pub enum WhereClause {
Or(Vec<WhereClause>), Or(Vec<WhereClause>),
Condition(WhereCondition), Condition(WhereCondition),
} }
#[derive(Debug, Clone)]
pub struct QueryBuilder { pub struct QueryBuilder {
operation: SqlOperation, operation: SqlOperation,
table: ValidatedValue, table: ValidatedValue,
fields: Vec<ValidatedValue>, fields: Vec<ValidatedValue>,
params: HashMap<ValidatedValue, ValidatedValue>, params: HashMap<ValidatedValue, ValidatedValue>,
where_clause: Option<WhereClause>, where_clause: Option<WhereClause>,
order_by: Option<ValidatedValue>, order_by: Option<ValidatedValue>,
limit: Option<i32>, limit: Option<i32>,
} }
impl QueryBuilder { impl QueryBuilder {
pub fn new(operation: SqlOperation, table: String) -> Result<Self, DatabaseError> { pub fn new(operation: SqlOperation, table: String) -> Result<Self, CustomError> {
Ok(QueryBuilder { Ok(QueryBuilder {
operation, operation,
table: ValidatedValue::new_identifier(table)?, table: ValidatedValue::new_identifier(table)?,
@ -164,7 +160,7 @@ impl QueryBuilder {
}) })
} }
pub fn build(&self) -> Result<(String, Vec<String>), DatabaseError> { pub fn build(&self) -> Result<(String, Vec<String>), CustomError> {
let mut query = String::new(); let mut query = String::new();
let mut values = Vec::new(); let mut values = Vec::new();
let mut param_counter = 1; let mut param_counter = 1;
@ -174,7 +170,8 @@ impl QueryBuilder {
let fields = if self.fields.is_empty() { let fields = if self.fields.is_empty() {
"*".to_string() "*".to_string()
} else { } else {
self.fields.iter() self.fields
.iter()
.map(|f| f.get().to_string()) .map(|f| f.get().to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", ") .join(", ")
@ -182,12 +179,9 @@ impl QueryBuilder {
query.push_str(&format!("SELECT {} FROM {}", fields, self.table.get())); query.push_str(&format!("SELECT {} FROM {}", fields, self.table.get()));
} }
SqlOperation::Insert => { SqlOperation::Insert => {
let fields: Vec<String> = self.params.keys() let fields: Vec<String> = self.params.keys().map(|k| k.get().to_string()).collect();
.map(|k| k.get().to_string()) let placeholders: Vec<String> =
.collect(); (1..=self.params.len()).map(|i| format!("${}", i)).collect();
let placeholders: Vec<String> = (1..=self.params.len())
.map(|i| format!("${}", i))
.collect();
query.push_str(&format!( query.push_str(&format!(
"INSERT INTO {} ({}) VALUES ({})", "INSERT INTO {} ({}) VALUES ({})",
@ -201,7 +195,8 @@ impl QueryBuilder {
} }
SqlOperation::Update => { SqlOperation::Update => {
query.push_str(&format!("UPDATE {} SET ", self.table.get())); query.push_str(&format!("UPDATE {} SET ", self.table.get()));
let set_clauses: Vec<String> = self.params let set_clauses: Vec<String> = self
.params
.iter() .iter()
.map(|(key, _)| { .map(|(key, _)| {
let placeholder = format!("${}", param_counter); let placeholder = format!("${}", param_counter);
@ -239,7 +234,7 @@ impl QueryBuilder {
&self, &self,
clause: &WhereClause, clause: &WhereClause,
mut param_counter: i32, mut param_counter: i32,
) -> Result<(String, Vec<String>), DatabaseError> { ) -> Result<(String, Vec<String>), CustomError> {
let mut values = Vec::new(); let mut values = Vec::new();
let sql = match clause { let sql = match clause {
@ -267,7 +262,12 @@ impl QueryBuilder {
if let Some(value) = &cond.value { if let Some(value) = &cond.value {
let placeholder = format!("${}", param_counter); let placeholder = format!("${}", param_counter);
values.push(value.get().to_string()); values.push(value.get().to_string());
format!("{} {} {}", cond.field.get(), cond.operator.as_str(), placeholder) format!(
"{} {} {}",
cond.field.get(),
cond.operator.as_str(),
placeholder
)
} else { } else {
format!("{} {}", cond.field.get(), cond.operator.as_str()) format!("{} {}", cond.field.get(), cond.operator.as_str())
} }
@ -281,7 +281,7 @@ impl QueryBuilder {
self self
} }
pub fn params(mut self, params: HashMap<ValidatedValue, ValidatedValue>) -> Self { pub fn params(mut self, params: HashMap<ValidatedValue, ValidatedValue>) -> Self {
self.params = params; self.params = params;
self self
} }
@ -300,4 +300,4 @@ impl QueryBuilder {
self.limit = Some(limit); self.limit = Some(limit);
self self
} }
} }

View File

@ -2,43 +2,20 @@ mod postgresql;
use crate::config; use crate::config;
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::HashMap; use std::collections::HashMap;
use std::error::Error; use crate::utils::CustomError;
use std::sync::Arc; use std::sync::Arc;
use std::fmt;
pub mod builder; pub mod builder;
#[derive(Debug)]
pub enum DatabaseError {
ValidationError(String),
SqlInjectionAttempt(String),
InvalidParameter(String),
ExecutionError(String),
}
impl fmt::Display for DatabaseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
DatabaseError::ValidationError(msg) => write!(f, "Validation error: {}", msg),
DatabaseError::SqlInjectionAttempt(msg) => write!(f, "SQL injection attempt: {}", msg),
DatabaseError::InvalidParameter(msg) => write!(f, "Invalid parameter: {}", msg),
DatabaseError::ExecutionError(msg) => write!(f, "Execution error: {}", msg),
}
}
}
impl Error for DatabaseError {}
#[async_trait] #[async_trait]
pub trait DatabaseTrait: Send + Sync { pub trait DatabaseTrait: Send + Sync {
async fn connect(database: &config::SqlConfig) -> Result<Self, Box<dyn Error>> async fn connect(database: &config::SqlConfig) -> Result<Self, CustomError>
where where
Self: Sized; Self: Sized;
async fn execute_query<'a>( async fn execute_query<'a>(
&'a self, &'a self,
builder: &builder::QueryBuilder, builder: &builder::QueryBuilder,
) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>>; ) -> Result<Vec<HashMap<String, String>>, CustomError>;
async fn initialization(database: config::SqlConfig) -> Result<(), Box<dyn Error>> async fn initialization(database: config::SqlConfig) -> Result<(), CustomError>
where where
Self: Sized; Self: Sized;
} }
@ -53,7 +30,7 @@ impl Database {
&self.db &self.db
} }
pub async fn link(database: &config::SqlConfig) -> Result<Self, Box<dyn Error>> { pub async fn link(database: &config::SqlConfig) -> Result<Self, CustomError> {
let db = match database.db_type.as_str() { let db = match database.db_type.as_str() {
"postgresql" => postgresql::Postgresql::connect(database).await?, "postgresql" => postgresql::Postgresql::connect(database).await?,
_ => return Err("unknown database type".into()), _ => return Err("unknown database type".into()),
@ -64,7 +41,7 @@ impl Database {
}) })
} }
pub async fn initial_setup(database: config::SqlConfig) -> Result<(), Box<dyn Error>> { pub async fn initial_setup(database: config::SqlConfig) -> Result<(), CustomError> {
match database.db_type.as_str() { match database.db_type.as_str() {
"postgresql" => postgresql::Postgresql::initialization(database).await?, "postgresql" => postgresql::Postgresql::initialization(database).await?,
_ => return Err("unknown database type".into()), _ => return Err("unknown database type".into()),

View File

@ -6,33 +6,13 @@ mod routes;
use chrono::Duration; use chrono::Duration;
use database::relational; use database::relational;
use rocket::{ use rocket::{
get, post, get, http::Status, launch, outcome::IntoOutcome, post, response::status, State
http::Status,
launch,
response::status,
State,
}; };
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use std::error::Error;
#[derive(Debug)]
pub enum AppError {
Database(String),
Config(String),
Auth(String),
}
impl From<AppError> for status::Custom<String> {
fn from(error: AppError) -> Self {
match error {
AppError::Database(msg) => status::Custom(Status::InternalServerError, format!("Database error: {}", msg)),
AppError::Config(msg) => status::Custom(Status::InternalServerError, format!("Config error: {}", msg)),
AppError::Auth(msg) => status::Custom(Status::InternalServerError, format!("Auth error: {}", msg)),
}
}
}
type AppResult<T> = Result<T, AppError>;
struct AppState { struct AppState {
db: Arc<Mutex<Option<relational::Database>>>, db: Arc<Mutex<Option<relational::Database>>>,
@ -40,18 +20,17 @@ struct AppState {
} }
impl AppState { impl AppState {
async fn get_sql(&self) -> AppResult<relational::Database> { async fn get_sql(&self) -> Result<relational::Database,Box<dyn Error>> {
self.db self.db
.lock() .lock()
.await .await
.clone() .clone()
.ok_or_else(|| AppError::Database("Database not initialized".into())) .ok_or_else(|| "Database not initialized".into())
} }
async fn link_sql(&self, config: config::SqlConfig) -> AppResult<()> { async fn link_sql(&self, config: config::SqlConfig) -> Result<,Box<dyn Error>> {
let database = relational::Database::link(&config) let database = relational::Database::link(&config)
.await .await?;
.map_err(|e| AppError::Database(e.to_string()))?;
*self.db.lock().await = Some(database); *self.db.lock().await = Some(database);
Ok(()) Ok(())
} }
@ -68,7 +47,7 @@ async fn token_system(_state: &State<AppState>) -> Result<status::Custom<String>
auth::jwt::generate_jwt(claims, Duration::seconds(1)) auth::jwt::generate_jwt(claims, Duration::seconds(1))
.map(|token| status::Custom(Status::Ok, token)) .map(|token| status::Custom(Status::Ok, token))
.map_err(|e| AppError::Auth(e.to_string()).into()) .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))
} }

View File

@ -1,7 +1,7 @@
use serde::{Deserialize,Serialize}; use serde::{Deserialize,Serialize};
use crate::{config,utils}; use crate::{config,utils};
use crate::database::relational; use crate::database::relational;
use crate::{AppState,AppError,AppResult}; use crate::AppState;
use rocket::{ use rocket::{
post, post,
http::Status, http::Status,
@ -28,20 +28,6 @@ pub struct InstallReplyData{
password:String, password:String,
} }
#[post("/test", format = "application/json", data = "<data>")]
pub async fn test(
data: Json<InstallData>,
state: &State<AppState>
) -> Result<status::Custom<String>, status::Custom<String>> {
let data=data.into_inner();
let sql= state.get_sql().await.map_err(|e| e)?;
let _ = person::insert(&sql,person::RegisterData{ name: data.name.clone(), email: data.email, password:data.password });
Ok(status::Custom(Status::Ok, "Installation successful".to_string()))
}
#[post("/install", format = "application/json", data = "<data>")] #[post("/install", format = "application/json", data = "<data>")]
pub async fn install( pub async fn install(
@ -58,6 +44,9 @@ pub async fn install(
relational::Database::initial_setup(data.sql_config.clone()) relational::Database::initial_setup(data.sql_config.clone())
.await .await
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?;
auth::jwt::generate_key();
config.info.install = true; config.info.install = true;

View File

@ -1,7 +1,6 @@
use serde::{Deserialize,Serialize}; use serde::{Deserialize,Serialize};
use crate::{config,utils}; use crate::{config,utils};
use crate::database::{relational,relational::builder}; use crate::database::{relational,relational::builder};
use crate::{AppError,AppResult};
use rocket::{ use rocket::{
get, post, get, post,
http::Status, http::Status,
@ -11,6 +10,8 @@ use rocket::{
}; };
use std::collections::HashMap; use std::collections::HashMap;
use bcrypt::{hash, verify, DEFAULT_COST}; use bcrypt::{hash, verify, DEFAULT_COST};
use crate::utils::CustomError;
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
@ -25,7 +26,7 @@ pub struct RegisterData{
pub password:String pub password:String
} }
pub async fn insert(sql:&relational::Database,data:RegisterData) -> AppResult<()>{ pub async fn insert(sql:&relational::Database,data:RegisterData) -> Result<(),CustomError>{
let hashed_password = hash(data.password, DEFAULT_COST).expect("Failed to hash password"); let hashed_password = hash(data.password, DEFAULT_COST).expect("Failed to hash password");
@ -46,16 +47,11 @@ pub async fn insert(sql:&relational::Database,data:RegisterData) -> AppResult<()
builder::ValidatedValue::PlainText(hashed_password) builder::ValidatedValue::PlainText(hashed_password)
); );
let builder = builder::QueryBuilder::new(builder::SqlOperation::Insert,String::from("persons")) let builder = builder::QueryBuilder::new(builder::SqlOperation::Insert,String::from("persons"))?
.map_err(|e|{
AppError::Database(format!("Error while building query: {}", e.to_string()))
})?
.params(user_params) .params(user_params)
; ;
let _= sql.get_db().execute_query(&builder).await.map_err(|e|{ sql.get_db().execute_query(&builder).await?;
AppError::Database(format!("Travel during execution: {}", e.to_string()))
})?;
Ok(()) Ok(())
} }

View File

@ -1,9 +1,34 @@
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
pub fn generate_random_string(length: usize) -> String { pub fn generate_random_string(length: usize) -> String {
let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
(0..length) (0..length)
.map(|_| *charset.choose(&mut rng).unwrap() as char) .map(|_| *charset.choose(&mut rng).unwrap() as char)
.collect() .collect()
} }
pub struct CustomError(String);
impl std::fmt::Display for CustomError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl<T> From<T> for CustomError
where
T: std::error::Error + Send + 'static,
{
fn from(error: T) -> Self {
CustomError(error.to_string())
}
}
impl CustomError {
pub fn from_str(error: &str) -> Self {
CustomError(error.to_string())
}
}