diff --git a/tcp_server/src/main.rs b/tcp_server/src/main.rs index a749101..c2ae044 100644 --- a/tcp_server/src/main.rs +++ b/tcp_server/src/main.rs @@ -1,9 +1,53 @@ -use tokio::net::TcpListener; +use std::error::Error; +use tokio::net::{TcpListener, TcpStream}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use config::Config; use tokio_postgres::{NoTls, Error as PgError}; use serde::{Deserialize, Serialize}; use std::sync::Arc; +use std::net::SocketAddr; + +#[derive(Debug)] +enum AppError { + Database(PgError), + Json(serde_json::Error), +} + +impl From for AppError { + fn from(err: PgError) -> Self { + AppError::Database(err) + } +} + +impl From for AppError { + fn from(err: serde_json::Error) -> Self { + AppError::Json(err) + } +} + +impl std::fmt::Display for AppError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AppError::Database(e) => write!(f, "数据库错误: {}", e), + AppError::Json(e) => write!(f, "JSON错误: {}", e), + } + } +} + +impl std::error::Error for AppError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + AppError::Database(e) => Some(e), + AppError::Json(e) => Some(e), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct TableType { + table_name: String, + data: serde_json::Value, +} #[derive(Debug, Serialize, Deserialize)] struct InstrumentInfo { @@ -16,6 +60,120 @@ struct InstrumentInfo { specification: String, } +async fn insert_data(client: &tokio_postgres::Client, table_type: &TableType) -> Result<(), AppError> { + // 先检查表是否存在 + let exists = client + .query_one( + "SELECT EXISTS( + SELECT 1 FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name = $1 + )", + &[&table_type.table_name], + ) + .await? + .get::<_, bool>(0); + + if !exists { + println!("表 {} 不存在", table_type.table_name); + return Ok(()); + } + + // 获取表的列信息 + let columns = client + .query( + "SELECT column_name, data_type + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = $1 + ORDER BY ordinal_position", + &[&table_type.table_name], + ) + .await?; + + if columns.is_empty() { + println!("表 {} 没有列信息", table_type.table_name); + return Ok(()); + } + + // 检查 ID 是否存在 + if let Some(id) = table_type.data.get("id") { + let id_value = id.as_i64().unwrap_or(0); + let exists = client + .query_one( + &format!("SELECT EXISTS(SELECT 1 FROM public.{} WHERE id = $1)", table_type.table_name), + &[&id_value], + ) + .await? + .get::<_, bool>(0); + + if exists { + println!("表 {} 中 ID {} 已存在,跳过插入", table_type.table_name, id); + return Ok(()); + } + } + + // 构建插入语句 + let mut column_names = Vec::new(); + let mut placeholders = Vec::new(); + let mut param_count = 1; + let mut query_values: Vec> = Vec::new(); + + for row in &columns { + let column_name: String = row.get("column_name"); + let data_type: String = row.get("data_type"); + + if let Some(value) = table_type.data.get(&column_name) { + column_names.push(column_name); + placeholders.push(format!("${}", param_count)); + param_count += 1; + + // 根据数据类型转换值 + match data_type.as_str() { + "integer" | "bigint" => { + if let Some(n) = value.as_i64() { + query_values.push(Box::new(n)); + } + } + "character varying" | "text" => { + if let Some(s) = value.as_str() { + query_values.push(Box::new(s.to_string())); + } + } + "boolean" => { + if let Some(b) = value.as_bool() { + query_values.push(Box::new(b)); + } + } + "double precision" | "numeric" => { + if let Some(n) = value.as_f64() { + query_values.push(Box::new(n)); + } + } + // 可以根据需要添加更多数据类型的处理 + _ => println!("不支持的数据类型: {}", data_type), + } + } + } + + let query = format!( + "INSERT INTO public.{} ({}) OVERRIDING SYSTEM VALUE VALUES ({})", + table_type.table_name, + column_names.join(", "), + placeholders.join(", ") + ); + + let values: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = query_values + .iter() + .map(|v| v.as_ref()) + .collect(); + + client.execute(&query, &values[..]).await?; + + println!("成功插入数据到表 {}", table_type.table_name); + Ok(()) +} + async fn connect_db(config: &Config) -> Result { let host = config.get_string("database.host").unwrap(); let port = config.get_int("database.port").unwrap() as u16; @@ -31,7 +189,7 @@ async fn connect_db(config: &Config) -> Result let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?; // 在后台运行连接 - tokio::spawn(async move { + tokio::task::spawn(async move { if let Err(e) = connection.await { eprintln!("数据库连接错误: {}", e); } @@ -40,48 +198,21 @@ async fn connect_db(config: &Config) -> Result Ok(client) } -async fn insert_instrument(client: &tokio_postgres::Client, info: &InstrumentInfo) -> Result<(), PgError> { - // 先检查 ID 是否存在 - let exists = client - .query_one( - "SELECT EXISTS(SELECT 1 FROM public.hy_instrument WHERE id = $1)", - &[&info.id], - ) - .await? - .get::<_, bool>(0); - - if exists { - println!("ID {} 已存在,跳过插入", info.id); - return Ok(()); - } - - // ID 不存在,执行插入 - client.execute( - "INSERT INTO public.hy_instrument (id, informationid, instrumentcode, laboratoryid, name, remark, specification) - OVERRIDING SYSTEM VALUE - VALUES ($1, $2, $3, $4, $5, $6, $7)", - &[ - &info.id, - &info.informationid, - &info.instrumentcode, - &info.laboratoryid, - &info.name, - &info.remark, - &info.specification, - ], - ) - .await?; - - println!("成功插入仪器信息: {} (ID: {})", info.instrumentcode, info.id); - Ok(()) +async fn insert_instrument(client: &tokio_postgres::Client, info: &InstrumentInfo) -> Result<(), AppError> { + let json_data = serde_json::to_value(info)?; + let table_type = TableType { + table_name: "hy_instrument".to_string(), + data: json_data, + }; + insert_data(client, &table_type).await } -fn check_ok_message(message: &[u8]) -> u8 { +fn check_ok_message(_message: &[u8]) -> u8 { 0xFF } #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> Result<(), Box> { // 读取配置文件 let settings = Config::builder() .add_source(config::File::with_name("config")) @@ -100,26 +231,45 @@ async fn main() -> Result<(), Box> { println!("服务器监听地址: {}", bind_address); loop { - let (mut socket, addr) = listener.accept().await?; + let (socket, addr) = listener.accept().await?; println!("新客户端连接: {}", addr); let client = Arc::clone(&client); - tokio::spawn(async move { - let mut buf = [0; 1024 * 64]; // 增加缓冲区大小到64KB + // 使用 spawn_blocking 来处理连接 + tokio::task::spawn(async move { + if let Err(e) = process_connection(socket, addr, client).await { + eprintln!("处理连接错误: {}", e); + } + }); + } +} + +async fn process_connection(mut socket: TcpStream, addr: SocketAddr, client: Arc) -> Result<(), Box> { + let mut buf = [0; 1024 * 64]; // 增加缓冲区大小到64KB - loop { - match socket.read(&mut buf).await { - Ok(0) => { - println!("客户端断开连接: {}", addr); - return; + loop { + match socket.read(&mut buf).await { + Ok(0) => { + println!("客户端断开连接: {}", addr); + break; + } + Ok(n) => { + let data = &buf[..n]; + // 尝试解析为 TableType + match serde_json::from_slice::(data) { + Ok(table_type) => { + println!("接收到表 {} 的数据", table_type.table_name); + if let Err(e) = insert_data(&client, &table_type).await { + eprintln!("错误: {}", e); + } } - Ok(n) => { - let data = &buf[..n]; + Err(_) => { + // 如果不是 TableType 格式,尝试解析为 InstrumentInfo(保持向后兼容) match serde_json::from_slice::(data) { Ok(info) => { println!("接收到仪器信息: {:?}", info); if let Err(e) = insert_instrument(&client, &info).await { - eprintln!("插入数据失败: {}", e); + eprintln!("错误: {}", e); } } Err(e) => { @@ -127,19 +277,20 @@ async fn main() -> Result<(), Box> { eprintln!("接收到的数据: {}", String::from_utf8_lossy(data)); } } - - let response = check_ok_message(data); - if let Err(e) = socket.write_all(&[response]).await { - println!("发送响应失败: {}", e); - return; - } - } - Err(e) => { - println!("读取数据失败: {}", e); - return; } } + + let response = check_ok_message(data); + if let Err(e) = socket.write_all(&[response]).await { + eprintln!("发送响应失败: {}", e); + break; + } } - }); + Err(e) => { + eprintln!("读取数据失败: {}", e); + break; + } + } } + Ok(()) } diff --git a/tcp_server/target/debug/.fingerprint/tcp_server-d434eb7e44cc404e/bin-tcp_server b/tcp_server/target/debug/.fingerprint/tcp_server-d434eb7e44cc404e/bin-tcp_server index 03016a9..e69de29 100644 --- a/tcp_server/target/debug/.fingerprint/tcp_server-d434eb7e44cc404e/bin-tcp_server +++ b/tcp_server/target/debug/.fingerprint/tcp_server-d434eb7e44cc404e/bin-tcp_server @@ -1 +0,0 @@ -4ef5d624c06b520e \ No newline at end of file diff --git a/tcp_server/target/debug/deps/tcp_server.exe b/tcp_server/target/debug/deps/tcp_server.exe deleted file mode 100644 index 790dbd4..0000000 Binary files a/tcp_server/target/debug/deps/tcp_server.exe and /dev/null differ diff --git a/tcp_server/target/debug/deps/tcp_server.pdb b/tcp_server/target/debug/deps/tcp_server.pdb deleted file mode 100644 index 728fd12..0000000 Binary files a/tcp_server/target/debug/deps/tcp_server.pdb and /dev/null differ