后端:优化重启,更新数据库配置字段为host,优化SQL查询构建逻辑;前端:改进环境变量读取逻辑,增加express来修改文件

This commit is contained in:
lsy 2024-11-30 02:15:46 +08:00
parent 3daf6280a7
commit 72db2c9de5
32 changed files with 1282 additions and 492 deletions

View File

@ -1,17 +1,10 @@
use crate::common::error::{AppResult, AppResultInto};
use crate::security; use crate::security;
use crate::storage::sql::builder; use crate::storage::sql::builder;
use crate::common::error::{AppResult, AppResultInto};
use crate::AppState; use crate::AppState;
use chrono::Duration; use chrono::Duration;
use rocket::{ use rocket::{http::Status, post, response::status, serde::json::Json, State};
http::Status,
post,
response::status,
serde::json::{Json, Value},
State,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json;
use std::sync::Arc; use std::sync::Arc;
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
pub struct TokenSystemData { pub struct TokenSystemData {
@ -23,13 +16,13 @@ pub async fn token_system(
state: &State<Arc<AppState>>, state: &State<Arc<AppState>>,
data: Json<TokenSystemData>, data: Json<TokenSystemData>,
) -> AppResult<String> { ) -> AppResult<String> {
let sql = state let sql = state.sql_get().await.into_app_result()?;
.sql_get() let mut builder = builder::QueryBuilder::new(
.await builder::SqlOperation::Select,
sql.table_name("users"),
sql.get_type(),
)
.into_app_result()?; .into_app_result()?;
let mut builder =
builder::QueryBuilder::new(builder::SqlOperation::Select, sql.table_name("users"), sql.get_type())
.into_app_result()?;
builder builder
.add_field("password_hash".to_string()) .add_field("password_hash".to_string())
.into_app_result()? .into_app_result()?
@ -69,7 +62,7 @@ pub async fn token_system(
), ),
])); ]));
let values = sql let values = sql
.get_db() .get_db()
.execute_query(&builder) .execute_query(&builder)
.await .await
@ -79,14 +72,10 @@ pub async fn token_system(
.first() .first()
.and_then(|row| row.get("password_hash")) .and_then(|row| row.get("password_hash"))
.and_then(|val| val.as_str()) .and_then(|val| val.as_str())
.ok_or_else(|| { .ok_or_else(|| status::Custom(Status::NotFound, "系统用户或密码无效".into()))?;
status::Custom(Status::NotFound, "Invalid system user or password".into())
})?;
println!("{}\n{}",&data.password,password.clone());
security::bcrypt::verify_hash(&data.password, password) security::bcrypt::verify_hash(&data.password, password)
.map_err(|_| status::Custom(Status::Forbidden, "Invalid password".into()))?; .map_err(|_| status::Custom(Status::Forbidden, "密码无效".into()))?;
Ok(security::jwt::generate_jwt( Ok(security::jwt::generate_jwt(
security::jwt::CustomClaims { security::jwt::CustomClaims {

View File

@ -1,12 +1,10 @@
use super::SystemToken; use super::SystemToken;
use crate::storage::{sql, sql::builder};
use crate::common::error::{AppResult, AppResultInto, CustomResult}; use crate::common::error::{AppResult, AppResultInto, CustomResult};
use crate::storage::{sql, sql::builder};
use crate::AppState; use crate::AppState;
use rocket::data;
use rocket::{ use rocket::{
get, get,
http::Status, http::Status,
response::status,
serde::json::{Json, Value}, serde::json::{Json, Value},
State, State,
}; };
@ -51,8 +49,11 @@ pub async fn get_setting(
let where_clause = builder::WhereClause::Condition(name_condition); let where_clause = builder::WhereClause::Condition(name_condition);
let mut sql_builder = let mut sql_builder = builder::QueryBuilder::new(
builder::QueryBuilder::new(builder::SqlOperation::Select, sql.table_name("settings"),sql.get_type())?; builder::SqlOperation::Select,
sql.table_name("settings"),
sql.get_type(),
)?;
sql_builder sql_builder
.add_condition(where_clause) .add_condition(where_clause)
@ -69,8 +70,11 @@ pub async fn insert_setting(
name: String, name: String,
data: Json<Value>, data: Json<Value>,
) -> CustomResult<()> { ) -> CustomResult<()> {
let mut builder = let mut builder = builder::QueryBuilder::new(
builder::QueryBuilder::new(builder::SqlOperation::Insert, sql.table_name("settings"),sql.get_type())?; builder::SqlOperation::Insert,
sql.table_name("settings"),
sql.get_type(),
)?;
builder.set_value( builder.set_value(
"name".to_string(), "name".to_string(),
builder::SafeValue::Text( builder::SafeValue::Text(
@ -80,7 +84,7 @@ pub async fn insert_setting(
)?; )?;
builder.set_value( builder.set_value(
"data".to_string(), "data".to_string(),
builder::SafeValue::Text(data.to_string(),builder::ValidationLevel::Relaxed), builder::SafeValue::Text(data.to_string(), builder::ValidationLevel::Relaxed),
)?; )?;
sql.get_db().execute_query(&builder).await?; sql.get_db().execute_query(&builder).await?;
Ok(()) Ok(())
@ -92,7 +96,7 @@ pub async fn system_config_get(
_token: SystemToken, _token: SystemToken,
) -> AppResult<Json<Value>> { ) -> AppResult<Json<Value>> {
let sql = state.sql_get().await.into_app_result()?; let sql = state.sql_get().await.into_app_result()?;
let settings = get_setting(&sql, "system".to_string(), sql.table_name("settings")) let settings = get_setting(&sql, "system".to_string(), sql.table_name("settings"))
.await .await
.into_app_result()?; .into_app_result()?;
Ok(settings) Ok(settings)

View File

@ -6,7 +6,6 @@ use crate::security;
use crate::storage::sql; use crate::storage::sql;
use crate::AppState; use crate::AppState;
use chrono::Duration; use chrono::Duration;
use rocket::data;
use rocket::{http::Status, post, response::status, serde::json::Json, State}; use rocket::{http::Status, post, response::status, serde::json::Json, State};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
@ -21,7 +20,7 @@ pub async fn setup_sql(
if config.init.sql { if config.init.sql {
return Err(status::Custom( return Err(status::Custom(
Status::BadRequest, Status::BadRequest,
"Database already initialized".to_string(), "数据库已经初始化".to_string(),
)); ));
} }
@ -34,7 +33,7 @@ pub async fn setup_sql(
.into_app_result()?; .into_app_result()?;
config::Config::write(config).into_app_result()?; config::Config::write(config).into_app_result()?;
state.trigger_restart().await.into_app_result()?; state.restart_server().await.into_app_result()?;
Ok("Database installation successful".to_string()) Ok("Database installation successful".to_string())
} }
@ -52,17 +51,16 @@ pub struct InstallReplyData {
password: String, password: String,
} }
#[post("/administrator", format = "application/json", data = "<data>")] #[post("/administrator", format = "application/json", data = "<data>")]
pub async fn setup_account( pub async fn setup_account(
data: Json<StepAccountData>, data: Json<StepAccountData>,
state: &State<Arc<AppState>>, state: &State<Arc<AppState>>,
) -> AppResult<status::Custom<Json<InstallReplyData>>> { ) -> AppResult<status::Custom<Json<InstallReplyData>>> {
let mut config = config::Config::read().unwrap_or_default(); let mut config = config::Config::read().unwrap_or_default();
if config.init.administrator { if config.init.administrator {
return Err(status::Custom( return Err(status::Custom(
Status::BadRequest, Status::BadRequest,
"Administrator user has been set".to_string(), "管理员用户已设置".to_string(),
)); ));
} }
@ -123,9 +121,9 @@ pub async fn setup_account(
Duration::days(7), Duration::days(7),
) )
.into_app_result()?; .into_app_result()?;
config.init.administrator=true; config.init.administrator = true;
config::Config::write(config).into_app_result()?; config::Config::write(config).into_app_result()?;
state.trigger_restart().await.into_app_result()?; state.restart_server().await.into_app_result()?;
Ok(status::Custom( Ok(status::Custom(
Status::Ok, Status::Ok,

View File

@ -1,10 +1,8 @@
use crate::security; use crate::common::error::{CustomErrorInto, CustomResult};
use crate::security::bcrypt; use crate::security::bcrypt;
use crate::storage::{sql, sql::builder}; use crate::storage::{sql, sql::builder};
use crate::common::error::{CustomErrorInto, CustomResult};
use rocket::{get, http::Status, post, response::status, serde::json::Json, State}; use rocket::{get, http::Status, post, response::status, serde::json::Json, State};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
pub struct LoginData { pub struct LoginData {
@ -20,16 +18,23 @@ pub struct RegisterData {
pub role: String, pub role: String,
} }
pub async fn insert_user(sql: &sql::Database , data: RegisterData) -> CustomResult<()> { pub async fn insert_user(sql: &sql::Database, data: RegisterData) -> CustomResult<()> {
let role = match data.role.as_str() { let role = match data.role.as_str() {
"administrator" | "contributor" => data.role, "administrator" | "contributor" => data.role,
_ => return Err("Invalid role. Must be either 'administrator' or 'contributor'".into_custom_error()), _ => {
return Err(
"Invalid role. Must be either 'administrator' or 'contributor'".into_custom_error(),
)
}
}; };
let password_hash = bcrypt::generate_hash(&data.password)?; let password_hash = bcrypt::generate_hash(&data.password)?;
let mut builder = let mut builder = builder::QueryBuilder::new(
builder::QueryBuilder::new(builder::SqlOperation::Insert, sql.table_name("users"), sql.get_type())?; builder::SqlOperation::Insert,
sql.table_name("users"),
sql.get_type(),
)?;
builder builder
.set_value( .set_value(
"username".to_string(), "username".to_string(),

View File

@ -42,19 +42,19 @@ impl Default for Init {
#[derive(Deserialize, Serialize, Debug, Clone)] #[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SqlConfig { pub struct SqlConfig {
pub db_type: String, pub db_type: String,
pub address: String, pub host: String,
pub port: u32, pub port: u32,
pub user: String, pub user: String,
pub password: String, pub password: String,
pub db_name: String, pub db_name: String,
pub db_prefix:String, pub db_prefix: String,
} }
impl Default for SqlConfig { impl Default for SqlConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
db_type: "sqllite".to_string(), db_type: "sqllite".to_string(),
address: "".to_string(), host: "".to_string(),
port: 0, port: 0,
user: "".to_string(), user: "".to_string(),
password: "".to_string(), password: "".to_string(),
@ -67,7 +67,7 @@ impl Default for SqlConfig {
#[derive(Deserialize, Serialize, Debug, Clone)] #[derive(Deserialize, Serialize, Debug, Clone)]
pub struct NoSqlConfig { pub struct NoSqlConfig {
pub db_type: String, pub db_type: String,
pub address: String, pub host: String,
pub port: u32, pub port: u32,
pub user: String, pub user: String,
pub password: String, pub password: String,
@ -78,7 +78,7 @@ impl Default for NoSqlConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
db_type: "postgresql".to_string(), db_type: "postgresql".to_string(),
address: "localhost".to_string(), host: "localhost".to_string(),
port: 5432, port: 5432,
user: "postgres".to_string(), user: "postgres".to_string(),
password: "postgres".to_string(), password: "postgres".to_string(),

View File

@ -1,3 +1,3 @@
pub mod config;
pub mod error; pub mod error;
pub mod helpers; pub mod helpers;
pub mod config;

View File

@ -1,7 +1,7 @@
mod security;
mod common;
mod storage;
mod api; mod api;
mod common;
mod security;
mod storage;
use crate::common::config; use crate::common::config;
use common::error::{CustomErrorInto, CustomResult}; use common::error::{CustomErrorInto, CustomResult};
@ -13,6 +13,7 @@ pub struct AppState {
db: Arc<Mutex<Option<sql::Database>>>, db: Arc<Mutex<Option<sql::Database>>>,
shutdown: Arc<Mutex<Option<Shutdown>>>, shutdown: Arc<Mutex<Option<Shutdown>>>,
restart_progress: Arc<Mutex<bool>>, restart_progress: Arc<Mutex<bool>>,
restart_attempts: Arc<Mutex<u32>>,
} }
impl AppState { impl AppState {
@ -21,11 +22,16 @@ impl AppState {
db: Arc::new(Mutex::new(None)), db: Arc::new(Mutex::new(None)),
shutdown: Arc::new(Mutex::new(None)), shutdown: Arc::new(Mutex::new(None)),
restart_progress: Arc::new(Mutex::new(false)), restart_progress: Arc::new(Mutex::new(false)),
restart_attempts: Arc::new(Mutex::new(0)),
} }
} }
pub async fn sql_get(&self) -> CustomResult<sql::Database> { pub async fn sql_get(&self) -> CustomResult<sql::Database> {
self.db.lock().await.clone().ok_or_else(|| "数据库未连接".into_custom_error()) self.db
.lock()
.await
.clone()
.ok_or_else(|| "数据库未连接".into_custom_error())
} }
pub async fn sql_link(&self, config: &config::SqlConfig) -> CustomResult<()> { pub async fn sql_link(&self, config: &config::SqlConfig) -> CustomResult<()> {
@ -33,14 +39,40 @@ impl AppState {
Ok(()) Ok(())
} }
pub async fn set_shutdown(&self, shutdown: Shutdown) { pub async fn set_shutdown(&self, shutdown: Shutdown) {
*self.shutdown.lock().await = Some(shutdown); *self.shutdown.lock().await = Some(shutdown);
} }
pub async fn trigger_restart(&self) -> CustomResult<()> { pub async fn trigger_restart(&self) -> CustomResult<()> {
*self.restart_progress.lock().await = true; *self.restart_progress.lock().await = true;
self.shutdown.lock().await.take().ok_or_else(|| "未能获取rocket的shutdown".into_custom_error())?.notify(); self.shutdown
.lock()
.await
.take()
.ok_or_else(|| "未能获取rocket的shutdown".into_custom_error())?
.notify();
Ok(())
}
pub async fn restart_server(&self) -> CustomResult<()> {
const MAX_RESTART_ATTEMPTS: u32 = 3;
const RESTART_DELAY_MS: u64 = 1000;
let mut attempts = self.restart_attempts.lock().await;
if *attempts >= MAX_RESTART_ATTEMPTS {
return Err("达到最大重启尝试次数".into_custom_error());
}
*attempts += 1;
*self.restart_progress.lock().await = true;
self.shutdown
.lock()
.await
.take()
.ok_or_else(|| "未能获取rocket的shutdown".into_custom_error())?
.notify();
Ok(()) Ok(())
} }
} }
@ -53,9 +85,13 @@ async fn main() -> CustomResult<()> {
}); });
let state = Arc::new(AppState::new()); let state = Arc::new(AppState::new());
let rocket_config = rocket::Config::figment().merge(("address", config.address)).merge(("port", config.port)); let rocket_config = rocket::Config::figment()
.merge(("address", config.address))
.merge(("port", config.port));
let mut rocket_builder = rocket::build().configure(rocket_config).manage(state.clone()); let mut rocket_builder = rocket::build()
.configure(rocket_config)
.manage(state.clone());
if !config.init.sql { if !config.init.sql {
rocket_builder = rocket_builder.mount("/", rocket::routes![api::setup::setup_sql]); rocket_builder = rocket_builder.mount("/", rocket::routes![api::setup::setup_sql]);
@ -63,24 +99,47 @@ async fn main() -> CustomResult<()> {
rocket_builder = rocket_builder.mount("/", rocket::routes![api::setup::setup_account]); rocket_builder = rocket_builder.mount("/", rocket::routes![api::setup::setup_account]);
} else { } else {
state.sql_link(&config.sql_config).await?; state.sql_link(&config.sql_config).await?;
rocket_builder = rocket_builder.mount("/auth/token", api::jwt_routes()).mount("/config", api::configure_routes()); rocket_builder = rocket_builder
.mount("/auth/token", api::jwt_routes())
.mount("/config", api::configure_routes());
} }
let rocket = rocket_builder.ignite().await?; let rocket = rocket_builder.ignite().await?;
rocket.state::<Arc<AppState>>().ok_or_else(|| "未能获取AppState".into_custom_error())?.set_shutdown(rocket.shutdown()).await; rocket
.state::<Arc<AppState>>()
.ok_or_else(|| "无法获取AppState".into_custom_error())?
.set_shutdown(rocket.shutdown())
.await;
rocket.launch().await?; rocket.launch().await?;
if *state.restart_progress.lock().await { if *state.restart_progress.lock().await {
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
if let Ok(current_exe) = std::env::current_exe() { if let Ok(current_exe) = std::env::current_exe() {
match std::process::Command::new(current_exe).spawn() { println!("正在尝试重启服务器...");
Ok(_) => println!("成功启动新进程"),
Err(e) => eprintln!("启动新进程失败: {}", e), let mut command = std::process::Command::new(current_exe);
command.env("RUST_BACKTRACE", "1");
match command.spawn() {
Ok(child) => {
println!("成功启动新进程 (PID: {})", child.id());
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
Err(e) => {
eprintln!("启动新进程失败: {}", e);
*state.restart_progress.lock().await = false;
return Err(format!("重启失败: {}", e).into_custom_error());
}
}; };
} else { } else {
eprintln!("获取当前可执行文件路径失败"); eprintln!("获取当前可执行文件路径失败");
return Err("重启失败: 无法获取可执行文件路径".into_custom_error());
} }
} }
println!("服务器正常退出");
std::process::exit(0); std::process::exit(0);
} }

View File

@ -1,10 +1,10 @@
use super::DatabaseType;
use crate::common::error::{CustomErrorInto, CustomResult}; use crate::common::error::{CustomErrorInto, CustomResult};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use regex::Regex; use regex::Regex;
use serde::Serialize; use serde::Serialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::hash::Hash; use std::hash::Hash;
use crate::sql::schema::DatabaseType;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, Serialize)]
pub enum ValidationLevel { pub enum ValidationLevel {
@ -82,10 +82,10 @@ impl TextValidator {
let max_length = self let max_length = self
.level_max_lengths .level_max_lengths
.get(&level) .get(&level)
.ok_or("Invalid validation level".into_custom_error())?; .ok_or("无效的验证级别".into_custom_error())?;
if text.len() > *max_length { if text.len() > *max_length {
return Err("Text exceeds maximum length".into_custom_error()); return Err("文本超出最大长度限制".into_custom_error());
} }
if level == ValidationLevel::Relaxed { if level == ValidationLevel::Relaxed {
@ -103,7 +103,7 @@ impl TextValidator {
.iter() .iter()
.any(|&pattern| upper_text.contains(&pattern.to_uppercase())) .any(|&pattern| upper_text.contains(&pattern.to_uppercase()))
{ {
return Err("Potentially dangerous SQL pattern detected".into_custom_error()); return Err("检测到潜在危险的SQL模式".into_custom_error());
} }
Ok(()) Ok(())
} }
@ -112,14 +112,14 @@ impl TextValidator {
let allowed_chars = self let allowed_chars = self
.level_allowed_chars .level_allowed_chars
.get(&level) .get(&level)
.ok_or_else(|| "Invalid validation level".into_custom_error())?; .ok_or_else(|| "无效的验证级别".into_custom_error())?;
if let Some(invalid_char) = text if let Some(invalid_char) = text
.chars() .chars()
.find(|&c| !c.is_alphanumeric() && !allowed_chars.contains(&c)) .find(|&c| !c.is_alphanumeric() && !allowed_chars.contains(&c))
{ {
return Err( return Err(
format!("Invalid character '{}' for {:?} level", invalid_char, level) format!("'{}'字符在{:?}验证级别中是无效的", invalid_char, level)
.into_custom_error(), .into_custom_error(),
); );
} }
@ -128,7 +128,7 @@ impl TextValidator {
fn validate_special_chars(&self, text: &str) -> CustomResult<()> { fn validate_special_chars(&self, text: &str) -> CustomResult<()> {
if self.special_chars.iter().any(|&c| text.contains(c)) { if self.special_chars.iter().any(|&c| text.contains(c)) {
return Err("Invalid special character detected".into_custom_error()); return Err("检测到无效的特殊字符".into_custom_error());
} }
Ok(()) Ok(())
} }
@ -179,7 +179,6 @@ impl std::fmt::Display for SafeValue {
} }
impl SafeValue { impl SafeValue {
fn get_sql_type(&self) -> CustomResult<String> { fn get_sql_type(&self) -> CustomResult<String> {
let sql_type = match self { let sql_type = match self {
SafeValue::Null => "NULL", SafeValue::Null => "NULL",
@ -192,7 +191,7 @@ impl SafeValue {
Ok(sql_type.to_string()) Ok(sql_type.to_string())
} }
pub fn to_sql_string(&self) -> CustomResult<String> { pub fn to_string(&self) -> CustomResult<String> {
match self { match self {
SafeValue::Null => Ok("NULL".to_string()), SafeValue::Null => Ok("NULL".to_string()),
SafeValue::Bool(b) => Ok(if *b { "true" } else { "false" }.to_string()), SafeValue::Bool(b) => Ok(if *b { "true" } else { "false" }.to_string()),
@ -225,9 +224,9 @@ pub struct Identifier(String);
impl Identifier { impl Identifier {
pub fn new(value: String) -> CustomResult<Self> { pub fn new(value: String) -> CustomResult<Self> {
let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_\.]{0,63}$")?; let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_.]{0,63}$")?;
if !valid_pattern.is_match(&value) { if !valid_pattern.is_match(&value) {
return Err("Invalid identifier format".into_custom_error()); return Err("标识符格式无效".into_custom_error());
} }
Ok(Identifier(value)) Ok(Identifier(value))
} }
@ -314,7 +313,11 @@ pub struct QueryBuilder {
} }
impl QueryBuilder { impl QueryBuilder {
pub fn new(operation: SqlOperation, table: String, db_type: DatabaseType) -> CustomResult<Self> { pub fn new(
operation: SqlOperation,
table: String,
db_type: DatabaseType,
) -> CustomResult<Self> {
Ok(QueryBuilder { Ok(QueryBuilder {
operation, operation,
table: Identifier::new(table)?, table: Identifier::new(table)?,

View File

@ -1,14 +1,30 @@
mod postgresql;
mod mysql;
mod sqllite;
pub mod builder; pub mod builder;
mod mysql;
mod postgresql;
mod schema; mod schema;
mod sqllite;
use crate::config;
use crate::common::error::{CustomErrorInto, CustomResult}; use crate::common::error::{CustomErrorInto, CustomResult};
use crate::config;
use async_trait::async_trait; use async_trait::async_trait;
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use schema::DatabaseType;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DatabaseType {
PostgreSQL,
MySQL,
SQLite,
}
impl std::fmt::Display for DatabaseType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DatabaseType::PostgreSQL => write!(f, "postgresql"),
DatabaseType::MySQL => write!(f, "mysql"),
DatabaseType::SQLite => write!(f, "sqlite"),
}
}
}
#[async_trait] #[async_trait]
pub trait DatabaseTrait: Send + Sync { pub trait DatabaseTrait: Send + Sync {
@ -28,7 +44,7 @@ pub trait DatabaseTrait: Send + Sync {
pub struct Database { pub struct Database {
pub db: Arc<Box<dyn DatabaseTrait>>, pub db: Arc<Box<dyn DatabaseTrait>>,
pub prefix: Arc<String>, pub prefix: Arc<String>,
pub db_type: Arc<String> pub db_type: Arc<DatabaseType>,
} }
impl Database { impl Database {
@ -40,20 +56,16 @@ impl Database {
&self.prefix &self.prefix
} }
pub fn get_type(&self) -> DatabaseType {
match self.db_type.as_str() {
"postgresql" => DatabaseType::PostgreSQL,
"mysql" => DatabaseType::MySQL,
_ => DatabaseType::SQLite,
}
}
pub fn table_name(&self, name: &str) -> String { pub fn table_name(&self, name: &str) -> String {
format!("{}{}", self.prefix, name) format!("{}{}", self.prefix, name)
} }
pub fn get_type(&self) -> DatabaseType {
*self.db_type.clone()
}
pub async fn link(database: &config::SqlConfig) -> CustomResult<Self> { pub async fn link(database: &config::SqlConfig) -> CustomResult<Self> {
let db: Box<dyn DatabaseTrait> = match database.db_type.as_str() { let db: Box<dyn DatabaseTrait> = match database.db_type.to_lowercase().as_str() {
"postgresql" => Box::new(postgresql::Postgresql::connect(database).await?), "postgresql" => Box::new(postgresql::Postgresql::connect(database).await?),
"mysql" => Box::new(mysql::Mysql::connect(database).await?), "mysql" => Box::new(mysql::Mysql::connect(database).await?),
"sqllite" => Box::new(sqllite::Sqlite::connect(database).await?), "sqllite" => Box::new(sqllite::Sqlite::connect(database).await?),
@ -63,12 +75,17 @@ impl Database {
Ok(Self { Ok(Self {
db: Arc::new(db), db: Arc::new(db),
prefix: Arc::new(database.db_prefix.clone()), prefix: Arc::new(database.db_prefix.clone()),
db_type: Arc::new(database.db_type.clone()) db_type: Arc::new(match database.db_type.to_lowercase().as_str() {
"postgresql" => DatabaseType::PostgreSQL,
"mysql" => DatabaseType::MySQL,
"sqllite" => DatabaseType::SQLite,
_ => return Err("unknown database type".into_custom_error()),
}),
}) })
} }
pub async fn initial_setup(database: config::SqlConfig) -> CustomResult<()> { pub async fn initial_setup(database: config::SqlConfig) -> CustomResult<()> {
match database.db_type.as_str() { match database.db_type.to_lowercase().as_str() {
"postgresql" => postgresql::Postgresql::initialization(database).await?, "postgresql" => postgresql::Postgresql::initialization(database).await?,
"mysql" => mysql::Mysql::initialization(database).await?, "mysql" => mysql::Mysql::initialization(database).await?,
"sqllite" => sqllite::Sqlite::initialization(database).await?, "sqllite" => sqllite::Sqlite::initialization(database).await?,
@ -76,5 +93,4 @@ impl Database {
}; };
Ok(()) Ok(())
} }
} }

View File

@ -1,12 +1,14 @@
use super::{builder::{self, SafeValue}, schema, DatabaseTrait}; use super::{
use crate::config; builder::{self, SafeValue},
schema, DatabaseTrait,
};
use crate::common::error::CustomResult; use crate::common::error::CustomResult;
use crate::config;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use sqlx::mysql::MySqlPool; use sqlx::mysql::MySqlPool;
use sqlx::{Column, Executor, Row, TypeInfo}; use sqlx::{Column, Executor, Row, TypeInfo};
use std::collections::HashMap; use std::collections::HashMap;
use chrono::{DateTime, Utc};
#[derive(Clone)] #[derive(Clone)]
pub struct Mysql { pub struct Mysql {
@ -15,51 +17,16 @@ pub struct Mysql {
#[async_trait] #[async_trait]
impl DatabaseTrait for Mysql { impl DatabaseTrait for Mysql {
async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> {
let db_prefix = SafeValue::Text(format!("{}",db_config.db_prefix), builder::ValidationLevel::Strict);
let grammar = schema::generate_schema(schema::DatabaseType::MySQL,db_prefix)?;
let connection_str = format!(
"mysql://{}:{}@{}:{}",
db_config.user, db_config.password, db_config.address, db_config.port
);
let pool = MySqlPool::connect(&connection_str).await?;
pool.execute(format!("CREATE DATABASE `{}`", db_config.db_name).as_str()).await?;
pool.execute(format!(
"ALTER DATABASE `{}` DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci",
db_config.db_name
).as_str()).await?;
let new_connection_str = format!(
"mysql://{}:{}@{}:{}/{}",
db_config.user,
db_config.password,
db_config.address,
db_config.port,
db_config.db_name
);
let new_pool = MySqlPool::connect(&new_connection_str).await?;
new_pool.execute(grammar.as_str()).await?;
Ok(())
}
async fn connect(db_config: &config::SqlConfig) -> CustomResult<Self> { async fn connect(db_config: &config::SqlConfig) -> CustomResult<Self> {
let connection_str = format!( let connection_str = format!(
"mysql://{}:{}@{}:{}/{}", "mysql://{}:{}@{}:{}/{}",
db_config.user, db_config.user, db_config.password, db_config.host, db_config.port, db_config.db_name
db_config.password,
db_config.address,
db_config.port,
db_config.db_name
); );
let pool = MySqlPool::connect(&connection_str).await?; let pool = MySqlPool::connect(&connection_str).await?;
Ok(Mysql { pool }) Ok(Mysql { pool })
} }
async fn execute_query<'a>( async fn execute_query<'a>(
&'a self, &'a self,
builder: &builder::QueryBuilder, builder: &builder::QueryBuilder,
@ -107,4 +74,38 @@ impl DatabaseTrait for Mysql {
.collect()) .collect())
} }
async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> {
let db_prefix = SafeValue::Text(
format!("{}", db_config.db_prefix),
builder::ValidationLevel::Strict,
);
let grammar = schema::generate_schema(super::DatabaseType::MySQL, db_prefix)?;
let connection_str = format!(
"mysql://{}:{}@{}:{}",
db_config.user, db_config.password, db_config.host, db_config.port
);
let pool = MySqlPool::connect(&connection_str).await?;
pool.execute(format!("CREATE DATABASE `{}`", db_config.db_name).as_str())
.await?;
pool.execute(
format!(
"ALTER DATABASE `{}` DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci",
db_config.db_name
)
.as_str(),
)
.await?;
let new_connection_str = format!(
"mysql://{}:{}@{}:{}/{}",
db_config.user, db_config.password, db_config.host, db_config.port, db_config.db_name
);
let new_pool = MySqlPool::connect(&new_connection_str).await?;
new_pool.execute(grammar.as_str()).await?;
Ok(())
}
} }

View File

@ -2,10 +2,9 @@ use super::{
builder::{self, SafeValue}, builder::{self, SafeValue},
schema, DatabaseTrait, schema, DatabaseTrait,
}; };
use crate::common::error::{CustomError, CustomErrorInto, CustomResult}; use crate::common::error::CustomResult;
use crate::config; use crate::config;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde_json::Value; use serde_json::Value;
use sqlx::{Column, Executor, PgPool, Row, TypeInfo}; use sqlx::{Column, Executor, PgPool, Row, TypeInfo};
use std::collections::HashMap; use std::collections::HashMap;
@ -17,45 +16,10 @@ pub struct Postgresql {
#[async_trait] #[async_trait]
impl DatabaseTrait for Postgresql { impl DatabaseTrait for Postgresql {
async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> {
let db_prefix = SafeValue::Text(
format!("{}", db_config.db_prefix),
builder::ValidationLevel::Strict,
);
let grammar = schema::generate_schema(schema::DatabaseType::PostgreSQL, db_prefix)?;
let connection_str = format!(
"postgres://{}:{}@{}:{}",
db_config.user, db_config.password, db_config.address, db_config.port
);
let pool = PgPool::connect(&connection_str).await?;
pool.execute(format!("CREATE DATABASE {}", db_config.db_name).as_str())
.await?;
let new_connection_str = format!(
"postgres://{}:{}@{}:{}/{}",
db_config.user,
db_config.password,
db_config.address,
db_config.port,
db_config.db_name
);
let new_pool = PgPool::connect(&new_connection_str).await?;
new_pool.execute(grammar.as_str()).await?;
Ok(())
}
async fn connect(db_config: &config::SqlConfig) -> CustomResult<Self> { async fn connect(db_config: &config::SqlConfig) -> CustomResult<Self> {
let connection_str = format!( let connection_str = format!(
"postgres://{}:{}@{}:{}/{}", "postgres://{}:{}@{}:{}/{}",
db_config.user, db_config.user, db_config.password, db_config.host, db_config.port, db_config.db_name
db_config.password,
db_config.address,
db_config.port,
db_config.db_name
); );
let pool = PgPool::connect(&connection_str).await?; let pool = PgPool::connect(&connection_str).await?;
@ -109,10 +73,31 @@ impl DatabaseTrait for Postgresql {
}) })
.collect()) .collect())
} }
}
impl Postgresql { async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> {
fn get_sdb(&self){ let db_prefix = SafeValue::Text(
let a=self.pool; format!("{}", db_config.db_prefix),
builder::ValidationLevel::Strict,
);
let grammar = schema::generate_schema(super::DatabaseType::PostgreSQL, db_prefix)?;
let connection_str = format!(
"postgres://{}:{}@{}:{}",
db_config.user, db_config.password, db_config.host, db_config.port
);
let pool = PgPool::connect(&connection_str).await?;
pool.execute(format!("CREATE DATABASE {}", db_config.db_name).as_str())
.await?;
let new_connection_str = format!(
"postgres://{}:{}@{}:{}/{}",
db_config.user, db_config.password, db_config.host, db_config.port, db_config.db_name
);
let new_pool = PgPool::connect(&new_connection_str).await?;
new_pool.execute(grammar.as_str()).await?;
Ok(())
} }
} }

View File

@ -1,13 +1,7 @@
use super::builder::{Condition, Identifier, Operator, SafeValue, ValidationLevel, WhereClause}; use super::builder::{Condition, Identifier, Operator, SafeValue, ValidationLevel, WhereClause};
use super::DatabaseType;
use crate::common::error::{CustomErrorInto, CustomResult}; use crate::common::error::{CustomErrorInto, CustomResult};
use std::{collections::HashMap, fmt::format}; use std::fmt::Display;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DatabaseType {
PostgreSQL,
MySQL,
SQLite,
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum FieldType { pub enum FieldType {
@ -46,16 +40,17 @@ pub struct ForeignKey {
pub on_update: Option<ForeignKeyAction>, pub on_update: Option<ForeignKeyAction>,
} }
impl ToString for ForeignKeyAction { impl Display for ForeignKeyAction {
fn to_string(&self) -> String { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { let str = match self {
ForeignKeyAction::Cascade => "CASCADE", ForeignKeyAction::Cascade => "CASCADE",
ForeignKeyAction::Restrict => "RESTRICT", ForeignKeyAction::Restrict => "RESTRICT",
ForeignKeyAction::SetNull => "SET NULL", ForeignKeyAction::SetNull => "SET NULL",
ForeignKeyAction::NoAction => "NO ACTION", ForeignKeyAction::NoAction => "NO ACTION",
ForeignKeyAction::SetDefault => "SET DEFAULT", ForeignKeyAction::SetDefault => "SET DEFAULT",
} }
.to_string() .to_string();
write!(f, "{}", str)
} }
} }
@ -216,7 +211,7 @@ impl Field {
"{} {} {}", "{} {} {}",
field_name, field_name,
condition.operator.as_str(), condition.operator.as_str(),
value.to_sql_string()? value.to_string()?
)) ))
} else { } else {
Err("Missing value for comparison".into_custom_error()) Err("Missing value for comparison".into_custom_error())
@ -256,7 +251,7 @@ impl Field {
} }
} }
if let Some(default) = &self.constraints.default_value { if let Some(default) = &self.constraints.default_value {
sql.push_str(&format!(" DEFAULT {}", default.to_sql_string()?)); sql.push_str(&format!(" DEFAULT {}", default.to_string()?));
} }
if let Some(check) = &self.constraints.check_constraint { if let Some(check) = &self.constraints.check_constraint {
let check_sql = Self::build_check_constraint(check)?; let check_sql = Self::build_check_constraint(check)?;
@ -332,7 +327,7 @@ impl Index {
}) })
} }
fn to_sql(&self, table_name: &str, db_type: DatabaseType) -> CustomResult<String> { fn to_sql(&self, table_name: &str, _db_type: DatabaseType) -> CustomResult<String> {
let unique = if self.is_unique { "UNIQUE " } else { "" }; let unique = if self.is_unique { "UNIQUE " } else { "" };
Ok(format!( Ok(format!(
"CREATE {}INDEX {} ON {} ({});", "CREATE {}INDEX {} ON {} ({});",
@ -356,9 +351,7 @@ pub struct SchemaBuilder {
impl SchemaBuilder { impl SchemaBuilder {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self { tables: Vec::new() }
tables: Vec::new(),
}
} }
pub fn add_table(&mut self, table: Table) -> CustomResult<&mut Self> { pub fn add_table(&mut self, table: Table) -> CustomResult<&mut Self> {
@ -376,14 +369,14 @@ impl SchemaBuilder {
} }
} }
pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResult<String> { pub fn generate_schema(db_type: DatabaseType, db_prefix: SafeValue) -> CustomResult<String> {
let db_prefix=db_prefix.to_sql_string()?; let db_prefix = db_prefix.to_string()?;
let mut schema = SchemaBuilder::new(); let mut schema = SchemaBuilder::new();
let user_level = "('contributor', 'administrator')"; let user_level = "('contributor', 'administrator')";
let content_state = "('draft', 'published', 'private', 'hidden')"; let content_state = "('draft', 'published', 'private', 'hidden')";
// 用户表 // 用户表
let mut users_table = Table::new(&format!("{}users",db_prefix))?; let mut users_table = Table::new(&format!("{}users", db_prefix))?;
users_table users_table
.add_field(Field::new( .add_field(Field::new(
"username", "username",
@ -423,7 +416,10 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
.check(WhereClause::Condition(Condition::new( .check(WhereClause::Condition(Condition::new(
"role".to_string(), "role".to_string(),
Operator::In, Operator::In,
Some(SafeValue::Text(user_level.to_string(), ValidationLevel::Relaxed)), Some(SafeValue::Text(
user_level.to_string(),
ValidationLevel::Relaxed,
)),
)?)), )?)),
ValidationLevel::Strict, ValidationLevel::Strict,
)?) )?)
@ -459,7 +455,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
// 独立页面表 // 独立页面表
let mut pages_table = Table::new(&format!("{}pages",db_prefix))?; let mut pages_table = Table::new(&format!("{}pages", db_prefix))?;
pages_table pages_table
.add_field(Field::new( .add_field(Field::new(
"id", "id",
@ -511,16 +507,18 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
.check(WhereClause::Condition(Condition::new( .check(WhereClause::Condition(Condition::new(
"status".to_string(), "status".to_string(),
Operator::In, Operator::In,
Some(SafeValue::Text(content_state.to_string(), ValidationLevel::Standard)), Some(SafeValue::Text(
content_state.to_string(),
ValidationLevel::Standard,
)),
)?)), )?)),
ValidationLevel::Strict, ValidationLevel::Strict,
)?); )?);
schema.add_table(pages_table)?; schema.add_table(pages_table)?;
// posts 表 // posts 表
let mut posts_table = Table::new(&format!("{}posts",db_prefix))?; let mut posts_table = Table::new(&format!("{}posts", db_prefix))?;
posts_table posts_table
.add_field(Field::new( .add_field(Field::new(
"id", "id",
@ -533,7 +531,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
FieldType::VarChar(100), FieldType::VarChar(100),
FieldConstraint::new() FieldConstraint::new()
.not_null() .not_null()
.foreign_key(format!("{}users",db_prefix), "username".to_string()) .foreign_key(format!("{}users", db_prefix), "username".to_string())
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
ValidationLevel::Strict, ValidationLevel::Strict,
@ -576,7 +574,10 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
.check(WhereClause::Condition(Condition::new( .check(WhereClause::Condition(Condition::new(
"status".to_string(), "status".to_string(),
Operator::In, Operator::In,
Some(SafeValue::Text(content_state.to_string(), ValidationLevel::Standard)), Some(SafeValue::Text(
content_state.to_string(),
ValidationLevel::Standard,
)),
)?)), )?)),
ValidationLevel::Strict, ValidationLevel::Strict,
)?) )?)
@ -622,7 +623,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
schema.add_table(posts_table)?; schema.add_table(posts_table)?;
// 标签表 // 标签表
let mut tags_tables = Table::new(&format!("{}tags",db_prefix))?; let mut tags_tables = Table::new(&format!("{}tags", db_prefix))?;
tags_tables tags_tables
.add_field(Field::new( .add_field(Field::new(
"name", "name",
@ -639,25 +640,25 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
schema.add_table(tags_tables)?; schema.add_table(tags_tables)?;
// 文章标签 // 文章标签
let mut post_tags_tables = Table::new(&format!("{}post_tags",db_prefix))?; let mut post_tags_tables = Table::new(&format!("{}post_tags", db_prefix))?;
post_tags_tables post_tags_tables
.add_field(Field::new( .add_field(Field::new(
"post_id", "post_id",
FieldType::Integer(false), FieldType::Integer(false),
FieldConstraint::new() FieldConstraint::new()
.not_null() .not_null()
.foreign_key(format!("{}posts",db_prefix), "id".to_string()) .foreign_key(format!("{}posts", db_prefix), "id".to_string())
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
ValidationLevel::Strict, ValidationLevel::Strict,
)?).add_field(Field::new( )?)
.add_field(Field::new(
"tag_id", "tag_id",
FieldType::VarChar(50), FieldType::VarChar(50),
FieldConstraint::new() FieldConstraint::new()
.not_null() .not_null()
.foreign_key(format!("{}tags",db_prefix), "name".to_string()) .foreign_key(format!("{}tags", db_prefix), "name".to_string())
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
ValidationLevel::Strict, ValidationLevel::Strict,
@ -673,7 +674,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
// 分类表 // 分类表
let mut categories_table = Table::new(&format!("{}categories",db_prefix))?; let mut categories_table = Table::new(&format!("{}categories", db_prefix))?;
categories_table categories_table
.add_field(Field::new( .add_field(Field::new(
"name", "name",
@ -685,21 +686,21 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
"parent_id", "parent_id",
FieldType::VarChar(50), FieldType::VarChar(50),
FieldConstraint::new() FieldConstraint::new()
.foreign_key(format!("{}categories",db_prefix), "name".to_string()), .foreign_key(format!("{}categories", db_prefix), "name".to_string()),
ValidationLevel::Strict, ValidationLevel::Strict,
)?); )?);
schema.add_table(categories_table)?; schema.add_table(categories_table)?;
// 文章分类关联表 // 文章分类关联表
let mut post_categories_table = Table::new(&format!("{}post_categories",db_prefix))?; let mut post_categories_table = Table::new(&format!("{}post_categories", db_prefix))?;
post_categories_table post_categories_table
.add_field(Field::new( .add_field(Field::new(
"post_id", "post_id",
FieldType::Integer(false), FieldType::Integer(false),
FieldConstraint::new() FieldConstraint::new()
.not_null() .not_null()
.foreign_key(format!("{}posts",db_prefix), "id".to_string()) .foreign_key(format!("{}posts", db_prefix), "id".to_string())
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
ValidationLevel::Strict, ValidationLevel::Strict,
@ -709,7 +710,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
FieldType::VarChar(50), FieldType::VarChar(50),
FieldConstraint::new() FieldConstraint::new()
.not_null() .not_null()
.foreign_key(format!("{}categories",db_prefix), "name".to_string()) .foreign_key(format!("{}categories", db_prefix), "name".to_string())
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
ValidationLevel::Strict, ValidationLevel::Strict,
@ -724,7 +725,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
schema.add_table(post_categories_table)?; schema.add_table(post_categories_table)?;
// 资源库表 // 资源库表
let mut resources_table = Table::new(&format!("{}resources",db_prefix))?; let mut resources_table = Table::new(&format!("{}resources", db_prefix))?;
resources_table resources_table
.add_field(Field::new( .add_field(Field::new(
"id", "id",
@ -737,7 +738,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
FieldType::VarChar(100), FieldType::VarChar(100),
FieldConstraint::new() FieldConstraint::new()
.not_null() .not_null()
.foreign_key(format!("{}users",db_prefix), "username".to_string()) .foreign_key(format!("{}users", db_prefix), "username".to_string())
.on_delete(ForeignKeyAction::Cascade) .on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade), .on_update(ForeignKeyAction::Cascade),
ValidationLevel::Strict, ValidationLevel::Strict,
@ -791,7 +792,7 @@ pub fn generate_schema(db_type: DatabaseType,db_prefix:SafeValue) -> CustomResul
schema.add_table(resources_table)?; schema.add_table(resources_table)?;
// 配置表 // 配置表
let mut settings_table = Table::new(&format!("{}settings",db_prefix))?; let mut settings_table = Table::new(&format!("{}settings", db_prefix))?;
settings_table settings_table
.add_field(Field::new( .add_field(Field::new(
"name", "name",

View File

@ -2,12 +2,11 @@ use super::{
builder::{self, SafeValue}, builder::{self, SafeValue},
schema, DatabaseTrait, schema, DatabaseTrait,
}; };
use crate::common::error::{CustomError, CustomErrorInto, CustomResult}; use crate::common::error::{CustomErrorInto, CustomResult};
use crate::config; use crate::config;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde_json::Value; use serde_json::Value;
use sqlx::{Column, Executor, SqlitePool, Row, TypeInfo}; use sqlx::{Column, Executor, Row, SqlitePool, TypeInfo};
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
@ -18,29 +17,6 @@ pub struct Sqlite {
#[async_trait] #[async_trait]
impl DatabaseTrait for Sqlite { impl DatabaseTrait for Sqlite {
async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> {
let db_prefix = SafeValue::Text(
format!("{}", db_config.db_prefix),
builder::ValidationLevel::Strict,
);
let sqlite_dir = env::current_dir()?.join("assets").join("sqllite");
std::fs::create_dir_all(&sqlite_dir)?;
let db_file = sqlite_dir.join(&db_config.db_name);
std::fs::File::create(&db_file)?;
let path = db_file.to_str().ok_or("Unable to get sqllite path".into_custom_error())?;
let grammar = schema::generate_schema(schema::DatabaseType::SQLite, db_prefix)?;
let connection_str = format!("sqlite:///{}", path);
let pool = SqlitePool::connect(&connection_str).await?;
pool.execute(grammar.as_str()).await?;
Ok(())
}
async fn connect(db_config: &config::SqlConfig) -> CustomResult<Self> { async fn connect(db_config: &config::SqlConfig) -> CustomResult<Self> {
let db_file = env::current_dir()? let db_file = env::current_dir()?
.join("assets") .join("assets")
@ -48,10 +24,12 @@ impl DatabaseTrait for Sqlite {
.join(&db_config.db_name); .join(&db_config.db_name);
if !db_file.exists() { if !db_file.exists() {
return Err("SQLite database file does not exist".into_custom_error()); return Err("SQLite数据库文件不存在".into_custom_error());
} }
let path = db_file.to_str().ok_or("Unable to get sqllite path".into_custom_error())?; let path = db_file
.to_str()
.ok_or("无法获取SQLite路径".into_custom_error())?;
let connection_str = format!("sqlite:///{}", path); let connection_str = format!("sqlite:///{}", path);
let pool = SqlitePool::connect(&connection_str).await?; let pool = SqlitePool::connect(&connection_str).await?;
@ -104,4 +82,31 @@ impl DatabaseTrait for Sqlite {
}) })
.collect()) .collect())
} }
async fn initialization(db_config: config::SqlConfig) -> CustomResult<()> {
let db_prefix = SafeValue::Text(
format!("{}", db_config.db_prefix),
builder::ValidationLevel::Strict,
);
let sqlite_dir = env::current_dir()?.join("assets").join("sqllite");
std::fs::create_dir_all(&sqlite_dir)?;
let db_file = sqlite_dir.join(&db_config.db_name);
std::fs::File::create(&db_file)?;
let path = db_file
.to_str()
.ok_or("Unable to get sqllite path".into_custom_error())?;
let grammar = schema::generate_schema(super::DatabaseType::SQLite, db_prefix)?;
println!("\n{}\n", grammar);
let connection_str = format!("sqlite:///{}", path);
let pool = SqlitePool::connect(&connection_str).await?;
pool.execute(grammar.as_str()).await?;
Ok(())
}
} }

18
frontend/app/env.d.ts vendored
View File

@ -1,18 +0,0 @@
// File path: app/end.d.ts
/**
*
*/
/// <reference types="vite/client" />
interface ImportMetaEnv {
readonly VITE_INIT_STATUS: string;
readonly VITE_SERVER_API: string;
readonly VITE_PORT: string;
readonly VITE_ADDRESS: string;
}
interface ImportMeta {
readonly env: ImportMetaEnv;
}

25
frontend/app/env.ts Normal file
View File

@ -0,0 +1,25 @@
export interface EnvConfig {
VITE_PORT: string;
VITE_ADDRESS: string;
VITE_INIT_STATUS: string;
VITE_API_BASE_URL: string;
VITE_API_USERNAME: string;
VITE_API_PASSWORD: string;
}
export const DEFAULT_CONFIG: EnvConfig = {
VITE_PORT: "22100",
VITE_ADDRESS: "localhost",
VITE_INIT_STATUS: "0",
VITE_API_BASE_URL: "http://127.0.0.1:22000",
VITE_API_USERNAME: "",
VITE_API_PASSWORD: "",
} as const;
// 扩展 ImportMeta 接口
declare global {
interface ImportMetaEnv extends EnvConfig {}
interface ImportMeta {
readonly env: ImportMetaEnv;
}
}

View File

@ -1,5 +1,9 @@
import React, { createContext, useState, useEffect } from "react"; import React, { createContext, useState, useEffect } from "react";
import { useApi } from "hooks/servicesProvider"; import {useHttp} from 'hooks/servicesProvider'
import { message} from "hooks/message";
import {DEFAULT_CONFIG} from "app/env"
interface SetupContextType { interface SetupContextType {
currentStep: number; currentStep: number;
setCurrentStep: (step: number) => void; setCurrentStep: (step: number) => void;
@ -13,10 +17,8 @@ const SetupContext = createContext<SetupContextType>({
// 步骤组件的通用属性接口 // 步骤组件的通用属性接口
interface StepProps { interface StepProps {
onNext: () => void; onNext: () => void;
onPrev?: () => void;
} }
// 通用的步骤容器组件
const StepContainer: React.FC<{ title: string; children: React.ReactNode }> = ({ const StepContainer: React.FC<{ title: string; children: React.ReactNode }> = ({
title, title,
children, children,
@ -32,13 +34,22 @@ const StepContainer: React.FC<{ title: string; children: React.ReactNode }> = ({
); );
// 通用的导航按钮组件 // 通用的导航按钮组件
const NavigationButtons: React.FC<StepProps> = ({ onNext }) => ( const NavigationButtons: React.FC<StepProps & { loading?: boolean; disabled?: boolean }> = ({
onNext,
loading = false,
disabled = false
}) => (
<div className="flex justify-end mt-4"> <div className="flex justify-end mt-4">
<button <button
onClick={onNext} onClick={onNext}
className="px-6 py-2 rounded-lg bg-blue-500 hover:bg-blue-600 text-white transition-colors font-medium text-sm" disabled={loading || disabled}
className={`px-6 py-2 rounded-lg transition-colors font-medium text-sm
${loading || disabled
? 'bg-gray-400 cursor-not-allowed'
: 'bg-blue-500 hover:bg-blue-600 text-white'
}`}
> >
{loading ? '处理中...' : '下一步'}
</button> </button>
</div> </div>
); );
@ -49,14 +60,16 @@ const InputField: React.FC<{
name: string; name: string;
defaultValue?: string | number; defaultValue?: string | number;
hint?: string; hint?: string;
}> = ({ label, name, defaultValue, hint }) => ( required?: boolean;
}> = ({ label, name, defaultValue, hint, required = true }) => (
<div className="mb-6"> <div className="mb-6">
<h3 className="text-base font-medium text-custom-title-light dark:text-custom-title-dark mb-2"> <h3 className="text-base font-medium text-custom-title-light dark:text-custom-title-dark mb-2">
{label} {label} {required && <span className="text-red-500">*</span>}
</h3> </h3>
<input <input
name={name} name={name}
defaultValue={defaultValue} defaultValue={defaultValue}
required={required}
className="w-full p-2.5 rounded-lg border border-gray-300 dark:border-gray-600 bg-white dark:bg-gray-700 focus:ring-2 focus:ring-blue-500 focus:border-transparent outline-none transition-all" className="w-full p-2.5 rounded-lg border border-gray-300 dark:border-gray-600 bg-white dark:bg-gray-700 focus:ring-2 focus:ring-blue-500 focus:border-transparent outline-none transition-all"
/> />
{hint && ( {hint && (
@ -80,6 +93,102 @@ const Introduction: React.FC<StepProps> = ({ onNext }) => (
const DatabaseConfig: React.FC<StepProps> = ({ onNext }) => { const DatabaseConfig: React.FC<StepProps> = ({ onNext }) => {
const [dbType, setDbType] = useState("postgresql"); const [dbType, setDbType] = useState("postgresql");
const [loading, setLoading] = useState(false);
const api = useHttp();
const validateForm = () => {
const getRequiredFields = () => {
switch (dbType) {
case 'sqllite':
return ['db_prefix', 'db_name'];
case 'postgresql':
case 'mysql':
return ['db_host', 'db_prefix', 'db_port', 'db_user', 'db_password', 'db_name'];
default:
return [];
}
};
const requiredFields = getRequiredFields();
const emptyFields: string[] = [];
requiredFields.forEach(field => {
const input = document.querySelector(`[name="${field}"]`) as HTMLInputElement;
if (input && (!input.value || input.value.trim() === '')) {
emptyFields.push(field);
}
});
if (emptyFields.length > 0) {
const fieldNames = emptyFields.map(field => {
switch (field) {
case 'db_host': return '数据库地址';
case 'db_prefix': return '数据库前缀';
case 'db_port': return '端口';
case 'db_user': return '用户名';
case 'db_password': return '密码';
case 'db_name': return '数据库名';
default: return field;
}
});
message.error(`请填写以下必填项:${fieldNames.join('、')}`);
return false;
}
return true;
};
const handleNext = async () => {
if (!validateForm()) {
return;
}
setLoading(true);
try {
const formData = {
db_type: dbType,
host: (document.querySelector('[name="db_host"]') as HTMLInputElement)?.value?.trim()??"",
db_prefix: (document.querySelector('[name="db_prefix"]') as HTMLInputElement)?.value?.trim()??"",
port: Number((document.querySelector('[name="db_port"]') as HTMLInputElement)?.value?.trim()??0),
user: (document.querySelector('[name="db_user"]') as HTMLInputElement)?.value?.trim()??"",
password: (document.querySelector('[name="db_password"]') as HTMLInputElement)?.value?.trim()??"",
db_name: (document.querySelector('[name="db_name"]') as HTMLInputElement)?.value?.trim()??"",
};
await api.post('/sql', formData);
let oldEnv = import.meta.env?? DEFAULT_CONFIG
const viteEnv = Object.entries(oldEnv).reduce((acc, [key, value]) => {
if (key.startsWith('VITE_')) {
acc[key] = value;
}
return acc;
}, {} as Record<string, any>);
const newEnv = {
...viteEnv,
VITE_INIT_STATUS: '2'
};
await api.dev("/env", {
method: "POST",
body: JSON.stringify(newEnv),
});
Object.assign( newEnv)
message.success('数据库配置成功!');
setTimeout(() => onNext(), 1000);
} catch (error: any) {
console.error( error);
message.error(error.message );
} finally {
setLoading(false);
}
};
return ( return (
<StepContainer title="数据库配置"> <StepContainer title="数据库配置">
@ -105,34 +214,40 @@ const DatabaseConfig: React.FC<StepProps> = ({ onNext }) => {
label="数据库地址" label="数据库地址"
name="db_host" name="db_host"
defaultValue="localhost" defaultValue="localhost"
hint="通常使用 localhost" hint="通常使 localhost"
required
/> />
<InputField <InputField
label="数据库前缀" label="数据库前缀"
name="db_prefix" name="db_prefix"
defaultValue="echoec_" defaultValue="echoec_"
hint="通常使用 echoec_" hint="通常使用 echoec_"
required
/> />
<InputField <InputField
label="端口" label="端口"
name="db_port" name="db_port"
defaultValue={5432} defaultValue={5432}
hint="PostgreSQL 默认端口为 5432" hint="PostgreSQL 默认端口为 5432"
required
/> />
<InputField <InputField
label="用户名" label="用户名"
name="db_user" name="db_user"
defaultValue="postgres" defaultValue="postgres"
required
/> />
<InputField <InputField
label="密码" label="密码"
name="db_password" name="db_password"
defaultValue="postgres" defaultValue="postgres"
required
/> />
<InputField <InputField
label="数据库名" label="数据库名"
name="db_name" name="db_name"
defaultValue="echoes" defaultValue="echoes"
required
/> />
</> </>
)} )}
@ -143,33 +258,39 @@ const DatabaseConfig: React.FC<StepProps> = ({ onNext }) => {
name="db_host" name="db_host"
defaultValue="localhost" defaultValue="localhost"
hint="通常使用 localhost" hint="通常使用 localhost"
required
/> />
<InputField <InputField
label="数据库前缀" label="数据库前缀"
name="db_prefix" name="db_prefix"
defaultValue="echoec_" defaultValue="echoec_"
hint="通常使用 echoec_" hint="通常使用 echoec_"
required
/> />
<InputField <InputField
label="端口" label="端口"
name="db_port" name="db_port"
defaultValue={3306} defaultValue={3306}
hint="mysql 默认端口为 3306" hint="mysql 默认端口为 3306"
required
/> />
<InputField <InputField
label="用户名" label="用户名"
name="db_user" name="db_user"
defaultValue="root" defaultValue="root"
required
/> />
<InputField <InputField
label="密码" label="密码"
name="db_password" name="db_password"
defaultValue="mysql" defaultValue="mysql"
required
/> />
<InputField <InputField
label="数据库名" label="数据库名"
name="db_name" name="db_name"
defaultValue="echoes" defaultValue="echoes"
required
/> />
</> </>
)} )}
@ -180,40 +301,117 @@ const DatabaseConfig: React.FC<StepProps> = ({ onNext }) => {
name="db_prefix" name="db_prefix"
defaultValue="echoec_" defaultValue="echoec_"
hint="通常使用 echoec_" hint="通常使用 echoec_"
required
/> />
<InputField <InputField
label="数据库名" label="数据库名"
name="db_name" name="db_name"
defaultValue="echoes.db" defaultValue="echoes.db"
required
/> />
</> </>
)} )}
<NavigationButtons onNext={onNext} /> <NavigationButtons
onNext={handleNext}
loading={loading}
disabled={loading}
/>
</div> </div>
</StepContainer> </StepContainer>
); );
}; };
const AdminConfig: React.FC<StepProps> = ({ onNext }) => (
<StepContainer title="创建管理员账号">
<div className="space-y-6">
<InputField label="用户名" name="admin_username" />
<InputField label="密码" name="admin_password" />
<InputField label="邮箱" name="admin_email" />
<NavigationButtons onNext={onNext} />
</div>
</StepContainer>
);
const SetupComplete: React.FC = () => ( interface InstallReplyData {
<StepContainer title="安装完成"> token: string,
<div className="text-center"> username: string,
<p className="text-xl text-custom-p-light dark:text-custom-p-dark"> password: string,
... }
</p>
</div>
</StepContainer> const AdminConfig: React.FC<StepProps> = ({ onNext }) => {
); const [loading, setLoading] = useState(false);
const api = useHttp();
const handleNext = async () => {
setLoading(true);
try {
const formData = {
username: (document.querySelector('[name="admin_username"]') as HTMLInputElement)?.value,
password: (document.querySelector('[name="admin_password"]') as HTMLInputElement)?.value,
email: (document.querySelector('[name="admin_email"]') as HTMLInputElement)?.value,
};
const response = await api.post('/administrator', formData) as InstallReplyData;
const data = response;
localStorage.setItem('token', data.token);
let oldEnv = import.meta.env ?? DEFAULT_CONFIG;
const viteEnv = Object.entries(oldEnv).reduce((acc, [key, value]) => {
if (key.startsWith('VITE_')) {
acc[key] = value;
}
return acc;
}, {} as Record<string, any>);
const newEnv = {
...viteEnv,
VITE_INIT_STATUS: '3',
VITE_API_USERNAME:data.username,
VITE_API_PASSWORD:data.password
};
await api.dev("/env", {
method: "POST",
body: JSON.stringify(newEnv),
});
message.success('管理员账号创建成功!');
onNext();
} catch (error: any) {
console.error(error);
message.error(error.message);
} finally {
setLoading(false);
}
};
return (
<StepContainer title="创建管理员账号">
<div className="space-y-6">
<InputField label="用户名" name="admin_username" />
<InputField label="密码" name="admin_password" />
<InputField label="邮箱" name="admin_email" />
<NavigationButtons onNext={handleNext} loading={loading} />
</div>
</StepContainer>
);
};
const SetupComplete: React.FC = () => {
const api = useHttp();
return (
<StepContainer title="安装完成">
<div className="text-center">
<p className="text-xl text-custom-p-light dark:text-custom-p-dark mb-4">
</p>
<p className="text-base text-custom-p-light dark:text-custom-p-dark">
...
</p>
<div className="mt-4">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-blue-500 mx-auto"></div>
</div>
</div>
</StepContainer>
);
};
// 修改主题切换按钮组件 // 修改主题切换按钮组件
const ThemeToggle: React.FC = () => { const ThemeToggle: React.FC = () => {
@ -260,7 +458,7 @@ const ThemeToggle: React.FC = () => {
}; };
export default function SetupPage() { export default function SetupPage() {
let step = Number(import.meta.env.VITE_INIT_STATUS); let step = Number(import.meta.env.VITE_INIT_STATUS)+1;
const [currentStep, setCurrentStep] = useState(step); const [currentStep, setCurrentStep] = useState(step);

View File

@ -7,6 +7,8 @@ import {
} from "@remix-run/react"; } from "@remix-run/react";
import { BaseProvider } from "hooks/servicesProvider"; import { BaseProvider } from "hooks/servicesProvider";
import { MessageProvider } from "hooks/message";
import { MessageContainer } from "hooks/message";
import "~/index.css"; import "~/index.css";
@ -22,7 +24,10 @@ export function Layout({ children }: { children: React.ReactNode }) {
</head> </head>
<body suppressHydrationWarning={true}> <body suppressHydrationWarning={true}>
<BaseProvider> <BaseProvider>
<Outlet /> <MessageProvider>
<MessageContainer />
<Outlet />
</MessageProvider>
</BaseProvider> </BaseProvider>
<ScrollRestoration /> <ScrollRestoration />
<script <script

View File

@ -0,0 +1,32 @@
export interface AppConfig {
port: string;
host: string;
initStatus: string;
apiUrl: string;
credentials: {
username: string;
password: string;
};
}
export const DEFAULT_CONFIG: AppConfig = {
port: "22100",
host: "localhost",
initStatus: "0",
apiUrl: "http://127.0.0.1:22000",
credentials: {
username: "",
password: "",
},
} as const;
declare global {
interface ImportMetaEnv extends Record<string, string> {
VITE_PORT: string;
VITE_HOST: string;
VITE_INIT_STATUS: string;
VITE_API_URL: string;
VITE_USERNAME: string;
VITE_PASSWORD: string;
}
}

29
frontend/config/env.ts Normal file
View File

@ -0,0 +1,29 @@
import { readFile, writeFile } from "fs/promises";
import { resolve } from "path";
const ENV_PATH = resolve(process.cwd(), ".env");
export async function loadEnv(): Promise<Record<string, string>> {
try {
const content = await readFile(ENV_PATH, "utf-8");
return content.split("\n").reduce(
(acc, line) => {
const [key, value] = line.split("=").map((s) => s.trim());
if (key && value) {
acc[key] = value.replace(/["']/g, "");
}
return acc;
},
{} as Record<string, string>,
);
} catch {
return {};
}
}
export async function saveEnv(env: Record<string, string>): Promise<void> {
const content = Object.entries(env)
.map(([key, value]) => `${key}="${value}"`)
.join("\n");
await writeFile(ENV_PATH, content, "utf-8");
}

View File

@ -1,125 +0,0 @@
interface ApiConfig {
baseURL: string;
timeout?: number;
}
export class ApiService {
private static instance: ApiService;
private baseURL: string;
private timeout: number;
private constructor(config: ApiConfig) {
this.baseURL = config.baseURL;
this.timeout = config.timeout || 10000;
}
public static getInstance(config?: ApiConfig): ApiService {
if (!this.instance && config) {
this.instance = new ApiService(config);
}
return this.instance;
}
private async getSystemToken(): Promise<string> {
const username = import.meta.env.VITE_SYSTEM_USERNAME;
const password = import.meta.env.VITE_SYSTEM_PASSWORD;
if (!username || !password) {
throw new Error(
"Failed to obtain the username or password of the front-end system",
);
}
try {
const response = await fetch(`${this.baseURL}/auth/token/system`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
username,
password,
}),
});
if (!response.ok) {
throw new Error("Failed to get system token");
}
const data = await response.text();
return data;
} catch (error) {
console.error("Error getting system token:", error);
throw error;
}
}
private async getToken(username: string, password: string): Promise<string> {
if (username.split(" ").length === 0 || password.split(" ").length === 0) {
throw new Error("Username or password cannot be empty");
}
try {
const response = await fetch(`${this.baseURL}/auth/token`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
username,
password,
}),
});
if (!response.ok) {
throw new Error("Failed to get system token");
}
const data = await response.text();
return data;
} catch (error) {
console.error("Error getting system token:", error);
throw error;
}
}
public async request<T>(
endpoint: string,
options: RequestInit = {},
toekn?: string,
): Promise<T> {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), this.timeout);
try {
const headers = new Headers(options.headers);
if (toekn) {
headers.set("Authorization", `Bearer ${toekn}`);
}
const response = await fetch(`${this.baseURL}${endpoint}`, {
...options,
headers,
signal: controller.signal,
});
if (!response.ok) {
throw new Error(`API Error: ${response.statusText}`);
}
const data = await response.json();
return data as T;
} catch (error: any) {
if (error.name === "AbortError") {
throw new Error("Request timeout");
}
throw error;
} finally {
clearTimeout(timeoutId);
}
}
}
export default ApiService.getInstance({
baseURL: import.meta.env.VITE_API_BASE_URL,
});

142
frontend/core/http.ts Normal file
View File

@ -0,0 +1,142 @@
export class HttpClient {
private static instance: HttpClient;
private timeout: number;
private constructor(timeout = 10000) {
this.timeout = timeout;
}
public static getInstance(timeout?: number): HttpClient {
if (!this.instance) {
this.instance = new HttpClient(timeout);
}
return this.instance;
}
private async setHeaders(options: RequestInit = {}): Promise<RequestInit> {
const headers = new Headers(options.headers);
if (!headers.has("Content-Type")) {
headers.set("Content-Type", "application/json");
}
const token = localStorage.getItem("auth_token");
if (token) {
headers.set("Authorization", `Bearer ${token}`);
}
return { ...options, headers };
}
private async handleResponse(response: Response): Promise<any> {
if (!response.ok) {
const contentType = response.headers.get("content-type");
let message = `${response.statusText} (${response.status})`;
try {
if (contentType?.includes("application/json")) {
const error = await response.json();
message = error.message || message;
} else {
message = this.getErrorMessage(response.status);
}
} catch (e) {
console.error("解析响应错误:", e);
}
throw new Error(message);
}
const contentType = response.headers.get("content-type");
return contentType?.includes("application/json")
? response.json()
: response.text();
}
private getErrorMessage(status: number): string {
const messages: Record<number, string> = {
0: "网络连接失败",
401: "未授权访问",
403: "禁止访问",
404: "资源不存在",
405: "方法不允许",
500: "服务器错误",
502: "网关错误",
503: "服务不可用",
504: "网关超时",
};
return messages[status] || `请求失败 (${status})`;
}
private async request<T>(
endpoint: string,
options: RequestInit = {},
prefix = "api",
): Promise<T> {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), this.timeout);
try {
const config = await this.setHeaders(options);
const url = endpoint.startsWith(`/__/${prefix}`)
? endpoint
: `/__/${prefix}${endpoint.startsWith("/") ? endpoint : `/${endpoint}`}`;
const response = await fetch(url, {
...config,
signal: controller.signal,
credentials: "include",
mode: "cors",
});
return await this.handleResponse(response);
} catch (error: any) {
throw error.name === "AbortError" ? new Error("请求超时") : error;
} finally {
clearTimeout(timeoutId);
}
}
public async api<T>(endpoint: string, options: RequestInit = {}): Promise<T> {
return this.request<T>(endpoint, options, "api");
}
public async dev<T>(endpoint: string, options: RequestInit = {}): Promise<T> {
return this.request<T>(endpoint, options, "express");
}
public async get<T>(endpoint: string, options: RequestInit = {}): Promise<T> {
return this.api<T>(endpoint, { ...options, method: "GET" });
}
public async post<T>(
endpoint: string,
data?: any,
options: RequestInit = {},
): Promise<T> {
return this.api<T>(endpoint, {
...options,
method: "POST",
body: JSON.stringify(data),
});
}
public async put<T>(
endpoint: string,
data?: any,
options: RequestInit = {},
): Promise<T> {
return this.api<T>(endpoint, {
...options,
method: "PUT",
body: JSON.stringify(data),
});
}
public async delete<T>(
endpoint: string,
options: RequestInit = {},
): Promise<T> {
return this.api<T>(endpoint, { ...options, method: "DELETE" });
}
}

View File

@ -1,12 +1,11 @@
export interface Template { export interface Template {
name: string; name: string;
description?: string; description?: string;
config: { config: {
layout?: string; layout?: string;
styles?: string[]; styles?: string[];
scripts?: string[]; scripts?: string[];
}; };
loader: () => Promise<void>; loader: () => Promise<void>;
element: () => React.ReactNode; element: () => React.ReactNode;
} }

View File

@ -26,7 +26,6 @@ export interface ThemeConfig {
}; };
} }
export class ThemeService { export class ThemeService {
private static instance: ThemeService; private static instance: ThemeService;
private currentTheme?: ThemeConfig; private currentTheme?: ThemeConfig;
@ -45,10 +44,9 @@ export class ThemeService {
public async getCurrentTheme(): Promise<void> { public async getCurrentTheme(): Promise<void> {
try { try {
const themeConfig = await this.api.request<ThemeConfig>( const themeConfig = await this.api.request<ThemeConfig>("/theme", {
"/theme", method: "GET",
{ method: "GET" }, });
);
this.currentTheme = themeConfig; this.currentTheme = themeConfig;
} catch (error) { } catch (error) {
console.error("Failed to initialize theme:", error); console.error("Failed to initialize theme:", error);
@ -60,18 +58,18 @@ export class ThemeService {
return this.currentTheme; return this.currentTheme;
} }
public async updateThemeConfig(config: Partial<ThemeConfig>,name:string): Promise<void> { public async updateThemeConfig(
config: Partial<ThemeConfig>,
name: string,
): Promise<void> {
try { try {
const updatedConfig = await this.api.request<ThemeConfig>( const updatedConfig = await this.api.request<ThemeConfig>(`/theme/`, {
`/theme/`, method: "PUT",
{ headers: {
method: "PUT", "Content-Type": "application/json",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(config),
}, },
); body: JSON.stringify(config),
});
await this.loadTheme(updatedConfig); await this.loadTheme(updatedConfig);
} catch (error) { } catch (error) {

229
frontend/hooks/message.tsx Normal file
View File

@ -0,0 +1,229 @@
import React, { createContext, useState, useContext, useEffect } from "react";
interface Message {
id: number;
type: "success" | "error" | "info" | "warning";
content: string;
duration?: number;
}
interface MessageOptions {
content: string;
duration?: number;
}
interface MessageContextType {
messages: Message[];
addMessage: (
type: Message["type"],
content: string,
duration?: number,
) => void;
}
const MessageContext = createContext<MessageContextType>({
messages: [],
addMessage: () => {},
});
export const MessageProvider: React.FC<{ children: React.ReactNode }> = ({
children,
}) => {
const [messages, setMessages] = useState<Message[]>([]);
const removeMessage = (id: number) => {
setMessages((prev) => prev.filter((msg) => msg.id !== id));
};
const addMessage = (
type: Message["type"],
content: string,
duration = 3000,
) => {
const id = Date.now();
setMessages((prevMessages) => {
const newMessages = [...prevMessages, { id, type, content }];
return newMessages;
});
if (duration > 0) {
setTimeout(() => removeMessage(id), duration);
}
};
return (
<>
<MessageContext.Provider value={{ messages, addMessage }}>
{children}
</MessageContext.Provider>
<div
id="message-container"
style={{
position: "fixed",
top: "16px",
right: "16px",
display: "flex",
flexDirection: "column",
gap: "8px",
pointerEvents: "none",
zIndex: 999999,
maxWidth: "90vw",
}}
>
{messages.map((msg) => (
<div
key={msg.id}
style={{
backgroundColor:
msg.type === "success"
? "rgba(34, 197, 94, 0.95)"
: msg.type === "error"
? "rgba(239, 68, 68, 0.95)"
: msg.type === "warning"
? "rgba(234, 179, 8, 0.95)"
: "rgba(59, 130, 246, 0.95)",
color: "white",
width: "320px",
borderRadius: "4px",
boxShadow: "0 4px 12px rgba(0, 0, 0, 0.15)",
overflow: "hidden",
animation: "slideInRight 0.3s ease-out forwards",
pointerEvents: "auto",
}}
>
<div
style={{
padding: "12px 16px",
display: "flex",
justifyContent: "space-between",
alignItems: "center",
}}
>
<span
style={{
fontSize: "14px",
lineHeight: "1.5",
marginRight: "12px",
flex: 1,
wordBreak: "break-word",
}}
>
{msg.content}
</span>
<button
onClick={() => removeMessage(msg.id)}
style={{
background: "none",
border: "none",
color: "rgba(255, 255, 255, 0.8)",
cursor: "pointer",
padding: "4px",
fontSize: "16px",
lineHeight: 1,
transition: "color 0.2s",
flexShrink: 0,
}}
onMouseEnter={(e) => (e.currentTarget.style.color = "white")}
onMouseLeave={(e) =>
(e.currentTarget.style.color = "rgba(255, 255, 255, 0.8)")
}
>
</button>
</div>
<div
style={{
height: "2px",
backgroundColor: "rgba(255, 255, 255, 0.2)",
position: "relative",
}}
>
<div
style={{
position: "absolute",
left: 0,
bottom: 0,
width: "100%",
height: "100%",
backgroundColor: "rgba(255, 255, 255, 0.4)",
animation: `progress ${msg.duration || 3000}ms linear`,
}}
/>
</div>
</div>
))}
</div>
<style>{`
@keyframes slideInRight {
from {
transform: translateX(100%);
opacity: 0;
}
to {
transform: translateX(0);
opacity: 1;
}
}
@keyframes progress {
from {
width: 100%;
}
to {
width: 0%;
}
}
`}</style>
</>
);
};
// 修改全局消息实例的实现
let globalAddMessage:
| ((type: Message["type"], content: string, duration?: number) => void)
| null = null;
export const MessageContainer: React.FC = () => {
const { addMessage } = useContext(MessageContext);
useEffect(() => {
globalAddMessage = addMessage;
return () => {
globalAddMessage = null;
};
}, [addMessage]);
return null;
};
// 修改消息方法的实现
export const message = {
success: (content: string) => {
if (!globalAddMessage) {
console.warn("Message system not initialized");
return;
}
globalAddMessage("success", content);
},
error: (content: string) => {
if (!globalAddMessage) {
console.warn("Message system not initialized");
return;
}
globalAddMessage("error", content);
},
warning: (content: string) => {
if (!globalAddMessage) {
console.warn("Message system not initialized");
return;
}
globalAddMessage("warning", content);
},
info: (content: string) => {
if (!globalAddMessage) {
console.warn("Message system not initialized");
return;
}
globalAddMessage("info", content);
},
};

View File

@ -1,5 +1,5 @@
import { CapabilityService } from "core/capability"; import { CapabilityService } from "core/capability";
import { ApiService } from "core/api"; import { HttpClient } from "core/http";
import { RouteManager } from "core/route"; import { RouteManager } from "core/route";
import { createServiceContext } from "hooks/createServiceContext"; import { createServiceContext } from "hooks/createServiceContext";
import { ReactNode } from "react"; import { ReactNode } from "react";
@ -13,14 +13,14 @@ export const { RouteProvider, useRoute } = createServiceContext("Route", () =>
RouteManager.getInstance(), RouteManager.getInstance(),
); );
export const { ApiProvider, useApi } = createServiceContext("Api", () => export const { HttpProvider, useHttp } = createServiceContext("Http", () =>
ApiService.getInstance(), HttpClient.getInstance(),
); );
export const BaseProvider = ({ children }: { children: ReactNode }) => ( export const BaseProvider = ({ children }: { children: ReactNode }) => (
<ApiProvider> <HttpProvider>
<CapabilityProvider> <CapabilityProvider>
<RouteProvider>{children}</RouteProvider> <RouteProvider>{children}</RouteProvider>
</CapabilityProvider> </CapabilityProvider>
</ApiProvider> </HttpProvider>
); );

View File

@ -5,7 +5,7 @@
"type": "module", "type": "module",
"scripts": { "scripts": {
"build": "remix vite:build", "build": "remix vite:build",
"dev": "remix vite:dev", "dev": "concurrently \"node --trace-warnings ./node_modules/vite/bin/vite.js\" \"cross-env VITE_ADDRESS=localhost VITE_PORT=22100 tsx --trace-warnings server/express.ts\"",
"format": "prettier --write \"./**/*.{ts,tsx,js,jsx}\"", "format": "prettier --write \"./**/*.{ts,tsx,js,jsx}\"",
"lint": "eslint \"./**/*.{ts,tsx,js,jsx}\" --fix", "lint": "eslint \"./**/*.{ts,tsx,js,jsx}\" --fix",
"start": "remix-serve ./build/server/index.js", "start": "remix-serve ./build/server/index.js",
@ -17,17 +17,23 @@
"@remix-run/serve": "^2.14.0", "@remix-run/serve": "^2.14.0",
"@types/axios": "^0.14.4", "@types/axios": "^0.14.4",
"axios": "^1.7.7", "axios": "^1.7.7",
"cors": "^2.8.5",
"express": "^4.21.1",
"isbot": "^4.1.0", "isbot": "^4.1.0",
"react": "^18.2.0", "react": "^18.2.0",
"react-dom": "^18.2.0" "react-dom": "^18.2.0"
}, },
"devDependencies": { "devDependencies": {
"@remix-run/dev": "^2.14.0", "@remix-run/dev": "^2.14.0",
"@types/cors": "^2.8.17",
"@types/express": "^5.0.0",
"@types/react": "^18.2.20", "@types/react": "^18.2.20",
"@types/react-dom": "^18.2.7", "@types/react-dom": "^18.2.7",
"@typescript-eslint/eslint-plugin": "^6.7.4", "@typescript-eslint/eslint-plugin": "^6.7.4",
"@typescript-eslint/parser": "^6.7.4", "@typescript-eslint/parser": "^6.7.4",
"autoprefixer": "^10.4.19", "autoprefixer": "^10.4.19",
"concurrently": "^9.1.0",
"cross-env": "^7.0.3",
"eslint": "^8.57.1", "eslint": "^8.57.1",
"eslint-import-resolver-typescript": "^3.6.1", "eslint-import-resolver-typescript": "^3.6.1",
"eslint-plugin-import": "^2.28.1", "eslint-plugin-import": "^2.28.1",
@ -37,6 +43,7 @@
"postcss": "^8.4.38", "postcss": "^8.4.38",
"prettier": "^3.3.3", "prettier": "^3.3.3",
"tailwindcss": "^3.4.4", "tailwindcss": "^3.4.4",
"tsx": "^4.19.2",
"typescript": "^5.1.6", "typescript": "^5.1.6",
"vite": "^5.1.0", "vite": "^5.1.0",
"vite-tsconfig-paths": "^4.2.1" "vite-tsconfig-paths": "^4.2.1"

32
frontend/server/env.ts Normal file
View File

@ -0,0 +1,32 @@
import fs from "fs/promises";
import path from "path";
export async function readEnvFile() {
const envPath = path.resolve(process.cwd(), ".env");
try {
const content = await fs.readFile(envPath, "utf-8");
return content.split("\n").reduce(
(acc, line) => {
const [key, value] = line.split("=").map((s) => s.trim());
if (key && value) {
acc[key] = value.replace(/["']/g, "");
}
return acc;
},
{} as Record<string, string>,
);
} catch {
return {};
}
}
export async function writeEnvFile(env: Record<string, string>) {
const envPath = path.resolve(process.cwd(), ".env");
const content = Object.entries(env)
.map(
([key, value]) =>
`${key}=${typeof value === "string" ? `"${value}"` : value}`,
)
.join("\n");
await fs.writeFile(envPath, content, "utf-8");
}

View File

@ -0,0 +1,72 @@
import express from "express";
import cors from "cors";
import { DEFAULT_CONFIG } from "../app/env";
import { readEnvFile, writeEnvFile } from "./env";
const app = express();
const address = process.env.VITE_ADDRESS ?? DEFAULT_CONFIG.VITE_ADDRESS;
const port = Number(process.env.VITE_PORT ?? DEFAULT_CONFIG.VITE_PORT);
const ALLOWED_ORIGIN = `http://${address}:${port}`;
// 配置 CORS只允许来自 Vite 服务器的请求
app.use(
cors({
origin: (origin, callback) => {
if (!origin || origin === ALLOWED_ORIGIN) {
callback(null, true);
} else {
callback(new Error("不允许的来源"));
}
},
credentials: true,
}),
);
// 添加 IP 和端口检查中间件
const checkAccessMiddleware = (
req: express.Request,
res: express.Response,
next: express.NextFunction,
) => {
const clientIp = req.ip === "::1" ? "localhost" : req.ip;
const clientPort = Number(req.get("origin")?.split(":").pop() ?? 0);
const isLocalIp = clientIp === "localhost" || clientIp === "127.0.0.1";
const isAllowedPort = clientPort === port;
if (isLocalIp && isAllowedPort) {
next();
} else {
res.status(403).json({
error: "禁止访问",
detail: `仅允许 ${address}:${port} 访问`,
});
}
};
app.use(checkAccessMiddleware);
app.use(express.json());
app.get("/env", async (req, res) => {
try {
const envData = await readEnvFile();
res.json(envData);
} catch (error) {
res.status(500).json({ error: "读取环境变量失败" });
}
});
app.post("/env", async (req, res) => {
try {
const newEnv = req.body;
await writeEnvFile(newEnv);
res.json({ success: true });
} catch (error) {
res.status(500).json({ error: "更新环境变量失败" });
}
});
app.listen(port + 1, address, () => {
console.log(`内部服务器运行在 http://${address}:${port + 1}`);
});

57
frontend/start.ts Normal file
View File

@ -0,0 +1,57 @@
import { spawn } from "child_process";
import path from "path";
import { EventEmitter } from 'events'
// 设置全局最大监听器数量
EventEmitter.defaultMaxListeners = 20
const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
const startServers = async () => {
// 先启动内部服务器
const internalServer = spawn("tsx", ["backend/internalServer.ts"], {
stdio: "inherit",
shell: true,
env: {
...process.env,
NODE_ENV: process.env.NODE_ENV || "development",
},
});
internalServer.on("error", (err) => {
console.error("内部服务器启动错误:", err);
});
// 等待内部服务器启动
console.log("等待内部服务器启动...");
await delay(2000);
// 然后启动 Vite
const viteProcess = spawn("npm", ["run", "dev"], {
stdio: "inherit",
shell: true,
env: {
...process.env,
NODE_ENV: process.env.NODE_ENV || "development",
},
});
viteProcess.on("error", (err) => {
console.error("Vite 进程启动错误:", err);
});
const cleanup = () => {
console.log("正在关闭服务器...");
viteProcess.kill();
internalServer.kill();
process.exit();
};
process.on("SIGINT", cleanup);
process.on("SIGTERM", cleanup);
};
startServers().catch((err) => {
console.error("启动服务器时发生错误:", err);
process.exit(1);
});

View File

@ -15,19 +15,19 @@ export default {
custom: { custom: {
bg: { bg: {
light: "#F5F5FB", light: "#F5F5FB",
dark: "#0F172A" dark: "#0F172A",
}, },
box: { box: {
light: "#FFFFFF", light: "#FFFFFF",
dark: "#1E293B" dark: "#1E293B",
}, },
p: { p: {
light: "#4b5563", light: "#4b5563",
dark: "#94A3B8" dark: "#94A3B8",
}, },
title: { title: {
light: "#111827", light: "#111827",
dark: "#F1F5F9" dark: "#F1F5F9",
}, },
}, },
}, },

View File

@ -2,9 +2,40 @@ import { vitePlugin as remix } from "@remix-run/dev";
import { defineConfig, loadEnv } from "vite"; import { defineConfig, loadEnv } from "vite";
import tsconfigPaths from "vite-tsconfig-paths"; import tsconfigPaths from "vite-tsconfig-paths";
import { resolve } from "path"; import { resolve } from "path";
import { readEnvFile } from "./server/env";
import { DEFAULT_CONFIG, EnvConfig } from "./app/env";
// 修改为异步函数来读取最新的环境变量
async function getLatestEnv() {
try {
const envData = await readEnvFile();
return {
...DEFAULT_CONFIG,
...envData,
} as EnvConfig;
} catch (error) {
console.error("读取环境变量失败:", error);
return DEFAULT_CONFIG;
}
}
const createDefineConfig = (config: EnvConfig) => {
return Object.entries(config).reduce(
(acc, [key, value]) => {
acc[`import.meta.env.${key}`] =
typeof value === "string" ? JSON.stringify(value) : value;
return acc;
},
{} as Record<string, any>,
);
};
export default defineConfig(async ({ mode }) => {
// 确保每次都读取最新的环境变量
const currentConfig = await getLatestEnv();
const env = loadEnv(mode, process.cwd(), "VITE_");
export default defineConfig(({ mode }) => {
const env = loadEnv(mode, process.cwd(), "");
return { return {
plugins: [ plugins: [
remix({ remix({
@ -15,9 +46,12 @@ export default defineConfig(({ mode }) => {
v3_singleFetch: true, v3_singleFetch: true,
v3_lazyRouteDiscovery: true, v3_lazyRouteDiscovery: true,
}, },
routes: (defineRoutes) => { routes: async (defineRoutes) => {
// 每次路由配置时重新读取环境变量
const latestConfig = await getLatestEnv();
return defineRoutes((route) => { return defineRoutes((route) => {
if (Number(env.VITE_INIT_STATUS??1)<4) { if (Number(latestConfig.VITE_INIT_STATUS) < 3) {
route("/", "init.tsx", { id: "index-route" }); route("/", "init.tsx", { id: "index-route" });
route("*", "init.tsx", { id: "catch-all-route" }); route("*", "init.tsx", { id: "catch-all-route" });
} else { } else {
@ -29,22 +63,30 @@ export default defineConfig(({ mode }) => {
}), }),
tsconfigPaths(), tsconfigPaths(),
], ],
define: { define: createDefineConfig(currentConfig),
"import.meta.env.VITE_INIT_STATUS": JSON.stringify(1),
"import.meta.env.VITE_SERVER_API": JSON.stringify("localhost:22000"),
"import.meta.env.VITE_PORT": JSON.stringify(22100),
"import.meta.env.VITE_ADDRESS": JSON.stringify("localhost"),
},
server: { server: {
host: true, host: true,
address: "localhost", address: currentConfig.VITE_ADDRESS,
port: Number(env.VITE_SYSTEM_PORT ?? 22100), port: Number(env.VITE_SYSTEM_PORT ?? currentConfig.VITE_PORT),
strictPort: true, strictPort: true,
hmr: true, // 确保启用热更新 hmr: true,
watch: { watch: {
usePolling: true, // 添加这个配置可以解决某些系统下热更新不工作的问题 usePolling: true,
},
proxy: {
"/__/api": {
target: currentConfig.VITE_API_BASE_URL,
changeOrigin: true,
rewrite: (path: string) => path.replace(/^\/__\/api/, ""),
},
"/__/express": {
target: `http://${currentConfig.VITE_ADDRESS}:${Number(currentConfig.VITE_PORT) + 1}`,
changeOrigin: true,
rewrite: (path: string) => path.replace(/^\/__\/express/, ""),
},
}, },
}, },
publicDir: resolve(__dirname, "public"), publicDir: resolve(__dirname, "public"),
envPrefix: "VITE_",
}; };
}); });