From c2bb2d21d9f53a40b4af87808d7c96d9d5293d9a Mon Sep 17 00:00:00 2001 From: lsy Date: Mon, 11 Nov 2024 01:38:58 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E8=BF=9E=E6=8E=A5=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/Cargo.toml | 7 +-- backend/src/config.rs | 4 +- backend/src/config.toml | 4 +- backend/src/main.rs | 95 ++++++++++++++++++---------------- backend/src/sql/mod.rs | 84 +++++++++++++++--------------- backend/src/sql/postgresql.rs | 97 +++++++++++++++-------------------- 6 files changed, 145 insertions(+), 146 deletions(-) diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 0002f48..fdde733 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -9,6 +9,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" toml = "0.8.19" tokio = { version = "1", features = ["full"] } -tokio-postgres = "0.7.12" -once_cell = "1.20.2" -async-trait = "0.1.83" \ No newline at end of file +sqlx = { version = "0.8.2", features = ["runtime-tokio-native-tls", "postgres"] } +async-trait = "0.1.83" +anyhow = "1.0" +once_cell = "1.10.0" \ No newline at end of file diff --git a/backend/src/config.rs b/backend/src/config.rs index 5b26a1a..23ffd0a 100644 --- a/backend/src/config.rs +++ b/backend/src/config.rs @@ -13,10 +13,10 @@ pub struct Info { #[derive(Deserialize)] pub struct Database { - pub ilk : String, + pub db_type : String, pub address : String, pub prot : u32, pub user : String, pub password : String, - pub dbname : String, + pub db_name : String, } \ No newline at end of file diff --git a/backend/src/config.toml b/backend/src/config.toml index 8d73a7a..fddd649 100644 --- a/backend/src/config.toml +++ b/backend/src/config.toml @@ -2,9 +2,9 @@ install = false [database] -ilk = "postgresql" +db_type = "postgresql" address = "localhost" prot = 5432 user = "postgres" password = "postgres" -dbname = "echoes" \ No newline at end of file +db_name = "echoes" \ No newline at end of file diff --git a/backend/src/main.rs b/backend/src/main.rs index fe27177..593520b 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,61 +1,68 @@ +// main.rs mod sql; mod config; -use rocket::{get, launch, routes, Route}; -use rocket::http::{ContentType, Status}; -use rocket::serde::{ Serialize}; -use rocket::response::status; -use std::sync::{Arc, Mutex}; +use rocket::{ get, launch, routes}; +use rocket::serde::json::Json; // Added import for Json use once_cell::sync::Lazy; -use rocket::serde::json::Json; -use tokio::sync::Mutex as TokioMutex; -use tokio_postgres::types::ToSql; +use rocket::http::Status; +use rocket::response::status; +use std::sync::Arc; // Added import for Arc and Mutex +use tokio::sync::Mutex; +use crate::sql::Database; -// 获取数据库连接 -static GLOBAL_SQL: Lazy>>>> -= Lazy::new(|| Arc::new(TokioMutex::new(None))); +// 修改全局变量的类型定义 +static GLOBAL_SQL: Lazy>>> = Lazy::new(|| { + Arc::new(Mutex::new(None)) +}); -// 获取数据库连接 -async fn initialize_sql() { - let sql_instance = sql::loading().await; +// 修改数据库连接函数 +async fn connect_database() -> Result<(), Box> { + let database = sql::Database::init().await?; let mut lock = GLOBAL_SQL.lock().await; - *lock = sql_instance; + *lock = Some(database); + Ok(()) } -// 网站初始化 -#[get("/install")] -fn install() -> status::Custom<()> { - status::Custom(Status::Ok, ()) -} -// sql查询 -#[derive(Serialize)] -struct SSql{ - key:String, - value:String, -} -#[get("/sql")] -async fn ssql() -> status::Custom>> { - let sql_instance=GLOBAL_SQL.lock().await; - let sql =sql_instance.as_ref().unwrap(); - let query = "SELECT * FROM info"; - let params: &[&(dyn ToSql + Sync)] = &[]; - let data = sql.query(query, params).await.expect("查询数据失败"); - let mut vec = Vec::new(); - for row in data { - let key=row.get(0); - let value=row.get(1); - vec.push(SSql{ - key, - value, - }); +async fn get_db() -> Result> { + let lock = GLOBAL_SQL.lock().await; + match &*lock { + Some(db) => Ok(db.clone()), + None => Err("Database not initialized".into()) } - status::Custom(Status::Ok, Json(vec)) } +#[get("/sql")] +async fn ssql() -> Result>>, status::Custom> { + let db = get_db().await.map_err(|e| { + eprintln!("Database error: {}", e); + status::Custom(Status::InternalServerError, format!("Database error: {}", e)) + })?; + + let query_result = db.get_db() + .query("SELECT * FROM info".to_string()) // 确保这里是正确的表名 + .await + .map_err(|e| { + eprintln!("Query error: {}", e); + status::Custom(Status::InternalServerError, format!("Query error: {}", e)) + })?; + + Ok(Json(query_result)) +} + + +#[get("/install")] +async fn install() -> status::Custom { + match connect_database().await { + Ok(_) => status::Custom(Status::Ok, "Database connected successfully".to_string()), + Err(e) => status::Custom(Status::InternalServerError, format!("Failed to connect: {}", e)) + } +} + #[launch] async fn rocket() -> _ { - initialize_sql().await; + connect_database().await.expect("Failed to connect to database"); rocket::build() .mount("/api", routes![install,ssql]) -} +} \ No newline at end of file diff --git a/backend/src/sql/mod.rs b/backend/src/sql/mod.rs index 8f3df0d..fb37715 100644 --- a/backend/src/sql/mod.rs +++ b/backend/src/sql/mod.rs @@ -1,53 +1,57 @@ +// mod.rs mod postgresql; -use std::fs; -use tokio_postgres::{Error, Row}; +use std::{collections::HashMap, fs}; use toml; use crate::config::Config; use async_trait::async_trait; +use std::error::Error; +use std::sync::Arc; -// 所有数据库类型 #[async_trait] -pub trait Database: Send + Sync { - async fn query(&self, - query: &str, - params: &[&(dyn tokio_postgres::types::ToSql + Sync)]) - -> Result, Error>; - async fn execute( - &self, - data: &str, - params: &[&(dyn tokio_postgres::types::ToSql + Sync)], - ) - -> Result; +pub trait Databasetrait: Send + Sync { + async fn connect( + address: String, + port: u32, + user: String, + password: String, + dbname: String, + ) -> Result> where Self: Sized; + async fn query<'a>(&'a self, query: String) -> Result>, Box>; +} +#[derive(Clone)] +pub struct Database { + pub db: Arc>, +} + +impl Database { + pub fn get_db(&self) -> &Box { + &self.db + } } -// 加载对应数据库 -pub async fn loading() -> Option> { - let config_string = fs::read_to_string("./src/config.toml") - .expect("Could not load config file"); - let config: Config = toml::de::from_str(config_string.as_str()).expect("Could not parse config"); - let address = config.database.address; - let port = config.database.prot; - let user = config.database.user; - let password = config.database.password; - let dbname = config.database.dbname; - let sql_instance: Box; - match config.database.ilk.as_str() { - "postgresql" => { - let sql = postgresql::connect(address, port, user, password, dbname).await; - match sql { - Ok(conn) => { - sql_instance = Box::new(conn); - } - Err(e) => { - println!("Database connection failed {}", e); - return None; - } +impl Database { + pub async fn init() -> Result> { + let config_string = fs::read_to_string("./src/config.toml") + .map_err(|e| Box::new(e) as Box)?; + let config: Config = toml::from_str(&config_string) + .map_err(|e| Box::new(e) as Box)?; + + match config.database.db_type.as_str() { + "postgresql" => { + let db = postgresql::Postgresql::connect( + config.database.address, + config.database.prot, + config.database.user, + config.database.password, + config.database.db_name, + ).await?; + Ok(Database { + db: Arc::new(Box::new(db)) + }) } - + _ => Err(anyhow::anyhow!("unknown database type").into()), } - _ => { return None } - }; - Some(sql_instance) + } } \ No newline at end of file diff --git a/backend/src/sql/postgresql.rs b/backend/src/sql/postgresql.rs index e0f67a9..2fb5f2d 100644 --- a/backend/src/sql/postgresql.rs +++ b/backend/src/sql/postgresql.rs @@ -1,65 +1,52 @@ -use tokio_postgres::{NoTls, Error, Row, Client, Connection, Socket}; -use crate::sql; use async_trait::async_trait; -use tokio_postgres::tls::NoTlsStream; +use sqlx::{PgPool, Row,Column}; +use std::{collections::HashMap, error::Error}; +use super::Databasetrait; +#[derive(Clone)] pub struct Postgresql { - pub client: tokio_postgres::Client, + pool: PgPool, } -pub async fn connect( - address: String, - port: u32, - user: String, - password: String, - dbname: String, -) -> Result { - let connection_str = format!( - "host={} port={} user={} password={} dbname={}", - address, port, user, password, dbname - ); - let client:Client; - let connection:Connection; - let link = tokio_postgres::connect(&connection_str, NoTls).await; - match link { - Ok((clie,conne)) => { - client = clie; - connection = conne; - } - Err(err) => { - println!("Failed to connect to postgresql: {}", err); - return Err(err); - } - } - tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("postgresql connection error: {}", e); - } - }); - - Ok(Postgresql { client }) -} - -impl Postgresql { - -} #[async_trait] -impl sql::Database for Postgresql { - async fn query(&self, query: & str, - params: &[&(dyn tokio_postgres::types::ToSql + Sync)] - ) -> Result, Error> { - let rows = self.client.query(query, params).await?; - Ok(rows) - } +impl Databasetrait for Postgresql { + async fn connect( + address: String, + port: u32, + user: String, + password: String, + dbname: String, + ) -> Result> { + let connection_str = format!( + "postgres://{}:{}@{}:{}/{}", + user, password, address, port, dbname + ); - async fn execute( - &self, - data: &str, - params: &[&(dyn tokio_postgres::types::ToSql + Sync)], - ) -> Result { - let rows_affected = self.client.execute(data, params).await?; - Ok(rows_affected) + let pool = PgPool::connect(&connection_str) + .await + .map_err(|e| Box::new(e) as Box)?; + + Ok(Postgresql { pool }) } -} + async fn query<'a>(&'a self, query: String) -> Result>, Box> { + let rows = sqlx::query(&query) + .fetch_all(&self.pool) + .await + .map_err(|e| Box::new(e) as Box)?; + + let mut results = Vec::new(); + + for row in rows { + let mut map = HashMap::new(); + for column in row.columns() { + let value: String = row.try_get(column.name()).unwrap_or_default(); + map.insert(column.name().to_string(), value); + } + results.push(map); + } + + Ok(results) + } +} \ No newline at end of file