实现连接数据库接口

This commit is contained in:
lsy 2024-11-11 01:38:58 +08:00
parent 792346d43d
commit c2bb2d21d9
6 changed files with 145 additions and 146 deletions

View File

@ -9,6 +9,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
toml = "0.8.19" toml = "0.8.19"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tokio-postgres = "0.7.12" sqlx = { version = "0.8.2", features = ["runtime-tokio-native-tls", "postgres"] }
once_cell = "1.20.2"
async-trait = "0.1.83" async-trait = "0.1.83"
anyhow = "1.0"
once_cell = "1.10.0"

View File

@ -13,10 +13,10 @@ pub struct Info {
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Database { pub struct Database {
pub ilk : String, pub db_type : String,
pub address : String, pub address : String,
pub prot : u32, pub prot : u32,
pub user : String, pub user : String,
pub password : String, pub password : String,
pub dbname : String, pub db_name : String,
} }

View File

@ -2,9 +2,9 @@
install = false install = false
[database] [database]
ilk = "postgresql" db_type = "postgresql"
address = "localhost" address = "localhost"
prot = 5432 prot = 5432
user = "postgres" user = "postgres"
password = "postgres" password = "postgres"
dbname = "echoes" db_name = "echoes"

View File

@ -1,61 +1,68 @@
// main.rs
mod sql; mod sql;
mod config; mod config;
use rocket::{get, launch, routes, Route}; use rocket::{ get, launch, routes};
use rocket::http::{ContentType, Status}; use rocket::serde::json::Json; // Added import for Json
use rocket::serde::{ Serialize};
use rocket::response::status;
use std::sync::{Arc, Mutex};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rocket::serde::json::Json; use rocket::http::Status;
use tokio::sync::Mutex as TokioMutex; use rocket::response::status;
use tokio_postgres::types::ToSql; use std::sync::Arc; // Added import for Arc and Mutex
use tokio::sync::Mutex;
use crate::sql::Database;
// 获取数据库连接 // 修改全局变量的类型定义
static GLOBAL_SQL: Lazy<Arc<TokioMutex<Option<Box<dyn sql::Database >>>>> static GLOBAL_SQL: Lazy<Arc<Mutex<Option<Database>>>> = Lazy::new(|| {
= Lazy::new(|| Arc::new(TokioMutex::new(None))); Arc::new(Mutex::new(None))
// 获取数据库连接
async fn initialize_sql() {
let sql_instance = sql::loading().await;
let mut lock = GLOBAL_SQL.lock().await;
*lock = sql_instance;
}
// 网站初始化
#[get("/install")]
fn install() -> status::Custom<()> {
status::Custom(Status::Ok, ())
}
// sql查询
#[derive(Serialize)]
struct SSql{
key:String,
value:String,
}
#[get("/sql")]
async fn ssql() -> status::Custom<Json<Vec<SSql>>> {
let sql_instance=GLOBAL_SQL.lock().await;
let sql =sql_instance.as_ref().unwrap();
let query = "SELECT * FROM info";
let params: &[&(dyn ToSql + Sync)] = &[];
let data = sql.query(query, params).await.expect("查询数据失败");
let mut vec = Vec::new();
for row in data {
let key=row.get(0);
let value=row.get(1);
vec.push(SSql{
key,
value,
}); });
}
status::Custom(Status::Ok, Json(vec)) // 修改数据库连接函数
async fn connect_database() -> Result<(), Box<dyn std::error::Error>> {
let database = sql::Database::init().await?;
let mut lock = GLOBAL_SQL.lock().await;
*lock = Some(database);
Ok(())
} }
async fn get_db() -> Result<Database, Box<dyn std::error::Error>> {
let lock = GLOBAL_SQL.lock().await;
match &*lock {
Some(db) => Ok(db.clone()),
None => Err("Database not initialized".into())
}
}
#[get("/sql")]
async fn ssql() -> Result<Json<Vec<std::collections::HashMap<String, String>>>, status::Custom<String>> {
let db = get_db().await.map_err(|e| {
eprintln!("Database error: {}", e);
status::Custom(Status::InternalServerError, format!("Database error: {}", e))
})?;
let query_result = db.get_db()
.query("SELECT * FROM info".to_string()) // 确保这里是正确的表名
.await
.map_err(|e| {
eprintln!("Query error: {}", e);
status::Custom(Status::InternalServerError, format!("Query error: {}", e))
})?;
Ok(Json(query_result))
}
#[get("/install")]
async fn install() -> status::Custom<String> {
match connect_database().await {
Ok(_) => status::Custom(Status::Ok, "Database connected successfully".to_string()),
Err(e) => status::Custom(Status::InternalServerError, format!("Failed to connect: {}", e))
}
}
#[launch] #[launch]
async fn rocket() -> _ { async fn rocket() -> _ {
initialize_sql().await; connect_database().await.expect("Failed to connect to database");
rocket::build() rocket::build()
.mount("/api", routes![install,ssql]) .mount("/api", routes![install,ssql])
} }

View File

@ -1,53 +1,57 @@
// mod.rs
mod postgresql; mod postgresql;
use std::fs; use std::{collections::HashMap, fs};
use tokio_postgres::{Error, Row};
use toml; use toml;
use crate::config::Config; use crate::config::Config;
use async_trait::async_trait; use async_trait::async_trait;
use std::error::Error;
use std::sync::Arc;
// 所有数据库类型
#[async_trait] #[async_trait]
pub trait Database: Send + Sync { pub trait Databasetrait: Send + Sync {
async fn query(&self, async fn connect(
query: &str, address: String,
params: &[&(dyn tokio_postgres::types::ToSql + Sync)]) port: u32,
-> Result<Vec<Row>, Error>; user: String,
async fn execute( password: String,
&self, dbname: String,
data: &str, ) -> Result<Self, Box<dyn Error>> where Self: Sized;
params: &[&(dyn tokio_postgres::types::ToSql + Sync)], async fn query<'a>(&'a self, query: String) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>>;
) }
-> Result<u64, Error>; #[derive(Clone)]
pub struct Database {
pub db: Arc<Box<dyn Databasetrait>>,
}
impl Database {
pub fn get_db(&self) -> &Box<dyn Databasetrait> {
&self.db
}
} }
// 加载对应数据库
pub async fn loading() -> Option<Box<dyn Database>> { impl Database {
pub async fn init() -> Result<Database, Box<dyn Error>> {
let config_string = fs::read_to_string("./src/config.toml") let config_string = fs::read_to_string("./src/config.toml")
.expect("Could not load config file"); .map_err(|e| Box::new(e) as Box<dyn Error>)?;
let config: Config = toml::de::from_str(config_string.as_str()).expect("Could not parse config"); let config: Config = toml::from_str(&config_string)
let address = config.database.address; .map_err(|e| Box::new(e) as Box<dyn Error>)?;
let port = config.database.prot;
let user = config.database.user;
let password = config.database.password;
let dbname = config.database.dbname;
let sql_instance: Box<dyn Database>;
match config.database.ilk.as_str() { match config.database.db_type.as_str() {
"postgresql" => { "postgresql" => {
let sql = postgresql::connect(address, port, user, password, dbname).await; let db = postgresql::Postgresql::connect(
match sql { config.database.address,
Ok(conn) => { config.database.prot,
sql_instance = Box::new(conn); config.database.user,
config.database.password,
config.database.db_name,
).await?;
Ok(Database {
db: Arc::new(Box::new(db))
})
} }
Err(e) => { _ => Err(anyhow::anyhow!("unknown database type").into()),
println!("Database connection failed {}", e);
return None;
} }
} }
}
_ => { return None }
};
Some(sql_instance)
} }

