后端:实现了对执行sql的参数化保证了安全
This commit is contained in:
parent
5ca72e42cf
commit
9d88e9a159
@ -9,7 +9,8 @@ use std::{env, fs};
|
||||
#[derive(Deserialize)]
|
||||
pub struct Config {
|
||||
pub info: Info, // 配置信息
|
||||
pub db_config: DbConfig, // 数据库配置
|
||||
pub sql_config: SqlConfig, // 关系型数据库配置
|
||||
// pub no_sql_config:NoSqlConfig, 非关系型数据库配置
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@ -19,7 +20,7 @@ pub struct Info {
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct DbConfig {
|
||||
pub struct SqlConfig {
|
||||
pub db_type: String, // 数据库类型
|
||||
pub address: String, // 地址
|
||||
pub prot: u32, // 端口
|
||||
@ -28,6 +29,17 @@ pub struct DbConfig {
|
||||
pub db_name: String, // 数据库名称
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct NoSqlConfig {
|
||||
pub db_type: String, // 数据库类型
|
||||
pub address: String, // 地址
|
||||
pub prot: u32, // 端口
|
||||
pub user: String, // 用户名
|
||||
pub password: String, // 密码
|
||||
pub db_name: String, // 数据库名称
|
||||
}
|
||||
|
||||
|
||||
impl Config {
|
||||
/// 读取配置文件
|
||||
pub fn read() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
|
@ -1,7 +1,9 @@
|
||||
// sql/mod.rs
|
||||
/*
|
||||
定义了数据库具有的特征和方法
|
||||
// File path: src/database/relational/mod.rs
|
||||
|
||||
/**
|
||||
本模块定义了数据库的特征和方法,包括查询构建器和数据库连接。
|
||||
*/
|
||||
|
||||
mod postgresql;
|
||||
use std::collections::HashMap;
|
||||
use crate::config;
|
||||
@ -9,13 +11,43 @@ use async_trait::async_trait;
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SqlOperation {
|
||||
Select, // 查询操作
|
||||
Insert, // 插入操作
|
||||
Update, // 更新操作
|
||||
Delete, // 删除操作
|
||||
}
|
||||
|
||||
/// 查询构建器结构
|
||||
pub struct QueryBuilder {
|
||||
operation: SqlOperation, // SQL操作类型
|
||||
table: String, // 表名
|
||||
fields: Vec<String>, // 查询字段
|
||||
params: HashMap<String, String>, // 插入或更新的参数
|
||||
where_conditions: HashMap<String, String>, // WHERE条件
|
||||
order_by: Option<String>, // 排序字段
|
||||
limit: Option<i32>, // 限制返回的记录数
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait DatabaseTrait: Send + Sync {
|
||||
// 连接数据库
|
||||
async fn connect(database: config::DbConfig) -> Result<Self, Box<dyn Error>> where Self: Sized;
|
||||
// 执行查询
|
||||
async fn query<'a>(&'a self, query: String) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>>;
|
||||
/**
|
||||
连接数据库
|
||||
@param database 数据库配置
|
||||
@return Result<Self, Box<dyn Error>> 返回数据库实例或错误
|
||||
*/
|
||||
async fn connect(database: config::SqlConfig) -> Result<Self, Box<dyn Error>> where Self: Sized;
|
||||
|
||||
/**
|
||||
执行查询
|
||||
@param query SQL查询语句
|
||||
@return Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>> 返回查询结果或错误
|
||||
*/
|
||||
async fn execute_query<'a>(
|
||||
&'a self,
|
||||
builder: &QueryBuilder,
|
||||
) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>> ;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -25,13 +57,20 @@ pub struct Database {
|
||||
}
|
||||
|
||||
impl Database {
|
||||
// 获取当前数据库实例
|
||||
/**
|
||||
获取当前数据库实例
|
||||
@return &Box<dyn DatabaseTrait> 返回数据库实例的引用
|
||||
*/
|
||||
pub fn get_db(&self) -> &Box<dyn DatabaseTrait> {
|
||||
&self.db
|
||||
}
|
||||
|
||||
// 初始化数据库
|
||||
pub async fn init(database: config::DbConfig) -> Result<Self, Box<dyn Error>> {
|
||||
/**
|
||||
初始化数据库
|
||||
@param database 数据库配置
|
||||
@return Result<Self, Box<dyn Error>> 返回数据库实例或错误
|
||||
*/
|
||||
pub async fn init(database: config::SqlConfig) -> Result<Self, Box<dyn Error>> {
|
||||
let db = match database.db_type.as_str() {
|
||||
"postgresql" => postgresql::Postgresql::connect(database).await?,
|
||||
_ => return Err("unknown database type".into()),
|
||||
@ -40,3 +79,193 @@ impl Database {
|
||||
Ok(Self { db: Arc::new(Box::new(db)) })
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryBuilder {
|
||||
/**
|
||||
创建新的查询构建器
|
||||
@param operation SQL操作类型
|
||||
@param table 表名
|
||||
@return Self 返回新的查询构建器实例
|
||||
*/
|
||||
pub fn new(operation: SqlOperation, table: &str) -> Self {
|
||||
QueryBuilder {
|
||||
operation,
|
||||
table: table.to_string(),
|
||||
fields: Vec::new(),
|
||||
params: HashMap::new(),
|
||||
where_conditions: HashMap::new(),
|
||||
order_by: None,
|
||||
limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
构建SQL语句和参数
|
||||
@return (String, Vec<String>) 返回构建的SQL语句和参数值
|
||||
*/
|
||||
pub fn build(&self) -> (String, Vec<String>) {
|
||||
let mut query = String::new();
|
||||
let mut values = Vec::new();
|
||||
let mut param_counter = 1;
|
||||
|
||||
match self.operation {
|
||||
SqlOperation::Select => {
|
||||
// SELECT 操作
|
||||
let fields = if self.fields.is_empty() {
|
||||
"*".to_string()
|
||||
} else {
|
||||
self.fields.join(", ")
|
||||
};
|
||||
|
||||
query.push_str(&format!("SELECT {} FROM {}", fields, self.table));
|
||||
|
||||
// 添加 WHERE 条件
|
||||
if !self.where_conditions.is_empty() {
|
||||
let conditions: Vec<String> = self.where_conditions
|
||||
.iter()
|
||||
.map(|(key, _)| {
|
||||
let placeholder = format!("${}", param_counter);
|
||||
values.push(self.where_conditions[key].clone());
|
||||
param_counter += 1;
|
||||
format!("{} = {}", key, placeholder)
|
||||
})
|
||||
.collect();
|
||||
query.push_str(" WHERE ");
|
||||
query.push_str(&conditions.join(" AND "));
|
||||
}
|
||||
},
|
||||
SqlOperation::Insert => {
|
||||
// INSERT 操作
|
||||
let fields: Vec<String> = self.params.keys().cloned().collect();
|
||||
let placeholders: Vec<String> = (1..=self.params.len())
|
||||
.map(|i| format!("${}", i))
|
||||
.collect();
|
||||
|
||||
query.push_str(&format!(
|
||||
"INSERT INTO {} ({}) VALUES ({})",
|
||||
self.table,
|
||||
fields.join(", "),
|
||||
placeholders.join(", ")
|
||||
));
|
||||
|
||||
// 收集参数值
|
||||
for field in fields {
|
||||
values.push(self.params[&field].clone());
|
||||
}
|
||||
},
|
||||
SqlOperation::Update => {
|
||||
// UPDATE 操作
|
||||
query.push_str(&format!("UPDATE {}", self.table));
|
||||
|
||||
let set_clauses: Vec<String> = self.params
|
||||
.keys()
|
||||
.map(|key| {
|
||||
let placeholder = format!("${}", param_counter);
|
||||
values.push(self.params[key].clone());
|
||||
param_counter += 1;
|
||||
format!("{} = {}", key, placeholder)
|
||||
})
|
||||
.collect();
|
||||
|
||||
query.push_str(" SET ");
|
||||
query.push_str(&set_clauses.join(", "));
|
||||
|
||||
// 添加 WHERE 条件
|
||||
if !self.where_conditions.is_empty() {
|
||||
let conditions: Vec<String> = self.where_conditions
|
||||
.iter()
|
||||
.map(|(key, _)| {
|
||||
let placeholder = format!("${}", param_counter);
|
||||
values.push(self.where_conditions[key].clone());
|
||||
param_counter += 1;
|
||||
format!("{} = {}", key, placeholder)
|
||||
})
|
||||
.collect();
|
||||
query.push_str(" WHERE ");
|
||||
query.push_str(&conditions.join(" AND "));
|
||||
}
|
||||
},
|
||||
SqlOperation::Delete => {
|
||||
// DELETE 操作
|
||||
query.push_str(&format!("DELETE FROM {}", self.table));
|
||||
|
||||
// 添加 WHERE 条件
|
||||
if !self.where_conditions.is_empty() {
|
||||
let conditions: Vec<String> = self.where_conditions
|
||||
.iter()
|
||||
.map(|(key, _)| {
|
||||
let placeholder = format!("${}", param_counter);
|
||||
values.push(self.where_conditions[key].clone());
|
||||
param_counter += 1;
|
||||
format!("{} = {}", key, placeholder)
|
||||
})
|
||||
.collect();
|
||||
query.push_str(" WHERE ");
|
||||
query.push_str(&conditions.join(" AND "));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加 ORDER BY
|
||||
if let Some(order) = &self.order_by {
|
||||
query.push_str(&format!(" ORDER BY {}", order));
|
||||
}
|
||||
|
||||
// 添加 LIMIT
|
||||
if let Some(limit) = self.limit {
|
||||
query.push_str(&format!(" LIMIT {}", limit));
|
||||
}
|
||||
|
||||
(query, values)
|
||||
}
|
||||
|
||||
/**
|
||||
设置查询字段
|
||||
@param fields 字段列表
|
||||
@return &mut Self 返回可变引用以便链式调用
|
||||
*/
|
||||
pub fn fields(&mut self, fields: Vec<String>) -> &mut Self {
|
||||
self.fields = fields;
|
||||
self
|
||||
}
|
||||
|
||||
/**
|
||||
设置参数
|
||||
@param params 参数键值对
|
||||
@return &mut Self 返回可变引用以便链式调用
|
||||
*/
|
||||
pub fn params(&mut self, params: HashMap<String, String>) -> &mut Self {
|
||||
self.params = params;
|
||||
self
|
||||
}
|
||||
|
||||
/**
|
||||
设置WHERE条件
|
||||
@param conditions 条件键值对
|
||||
@return &mut Self 返回可变引用以便链式调用
|
||||
*/
|
||||
pub fn where_conditions(&mut self, conditions: HashMap<String, String>) -> &mut Self {
|
||||
self.where_conditions = conditions;
|
||||
self
|
||||
}
|
||||
|
||||
/**
|
||||
设置排序
|
||||
@param order 排序字段
|
||||
@return &mut Self 返回可变引用以便链式调用
|
||||
*/
|
||||
pub fn order_by(&mut self, order: &str) -> &mut Self {
|
||||
self.order_by = Some(order.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/**
|
||||
设置限制
|
||||
@param limit 限制记录数
|
||||
@return &mut Self 返回可变引用以便链式调用
|
||||
*/
|
||||
pub fn limit(&mut self, limit: i32) -> &mut Self {
|
||||
self.limit = Some(limit);
|
||||
self
|
||||
}
|
||||
}
|
@ -1,21 +1,34 @@
|
||||
// sql/psotgresql.rs
|
||||
// src/database/relational/postgresql/mod.rs
|
||||
/*
|
||||
为postgresql数据库实现具体的方法
|
||||
* 该模块实现了与PostgreSQL数据库的交互功能。
|
||||
* 包含连接池的结构体定义,提供数据库操作的基础。
|
||||
*/
|
||||
use super::DatabaseTrait;
|
||||
use super::{DatabaseTrait, QueryBuilder};
|
||||
use crate::config;
|
||||
use async_trait::async_trait;
|
||||
use sqlx::{Column, PgPool, Row};
|
||||
use std::{collections::HashMap, error::Error};
|
||||
|
||||
#[derive(Clone)]
|
||||
/// PostgreSQL数据库连接池结构体
|
||||
pub struct Postgresql {
|
||||
/// 数据库连接池
|
||||
pool: PgPool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl DatabaseTrait for Postgresql {
|
||||
async fn connect(db_config: config::DbConfig) -> Result<Self, Box<dyn Error>> {
|
||||
/**
|
||||
* 连接到PostgreSQL数据库并返回Postgresql实例。
|
||||
*
|
||||
* # 参数
|
||||
* - `db_config`: 数据库配置
|
||||
*
|
||||
* # 返回
|
||||
* - `Result<Self, Box<dyn Error>>`: 连接结果
|
||||
*/
|
||||
|
||||
async fn connect(db_config: config::SqlConfig) -> Result<Self, Box<dyn Error>> {
|
||||
let connection_str = format!(
|
||||
"postgres://{}:{}@{}:{}/{}",
|
||||
db_config.user, db_config.password, db_config.address, db_config.prot, db_config.db_name
|
||||
@ -30,34 +43,47 @@ impl DatabaseTrait for Postgresql {
|
||||
Ok(Postgresql { pool })
|
||||
}
|
||||
|
||||
/**
|
||||
/**
|
||||
* 异步执行查询并返回结果。
|
||||
*
|
||||
* # 参数
|
||||
* - `builder`: 查询构建器
|
||||
*
|
||||
* # 返回
|
||||
* - `Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>>`: 查询结果
|
||||
*/
|
||||
async fn query<'a>(
|
||||
async fn execute_query<'a>(
|
||||
&'a self,
|
||||
query: String,
|
||||
builder: &QueryBuilder,
|
||||
) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>> {
|
||||
// 执行查询并获取所有行
|
||||
let rows = sqlx::query(&query)
|
||||
let (query, values) = builder.build();
|
||||
|
||||
// 构建查询
|
||||
let mut sqlx_query = sqlx::query(&query);
|
||||
|
||||
// 绑定参数
|
||||
for value in values {
|
||||
sqlx_query = sqlx_query.bind(value);
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
let rows = sqlx_query
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
|
||||
|
||||
// 存储查询结果
|
||||
// 处理结果
|
||||
let mut results = Vec::new();
|
||||
|
||||
// 遍历每一行并构建结果映射
|
||||
for row in rows {
|
||||
let mut map = HashMap::new();
|
||||
for column in row.columns() {
|
||||
// 获取列的值,若失败则使用默认值
|
||||
let value: String = row.try_get(column.name()).unwrap_or_default();
|
||||
map.insert(column.name().to_string(), value);
|
||||
}
|
||||
results.push(map);
|
||||
}
|
||||
|
||||
// 返回查询结果
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ static DB: Lazy<Arc<Mutex<Option<relational::Database>>>> = Lazy::new(|| Arc::ne
|
||||
/**
|
||||
* 初始化数据库连接
|
||||
*/
|
||||
async fn init_db(database: config::DbConfig) -> Result<(), Box<dyn std::error::Error>> {
|
||||
async fn init_db(database: config::SqlConfig) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let database = relational::Database::init(database).await?; // 初始化数据库
|
||||
*DB.lock().await = Some(database); // 保存数据库实例
|
||||
Ok(())
|
||||
@ -99,7 +99,7 @@ async fn token_system() -> Result<status::Custom<String>, status::Custom<String>
|
||||
#[launch]
|
||||
async fn rocket() -> _ {
|
||||
let config = config::Config::read().expect("Failed to read config"); // 读取配置
|
||||
init_db(config.db_config)
|
||||
init_db(config.sql_config)
|
||||
.await
|
||||
.expect("Failed to connect to database"); // 初始化数据库连接
|
||||
rocket::build()
|
||||
|
Loading…
Reference in New Issue
Block a user