实现连接数据库接口

This commit is contained in:
lsy 2024-11-11 01:38:58 +08:00
parent 792346d43d
commit c2bb2d21d9
6 changed files with 145 additions and 146 deletions

View File

@ -9,6 +9,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
toml = "0.8.19" toml = "0.8.19"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tokio-postgres = "0.7.12" sqlx = { version = "0.8.2", features = ["runtime-tokio-native-tls", "postgres"] }
once_cell = "1.20.2" async-trait = "0.1.83"
async-trait = "0.1.83" anyhow = "1.0"
once_cell = "1.10.0"

View File

@ -13,10 +13,10 @@ pub struct Info {
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Database { pub struct Database {
pub ilk : String, pub db_type : String,
pub address : String, pub address : String,
pub prot : u32, pub prot : u32,
pub user : String, pub user : String,
pub password : String, pub password : String,
pub dbname : String, pub db_name : String,
} }

View File

@ -2,9 +2,9 @@
install = false install = false
[database] [database]
ilk = "postgresql" db_type = "postgresql"
address = "localhost" address = "localhost"
prot = 5432 prot = 5432
user = "postgres" user = "postgres"
password = "postgres" password = "postgres"
dbname = "echoes" db_name = "echoes"

View File

@ -1,61 +1,68 @@
// main.rs
mod sql; mod sql;
mod config; mod config;
use rocket::{get, launch, routes, Route}; use rocket::{ get, launch, routes};
use rocket::http::{ContentType, Status}; use rocket::serde::json::Json; // Added import for Json
use rocket::serde::{ Serialize};
use rocket::response::status;
use std::sync::{Arc, Mutex};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rocket::serde::json::Json; use rocket::http::Status;
use tokio::sync::Mutex as TokioMutex; use rocket::response::status;
use tokio_postgres::types::ToSql; use std::sync::Arc; // Added import for Arc and Mutex
use tokio::sync::Mutex;
use crate::sql::Database;
// 获取数据库连接 // 修改全局变量的类型定义
static GLOBAL_SQL: Lazy<Arc<TokioMutex<Option<Box<dyn sql::Database >>>>> static GLOBAL_SQL: Lazy<Arc<Mutex<Option<Database>>>> = Lazy::new(|| {
= Lazy::new(|| Arc::new(TokioMutex::new(None))); Arc::new(Mutex::new(None))
});
// 获取数据库连接 // 修改数据库连接函数
async fn initialize_sql() { async fn connect_database() -> Result<(), Box<dyn std::error::Error>> {
let sql_instance = sql::loading().await; let database = sql::Database::init().await?;
let mut lock = GLOBAL_SQL.lock().await; let mut lock = GLOBAL_SQL.lock().await;
*lock = sql_instance; *lock = Some(database);
Ok(())
} }
// 网站初始化 async fn get_db() -> Result<Database, Box<dyn std::error::Error>> {
#[get("/install")] let lock = GLOBAL_SQL.lock().await;
fn install() -> status::Custom<()> { match &*lock {
status::Custom(Status::Ok, ()) Some(db) => Ok(db.clone()),
} None => Err("Database not initialized".into())
// sql查询
#[derive(Serialize)]
struct SSql{
key:String,
value:String,
}
#[get("/sql")]
async fn ssql() -> status::Custom<Json<Vec<SSql>>> {
let sql_instance=GLOBAL_SQL.lock().await;
let sql =sql_instance.as_ref().unwrap();
let query = "SELECT * FROM info";
let params: &[&(dyn ToSql + Sync)] = &[];
let data = sql.query(query, params).await.expect("查询数据失败");
let mut vec = Vec::new();
for row in data {
let key=row.get(0);
let value=row.get(1);
vec.push(SSql{
key,
value,
});
} }
status::Custom(Status::Ok, Json(vec))
} }
#[get("/sql")]
async fn ssql() -> Result<Json<Vec<std::collections::HashMap<String, String>>>, status::Custom<String>> {
let db = get_db().await.map_err(|e| {
eprintln!("Database error: {}", e);
status::Custom(Status::InternalServerError, format!("Database error: {}", e))
})?;
let query_result = db.get_db()
.query("SELECT * FROM info".to_string()) // 确保这里是正确的表名
.await
.map_err(|e| {
eprintln!("Query error: {}", e);
status::Custom(Status::InternalServerError, format!("Query error: {}", e))
})?;
Ok(Json(query_result))
}
#[get("/install")]
async fn install() -> status::Custom<String> {
match connect_database().await {
Ok(_) => status::Custom(Status::Ok, "Database connected successfully".to_string()),
Err(e) => status::Custom(Status::InternalServerError, format!("Failed to connect: {}", e))
}
}
#[launch] #[launch]
async fn rocket() -> _ { async fn rocket() -> _ {
initialize_sql().await; connect_database().await.expect("Failed to connect to database");
rocket::build() rocket::build()
.mount("/api", routes![install,ssql]) .mount("/api", routes![install,ssql])
} }

View File

