后端:完成错误和Result的更换

This commit is contained in:
lsy 2024-11-21 19:07:42 +08:00
parent 3a88c33a6e
commit 4bf55506b9
11 changed files with 124 additions and 103 deletions

View File

@ -5,7 +5,7 @@ use ed25519_dalek::{SigningKey, VerifyingKey};
use std::fs::File; use std::fs::File;
use std::io::Write; use std::io::Write;
use std::{env, fs}; use std::{env, fs};
use crate::utils::CustomError; use crate::utils::CustomResult;
use rand::{SeedableRng, RngCore}; use rand::{SeedableRng, RngCore};
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
@ -27,7 +27,7 @@ impl SecretKey {
} }
} }
pub fn generate_key() -> Result<(),CustomError> { pub fn generate_key() -> CustomResult<()> {
let mut csprng = rand::rngs::StdRng::from_entropy(); let mut csprng = rand::rngs::StdRng::from_entropy();
let mut private_key_bytes = [0u8; 32]; let mut private_key_bytes = [0u8; 32];
@ -49,7 +49,7 @@ pub fn generate_key() -> Result<(),CustomError> {
Ok(()) Ok(())
} }
pub fn get_key(key_type: SecretKey) -> Result<[u8; 32],CustomError> { pub fn get_key(key_type: SecretKey) -> CustomResult<[u8; 32]> {
let path = env::current_dir()? let path = env::current_dir()?
.join("assets") .join("assets")
.join("key") .join("key")
@ -60,7 +60,7 @@ pub fn get_key(key_type: SecretKey) -> Result<[u8; 32],CustomError> {
Ok(key) Ok(key)
} }
pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> Result<String,CustomError> { pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> CustomResult<String> {
let key_bytes = get_key(SecretKey::Signing)?; let key_bytes = get_key(SecretKey::Signing)?;
let signing_key = SigningKey::from_bytes(&key_bytes); let signing_key = SigningKey::from_bytes(&key_bytes);
@ -79,7 +79,7 @@ pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> Result<String,C
Ok(token) Ok(token)
} }
pub fn validate_jwt(token: &str) -> Result<CustomClaims, CustomError> { pub fn validate_jwt(token: &str) -> CustomResult<CustomClaims> {
let key_bytes = get_key(SecretKey::Verifying)?; let key_bytes = get_key(SecretKey::Verifying)?;
let verifying = VerifyingKey::from_bytes(&key_bytes)?; let verifying = VerifyingKey::from_bytes(&key_bytes)?;
let token = UntrustedToken::new(token)?; let token = UntrustedToken::new(token)?;

View File

@ -1,6 +1,7 @@
use serde::{Deserialize,Serialize}; use serde::{Deserialize,Serialize};
use std::{ env, fs}; use std::{ env, fs};
use std::path::PathBuf; use std::path::PathBuf;
use crate::utils::CustomResult;
#[derive(Deserialize,Serialize,Debug,Clone)] #[derive(Deserialize,Serialize,Debug,Clone)]
pub struct Config { pub struct Config {
@ -35,17 +36,17 @@ pub struct NoSqlConfig {
} }
impl Config { impl Config {
pub fn read() -> Result<Self, Box<dyn std::error::Error>> { pub fn read() -> CustomResult<Self> {
let path=Self::get_path()?; let path=Self::get_path()?;
Ok(toml::from_str(&fs::read_to_string(path)?)?) Ok(toml::from_str(&fs::read_to_string(path)?)?)
} }
pub fn write(config:Config) -> Result<(), Box<dyn std::error::Error>> { pub fn write(config:Config) -> CustomResult<()> {
let path=Self::get_path()?; let path=Self::get_path()?;
fs::write(path, toml::to_string(&config)?)?; fs::write(path, toml::to_string(&config)?)?;
Ok(()) Ok(())
} }
pub fn get_path() -> Result<PathBuf, Box<dyn std::error::Error>> { pub fn get_path() -> CustomResult<PathBuf> {
Ok(env::current_dir()? Ok(env::current_dir()?
.join("assets") .join("assets")
.join("config.toml")) .join("config.toml"))

View File

@ -1,5 +1,5 @@
use regex::Regex; use regex::Regex;
use crate::utils::CustomError; use crate::utils::{CustomResult,CustomError};
use std::collections::HashMap; use std::collections::HashMap;
use std::hash::Hash; use std::hash::Hash;
@ -11,7 +11,7 @@ pub enum ValidatedValue {
} }
impl ValidatedValue { impl ValidatedValue {
pub fn new_identifier(value: String) -> Result<Self, CustomError> { pub fn new_identifier(value: String) -> CustomResult<Self> {
let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]{0,63}$").unwrap(); let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]{0,63}$").unwrap();
if !valid_pattern.is_match(&value) { if !valid_pattern.is_match(&value) {
return Err(CustomError::from_str("Invalid identifier format")); return Err(CustomError::from_str("Invalid identifier format"));
@ -19,7 +19,7 @@ impl ValidatedValue {
Ok(ValidatedValue::Identifier(value)) Ok(ValidatedValue::Identifier(value))
} }
pub fn new_rich_text(value: String) -> Result<Self, CustomError> { pub fn new_rich_text(value: String) -> CustomResult<Self> {
let dangerous_patterns = [ let dangerous_patterns = [
"UNION ALL SELECT", "UNION ALL SELECT",
"UNION SELECT", "UNION SELECT",
@ -44,7 +44,7 @@ impl ValidatedValue {
Ok(ValidatedValue::RichText(value)) Ok(ValidatedValue::RichText(value))
} }
pub fn new_plain_text(value: String) -> Result<Self, CustomError> { pub fn new_plain_text(value: String) -> CustomResult<Self> {
if value.contains(';') || value.contains("--") { if value.contains(';') || value.contains("--") {
return Err(CustomError::from_str("Invalid characters in text")); return Err(CustomError::from_str("Invalid characters in text"));
} }
@ -111,7 +111,7 @@ impl WhereCondition {
field: String, field: String,
operator: Operator, operator: Operator,
value: Option<String>, value: Option<String>,
) -> Result<Self, CustomError> { ) -> CustomResult<Self> {
let field = ValidatedValue::new_identifier(field)?; let field = ValidatedValue::new_identifier(field)?;
let value = match value { let value = match value {
@ -148,7 +148,7 @@ pub struct QueryBuilder {
} }
impl QueryBuilder { impl QueryBuilder {
pub fn new(operation: SqlOperation, table: String) -> Result<Self, CustomError> { pub fn new(operation: SqlOperation, table: String) -> CustomResult<Self> {
Ok(QueryBuilder { Ok(QueryBuilder {
operation, operation,
table: ValidatedValue::new_identifier(table)?, table: ValidatedValue::new_identifier(table)?,
@ -160,7 +160,7 @@ impl QueryBuilder {
}) })
} }
pub fn build(&self) -> Result<(String, Vec<String>), CustomError> { pub fn build(&self) -> CustomResult<(String, Vec<String>)> {
let mut query = String::new(); let mut query = String::new();
let mut values = Vec::new(); let mut values = Vec::new();
let mut param_counter = 1; let mut param_counter = 1;
@ -234,7 +234,7 @@ impl QueryBuilder {
&self, &self,
clause: &WhereClause, clause: &WhereClause,
mut param_counter: i32, mut param_counter: i32,
) -> Result<(String, Vec<String>), CustomError> { ) -> CustomResult<(String, Vec<String>)> {
let mut values = Vec::new(); let mut values = Vec::new();
let sql = match clause { let sql = match clause {

View File

@ -2,20 +2,20 @@ mod postgresql;
use crate::config; use crate::config;
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::HashMap; use std::collections::HashMap;
use crate::utils::CustomError; use crate::utils::{CustomResult,CustomError};
use std::sync::Arc; use std::sync::Arc;
pub mod builder; pub mod builder;
#[async_trait] #[async_trait]
pub trait DatabaseTrait: Send + Sync { pub trait DatabaseTrait: Send + Sync {
async fn connect(database: &config::SqlConfig) -> Result<Self, CustomError> async fn connect(database: &config::SqlConfig) -> CustomResult<Self>
where where
Self: Sized; Self: Sized;
async fn execute_query<'a>( async fn execute_query<'a>(
&'a self, &'a self,
builder: &builder::QueryBuilder, builder: &builder::QueryBuilder,
) -> Result<Vec<HashMap<String, String>>, CustomError>; ) -> CustomResult<Vec<HashMap<String, String>>>;
async fn initialization(database: config::SqlConfig) -> Result<(), CustomError> async fn initialization(database: config::SqlConfig) -> CustomResult<()>
where where
Self: Sized; Self: Sized;
} }
@ -30,10 +30,10 @@ impl Database {
&self.db &self.db
} }
pub async fn link(database: &config::SqlConfig) -> Result<Self, CustomError> { pub async fn link(database: &config::SqlConfig) -> CustomResult<Self> {
let db = match database.db_type.as_str() { let db = match database.db_type.as_str() {
"postgresql" => postgresql::Postgresql::connect(database).await?, "postgresql" => postgresql::Postgresql::connect(database).await?,
_ => return Err("unknown database type".into()), _ => return Err(CustomError::from_str("unknown database type")),
}; };
Ok(Self { Ok(Self {
@ -41,10 +41,10 @@ impl Database {
}) })
} }
pub async fn initial_setup(database: config::SqlConfig) -> Result<(), CustomError> { pub async fn initial_setup(database: config::SqlConfig) -> CustomResult<()> {
match database.db_type.as_str() { match database.db_type.as_str() {
"postgresql" => postgresql::Postgresql::initialization(database).await?, "postgresql" => postgresql::Postgresql::initialization(database).await?,
_ => return Err("unknown database type".into()), _ => return Err(CustomError::from_str("unknown database type")),
}; };
Ok(()) Ok(())
} }

View File

@ -2,8 +2,10 @@ use super::{DatabaseTrait,builder};
use crate::config; use crate::config;
use async_trait::async_trait; use async_trait::async_trait;
use sqlx::{Column, PgPool, Row, Executor}; use sqlx::{Column, PgPool, Row, Executor};
use std::{collections::HashMap, error::Error}; use std::collections::HashMap;
use std::{env, fs}; use std::{env, fs};
use crate::utils::CustomResult;
#[derive(Clone)] #[derive(Clone)]
pub struct Postgresql { pub struct Postgresql {
@ -12,7 +14,7 @@ pub struct Postgresql {
#[async_trait] #[async_trait]
impl DatabaseTrait for Postgresql { impl DatabaseTrait for Postgresql {
async fn initialization(db_config: config::SqlConfig) -> Result<(), Box<dyn Error>> { async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> {
let path = env::current_dir()? let path = env::current_dir()?
.join("src") .join("src")
.join("database") .join("database")
@ -34,15 +36,14 @@ impl DatabaseTrait for Postgresql {
Ok(()) Ok(())
} }
async fn connect(db_config: &config::SqlConfig) -> Result<Self, Box<dyn Error>> { async fn connect(db_config: &config::SqlConfig) -> CustomResult<Self> {
let connection_str = format!( let connection_str = format!(
"postgres://{}:{}@{}:{}/{}", "postgres://{}:{}@{}:{}/{}",
db_config.user, db_config.password, db_config.address, db_config.port, db_config.db_name db_config.user, db_config.password, db_config.address, db_config.port, db_config.db_name
); );
let pool = PgPool::connect(&connection_str) let pool = PgPool::connect(&connection_str)
.await .await?;
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
Ok(Postgresql { pool }) Ok(Postgresql { pool })
} }
@ -50,7 +51,7 @@ impl DatabaseTrait for Postgresql {
async fn execute_query<'a>( async fn execute_query<'a>(
&'a self, &'a self,
builder: &builder::QueryBuilder, builder: &builder::QueryBuilder,
) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>> { ) -> CustomResult<Vec<HashMap<String, String>>> {
let (query, values) = builder.build()?; let (query, values) = builder.build()?;
let mut sqlx_query = sqlx::query(&query); let mut sqlx_query = sqlx::query(&query);
@ -61,8 +62,7 @@ impl DatabaseTrait for Postgresql {
let rows = sqlx_query let rows = sqlx_query
.fetch_all(&self.pool) .fetch_all(&self.pool)
.await .await?;
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
let mut results = Vec::new(); let mut results = Vec::new();
for row in rows { for row in rows {

View File

@ -6,11 +6,11 @@ mod routes;
use chrono::Duration; use chrono::Duration;
use database::relational; use database::relational;
use rocket::{ use rocket::{
get, http::Status, launch, outcome::IntoOutcome, post, response::status, State get, http::Status, launch, response::status, State
}; };
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use std::error::Error; use utils::{CustomResult, AppResult,CustomError};
@ -20,15 +20,15 @@ struct AppState {
} }
impl AppState { impl AppState {
async fn get_sql(&self) -> Result<relational::Database,Box<dyn Error>> { async fn get_sql(&self) -> CustomResult<relational::Database> {
self.db self.db
.lock() .lock()
.await .await
.clone() .clone()
.ok_or_else(|| "Database not initialized".into()) .ok_or_else(|| CustomError::from_str("Database not initialized"))
} }
async fn link_sql(&self, config: config::SqlConfig) -> Result<,Box<dyn Error>> { async fn link_sql(&self, config: config::SqlConfig) -> Result<(),CustomError> {
let database = relational::Database::link(&config) let database = relational::Database::link(&config)
.await?; .await?;
*self.db.lock().await = Some(database); *self.db.lock().await = Some(database);
@ -40,7 +40,7 @@ impl AppState {
#[get("/system")] #[get("/system")]
async fn token_system(_state: &State<AppState>) -> Result<status::Custom<String>, status::Custom<String>> { async fn token_system(_state: &State<AppState>) -> AppResult<status::Custom<String>> {
let claims = auth::jwt::CustomClaims { let claims = auth::jwt::CustomClaims {
name: "system".into(), name: "system".into(),
}; };
@ -77,8 +77,7 @@ async fn rocket() -> _ {
} }
rocket_builder = rocket_builder rocket_builder = rocket_builder
.mount("/auth/token", rocket::routes![token_system]) .mount("/auth/token", rocket::routes![token_system]);
.mount("/", rocket::routes![routes::intsall::test]);
rocket_builder rocket_builder
} }

View File

@ -1,82 +1,98 @@
use serde::{Deserialize,Serialize};
use crate::{config,utils};
use crate::database::relational;
use crate::AppState;
use rocket::{
post,
http::Status,
response::status,
serde::json::Json,
State,
};
use crate::routes::person;
use crate::auth; use crate::auth;
use crate::database::relational;
use crate::routes::person;
use crate::utils::AppResult;
use crate::AppState;
use crate::{config, utils};
use chrono::Duration; use chrono::Duration;
use rocket::{http::Status, post, response::status, serde::json::Json, State};
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
pub struct InstallData{ pub struct InstallData {
name:String, name: String,
email:String, email: String,
password:String, password: String,
sql_config: config::SqlConfig sql_config: config::SqlConfig,
} }
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
pub struct InstallReplyData{ pub struct InstallReplyData {
token:String, token: String,
name:String, name: String,
password:String, password: String,
} }
#[post("/install", format = "application/json", data = "<data>")] #[post("/install", format = "application/json", data = "<data>")]
pub async fn install( pub async fn install(
data: Json<InstallData>, data: Json<InstallData>,
state: &State<AppState> state: &State<AppState>,
) -> Result<status::Custom<Json<InstallReplyData>>, status::Custom<String>> { ) -> AppResult<status::Custom<Json<InstallReplyData>>> {
let mut config = state.configure.lock().await; let mut config = state.configure.lock().await;
if config.info.install { if config.info.install {
return Err(status::Custom(Status::BadRequest, "Database already initialized".to_string())); return Err(status::Custom(
Status::BadRequest,
"Database already initialized".to_string(),
));
} }
let data=data.into_inner(); let data = data.into_inner();
relational::Database::initial_setup(data.sql_config.clone()) relational::Database::initial_setup(data.sql_config.clone())
.await .await
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?;
auth::jwt::generate_key(); let _ = auth::jwt::generate_key();
config.info.install = true; config.info.install = true;
state.link_sql(data.sql_config.clone()).await?; state
let sql= state.get_sql().await?; .link_sql(data.sql_config.clone())
.await
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?;
let sql = state
.get_sql()
.await
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?;
let system_name = utils::generate_random_string(20);
let system_password = utils::generate_random_string(20);
let system_name=utils::generate_random_string(20); let _ = person::insert(
let system_password=utils::generate_random_string(20); &sql,
person::RegisterData {
let _ = person::insert(&sql,person::RegisterData{ name: data.name.clone(), email: data.email, password:data.password }).await?; name: data.name.clone(),
let _ = person::insert(&sql,person::RegisterData{ name: system_name.clone(), email: String::from("author@lsy22.com"), password:system_name.clone() }).await?; email: data.email,
password: data.password,
},
)
.await
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()));
let _ = person::insert(
&sql,
person::RegisterData {
name: system_name.clone(),
email: String::from("author@lsy22.com"),
password: system_name.clone(),
},
)
.await
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()));
let token = auth::jwt::generate_jwt( let token = auth::jwt::generate_jwt(
auth::jwt::CustomClaims{name:data.name.clone()}, auth::jwt::CustomClaims {
Duration::days(7) name: data.name.clone(),
).map_err(|e| status::Custom(Status::Unauthorized, e.to_string()))?; },
Duration::days(7),
)
.map_err(|e| status::Custom(Status::Unauthorized, e.to_string()))?;
config::Config::write(config.clone()) config::Config::write(config.clone())
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?; .map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?;
Ok( Ok(status::Custom(
status::Custom( Status::Ok,
Status::Ok, Json(InstallReplyData {
Json(InstallReplyData{ token: token,
token:token, name: system_name,
name: system_name, password: system_password,
password: system_password }),
} ))
) }
)
)
}

View File

@ -2,7 +2,9 @@ pub mod intsall;
pub mod person; pub mod person;
use rocket::routes; use rocket::routes;
// pub fn create_routes() -> Vec<rocket::Route> { pub fn create_routes() -> Vec<rocket::Route> {
routes![
// } intsall::install,
]
}

View File

@ -9,8 +9,8 @@ use rocket::{
State, State,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use bcrypt::{hash, verify, DEFAULT_COST}; use bcrypt::{hash, DEFAULT_COST};
use crate::utils::CustomError; use crate::utils::CustomResult;
@ -26,7 +26,7 @@ pub struct RegisterData{
pub password:String pub password:String
} }
pub async fn insert(sql:&relational::Database,data:RegisterData) -> Result<(),CustomError>{ pub async fn insert(sql:&relational::Database,data:RegisterData) -> CustomResult<()>{
let hashed_password = hash(data.password, DEFAULT_COST).expect("Failed to hash password"); let hashed_password = hash(data.password, DEFAULT_COST).expect("Failed to hash password");

View File

@ -1,5 +1,5 @@
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rocket::response::status;
pub fn generate_random_string(length: usize) -> String { pub fn generate_random_string(length: usize) -> String {
let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
@ -8,7 +8,7 @@ pub fn generate_random_string(length: usize) -> String {
.map(|_| *charset.choose(&mut rng).unwrap() as char) .map(|_| *charset.choose(&mut rng).unwrap() as char)
.collect() .collect()
} }
#[derive(Debug)]
pub struct CustomError(String); pub struct CustomError(String);
impl std::fmt::Display for CustomError { impl std::fmt::Display for CustomError {
@ -17,7 +17,6 @@ impl std::fmt::Display for CustomError {
} }
} }
impl<T> From<T> for CustomError impl<T> From<T> for CustomError
where where
T: std::error::Error + Send + 'static, T: std::error::Error + Send + 'static,
@ -32,3 +31,7 @@ impl CustomError {
CustomError(error.to_string()) CustomError(error.to_string())
} }
} }
pub type CustomResult<T> = Result<T, CustomError>;
pub type AppResult<T> = Result<T, status::Custom<String>>;