db_client.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import pandas as pd
  2. from sqlalchemy import create_engine, text
  3. from sqlalchemy.orm import sessionmaker, scoped_session
  4. from sqlalchemy.exc import SQLAlchemyError
  5. from core.config.config_manager import ConfigManager
  6. from core.utils.logger import Logger
  7. from core.singleton import Singleton
  8. class DBClient(metaclass=Singleton):
  9. def __init__(self):
  10. self.config = ConfigManager()
  11. self.logger = Logger.get_logger()
  12. self.engine = None
  13. self.session_factory = None
  14. self.ScopedSession = None
  15. self.connect()
  16. def connect(self):
  17. """连接数据库"""
  18. try:
  19. db_config = self.config.get_db_config()
  20. driver = db_config['driver']
  21. if driver == 'mysql':
  22. connection_string = f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['name']}"
  23. elif driver == 'postgresql':
  24. connection_string = f"postgresql+psycopg2://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['name']}"
  25. elif driver == 'sqlite':
  26. connection_string = f"sqlite:///{db_config['name']}"
  27. else:
  28. raise ValueError(f"不支持的数据库类型: {driver}")
  29. self.engine = create_engine(
  30. connection_string,
  31. echo=self.config.get('database.echo', False),
  32. pool_size=self.config.get('database.pool_size', 5),
  33. max_overflow=self.config.get('database.max_overflow', 10),
  34. pool_timeout=self.config.get('database.pool_timeout', 30),
  35. pool_recycle=self.config.get('database.pool_recycle', 3600)
  36. )
  37. self.session_factory = sessionmaker(bind=self.engine)
  38. self.ScopedSession = scoped_session(self.session_factory)
  39. # 测试连接
  40. with self.engine.connect() as conn:
  41. conn.execute(text("SELECT 1"))
  42. self.logger.info("数据库连接成功")
  43. except Exception as e:
  44. self.logger.error(f"数据库连接失败: {str(e)}")
  45. raise
  46. def get_session(self):
  47. """获取数据库会话"""
  48. if not self.ScopedSession:
  49. self.connect()
  50. return self.ScopedSession()
  51. def execute_query(self, query, params=None, return_dataframe=False):
  52. """执行查询语句"""
  53. session = self.get_session()
  54. try:
  55. if isinstance(query, str):
  56. query = text(query)
  57. result = session.execute(query, params or {})
  58. if return_dataframe:
  59. # 将结果转换为DataFrame
  60. df = pd.DataFrame(result.fetchall(), columns=result.keys())
  61. return df
  62. else:
  63. return result.fetchall()
  64. except SQLAlchemyError as e:
  65. self.logger.error(f"查询执行失败: {str(e)}")
  66. session.rollback()
  67. raise
  68. finally:
  69. session.close()
  70. def execute_update(self, query, params=None):
  71. """执行更新语句"""
  72. session = self.get_session()
  73. try:
  74. if isinstance(query, str):
  75. query = text(query)
  76. result = session.execute(query, params or {})
  77. session.commit()
  78. return result.rowcount
  79. except SQLAlchemyError as e:
  80. self.logger.error(f"更新执行失败: {str(e)}")
  81. session.rollback()
  82. raise
  83. finally:
  84. session.close()
  85. def execute_many(self, query, params_list):
  86. """批量执行语句"""
  87. session = self.get_session()
  88. try:
  89. if isinstance(query, str):
  90. query = text(query)
  91. result = session.execute(query, params_list)
  92. session.commit()
  93. return result.rowcount
  94. except SQLAlchemyError as e:
  95. self.logger.error(f"批量执行失败: {str(e)}")
  96. session.rollback()
  97. raise
  98. finally:
  99. session.close()
  100. def insert_dataframe(self, table_name, df, if_exists='append', index=False):
  101. """将DataFrame插入数据库"""
  102. try:
  103. df.to_sql(
  104. table_name,
  105. self.engine,
  106. if_exists=if_exists,
  107. index=index,
  108. method='multi' if len(df) > 1000 else None
  109. )
  110. self.logger.info(f"成功插入 {len(df)} 行数据到表 {table_name}")
  111. return len(df)
  112. except Exception as e:
  113. self.logger.error(f"插入DataFrame失败: {str(e)}")
  114. raise
  115. def read_table(self, table_name, conditions=None, columns=None):
  116. """读取数据库表"""
  117. try:
  118. query = f"SELECT {','.join(columns) if columns else '*'} FROM {table_name}"
  119. if conditions:
  120. query += f" WHERE {conditions}"
  121. df = pd.read_sql(query, self.engine)
  122. self.logger.info(f"成功从表 {table_name} 读取 {len(df)} 行数据")
  123. return df
  124. except Exception as e:
  125. self.logger.error(f"读取表失败: {str(e)}")
  126. raise
  127. def close(self):
  128. """关闭数据库连接"""
  129. if self.ScopedSession:
  130. self.ScopedSession.remove()
  131. if self.engine:
  132. self.engine.dispose()
  133. self.logger.info("数据库连接已关闭")