@ -1,53 +1,57 @@
// mod.rs
mod postgresql; mod postgresql;
use std::fs; use std::{collections::HashMap, fs};
use tokio_postgres::{Error, Row};
use toml; use toml;
use crate::config::Config; use crate::config::Config;
use async_trait::async_trait; use async_trait::async_trait;
use std::error::Error;
use std::sync::Arc;
// 所有数据库类型
#[async_trait] #[async_trait]
pub trait Database: Send + Sync { pub trait Databasetrait: Send + Sync {
async fn query(&self, async fn connect(
query: &str, address: String,
params: &[&(dyn tokio_postgres::types::ToSql + Sync)]) port: u32,
-> Result<Vec<Row>, Error>; user: String,
async fn execute( password: String,
&self, dbname: String,
data: &str, ) -> Result<Self, Box<dyn Error>> where Self: Sized;
params: &[&(dyn tokio_postgres::types::ToSql + Sync)], async fn query<'a>(&'a self, query: String) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>>;
) }
-> Result<u64, Error>; #[derive(Clone)]
pub struct Database {
pub db: Arc<Box<dyn Databasetrait>>,
}
impl Database {
pub fn get_db(&self) -> &Box<dyn Databasetrait> {
&self.db
}
} }
// 加载对应数据库
pub async fn loading() -> Option<Box<dyn Database>> {
let config_string = fs::read_to_string("./src/config.toml")
.expect("Could not load config file");
let config: Config = toml::de::from_str(config_string.as_str()).expect("Could not parse config");
let address = config.database.address;
let port = config.database.prot;
let user = config.database.user;
let password = config.database.password;
let dbname = config.database.dbname;
let sql_instance: Box<dyn Database>;
match config.database.ilk.as_str() { impl Database {
"postgresql" => { pub async fn init() -> Result<Database, Box<dyn Error>> {
let sql = postgresql::connect(address, port, user, password, dbname).await; let config_string = fs::read_to_string("./src/config.toml")
match sql { .map_err(|e| Box::new(e) as Box<dyn Error>)?;
Ok(conn) => { let config: Config = toml::from_str(&config_string)
sql_instance = Box::new(conn); .map_err(|e| Box::new(e) as Box<dyn Error>)?;
}
Err(e) => { match config.database.db_type.as_str() {
println!("Database connection failed {}", e); "postgresql" => {
return None; let db = postgresql::Postgresql::connect(
} config.database.address,
config.database.prot,
config.database.user,
config.database.password,
config.database.db_name,
).await?;
Ok(Database {
db: Arc::new(Box::new(db))
})
} }
_ => Err(anyhow::anyhow!("unknown database type").into()),
} }
_ => { return None } }
};
Some(sql_instance)
} }

View File

@ -1,65 +1,52 @@
use tokio_postgres::{NoTls, Error, Row, Client, Connection, Socket};
use crate::sql;
use async_trait::async_trait; use async_trait::async_trait;
use tokio_postgres::tls::NoTlsStream; use sqlx::{PgPool, Row,Column};
use std::{collections::HashMap, error::Error};
use super::Databasetrait;
#[derive(Clone)]
pub struct Postgresql { pub struct Postgresql {
pub client: tokio_postgres::Client, pool: PgPool,
} }
pub async fn connect(
address: String,
port: u32,
user: String,
password: String,
dbname: String,
) -> Result<Postgresql, Error> {
let connection_str = format!(
"host={} port={} user={} password={} dbname={}",
address, port, user, password, dbname
);
let client:Client;
let connection:Connection<Socket, NoTlsStream>;
let link = tokio_postgres::connect(&connection_str, NoTls).await;
match link {
Ok((clie,conne)) => {
client = clie;
connection = conne;
}
Err(err) => {
println!("Failed to connect to postgresql: {}", err);
return Err(err);
}
}
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("postgresql connection error: {}", e);
}
});
Ok(Postgresql { client })
}
impl Postgresql {
}
#[async_trait] #[async_trait]
impl sql::Database for Postgresql { impl Databasetrait for Postgresql {
async fn query(&self, query: & str, async fn connect(
params: &[&(dyn tokio_postgres::types::ToSql + Sync)] address: String,
) -> Result<Vec<Row>, Error> { port: u32,
let rows = self.client.query(query, params).await?; user: String,
Ok(rows) password: String,
} dbname: String,
) -> Result<Self, Box<dyn Error>> {
let connection_str = format!(
"postgres://{}:{}@{}:{}/{}",
user, password, address, port, dbname
);
async fn execute( let pool = PgPool::connect(&connection_str)
&self, .await
data: &str, .map_err(|e| Box::new(e) as Box<dyn Error>)?;
params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
) -> Result<u64, Error> { Ok(Postgresql { pool })
let rows_affected = self.client.execute(data, params).await?;
Ok(rows_affected)
} }
} async fn query<'a>(&'a self, query: String) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>> {
let rows = sqlx::query(&query)
.fetch_all(&self.pool)
.await
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
let mut results = Vec::new();
for row in rows {
let mut map = HashMap::new();
for column in row.columns() {
let value: String = row.try_get(column.name()).unwrap_or_default();
map.insert(column.name().to_string(), value);
}
results.push(map);
}
Ok(results)
}
}