后端:实现了对执行sql的参数化保证了安全

This commit is contained in:
lsy 2024-11-18 13:40:47 +08:00
parent 5ca72e42cf
commit 9d88e9a159
4 changed files with 295 additions and 28 deletions

View File

@ -9,7 +9,8 @@ use std::{env, fs};
#[derive(Deserialize)]
pub struct Config {
pub info: Info, // 配置信息
pub db_config: DbConfig, // 数据库配置
pub sql_config: SqlConfig, // 关系型数据库配置
// pub no_sql_config:NoSqlConfig, 非关系型数据库配置
}
#[derive(Deserialize)]
@ -19,7 +20,7 @@ pub struct Info {
}
#[derive(Deserialize)]
pub struct DbConfig {
pub struct SqlConfig {
pub db_type: String, // 数据库类型
pub address: String, // 地址
pub prot: u32, // 端口
@ -28,6 +29,17 @@ pub struct DbConfig {
pub db_name: String, // 数据库名称
}
#[derive(Deserialize)]
pub struct NoSqlConfig {
pub db_type: String, // 数据库类型
pub address: String, // 地址
pub prot: u32, // 端口
pub user: String, // 用户名
pub password: String, // 密码
pub db_name: String, // 数据库名称
}
impl Config {
/// 读取配置文件
pub fn read() -> Result<Self, Box<dyn std::error::Error>> {

View File

@ -1,7 +1,9 @@
// sql/mod.rs
/*
// File path: src/database/relational/mod.rs
/**
*/
mod postgresql;
use std::collections::HashMap;
use crate::config;
@ -9,13 +11,43 @@ use async_trait::async_trait;
use std::error::Error;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub enum SqlOperation {
Select, // 查询操作
Insert, // 插入操作
Update, // 更新操作
Delete, // 删除操作
}
/// 查询构建器结构
pub struct QueryBuilder {
operation: SqlOperation, // SQL操作类型
table: String, // 表名
fields: Vec<String>, // 查询字段
params: HashMap<String, String>, // 插入或更新的参数
where_conditions: HashMap<String, String>, // WHERE条件
order_by: Option<String>, // 排序字段
limit: Option<i32>, // 限制返回的记录数
}
#[async_trait]
pub trait DatabaseTrait: Send + Sync {
// 连接数据库
async fn connect(database: config::DbConfig) -> Result<Self, Box<dyn Error>> where Self: Sized;
// 执行查询
async fn query<'a>(&'a self, query: String) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>>;
/**
@param database
@return Result<Self, Box<dyn Error>>
*/
async fn connect(database: config::SqlConfig) -> Result<Self, Box<dyn Error>> where Self: Sized;
/**
@param query SQL查询语句
@return Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>>
*/
async fn execute_query<'a>(
&'a self,
builder: &QueryBuilder,
) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>> ;
}
#[derive(Clone)]
@ -25,13 +57,20 @@ pub struct Database {
}
impl Database {
// 获取当前数据库实例
/**
@return &Box<dyn DatabaseTrait>
*/
pub fn get_db(&self) -> &Box<dyn DatabaseTrait> {
&self.db
}
// 初始化数据库
pub async fn init(database: config::DbConfig) -> Result<Self, Box<dyn Error>> {
/**
@param database
@return Result<Self, Box<dyn Error>>
*/
pub async fn init(database: config::SqlConfig) -> Result<Self, Box<dyn Error>> {
let db = match database.db_type.as_str() {
"postgresql" => postgresql::Postgresql::connect(database).await?,
_ => return Err("unknown database type".into()),
@ -40,3 +79,193 @@ impl Database {
Ok(Self { db: Arc::new(Box::new(db)) })
}
}
impl QueryBuilder {
/**
@param operation SQL操作类型
@param table
@return Self
*/
pub fn new(operation: SqlOperation, table: &str) -> Self {
QueryBuilder {
operation,
table: table.to_string(),
fields: Vec::new(),
params: HashMap::new(),
where_conditions: HashMap::new(),
order_by: None,
limit: None,
}
}
/**
SQL语句和参数
@return (String, Vec<String>) SQL语句和参数值
*/
pub fn build(&self) -> (String, Vec<String>) {
let mut query = String::new();
let mut values = Vec::new();
let mut param_counter = 1;
match self.operation {
SqlOperation::Select => {
// SELECT 操作
let fields = if self.fields.is_empty() {
"*".to_string()
} else {
self.fields.join(", ")
};
query.push_str(&format!("SELECT {} FROM {}", fields, self.table));
// 添加 WHERE 条件
if !self.where_conditions.is_empty() {
let conditions: Vec<String> = self.where_conditions
.iter()
.map(|(key, _)| {
let placeholder = format!("${}", param_counter);
values.push(self.where_conditions[key].clone());
param_counter += 1;
format!("{} = {}", key, placeholder)
})
.collect();
query.push_str(" WHERE ");
query.push_str(&conditions.join(" AND "));
}
},
SqlOperation::Insert => {
// INSERT 操作
let fields: Vec<String> = self.params.keys().cloned().collect();
let placeholders: Vec<String> = (1..=self.params.len())
.map(|i| format!("${}", i))
.collect();
query.push_str(&format!(
"INSERT INTO {} ({}) VALUES ({})",
self.table,
fields.join(", "),
placeholders.join(", ")
));
// 收集参数值
for field in fields {
values.push(self.params[&field].clone());
}
},
SqlOperation::Update => {
// UPDATE 操作
query.push_str(&format!("UPDATE {}", self.table));
let set_clauses: Vec<String> = self.params
.keys()
.map(|key| {
let placeholder = format!("${}", param_counter);
values.push(self.params[key].clone());
param_counter += 1;
format!("{} = {}", key, placeholder)
})
.collect();
query.push_str(" SET ");
query.push_str(&set_clauses.join(", "));
// 添加 WHERE 条件
if !self.where_conditions.is_empty() {
let conditions: Vec<String> = self.where_conditions
.iter()
.map(|(key, _)| {
let placeholder = format!("${}", param_counter);
values.push(self.where_conditions[key].clone());
param_counter += 1;
format!("{} = {}", key, placeholder)
})
.collect();
query.push_str(" WHERE ");
query.push_str(&conditions.join(" AND "));
}
},
SqlOperation::Delete => {
// DELETE 操作
query.push_str(&format!("DELETE FROM {}", self.table));
// 添加 WHERE 条件
if !self.where_conditions.is_empty() {
let conditions: Vec<String> = self.where_conditions
.iter()
.map(|(key, _)| {
let placeholder = format!("${}", param_counter);
values.push(self.where_conditions[key].clone());
param_counter += 1;
format!("{} = {}", key, placeholder)
})
.collect();
query.push_str(" WHERE ");
query.push_str(&conditions.join(" AND "));
}
}
}
// 添加 ORDER BY
if let Some(order) = &self.order_by {
query.push_str(&format!(" ORDER BY {}", order));
}
// 添加 LIMIT
if let Some(limit) = self.limit {
query.push_str(&format!(" LIMIT {}", limit));
}
(query, values)
}
/**
@param fields
@return &mut Self 便
*/
pub fn fields(&mut self, fields: Vec<String>) -> &mut Self {
self.fields = fields;
self
}
/**
@param params
@return &mut Self 便
*/
pub fn params(&mut self, params: HashMap<String, String>) -> &mut Self {
self.params = params;
self
}
/**
WHERE条件
@param conditions
@return &mut Self 便
*/
pub fn where_conditions(&mut self, conditions: HashMap<String, String>) -> &mut Self {
self.where_conditions = conditions;
self
}
/**
@param order
@return &mut Self 便
*/
pub fn order_by(&mut self, order: &str) -> &mut Self {
self.order_by = Some(order.to_string());
self
}
/**
@param limit
@return &mut Self 便
*/
pub fn limit(&mut self, limit: i32) -> &mut Self {
self.limit = Some(limit);
self
}
}

View File

@ -1,21 +1,34 @@
// sql/psotgresql.rs
// src/database/relational/postgresql/mod.rs
/*
postgresql数据库实现具体的方法
* PostgreSQL数据库的交互功能
*
*/
use super::DatabaseTrait;
use super::{DatabaseTrait, QueryBuilder};
use crate::config;
use async_trait::async_trait;
use sqlx::{Column, PgPool, Row};
use std::{collections::HashMap, error::Error};
#[derive(Clone)]
/// PostgreSQL数据库连接池结构体
pub struct Postgresql {
/// 数据库连接池
pool: PgPool,
}
#[async_trait]
impl DatabaseTrait for Postgresql {
async fn connect(db_config: config::DbConfig) -> Result<Self, Box<dyn Error>> {
/**
* PostgreSQL数据库并返回Postgresql实例
*
* #
* - `db_config`:
*
* #
* - `Result<Self, Box<dyn Error>>`:
*/
async fn connect(db_config: config::SqlConfig) -> Result<Self, Box<dyn Error>> {
let connection_str = format!(
"postgres://{}:{}@{}:{}/{}",
db_config.user, db_config.password, db_config.address, db_config.prot, db_config.db_name
@ -30,34 +43,47 @@ impl DatabaseTrait for Postgresql {
Ok(Postgresql { pool })
}
/**
/**
*
*
* #
* - `builder`:
*
* #
* - `Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>>`:
*/
async fn query<'a>(
async fn execute_query<'a>(
&'a self,
query: String,
builder: &QueryBuilder,
) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>> {
// 执行查询并获取所有行
let rows = sqlx::query(&query)
let (query, values) = builder.build();
// 构建查询
let mut sqlx_query = sqlx::query(&query);
// 绑定参数
for value in values {
sqlx_query = sqlx_query.bind(value);
}
// 执行查询
let rows = sqlx_query
.fetch_all(&self.pool)
.await
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
// 存储查询结果
// 处理结果
let mut results = Vec::new();
// 遍历每一行并构建结果映射
for row in rows {
let mut map = HashMap::new();
for column in row.columns() {
// 获取列的值,若失败则使用默认值
let value: String = row.try_get(column.name()).unwrap_or_default();
map.insert(column.name().to_string(), value);
}
results.push(map);
}
// 返回查询结果
Ok(results)
}
}

View File

@ -26,7 +26,7 @@ static DB: Lazy<Arc<Mutex<Option<relational::Database>>>> = Lazy::new(|| Arc::ne
/**
*
*/
async fn init_db(database: config::DbConfig) -> Result<(), Box<dyn std::error::Error>> {
async fn init_db(database: config::SqlConfig) -> Result<(), Box<dyn std::error::Error>> {
let database = relational::Database::init(database).await?; // 初始化数据库
*DB.lock().await = Some(database); // 保存数据库实例
Ok(())
@ -99,7 +99,7 @@ async fn token_system() -> Result<status::Custom<String>, status::Custom<String>
#[launch]
async fn rocket() -> _ {
let config = config::Config::read().expect("Failed to read config"); // 读取配置
init_db(config.db_config)
init_db(config.sql_config)
.await
.expect("Failed to connect to database"); // 初始化数据库连接
rocket::build()