后端:优化代码结构,改进数据库查询构建,增强数据库返回数据

This commit is contained in:
lsy 2024-11-25 22:43:24 +08:00
parent b1854b4fb8
commit 2c1923da07
11 changed files with 215 additions and 200 deletions

View File

@ -1,16 +1,12 @@
use crate::error::CustomErrorInto;
use crate::error::CustomResult;
use crate::error::{CustomErrorInto, CustomResult};
use bcrypt::{hash, verify, DEFAULT_COST};
pub fn generate_hash(s: &str) -> CustomResult<String> {
let hashed = hash(s, DEFAULT_COST)?;
Ok(hashed)
Ok(hash(s, DEFAULT_COST)?)
}
pub fn verify_hash(s: &str, hash: &str) -> CustomResult<()> {
let is_valid = verify(s, hash)?;
if !is_valid {
return Err("密码无效".into_custom_error());
}
Ok(())
verify(s, hash)?
.then_some(())
.ok_or_else(|| "密码无效".into_custom_error())
}

View File

@ -4,9 +4,7 @@ use ed25519_dalek::{SigningKey, VerifyingKey};
use jwt_compact::{alg::Ed25519, AlgorithmExt, Header, TimeOptions, Token, UntrustedToken};
use rand::{RngCore, SeedableRng};
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::Write;
use std::{env, fs};
use std::{env, fs, path::PathBuf};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CustomClaims {
@ -19,73 +17,74 @@ pub enum SecretKey {
}
impl SecretKey {
fn as_string(&self) -> String {
const fn as_str(&self) -> &'static str {
match self {
Self::Signing => String::from("signing"),
Self::Verifying => String::from("verifying"),
Self::Signing => "signing",
Self::Verifying => "verifying",
}
}
}
fn get_key_path(key_type: &SecretKey) -> CustomResult<PathBuf> {
Ok(env::current_dir()?
.join("assets")
.join("key")
.join(key_type.as_str()))
}
pub fn generate_key() -> CustomResult<()> {
let mut csprng = rand::rngs::StdRng::from_entropy();
let mut private_key_bytes = [0u8; 32];
csprng.fill_bytes(&mut private_key_bytes);
let signing_key = SigningKey::from_bytes(&private_key_bytes);
let verifying_key = signing_key.verifying_key();
let base_path = env::current_dir()?.join("assets").join("key");
let base_path = get_key_path(&SecretKey::Signing)?
.parent()
.unwrap()
.to_path_buf();
fs::create_dir_all(&base_path)?;
File::create(base_path.join(SecretKey::Signing.as_string()))?
.write_all(signing_key.as_bytes())?;
File::create(base_path.join(SecretKey::Verifying.as_string()))?
.write_all(verifying_key.as_bytes())?;
fs::write(get_key_path(&SecretKey::Signing)?, signing_key.as_bytes())?;
fs::write(
get_key_path(&SecretKey::Verifying)?,
verifying_key.as_bytes(),
)?;
Ok(())
}
pub fn get_key(key_type: SecretKey) -> CustomResult<[u8; 32]> {
let path = env::current_dir()?
.join("assets")
.join("key")
.join(key_type.as_string());
let key_bytes = fs::read(path)?;
let key_bytes = fs::read(get_key_path(&key_type)?)?;
let mut key = [0u8; 32];
key.copy_from_slice(&key_bytes[..32]);
Ok(key)
}
pub fn generate_jwt(claims: CustomClaims, duration: Duration) -> CustomResult<String> {
let key_bytes = get_key(SecretKey::Signing)?;
let signing_key = SigningKey::from_bytes(&key_bytes);
let signing_key = SigningKey::from_bytes(&get_key(SecretKey::Signing)?);
let time_options = TimeOptions::new(Duration::seconds(0), Utc::now);
let claims = jwt_compact::Claims::new(claims)
.set_duration_and_issuance(&time_options, duration)
.set_not_before(Utc::now());
let header = Header::empty();
let token = Ed25519.token(&header, &claims, &signing_key)?;
Ok(token)
Ok(Ed25519.token(&Header::empty(), &claims, &signing_key)?)
}
pub fn validate_jwt(token: &str) -> CustomResult<CustomClaims> {
let key_bytes = get_key(SecretKey::Verifying)?;
let verifying = VerifyingKey::from_bytes(&key_bytes)?;
let token = UntrustedToken::new(token)?;
let token: Token<CustomClaims> = Ed25519.validator(&verifying).validate(&token)?;
let verifying = VerifyingKey::from_bytes(&get_key(SecretKey::Verifying)?)?;
let time_options = TimeOptions::new(Duration::seconds(0), Utc::now);
let token: Token<CustomClaims> = Ed25519
.validator(&verifying)
.validate(&UntrustedToken::new(token)?)?;
token
.claims()
.validate_expiration(&time_options)?
.validate_maturity(&time_options)?;
let claims = token.claims().custom.clone();
Ok(claims)
Ok(token.claims().custom.clone())
}

View File

@ -11,7 +11,6 @@ pub struct Config {
pub sql_config: SqlConfig,
}
impl Default for Config {
fn default() -> Self {
Self {

View File

@ -77,7 +77,7 @@ impl TextValidator {
let max_length = self
.level_max_lengths
.get(&level)
.ok_or( "Invalid validation level".into_custom_error())?;
.ok_or("Invalid validation level".into_custom_error())?;
if text.len() > *max_length {
return Err("Text exceeds maximum length".into_custom_error());
@ -499,7 +499,7 @@ impl QueryBuilder {
for condition in conditions {
let (sql, mut condition_params) =
self.build_where_clause_with_index(condition, param_index)?;
param_index += condition_params.len(); // 更新参数索引
param_index += condition_params.len();
parts.push(sql);
params.append(&mut condition_params);
}

View File

@ -2,8 +2,7 @@ mod postgresql;
use crate::config;
use crate::error::{CustomErrorInto, CustomResult};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
pub mod builder;
#[async_trait]
@ -14,7 +13,7 @@ pub trait DatabaseTrait: Send + Sync {
async fn execute_query<'a>(
&'a self,
builder: &builder::QueryBuilder,
) -> CustomResult<Vec<HashMap<String, String>>>;
) -> CustomResult<Vec<HashMap<String, serde_json::Value>>>;
async fn initialization(database: config::SqlConfig) -> CustomResult<()>
where
Self: Sized;

View File

@ -3,7 +3,8 @@ use crate::config;
use crate::error::CustomErrorInto;
use crate::error::CustomResult;
use async_trait::async_trait;
use sqlx::{Column, Executor, PgPool, Row};
use serde_json::Value;
use sqlx::{Column, Executor, PgPool, Row, TypeInfo};
use std::collections::HashMap;
use std::{env, fs};
#[derive(Clone)]
@ -64,7 +65,7 @@ impl DatabaseTrait for Postgresql {
async fn execute_query<'a>(
&'a self,
builder: &builder::QueryBuilder,
) -> CustomResult<Vec<HashMap<String, String>>> {
) -> CustomResult<Vec<HashMap<String, Value>>> {
let (query, values) = builder.build()?;
let mut sqlx_query = sqlx::query(&query);
@ -73,22 +74,32 @@ impl DatabaseTrait for Postgresql {
sqlx_query = sqlx_query.bind(value.to_sql_string()?);
}
let rows = sqlx_query.fetch_all(&self.pool).await.map_err(|e| {
let (sql, params) = builder.build().unwrap();
format!("Err:{}\n,SQL: {}\nParams: {:?}", e.to_string(), sql, params)
.into_custom_error()
})?;
let rows = sqlx_query.fetch_all(&self.pool).await?;
let mut results = Vec::new();
for row in rows {
let mut map = HashMap::new();
for column in row.columns() {
let value: String = row.try_get(column.name()).unwrap_or_default();
map.insert(column.name().to_string(), value);
}
results.push(map);
}
Ok(results)
Ok(rows
.into_iter()
.map(|row| {
row.columns()
.iter()
.map(|col| {
let value = match col.type_info().name() {
"INT4" | "INT8" => Value::Number(
row.try_get::<i64, _>(col.name()).unwrap_or_default().into(),
),
"FLOAT4" | "FLOAT8" => Value::Number(
serde_json::Number::from_f64(
row.try_get::<f64, _>(col.name()).unwrap_or(0.0),
)
.unwrap_or_else(|| 0.into()),
),
"BOOL" => Value::Bool(row.try_get(col.name()).unwrap_or_default()),
"JSON" | "JSONB" => row.try_get(col.name()).unwrap_or(Value::Null),
_ => Value::String(row.try_get(col.name()).unwrap_or_default()),
};
(col.name().to_string(), value)
})
.collect()
})
.collect())
}
}

View File

@ -1,14 +1,15 @@
mod auth;
mod config;
mod database;
mod error;
mod routes;
mod utils;
use database::relational;
use error::{CustomErrorInto, CustomResult};
use rocket::Shutdown;
use std::sync::Arc;
use tokio::sync::Mutex;
mod error;
use error::{CustomErrorInto, CustomResult};
pub struct AppState {
db: Arc<Mutex<Option<relational::Database>>>,
@ -30,12 +31,11 @@ impl AppState {
.lock()
.await
.clone()
.ok_or("数据库未连接".into_custom_error())
.ok_or_else(|| "数据库未连接".into_custom_error())
}
pub async fn sql_link(&self, config: &config::SqlConfig) -> CustomResult<()> {
let database = relational::Database::link(config).await?;
*self.db.lock().await = Some(database);
*self.db.lock().await = Some(relational::Database::link(config).await?);
Ok(())
}
@ -45,14 +45,12 @@ impl AppState {
pub async fn trigger_restart(&self) -> CustomResult<()> {
*self.restart_progress.lock().await = true;
self.shutdown
.lock()
.await
.take()
.ok_or("未能获取rocket的shutdown".into_custom_error())?
.ok_or_else(|| "未能获取rocket的shutdown".into_custom_error())?
.notify();
Ok(())
}
}
@ -60,39 +58,39 @@ impl AppState {
#[rocket::main]
async fn main() -> CustomResult<()> {
let config = config::Config::read().unwrap_or_default();
let state = Arc::new(AppState::new());
let rocket_config = rocket::Config::figment()
.merge(("address", config.address.clone()))
.merge(("address", config.address))
.merge(("port", config.port));
let state = AppState::new();
let state = Arc::new(state);
let mut rocket_builder = rocket::build()
.configure(rocket_config)
.manage(state.clone());
let rocket_builder = rocket::build().configure(rocket_config).manage(state.clone());
let rocket_builder = if !config.info.install {
rocket_builder.mount("/", rocket::routes![routes::install::install])
if !config.info.install {
rocket_builder = rocket_builder.mount("/", rocket::routes![routes::install::install]);
} else {
state.sql_link(&config.sql_config).await?;
rocket_builder
rocket_builder = rocket_builder
.mount("/auth/token", routes::jwt_routes())
.mount("/config", routes::configure_routes())
};
.mount("/config", routes::configure_routes());
}
let rocket = rocket_builder.ignite().await?;
rocket
.state::<Arc<AppState>>()
.ok_or("未能获取AppState".into_custom_error())?
.ok_or_else(|| "未能获取AppState".into_custom_error())?
.set_shutdown(rocket.shutdown())
.await;
rocket.launch().await?;
let restart_progress = *state.restart_progress.lock().await;
if restart_progress {
let current_exe = std::env::current_exe()?;
let _ = std::process::Command::new(current_exe).spawn();
if *state.restart_progress.lock().await {
if let Ok(current_exe) = std::env::current_exe() {
let _ = std::process::Command::new(current_exe).spawn();
}
}
std::process::exit(0);
}

View File

@ -23,81 +23,74 @@ pub async fn token_system(
state: &State<Arc<AppState>>,
data: Json<TokenSystemData>,
) -> AppResult<String> {
let name_condition = builder::Condition::new(
"person_name".to_string(),
builder::Operator::Eq,
Some(builder::SafeValue::Text(
data.name.to_string(),
builder::ValidationLevel::Relaxed,
)),
)
.into_app_result()?;
let email_condition = builder::Condition::new(
"person_email".to_string(),
builder::Operator::Eq,
Some(builder::SafeValue::Text(
"author@lsy22.com".to_string(),
builder::ValidationLevel::Relaxed,
)),
)
.into_app_result()?;
let level_condition = builder::Condition::new(
"person_level".to_string(),
builder::Operator::Eq,
Some(builder::SafeValue::Enum(
"administrators".to_string(),
"privilege_level".to_string(),
builder::ValidationLevel::Standard,
)),
)
.into_app_result()?;
let where_clause = builder::WhereClause::And(vec![
builder::WhereClause::Condition(name_condition),
builder::WhereClause::Condition(email_condition),
builder::WhereClause::Condition(level_condition),
]);
let mut builder =
builder::QueryBuilder::new(builder::SqlOperation::Select, String::from("persons"))
builder::QueryBuilder::new(builder::SqlOperation::Select, "persons".to_string())
.into_app_result()?;
let builder = builder
builder
.add_field("person_password".to_string())
.into_app_result()?;
.into_app_result()?
.add_condition(builder::WhereClause::And(vec![
builder::WhereClause::Condition(
builder::Condition::new(
"person_name".to_string(),
builder::Operator::Eq,
Some(builder::SafeValue::Text(
data.name.clone(),
builder::ValidationLevel::Relaxed,
)),
)
.into_app_result()?,
),
builder::WhereClause::Condition(
builder::Condition::new(
"person_email".to_string(),
builder::Operator::Eq,
Some(builder::SafeValue::Text(
"author@lsy22.com".into(),
builder::ValidationLevel::Relaxed,
)),
)
.into_app_result()?,
),
builder::WhereClause::Condition(
builder::Condition::new(
"person_level".to_string(),
builder::Operator::Eq,
Some(builder::SafeValue::Enum(
"administrators".into(),
"privilege_level".into(),
builder::ValidationLevel::Standard,
)),
)
.into_app_result()?,
),
]));
let sql_builder = builder.add_condition(where_clause);
let values = state
.sql_get()
.await
.into_app_result()?
.get_db()
.execute_query(&sql_builder)
.execute_query(&builder)
.await
.into_app_result()?;
let password = values
.first()
.ok_or(status::Custom(
Status::NotFound,
String::from("该用户并非系统用户"),
))?
.get("person_password")
.ok_or(status::Custom(
Status::NotFound,
String::from("该用户密码丢失"),
))?;
.and_then(|row| row.get("person_password"))
.and_then(|val| val.as_str())
.ok_or_else(|| {
status::Custom(Status::NotFound, "Invalid system user or password".into())
})?;
auth::bcrypt::verify_hash(&data.password, password).map_err(|_| {
status::Custom(Status::Forbidden, String::from("密码错误"))
})?;
auth::bcrypt::verify_hash(&data.password, password)
.map_err(|_| status::Custom(Status::Forbidden, "Invalid password".into()))?;
let claims = auth::jwt::CustomClaims {
name: "system".into(),
};
let token = auth::jwt::generate_jwt(claims, Duration::minutes(1)).into_app_result()?;
Ok(token)
Ok(auth::jwt::generate_jwt(
auth::jwt::CustomClaims {
name: "system".into(),
},
Duration::minutes(1),
)
.into_app_result()?)
}

View File

@ -39,15 +39,28 @@ pub async fn get_configure(
comfig_type: String,
name: String,
) -> CustomResult<Json<Value>> {
let name_condition = builder::Condition::new(
"config_name".to_string(),
builder::Operator::Eq,
Some(builder::SafeValue::Text(
format!("{}_{}", comfig_type, name),
builder::ValidationLevel::Strict,
)),
)?;
println!(
"Searching for config_name: {}",
format!("{}_{}", comfig_type, name)
);
let where_clause = builder::WhereClause::Condition(name_condition);
let mut sql_builder =
builder::QueryBuilder::new(builder::SqlOperation::Select, "config".to_string())?;
sql_builder.set_value(
"config_name".to_string(),
builder::SafeValue::Text(
format!("{}_{}", comfig_type, name).to_string(),
builder::ValidationLevel::Strict,
),
)?;
sql_builder
.add_condition(where_clause)
.add_field("config_data".to_string())?;
let result = sql.get_db().execute_query(&sql_builder).await?;
Ok(Json(json!(result)))
}
@ -76,9 +89,12 @@ pub async fn insert_configure(
}
#[get("/system")]
pub async fn system_config_get(state: &State<Arc<AppState>>,token: SystemToken) -> AppResult<Json<Value>> {
pub async fn system_config_get(
state: &State<Arc<AppState>>,
_token: SystemToken,
) -> AppResult<Json<Value>> {
let sql = state.sql_get().await.into_app_result()?;
let configure = get_configure(&sql, "system".to_string(), "configure".to_string())
let configure = get_configure(&sql, "system".to_string(), "config".to_string())
.await
.into_app_result()?;
Ok(configure)

View File

@ -1,14 +1,14 @@
use super::{configure, person};
use crate::auth;
use crate::database::relational;
use crate::error::{AppResult, AppResultInto};
use super::{person, configure};
use crate::AppState;
use crate::{config, utils};
use chrono::Duration;
use rocket::{http::Status, post, response::status, serde::json::Json, State};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use serde_json::json;
use std::sync::Arc;
#[derive(Deserialize, Serialize)]
pub struct InstallData {
@ -37,26 +37,26 @@ pub async fn install(
));
}
config.info.install = true;
config.sql_config = data.sql_config.clone();
let data = data.into_inner();
let sql = {
config.info.install = true;
config.sql_config = data.sql_config.clone();
relational::Database::initial_setup(data.sql_config.clone())
.await
.into_app_result()?;
relational::Database::initial_setup(data.sql_config.clone())
.await
.into_app_result()?;
auth::jwt::generate_key().into_app_result()?;
let _ = auth::jwt::generate_key();
state.sql_link(&data.sql_config).await.into_app_result()?;
state.sql_get().await.into_app_result()?
};
let system_credentials = (
utils::generate_random_string(20),
utils::generate_random_string(20),
);
state.sql_link(&data.sql_config).await.into_app_result()?;
let sql = state.sql_get().await.into_app_result()?;
let system_name = utils::generate_random_string(20);
let system_password = utils::generate_random_string(20);
let _ = person::insert(
person::insert(
&sql,
person::RegisterData {
name: data.name.clone(),
@ -68,41 +68,45 @@ pub async fn install(
.await
.into_app_result()?;
let _ = person::insert(
person::insert(
&sql,
person::RegisterData {
name: system_name.clone(),
email: String::from("author@lsy22.com"),
password: system_password.clone(),
name: system_credentials.0.clone(),
email: "author@lsy22.com".to_string(),
password: system_credentials.1.clone(),
level: "administrators".to_string(),
},
)
.await
.into_app_result()?;
let mut system_configure = configure::SystemConfigure::default();
system_configure.author_name = data.name.clone();
configure::insert_configure(&sql, "system".to_string(), "configure".to_string(), Json(json!(system_configure))).await.into_app_result()?;
configure::insert_configure(
&sql,
"system".to_string(),
"config".to_string(),
Json(json!(configure::SystemConfigure {
author_name: data.name.clone(),
..configure::SystemConfigure::default()
})),
)
.await
.into_app_result()?;
let token = auth::jwt::generate_jwt(
auth::jwt::CustomClaims {
name: data.name.clone(),
},
auth::jwt::CustomClaims { name: data.name },
Duration::days(7),
)
.into_app_result()?;
config::Config::write(config.clone()).into_app_result()?;
config::Config::write(config).into_app_result()?;
state.trigger_restart().await.into_app_result()?;
Ok(status::Custom(
Status::Ok,
Json(InstallReplyData {
token: token,
name: system_name,
password: system_password,
token,
name: system_credentials.0,
password: system_credentials.1,
}),
))
}

View File

@ -23,26 +23,26 @@ pub struct RegisterData {
pub async fn insert(sql: &relational::Database, data: RegisterData) -> CustomResult<()> {
let mut builder =
builder::QueryBuilder::new(builder::SqlOperation::Insert, "persons".to_string())?;
let password_hash = auth::bcrypt::generate_hash(&data.password)?;
builder
.set_value(
"person_name".to_string(),
builder::SafeValue::Text(data.name.to_string(), builder::ValidationLevel::Relaxed),
builder::SafeValue::Text(data.name, builder::ValidationLevel::Relaxed),
)?
.set_value(
"person_email".to_string(),
builder::SafeValue::Text(data.email.to_string(), builder::ValidationLevel::Relaxed),
builder::SafeValue::Text(data.email, builder::ValidationLevel::Relaxed),
)?
.set_value(
"person_password".to_string(),
builder::SafeValue::Text(password_hash, builder::ValidationLevel::Relaxed),
builder::SafeValue::Text(
bcrypt::generate_hash(&data.password)?,
builder::ValidationLevel::Relaxed,
),
)?
.set_value(
"person_level".to_string(),
builder::SafeValue::Enum(
data.level.to_string(),
data.level,
"privilege_level".to_string(),
builder::ValidationLevel::Standard,
),