| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- import pandas as pd
- from sqlalchemy import create_engine, text
- from sqlalchemy.orm import sessionmaker, scoped_session
- from sqlalchemy.exc import SQLAlchemyError
- from core.config.config_manager import ConfigManager
- from core.utils.logger import Logger
- from core.singleton import Singleton
- class DBClient(metaclass=Singleton):
- def __init__(self):
- self.config = ConfigManager()
- self.logger = Logger.get_logger()
- self.engine = None
- self.session_factory = None
- self.ScopedSession = None
- self.connect()
- def connect(self):
- """连接数据库"""
- try:
- db_config = self.config.get_db_config()
- driver = db_config['driver']
- if driver == 'mysql':
- connection_string = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['name']}"
- elif driver == 'postgresql':
- connection_string = f"postgresql+psycopg2://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['name']}"
- elif driver == 'sqlite':
- connection_string = f"sqlite:///{db_config['name']}"
- else:
- raise ValueError(f"不支持的数据库类型: {driver}")
- self.engine = create_engine(
- connection_string,
- echo=self.config.get('database.echo', False),
- pool_size=self.config.get('database.pool_size', 5),
- max_overflow=self.config.get('database.max_overflow', 10),
- pool_timeout=self.config.get('database.pool_timeout', 30),
- pool_recycle=self.config.get('database.pool_recycle', 3600)
- )
- self.session_factory = sessionmaker(bind=self.engine)
- self.ScopedSession = scoped_session(self.session_factory)
- # 测试连接
- with self.engine.connect() as conn:
- conn.execute(text("SELECT 1"))
- self.logger.info("数据库连接成功")
- except Exception as e:
- self.logger.error(f"数据库连接失败: {str(e)}")
- raise
- def get_session(self):
- """获取数据库会话"""
- if not self.ScopedSession:
- self.connect()
- return self.ScopedSession()
- def execute_query(self, query, params=None, return_dataframe=False):
- """执行查询语句"""
- session = self.get_session()
- try:
- if isinstance(query, str):
- query = text(query)
- result = session.execute(query, params or {})
- if return_dataframe:
- # 将结果转换为DataFrame
- df = pd.DataFrame(result.fetchall(), columns=result.keys())
- return df
- else:
- return result.fetchall()
- except SQLAlchemyError as e:
- self.logger.error(f"查询执行失败: {str(e)}")
- session.rollback()
- raise
- finally:
- session.close()
- def execute_update(self, query, params=None):
- """执行更新语句"""
- session = self.get_session()
- try:
- if isinstance(query, str):
- query = text(query)
- result = session.execute(query, params or {})
- session.commit()
- return result.rowcount
- except SQLAlchemyError as e:
- self.logger.error(f"更新执行失败: {str(e)}")
- session.rollback()
- raise
- finally:
- session.close()
- def execute_many(self, query, params_list):
- """批量执行语句"""
- session = self.get_session()
- try:
- if isinstance(query, str):
- query = text(query)
- result = session.execute(query, params_list)
- session.commit()
- return result.rowcount
- except SQLAlchemyError as e:
- self.logger.error(f"批量执行失败: {str(e)}")
- session.rollback()
- raise
- finally:
- session.close()
- def insert_dataframe(self, table_name, df, if_exists='append', index=False):
- """将DataFrame插入数据库"""
- try:
- df.to_sql(
- table_name,
- self.engine,
- if_exists=if_exists,
- index=index,
- method='multi' if len(df) > 1000 else None
- )
- self.logger.info(f"成功插入 {len(df)} 行数据到表 {table_name}")
- return len(df)
- except Exception as e:
- self.logger.error(f"插入DataFrame失败: {str(e)}")
- raise
- def read_table(self, table_name, conditions=None, columns=None):
- """读取数据库表"""
- try:
- query = f"SELECT {','.join(columns) if columns else '*'} FROM {table_name}"
- if conditions:
- query += f" WHERE {conditions}"
- df = pd.read_sql(query, self.engine)
- self.logger.info(f"成功从表 {table_name} 读取 {len(df)} 行数据")
- return df
- except Exception as e:
- self.logger.error(f"读取表失败: {str(e)}")
- raise
- def close(self):
- """关闭数据库连接"""
- if self.ScopedSession:
- self.ScopedSession.remove()
- if self.engine:
- self.engine.dispose()
- self.logger.info("数据库连接已关闭")
|