""" 测试工具配置管理模块 """ import os import re from pathlib import Path from typing import Dict, Any import yaml class TestConfig: """测试配置类""" def __init__(self, config_path: str = None): """ 初始化配置 Args: config_path: 配置文件路径,默认为config/test_config.yaml """ if config_path is None: config_path = Path(__file__).parent.parent / "config" / "test_config.yaml" self.config_path = Path(config_path) self.config = self._load_config() def _load_config(self) -> Dict[str, Any]: """ 加载配置文件 Returns: 配置字典 """ if not self.config_path.exists(): return self._get_default_config() try: with open(self.config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) return self._resolve_env_vars(config) except Exception as e: print(f"加载配置文件失败: {e}") return self._get_default_config() def _resolve_env_vars(self, config: Any) -> Any: """ 解析环境变量 Args: config: 配置对象 Returns: 解析后的配置对象 """ if isinstance(config, dict): return {k: self._convert_type(self._resolve_env_vars(v)) for k, v in config.items()} elif isinstance(config, list): return [self._convert_type(self._resolve_env_vars(item)) for item in config] elif isinstance(config, str): return self._convert_type(self._parse_env_var(config)) else: return self._convert_type(config) def _parse_env_var(self, value: str) -> str: """ 解析环境变量 Args: value: 配置值 Returns: 解析后的值 """ pattern = r'\$\{([^:}]+)(?::([^}]*))?\}' match = re.search(pattern, value) if match: env_var = match.group(1) default_value = match.group(2) if match.group(2) else "" return os.getenv(env_var, default_value) return value def _convert_type(self, value: Any) -> Any: """ 类型转换 Args: value: 配置值 Returns: 转换后的值 """ if isinstance(value, str): if value.isdigit(): return int(value) try: return float(value) except ValueError: pass if value.lower() in ('true', 'false'): return value.lower() == 'true' return value def _get_default_config(self) -> Dict[str, Any]: """ 获取默认配置 Returns: 默认配置字典 """ return { "api": { "base_url": os.getenv("TEST_TOOLS_API_BASE_URL", "http://127.0.0.1:8080/api"), "timeout": int(os.getenv("TEST_TOOLS_API_TIMEOUT", "30")), "max_retries": int(os.getenv("TEST_TOOLS_API_MAX_RETRIES", "3")) }, "auth": { "login_endpoint": "/sys/auth/login", "username": os.getenv("TEST_TOOLS_AUTH_USERNAME", "admin"), "password": os.getenv("TEST_TOOLS_AUTH_PASSWORD", "admin123"), "token_storage": "memory" }, "test": { "data_dir": "data", "test_cases_dir": "test_cases", "parallel": True, "retry_count": 2 }, "report": { "output_dir": "../test-results/test-tools/reports", "formats": ["json", "html"], "include_details": True }, "logging": { "level": "INFO", "file": "../test-results/test-tools/logs/test.log", "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "console": True } } def get(self, key: str, default: Any = None) -> Any: """ 获取配置值 Args: key: 配置键,支持点号分隔的嵌套键(如"api.base_url") default: 默认值 Returns: 配置值 """ keys = key.split('.') value = self.config for k in keys: if isinstance(value, dict) and k in value: value = value[k] else: return default return value def set(self, key: str, value: Any) -> None: """ 设置配置值 Args: key: 配置键,支持点号分隔的嵌套键 value: 配置值 """ keys = key.split('.') config = self.config for k in keys[:-1]: if k not in config: config[k] = {} config = config[k] config[keys[-1]] = value def save(self) -> None: """保存配置到文件""" self.config_path.parent.mkdir(parents=True, exist_ok=True) with open(self.config_path, 'w', encoding='utf-8') as f: yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True) @property def api_base_url(self) -> str: """获取API基础URL""" return self.get("api.base_url") @property def api_timeout(self) -> int: """获取API超时时间""" return self.get("api.timeout") @property def api_max_retries(self) -> int: """获取API最大重试次数""" return self.get("api.max_retries") @property def auth_login_endpoint(self) -> str: """获取登录端点""" return self.get("auth.login_endpoint") @property def auth_username(self) -> str: """获取测试用户名""" return self.get("auth.username") @property def auth_password(self) -> str: """获取测试密码""" return self.get("auth.password") @property def report_output_dir(self) -> str: """获取报告输出目录""" return self.get("report.output_dir") @property def report_formats(self) -> list: """获取报告格式""" return self.get("report.formats") @property def logging_level(self) -> str: """获取日志级别""" return self.get("logging.level") @property def logging_file(self) -> str: """获取日志文件路径""" return self.get("logging.file") @property def logging_console(self) -> bool: """获取是否输出到控制台""" return self.get("logging.console", True) # 全局配置实例 config = TestConfig()