diff --git a/backend/src/auth/jwt.rs b/backend/src/auth/jwt.rs index 92b58f7..b567024 100644 --- a/backend/src/auth/jwt.rs +++ b/backend/src/auth/jwt.rs @@ -5,7 +5,7 @@ use ed25519_dalek::{SigningKey, VerifyingKey}; use std::fs::File; use std::io::Write; use std::{env, fs}; -use crate::utils::CustomError; +use crate::utils::CustomResult; use rand::{SeedableRng, RngCore}; #[derive(Debug, Serialize, Deserialize, Clone)] @@ -27,7 +27,7 @@ impl SecretKey { } } -pub fn generate_key() -> Result<(),CustomError> { +pub fn generate_key() -> CustomResult<()> { let mut csprng = rand::rngs::StdRng::from_entropy(); let mut private_key_bytes = [0u8; 32]; @@ -49,7 +49,7 @@ pub fn generate_key() -> Result<(),CustomError> { Ok(()) } -pub fn get_key(key_type: SecretKey) -> Result<[u8; 32],CustomError> { +pub fn get_key(key_type: SecretKey) -> CustomResult<[u8; 32]> { let path = env::current_dir()? .join("assets") .join("key") @@ -60,7 +60,7 @@ pub fn get_key(key_type: SecretKey) -> Result<[u8; 32],CustomError> { Ok(key) } -pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> Result { +pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> CustomResult { let key_bytes = get_key(SecretKey::Signing)?; let signing_key = SigningKey::from_bytes(&key_bytes); @@ -79,7 +79,7 @@ pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> Result Result { +pub fn validate_jwt(token: &str) -> CustomResult { let key_bytes = get_key(SecretKey::Verifying)?; let verifying = VerifyingKey::from_bytes(&key_bytes)?; let token = UntrustedToken::new(token)?; diff --git a/backend/src/config.rs b/backend/src/config.rs index 9312fbe..ccbd8de 100644 --- a/backend/src/config.rs +++ b/backend/src/config.rs @@ -1,6 +1,7 @@ use serde::{Deserialize,Serialize}; use std::{ env, fs}; use std::path::PathBuf; +use crate::utils::CustomResult; #[derive(Deserialize,Serialize,Debug,Clone)] pub struct Config { @@ -35,17 +36,17 @@ pub struct NoSqlConfig { } impl Config { - pub fn read() -> Result> { + pub fn read() -> CustomResult { let path=Self::get_path()?; Ok(toml::from_str(&fs::read_to_string(path)?)?) } - pub fn write(config:Config) -> Result<(), Box> { + pub fn write(config:Config) -> CustomResult<()> { let path=Self::get_path()?; fs::write(path, toml::to_string(&config)?)?; Ok(()) } - pub fn get_path() -> Result> { + pub fn get_path() -> CustomResult { Ok(env::current_dir()? .join("assets") .join("config.toml")) diff --git a/backend/src/database/relational/builder.rs b/backend/src/database/relational/builder.rs index 8b46cf2..cd38459 100644 --- a/backend/src/database/relational/builder.rs +++ b/backend/src/database/relational/builder.rs @@ -1,5 +1,5 @@ use regex::Regex; -use crate::utils::CustomError; +use crate::utils::{CustomResult,CustomError}; use std::collections::HashMap; use std::hash::Hash; @@ -11,7 +11,7 @@ pub enum ValidatedValue { } impl ValidatedValue { - pub fn new_identifier(value: String) -> Result { + pub fn new_identifier(value: String) -> CustomResult { let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]{0,63}$").unwrap(); if !valid_pattern.is_match(&value) { return Err(CustomError::from_str("Invalid identifier format")); @@ -19,7 +19,7 @@ impl ValidatedValue { Ok(ValidatedValue::Identifier(value)) } - pub fn new_rich_text(value: String) -> Result { + pub fn new_rich_text(value: String) -> CustomResult { let dangerous_patterns = [ "UNION ALL SELECT", "UNION SELECT", @@ -44,7 +44,7 @@ impl ValidatedValue { Ok(ValidatedValue::RichText(value)) } - pub fn new_plain_text(value: String) -> Result { + pub fn new_plain_text(value: String) -> CustomResult { if value.contains(';') || value.contains("--") { return Err(CustomError::from_str("Invalid characters in text")); } @@ -111,7 +111,7 @@ impl WhereCondition { field: String, operator: Operator, value: Option, - ) -> Result { + ) -> CustomResult { let field = ValidatedValue::new_identifier(field)?; let value = match value { @@ -148,7 +148,7 @@ pub struct QueryBuilder { } impl QueryBuilder { - pub fn new(operation: SqlOperation, table: String) -> Result { + pub fn new(operation: SqlOperation, table: String) -> CustomResult { Ok(QueryBuilder { operation, table: ValidatedValue::new_identifier(table)?, @@ -160,7 +160,7 @@ impl QueryBuilder { }) } - pub fn build(&self) -> Result<(String, Vec), CustomError> { + pub fn build(&self) -> CustomResult<(String, Vec)> { let mut query = String::new(); let mut values = Vec::new(); let mut param_counter = 1; @@ -234,7 +234,7 @@ impl QueryBuilder { &self, clause: &WhereClause, mut param_counter: i32, - ) -> Result<(String, Vec), CustomError> { + ) -> CustomResult<(String, Vec)> { let mut values = Vec::new(); let sql = match clause { diff --git a/backend/src/database/relational/mod.rs b/backend/src/database/relational/mod.rs index 5d2d213..9c14ff3 100644 --- a/backend/src/database/relational/mod.rs +++ b/backend/src/database/relational/mod.rs @@ -2,20 +2,20 @@ mod postgresql; use crate::config; use async_trait::async_trait; use std::collections::HashMap; -use crate::utils::CustomError; +use crate::utils::{CustomResult,CustomError}; use std::sync::Arc; pub mod builder; #[async_trait] pub trait DatabaseTrait: Send + Sync { - async fn connect(database: &config::SqlConfig) -> Result + async fn connect(database: &config::SqlConfig) -> CustomResult where Self: Sized; async fn execute_query<'a>( &'a self, builder: &builder::QueryBuilder, - ) -> Result>, CustomError>; - async fn initialization(database: config::SqlConfig) -> Result<(), CustomError> + ) -> CustomResult>>; + async fn initialization(database: config::SqlConfig) -> CustomResult<()> where Self: Sized; } @@ -30,10 +30,10 @@ impl Database { &self.db } - pub async fn link(database: &config::SqlConfig) -> Result { + pub async fn link(database: &config::SqlConfig) -> CustomResult { let db = match database.db_type.as_str() { "postgresql" => postgresql::Postgresql::connect(database).await?, - _ => return Err("unknown database type".into()), + _ => return Err(CustomError::from_str("unknown database type")), }; Ok(Self { @@ -41,10 +41,10 @@ impl Database { }) } - pub async fn initial_setup(database: config::SqlConfig) -> Result<(), CustomError> { + pub async fn initial_setup(database: config::SqlConfig) -> CustomResult<()> { match database.db_type.as_str() { "postgresql" => postgresql::Postgresql::initialization(database).await?, - _ => return Err("unknown database type".into()), + _ => return Err(CustomError::from_str("unknown database type")), }; Ok(()) } diff --git a/backend/src/database/relational/postgresql/mod.rs b/backend/src/database/relational/postgresql/mod.rs index 566775b..b875f29 100644 --- a/backend/src/database/relational/postgresql/mod.rs +++ b/backend/src/database/relational/postgresql/mod.rs @@ -2,8 +2,10 @@ use super::{DatabaseTrait,builder}; use crate::config; use async_trait::async_trait; use sqlx::{Column, PgPool, Row, Executor}; -use std::{collections::HashMap, error::Error}; +use std::collections::HashMap; use std::{env, fs}; +use crate::utils::CustomResult; + #[derive(Clone)] pub struct Postgresql { @@ -12,7 +14,7 @@ pub struct Postgresql { #[async_trait] impl DatabaseTrait for Postgresql { - async fn initialization(db_config: config::SqlConfig) -> Result<(), Box> { + async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> { let path = env::current_dir()? .join("src") .join("database") @@ -34,15 +36,14 @@ impl DatabaseTrait for Postgresql { Ok(()) } - async fn connect(db_config: &config::SqlConfig) -> Result> { + async fn connect(db_config: &config::SqlConfig) -> CustomResult { let connection_str = format!( "postgres://{}:{}@{}:{}/{}", db_config.user, db_config.password, db_config.address, db_config.port, db_config.db_name ); let pool = PgPool::connect(&connection_str) - .await - .map_err(|e| Box::new(e) as Box)?; + .await?; Ok(Postgresql { pool }) } @@ -50,7 +51,7 @@ impl DatabaseTrait for Postgresql { async fn execute_query<'a>( &'a self, builder: &builder::QueryBuilder, - ) -> Result>, Box> { + ) -> CustomResult>> { let (query, values) = builder.build()?; let mut sqlx_query = sqlx::query(&query); @@ -61,8 +62,7 @@ impl DatabaseTrait for Postgresql { let rows = sqlx_query .fetch_all(&self.pool) - .await - .map_err(|e| Box::new(e) as Box)?; + .await?; let mut results = Vec::new(); for row in rows { diff --git a/backend/src/database/relational/uilts.rs b/backend/src/database/relational/uilts.rs deleted file mode 100644 index e69de29..0000000 diff --git a/backend/src/main.rs b/backend/src/main.rs index 62f8d56..385d465 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -6,11 +6,11 @@ mod routes; use chrono::Duration; use database::relational; use rocket::{ - get, http::Status, launch, outcome::IntoOutcome, post, response::status, State + get, http::Status, launch, response::status, State }; use std::sync::Arc; use tokio::sync::Mutex; -use std::error::Error; +use utils::{CustomResult, AppResult,CustomError}; @@ -20,15 +20,15 @@ struct AppState { } impl AppState { - async fn get_sql(&self) -> Result> { + async fn get_sql(&self) -> CustomResult { self.db .lock() .await .clone() - .ok_or_else(|| "Database not initialized".into()) + .ok_or_else(|| CustomError::from_str("Database not initialized")) } - async fn link_sql(&self, config: config::SqlConfig) -> Result<,Box> { + async fn link_sql(&self, config: config::SqlConfig) -> Result<(),CustomError> { let database = relational::Database::link(&config) .await?; *self.db.lock().await = Some(database); @@ -40,7 +40,7 @@ impl AppState { #[get("/system")] -async fn token_system(_state: &State) -> Result, status::Custom> { +async fn token_system(_state: &State) -> AppResult> { let claims = auth::jwt::CustomClaims { name: "system".into(), }; @@ -77,8 +77,7 @@ async fn rocket() -> _ { } rocket_builder = rocket_builder - .mount("/auth/token", rocket::routes![token_system]) - .mount("/", rocket::routes![routes::intsall::test]); + .mount("/auth/token", rocket::routes![token_system]); rocket_builder } diff --git a/backend/src/routes/intsall.rs b/backend/src/routes/intsall.rs index 3fa354d..80a2f3f 100644 --- a/backend/src/routes/intsall.rs +++ b/backend/src/routes/intsall.rs @@ -1,82 +1,98 @@ -use serde::{Deserialize,Serialize}; -use crate::{config,utils}; -use crate::database::relational; -use crate::AppState; -use rocket::{ - post, - http::Status, - response::status, - serde::json::Json, - State, -}; -use crate::routes::person; use crate::auth; +use crate::database::relational; +use crate::routes::person; +use crate::utils::AppResult; +use crate::AppState; +use crate::{config, utils}; use chrono::Duration; - +use rocket::{http::Status, post, response::status, serde::json::Json, State}; +use serde::{Deserialize, Serialize}; #[derive(Deserialize, Serialize)] -pub struct InstallData{ - name:String, - email:String, - password:String, - sql_config: config::SqlConfig +pub struct InstallData { + name: String, + email: String, + password: String, + sql_config: config::SqlConfig, } #[derive(Deserialize, Serialize)] -pub struct InstallReplyData{ - token:String, - name:String, - password:String, +pub struct InstallReplyData { + token: String, + name: String, + password: String, } - #[post("/install", format = "application/json", data = "")] pub async fn install( data: Json, - state: &State -) -> Result>, status::Custom> { + state: &State, +) -> AppResult>> { let mut config = state.configure.lock().await; if config.info.install { - return Err(status::Custom(Status::BadRequest, "Database already initialized".to_string())); + return Err(status::Custom( + Status::BadRequest, + "Database already initialized".to_string(), + )); } - let data=data.into_inner(); + let data = data.into_inner(); relational::Database::initial_setup(data.sql_config.clone()) .await .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; - - auth::jwt::generate_key(); + let _ = auth::jwt::generate_key(); config.info.install = true; - state.link_sql(data.sql_config.clone()).await?; - let sql= state.get_sql().await?; + state + .link_sql(data.sql_config.clone()) + .await + .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; + let sql = state + .get_sql() + .await + .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; + let system_name = utils::generate_random_string(20); + let system_password = utils::generate_random_string(20); - let system_name=utils::generate_random_string(20); - let system_password=utils::generate_random_string(20); - - let _ = person::insert(&sql,person::RegisterData{ name: data.name.clone(), email: data.email, password:data.password }).await?; - let _ = person::insert(&sql,person::RegisterData{ name: system_name.clone(), email: String::from("author@lsy22.com"), password:system_name.clone() }).await?; + let _ = person::insert( + &sql, + person::RegisterData { + name: data.name.clone(), + email: data.email, + password: data.password, + }, + ) + .await + .map_err(|e| status::Custom(Status::InternalServerError, e.to_string())); + let _ = person::insert( + &sql, + person::RegisterData { + name: system_name.clone(), + email: String::from("author@lsy22.com"), + password: system_name.clone(), + }, + ) + .await + .map_err(|e| status::Custom(Status::InternalServerError, e.to_string())); let token = auth::jwt::generate_jwt( - auth::jwt::CustomClaims{name:data.name.clone()}, - Duration::days(7) - ).map_err(|e| status::Custom(Status::Unauthorized, e.to_string()))?; - - + auth::jwt::CustomClaims { + name: data.name.clone(), + }, + Duration::days(7), + ) + .map_err(|e| status::Custom(Status::Unauthorized, e.to_string()))?; config::Config::write(config.clone()) - .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; - Ok( - status::Custom( - Status::Ok, - Json(InstallReplyData{ - token:token, - name: system_name, - password: system_password - } - ) - ) - ) -} \ No newline at end of file + .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; + Ok(status::Custom( + Status::Ok, + Json(InstallReplyData { + token: token, + name: system_name, + password: system_password, + }), + )) +} diff --git a/backend/src/routes/mod.rs b/backend/src/routes/mod.rs index 9d3a597..c245595 100644 --- a/backend/src/routes/mod.rs +++ b/backend/src/routes/mod.rs @@ -2,7 +2,9 @@ pub mod intsall; pub mod person; use rocket::routes; -// pub fn create_routes() -> Vec { - -// } +pub fn create_routes() -> Vec { + routes![ + intsall::install, + ] +} diff --git a/backend/src/routes/person.rs b/backend/src/routes/person.rs index b1cc03d..6834b9e 100644 --- a/backend/src/routes/person.rs +++ b/backend/src/routes/person.rs @@ -9,8 +9,8 @@ use rocket::{ State, }; use std::collections::HashMap; -use bcrypt::{hash, verify, DEFAULT_COST}; -use crate::utils::CustomError; +use bcrypt::{hash, DEFAULT_COST}; +use crate::utils::CustomResult; @@ -26,7 +26,7 @@ pub struct RegisterData{ pub password:String } -pub async fn insert(sql:&relational::Database,data:RegisterData) -> Result<(),CustomError>{ +pub async fn insert(sql:&relational::Database,data:RegisterData) -> CustomResult<()>{ let hashed_password = hash(data.password, DEFAULT_COST).expect("Failed to hash password"); diff --git a/backend/src/utils.rs b/backend/src/utils.rs index dc56d9f..5a74e5f 100644 --- a/backend/src/utils.rs +++ b/backend/src/utils.rs @@ -1,5 +1,5 @@ use rand::seq::SliceRandom; - +use rocket::response::status; pub fn generate_random_string(length: usize) -> String { let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; @@ -8,7 +8,7 @@ pub fn generate_random_string(length: usize) -> String { .map(|_| *charset.choose(&mut rng).unwrap() as char) .collect() } - +#[derive(Debug)] pub struct CustomError(String); impl std::fmt::Display for CustomError { @@ -17,7 +17,6 @@ impl std::fmt::Display for CustomError { } } - impl From for CustomError where T: std::error::Error + Send + 'static, @@ -32,3 +31,7 @@ impl CustomError { CustomError(error.to_string()) } } + +pub type CustomResult = Result; + +pub type AppResult = Result>;