diff --git a/backend/src/api/auth/token.rs b/backend/src/api/auth/token.rs index 977907e..60ca90d 100644 --- a/backend/src/api/auth/token.rs +++ b/backend/src/api/auth/token.rs @@ -1,17 +1,10 @@ +use crate::common::error::{AppResult, AppResultInto}; use crate::security; use crate::storage::sql::builder; -use crate::common::error::{AppResult, AppResultInto}; use crate::AppState; use chrono::Duration; -use rocket::{ - http::Status, - post, - response::status, - serde::json::{Json, Value}, - State, -}; +use rocket::{http::Status, post, response::status, serde::json::Json, State}; use serde::{Deserialize, Serialize}; -use serde_json::json; use std::sync::Arc; #[derive(Deserialize, Serialize)] pub struct TokenSystemData { @@ -23,13 +16,13 @@ pub async fn token_system( state: &State>, data: Json, ) -> AppResult { - let sql = state - .sql_get() - .await + let sql = state.sql_get().await.into_app_result()?; + let mut builder = builder::QueryBuilder::new( + builder::SqlOperation::Select, + sql.table_name("users"), + sql.get_type(), + ) .into_app_result()?; - let mut builder = - builder::QueryBuilder::new(builder::SqlOperation::Select, sql.table_name("users"), sql.get_type()) - .into_app_result()?; builder .add_field("password_hash".to_string()) .into_app_result()? @@ -69,7 +62,7 @@ pub async fn token_system( ), ])); - let values = sql + let values = sql .get_db() .execute_query(&builder) .await @@ -79,14 +72,10 @@ pub async fn token_system( .first() .and_then(|row| row.get("password_hash")) .and_then(|val| val.as_str()) - .ok_or_else(|| { - status::Custom(Status::NotFound, "Invalid system user or password".into()) - })?; - - println!("{}\n{}",&data.password,password.clone()); + .ok_or_else(|| status::Custom(Status::NotFound, "系统用户或密码无效".into()))?; security::bcrypt::verify_hash(&data.password, password) - .map_err(|_| status::Custom(Status::Forbidden, "Invalid password".into()))?; + .map_err(|_| status::Custom(Status::Forbidden, "密码无效".into()))?; Ok(security::jwt::generate_jwt( security::jwt::CustomClaims { diff --git a/backend/src/api/settings.rs b/backend/src/api/settings.rs index 6df1363..0ca75ae 100644 --- a/backend/src/api/settings.rs +++ b/backend/src/api/settings.rs @@ -1,12 +1,10 @@ use super::SystemToken; -use crate::storage::{sql, sql::builder}; use crate::common::error::{AppResult, AppResultInto, CustomResult}; +use crate::storage::{sql, sql::builder}; use crate::AppState; -use rocket::data; use rocket::{ get, http::Status, - response::status, serde::json::{Json, Value}, State, }; @@ -51,8 +49,11 @@ pub async fn get_setting( let where_clause = builder::WhereClause::Condition(name_condition); - let mut sql_builder = - builder::QueryBuilder::new(builder::SqlOperation::Select, sql.table_name("settings"),sql.get_type())?; + let mut sql_builder = builder::QueryBuilder::new( + builder::SqlOperation::Select, + sql.table_name("settings"), + sql.get_type(), + )?; sql_builder .add_condition(where_clause) @@ -69,8 +70,11 @@ pub async fn insert_setting( name: String, data: Json, ) -> CustomResult<()> { - let mut builder = - builder::QueryBuilder::new(builder::SqlOperation::Insert, sql.table_name("settings"),sql.get_type())?; + let mut builder = builder::QueryBuilder::new( + builder::SqlOperation::Insert, + sql.table_name("settings"), + sql.get_type(), + )?; builder.set_value( "name".to_string(), builder::SafeValue::Text( @@ -80,7 +84,7 @@ pub async fn insert_setting( )?; builder.set_value( "data".to_string(), - builder::SafeValue::Text(data.to_string(),builder::ValidationLevel::Relaxed), + builder::SafeValue::Text(data.to_string(), builder::ValidationLevel::Relaxed), )?; sql.get_db().execute_query(&builder).await?; Ok(()) @@ -92,7 +96,7 @@ pub async fn system_config_get( _token: SystemToken, ) -> AppResult> { let sql = state.sql_get().await.into_app_result()?; - let settings = get_setting(&sql, "system".to_string(), sql.table_name("settings")) + let settings = get_setting(&sql, "system".to_string(), sql.table_name("settings")) .await .into_app_result()?; Ok(settings) @@ -109,4 +113,4 @@ pub async fn theme_config_get( .await .into_app_result()?; Ok(settings) -} \ No newline at end of file +} diff --git a/backend/src/api/setup.rs b/backend/src/api/setup.rs index ac1b25c..21a78d2 100644 --- a/backend/src/api/setup.rs +++ b/backend/src/api/setup.rs @@ -6,7 +6,6 @@ use crate::security; use crate::storage::sql; use crate::AppState; use chrono::Duration; -use rocket::data; use rocket::{http::Status, post, response::status, serde::json::Json, State}; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -21,7 +20,7 @@ pub async fn setup_sql( if config.init.sql { return Err(status::Custom( Status::BadRequest, - "Database already initialized".to_string(), + "数据库已经初始化".to_string(), )); } @@ -34,7 +33,7 @@ pub async fn setup_sql( .into_app_result()?; config::Config::write(config).into_app_result()?; - state.trigger_restart().await.into_app_result()?; + state.restart_server().await.into_app_result()?; Ok("Database installation successful".to_string()) } @@ -52,17 +51,16 @@ pub struct InstallReplyData { password: String, } - #[post("/administrator", format = "application/json", data = "")] pub async fn setup_account( data: Json, state: &State>, ) -> AppResult>> { let mut config = config::Config::read().unwrap_or_default(); - if config.init.administrator { + if config.init.administrator { return Err(status::Custom( Status::BadRequest, - "Administrator user has been set".to_string(), + "管理员用户已设置".to_string(), )); } @@ -123,9 +121,9 @@ pub async fn setup_account( Duration::days(7), ) .into_app_result()?; - config.init.administrator=true; + config.init.administrator = true; config::Config::write(config).into_app_result()?; - state.trigger_restart().await.into_app_result()?; + state.restart_server().await.into_app_result()?; Ok(status::Custom( Status::Ok, diff --git a/backend/src/api/users.rs b/backend/src/api/users.rs index 1b0d74b..8bf71cf 100644 --- a/backend/src/api/users.rs +++ b/backend/src/api/users.rs @@ -1,10 +1,8 @@ -use crate::security; +use crate::common::error::{CustomErrorInto, CustomResult}; use crate::security::bcrypt; use crate::storage::{sql, sql::builder}; -use crate::common::error::{CustomErrorInto, CustomResult}; use rocket::{get, http::Status, post, response::status, serde::json::Json, State}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; #[derive(Deserialize, Serialize)] pub struct LoginData { @@ -20,16 +18,23 @@ pub struct RegisterData { pub role: String, } -pub async fn insert_user(sql: &sql::Database , data: RegisterData) -> CustomResult<()> { +pub async fn insert_user(sql: &sql::Database, data: RegisterData) -> CustomResult<()> { let role = match data.role.as_str() { "administrator" | "contributor" => data.role, - _ => return Err("Invalid role. Must be either 'administrator' or 'contributor'".into_custom_error()), + _ => { + return Err( + "Invalid role. Must be either 'administrator' or 'contributor'".into_custom_error(), + ) + } }; let password_hash = bcrypt::generate_hash(&data.password)?; - let mut builder = - builder::QueryBuilder::new(builder::SqlOperation::Insert, sql.table_name("users"), sql.get_type())?; + let mut builder = builder::QueryBuilder::new( + builder::SqlOperation::Insert, + sql.table_name("users"), + sql.get_type(), + )?; builder .set_value( "username".to_string(), diff --git a/backend/src/common/config.rs b/backend/src/common/config.rs index c55bc22..ab1f15a 100644 --- a/backend/src/common/config.rs +++ b/backend/src/common/config.rs @@ -42,19 +42,19 @@ impl Default for Init { #[derive(Deserialize, Serialize, Debug, Clone)] pub struct SqlConfig { pub db_type: String, - pub address: String, + pub host: String, pub port: u32, pub user: String, pub password: String, pub db_name: String, - pub db_prefix:String, + pub db_prefix: String, } impl Default for SqlConfig { fn default() -> Self { Self { db_type: "sqllite".to_string(), - address: "".to_string(), + host: "".to_string(), port: 0, user: "".to_string(), password: "".to_string(), @@ -67,7 +67,7 @@ impl Default for SqlConfig { #[derive(Deserialize, Serialize, Debug, Clone)] pub struct NoSqlConfig { pub db_type: String, - pub address: String, + pub host: String, pub port: u32, pub user: String, pub password: String, @@ -78,7 +78,7 @@ impl Default for NoSqlConfig { fn default() -> Self { Self { db_type: "postgresql".to_string(), - address: "localhost".to_string(), + host: "localhost".to_string(), port: 5432, user: "postgres".to_string(), password: "postgres".to_string(), diff --git a/backend/src/common/mod.rs b/backend/src/common/mod.rs index d8458fe..562b5dd 100644 --- a/backend/src/common/mod.rs +++ b/backend/src/common/mod.rs @@ -1,3 +1,3 @@ +pub mod config; pub mod error; pub mod helpers; -pub mod config; diff --git a/backend/src/main.rs b/backend/src/main.rs index 8c09594..76bb1f6 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,7 +1,7 @@ -mod security; -mod common; -mod storage; mod api; +mod common; +mod security; +mod storage; use crate::common::config; use common::error::{CustomErrorInto, CustomResult}; @@ -13,6 +13,7 @@ pub struct AppState { db: Arc>>, shutdown: Arc>>, restart_progress: Arc>, + restart_attempts: Arc>, } impl AppState { @@ -21,11 +22,16 @@ impl AppState { db: Arc::new(Mutex::new(None)), shutdown: Arc::new(Mutex::new(None)), restart_progress: Arc::new(Mutex::new(false)), + restart_attempts: Arc::new(Mutex::new(0)), } } pub async fn sql_get(&self) -> CustomResult { - self.db.lock().await.clone().ok_or_else(|| "数据库未连接".into_custom_error()) + self.db + .lock() + .await + .clone() + .ok_or_else(|| "数据库未连接".into_custom_error()) } pub async fn sql_link(&self, config: &config::SqlConfig) -> CustomResult<()> { @@ -33,14 +39,40 @@ impl AppState { Ok(()) } - pub async fn set_shutdown(&self, shutdown: Shutdown) { *self.shutdown.lock().await = Some(shutdown); } pub async fn trigger_restart(&self) -> CustomResult<()> { *self.restart_progress.lock().await = true; - self.shutdown.lock().await.take().ok_or_else(|| "未能获取rocket的shutdown".into_custom_error())?.notify(); + self.shutdown + .lock() + .await + .take() + .ok_or_else(|| "未能获取rocket的shutdown".into_custom_error())? + .notify(); + Ok(()) + } + + pub async fn restart_server(&self) -> CustomResult<()> { + const MAX_RESTART_ATTEMPTS: u32 = 3; + const RESTART_DELAY_MS: u64 = 1000; + + let mut attempts = self.restart_attempts.lock().await; + if *attempts >= MAX_RESTART_ATTEMPTS { + return Err("达到最大重启尝试次数".into_custom_error()); + } + *attempts += 1; + + *self.restart_progress.lock().await = true; + + self.shutdown + .lock() + .await + .take() + .ok_or_else(|| "未能获取rocket的shutdown".into_custom_error())? + .notify(); + Ok(()) } } @@ -53,9 +85,13 @@ async fn main() -> CustomResult<()> { }); let state = Arc::new(AppState::new()); - let rocket_config = rocket::Config::figment().merge(("address", config.address)).merge(("port", config.port)); + let rocket_config = rocket::Config::figment() + .merge(("address", config.address)) + .merge(("port", config.port)); - let mut rocket_builder = rocket::build().configure(rocket_config).manage(state.clone()); + let mut rocket_builder = rocket::build() + .configure(rocket_config) + .manage(state.clone()); if !config.init.sql { rocket_builder = rocket_builder.mount("/", rocket::routes![api::setup::setup_sql]); @@ -63,24 +99,47 @@ async fn main() -> CustomResult<()> { rocket_builder = rocket_builder.mount("/", rocket::routes![api::setup::setup_account]); } else { state.sql_link(&config.sql_config).await?; - rocket_builder = rocket_builder.mount("/auth/token", api::jwt_routes()).mount("/config", api::configure_routes()); + rocket_builder = rocket_builder + .mount("/auth/token", api::jwt_routes()) + .mount("/config", api::configure_routes()); } let rocket = rocket_builder.ignite().await?; - rocket.state::>().ok_or_else(|| "未能获取AppState".into_custom_error())?.set_shutdown(rocket.shutdown()).await; + rocket + .state::>() + .ok_or_else(|| "无法获取AppState".into_custom_error())? + .set_shutdown(rocket.shutdown()) + .await; rocket.launch().await?; if *state.restart_progress.lock().await { + tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + if let Ok(current_exe) = std::env::current_exe() { - match std::process::Command::new(current_exe).spawn() { - Ok(_) => println!("成功启动新进程"), - Err(e) => eprintln!("启动新进程失败: {}", e), + println!("正在尝试重启服务器..."); + + let mut command = std::process::Command::new(current_exe); + command.env("RUST_BACKTRACE", "1"); + + match command.spawn() { + Ok(child) => { + println!("成功启动新进程 (PID: {})", child.id()); + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + } + Err(e) => { + eprintln!("启动新进程失败: {}", e); + *state.restart_progress.lock().await = false; + return Err(format!("重启失败: {}", e).into_custom_error()); + } }; } else { eprintln!("获取当前可执行文件路径失败"); + return Err("重启失败: 无法获取可执行文件路径".into_custom_error()); } } + + println!("服务器正常退出"); std::process::exit(0); } diff --git a/backend/src/storage/mod.rs b/backend/src/storage/mod.rs index 07c4d2f..2752f63 100644 --- a/backend/src/storage/mod.rs +++ b/backend/src/storage/mod.rs @@ -1 +1 @@ -pub mod sql; \ No newline at end of file +pub mod sql; diff --git a/backend/src/storage/sql/builder.rs b/backend/src/storage/sql/builder.rs index aedc3e0..0c72c49 100644 --- a/backend/src/storage/sql/builder.rs +++ b/backend/src/storage/sql/builder.rs @@ -1,17 +1,17 @@ +use super::DatabaseType; use crate::common::error::{CustomErrorInto, CustomResult}; use chrono::{DateTime, Utc}; use regex::Regex; use serde::Serialize; use std::collections::HashMap; use std::hash::Hash; -use crate::sql::schema::DatabaseType; #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, Serialize)] pub enum ValidationLevel { Strict, Standard, Relaxed, - Raw, + Raw, } #[derive(Debug, Clone)] @@ -82,10 +82,10 @@ impl TextValidator { let max_length = self .level_max_lengths .get(&level) - .ok_or("Invalid validation level".into_custom_error())?; + .ok_or("无效的验证级别".into_custom_error())?; if text.len() > *max_length { - return Err("Text exceeds maximum length".into_custom_error()); + return Err("文本超出最大长度限制".into_custom_error()); } if level == ValidationLevel::Relaxed { @@ -103,7 +103,7 @@ impl TextValidator { .iter() .any(|&pattern| upper_text.contains(&pattern.to_uppercase())) { - return Err("Potentially dangerous SQL pattern detected".into_custom_error()); + return Err("检测到潜在危险的SQL模式".into_custom_error()); } Ok(()) } @@ -112,14 +112,14 @@ impl TextValidator { let allowed_chars = self .level_allowed_chars .get(&level) - .ok_or_else(|| "Invalid validation level".into_custom_error())?; + .ok_or_else(|| "无效的验证级别".into_custom_error())?; if let Some(invalid_char) = text .chars() .find(|&c| !c.is_alphanumeric() && !allowed_chars.contains(&c)) { return Err( - format!("Invalid character '{}' for {:?} level", invalid_char, level) + format!("'{}'字符在{:?}验证级别中是无效的", invalid_char, level) .into_custom_error(), ); } @@ -128,7 +128,7 @@ impl TextValidator { fn validate_special_chars(&self, text: &str) -> CustomResult<()> { if self.special_chars.iter().any(|&c| text.contains(c)) { - return Err("Invalid special character detected".into_custom_error()); + return Err("检测到无效的特殊字符".into_custom_error()); } Ok(()) } @@ -179,7 +179,6 @@ impl std::fmt::Display for SafeValue { } impl SafeValue { - fn get_sql_type(&self) -> CustomResult { let sql_type = match self { SafeValue::Null => "NULL", @@ -192,7 +191,7 @@ impl SafeValue { Ok(sql_type.to_string()) } - pub fn to_sql_string(&self) -> CustomResult { + pub fn to_string(&self) -> CustomResult { match self { SafeValue::Null => Ok("NULL".to_string()), SafeValue::Bool(b) => Ok(if *b { "true" } else { "false" }.to_string()), @@ -225,9 +224,9 @@ pub struct Identifier(String); impl Identifier { pub fn new(value: String) -> CustomResult { - let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_\.]{0,63}$")?; + let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_.]{0,63}$")?; if !valid_pattern.is_match(&value) { - return Err("Invalid identifier format".into_custom_error()); + return Err("标识符格式无效".into_custom_error()); } Ok(Identifier(value)) } @@ -314,7 +313,11 @@ pub struct QueryBuilder { } impl QueryBuilder { - pub fn new(operation: SqlOperation, table: String, db_type: DatabaseType) -> CustomResult { + pub fn new( + operation: SqlOperation, + table: String, + db_type: DatabaseType, + ) -> CustomResult { Ok(QueryBuilder { operation, table: Identifier::new(table)?, @@ -413,7 +416,7 @@ impl QueryBuilder { } else { value.to_param_sql(params.len() + 1, self.db_type)? }; - + let set_sql = format!("{} = {}", field.as_str(), placeholder); if !matches!(value, SafeValue::Null) { params.push(value.clone()); @@ -499,7 +502,7 @@ impl QueryBuilder { } else { value.to_param_sql(param_index, self.db_type)? }; - + let sql = format!( "{} {} {}", condition.field.as_str(), diff --git a/backend/src/storage/sql/mod.rs b/backend/src/storage/sql/mod.rs index db3df33..d09ff9c 100644 --- a/backend/src/storage/sql/mod.rs +++ b/backend/src/storage/sql/mod.rs @@ -1,16 +1,32 @@ -mod postgresql; -mod mysql; -mod sqllite; pub mod builder; +mod mysql; +mod postgresql; mod schema; +mod sqllite; -use crate::config; use crate::common::error::{CustomErrorInto, CustomResult}; +use crate::config; use async_trait::async_trait; use std::{collections::HashMap, sync::Arc}; -use schema::DatabaseType; -#[async_trait] +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum DatabaseType { + PostgreSQL, + MySQL, + SQLite, +} + +impl std::fmt::Display for DatabaseType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DatabaseType::PostgreSQL => write!(f, "postgresql"), + DatabaseType::MySQL => write!(f, "mysql"), + DatabaseType::SQLite => write!(f, "sqlite"), + } + } +} + +#[async_trait] pub trait DatabaseTrait: Send + Sync { async fn connect(database: &config::SqlConfig) -> CustomResult where @@ -28,7 +44,7 @@ pub trait DatabaseTrait: Send + Sync { pub struct Database { pub db: Arc>, pub prefix: Arc, - pub db_type: Arc + pub db_type: Arc, } impl Database { @@ -40,20 +56,16 @@ impl Database { &self.prefix } - pub fn get_type(&self) -> DatabaseType { - match self.db_type.as_str() { - "postgresql" => DatabaseType::PostgreSQL, - "mysql" => DatabaseType::MySQL, - _ => DatabaseType::SQLite, - } - } - pub fn table_name(&self, name: &str) -> String { format!("{}{}", self.prefix, name) } + pub fn get_type(&self) -> DatabaseType { + *self.db_type.clone() + } + pub async fn link(database: &config::SqlConfig) -> CustomResult { - let db: Box = match database.db_type.as_str() { + let db: Box = match database.db_type.to_lowercase().as_str() { "postgresql" => Box::new(postgresql::Postgresql::connect(database).await?), "mysql" => Box::new(mysql::Mysql::connect(database).await?), "sqllite" => Box::new(sqllite::Sqlite::connect(database).await?), @@ -63,12 +75,17 @@ impl Database { Ok(Self { db: Arc::new(db), prefix: Arc::new(database.db_prefix.clone()), - db_type: Arc::new(database.db_type.clone()) + db_type: Arc::new(match database.db_type.to_lowercase().as_str() { + "postgresql" => DatabaseType::PostgreSQL, + "mysql" => DatabaseType::MySQL, + "sqllite" => DatabaseType::SQLite, + _ => return Err("unknown database type".into_custom_error()), + }), }) } pub async fn initial_setup(database: config::SqlConfig) -> CustomResult<()> { - match database.db_type.as_str() { + match database.db_type.to_lowercase().as_str() { "postgresql" => postgresql::Postgresql::initialization(database).await?, "mysql" => mysql::Mysql::initialization(database).await?, "sqllite" => sqllite::Sqlite::initialization(database).await?, @@ -76,5 +93,4 @@ impl Database { }; Ok(()) } - } diff --git a/backend/src/storage/sql/mysql.rs b/backend/src/storage/sql/mysql.rs index 31e590a..b6087bb 100644 --- a/backend/src/storage/sql/mysql.rs +++ b/backend/src/storage/sql/mysql.rs @@ -1,12 +1,14 @@ -use super::{builder::{self, SafeValue}, schema, DatabaseTrait}; -use crate::config; +use super::{ + builder::{self, SafeValue}, + schema, DatabaseTrait, +}; use crate::common::error::CustomResult; +use crate::config; use async_trait::async_trait; use serde_json::Value; use sqlx::mysql::MySqlPool; use sqlx::{Column, Executor, Row, TypeInfo}; use std::collections::HashMap; -use chrono::{DateTime, Utc}; #[derive(Clone)] pub struct Mysql { @@ -15,51 +17,16 @@ pub struct Mysql { #[async_trait] impl DatabaseTrait for Mysql { - async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> { - let db_prefix = SafeValue::Text(format!("{}",db_config.db_prefix), builder::ValidationLevel::Strict); - let grammar = schema::generate_schema(schema::DatabaseType::MySQL,db_prefix)?; - let connection_str = format!( - "mysql://{}:{}@{}:{}", - db_config.user, db_config.password, db_config.address, db_config.port - ); - - let pool = MySqlPool::connect(&connection_str).await?; - - pool.execute(format!("CREATE DATABASE `{}`", db_config.db_name).as_str()).await?; - pool.execute(format!( - "ALTER DATABASE `{}` DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci", - db_config.db_name - ).as_str()).await?; - - - let new_connection_str = format!( - "mysql://{}:{}@{}:{}/{}", - db_config.user, - db_config.password, - db_config.address, - db_config.port, - db_config.db_name - ); - let new_pool = MySqlPool::connect(&new_connection_str).await?; - - new_pool.execute(grammar.as_str()).await?; - Ok(()) - } async fn connect(db_config: &config::SqlConfig) -> CustomResult { let connection_str = format!( "mysql://{}:{}@{}:{}/{}", - db_config.user, - db_config.password, - db_config.address, - db_config.port, - db_config.db_name + db_config.user, db_config.password, db_config.host, db_config.port, db_config.db_name ); let pool = MySqlPool::connect(&connection_str).await?; Ok(Mysql { pool }) } - async fn execute_query<'a>( &'a self, builder: &builder::QueryBuilder, @@ -107,4 +74,38 @@ impl DatabaseTrait for Mysql { .collect()) } -} \ No newline at end of file + async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> { + let db_prefix = SafeValue::Text( + format!("{}", db_config.db_prefix), + builder::ValidationLevel::Strict, + ); + let grammar = schema::generate_schema(super::DatabaseType::MySQL, db_prefix)?; + + let connection_str = format!( + "mysql://{}:{}@{}:{}", + db_config.user, db_config.password, db_config.host, db_config.port + ); + + let pool = MySqlPool::connect(&connection_str).await?; + + pool.execute(format!("CREATE DATABASE `{}`", db_config.db_name).as_str()) + .await?; + pool.execute( + format!( + "ALTER DATABASE `{}` DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci", + db_config.db_name + ) + .as_str(), + ) + .await?; + + let new_connection_str = format!( + "mysql://{}:{}@{}:{}/{}", + db_config.user, db_config.password, db_config.host, db_config.port, db_config.db_name + ); + let new_pool = MySqlPool::connect(&new_connection_str).await?; + + new_pool.execute(grammar.as_str()).await?; + Ok(()) + } +} diff --git a/backend/src/storage/sql/postgresql.rs b/backend/src/storage/sql/postgresql.rs index 64100f8..5b96f94 100644 --- a/backend/src/storage/sql/postgresql.rs +++ b/backend/src/storage/sql/postgresql.rs @@ -2,10 +2,9 @@ use super::{ builder::{self, SafeValue}, schema, DatabaseTrait, }; -use crate::common::error::{CustomError, CustomErrorInto, CustomResult}; +use crate::common::error::CustomResult; use crate::config; use async_trait::async_trait; -use chrono::{DateTime, Utc}; use serde_json::Value; use sqlx::{Column, Executor, PgPool, Row, TypeInfo}; use std::collections::HashMap; @@ -17,45 +16,10 @@ pub struct Postgresql { #[async_trait] impl DatabaseTrait for Postgresql { - async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> { - let db_prefix = SafeValue::Text( - format!("{}", db_config.db_prefix), - builder::ValidationLevel::Strict, - ); - let grammar = schema::generate_schema(schema::DatabaseType::PostgreSQL, db_prefix)?; - - let connection_str = format!( - "postgres://{}:{}@{}:{}", - db_config.user, db_config.password, db_config.address, db_config.port - ); - let pool = PgPool::connect(&connection_str).await?; - - pool.execute(format!("CREATE DATABASE {}", db_config.db_name).as_str()) - .await?; - - let new_connection_str = format!( - "postgres://{}:{}@{}:{}/{}", - db_config.user, - db_config.password, - db_config.address, - db_config.port, - db_config.db_name - ); - let new_pool = PgPool::connect(&new_connection_str).await?; - - new_pool.execute(grammar.as_str()).await?; - - Ok(()) - } - async fn connect(db_config: &config::SqlConfig) -> CustomResult { let connection_str = format!( "postgres://{}:{}@{}:{}/{}", - db_config.user, - db_config.password, - db_config.address, - db_config.port, - db_config.db_name + db_config.user, db_config.password, db_config.host, db_config.port, db_config.db_name ); let pool = PgPool::connect(&connection_str).await?; @@ -109,10 +73,31 @@ impl DatabaseTrait for Postgresql { }) .collect()) } -} -impl Postgresql { - fn get_sdb(&self){ - let a=self.pool; + async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> { + let db_prefix = SafeValue::Text( + format!("{}", db_config.db_prefix), + builder::ValidationLevel::Strict, + ); + let grammar = schema::generate_schema(super::DatabaseType::PostgreSQL, db_prefix)?; + + let connection_str = format!( + "postgres://{}:{}@{}:{}", + db_config.user, db_config.password, db_config.host, db_config.port + ); + let pool = PgPool::connect(&connection_str).await?; + + pool.execute(format!("CREATE DATABASE {}", db_config.db_name).as_str()) + .await?; + + let new_connection_str = format!( + "postgres://{}:{}@{}:{}/{}", + db_config.user, db_config.password, db_config.host, db_config.port, db_config.db_name + ); + let new_pool = PgPool::connect(&new_connection_str).await?; + + new_pool.execute(grammar.as_str()).await?; + + Ok(()) } -} \ No newline at end of file +} diff --git a/backend/src/storage/sql/schema.rs b/backend/src/storage/sql/schema.rs index f3145b6..4a2c376 100644 --- a/backend/src/storage/sql/schema.rs +++ b/backend/src/storage/sql/schema.rs @@ -1,13 +1,7 @@ use super::builder::{Condition, Identifier, Operator, SafeValue, ValidationLevel, WhereClause}; +use super::DatabaseType; use crate::common::error::{CustomErrorInto, CustomResult}; -use std::{collections::HashMap, fmt::format}; - -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum DatabaseType { - PostgreSQL, - MySQL, - SQLite, -} +use std::fmt::Display; #[derive(Debug, Clone, PartialEq)] pub enum FieldType { @@ -46,16 +40,17 @@ pub struct ForeignKey { pub on_update: Option, } -impl ToString for ForeignKeyAction { - fn to_string(&self) -> String { - match self { +impl Display for ForeignKeyAction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let str = match self { ForeignKeyAction::Cascade => "CASCADE", ForeignKeyAction::Restrict => "RESTRICT", ForeignKeyAction::SetNull => "SET NULL", ForeignKeyAction::NoAction => "NO ACTION", ForeignKeyAction::SetDefault => "SET DEFAULT", } - .to_string() + .to_string(); + write!(f, "{}", str) } } @@ -216,7 +211,7 @@ impl Field { "{} {} {}", field_name, condition.operator.as_str(), - value.to_sql_string()? + value.to_string()? )) } else { Err("Missing value for comparison".into_custom_error()) @@ -256,7 +251,7 @@ impl Field { } } if let Some(default) = &self.constraints.default_value { - sql.push_str(&format!(" DEFAULT {}", default.to_sql_string()?)); + sql.push_str(&format!(" DEFAULT {}", default.to_string()?)); } if let Some(check) = &self.constraints.check_constraint { let check_sql = Self::build_check_constraint(check)?; @@ -332,7 +327,7 @@ impl Index { }) } - fn to_sql(&self, table_name: &str, db_type: DatabaseType) -> CustomResult { + fn to_sql(&self, table_name: &str, _db_type: DatabaseType) -> CustomResult { let unique = if self.is_unique { "UNIQUE " } else { "" }; Ok(format!( "CREATE {}INDEX {} ON {} ({});", @@ -356,9 +351,7 @@ pub struct SchemaBuilder { impl SchemaBuilder { pub fn new() -> Self { - Self { - tables: Vec::new(), - } + Self { tables: Vec::new() } } pub fn add_table(&mut self, table: Table) -> CustomResult<&mut Self> { @@ -376,14 +369,14 @@ impl SchemaBuilder { } } -pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResult { - let db_prefix=db_prefix.to_sql_string()?; +pub fn generate_schema(db_type: DatabaseType, db_prefix: SafeValue) -> CustomResult { + let db_prefix = db_prefix.to_string()?; let mut schema = SchemaBuilder::new(); let user_level = "('contributor', 'administrator')"; let content_state = "('draft', 'published', 'private', 'hidden')"; // 用户表 - let mut users_table = Table::new(&format!("{}users",db_prefix))?; + let mut users_table = Table::new(&format!("{}users", db_prefix))?; users_table .add_field(Field::new( "username", @@ -423,7 +416,10 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul .check(WhereClause::Condition(Condition::new( "role".to_string(), Operator::In, - Some(SafeValue::Text(user_level.to_string(), ValidationLevel::Relaxed)), + Some(SafeValue::Text( + user_level.to_string(), + ValidationLevel::Relaxed, + )), )?)), ValidationLevel::Strict, )?) @@ -458,8 +454,8 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul schema.add_table(users_table)?; // 独立页面表 - - let mut pages_table = Table::new(&format!("{}pages",db_prefix))?; + + let mut pages_table = Table::new(&format!("{}pages", db_prefix))?; pages_table .add_field(Field::new( "id", @@ -511,16 +507,18 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul .check(WhereClause::Condition(Condition::new( "status".to_string(), Operator::In, - Some(SafeValue::Text(content_state.to_string(), ValidationLevel::Standard)), + Some(SafeValue::Text( + content_state.to_string(), + ValidationLevel::Standard, + )), )?)), ValidationLevel::Strict, )?); schema.add_table(pages_table)?; - // posts 表 - let mut posts_table = Table::new(&format!("{}posts",db_prefix))?; + let mut posts_table = Table::new(&format!("{}posts", db_prefix))?; posts_table .add_field(Field::new( "id", @@ -533,7 +531,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul FieldType::VarChar(100), FieldConstraint::new() .not_null() - .foreign_key(format!("{}users",db_prefix), "username".to_string()) + .foreign_key(format!("{}users", db_prefix), "username".to_string()) .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ValidationLevel::Strict, @@ -576,7 +574,10 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul .check(WhereClause::Condition(Condition::new( "status".to_string(), Operator::In, - Some(SafeValue::Text(content_state.to_string(), ValidationLevel::Standard)), + Some(SafeValue::Text( + content_state.to_string(), + ValidationLevel::Standard, + )), )?)), ValidationLevel::Strict, )?) @@ -622,7 +623,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul schema.add_table(posts_table)?; // 标签表 - let mut tags_tables = Table::new(&format!("{}tags",db_prefix))?; + let mut tags_tables = Table::new(&format!("{}tags", db_prefix))?; tags_tables .add_field(Field::new( "name", @@ -636,28 +637,28 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul FieldConstraint::new(), ValidationLevel::Strict, )?); - + schema.add_table(tags_tables)?; - // 文章标签 - let mut post_tags_tables = Table::new(&format!("{}post_tags",db_prefix))?; + let mut post_tags_tables = Table::new(&format!("{}post_tags", db_prefix))?; post_tags_tables .add_field(Field::new( "post_id", FieldType::Integer(false), FieldConstraint::new() .not_null() - .foreign_key(format!("{}posts",db_prefix), "id".to_string()) + .foreign_key(format!("{}posts", db_prefix), "id".to_string()) .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ValidationLevel::Strict, - )?).add_field(Field::new( + )?) + .add_field(Field::new( "tag_id", FieldType::VarChar(50), FieldConstraint::new() .not_null() - .foreign_key(format!("{}tags",db_prefix), "name".to_string()) + .foreign_key(format!("{}tags", db_prefix), "name".to_string()) .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ValidationLevel::Strict, @@ -672,8 +673,8 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul schema.add_table(post_tags_tables)?; // 分类表 - - let mut categories_table = Table::new(&format!("{}categories",db_prefix))?; + + let mut categories_table = Table::new(&format!("{}categories", db_prefix))?; categories_table .add_field(Field::new( "name", @@ -685,21 +686,21 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul "parent_id", FieldType::VarChar(50), FieldConstraint::new() - .foreign_key(format!("{}categories",db_prefix), "name".to_string()), + .foreign_key(format!("{}categories", db_prefix), "name".to_string()), ValidationLevel::Strict, )?); schema.add_table(categories_table)?; // 文章分类关联表 - let mut post_categories_table = Table::new(&format!("{}post_categories",db_prefix))?; + let mut post_categories_table = Table::new(&format!("{}post_categories", db_prefix))?; post_categories_table .add_field(Field::new( "post_id", FieldType::Integer(false), FieldConstraint::new() .not_null() - .foreign_key(format!("{}posts",db_prefix), "id".to_string()) + .foreign_key(format!("{}posts", db_prefix), "id".to_string()) .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ValidationLevel::Strict, @@ -709,7 +710,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul FieldType::VarChar(50), FieldConstraint::new() .not_null() - .foreign_key(format!("{}categories",db_prefix), "name".to_string()) + .foreign_key(format!("{}categories", db_prefix), "name".to_string()) .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ValidationLevel::Strict, @@ -724,7 +725,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul schema.add_table(post_categories_table)?; // 资源库表 - let mut resources_table = Table::new(&format!("{}resources",db_prefix))?; + let mut resources_table = Table::new(&format!("{}resources", db_prefix))?; resources_table .add_field(Field::new( "id", @@ -737,7 +738,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul FieldType::VarChar(100), FieldConstraint::new() .not_null() - .foreign_key(format!("{}users",db_prefix), "username".to_string()) + .foreign_key(format!("{}users", db_prefix), "username".to_string()) .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ValidationLevel::Strict, @@ -791,7 +792,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul schema.add_table(resources_table)?; // 配置表 - let mut settings_table = Table::new(&format!("{}settings",db_prefix))?; + let mut settings_table = Table::new(&format!("{}settings", db_prefix))?; settings_table .add_field(Field::new( "name", @@ -806,7 +807,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul ValidationLevel::Strict, )?); - schema.add_table(settings_table)?; + schema.add_table(settings_table)?; schema.build(db_type) } diff --git a/backend/src/storage/sql/sqllite.rs b/backend/src/storage/sql/sqllite.rs index cd4d3e9..4d06e5d 100644 --- a/backend/src/storage/sql/sqllite.rs +++ b/backend/src/storage/sql/sqllite.rs @@ -2,12 +2,11 @@ use super::{ builder::{self, SafeValue}, schema, DatabaseTrait, }; -use crate::common::error::{CustomError, CustomErrorInto, CustomResult}; +use crate::common::error::{CustomErrorInto, CustomResult}; use crate::config; use async_trait::async_trait; -use chrono::{DateTime, Utc}; use serde_json::Value; -use sqlx::{Column, Executor, SqlitePool, Row, TypeInfo}; +use sqlx::{Column, Executor, Row, SqlitePool, TypeInfo}; use std::collections::HashMap; use std::env; @@ -18,40 +17,19 @@ pub struct Sqlite { #[async_trait] impl DatabaseTrait for Sqlite { - async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> { - let db_prefix = SafeValue::Text( - format!("{}", db_config.db_prefix), - builder::ValidationLevel::Strict, - ); - - let sqlite_dir = env::current_dir()?.join("assets").join("sqllite"); - std::fs::create_dir_all(&sqlite_dir)?; - - let db_file = sqlite_dir.join(&db_config.db_name); - std::fs::File::create(&db_file)?; - - let path = db_file.to_str().ok_or("Unable to get sqllite path".into_custom_error())?; - let grammar = schema::generate_schema(schema::DatabaseType::SQLite, db_prefix)?; - - let connection_str = format!("sqlite:///{}", path); - let pool = SqlitePool::connect(&connection_str).await?; - - pool.execute(grammar.as_str()).await?; - - Ok(()) - } - async fn connect(db_config: &config::SqlConfig) -> CustomResult { let db_file = env::current_dir()? .join("assets") .join("sqllite") .join(&db_config.db_name); - + if !db_file.exists() { - return Err("SQLite database file does not exist".into_custom_error()); + return Err("SQLite数据库文件不存在".into_custom_error()); } - - let path = db_file.to_str().ok_or("Unable to get sqllite path".into_custom_error())?; + + let path = db_file + .to_str() + .ok_or("无法获取SQLite路径".into_custom_error())?; let connection_str = format!("sqlite:///{}", path); let pool = SqlitePool::connect(&connection_str).await?; @@ -104,4 +82,31 @@ impl DatabaseTrait for Sqlite { }) .collect()) } + + async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> { + let db_prefix = SafeValue::Text( + format!("{}", db_config.db_prefix), + builder::ValidationLevel::Strict, + ); + + let sqlite_dir = env::current_dir()?.join("assets").join("sqllite"); + std::fs::create_dir_all(&sqlite_dir)?; + + let db_file = sqlite_dir.join(&db_config.db_name); + std::fs::File::create(&db_file)?; + + let path = db_file + .to_str() + .ok_or("Unable to get sqllite path".into_custom_error())?; + let grammar = schema::generate_schema(super::DatabaseType::SQLite, db_prefix)?; + + println!("\n{}\n", grammar); + + let connection_str = format!("sqlite:///{}", path); + let pool = SqlitePool::connect(&connection_str).await?; + + pool.execute(grammar.as_str()).await?; + + Ok(()) + } } diff --git a/frontend/app/env.d.ts b/frontend/app/env.d.ts deleted file mode 100644 index 0d45b87..0000000 --- a/frontend/app/env.d.ts +++ /dev/null @@ -1,18 +0,0 @@ -// File path: app/end.d.ts - -/** - * 配置 - */ - -/// - -interface ImportMetaEnv { - readonly VITE_INIT_STATUS: string; - readonly VITE_SERVER_API: string; - readonly VITE_PORT: string; - readonly VITE_ADDRESS: string; -} - -interface ImportMeta { - readonly env: ImportMetaEnv; -} diff --git a/frontend/app/env.ts b/frontend/app/env.ts new file mode 100644 index 0000000..d747d6f --- /dev/null +++ b/frontend/app/env.ts @@ -0,0 +1,25 @@ +export interface EnvConfig { + VITE_PORT: string; + VITE_ADDRESS: string; + VITE_INIT_STATUS: string; + VITE_API_BASE_URL: string; + VITE_API_USERNAME: string; + VITE_API_PASSWORD: string; +} + +export const DEFAULT_CONFIG: EnvConfig = { + VITE_PORT: "22100", + VITE_ADDRESS: "localhost", + VITE_INIT_STATUS: "0", + VITE_API_BASE_URL: "http://127.0.0.1:22000", + VITE_API_USERNAME: "", + VITE_API_PASSWORD: "", +} as const; + +// 扩展 ImportMeta 接口 +declare global { + interface ImportMetaEnv extends EnvConfig {} + interface ImportMeta { + readonly env: ImportMetaEnv; + } +} diff --git a/frontend/app/init.tsx b/frontend/app/init.tsx index 92a5194..b950465 100644 --- a/frontend/app/init.tsx +++ b/frontend/app/init.tsx @@ -1,5 +1,9 @@ import React, { createContext, useState, useEffect } from "react"; -import { useApi } from "hooks/servicesProvider"; +import {useHttp} from 'hooks/servicesProvider' +import { message} from "hooks/message"; +import {DEFAULT_CONFIG} from "app/env" + + interface SetupContextType { currentStep: number; setCurrentStep: (step: number) => void; @@ -13,10 +17,8 @@ const SetupContext = createContext({ // 步骤组件的通用属性接口 interface StepProps { onNext: () => void; - onPrev?: () => void; } -// 通用的步骤容器组件 const StepContainer: React.FC<{ title: string; children: React.ReactNode }> = ({ title, children, @@ -32,13 +34,22 @@ const StepContainer: React.FC<{ title: string; children: React.ReactNode }> = ({ ); // 通用的导航按钮组件 -const NavigationButtons: React.FC = ({ onNext }) => ( +const NavigationButtons: React.FC = ({ + onNext, + loading = false, + disabled = false +}) => (
); @@ -49,14 +60,16 @@ const InputField: React.FC<{ name: string; defaultValue?: string | number; hint?: string; -}> = ({ label, name, defaultValue, hint }) => ( + required?: boolean; +}> = ({ label, name, defaultValue, hint, required = true }) => (

- {label} + {label} {required && *}

{hint && ( @@ -80,6 +93,102 @@ const Introduction: React.FC = ({ onNext }) => ( const DatabaseConfig: React.FC = ({ onNext }) => { const [dbType, setDbType] = useState("postgresql"); + const [loading, setLoading] = useState(false); + const api = useHttp(); + + const validateForm = () => { + const getRequiredFields = () => { + switch (dbType) { + case 'sqllite': + return ['db_prefix', 'db_name']; + case 'postgresql': + case 'mysql': + return ['db_host', 'db_prefix', 'db_port', 'db_user', 'db_password', 'db_name']; + default: + return []; + } + }; + + const requiredFields = getRequiredFields(); + const emptyFields: string[] = []; + + requiredFields.forEach(field => { + const input = document.querySelector(`[name="${field}"]`) as HTMLInputElement; + if (input && (!input.value || input.value.trim() === '')) { + emptyFields.push(field); + } + }); + + if (emptyFields.length > 0) { + const fieldNames = emptyFields.map(field => { + switch (field) { + case 'db_host': return '数据库地址'; + case 'db_prefix': return '数据库前缀'; + case 'db_port': return '端口'; + case 'db_user': return '用户名'; + case 'db_password': return '密码'; + case 'db_name': return '数据库名'; + default: return field; + } + }); + message.error(`请填写以下必填项:${fieldNames.join('、')}`); + return false; + } + return true; + }; + + const handleNext = async () => { + if (!validateForm()) { + return; + } + + setLoading(true); + try { + const formData = { + db_type: dbType, + host: (document.querySelector('[name="db_host"]') as HTMLInputElement)?.value?.trim()??"", + db_prefix: (document.querySelector('[name="db_prefix"]') as HTMLInputElement)?.value?.trim()??"", + port: Number((document.querySelector('[name="db_port"]') as HTMLInputElement)?.value?.trim()??0), + user: (document.querySelector('[name="db_user"]') as HTMLInputElement)?.value?.trim()??"", + password: (document.querySelector('[name="db_password"]') as HTMLInputElement)?.value?.trim()??"", + db_name: (document.querySelector('[name="db_name"]') as HTMLInputElement)?.value?.trim()??"", + }; + + await api.post('/sql', formData); + + let oldEnv = import.meta.env?? DEFAULT_CONFIG + + + const viteEnv = Object.entries(oldEnv).reduce((acc, [key, value]) => { + if (key.startsWith('VITE_')) { + acc[key] = value; + } + return acc; + }, {} as Record); + + + const newEnv = { + ...viteEnv, + VITE_INIT_STATUS: '2' + }; + + + await api.dev("/env", { + method: "POST", + body: JSON.stringify(newEnv), + }); + + Object.assign( newEnv) + + message.success('数据库配置成功!'); + setTimeout(() => onNext(), 1000); + } catch (error: any) { + console.error( error); + message.error(error.message ); + } finally { + setLoading(false); + } + }; return ( @@ -105,34 +214,40 @@ const DatabaseConfig: React.FC = ({ onNext }) => { label="数据库地址" name="db_host" defaultValue="localhost" - hint="通常使用 localhost" + hint="通常使 localhost" + required /> )} @@ -143,33 +258,39 @@ const DatabaseConfig: React.FC = ({ onNext }) => { name="db_host" defaultValue="localhost" hint="通常使用 localhost" + required /> )} @@ -180,40 +301,117 @@ const DatabaseConfig: React.FC = ({ onNext }) => { name="db_prefix" defaultValue="echoec_" hint="通常使用 echoec_" + required /> )} - +
); }; -const AdminConfig: React.FC = ({ onNext }) => ( - -
- - - - -
-
-); -const SetupComplete: React.FC = () => ( - -
-

- 恭喜!安装已完成,系统即将重启... -

-
-
-); +interface InstallReplyData { + token: string, + username: string, + password: string, +} + + +const AdminConfig: React.FC = ({ onNext }) => { + const [loading, setLoading] = useState(false); + const api = useHttp(); + + const handleNext = async () => { + setLoading(true); + try { + const formData = { + username: (document.querySelector('[name="admin_username"]') as HTMLInputElement)?.value, + password: (document.querySelector('[name="admin_password"]') as HTMLInputElement)?.value, + email: (document.querySelector('[name="admin_email"]') as HTMLInputElement)?.value, + }; + + const response = await api.post('/administrator', formData) as InstallReplyData; + const data = response; + + localStorage.setItem('token', data.token); + + let oldEnv = import.meta.env ?? DEFAULT_CONFIG; + const viteEnv = Object.entries(oldEnv).reduce((acc, [key, value]) => { + if (key.startsWith('VITE_')) { + acc[key] = value; + } + return acc; + }, {} as Record); + + const newEnv = { + ...viteEnv, + VITE_INIT_STATUS: '3', + VITE_API_USERNAME:data.username, + VITE_API_PASSWORD:data.password + }; + + + + await api.dev("/env", { + method: "POST", + body: JSON.stringify(newEnv), + }); + + message.success('管理员账号创建成功!'); + onNext(); + } catch (error: any) { + console.error(error); + message.error(error.message); + } finally { + setLoading(false); + } + }; + + return ( + +
+ + + + +
+
+ ); +}; + +const SetupComplete: React.FC = () => { + const api = useHttp(); + + + + return ( + +
+

+ 恭喜!安装已完成 +

+

+ 系统正在重启中,请稍候... +

+
+
+
+
+
+ ); +}; // 修改主题切换按钮组件 const ThemeToggle: React.FC = () => { @@ -260,7 +458,7 @@ const ThemeToggle: React.FC = () => { }; export default function SetupPage() { - let step = Number(import.meta.env.VITE_INIT_STATUS); + let step = Number(import.meta.env.VITE_INIT_STATUS)+1; const [currentStep, setCurrentStep] = useState(step); diff --git a/frontend/app/root.tsx b/frontend/app/root.tsx index 1e2ff67..ae76c45 100644 --- a/frontend/app/root.tsx +++ b/frontend/app/root.tsx @@ -7,6 +7,8 @@ import { } from "@remix-run/react"; import { BaseProvider } from "hooks/servicesProvider"; +import { MessageProvider } from "hooks/message"; +import { MessageContainer } from "hooks/message"; import "~/index.css"; @@ -22,7 +24,10 @@ export function Layout({ children }: { children: React.ReactNode }) { - + + + +