View File

@ -1,65 +1,52 @@
use tokio_postgres::{NoTls, Error, Row, Client, Connection, Socket};
use crate::sql;
use async_trait::async_trait; use async_trait::async_trait;
use tokio_postgres::tls::NoTlsStream; use sqlx::{PgPool, Row,Column};
use std::{collections::HashMap, error::Error};
use super::Databasetrait;
#[derive(Clone)]
pub struct Postgresql { pub struct Postgresql {
pub client: tokio_postgres::Client, pool: PgPool,
} }
pub async fn connect(
#[async_trait]
impl Databasetrait for Postgresql {
async fn connect(
address: String, address: String,
port: u32, port: u32,
user: String, user: String,
password: String, password: String,
dbname: String, dbname: String,
) -> Result<Postgresql, Error> { ) -> Result<Self, Box<dyn Error>> {
let connection_str = format!( let connection_str = format!(
"host={} port={} user={} password={} dbname={}", "postgres://{}:{}@{}:{}/{}",
address, port, user, password, dbname user, password, address, port, dbname
); );
let client:Client; let pool = PgPool::connect(&connection_str)
let connection:Connection<Socket, NoTlsStream>; .await
let link = tokio_postgres::connect(&connection_str, NoTls).await; .map_err(|e| Box::new(e) as Box<dyn Error>)?;
match link {
Ok((clie,conne)) => { Ok(Postgresql { pool })
client = clie;
connection = conne;
} }
Err(err) => { async fn query<'a>(&'a self, query: String) -> Result<Vec<HashMap<String, String>>, Box<dyn Error + 'a>> {
println!("Failed to connect to postgresql: {}", err); let rows = sqlx::query(&query)
return Err(err); .fetch_all(&self.pool)
.await
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
let mut results = Vec::new();
for row in rows {
let mut map = HashMap::new();
for column in row.columns() {
let value: String = row.try_get(column.name()).unwrap_or_default();
map.insert(column.name().to_string(), value);
} }
results.push(map);
} }
tokio::spawn(async move { Ok(results)
if let Err(e) = connection.await {
eprintln!("postgresql connection error: {}", e);
}
});
Ok(Postgresql { client })
}
impl Postgresql {
}
#[async_trait]
impl sql::Database for Postgresql {
async fn query(&self, query: & str,
params: &[&(dyn tokio_postgres::types::ToSql + Sync)]
) -> Result<Vec<Row>, Error> {
let rows = self.client.query(query, params).await?;
Ok(rows)
}
async fn execute(
&self,
data: &str,
params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
) -> Result<u64, Error> {
let rows_affected = self.client.execute(data, params).await?;
Ok(rows_affected)
} }
} }