后端:完善错误和数据库代码构建,实现应用重启
This commit is contained in:
parent
eb53c72203
commit
d2eac057ca
@ -17,3 +17,5 @@ rand = "0.8.5"
|
|||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
regex = "1.11.1"
|
regex = "1.11.1"
|
||||||
bcrypt = "0.16"
|
bcrypt = "0.16"
|
||||||
|
uuid = { version = "1.11.0", features = ["v4", "serde"] }
|
||||||
|
hex = "0.4.3"
|
@ -1,11 +0,0 @@
|
|||||||
[info]
|
|
||||||
install = false
|
|
||||||
non_relational = false
|
|
||||||
|
|
||||||
[sql_config]
|
|
||||||
db_type = "postgresql"
|
|
||||||
address = "localhost"
|
|
||||||
port = 5432
|
|
||||||
user = "postgres"
|
|
||||||
password = "postgres"
|
|
||||||
db_name = "echoes"
|
|
16
backend/src/auth/bcrypt.rs
Normal file
16
backend/src/auth/bcrypt.rs
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
use crate::error::CustomErrorInto;
|
||||||
|
use crate::error::CustomResult;
|
||||||
|
use bcrypt::{hash, verify, DEFAULT_COST};
|
||||||
|
|
||||||
|
pub fn generate_hash(s: &str) -> CustomResult<String> {
|
||||||
|
let hashed = hash(s, DEFAULT_COST)?;
|
||||||
|
Ok(hashed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn verify_hash(s: &str, hash: &str) -> CustomResult<()> {
|
||||||
|
let is_valid = verify(s, hash)?;
|
||||||
|
if !is_valid {
|
||||||
|
return Err("密码无效".into_custom_error());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
use crate::utils::CustomResult;
|
use crate::error::CustomResult;
|
||||||
use chrono::{Duration, Utc};
|
use chrono::{Duration, Utc};
|
||||||
use ed25519_dalek::{SigningKey, VerifyingKey};
|
use ed25519_dalek::{SigningKey, VerifyingKey};
|
||||||
use jwt_compact::{alg::Ed25519, AlgorithmExt, Header, TimeOptions, Token, UntrustedToken};
|
use jwt_compact::{alg::Ed25519, AlgorithmExt, Header, TimeOptions, Token, UntrustedToken};
|
||||||
|
@ -1 +1,2 @@
|
|||||||
|
pub mod bcrypt;
|
||||||
pub mod jwt;
|
pub mod jwt;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::utils::CustomResult;
|
use crate::error::CustomResult;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::{env, fs};
|
use std::{env, fs};
|
||||||
@ -47,6 +47,6 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_path() -> CustomResult<PathBuf> {
|
pub fn get_path() -> CustomResult<PathBuf> {
|
||||||
Ok(env::current_dir()?.join("assets").join("config.toml"))
|
Ok(env::current_dir()?.join("config.toml"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,63 +1,289 @@
|
|||||||
use crate::utils::{CustomError, CustomResult};
|
use crate::error::{CustomErrorInto, CustomResult};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
|
use serde::Serialize;
|
||||||
|
use serde_json::Value as JsonValue;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, Serialize)]
|
||||||
pub enum ValidatedValue {
|
pub enum ValidationLevel {
|
||||||
Identifier(String),
|
Strict,
|
||||||
RichText(String),
|
Standard,
|
||||||
PlainText(String),
|
Relaxed,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ValidatedValue {
|
#[derive(Debug, Clone)]
|
||||||
pub fn new_identifier(value: String) -> CustomResult<Self> {
|
pub struct TextValidator {
|
||||||
let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]{0,63}$").unwrap();
|
sql_patterns: Vec<&'static str>,
|
||||||
if !valid_pattern.is_match(&value) {
|
special_chars: Vec<char>,
|
||||||
return Err(CustomError::from_str("Invalid identifier format"));
|
level_max_lengths: HashMap<ValidationLevel, usize>,
|
||||||
|
level_allowed_chars: HashMap<ValidationLevel, Vec<char>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TextValidator {
|
||||||
|
fn default() -> Self {
|
||||||
|
let level_max_lengths = HashMap::from([
|
||||||
|
(ValidationLevel::Strict, 100),
|
||||||
|
(ValidationLevel::Standard, 1000),
|
||||||
|
(ValidationLevel::Relaxed, 100000),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let level_allowed_chars = HashMap::from([
|
||||||
|
(ValidationLevel::Strict, vec!['_']),
|
||||||
|
(
|
||||||
|
ValidationLevel::Standard,
|
||||||
|
vec!['_', '-', '.', ',', '!', '?', ':', ' '],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ValidationLevel::Relaxed,
|
||||||
|
vec![
|
||||||
|
'_', '-', '.', ',', '!', '?', ':', ' ', '"', '\'', '(', ')', '[', ']', '{',
|
||||||
|
'}', '@', '#', '$', '%', '^', '&', '*', '+', '=', '<', '>', '/', '\\',
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
TextValidator {
|
||||||
|
sql_patterns: vec![
|
||||||
|
"DROP",
|
||||||
|
"TRUNCATE",
|
||||||
|
"ALTER",
|
||||||
|
"DELETE",
|
||||||
|
"UPDATE",
|
||||||
|
"INSERT",
|
||||||
|
"MERGE",
|
||||||
|
"GRANT",
|
||||||
|
"REVOKE",
|
||||||
|
"UNION",
|
||||||
|
"--",
|
||||||
|
"/*",
|
||||||
|
"EXEC",
|
||||||
|
"EXECUTE",
|
||||||
|
"WAITFOR",
|
||||||
|
"DELAY",
|
||||||
|
"BENCHMARK",
|
||||||
|
],
|
||||||
|
special_chars: vec!['\0', '\n', '\r', '\t'],
|
||||||
|
level_max_lengths,
|
||||||
|
level_allowed_chars,
|
||||||
}
|
}
|
||||||
Ok(ValidatedValue::Identifier(value))
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextValidator {
|
||||||
|
pub fn validate(&self, text: &str, level: ValidationLevel) -> CustomResult<()> {
|
||||||
|
let max_length = self
|
||||||
|
.level_max_lengths
|
||||||
|
.get(&level)
|
||||||
|
.ok_or_else(|| "Invalid validation level".into_custom_error())?;
|
||||||
|
|
||||||
|
if text.len() > *max_length {
|
||||||
|
return Err("Text exceeds maximum length".into_custom_error());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 简化验证逻辑
|
||||||
|
if level == ValidationLevel::Relaxed {
|
||||||
|
return self.validate_sql_patterns(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.validate_chars(text, level)?;
|
||||||
|
self.validate_special_chars(text)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_rich_text(value: String) -> CustomResult<Self> {
|
fn validate_sql_patterns(&self, text: &str) -> CustomResult<()> {
|
||||||
let dangerous_patterns = [
|
let upper_text = text.to_uppercase();
|
||||||
"UNION ALL SELECT",
|
if self
|
||||||
"UNION SELECT",
|
.sql_patterns
|
||||||
"OR 1=1",
|
.iter()
|
||||||
"OR '1'='1",
|
.any(|&pattern| upper_text.contains(&pattern.to_uppercase()))
|
||||||
"DROP TABLE",
|
{
|
||||||
"DELETE FROM",
|
return Err("Potentially dangerous SQL pattern detected".into_custom_error());
|
||||||
"UPDATE ",
|
}
|
||||||
"INSERT INTO",
|
Ok(())
|
||||||
"--",
|
}
|
||||||
"/*",
|
|
||||||
"*/",
|
|
||||||
"@@",
|
|
||||||
];
|
|
||||||
|
|
||||||
let value_upper = value.to_uppercase();
|
fn validate_chars(&self, text: &str, level: ValidationLevel) -> CustomResult<()> {
|
||||||
for pattern in dangerous_patterns.iter() {
|
let allowed_chars = self
|
||||||
if value_upper.contains(&pattern.to_uppercase()) {
|
.level_allowed_chars
|
||||||
return Err(CustomError::from_str("Invalid identifier format"));
|
.get(&level)
|
||||||
|
.ok_or_else(|| "Invalid validation level".into_custom_error())?;
|
||||||
|
|
||||||
|
if let Some(invalid_char) = text
|
||||||
|
.chars()
|
||||||
|
.find(|&c| !c.is_alphanumeric() && !allowed_chars.contains(&c))
|
||||||
|
{
|
||||||
|
return Err(
|
||||||
|
format!("Invalid character '{}' for {:?} level", invalid_char, level)
|
||||||
|
.into_custom_error(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_special_chars(&self, text: &str) -> CustomResult<()> {
|
||||||
|
if self.special_chars.iter().any(|&c| text.contains(c)) {
|
||||||
|
return Err("Invalid special character detected".into_custom_error());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提供便捷方法
|
||||||
|
pub fn validate_relaxed(&self, text: &str) -> CustomResult<()> {
|
||||||
|
self.validate(text, ValidationLevel::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn validate_standard(&self, text: &str) -> CustomResult<()> {
|
||||||
|
self.validate(text, ValidationLevel::Standard)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn validate_strict(&self, text: &str) -> CustomResult<()> {
|
||||||
|
self.validate(text, ValidationLevel::Strict)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sanitize(&self, text: &str) -> CustomResult<String> {
|
||||||
|
self.validate_relaxed(text)?;
|
||||||
|
Ok(text.replace('\'', "''").replace('\\', "\\\\"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum SafeValue {
|
||||||
|
Null,
|
||||||
|
Bool(bool),
|
||||||
|
Integer(i64),
|
||||||
|
Float(f64),
|
||||||
|
Text(String, ValidationLevel),
|
||||||
|
DateTime(DateTime<Utc>),
|
||||||
|
Uuid(Uuid),
|
||||||
|
Binary(Vec<u8>),
|
||||||
|
Array(Vec<SafeValue>),
|
||||||
|
Json(JsonValue),
|
||||||
|
Enum(String, String, ValidationLevel),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SafeValue {
|
||||||
|
pub fn from_json(value: JsonValue, level: ValidationLevel) -> CustomResult<Self> {
|
||||||
|
match value {
|
||||||
|
JsonValue::Null => Ok(SafeValue::Null),
|
||||||
|
JsonValue::Bool(b) => Ok(SafeValue::Bool(b)),
|
||||||
|
JsonValue::Number(n) => {
|
||||||
|
if let Some(i) = n.as_i64() {
|
||||||
|
Ok(SafeValue::Integer(i))
|
||||||
|
} else if let Some(f) = n.as_f64() {
|
||||||
|
Ok(SafeValue::Float(f))
|
||||||
|
} else {
|
||||||
|
Err("Invalid number format".into_custom_error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
JsonValue::String(s) => {
|
||||||
|
TextValidator::default().validate(&s, level)?;
|
||||||
|
Ok(SafeValue::Text(s, level))
|
||||||
|
}
|
||||||
|
JsonValue::Array(arr) => Ok(SafeValue::Array(
|
||||||
|
arr.into_iter()
|
||||||
|
.map(|item| SafeValue::from_json(item, level))
|
||||||
|
.collect::<CustomResult<Vec<_>>>()?,
|
||||||
|
)),
|
||||||
|
JsonValue::Object(_) => {
|
||||||
|
Self::validate_json_structure(&value, level)?;
|
||||||
|
Ok(SafeValue::Json(value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(ValidatedValue::RichText(value))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_plain_text(value: String) -> CustomResult<Self> {
|
fn validate_json_structure(value: &JsonValue, level: ValidationLevel) -> CustomResult<()> {
|
||||||
if value.contains(';') || value.contains("--") {
|
let validator = TextValidator::default();
|
||||||
return Err(CustomError::from_str("Invalid characters in text"));
|
match value {
|
||||||
|
JsonValue::Object(map) => {
|
||||||
|
for (key, val) in map {
|
||||||
|
validator.validate(key, level)?;
|
||||||
|
Self::validate_json_structure(val, level)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
JsonValue::Array(arr) => {
|
||||||
|
arr.iter()
|
||||||
|
.try_for_each(|item| Self::validate_json_structure(item, level))?;
|
||||||
|
}
|
||||||
|
JsonValue::String(s) => validator.validate(s, level)?,
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
Ok(ValidatedValue::PlainText(value))
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self) -> &str {
|
fn get_sql_type(&self) -> CustomResult<String> {
|
||||||
|
let sql_type = match self {
|
||||||
|
SafeValue::Null => "NULL",
|
||||||
|
SafeValue::Bool(_) => "boolean",
|
||||||
|
SafeValue::Integer(_) => "bigint",
|
||||||
|
SafeValue::Float(_) => "double precision",
|
||||||
|
SafeValue::Text(_, _) => "text",
|
||||||
|
SafeValue::DateTime(_) => "timestamp with time zone",
|
||||||
|
SafeValue::Uuid(_) => "uuid",
|
||||||
|
SafeValue::Binary(_) => "bytea",
|
||||||
|
SafeValue::Array(_) | SafeValue::Json(_) => "jsonb",
|
||||||
|
SafeValue::Enum(_, enum_type, level) => {
|
||||||
|
TextValidator::default().validate(enum_type, *level)?;
|
||||||
|
return Ok(enum_type.replace('\'', "''"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(sql_type.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_sql_string(&self) -> CustomResult<String> {
|
||||||
match self {
|
match self {
|
||||||
ValidatedValue::Identifier(s)
|
SafeValue::Null => Ok("NULL".to_string()),
|
||||||
| ValidatedValue::RichText(s)
|
SafeValue::Bool(b) => Ok(b.to_string()),
|
||||||
| ValidatedValue::PlainText(s) => s,
|
SafeValue::Integer(i) => Ok(i.to_string()),
|
||||||
|
SafeValue::Float(f) => Ok(f.to_string()),
|
||||||
|
SafeValue::Text(s, level) => {
|
||||||
|
TextValidator::default().validate(s, *level)?;
|
||||||
|
Ok(s.replace('\'', "''"))
|
||||||
|
}
|
||||||
|
SafeValue::DateTime(dt) => Ok(format!("'{}'", dt.to_rfc3339())),
|
||||||
|
SafeValue::Uuid(u) => Ok(format!("'{}'", u)),
|
||||||
|
SafeValue::Binary(b) => Ok(format!("'\\x{}'", hex::encode(b))),
|
||||||
|
SafeValue::Array(arr) => {
|
||||||
|
let values: CustomResult<Vec<_>> = arr.iter().map(|v| v.to_sql_string()).collect();
|
||||||
|
Ok(format!("ARRAY[{}]", values?.join(",")))
|
||||||
|
}
|
||||||
|
SafeValue::Json(j) => {
|
||||||
|
let json_str = serde_json::to_string(j)?;
|
||||||
|
TextValidator::default().validate(&json_str, ValidationLevel::Relaxed)?;
|
||||||
|
Ok(json_str.replace('\'', "''"))
|
||||||
|
}
|
||||||
|
SafeValue::Enum(s, _, level) => {
|
||||||
|
TextValidator::default().validate(s, *level)?;
|
||||||
|
Ok(s.to_string())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn to_param_sql(&self, param_index: usize) -> CustomResult<String> {
|
||||||
|
if matches!(self, SafeValue::Null) {
|
||||||
|
Ok("NULL".to_string())
|
||||||
|
} else {
|
||||||
|
Ok(format!("${}::{}", param_index, self.get_sql_type()?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub struct Identifier(String);
|
||||||
|
|
||||||
|
impl Identifier {
|
||||||
|
pub fn new(value: String) -> CustomResult<Self> {
|
||||||
|
let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_\.]{0,63}$")?;
|
||||||
|
if !valid_pattern.is_match(&value) {
|
||||||
|
return Err("Invalid identifier format".into_custom_error());
|
||||||
|
}
|
||||||
|
Ok(Identifier(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_str(&self) -> &str {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@ -80,6 +306,8 @@ pub enum Operator {
|
|||||||
In,
|
In,
|
||||||
IsNull,
|
IsNull,
|
||||||
IsNotNull,
|
IsNotNull,
|
||||||
|
JsonContains,
|
||||||
|
JsonExists,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Operator {
|
impl Operator {
|
||||||
@ -95,31 +323,23 @@ impl Operator {
|
|||||||
Operator::In => "IN",
|
Operator::In => "IN",
|
||||||
Operator::IsNull => "IS NULL",
|
Operator::IsNull => "IS NULL",
|
||||||
Operator::IsNotNull => "IS NOT NULL",
|
Operator::IsNotNull => "IS NOT NULL",
|
||||||
|
Operator::JsonContains => "@>",
|
||||||
|
Operator::JsonExists => "?",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct WhereCondition {
|
pub struct Condition {
|
||||||
field: ValidatedValue,
|
field: Identifier,
|
||||||
operator: Operator,
|
operator: Operator,
|
||||||
value: Option<ValidatedValue>,
|
value: Option<SafeValue>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WhereCondition {
|
impl Condition {
|
||||||
pub fn new(field: String, operator: Operator, value: Option<String>) -> CustomResult<Self> {
|
pub fn new(field: String, operator: Operator, value: Option<SafeValue>) -> CustomResult<Self> {
|
||||||
let field = ValidatedValue::new_identifier(field)?;
|
Ok(Condition {
|
||||||
|
field: Identifier::new(field)?,
|
||||||
let value = match value {
|
|
||||||
Some(v) => Some(match operator {
|
|
||||||
Operator::Like => ValidatedValue::new_plain_text(v)?,
|
|
||||||
_ => ValidatedValue::new_plain_text(v)?,
|
|
||||||
}),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(WhereCondition {
|
|
||||||
field,
|
|
||||||
operator,
|
operator,
|
||||||
value,
|
value,
|
||||||
})
|
})
|
||||||
@ -130,170 +350,233 @@ impl WhereCondition {
|
|||||||
pub enum WhereClause {
|
pub enum WhereClause {
|
||||||
And(Vec<WhereClause>),
|
And(Vec<WhereClause>),
|
||||||
Or(Vec<WhereClause>),
|
Or(Vec<WhereClause>),
|
||||||
Condition(WhereCondition),
|
Condition(Condition),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct QueryBuilder {
|
pub struct QueryBuilder {
|
||||||
operation: SqlOperation,
|
operation: SqlOperation,
|
||||||
table: ValidatedValue,
|
table: Identifier,
|
||||||
fields: Vec<ValidatedValue>,
|
fields: Vec<Identifier>,
|
||||||
params: HashMap<ValidatedValue, ValidatedValue>,
|
values: HashMap<Identifier, SafeValue>,
|
||||||
where_clause: Option<WhereClause>,
|
where_clause: Option<WhereClause>,
|
||||||
order_by: Option<ValidatedValue>,
|
order_by: Option<Identifier>,
|
||||||
limit: Option<i32>,
|
limit: Option<i32>,
|
||||||
|
offset: Option<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QueryBuilder {
|
impl QueryBuilder {
|
||||||
pub fn new(operation: SqlOperation, table: String) -> CustomResult<Self> {
|
pub fn new(operation: SqlOperation, table: String) -> CustomResult<Self> {
|
||||||
Ok(QueryBuilder {
|
Ok(QueryBuilder {
|
||||||
operation,
|
operation,
|
||||||
table: ValidatedValue::new_identifier(table)?,
|
table: Identifier::new(table)?,
|
||||||
fields: Vec::new(),
|
fields: Vec::new(),
|
||||||
params: HashMap::new(),
|
values: HashMap::new(),
|
||||||
where_clause: None,
|
where_clause: None,
|
||||||
order_by: None,
|
order_by: None,
|
||||||
limit: None,
|
limit: None,
|
||||||
|
offset: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build(&self) -> CustomResult<(String, Vec<String>)> {
|
pub fn add_field(&mut self, field: String) -> CustomResult<&mut Self> {
|
||||||
|
self.fields.push(Identifier::new(field)?);
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_value(&mut self, field: String, value: SafeValue) -> CustomResult<&mut Self> {
|
||||||
|
self.values.insert(Identifier::new(field)?, value);
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_condition(&mut self, condition: WhereClause) -> &mut Self {
|
||||||
|
self.where_clause = Some(condition);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(&self) -> CustomResult<(String, Vec<SafeValue>)> {
|
||||||
let mut query = String::new();
|
let mut query = String::new();
|
||||||
let mut values = Vec::new();
|
let mut params = Vec::new();
|
||||||
let mut param_counter = 1;
|
|
||||||
|
|
||||||
match self.operation {
|
match self.operation {
|
||||||
SqlOperation::Select => {
|
SqlOperation::Select => self.build_select(&mut query)?,
|
||||||
let fields = if self.fields.is_empty() {
|
SqlOperation::Insert => self.build_insert(&mut query, &mut params)?,
|
||||||
"*".to_string()
|
SqlOperation::Update => self.build_update(&mut query, &mut params)?,
|
||||||
} else {
|
SqlOperation::Delete => query.push_str(&format!("DELETE FROM {}", self.table.as_str())),
|
||||||
self.fields
|
|
||||||
.iter()
|
|
||||||
.map(|f| f.get().to_string())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join(", ")
|
|
||||||
};
|
|
||||||
query.push_str(&format!("SELECT {} FROM {}", fields, self.table.get()));
|
|
||||||
}
|
|
||||||
SqlOperation::Insert => {
|
|
||||||
let fields: Vec<String> = self.params.keys().map(|k| k.get().to_string()).collect();
|
|
||||||
let placeholders: Vec<String> =
|
|
||||||
(1..=self.params.len()).map(|i| format!("${}", i)).collect();
|
|
||||||
|
|
||||||
query.push_str(&format!(
|
|
||||||
"INSERT INTO {} ({}) VALUES ({})",
|
|
||||||
self.table.get(),
|
|
||||||
fields.join(", "),
|
|
||||||
placeholders.join(", ")
|
|
||||||
));
|
|
||||||
|
|
||||||
values.extend(self.params.values().map(|v| v.get().to_string()));
|
|
||||||
return Ok((query, values));
|
|
||||||
}
|
|
||||||
SqlOperation::Update => {
|
|
||||||
query.push_str(&format!("UPDATE {} SET ", self.table.get()));
|
|
||||||
let set_clauses: Vec<String> = self
|
|
||||||
.params
|
|
||||||
.iter()
|
|
||||||
.map(|(key, _)| {
|
|
||||||
let placeholder = format!("${}", param_counter);
|
|
||||||
values.push(self.params[key].get().to_string());
|
|
||||||
param_counter += 1;
|
|
||||||
format!("{} = {}", key.get(), placeholder)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
query.push_str(&set_clauses.join(", "));
|
|
||||||
}
|
|
||||||
SqlOperation::Delete => {
|
|
||||||
query.push_str(&format!("DELETE FROM {}", self.table.get()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(where_clause) = &self.where_clause {
|
if let Some(where_clause) = &self.where_clause {
|
||||||
query.push_str(" WHERE ");
|
query.push_str(" WHERE ");
|
||||||
let (where_sql, where_values) = self.build_where_clause(where_clause, param_counter)?;
|
let (where_sql, where_params) = self.build_where_clause(where_clause)?;
|
||||||
query.push_str(&where_sql);
|
query.push_str(&where_sql);
|
||||||
values.extend(where_values);
|
params.extend(where_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(order) = &self.order_by {
|
self.build_pagination(&mut query)?;
|
||||||
query.push_str(&format!(" ORDER BY {}", order.get()));
|
Ok((query, params))
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(limit) = self.limit {
|
|
||||||
query.push_str(&format!(" LIMIT {}", limit));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok((query, values))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_where_clause(
|
fn build_select(&self, query: &mut String) -> CustomResult<()> {
|
||||||
&self,
|
let fields = if self.fields.is_empty() {
|
||||||
clause: &WhereClause,
|
"*".to_string()
|
||||||
mut param_counter: i32,
|
} else {
|
||||||
) -> CustomResult<(String, Vec<String>)> {
|
self.fields
|
||||||
let mut values = Vec::new();
|
.iter()
|
||||||
|
.map(|f| f.as_str())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ")
|
||||||
|
};
|
||||||
|
query.push_str(&format!("SELECT {} FROM {}", fields, self.table.as_str()));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_insert(&self, query: &mut String, params: &mut Vec<SafeValue>) -> CustomResult<()> {
|
||||||
|
let mut fields = Vec::new();
|
||||||
|
let mut placeholders = Vec::new();
|
||||||
|
|
||||||
|
for (field, value) in &self.values {
|
||||||
|
fields.push(field.as_str());
|
||||||
|
if matches!(value, SafeValue::Null) {
|
||||||
|
placeholders.push("NULL".to_string());
|
||||||
|
} else {
|
||||||
|
placeholders.push(format!("${}::{}", params.len() + 1, value.get_sql_type()?));
|
||||||
|
params.push(value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
query.push_str(&format!(
|
||||||
|
"INSERT INTO {} ({}) VALUES ({})",
|
||||||
|
self.table.as_str(),
|
||||||
|
fields.join(", "),
|
||||||
|
placeholders.join(", ")
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_update(&self, query: &mut String, params: &mut Vec<SafeValue>) -> CustomResult<()> {
|
||||||
|
query.push_str(&format!("UPDATE {} SET ", self.table.as_str()));
|
||||||
|
|
||||||
|
let mut updates = Vec::new();
|
||||||
|
for (field, value) in &self.values {
|
||||||
|
let set_sql = format!(
|
||||||
|
"{} = {}",
|
||||||
|
field.as_str(),
|
||||||
|
value.to_param_sql(params.len() + 1)?
|
||||||
|
);
|
||||||
|
if !matches!(value, SafeValue::Null) {
|
||||||
|
params.push(value.clone());
|
||||||
|
}
|
||||||
|
updates.push(set_sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
query.push_str(&updates.join(", "));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_delete(&self, query: &mut String) -> CustomResult<()> {
|
||||||
|
query.push_str(&format!("DELETE FROM {}", self.table.as_str()));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_where_clause(&self, clause: &WhereClause) -> CustomResult<(String, Vec<SafeValue>)> {
|
||||||
|
let mut params = Vec::new();
|
||||||
|
let mut param_index = 1; // 添加参数索引计数器
|
||||||
|
|
||||||
let sql = match clause {
|
let sql = match clause {
|
||||||
WhereClause::And(conditions) => {
|
WhereClause::And(conditions) => {
|
||||||
let mut parts = Vec::new();
|
let mut parts = Vec::new();
|
||||||
for condition in conditions {
|
for condition in conditions {
|
||||||
let (sql, mut vals) = self.build_where_clause(condition, param_counter)?;
|
let (sql, mut condition_params) =
|
||||||
param_counter += vals.len() as i32;
|
self.build_where_clause_with_index(condition, param_index)?;
|
||||||
|
param_index += condition_params.len(); // 更新参数索引
|
||||||
parts.push(sql);
|
parts.push(sql);
|
||||||
values.append(&mut vals);
|
params.append(&mut condition_params);
|
||||||
}
|
}
|
||||||
format!("({})", parts.join(" AND "))
|
format!("({})", parts.join(" AND "))
|
||||||
}
|
}
|
||||||
WhereClause::Or(conditions) => {
|
WhereClause::Or(conditions) => {
|
||||||
let mut parts = Vec::new();
|
let mut parts = Vec::new();
|
||||||
for condition in conditions {
|
for condition in conditions {
|
||||||
let (sql, mut vals) = self.build_where_clause(condition, param_counter)?;
|
let (sql, mut condition_params) =
|
||||||
param_counter += vals.len() as i32;
|
self.build_where_clause_with_index(condition, param_index)?;
|
||||||
|
param_index += condition_params.len(); // 更新参数索引
|
||||||
parts.push(sql);
|
parts.push(sql);
|
||||||
values.append(&mut vals);
|
params.append(&mut condition_params);
|
||||||
}
|
}
|
||||||
format!("({})", parts.join(" OR "))
|
format!("({})", parts.join(" OR "))
|
||||||
}
|
}
|
||||||
WhereClause::Condition(cond) => {
|
WhereClause::Condition(condition) => {
|
||||||
if let Some(value) = &cond.value {
|
self.build_condition(condition, &mut params, param_index)?
|
||||||
let placeholder = format!("${}", param_counter);
|
|
||||||
values.push(value.get().to_string());
|
|
||||||
format!(
|
|
||||||
"{} {} {}",
|
|
||||||
cond.field.get(),
|
|
||||||
cond.operator.as_str(),
|
|
||||||
placeholder
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
format!("{} {}", cond.field.get(), cond.operator.as_str())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((sql, values))
|
Ok((sql, params))
|
||||||
}
|
|
||||||
pub fn fields(mut self, fields: Vec<ValidatedValue>) -> Self {
|
|
||||||
self.fields = fields;
|
|
||||||
self
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn params(mut self, params: HashMap<ValidatedValue, ValidatedValue>) -> Self {
|
// 添加新的辅助方法
|
||||||
self.params = params;
|
fn build_where_clause_with_index(
|
||||||
self
|
&self,
|
||||||
|
clause: &WhereClause,
|
||||||
|
start_index: usize,
|
||||||
|
) -> CustomResult<(String, Vec<SafeValue>)> {
|
||||||
|
let mut params = Vec::new();
|
||||||
|
|
||||||
|
let sql = match clause {
|
||||||
|
WhereClause::Condition(condition) => {
|
||||||
|
self.build_condition(condition, &mut params, start_index)?
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let (sql, params_inner) = self.build_where_clause(clause)?;
|
||||||
|
params = params_inner;
|
||||||
|
sql
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((sql, params))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn where_clause(mut self, clause: WhereClause) -> Self {
|
fn build_condition(
|
||||||
self.where_clause = Some(clause);
|
&self,
|
||||||
self
|
condition: &Condition,
|
||||||
|
params: &mut Vec<SafeValue>,
|
||||||
|
param_index: usize,
|
||||||
|
) -> CustomResult<String> {
|
||||||
|
match &condition.value {
|
||||||
|
Some(value) => {
|
||||||
|
let sql = format!(
|
||||||
|
"{} {} {}",
|
||||||
|
condition.field.as_str(),
|
||||||
|
condition.operator.as_str(),
|
||||||
|
value.to_param_sql(param_index)?
|
||||||
|
);
|
||||||
|
if !matches!(value, SafeValue::Null) {
|
||||||
|
params.push(value.clone());
|
||||||
|
}
|
||||||
|
Ok(sql)
|
||||||
|
}
|
||||||
|
None => Ok(format!(
|
||||||
|
"{} {}",
|
||||||
|
condition.field.as_str(),
|
||||||
|
condition.operator.as_str()
|
||||||
|
)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn order_by(mut self, order: ValidatedValue) -> Self {
|
// 构建分页
|
||||||
self.order_by = Some(order);
|
fn build_pagination(&self, query: &mut String) -> CustomResult<()> {
|
||||||
self
|
if let Some(order) = &self.order_by {
|
||||||
}
|
query.push_str(&format!(" ORDER BY {}", order.as_str()));
|
||||||
|
}
|
||||||
|
|
||||||
pub fn limit(mut self, limit: i32) -> Self {
|
if let Some(limit) = self.limit {
|
||||||
self.limit = Some(limit);
|
query.push_str(&format!(" LIMIT {}", limit));
|
||||||
self
|
}
|
||||||
|
|
||||||
|
if let Some(offset) = self.offset {
|
||||||
|
query.push_str(&format!(" OFFSET {}", offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
mod postgresql;
|
mod postgresql;
|
||||||
use crate::config;
|
use crate::config;
|
||||||
use crate::utils::{CustomError, CustomResult};
|
use crate::error::{CustomErrorInto, CustomResult};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -33,7 +33,7 @@ impl Database {
|
|||||||
pub async fn link(database: &config::SqlConfig) -> CustomResult<Self> {
|
pub async fn link(database: &config::SqlConfig) -> CustomResult<Self> {
|
||||||
let db = match database.db_type.as_str() {
|
let db = match database.db_type.as_str() {
|
||||||
"postgresql" => postgresql::Postgresql::connect(database).await?,
|
"postgresql" => postgresql::Postgresql::connect(database).await?,
|
||||||
_ => return Err(CustomError::from_str("unknown database type")),
|
_ => return Err("unknown database type".into_custom_error()),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -44,7 +44,7 @@ impl Database {
|
|||||||
pub async fn initial_setup(database: config::SqlConfig) -> CustomResult<()> {
|
pub async fn initial_setup(database: config::SqlConfig) -> CustomResult<()> {
|
||||||
match database.db_type.as_str() {
|
match database.db_type.as_str() {
|
||||||
"postgresql" => postgresql::Postgresql::initialization(database).await?,
|
"postgresql" => postgresql::Postgresql::initialization(database).await?,
|
||||||
_ => return Err(CustomError::from_str("unknown database type")),
|
_ => return Err("unknown database type".into_custom_error()),
|
||||||
};
|
};
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
use super::{builder, DatabaseTrait};
|
use super::{builder, DatabaseTrait};
|
||||||
use crate::config;
|
use crate::config;
|
||||||
use crate::utils::CustomResult;
|
use crate::error::CustomErrorInto;
|
||||||
|
use crate::error::CustomResult;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use sqlx::{Column, Executor, PgPool, Row};
|
use sqlx::{Column, Executor, PgPool, Row};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::{env, fs};
|
use std::{env, fs};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Postgresql {
|
pub struct Postgresql {
|
||||||
pool: PgPool,
|
pool: PgPool,
|
||||||
@ -70,10 +70,14 @@ impl DatabaseTrait for Postgresql {
|
|||||||
let mut sqlx_query = sqlx::query(&query);
|
let mut sqlx_query = sqlx::query(&query);
|
||||||
|
|
||||||
for value in values {
|
for value in values {
|
||||||
sqlx_query = sqlx_query.bind(value);
|
sqlx_query = sqlx_query.bind(value.to_sql_string()?);
|
||||||
}
|
}
|
||||||
|
|
||||||
let rows = sqlx_query.fetch_all(&self.pool).await?;
|
let rows = sqlx_query.fetch_all(&self.pool).await.map_err(|e| {
|
||||||
|
let (sql, params) = builder.build().unwrap();
|
||||||
|
format!("Err:{}\n,SQL: {}\nParams: {:?}", e.to_string(), sql, params)
|
||||||
|
.into_custom_error()
|
||||||
|
})?;
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
for row in rows {
|
for row in rows {
|
||||||
|
41
backend/src/error.rs
Normal file
41
backend/src/error.rs
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
use rocket::http::Status;
|
||||||
|
use rocket::response::status;
|
||||||
|
|
||||||
|
pub type AppResult<T> = Result<T, status::Custom<String>>;
|
||||||
|
|
||||||
|
pub trait AppResultInto<T> {
|
||||||
|
fn into_app_result(self) -> AppResult<T>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CustomError(String);
|
||||||
|
|
||||||
|
impl std::fmt::Display for CustomError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
write!(f, "{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CustomErrorInto {
|
||||||
|
fn into_custom_error(self) -> CustomError;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomErrorInto for &str {
|
||||||
|
fn into_custom_error(self) -> CustomError {
|
||||||
|
CustomError(self.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: std::error::Error> From<E> for CustomError {
|
||||||
|
fn from(error: E) -> Self {
|
||||||
|
CustomError(error.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type CustomResult<T> = Result<T, CustomError>;
|
||||||
|
|
||||||
|
impl<T> AppResultInto<T> for CustomResult<T> {
|
||||||
|
fn into_app_result(self) -> AppResult<T> {
|
||||||
|
self.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))
|
||||||
|
}
|
||||||
|
}
|
@ -1,59 +1,98 @@
|
|||||||
mod auth;
|
mod auth;
|
||||||
mod config;
|
mod config;
|
||||||
mod database;
|
mod database;
|
||||||
mod manage;
|
|
||||||
mod routes;
|
mod routes;
|
||||||
mod utils;
|
mod utils;
|
||||||
use database::relational;
|
use database::relational;
|
||||||
use rocket::launch;
|
use rocket::Shutdown;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use utils::{AppResult, CustomError, CustomResult};
|
mod error;
|
||||||
|
use error::{CustomErrorInto, CustomResult};
|
||||||
|
|
||||||
struct AppState {
|
pub struct AppState {
|
||||||
db: Arc<Mutex<Option<relational::Database>>>,
|
db: Arc<Mutex<Option<relational::Database>>>,
|
||||||
configure: Arc<Mutex<config::Config>>,
|
configure: Arc<Mutex<config::Config>>,
|
||||||
|
shutdown: Arc<Mutex<Option<Shutdown>>>,
|
||||||
|
restart_progress: Arc<Mutex<bool>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
async fn get_sql(&self) -> CustomResult<relational::Database> {
|
pub fn new(config: config::Config) -> Self {
|
||||||
|
Self {
|
||||||
|
db: Arc::new(Mutex::new(None)),
|
||||||
|
configure: Arc::new(Mutex::new(config)),
|
||||||
|
shutdown: Arc::new(Mutex::new(None)),
|
||||||
|
restart_progress: Arc::new(Mutex::new(false)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn sql_get(&self) -> CustomResult<relational::Database> {
|
||||||
self.db
|
self.db
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
.clone()
|
.clone()
|
||||||
.ok_or_else(|| CustomError::from_str("Database not initialized"))
|
.ok_or("数据库未连接".into_custom_error())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn link_sql(&self, config: &config::SqlConfig) -> CustomResult<()> {
|
pub async fn sql_link(&self, config: &config::SqlConfig) -> CustomResult<()> {
|
||||||
let database = relational::Database::link(config).await?;
|
let database = relational::Database::link(config).await?;
|
||||||
*self.db.lock().await = Some(database);
|
*self.db.lock().await = Some(database);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[launch]
|
pub async fn set_shutdown(&self, shutdown: Shutdown) {
|
||||||
async fn rocket() -> _ {
|
*self.shutdown.lock().await = Some(shutdown);
|
||||||
let config = config::Config::read().expect("Failed to read config");
|
|
||||||
|
|
||||||
let state = AppState {
|
|
||||||
db: Arc::new(Mutex::new(None)),
|
|
||||||
configure: Arc::new(Mutex::new(config.clone())),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut rocket_builder = rocket::build().manage(state);
|
|
||||||
|
|
||||||
if config.info.install {
|
|
||||||
if let Some(state) = rocket_builder.state::<AppState>() {
|
|
||||||
state
|
|
||||||
.link_sql(&config.sql_config)
|
|
||||||
.await
|
|
||||||
.expect("Failed to connect to database");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
rocket_builder = rocket_builder.mount("/", rocket::routes![routes::intsall::install]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rocket_builder = rocket_builder.mount("/auth/token", routes::jwt_routes());
|
pub async fn trigger_restart(&self) -> CustomResult<()> {
|
||||||
|
*self.restart_progress.lock().await = true;
|
||||||
|
|
||||||
rocket_builder
|
self.shutdown
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.take()
|
||||||
|
.ok_or("未能获取rocket的shutdown".into_custom_error())?
|
||||||
|
.notify();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::main]
|
||||||
|
async fn main() -> CustomResult<()> {
|
||||||
|
let config = config::Config::read()?;
|
||||||
|
|
||||||
|
let state = AppState::new(config.clone());
|
||||||
|
|
||||||
|
if config.info.install {
|
||||||
|
state.sql_link(&config.sql_config).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let state = Arc::new(state);
|
||||||
|
|
||||||
|
let rocket_builder = rocket::build().manage(state.clone());
|
||||||
|
|
||||||
|
let rocket_builder = if !config.info.install {
|
||||||
|
rocket_builder.mount("/", rocket::routes![routes::install::install])
|
||||||
|
} else {
|
||||||
|
rocket_builder.mount("/auth/token", routes::jwt_routes())
|
||||||
|
};
|
||||||
|
|
||||||
|
let rocket = rocket_builder.ignite().await?;
|
||||||
|
|
||||||
|
rocket
|
||||||
|
.state::<Arc<AppState>>()
|
||||||
|
.ok_or("未能获取AppState".into_custom_error())?
|
||||||
|
.set_shutdown(rocket.shutdown())
|
||||||
|
.await;
|
||||||
|
|
||||||
|
rocket.launch().await?;
|
||||||
|
|
||||||
|
let restart_progress = *state.restart_progress.lock().await;
|
||||||
|
if restart_progress {
|
||||||
|
let current_exe = std::env::current_exe()?;
|
||||||
|
let _ = std::process::Command::new(current_exe).spawn();
|
||||||
|
}
|
||||||
|
std::process::exit(0);
|
||||||
}
|
}
|
||||||
|
@ -1,57 +0,0 @@
|
|||||||
use rocket::shutdown::Shutdown;
|
|
||||||
use std::env;
|
|
||||||
use std::path::Path;
|
|
||||||
use std::process::{exit, Command};
|
|
||||||
use tokio::signal;
|
|
||||||
|
|
||||||
// 应用管理器
|
|
||||||
pub struct AppManager {
|
|
||||||
shutdown: Shutdown,
|
|
||||||
executable_path: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AppManager {
|
|
||||||
pub fn new(shutdown: Shutdown) -> Self {
|
|
||||||
let executable_path = env::current_exe()
|
|
||||||
.expect("Failed to get executable path")
|
|
||||||
.to_string_lossy()
|
|
||||||
.into_owned();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
shutdown,
|
|
||||||
executable_path,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 优雅关闭
|
|
||||||
pub async fn graceful_shutdown(&self) {
|
|
||||||
println!("Initiating graceful shutdown...");
|
|
||||||
|
|
||||||
// 触发 Rocket 的优雅关闭
|
|
||||||
self.shutdown.notify();
|
|
||||||
|
|
||||||
// 等待一段时间以确保连接正确关闭
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 重启应用
|
|
||||||
pub async fn restart(&self) -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
println!("Preparing to restart application...");
|
|
||||||
|
|
||||||
// 执行优雅关闭
|
|
||||||
self.graceful_shutdown().await;
|
|
||||||
|
|
||||||
// 在新进程中启动应用
|
|
||||||
if cfg!(target_os = "windows") {
|
|
||||||
Command::new("cmd")
|
|
||||||
.args(&["/C", &self.executable_path])
|
|
||||||
.spawn()?;
|
|
||||||
} else {
|
|
||||||
Command::new(&self.executable_path).spawn()?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 退出当前进程
|
|
||||||
println!("Application restarting...");
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,15 +1,102 @@
|
|||||||
use crate::auth;
|
use crate::auth;
|
||||||
use crate::{AppResult, AppState};
|
use crate::database::relational::builder;
|
||||||
|
use crate::error::{AppResult, AppResultInto};
|
||||||
|
use crate::AppState;
|
||||||
use chrono::Duration;
|
use chrono::Duration;
|
||||||
use rocket::{get, http::Status, response::status, State};
|
use jwt_compact::Token;
|
||||||
|
use rocket::{
|
||||||
|
http::Status,
|
||||||
|
post,
|
||||||
|
response::status,
|
||||||
|
serde::json::{Json, Value},
|
||||||
|
State,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::sync::Arc;
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
pub struct TokenSystemData {
|
||||||
|
name: String,
|
||||||
|
password: String,
|
||||||
|
}
|
||||||
|
#[post("/system", format = "application/json", data = "<data>")]
|
||||||
|
pub async fn token_system(
|
||||||
|
state: &State<Arc<AppState>>,
|
||||||
|
data: Json<TokenSystemData>,
|
||||||
|
) -> AppResult<String> {
|
||||||
|
let name_condition = builder::Condition::new(
|
||||||
|
"person_name".to_string(),
|
||||||
|
builder::Operator::Eq,
|
||||||
|
Some(builder::SafeValue::Text(
|
||||||
|
data.name.to_string(),
|
||||||
|
builder::ValidationLevel::Relaxed,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_app_result()?;
|
||||||
|
|
||||||
|
let email_condition = builder::Condition::new(
|
||||||
|
"person_email".to_string(),
|
||||||
|
builder::Operator::Eq,
|
||||||
|
Some(builder::SafeValue::Text(
|
||||||
|
"author@lsy22.com".to_string(),
|
||||||
|
builder::ValidationLevel::Relaxed,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_app_result()?;
|
||||||
|
|
||||||
|
let level_condition = builder::Condition::new(
|
||||||
|
"person_level".to_string(),
|
||||||
|
builder::Operator::Eq,
|
||||||
|
Some(builder::SafeValue::Enum(
|
||||||
|
"administrators".to_string(),
|
||||||
|
"privilege_level".to_string(),
|
||||||
|
builder::ValidationLevel::Standard,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_app_result()?;
|
||||||
|
|
||||||
|
let where_clause = builder::WhereClause::And(vec![
|
||||||
|
builder::WhereClause::Condition(name_condition),
|
||||||
|
builder::WhereClause::Condition(email_condition),
|
||||||
|
builder::WhereClause::Condition(level_condition),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let mut builder =
|
||||||
|
builder::QueryBuilder::new(builder::SqlOperation::Select, String::from("persons"))
|
||||||
|
.into_app_result()?;
|
||||||
|
|
||||||
|
let builder = builder
|
||||||
|
.add_field("person_password".to_string())
|
||||||
|
.into_app_result()?;
|
||||||
|
|
||||||
|
let sql_builder = builder.add_condition(where_clause);
|
||||||
|
let values = state
|
||||||
|
.sql_get()
|
||||||
|
.await
|
||||||
|
.into_app_result()?
|
||||||
|
.get_db()
|
||||||
|
.execute_query(&sql_builder)
|
||||||
|
.await
|
||||||
|
.into_app_result()?;
|
||||||
|
|
||||||
|
let password = values
|
||||||
|
.first()
|
||||||
|
.ok_or(status::Custom(
|
||||||
|
Status::NotFound,
|
||||||
|
String::from("该用户并非系统用户"),
|
||||||
|
))?
|
||||||
|
.get("person_password")
|
||||||
|
.ok_or(status::Custom(
|
||||||
|
Status::NotFound,
|
||||||
|
String::from("该用户密码丢失"),
|
||||||
|
))?;
|
||||||
|
|
||||||
|
auth::bcrypt::verify_hash(&data.password, password).into_app_result()?;
|
||||||
|
|
||||||
#[get("/system")]
|
|
||||||
pub async fn token_system(_state: &State<AppState>) -> AppResult<status::Custom<String>> {
|
|
||||||
let claims = auth::jwt::CustomClaims {
|
let claims = auth::jwt::CustomClaims {
|
||||||
name: "system".into(),
|
name: "system".into(),
|
||||||
};
|
};
|
||||||
|
let token = auth::jwt::generate_jwt(claims, Duration::seconds(1)).into_app_result()?;
|
||||||
|
|
||||||
auth::jwt::generate_jwt(claims, Duration::seconds(1))
|
Ok(token)
|
||||||
.map(|token| status::Custom(Status::Ok, token))
|
|
||||||
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))
|
|
||||||
}
|
}
|
||||||
|
10
backend/src/routes/configure.rs
Normal file
10
backend/src/routes/configure.rs
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
use super::SystemToken;
|
||||||
|
use crate::error::AppResult;
|
||||||
|
use rocket::{
|
||||||
|
get,
|
||||||
|
http::Status,
|
||||||
|
post,
|
||||||
|
response::status,
|
||||||
|
serde::json::{Json, Value},
|
||||||
|
Request,
|
||||||
|
};
|
@ -1,12 +1,13 @@
|
|||||||
use crate::auth;
|
use crate::auth;
|
||||||
use crate::database::relational;
|
use crate::database::relational;
|
||||||
|
use crate::error::{AppResult, AppResultInto};
|
||||||
use crate::routes::person;
|
use crate::routes::person;
|
||||||
use crate::utils::AppResult;
|
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
use crate::{config, utils};
|
use crate::{config, utils};
|
||||||
use chrono::Duration;
|
use chrono::Duration;
|
||||||
use rocket::{http::Status, post, response::status, serde::json::Json, State};
|
use rocket::{http::Status, post, response::status, serde::json::Json, State};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize)]
|
#[derive(Deserialize, Serialize)]
|
||||||
pub struct InstallData {
|
pub struct InstallData {
|
||||||
@ -25,7 +26,7 @@ pub struct InstallReplyData {
|
|||||||
#[post("/install", format = "application/json", data = "<data>")]
|
#[post("/install", format = "application/json", data = "<data>")]
|
||||||
pub async fn install(
|
pub async fn install(
|
||||||
data: Json<InstallData>,
|
data: Json<InstallData>,
|
||||||
state: &State<AppState>,
|
state: &State<Arc<AppState>>,
|
||||||
) -> AppResult<status::Custom<Json<InstallReplyData>>> {
|
) -> AppResult<status::Custom<Json<InstallReplyData>>> {
|
||||||
let mut config = state.configure.lock().await;
|
let mut config = state.configure.lock().await;
|
||||||
if config.info.install {
|
if config.info.install {
|
||||||
@ -39,20 +40,14 @@ pub async fn install(
|
|||||||
|
|
||||||
relational::Database::initial_setup(data.sql_config.clone())
|
relational::Database::initial_setup(data.sql_config.clone())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?;
|
.into_app_result()?;
|
||||||
|
|
||||||
let _ = auth::jwt::generate_key();
|
let _ = auth::jwt::generate_key();
|
||||||
|
|
||||||
config.info.install = true;
|
config.info.install = true;
|
||||||
|
|
||||||
state
|
state.sql_link(&data.sql_config).await.into_app_result()?;
|
||||||
.link_sql(data.sql_config.clone())
|
let sql = state.sql_get().await.into_app_result()?;
|
||||||
.await
|
|
||||||
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?;
|
|
||||||
let sql = state
|
|
||||||
.get_sql()
|
|
||||||
.await
|
|
||||||
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?;
|
|
||||||
|
|
||||||
let system_name = utils::generate_random_string(20);
|
let system_name = utils::generate_random_string(20);
|
||||||
let system_password = utils::generate_random_string(20);
|
let system_password = utils::generate_random_string(20);
|
||||||
@ -63,30 +58,35 @@ pub async fn install(
|
|||||||
name: data.name.clone(),
|
name: data.name.clone(),
|
||||||
email: data.email,
|
email: data.email,
|
||||||
password: data.password,
|
password: data.password,
|
||||||
|
level: "administrators".to_string(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()));
|
.into_app_result()?;
|
||||||
|
|
||||||
let _ = person::insert(
|
let _ = person::insert(
|
||||||
&sql,
|
&sql,
|
||||||
person::RegisterData {
|
person::RegisterData {
|
||||||
name: system_name.clone(),
|
name: system_name.clone(),
|
||||||
email: String::from("author@lsy22.com"),
|
email: String::from("author@lsy22.com"),
|
||||||
password: system_name.clone(),
|
password: system_password.clone(),
|
||||||
|
level: "administrators".to_string(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()));
|
.into_app_result()?;
|
||||||
|
|
||||||
let token = auth::jwt::generate_jwt(
|
let token = auth::jwt::generate_jwt(
|
||||||
auth::jwt::CustomClaims {
|
auth::jwt::CustomClaims {
|
||||||
name: data.name.clone(),
|
name: data.name.clone(),
|
||||||
},
|
},
|
||||||
Duration::days(7),
|
Duration::days(7),
|
||||||
)
|
)
|
||||||
.map_err(|e| status::Custom(Status::Unauthorized, e.to_string()))?;
|
.into_app_result()?;
|
||||||
|
|
||||||
config::Config::write(config.clone())
|
config::Config::write(config.clone()).into_app_result()?;
|
||||||
.map_err(|e| status::Custom(Status::InternalServerError, e.to_string()))?;
|
|
||||||
|
state.trigger_restart().await.into_app_result()?;
|
||||||
Ok(status::Custom(
|
Ok(status::Custom(
|
||||||
Status::Ok,
|
Status::Ok,
|
||||||
Json(InstallReplyData {
|
Json(InstallReplyData {
|
@ -1,9 +1,55 @@
|
|||||||
pub mod auth;
|
pub mod auth;
|
||||||
pub mod intsall;
|
pub mod configure;
|
||||||
|
pub mod install;
|
||||||
pub mod person;
|
pub mod person;
|
||||||
pub mod theme;
|
use rocket::http::Status;
|
||||||
|
use rocket::request::{FromRequest, Outcome, Request};
|
||||||
use rocket::routes;
|
use rocket::routes;
|
||||||
|
|
||||||
|
pub struct Token(String);
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for Token {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
|
let token = request
|
||||||
|
.headers()
|
||||||
|
.get_one("Authorization")
|
||||||
|
.map(|value| value.replace("Bearer ", ""));
|
||||||
|
|
||||||
|
match token {
|
||||||
|
Some(token) => Outcome::Success(Token(token)),
|
||||||
|
None => Outcome::Success(Token("".to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SystemToken(String);
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for SystemToken {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
|
let token = request
|
||||||
|
.headers()
|
||||||
|
.get_one("Authorization")
|
||||||
|
.map(|value| value.replace("Bearer ", ""));
|
||||||
|
|
||||||
|
match token {
|
||||||
|
Some(token) => {
|
||||||
|
if token == "system" {
|
||||||
|
Outcome::Success(SystemToken(token))
|
||||||
|
} else {
|
||||||
|
Outcome::Error((Status::Unauthorized, ()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => Outcome::Error((Status::Unauthorized, ())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn jwt_routes() -> Vec<rocket::Route> {
|
pub fn jwt_routes() -> Vec<rocket::Route> {
|
||||||
routes![auth::token::token_system]
|
routes![auth::token::token_system]
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
|
use crate::auth;
|
||||||
|
use crate::auth::bcrypt;
|
||||||
use crate::database::{relational, relational::builder};
|
use crate::database::{relational, relational::builder};
|
||||||
use crate::utils::CustomResult;
|
use crate::error::{CustomErrorInto, CustomResult};
|
||||||
use crate::{config, utils};
|
use crate::{config, utils};
|
||||||
use bcrypt::{hash, DEFAULT_COST};
|
|
||||||
use rocket::{get, http::Status, post, response::status, serde::json::Json, State};
|
use rocket::{get, http::Status, post, response::status, serde::json::Json, State};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -16,28 +17,36 @@ pub struct RegisterData {
|
|||||||
pub name: String,
|
pub name: String,
|
||||||
pub email: String,
|
pub email: String,
|
||||||
pub password: String,
|
pub password: String,
|
||||||
|
pub level: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn insert(sql: &relational::Database, data: RegisterData) -> CustomResult<()> {
|
pub async fn insert(sql: &relational::Database, data: RegisterData) -> CustomResult<()> {
|
||||||
let hashed_password = hash(data.password, DEFAULT_COST).expect("Failed to hash password");
|
let mut builder =
|
||||||
|
builder::QueryBuilder::new(builder::SqlOperation::Insert, "persons".to_string())?;
|
||||||
|
|
||||||
let mut user_params = HashMap::new();
|
let password_hash = auth::bcrypt::generate_hash(&data.password)?;
|
||||||
user_params.insert(
|
|
||||||
builder::ValidatedValue::Identifier(String::from("person_name")),
|
|
||||||
builder::ValidatedValue::PlainText(data.name),
|
|
||||||
);
|
|
||||||
user_params.insert(
|
|
||||||
builder::ValidatedValue::Identifier(String::from("person_email")),
|
|
||||||
builder::ValidatedValue::PlainText(data.email),
|
|
||||||
);
|
|
||||||
user_params.insert(
|
|
||||||
builder::ValidatedValue::Identifier(String::from("person_password")),
|
|
||||||
builder::ValidatedValue::PlainText(hashed_password),
|
|
||||||
);
|
|
||||||
|
|
||||||
let builder =
|
builder
|
||||||
builder::QueryBuilder::new(builder::SqlOperation::Insert, String::from("persons"))?
|
.set_value(
|
||||||
.params(user_params);
|
"person_name".to_string(),
|
||||||
|
builder::SafeValue::Text(data.name.to_string(), builder::ValidationLevel::Relaxed),
|
||||||
|
)?
|
||||||
|
.set_value(
|
||||||
|
"person_email".to_string(),
|
||||||
|
builder::SafeValue::Text(data.email.to_string(), builder::ValidationLevel::Relaxed),
|
||||||
|
)?
|
||||||
|
.set_value(
|
||||||
|
"person_password".to_string(),
|
||||||
|
builder::SafeValue::Text(password_hash, builder::ValidationLevel::Relaxed),
|
||||||
|
)?
|
||||||
|
.set_value(
|
||||||
|
"person_level".to_string(),
|
||||||
|
builder::SafeValue::Enum(
|
||||||
|
data.level.to_string(),
|
||||||
|
"privilege_level".to_string(),
|
||||||
|
builder::ValidationLevel::Standard,
|
||||||
|
),
|
||||||
|
)?;
|
||||||
|
|
||||||
sql.get_db().execute_query(&builder).await?;
|
sql.get_db().execute_query(&builder).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
use crate::utils::AppResult;
|
|
||||||
use rocket::{
|
|
||||||
http::Status,
|
|
||||||
post,
|
|
||||||
response::status,
|
|
||||||
serde::json::{Json, Value},
|
|
||||||
};
|
|
||||||
|
|
||||||
#[post("/current", format = "application/json", data = "<data>")]
|
|
||||||
pub fn theme_current(data: Json<String>) -> AppResult<status::Custom<Json<Value>>> {
|
|
||||||
Ok(status::Custom(Status::Ok, Json(Value::Object(()))))
|
|
||||||
}
|
|
@ -1,5 +1,4 @@
|
|||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rocket::response::status;
|
|
||||||
|
|
||||||
pub fn generate_random_string(length: usize) -> String {
|
pub fn generate_random_string(length: usize) -> String {
|
||||||
let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
|
let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
|
||||||
@ -8,30 +7,3 @@ pub fn generate_random_string(length: usize) -> String {
|
|||||||
.map(|_| *charset.choose(&mut rng).unwrap() as char)
|
.map(|_| *charset.choose(&mut rng).unwrap() as char)
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct CustomError(String);
|
|
||||||
|
|
||||||
impl std::fmt::Display for CustomError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
||||||
write!(f, "{}", self.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> From<T> for CustomError
|
|
||||||
where
|
|
||||||
T: std::error::Error + Send + 'static,
|
|
||||||
{
|
|
||||||
fn from(error: T) -> Self {
|
|
||||||
CustomError(error.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CustomError {
|
|
||||||
pub fn from_str(error: &str) -> Self {
|
|
||||||
CustomError(error.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type CustomResult<T> = Result<T, CustomError>;
|
|
||||||
|
|
||||||
pub type AppResult<T> = Result<T, status::Custom<String>>;
|
|
||||||
|
@ -1 +0,0 @@
|
|||||||
VITE_API_BASE_URL = 1
|
|
Loading…
Reference in New Issue
Block a user