config_manager.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import os
  2. import yaml
  3. from pathlib import Path
  4. from dotenv import load_dotenv
  5. from core.singleton import Singleton
  6. class ConfigManager(metaclass=Singleton):
  7. def __init__(self):
  8. self.config = {}
  9. self._logger = None # 延迟初始化日志记录器
  10. self.load_environment_variables()
  11. self.load_config_files()
  12. def set_logger(self, logger):
  13. """设置日志记录器(延迟初始化)"""
  14. self._logger = logger
  15. def log(self, level, message):
  16. """记录日志(如果日志记录器已初始化)"""
  17. if self._logger:
  18. if level == 'debug':
  19. self._logger.debug(message)
  20. elif level == 'info':
  21. self._logger.info(message)
  22. elif level == 'warning':
  23. self._logger.warning(message)
  24. elif level == 'error':
  25. self._logger.error(message)
  26. elif level == 'critical':
  27. self._logger.critical(message)
  28. else:
  29. # 如果日志记录器尚未初始化,使用简单的打印
  30. print(f"[{level.upper()}] {message}")
  31. def load_environment_variables(self):
  32. """加载环境变量"""
  33. load_dotenv()
  34. self.env_vars = {
  35. 'BROWSER': os.getenv('BROWSER', 'chrome'),
  36. 'HEADLESS': os.getenv('HEADLESS', 'false').lower() == 'true',
  37. 'BASE_URL': os.getenv('BASE_URL', 'https://example.com'),
  38. 'ENVIRONMENT': os.getenv('ENVIRONMENT', 'dev'),
  39. 'DB_HOST': os.getenv('DB_HOST', 'localhost'),
  40. 'DB_PORT': os.getenv('DB_PORT', '3306'),
  41. 'DB_NAME': os.getenv('DB_NAME', 'test_db'),
  42. 'DB_USER': os.getenv('DB_USER', 'root'),
  43. 'DB_PASSWORD': os.getenv('DB_PASSWORD', ''),
  44. 'SMTP_SERVER': os.getenv('SMTP_SERVER', 'smtp.gmail.com'),
  45. 'SMTP_PORT': os.getenv('SMTP_PORT', '587'),
  46. 'SMTP_USERNAME': os.getenv('SMTP_USERNAME', ''),
  47. 'SMTP_PASSWORD': os.getenv('SMTP_PASSWORD', ''),
  48. }
  49. self.log('info', "环境变量加载完成")
  50. def load_config_files(self):
  51. """加载配置文件"""
  52. try:
  53. config_path = Path(__file__).parent.parent.parent / 'resources' / 'config'
  54. # 加载主配置
  55. main_config_file = config_path / 'config.yaml'
  56. if main_config_file.exists():
  57. with open(main_config_file, 'r') as f:
  58. self.config.update(yaml.safe_load(f))
  59. # 加载环境配置
  60. env = self.env_vars['ENVIRONMENT']
  61. env_config_file = config_path / 'environments' / f'{env}.yaml'
  62. if env_config_file.exists():
  63. with open(env_config_file, 'r') as f:
  64. self.config.update(yaml.safe_load(f))
  65. # 加载测试数据
  66. test_data_file = config_path / 'test_data.yaml'
  67. if test_data_file.exists():
  68. with open(test_data_file, 'r') as f:
  69. self.config['test_data'] = yaml.safe_load(f)
  70. self.log('info', "配置文件加载成功")
  71. except Exception as e:
  72. self.log('error', f"配置文件加载失败: {str(e)}")
  73. raise
  74. def get(self, key, default=None):
  75. """获取配置值"""
  76. keys = key.split('.')
  77. value = self.config
  78. for k in keys:
  79. if isinstance(value, dict) and k in value:
  80. value = value[k]
  81. else:
  82. return self.env_vars.get(key, default)
  83. return value
  84. def get_browser_config(self):
  85. """获取浏览器配置"""
  86. return {
  87. 'browser': self.get('browser.type', self.env_vars['BROWSER']),
  88. 'headless': self.get('browser.headless', self.env_vars['HEADLESS']),
  89. 'implicit_wait': self.get('browser.implicit_wait', 10),
  90. 'explicit_wait': self.get('browser.explicit_wait', 30),
  91. 'window_size': self.get('browser.window_size', (1920, 1080))
  92. }
  93. def get_db_config(self):
  94. """获取数据库配置"""
  95. return {
  96. 'host': self.get('database.host', self.env_vars['DB_HOST']),
  97. 'port': self.get('database.port', int(self.env_vars['DB_PORT'])),
  98. 'name': self.get('database.name', self.env_vars['DB_NAME']),
  99. 'user': self.get('database.user', self.env_vars['DB_USER']),
  100. 'password': self.get('database.password', self.env_vars['DB_PASSWORD']),
  101. 'driver': self.get('database.driver', 'mysql')
  102. }