diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 3433ed7..eec9e57 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -3,13 +3,19 @@ name = "echoes" version = "0.1.0" edition = "2021" + [dependencies] rocket = { version = "0.5", features = ["json"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" toml = "0.8.19" tokio = { version = "1", features = ["full"] } -sqlx = { version = "0.8.2", features = ["runtime-tokio-native-tls", "postgres"] } +sqlx = { version = "0.8.2", features = [ + "runtime-tokio-native-tls", + "postgres", + "mysql", + "sqlite" +] } async-trait = "0.1.83" jwt-compact = { version = "0.8.0", features = ["ed25519-dalek"] } ed25519-dalek = "2.1.1" @@ -17,5 +23,4 @@ rand = "0.8.5" chrono = "0.4" regex = "1.11.1" bcrypt = "0.16" -uuid = { version = "1.11.0", features = ["v4", "serde"] } hex = "0.4.3" \ No newline at end of file diff --git a/backend/src/api/auth/token.rs b/backend/src/api/auth/token.rs index b93fe5f..977907e 100644 --- a/backend/src/api/auth/token.rs +++ b/backend/src/api/auth/token.rs @@ -23,8 +23,12 @@ pub async fn token_system( state: &State>, data: Json, ) -> AppResult { + let sql = state + .sql_get() + .await + .into_app_result()?; let mut builder = - builder::QueryBuilder::new(builder::SqlOperation::Select, "users".to_string()) + builder::QueryBuilder::new(builder::SqlOperation::Select, sql.table_name("users"), sql.get_type()) .into_app_result()?; builder .add_field("password_hash".to_string()) @@ -56,9 +60,8 @@ pub async fn token_system( builder::Condition::new( "role".to_string(), builder::Operator::Eq, - Some(builder::SafeValue::Enum( + Some(builder::SafeValue::Text( "administrator".into(), - "user_role".into(), builder::ValidationLevel::Standard, )), ) @@ -66,10 +69,7 @@ pub async fn token_system( ), ])); - let values = state - .sql_get() - .await - .into_app_result()? + let values = sql .get_db() .execute_query(&builder) .await @@ -83,6 +83,8 @@ pub async fn token_system( status::Custom(Status::NotFound, "Invalid system user or password".into()) })?; + println!("{}\n{}",&data.password,password.clone()); + security::bcrypt::verify_hash(&data.password, password) .map_err(|_| status::Custom(Status::Forbidden, "Invalid password".into()))?; diff --git a/backend/src/api/settings.rs b/backend/src/api/settings.rs index ba81eb7..6df1363 100644 --- a/backend/src/api/settings.rs +++ b/backend/src/api/settings.rs @@ -2,6 +2,7 @@ use super::SystemToken; use crate::storage::{sql, sql::builder}; use crate::common::error::{AppResult, AppResultInto, CustomResult}; use crate::AppState; +use rocket::data; use rocket::{ get, http::Status, @@ -40,7 +41,7 @@ pub async fn get_setting( name: String, ) -> CustomResult> { let name_condition = builder::Condition::new( - "key".to_string(), + "name".to_string(), builder::Operator::Eq, Some(builder::SafeValue::Text( format!("{}_{}", comfig_type, name), @@ -51,7 +52,7 @@ pub async fn get_setting( let where_clause = builder::WhereClause::Condition(name_condition); let mut sql_builder = - builder::QueryBuilder::new(builder::SqlOperation::Select, "settings".to_string())?; + builder::QueryBuilder::new(builder::SqlOperation::Select, sql.table_name("settings"),sql.get_type())?; sql_builder .add_condition(where_clause) @@ -69,9 +70,9 @@ pub async fn insert_setting( data: Json, ) -> CustomResult<()> { let mut builder = - builder::QueryBuilder::new(builder::SqlOperation::Insert, "settings".to_string())?; + builder::QueryBuilder::new(builder::SqlOperation::Insert, sql.table_name("settings"),sql.get_type())?; builder.set_value( - "key".to_string(), + "name".to_string(), builder::SafeValue::Text( format!("{}_{}", comfig_type, name).to_string(), builder::ValidationLevel::Strict, @@ -79,7 +80,7 @@ pub async fn insert_setting( )?; builder.set_value( "data".to_string(), - builder::SafeValue::Json(data.into_inner()), + builder::SafeValue::Text(data.to_string(),builder::ValidationLevel::Relaxed), )?; sql.get_db().execute_query(&builder).await?; Ok(()) @@ -91,7 +92,7 @@ pub async fn system_config_get( _token: SystemToken, ) -> AppResult> { let sql = state.sql_get().await.into_app_result()?; - let settings = get_setting(&sql, "system".to_string(), "settings".to_string()) + let settings = get_setting(&sql, "system".to_string(), sql.table_name("settings")) .await .into_app_result()?; Ok(settings) diff --git a/backend/src/api/setup.rs b/backend/src/api/setup.rs index 5d4028c..ac1b25c 100644 --- a/backend/src/api/setup.rs +++ b/backend/src/api/setup.rs @@ -1,35 +1,22 @@ use super::{settings, users}; +use crate::common::config; +use crate::common::error::{AppResult, AppResultInto}; +use crate::common::helpers; use crate::security; use crate::storage::sql; -use crate::common::error::{AppResult, AppResultInto}; use crate::AppState; -use crate::common::config; -use crate::common::helpers; use chrono::Duration; +use rocket::data; use rocket::{http::Status, post, response::status, serde::json::Json, State}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::sync::Arc; -#[derive(Deserialize, Serialize,Debug)] -pub struct InstallData { - username: String, - email: String, - password: String, - sql_config: config::SqlConfig, -} -#[derive(Deserialize, Serialize,Debug)] -pub struct InstallReplyData { - token: String, - username: String, - password: String, -} - -#[post("/sql", format = "application/json", data = "")] -pub async fn steup_sql( - data: Json, +#[post("/sql", format = "application/json", data = "")] +pub async fn setup_sql( + sql_config: Json, state: &State>, -) -> AppResult>> { +) -> AppResult { let mut config = config::Config::read().unwrap_or_default(); if config.init.sql { return Err(status::Custom( @@ -37,19 +24,57 @@ pub async fn steup_sql( "Database already initialized".to_string(), )); } + + let sql_config = sql_config.into_inner(); + + config.init.sql = true; + config.sql_config = sql_config.clone(); + sql::Database::initial_setup(sql_config.clone()) + .await + .into_app_result()?; + + config::Config::write(config).into_app_result()?; + state.trigger_restart().await.into_app_result()?; + Ok("Database installation successful".to_string()) +} + +#[derive(Deserialize, Serialize, Debug)] +pub struct StepAccountData { + username: String, + email: String, + password: String, +} + +#[derive(Deserialize, Serialize, Debug)] +pub struct InstallReplyData { + token: String, + username: String, + password: String, +} + + +#[post("/administrator", format = "application/json", data = "")] +pub async fn setup_account( + data: Json, + state: &State>, +) -> AppResult>> { + let mut config = config::Config::read().unwrap_or_default(); + if config.init.administrator { + return Err(status::Custom( + Status::BadRequest, + "Administrator user has been set".to_string(), + )); + } + + security::jwt::generate_key().into_app_result()?; + let data = data.into_inner(); + let sql = { - config.init.sql = true; - config.sql_config = data.sql_config.clone(); - sql::Database::initial_setup(data.sql_config.clone()) - .await - .into_app_result()?; - security::jwt::generate_key().into_app_result()?; - state.sql_link(&data.sql_config).await.into_app_result()?; + state.sql_link(&config.sql_config).await.into_app_result()?; state.sql_get().await.into_app_result()? }; - let system_credentials = ( helpers::generate_random_string(20), helpers::generate_random_string(20), @@ -79,7 +104,6 @@ pub async fn steup_sql( .await .into_app_result()?; - settings::insert_setting( &sql, "system".to_string(), @@ -93,11 +117,13 @@ pub async fn steup_sql( .into_app_result()?; let token = security::jwt::generate_jwt( - security::jwt::CustomClaims { name: data.username }, + security::jwt::CustomClaims { + name: data.username, + }, Duration::days(7), ) .into_app_result()?; - + config.init.administrator=true; config::Config::write(config).into_app_result()?; state.trigger_restart().await.into_app_result()?; diff --git a/backend/src/api/users.rs b/backend/src/api/users.rs index 47f70d4..1b0d74b 100644 --- a/backend/src/api/users.rs +++ b/backend/src/api/users.rs @@ -20,37 +20,32 @@ pub struct RegisterData { pub role: String, } -pub async fn insert_user(sql: &sql::Database, data: RegisterData) -> CustomResult<()> { +pub async fn insert_user(sql: &sql::Database , data: RegisterData) -> CustomResult<()> { let role = match data.role.as_str() { "administrator" | "contributor" => data.role, _ => return Err("Invalid role. Must be either 'administrator' or 'contributor'".into_custom_error()), }; + let password_hash = bcrypt::generate_hash(&data.password)?; + let mut builder = - builder::QueryBuilder::new(builder::SqlOperation::Insert, "users".to_string())?; + builder::QueryBuilder::new(builder::SqlOperation::Insert, sql.table_name("users"), sql.get_type())?; builder .set_value( "username".to_string(), - builder::SafeValue::Text(data.username, builder::ValidationLevel::Relaxed), + builder::SafeValue::Text(data.username, builder::ValidationLevel::Standard), )? .set_value( "email".to_string(), - builder::SafeValue::Text(data.email, builder::ValidationLevel::Relaxed), + builder::SafeValue::Text(data.email, builder::ValidationLevel::Standard), )? .set_value( "password_hash".to_string(), - builder::SafeValue::Text( - bcrypt::generate_hash(&data.password)?, - builder::ValidationLevel::Relaxed, - ), + builder::SafeValue::Text(password_hash, builder::ValidationLevel::Relaxed), )? .set_value( "role".to_string(), - builder::SafeValue::Enum( - role, - "user_role".to_string(), - builder::ValidationLevel::Standard, - ), + builder::SafeValue::Text(role, builder::ValidationLevel::Strict), )?; sql.get_db().execute_query(&builder).await?; diff --git a/backend/src/common/config.rs b/backend/src/common/config.rs index b56fe78..c55bc22 100644 --- a/backend/src/common/config.rs +++ b/backend/src/common/config.rs @@ -47,17 +47,19 @@ pub struct SqlConfig { pub user: String, pub password: String, pub db_name: String, + pub db_prefix:String, } impl Default for SqlConfig { fn default() -> Self { Self { - db_type: "postgresql".to_string(), - address: "localhost".to_string(), - port: 5432, - user: "postgres".to_string(), - password: "postgres".to_string(), + db_type: "sqllite".to_string(), + address: "".to_string(), + port: 0, + user: "".to_string(), + password: "".to_string(), db_name: "echoes".to_string(), + db_prefix: "echoes_".to_string(), } } } diff --git a/backend/src/main.rs b/backend/src/main.rs index ee70f49..8c09594 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -3,12 +3,12 @@ mod common; mod storage; mod api; -use storage::sql; +use crate::common::config; use common::error::{CustomErrorInto, CustomResult}; use rocket::Shutdown; use std::sync::Arc; +use storage::sql; use tokio::sync::Mutex; -use crate::common::config; pub struct AppState { db: Arc>>, shutdown: Arc>>, @@ -25,11 +25,7 @@ impl AppState { } pub async fn sql_get(&self) -> CustomResult { - self.db - .lock() - .await - .clone() - .ok_or_else(|| "数据库未连接".into_custom_error()) + self.db.lock().await.clone().ok_or_else(|| "数据库未连接".into_custom_error()) } pub async fn sql_link(&self, config: &config::SqlConfig) -> CustomResult<()> { @@ -37,57 +33,53 @@ impl AppState { Ok(()) } + pub async fn set_shutdown(&self, shutdown: Shutdown) { *self.shutdown.lock().await = Some(shutdown); } pub async fn trigger_restart(&self) -> CustomResult<()> { *self.restart_progress.lock().await = true; - self.shutdown - .lock() - .await - .take() - .ok_or_else(|| "未能获取rocket的shutdown".into_custom_error())? - .notify(); + self.shutdown.lock().await.take().ok_or_else(|| "未能获取rocket的shutdown".into_custom_error())?.notify(); Ok(()) } } #[rocket::main] async fn main() -> CustomResult<()> { - let config = config::Config::read().unwrap_or_default(); + let config = config::Config::read().unwrap_or_else(|e| { + eprintln!("配置读取失败: {}", e); + config::Config::default() + }); let state = Arc::new(AppState::new()); - let rocket_config = rocket::Config::figment() - .merge(("address", config.address)) - .merge(("port", config.port)); + let rocket_config = rocket::Config::figment().merge(("address", config.address)).merge(("port", config.port)); - let mut rocket_builder = rocket::build() - .configure(rocket_config) - .manage(state.clone()); + let mut rocket_builder = rocket::build().configure(rocket_config).manage(state.clone()); - if !config.info.install { - rocket_builder = rocket_builder.mount("/", rocket::routes![api::setup::install]); + if !config.init.sql { + rocket_builder = rocket_builder.mount("/", rocket::routes![api::setup::setup_sql]); + } else if !config.init.administrator { + rocket_builder = rocket_builder.mount("/", rocket::routes![api::setup::setup_account]); } else { state.sql_link(&config.sql_config).await?; - rocket_builder = rocket_builder - .mount("/auth/token", api::jwt_routes()) - .mount("/config", api::configure_routes()); + rocket_builder = rocket_builder.mount("/auth/token", api::jwt_routes()).mount("/config", api::configure_routes()); } let rocket = rocket_builder.ignite().await?; - rocket - .state::>() - .ok_or_else(|| "未能获取AppState".into_custom_error())? - .set_shutdown(rocket.shutdown()) - .await; + rocket.state::>().ok_or_else(|| "未能获取AppState".into_custom_error())?.set_shutdown(rocket.shutdown()).await; rocket.launch().await?; if *state.restart_progress.lock().await { if let Ok(current_exe) = std::env::current_exe() { - let _ = std::process::Command::new(current_exe).spawn(); + match std::process::Command::new(current_exe).spawn() { + Ok(_) => println!("成功启动新进程"), + Err(e) => eprintln!("启动新进程失败: {}", e), + }; + } else { + eprintln!("获取当前可执行文件路径失败"); } } std::process::exit(0); diff --git a/backend/src/storage/sql/builder.rs b/backend/src/storage/sql/builder.rs index 46fae24..aedc3e0 100644 --- a/backend/src/storage/sql/builder.rs +++ b/backend/src/storage/sql/builder.rs @@ -2,16 +2,16 @@ use crate::common::error::{CustomErrorInto, CustomResult}; use chrono::{DateTime, Utc}; use regex::Regex; use serde::Serialize; -use serde_json::Value as JsonValue; use std::collections::HashMap; use std::hash::Hash; -use uuid::Uuid; +use crate::sql::schema::DatabaseType; #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, Serialize)] pub enum ValidationLevel { Strict, Standard, Relaxed, + Raw, } #[derive(Debug, Clone)] @@ -28,6 +28,7 @@ impl Default for TextValidator { (ValidationLevel::Strict, 100), (ValidationLevel::Standard, 1000), (ValidationLevel::Relaxed, 100000), + (ValidationLevel::Raw, usize::MAX), ]); let level_allowed_chars = HashMap::from([ @@ -43,6 +44,7 @@ impl Default for TextValidator { '}', '@', '#', '$', '%', '^', '&', '*', '+', '=', '<', '>', '/', '\\', ], ), + (ValidationLevel::Raw, vec![]), ]); TextValidator { @@ -74,6 +76,9 @@ impl Default for TextValidator { impl TextValidator { pub fn validate(&self, text: &str, level: ValidationLevel) -> CustomResult<()> { + if level == ValidationLevel::Raw { + return self.validate_sql_patterns(text); + } let max_length = self .level_max_lengths .get(&level) @@ -140,6 +145,9 @@ impl TextValidator { pub fn validate_strict(&self, text: &str) -> CustomResult<()> { self.validate(text, ValidationLevel::Strict) } + pub fn validate_raw(&self, text: &str) -> CustomResult<()> { + self.validate(text, ValidationLevel::Raw) + } pub fn sanitize(&self, text: &str) -> CustomResult { self.validate_relaxed(text)?; @@ -155,77 +163,31 @@ pub enum SafeValue { Float(f64), Text(String, ValidationLevel), DateTime(DateTime), - Uuid(Uuid), - Binary(Vec), - Array(Vec), - Json(JsonValue), - Enum(String, String, ValidationLevel), +} + +impl std::fmt::Display for SafeValue { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + SafeValue::Null => write!(f, "NULL"), + SafeValue::Bool(b) => write!(f, "{}", b), + SafeValue::Integer(i) => write!(f, "{}", i), + SafeValue::Float(f_val) => write!(f, "{}", f_val), + SafeValue::Text(s, _) => write!(f, "{}", s), + SafeValue::DateTime(dt) => write!(f, "{}", dt.to_rfc3339()), + } + } } impl SafeValue { - pub fn from_json(value: JsonValue, level: ValidationLevel) -> CustomResult { - match value { - JsonValue::Null => Ok(SafeValue::Null), - JsonValue::Bool(b) => Ok(SafeValue::Bool(b)), - JsonValue::Number(n) => { - if let Some(i) = n.as_i64() { - Ok(SafeValue::Integer(i)) - } else if let Some(f) = n.as_f64() { - Ok(SafeValue::Float(f)) - } else { - Err("Invalid number format".into_custom_error()) - } - } - JsonValue::String(s) => { - TextValidator::default().validate(&s, level)?; - Ok(SafeValue::Text(s, level)) - } - JsonValue::Array(arr) => Ok(SafeValue::Array( - arr.into_iter() - .map(|item| SafeValue::from_json(item, level)) - .collect::>>()?, - )), - JsonValue::Object(_) => { - Self::validate_json_structure(&value, level)?; - Ok(SafeValue::Json(value)) - } - } - } - - fn validate_json_structure(value: &JsonValue, level: ValidationLevel) -> CustomResult<()> { - let validator = TextValidator::default(); - match value { - JsonValue::Object(map) => { - for (key, val) in map { - validator.validate(key, level)?; - Self::validate_json_structure(val, level)?; - } - } - JsonValue::Array(arr) => { - arr.iter() - .try_for_each(|item| Self::validate_json_structure(item, level))?; - } - JsonValue::String(s) => validator.validate(s, level)?, - _ => {} - } - Ok(()) - } fn get_sql_type(&self) -> CustomResult { let sql_type = match self { SafeValue::Null => "NULL", - SafeValue::Bool(_) => "boolean", - SafeValue::Integer(_) => "bigint", - SafeValue::Float(_) => "double precision", - SafeValue::Text(_, _) => "text", - SafeValue::DateTime(_) => "timestamp with time zone", - SafeValue::Uuid(_) => "uuid", - SafeValue::Binary(_) => "bytea", - SafeValue::Array(_) | SafeValue::Json(_) => "jsonb", - SafeValue::Enum(_, enum_type, level) => { - TextValidator::default().validate(enum_type, *level)?; - return Ok(enum_type.replace('\'', "''")); - } + SafeValue::Bool(_) => "BOOLEAN", + SafeValue::Integer(_) => "INTEGER", + SafeValue::Float(_) => "REAL", + SafeValue::Text(_, _) => "TEXT", + SafeValue::DateTime(_) => "TEXT", }; Ok(sql_type.to_string()) } @@ -233,37 +195,27 @@ impl SafeValue { pub fn to_sql_string(&self) -> CustomResult { match self { SafeValue::Null => Ok("NULL".to_string()), - SafeValue::Bool(b) => Ok(b.to_string()), + SafeValue::Bool(b) => Ok(if *b { "true" } else { "false" }.to_string()), SafeValue::Integer(i) => Ok(i.to_string()), SafeValue::Float(f) => Ok(f.to_string()), SafeValue::Text(s, level) => { TextValidator::default().validate(s, *level)?; - Ok(s.replace('\'', "''")) - } - SafeValue::DateTime(dt) => Ok(format!("'{}'", dt.to_rfc3339())), - SafeValue::Uuid(u) => Ok(format!("'{}'", u)), - SafeValue::Binary(b) => Ok(format!("'\\x{}'", hex::encode(b))), - SafeValue::Array(arr) => { - let values: CustomResult> = arr.iter().map(|v| v.to_sql_string()).collect(); - Ok(format!("ARRAY[{}]", values?.join(","))) - } - SafeValue::Json(j) => { - let json_str = serde_json::to_string(j)?; - TextValidator::default().validate(&json_str, ValidationLevel::Relaxed)?; - Ok(json_str.replace('\'', "''")) - } - SafeValue::Enum(s, _, level) => { - TextValidator::default().validate(s, *level)?; - Ok(s.to_string()) + Ok(format!("{}", s.replace('\'', "''"))) } + SafeValue::DateTime(dt) => Ok(format!("{}", dt.to_rfc3339())), } } - fn to_param_sql(&self, param_index: usize) -> CustomResult { + fn to_param_sql(&self, param_index: usize, db_type: DatabaseType) -> CustomResult { if matches!(self, SafeValue::Null) { - Ok("NULL".to_string()) - } else { - Ok(format!("${}::{}", param_index, self.get_sql_type()?)) + return Ok("NULL".to_string()); + } + + // 根据数据库类型返回不同的参数占位符 + match db_type { + DatabaseType::MySQL => Ok("?".to_string()), + DatabaseType::PostgreSQL => Ok(format!("${}", param_index)), + DatabaseType::SQLite => Ok("?".to_string()), } } } @@ -305,12 +257,10 @@ pub enum Operator { In, IsNull, IsNotNull, - JsonContains, - JsonExists, } impl Operator { - fn as_str(&self) -> &'static str { + pub fn as_str(&self) -> &'static str { match self { Operator::Eq => "=", Operator::Ne => "!=", @@ -322,17 +272,15 @@ impl Operator { Operator::In => "IN", Operator::IsNull => "IS NULL", Operator::IsNotNull => "IS NOT NULL", - Operator::JsonContains => "@>", - Operator::JsonExists => "?", } } } #[derive(Debug, Clone)] pub struct Condition { - field: Identifier, - operator: Operator, - value: Option, + pub field: Identifier, + pub operator: Operator, + pub value: Option, } impl Condition { @@ -362,10 +310,11 @@ pub struct QueryBuilder { order_by: Option, limit: Option, offset: Option, + db_type: DatabaseType, } impl QueryBuilder { - pub fn new(operation: SqlOperation, table: String) -> CustomResult { + pub fn new(operation: SqlOperation, table: String, db_type: DatabaseType) -> CustomResult { Ok(QueryBuilder { operation, table: Identifier::new(table)?, @@ -375,6 +324,7 @@ impl QueryBuilder { order_by: None, limit: None, offset: None, + db_type, }) } @@ -438,7 +388,7 @@ impl QueryBuilder { if matches!(value, SafeValue::Null) { placeholders.push("NULL".to_string()); } else { - placeholders.push(format!("${}::{}", params.len() + 1, value.get_sql_type()?)); + placeholders.push(value.to_param_sql(params.len() + 1, self.db_type)?); params.push(value.clone()); } } @@ -458,11 +408,13 @@ impl QueryBuilder { let mut updates = Vec::new(); for (field, value) in &self.values { - let set_sql = format!( - "{} = {}", - field.as_str(), - value.to_param_sql(params.len() + 1)? - ); + let placeholder = if matches!(value, SafeValue::Null) { + "NULL".to_string() + } else { + value.to_param_sql(params.len() + 1, self.db_type)? + }; + + let set_sql = format!("{} = {}", field.as_str(), placeholder); if !matches!(value, SafeValue::Null) { params.push(value.clone()); } @@ -542,11 +494,17 @@ impl QueryBuilder { ) -> CustomResult { match &condition.value { Some(value) => { + let placeholder = if matches!(value, SafeValue::Null) { + "NULL".to_string() + } else { + value.to_param_sql(param_index, self.db_type)? + }; + let sql = format!( "{} {} {}", condition.field.as_str(), condition.operator.as_str(), - value.to_param_sql(param_index)? + placeholder ); if !matches!(value, SafeValue::Null) { params.push(value.clone()); diff --git a/backend/src/storage/sql/mod.rs b/backend/src/storage/sql/mod.rs index f4e987c..db3df33 100644 --- a/backend/src/storage/sql/mod.rs +++ b/backend/src/storage/sql/mod.rs @@ -1,11 +1,16 @@ mod postgresql; +mod mysql; +mod sqllite; +pub mod builder; +mod schema; + use crate::config; use crate::common::error::{CustomErrorInto, CustomResult}; use async_trait::async_trait; use std::{collections::HashMap, sync::Arc}; -pub mod builder; +use schema::DatabaseType; -#[async_trait] +#[async_trait] pub trait DatabaseTrait: Send + Sync { async fn connect(database: &config::SqlConfig) -> CustomResult where @@ -22,6 +27,8 @@ pub trait DatabaseTrait: Send + Sync { #[derive(Clone)] pub struct Database { pub db: Arc>, + pub prefix: Arc, + pub db_type: Arc } impl Database { @@ -29,22 +36,45 @@ impl Database { &self.db } + pub fn get_prefix(&self) -> &str { + &self.prefix + } + + pub fn get_type(&self) -> DatabaseType { + match self.db_type.as_str() { + "postgresql" => DatabaseType::PostgreSQL, + "mysql" => DatabaseType::MySQL, + _ => DatabaseType::SQLite, + } + } + + pub fn table_name(&self, name: &str) -> String { + format!("{}{}", self.prefix, name) + } + pub async fn link(database: &config::SqlConfig) -> CustomResult { - let db = match database.db_type.as_str() { - "postgresql" => postgresql::Postgresql::connect(database).await?, + let db: Box = match database.db_type.as_str() { + "postgresql" => Box::new(postgresql::Postgresql::connect(database).await?), + "mysql" => Box::new(mysql::Mysql::connect(database).await?), + "sqllite" => Box::new(sqllite::Sqlite::connect(database).await?), _ => return Err("unknown database type".into_custom_error()), }; Ok(Self { - db: Arc::new(Box::new(db)), + db: Arc::new(db), + prefix: Arc::new(database.db_prefix.clone()), + db_type: Arc::new(database.db_type.clone()) }) } pub async fn initial_setup(database: config::SqlConfig) -> CustomResult<()> { match database.db_type.as_str() { "postgresql" => postgresql::Postgresql::initialization(database).await?, + "mysql" => mysql::Mysql::initialization(database).await?, + "sqllite" => sqllite::Sqlite::initialization(database).await?, _ => return Err("unknown database type".into_custom_error()), }; Ok(()) } + } diff --git a/backend/src/storage/sql/mysql.rs b/backend/src/storage/sql/mysql.rs new file mode 100644 index 0000000..31e590a --- /dev/null +++ b/backend/src/storage/sql/mysql.rs @@ -0,0 +1,110 @@ +use super::{builder::{self, SafeValue}, schema, DatabaseTrait}; +use crate::config; +use crate::common::error::CustomResult; +use async_trait::async_trait; +use serde_json::Value; +use sqlx::mysql::MySqlPool; +use sqlx::{Column, Executor, Row, TypeInfo}; +use std::collections::HashMap; +use chrono::{DateTime, Utc}; + +#[derive(Clone)] +pub struct Mysql { + pool: MySqlPool, +} + +#[async_trait] +impl DatabaseTrait for Mysql { + async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> { + let db_prefix = SafeValue::Text(format!("{}",db_config.db_prefix), builder::ValidationLevel::Strict); + let grammar = schema::generate_schema(schema::DatabaseType::MySQL,db_prefix)?; + let connection_str = format!( + "mysql://{}:{}@{}:{}", + db_config.user, db_config.password, db_config.address, db_config.port + ); + + let pool = MySqlPool::connect(&connection_str).await?; + + pool.execute(format!("CREATE DATABASE `{}`", db_config.db_name).as_str()).await?; + pool.execute(format!( + "ALTER DATABASE `{}` DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci", + db_config.db_name + ).as_str()).await?; + + + let new_connection_str = format!( + "mysql://{}:{}@{}:{}/{}", + db_config.user, + db_config.password, + db_config.address, + db_config.port, + db_config.db_name + ); + let new_pool = MySqlPool::connect(&new_connection_str).await?; + + new_pool.execute(grammar.as_str()).await?; + Ok(()) + } + async fn connect(db_config: &config::SqlConfig) -> CustomResult { + let connection_str = format!( + "mysql://{}:{}@{}:{}/{}", + db_config.user, + db_config.password, + db_config.address, + db_config.port, + db_config.db_name + ); + + let pool = MySqlPool::connect(&connection_str).await?; + + Ok(Mysql { pool }) + } + + async fn execute_query<'a>( + &'a self, + builder: &builder::QueryBuilder, + ) -> CustomResult>> { + let (query, values) = builder.build()?; + + let mut sqlx_query = sqlx::query(&query); + + for value in values { + match value { + SafeValue::Null => sqlx_query = sqlx_query.bind(None::), + SafeValue::Bool(b) => sqlx_query = sqlx_query.bind(b), + SafeValue::Integer(i) => sqlx_query = sqlx_query.bind(i), + SafeValue::Float(f) => sqlx_query = sqlx_query.bind(f), + SafeValue::Text(s, _) => sqlx_query = sqlx_query.bind(s), + SafeValue::DateTime(dt) => sqlx_query = sqlx_query.bind(dt.to_rfc3339()), + } + } + + let rows = sqlx_query.fetch_all(&self.pool).await?; + + Ok(rows + .into_iter() + .map(|row| { + row.columns() + .iter() + .map(|col| { + let value = match col.type_info().name() { + "INT4" | "INT8" => Value::Number( + row.try_get::(col.name()).unwrap_or_default().into(), + ), + "FLOAT4" | "FLOAT8" => Value::Number( + serde_json::Number::from_f64( + row.try_get::(col.name()).unwrap_or(0.0), + ) + .unwrap_or_else(|| 0.into()), + ), + "BOOL" => Value::Bool(row.try_get(col.name()).unwrap_or_default()), + _ => Value::String(row.try_get(col.name()).unwrap_or_default()), + }; + (col.name().to_string(), value) + }) + .collect() + }) + .collect()) + } + +} \ No newline at end of file diff --git a/backend/src/storage/sql/postgresql/mod.rs b/backend/src/storage/sql/postgresql.rs similarity index 74% rename from backend/src/storage/sql/postgresql/mod.rs rename to backend/src/storage/sql/postgresql.rs index 6206e59..64100f8 100644 --- a/backend/src/storage/sql/postgresql/mod.rs +++ b/backend/src/storage/sql/postgresql.rs @@ -1,11 +1,15 @@ -use super::{builder, DatabaseTrait}; +use super::{ + builder::{self, SafeValue}, + schema, DatabaseTrait, +}; +use crate::common::error::{CustomError, CustomErrorInto, CustomResult}; use crate::config; -use crate::common::error::CustomResult; use async_trait::async_trait; +use chrono::{DateTime, Utc}; use serde_json::Value; use sqlx::{Column, Executor, PgPool, Row, TypeInfo}; use std::collections::HashMap; -use std::{env, fs}; + #[derive(Clone)] pub struct Postgresql { pool: PgPool, @@ -14,13 +18,11 @@ pub struct Postgresql { #[async_trait] impl DatabaseTrait for Postgresql { async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> { - let path = env::current_dir()? - .join("src") - .join("storage") - .join("sql") - .join("postgresql") - .join("schema.sql"); - let grammar = fs::read_to_string(&path)?; + let db_prefix = SafeValue::Text( + format!("{}", db_config.db_prefix), + builder::ValidationLevel::Strict, + ); + let grammar = schema::generate_schema(schema::DatabaseType::PostgreSQL, db_prefix)?; let connection_str = format!( "postgres://{}:{}@{}:{}", @@ -70,7 +72,14 @@ impl DatabaseTrait for Postgresql { let mut sqlx_query = sqlx::query(&query); for value in values { - sqlx_query = sqlx_query.bind(value.to_sql_string()?); + match value { + SafeValue::Null => sqlx_query = sqlx_query.bind(None::), + SafeValue::Bool(b) => sqlx_query = sqlx_query.bind(b), + SafeValue::Integer(i) => sqlx_query = sqlx_query.bind(i), + SafeValue::Float(f) => sqlx_query = sqlx_query.bind(f), + SafeValue::Text(s, _) => sqlx_query = sqlx_query.bind(s), + SafeValue::DateTime(dt) => sqlx_query = sqlx_query.bind(dt.to_rfc3339()), + } } let rows = sqlx_query.fetch_all(&self.pool).await?; @@ -92,7 +101,6 @@ impl DatabaseTrait for Postgresql { .unwrap_or_else(|| 0.into()), ), "BOOL" => Value::Bool(row.try_get(col.name()).unwrap_or_default()), - "JSON" | "JSONB" => row.try_get(col.name()).unwrap_or(Value::Null), _ => Value::String(row.try_get(col.name()).unwrap_or_default()), }; (col.name().to_string(), value) @@ -102,3 +110,9 @@ impl DatabaseTrait for Postgresql { .collect()) } } + +impl Postgresql { + fn get_sdb(&self){ + let a=self.pool; + } +} \ No newline at end of file diff --git a/backend/src/storage/sql/postgresql/schema.sql b/backend/src/storage/sql/postgresql/schema.sql deleted file mode 100644 index 53462d4..0000000 --- a/backend/src/storage/sql/postgresql/schema.sql +++ /dev/null @@ -1,102 +0,0 @@ --- 自定义类型定义 -CREATE EXTENSION IF NOT EXISTS pgcrypto; - -CREATE TYPE user_role AS ENUM ('contributor', 'administrator'); -CREATE TYPE content_status AS ENUM ('draft', 'published', 'private', 'hidden'); - --- 用户表 -CREATE TABLE users -( - username VARCHAR(100) PRIMARY KEY, - avatar_url VARCHAR(255), - email VARCHAR(255) UNIQUE NOT NULL, - profile_icon VARCHAR(255), - password_hash VARCHAR(255) NOT NULL, - role user_role NOT NULL DEFAULT 'contributor', - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - last_login_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -); - --- 页面表 -CREATE TABLE pages -( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - meta_keywords VARCHAR(255) NOT NULL, - meta_description VARCHAR(255) NOT NULL, - title VARCHAR(255) NOT NULL, - content TEXT NOT NULL, - template VARCHAR(50), - custom_fields JSON, - status content_status DEFAULT 'draft' -); - --- 文章表 -CREATE TABLE posts -( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - author_id VARCHAR(100) NOT NULL REFERENCES users (username) ON DELETE CASCADE, - cover_image VARCHAR(255), - title VARCHAR(255) NOT NULL, - meta_keywords VARCHAR(255) NOT NULL, - meta_description VARCHAR(255) NOT NULL, - content TEXT NOT NULL, - status content_status DEFAULT 'draft', - is_editor BOOLEAN DEFAULT FALSE, - draft_content TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - published_at TIMESTAMP, - CONSTRAINT check_update_time CHECK (updated_at >= created_at) -); - --- 标签表 -CREATE TABLE tags -( - name VARCHAR(50) PRIMARY KEY CHECK (LOWER(name) = name), - icon VARCHAR(255) -); - --- 文章标签关联表 -CREATE TABLE post_tags -( - post_id UUID REFERENCES posts (id) ON DELETE CASCADE, - tag_id VARCHAR(50) REFERENCES tags (name) ON DELETE CASCADE, - PRIMARY KEY (post_id, tag_id) -); - --- 分类表 -CREATE TABLE categories -( - name VARCHAR(50) PRIMARY KEY, - parent_id VARCHAR(50), - FOREIGN KEY (parent_id) REFERENCES categories (name) -); - --- 文章分类关联表 -CREATE TABLE post_categories -( - post_id UUID REFERENCES posts (id) ON DELETE CASCADE, - category_id VARCHAR(50) REFERENCES categories (name) ON DELETE CASCADE, - PRIMARY KEY (post_id, category_id) -); - --- 资源库表 -CREATE TABLE resources -( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - author_id VARCHAR(100) NOT NULL REFERENCES users (username) ON DELETE CASCADE, - name VARCHAR(255) NOT NULL, - size_bytes BIGINT NOT NULL, - storage_path VARCHAR(255) NOT NULL UNIQUE, - file_type VARCHAR(50) NOT NULL, - category VARCHAR(50), - description VARCHAR(255), - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -); --- 配置表 -CREATE TABLE settings -( - key VARCHAR(50) PRIMARY KEY CHECK (LOWER(key) = key), - data JSON -); \ No newline at end of file diff --git a/backend/src/storage/sql/schema.rs b/backend/src/storage/sql/schema.rs new file mode 100644 index 0000000..f3145b6 --- /dev/null +++ b/backend/src/storage/sql/schema.rs @@ -0,0 +1,812 @@ +use super::builder::{Condition, Identifier, Operator, SafeValue, ValidationLevel, WhereClause}; +use crate::common::error::{CustomErrorInto, CustomResult}; +use std::{collections::HashMap, fmt::format}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum DatabaseType { + PostgreSQL, + MySQL, + SQLite, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum FieldType { + Integer(bool), + BigInt, + VarChar(usize), + Text, + Boolean, + Timestamp, +} + +#[derive(Debug, Clone)] +pub struct FieldConstraint { + pub is_primary: bool, + pub is_unique: bool, + pub is_nullable: bool, + pub default_value: Option, + pub check_constraint: Option, + pub foreign_key: Option, +} + +#[derive(Debug, Clone)] +pub enum ForeignKeyAction { + Cascade, + Restrict, + SetNull, + NoAction, + SetDefault, +} + +#[derive(Debug, Clone)] +pub struct ForeignKey { + pub ref_table: String, + pub ref_column: String, + pub on_delete: Option, + pub on_update: Option, +} + +impl ToString for ForeignKeyAction { + fn to_string(&self) -> String { + match self { + ForeignKeyAction::Cascade => "CASCADE", + ForeignKeyAction::Restrict => "RESTRICT", + ForeignKeyAction::SetNull => "SET NULL", + ForeignKeyAction::NoAction => "NO ACTION", + ForeignKeyAction::SetDefault => "SET DEFAULT", + } + .to_string() + } +} + +#[derive(Debug, Clone)] +pub struct Field { + pub name: Identifier, + pub field_type: FieldType, + pub constraints: FieldConstraint, + pub validation_level: ValidationLevel, +} + +#[derive(Debug, Clone)] +pub struct Table { + pub name: Identifier, + pub fields: Vec, + pub indexes: Vec, +} + +#[derive(Debug, Clone)] +pub struct Index { + pub name: Identifier, + pub fields: Vec, + pub is_unique: bool, +} + +impl FieldConstraint { + pub fn new() -> Self { + Self { + is_primary: false, + is_unique: false, + is_nullable: true, + default_value: None, + check_constraint: None, + foreign_key: None, + } + } + + pub fn primary(mut self) -> Self { + self.is_primary = true; + self.is_nullable = false; + self + } + + pub fn unique(mut self) -> Self { + self.is_unique = true; + self + } + + pub fn not_null(mut self) -> Self { + self.is_nullable = false; + self + } + + pub fn default(mut self, value: SafeValue) -> Self { + self.default_value = Some(value); + self + } + + pub fn check(mut self, clause: WhereClause) -> Self { + self.check_constraint = Some(clause); + self + } + + pub fn foreign_key(mut self, ref_table: String, ref_column: String) -> Self { + self.foreign_key = Some(ForeignKey { + ref_table, + ref_column, + on_delete: None, + on_update: None, + }); + self + } + + pub fn on_delete(mut self, action: ForeignKeyAction) -> Self { + if let Some(ref mut fk) = self.foreign_key { + fk.on_delete = Some(action); + } + self + } + + pub fn on_update(mut self, action: ForeignKeyAction) -> Self { + if let Some(ref mut fk) = self.foreign_key { + fk.on_update = Some(action); + } + self + } +} + +impl Field { + pub fn new( + name: &str, + field_type: FieldType, + constraints: FieldConstraint, + validation_level: ValidationLevel, + ) -> CustomResult { + Ok(Self { + name: Identifier::new(name.to_string())?, + field_type, + constraints, + validation_level, + }) + } + + fn field_type_sql(&self, db_type: DatabaseType) -> CustomResult { + Ok(match &self.field_type { + FieldType::Integer(auto_increment) => { + if *auto_increment && self.constraints.is_primary { + match db_type { + DatabaseType::MySQL => "INT AUTO_INCREMENT".to_string(), + DatabaseType::PostgreSQL => { + "INTEGER GENERATED ALWAYS AS IDENTITY".to_string() + } + DatabaseType::SQLite => "INTEGER".to_string(), + } + } else { + match db_type { + DatabaseType::MySQL => "INT".to_string(), + _ => "INTEGER".to_string(), + } + } + } + FieldType::BigInt => "BIGINT".to_string(), + FieldType::VarChar(size) => format!("VARCHAR({})", size), + FieldType::Text => "TEXT".to_string(), + FieldType::Boolean => match db_type { + DatabaseType::PostgreSQL => "BOOLEAN".to_string(), + DatabaseType::MySQL => "BOOLEAN".to_string(), + DatabaseType::SQLite => "INTEGER".to_string(), + }, + FieldType::Timestamp => match db_type { + DatabaseType::PostgreSQL => "TIMESTAMP WITH TIME ZONE".to_string(), + DatabaseType::MySQL => "TIMESTAMP".to_string(), + DatabaseType::SQLite => "TEXT".to_string(), + }, + }) + } + + fn build_check_constraint(check: &WhereClause) -> CustomResult { + match check { + WhereClause::Condition(condition) => { + let field_name = condition.field.as_str(); + match condition.operator { + Operator::In => { + if let Some(SafeValue::Text(values, _)) = &condition.value { + Ok(format!("{} IN {}", field_name, values)) + } else { + Err("Invalid IN clause value".into_custom_error()) + } + } + Operator::Eq + | Operator::Ne + | Operator::Gt + | Operator::Lt + | Operator::Gte + | Operator::Lte => { + if let Some(value) = &condition.value { + Ok(format!( + "{} {} {}", + field_name, + condition.operator.as_str(), + value.to_sql_string()? + )) + } else { + Err("Missing value for comparison".into_custom_error()) + } + } + _ => Err("Unsupported operator for CHECK constraint".into_custom_error()), + } + } + _ => { + Err("Only simple conditions are supported for CHECK constraints" + .into_custom_error()) + } + } + } + + pub fn to_sql(&self, db_type: DatabaseType) -> CustomResult { + let mut sql = format!("{} {}", self.name.as_str(), self.field_type_sql(db_type)?); + + if !self.constraints.is_nullable { + sql.push_str(" NOT NULL"); + } + if self.constraints.is_unique { + sql.push_str(" UNIQUE"); + } + if self.constraints.is_primary { + match (db_type, &self.field_type) { + (DatabaseType::SQLite, FieldType::Integer(true)) => { + sql.push_str(" PRIMARY KEY AUTOINCREMENT"); + } + (DatabaseType::MySQL, FieldType::Integer(true)) => { + sql.push_str(" PRIMARY KEY"); + } + (DatabaseType::PostgreSQL, FieldType::Integer(true)) => { + sql.push_str(" PRIMARY KEY"); + } + _ => sql.push_str(" PRIMARY KEY"), + } + } + if let Some(default) = &self.constraints.default_value { + sql.push_str(&format!(" DEFAULT {}", default.to_sql_string()?)); + } + if let Some(check) = &self.constraints.check_constraint { + let check_sql = Self::build_check_constraint(check)?; + sql.push_str(&format!(" CHECK ({})", check_sql)); + } + if let Some(fk) = &self.constraints.foreign_key { + sql.push_str(&format!(" REFERENCES {}({})", fk.ref_table, fk.ref_column)); + + if let Some(on_delete) = &fk.on_delete { + sql.push_str(&format!(" ON DELETE {}", on_delete.to_string())); + } + + if let Some(on_update) = &fk.on_update { + sql.push_str(&format!(" ON UPDATE {}", on_update.to_string())); + } + } + + Ok(sql) + } +} + +impl Table { + pub fn new(name: &str) -> CustomResult { + Ok(Self { + name: Identifier::new(name.to_string())?, + fields: Vec::new(), + indexes: Vec::new(), + }) + } + + pub fn add_field(&mut self, field: Field) -> &mut Self { + self.fields.push(field); + self + } + + pub fn add_index(&mut self, index: Index) -> &mut Self { + self.indexes.push(index); + self + } + + pub fn to_sql(&self, db_type: DatabaseType) -> CustomResult { + let fields_sql: CustomResult> = + self.fields.iter().map(|f| f.to_sql(db_type)).collect(); + let fields_sql = fields_sql?; + + let mut sql = format!( + "CREATE TABLE {} (\n {}\n);", + self.name.as_str(), + fields_sql.join(",\n ") + ); + + // 添加索引 + for index in &self.indexes { + sql.push_str(&format!( + "\n\n{}", + index.to_sql(self.name.as_str(), db_type)? + )); + } + + Ok(sql) + } +} + +impl Index { + pub fn new(name: &str, fields: Vec, is_unique: bool) -> CustomResult { + Ok(Self { + name: Identifier::new(name.to_string())?, + fields: fields + .into_iter() + .map(|f| Identifier::new(f)) + .collect::>>()?, + is_unique, + }) + } + + fn to_sql(&self, table_name: &str, db_type: DatabaseType) -> CustomResult { + let unique = if self.is_unique { "UNIQUE " } else { "" }; + Ok(format!( + "CREATE {}INDEX {} ON {} ({});", + unique, + self.name.as_str(), + table_name, + self.fields + .iter() + .map(|f| f.as_str()) + .collect::>() + .join(", ") + )) + } +} + +// Schema构建器 +#[derive(Debug, Default)] +pub struct SchemaBuilder { + tables: Vec, +} + +impl SchemaBuilder { + pub fn new() -> Self { + Self { + tables: Vec::new(), + } + } + + pub fn add_table(&mut self, table: Table) -> CustomResult<&mut Self> { + self.tables.push(table); + Ok(self) + } + + pub fn build(&self, db_type: DatabaseType) -> CustomResult { + let mut sql = String::new(); + for table in &self.tables { + sql.push_str(&table.to_sql(db_type)?); + sql.push_str("\n\n"); + } + Ok(sql) + } +} + +pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResult { + let db_prefix=db_prefix.to_sql_string()?; + let mut schema = SchemaBuilder::new(); + let user_level = "('contributor', 'administrator')"; + let content_state = "('draft', 'published', 'private', 'hidden')"; + + // 用户表 + let mut users_table = Table::new(&format!("{}users",db_prefix))?; + users_table + .add_field(Field::new( + "username", + FieldType::VarChar(100), + FieldConstraint::new().primary(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "avatar_url", + FieldType::VarChar(255), + FieldConstraint::new(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "email", + FieldType::VarChar(255), + FieldConstraint::new().unique().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "profile_icon", + FieldType::VarChar(255), + FieldConstraint::new(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "password_hash", + FieldType::VarChar(255), + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "role", + FieldType::VarChar(20), + FieldConstraint::new() + .not_null() + .check(WhereClause::Condition(Condition::new( + "role".to_string(), + Operator::In, + Some(SafeValue::Text(user_level.to_string(), ValidationLevel::Relaxed)), + )?)), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "created_at", + FieldType::Timestamp, + FieldConstraint::new().not_null().default(SafeValue::Text( + "CURRENT_TIMESTAMP".to_string(), + ValidationLevel::Strict, + )), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "updated_at", + FieldType::Timestamp, + FieldConstraint::new().not_null().default(SafeValue::Text( + "CURRENT_TIMESTAMP".to_string(), + ValidationLevel::Strict, + )), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "last_login_at", + FieldType::Timestamp, + FieldConstraint::new().default(SafeValue::Text( + "CURRENT_TIMESTAMP".to_string(), + ValidationLevel::Strict, + )), + ValidationLevel::Strict, + )?); + + schema.add_table(users_table)?; + + // 独立页面表 + + let mut pages_table = Table::new(&format!("{}pages",db_prefix))?; + pages_table + .add_field(Field::new( + "id", + FieldType::Integer(true), + FieldConstraint::new().primary(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "title", + FieldType::VarChar(255), + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "meta_keywords", + FieldType::VarChar(255), + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "meta_description", + FieldType::VarChar(255), + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "content", + FieldType::Text, + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "template", + FieldType::VarChar(50), + FieldConstraint::new(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "custom_fields", + FieldType::Text, + FieldConstraint::new(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "status", + FieldType::VarChar(20), + FieldConstraint::new() + .not_null() + .check(WhereClause::Condition(Condition::new( + "status".to_string(), + Operator::In, + Some(SafeValue::Text(content_state.to_string(), ValidationLevel::Standard)), + )?)), + ValidationLevel::Strict, + )?); + + schema.add_table(pages_table)?; + + + // posts 表 + let mut posts_table = Table::new(&format!("{}posts",db_prefix))?; + posts_table + .add_field(Field::new( + "id", + FieldType::Integer(true), + FieldConstraint::new().primary(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "author_name", + FieldType::VarChar(100), + FieldConstraint::new() + .not_null() + .foreign_key(format!("{}users",db_prefix), "username".to_string()) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "cover_image", + FieldType::VarChar(255), + FieldConstraint::new(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "title", + FieldType::VarChar(255), + FieldConstraint::new(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "meta_keywords", + FieldType::VarChar(255), + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "meta_description", + FieldType::VarChar(255), + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "content", + FieldType::Text, + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "status", + FieldType::VarChar(20), + FieldConstraint::new() + .not_null() + .check(WhereClause::Condition(Condition::new( + "status".to_string(), + Operator::In, + Some(SafeValue::Text(content_state.to_string(), ValidationLevel::Standard)), + )?)), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "is_editor", + FieldType::Boolean, + FieldConstraint::new() + .not_null() + .default(SafeValue::Bool(false)), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "draft_content", + FieldType::Text, + FieldConstraint::new(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "created_at", + FieldType::Timestamp, + FieldConstraint::new().not_null().default(SafeValue::Text( + "CURRENT_TIMESTAMP".to_string(), + ValidationLevel::Strict, + )), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "updated_at", + FieldType::Timestamp, + FieldConstraint::new().not_null().default(SafeValue::Text( + "CURRENT_TIMESTAMP".to_string(), + ValidationLevel::Strict, + )), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "published_at", + FieldType::Timestamp, + FieldConstraint::new(), + ValidationLevel::Strict, + )?); + + schema.add_table(posts_table)?; + + // 标签表 + let mut tags_tables = Table::new(&format!("{}tags",db_prefix))?; + tags_tables + .add_field(Field::new( + "name", + FieldType::VarChar(50), + FieldConstraint::new().primary(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "icon", + FieldType::VarChar(255), + FieldConstraint::new(), + ValidationLevel::Strict, + )?); + + schema.add_table(tags_tables)?; + + + // 文章标签 + let mut post_tags_tables = Table::new(&format!("{}post_tags",db_prefix))?; + post_tags_tables + .add_field(Field::new( + "post_id", + FieldType::Integer(false), + FieldConstraint::new() + .not_null() + .foreign_key(format!("{}posts",db_prefix), "id".to_string()) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ValidationLevel::Strict, + )?).add_field(Field::new( + "tag_id", + FieldType::VarChar(50), + FieldConstraint::new() + .not_null() + .foreign_key(format!("{}tags",db_prefix), "name".to_string()) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ValidationLevel::Strict, + )?); + + post_tags_tables.add_index(Index::new( + "pk_post_tags", + vec!["post_id".to_string(), "tag_id".to_string()], + true, + )?); + + schema.add_table(post_tags_tables)?; + + // 分类表 + + let mut categories_table = Table::new(&format!("{}categories",db_prefix))?; + categories_table + .add_field(Field::new( + "name", + FieldType::VarChar(50), + FieldConstraint::new().primary(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "parent_id", + FieldType::VarChar(50), + FieldConstraint::new() + .foreign_key(format!("{}categories",db_prefix), "name".to_string()), + ValidationLevel::Strict, + )?); + + schema.add_table(categories_table)?; + + // 文章分类关联表 + let mut post_categories_table = Table::new(&format!("{}post_categories",db_prefix))?; + post_categories_table + .add_field(Field::new( + "post_id", + FieldType::Integer(false), + FieldConstraint::new() + .not_null() + .foreign_key(format!("{}posts",db_prefix), "id".to_string()) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "category_id", + FieldType::VarChar(50), + FieldConstraint::new() + .not_null() + .foreign_key(format!("{}categories",db_prefix), "name".to_string()) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ValidationLevel::Strict, + )?); + + post_categories_table.add_index(Index::new( + "pk_post_categories", + vec!["post_id".to_string(), "category_id".to_string()], + true, + )?); + + schema.add_table(post_categories_table)?; + + // 资源库表 + let mut resources_table = Table::new(&format!("{}resources",db_prefix))?; + resources_table + .add_field(Field::new( + "id", + FieldType::Integer(true), + FieldConstraint::new().primary(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "author_id", + FieldType::VarChar(100), + FieldConstraint::new() + .not_null() + .foreign_key(format!("{}users",db_prefix), "username".to_string()) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "name", + FieldType::VarChar(255), + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "size_bytes", + FieldType::BigInt, + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "storage_path", + FieldType::VarChar(255), + FieldConstraint::new().not_null().unique(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "file_type", + FieldType::VarChar(50), + FieldConstraint::new().not_null(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "category", + FieldType::VarChar(50), + FieldConstraint::new(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "description", + FieldType::VarChar(255), + FieldConstraint::new(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "created_at", + FieldType::Timestamp, + FieldConstraint::new().not_null().default(SafeValue::Text( + "CURRENT_TIMESTAMP".to_string(), + ValidationLevel::Strict, + )), + ValidationLevel::Strict, + )?); + + schema.add_table(resources_table)?; + + // 配置表 + let mut settings_table = Table::new(&format!("{}settings",db_prefix))?; + settings_table + .add_field(Field::new( + "name", + FieldType::VarChar(50), + FieldConstraint::new().primary(), + ValidationLevel::Strict, + )?) + .add_field(Field::new( + "data", + FieldType::Text, + FieldConstraint::new(), + ValidationLevel::Strict, + )?); + + schema.add_table(settings_table)?; + + schema.build(db_type) +} diff --git a/backend/src/storage/sql/sqllite.rs b/backend/src/storage/sql/sqllite.rs new file mode 100644 index 0000000..cd4d3e9 --- /dev/null +++ b/backend/src/storage/sql/sqllite.rs @@ -0,0 +1,107 @@ +use super::{ + builder::{self, SafeValue}, + schema, DatabaseTrait, +}; +use crate::common::error::{CustomError, CustomErrorInto, CustomResult}; +use crate::config; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use serde_json::Value; +use sqlx::{Column, Executor, SqlitePool, Row, TypeInfo}; +use std::collections::HashMap; +use std::env; + +#[derive(Clone)] +pub struct Sqlite { + pool: SqlitePool, +} + +#[async_trait] +impl DatabaseTrait for Sqlite { + async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> { + let db_prefix = SafeValue::Text( + format!("{}", db_config.db_prefix), + builder::ValidationLevel::Strict, + ); + + let sqlite_dir = env::current_dir()?.join("assets").join("sqllite"); + std::fs::create_dir_all(&sqlite_dir)?; + + let db_file = sqlite_dir.join(&db_config.db_name); + std::fs::File::create(&db_file)?; + + let path = db_file.to_str().ok_or("Unable to get sqllite path".into_custom_error())?; + let grammar = schema::generate_schema(schema::DatabaseType::SQLite, db_prefix)?; + + let connection_str = format!("sqlite:///{}", path); + let pool = SqlitePool::connect(&connection_str).await?; + + pool.execute(grammar.as_str()).await?; + + Ok(()) + } + + async fn connect(db_config: &config::SqlConfig) -> CustomResult { + let db_file = env::current_dir()? + .join("assets") + .join("sqllite") + .join(&db_config.db_name); + + if !db_file.exists() { + return Err("SQLite database file does not exist".into_custom_error()); + } + + let path = db_file.to_str().ok_or("Unable to get sqllite path".into_custom_error())?; + let connection_str = format!("sqlite:///{}", path); + let pool = SqlitePool::connect(&connection_str).await?; + + Ok(Sqlite { pool }) + } + + async fn execute_query<'a>( + &'a self, + builder: &builder::QueryBuilder, + ) -> CustomResult>> { + let (query, values) = builder.build()?; + + let mut sqlx_query = sqlx::query(&query); + + for value in values { + match value { + SafeValue::Null => sqlx_query = sqlx_query.bind(None::), + SafeValue::Bool(b) => sqlx_query = sqlx_query.bind(b), + SafeValue::Integer(i) => sqlx_query = sqlx_query.bind(i), + SafeValue::Float(f) => sqlx_query = sqlx_query.bind(f), + SafeValue::Text(s, _) => sqlx_query = sqlx_query.bind(s), + SafeValue::DateTime(dt) => sqlx_query = sqlx_query.bind(dt.to_rfc3339()), + } + } + + let rows = sqlx_query.fetch_all(&self.pool).await?; + + Ok(rows + .into_iter() + .map(|row| { + row.columns() + .iter() + .map(|col| { + let value = match col.type_info().name() { + "INTEGER" => Value::Number( + row.try_get::(col.name()).unwrap_or_default().into(), + ), + "REAL" => Value::Number( + serde_json::Number::from_f64( + row.try_get::(col.name()).unwrap_or(0.0), + ) + .unwrap_or_else(|| 0.into()), + ), + "BOOLEAN" => Value::Bool(row.try_get(col.name()).unwrap_or_default()), + _ => Value::String(row.try_get(col.name()).unwrap_or_default()), + }; + (col.name().to_string(), value) + }) + .collect() + }) + .collect()) + } +} diff --git a/frontend/app/env.d.ts b/frontend/app/env.d.ts index ebb1615..0d45b87 100644 --- a/frontend/app/env.d.ts +++ b/frontend/app/env.d.ts @@ -7,12 +7,10 @@ /// interface ImportMetaEnv { - readonly VITE_SERVER_API: string; // 用于访问API的基础URL - readonly VITE_ADDRESS: string; // 前端地址 - readonly VITE_PORT: number; // 前端系统端口 - VITE_USERNAME: string; // 前端账号名称 - VITE_PASSWORD: string; // 前端账号密码 - VITE_INIT_STATUS: boolean; // 系统是否进行安装 + readonly VITE_INIT_STATUS: string; + readonly VITE_SERVER_API: string; + readonly VITE_PORT: string; + readonly VITE_ADDRESS: string; } interface ImportMeta { diff --git a/frontend/app/init.tsx b/frontend/app/init.tsx index 1911606..92a5194 100644 --- a/frontend/app/init.tsx +++ b/frontend/app/init.tsx @@ -1,5 +1,5 @@ -import React, { useContext, createContext, useState } from "react"; - +import React, { createContext, useState, useEffect } from "react"; +import { useApi } from "hooks/servicesProvider"; interface SetupContextType { currentStep: number; setCurrentStep: (step: number) => void; @@ -21,30 +21,22 @@ const StepContainer: React.FC<{ title: string; children: React.ReactNode }> = ({ title, children, }) => ( -
-

+
+

{title}

-
+
{children}
); // 通用的导航按钮组件 -const NavigationButtons: React.FC = ({ onNext, onPrev }) => ( -
- {onPrev && ( - - )} +const NavigationButtons: React.FC = ({ onNext }) => ( +
@@ -58,17 +50,17 @@ const InputField: React.FC<{ defaultValue?: string | number; hint?: string; }> = ({ label, name, defaultValue, hint }) => ( -
-

+
+

{label}

{hint && ( -

+

{hint}

)} @@ -78,7 +70,7 @@ const InputField: React.FC<{ const Introduction: React.FC = ({ onNext }) => (
-

+

欢迎使用 Echoes

@@ -86,22 +78,24 @@ const Introduction: React.FC = ({ onNext }) => ( ); -const DatabaseConfig: React.FC = ({ onNext, onPrev }) => { +const DatabaseConfig: React.FC = ({ onNext }) => { const [dbType, setDbType] = useState("postgresql"); return ( -
+
-

+

数据库类型

@@ -113,6 +107,12 @@ const DatabaseConfig: React.FC = ({ onNext, onPrev }) => { defaultValue="localhost" hint="通常使用 localhost" /> + = ({ onNext, onPrev }) => { /> )} - + {dbType === "mysql" && ( + <> + + + + + + + + )} + {dbType === "sqllite" && ( + <> + + + + )} +
); }; -const AdminConfig: React.FC = ({ onNext, onPrev }) => ( +const AdminConfig: React.FC = ({ onNext }) => (
- +
); @@ -163,31 +215,71 @@ const SetupComplete: React.FC = () => ( ); -export default function SetupPage() { - const [currentStep, setCurrentStep] = useState(1); +// 修改主题切换按钮组件 +const ThemeToggle: React.FC = () => { + const [isDark, setIsDark] = useState(false); + const [isVisible, setIsVisible] = useState(true); + + useEffect(() => { + const isDarkMode = document.documentElement.classList.contains('dark'); + setIsDark(isDarkMode); + + const handleScroll = () => { + const currentScrollPos = window.scrollY; + setIsVisible(currentScrollPos < 100); // 滚动超过100px就隐藏 + }; + + window.addEventListener('scroll', handleScroll); + return () => window.removeEventListener('scroll', handleScroll); + }, []); + + const toggleTheme = () => { + const newIsDark = !isDark; + setIsDark(newIsDark); + document.documentElement.classList.toggle('dark'); + }; return ( -
-
-
-

- Echoes -

-
+ + ); +}; + +export default function SetupPage() { + let step = Number(import.meta.env.VITE_INIT_STATUS); + + const [currentStep, setCurrentStep] = useState(step); + + return ( +
+ +
{currentStep === 1 && ( setCurrentStep(currentStep + 1)} /> )} {currentStep === 2 && ( - setCurrentStep(currentStep + 1)} - onPrev={() => setCurrentStep(currentStep - 1)} /> )} {currentStep === 3 && ( - setCurrentStep(currentStep + 1)} - onPrev={() => setCurrentStep(currentStep - 1)} /> )} {currentStep === 4 && } diff --git a/frontend/core/template.ts b/frontend/core/template.ts new file mode 100644 index 0000000..0c6f8d1 --- /dev/null +++ b/frontend/core/template.ts @@ -0,0 +1,12 @@ +export interface Template { + name: string; + description?: string; + config: { + layout?: string; + styles?: string[]; + scripts?: string[]; + }; + loader: () => Promise; + element: () => React.ReactNode; + } + \ No newline at end of file diff --git a/frontend/core/theme.ts b/frontend/core/theme.ts index 4f57d1b..b059e21 100644 --- a/frontend/core/theme.ts +++ b/frontend/core/theme.ts @@ -1,5 +1,6 @@ import { Configuration, PathDescription } from "common/serializableType"; -import { ApiService } from "./api"; +import { ApiService } from "core/api"; +import { Template } from "core/template"; export interface ThemeConfig { name: string; @@ -25,17 +26,6 @@ export interface ThemeConfig { }; } -export interface Template { - name: string; - description?: string; - config: { - layout?: string; - styles?: string[]; - scripts?: string[]; - }; - loader: () => Promise; - element: () => React.ReactNode; -} export class ThemeService { private static instance: ThemeService; @@ -56,7 +46,7 @@ export class ThemeService { public async getCurrentTheme(): Promise { try { const themeConfig = await this.api.request( - "/theme/current", + "/theme", { method: "GET" }, ); this.currentTheme = themeConfig; @@ -70,10 +60,10 @@ export class ThemeService { return this.currentTheme; } - public async updateThemeConfig(config: Partial): Promise { + public async updateThemeConfig(config: Partial,name:string): Promise { try { const updatedConfig = await this.api.request( - "/theme/config", + `/theme/`, { method: "PUT", headers: { diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 7c96432..d1548dc 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -17,7 +17,7 @@ export default defineConfig(({ mode }) => { }, routes: (defineRoutes) => { return defineRoutes((route) => { - if (!env.VITE_INIT_STATUS) { + if (Number(env.VITE_INIT_STATUS??1)<4) { route("/", "init.tsx", { id: "index-route" }); route("*", "init.tsx", { id: "catch-all-route" }); } else { @@ -30,7 +30,7 @@ export default defineConfig(({ mode }) => { tsconfigPaths(), ], define: { - "import.meta.env.VITE_INIT_STATUS": JSON.stringify(false), + "import.meta.env.VITE_INIT_STATUS": JSON.stringify(1), "import.meta.env.VITE_SERVER_API": JSON.stringify("localhost:22000"), "import.meta.env.VITE_PORT": JSON.stringify(22100), "import.meta.env.VITE_ADDRESS": JSON.stringify("localhost"),