import pandas as pd from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.exc import SQLAlchemyError from core.config.config_manager import ConfigManager from core.utils.logger import Logger from core.singleton import Singleton class DBClient(metaclass=Singleton): def __init__(self): self.config = ConfigManager() self.logger = Logger.get_logger() self.engine = None self.session_factory = None self.ScopedSession = None self.connect() def connect(self): """连接数据库""" try: db_config = self.config.get_db_config() driver = db_config['driver'] if driver == 'mysql': connection_string = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['name']}" elif driver == 'postgresql': connection_string = f"postgresql+psycopg2://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['name']}" elif driver == 'sqlite': connection_string = f"sqlite:///{db_config['name']}" else: raise ValueError(f"不支持的数据库类型: {driver}") self.engine = create_engine( connection_string, echo=self.config.get('database.echo', False), pool_size=self.config.get('database.pool_size', 5), max_overflow=self.config.get('database.max_overflow', 10), pool_timeout=self.config.get('database.pool_timeout', 30), pool_recycle=self.config.get('database.pool_recycle', 3600) ) self.session_factory = sessionmaker(bind=self.engine) self.ScopedSession = scoped_session(self.session_factory) # 测试连接 with self.engine.connect() as conn: conn.execute(text("SELECT 1")) self.logger.info("数据库连接成功") except Exception as e: self.logger.error(f"数据库连接失败: {str(e)}") raise def get_session(self): """获取数据库会话""" if not self.ScopedSession: self.connect() return self.ScopedSession() def execute_query(self, query, params=None, return_dataframe=False): """执行查询语句""" session = self.get_session() try: if isinstance(query, str): query = text(query) result = session.execute(query, params or {}) if return_dataframe: # 将结果转换为DataFrame df = pd.DataFrame(result.fetchall(), columns=result.keys()) return df else: return result.fetchall() except SQLAlchemyError as e: self.logger.error(f"查询执行失败: {str(e)}") session.rollback() raise finally: session.close() def execute_update(self, query, params=None): """执行更新语句""" session = self.get_session() try: if isinstance(query, str): query = text(query) result = session.execute(query, params or {}) session.commit() return result.rowcount except SQLAlchemyError as e: self.logger.error(f"更新执行失败: {str(e)}") session.rollback() raise finally: session.close() def execute_many(self, query, params_list): """批量执行语句""" session = self.get_session() try: if isinstance(query, str): query = text(query) result = session.execute(query, params_list) session.commit() return result.rowcount except SQLAlchemyError as e: self.logger.error(f"批量执行失败: {str(e)}") session.rollback() raise finally: session.close() def insert_dataframe(self, table_name, df, if_exists='append', index=False): """将DataFrame插入数据库""" try: df.to_sql( table_name, self.engine, if_exists=if_exists, index=index, method='multi' if len(df) > 1000 else None ) self.logger.info(f"成功插入 {len(df)} 行数据到表 {table_name}") return len(df) except Exception as e: self.logger.error(f"插入DataFrame失败: {str(e)}") raise def read_table(self, table_name, conditions=None, columns=None): """读取数据库表""" try: query = f"SELECT {','.join(columns) if columns else '*'} FROM {table_name}" if conditions: query += f" WHERE {conditions}" df = pd.read_sql(query, self.engine) self.logger.info(f"成功从表 {table_name} 读取 {len(df)} 行数据") return df except Exception as e: self.logger.error(f"读取表失败: {str(e)}") raise def close(self): """关闭数据库连接""" if self.ScopedSession: self.ScopedSession.remove() if self.engine: self.engine.dispose() self.logger.info("数据库连接已关闭")