""" 测试执行性能优化模块 支持并行执行、结果缓存、连接池等功能 """ import asyncio import time from typing import List, Dict, Any, Optional, Callable from concurrent.futures import ThreadPoolExecutor, as_completed from functools import lru_cache from dataclasses import dataclass import hashlib import json from .api_tester import APITester, TestResult from ..utils.logger import TestLogger @dataclass class CachedResult: """缓存结果""" key: str result: TestResult timestamp: float ttl: int = 3600 # 缓存有效期(秒) def is_expired(self) -> bool: """检查缓存是否过期""" return time.time() - self.timestamp > self.ttl class ResultCache: """测试结果缓存""" def __init__(self, default_ttl: int = 3600): """ 初始化缓存 Args: default_ttl: 默认缓存有效期(秒) """ self.cache: Dict[str, CachedResult] = {} self.default_ttl = default_ttl def _generate_key(self, method: str, endpoint: str, data: Dict[str, Any] = None, params: Dict[str, Any] = None) -> str: """ 生成缓存键 Args: method: HTTP方法 endpoint: API端点 data: 请求体数据 params: URL参数 Returns: 缓存键 """ key_data = { "method": method, "endpoint": endpoint, "data": data or {}, "params": params or {} } key_str = json.dumps(key_data, sort_keys=True) return hashlib.md5(key_str.encode()).hexdigest() def get(self, method: str, endpoint: str, data: Dict[str, Any] = None, params: Dict[str, Any] = None) -> Optional[TestResult]: """ 获取缓存结果 Args: method: HTTP方法 endpoint: API端点 data: 请求体数据 params: URL参数 Returns: 缓存的测试结果 """ key = self._generate_key(method, endpoint, data, params) if key in self.cache: cached = self.cache[key] if not cached.is_expired(): return cached.result else: del self.cache[key] return None def set(self, method: str, endpoint: str, result: TestResult, data: Dict[str, Any] = None, params: Dict[str, Any] = None, ttl: int = None) -> None: """ 设置缓存结果 Args: method: HTTP方法 endpoint: API端点 result: 测试结果 data: 请求体数据 params: URL参数 ttl: 缓存有效期(秒) """ key = self._generate_key(method, endpoint, data, params) cached = CachedResult( key=key, result=result, timestamp=time.time(), ttl=ttl or self.default_ttl ) self.cache[key] = cached def clear(self) -> None: """清空缓存""" self.cache.clear() def cleanup_expired(self) -> None: """清理过期缓存""" expired_keys = [key for key, cached in self.cache.items() if cached.is_expired()] for key in expired_keys: del self.cache[key] class ParallelTestExecutor: """并行测试执行器""" def __init__(self, max_workers: int = 4, logger: TestLogger = None): """ 初始化并行执行器 Args: max_workers: 最大工作线程数 logger: 日志记录器 """ self.max_workers = max_workers self.logger = logger def execute_tests(self, test_functions: List[Callable], use_cache: bool = True, cache: ResultCache = None) -> List[TestResult]: """ 并行执行测试 Args: test_functions: 测试函数列表 use_cache: 是否使用缓存 cache: 缓存实例 Returns: 测试结果列表 """ results = [] with ThreadPoolExecutor(max_workers=self.max_workers) as executor: # 提交所有测试任务 future_to_test = { executor.submit(self._execute_single_test, func, use_cache, cache): func for func in test_functions } # 收集结果 for future in as_completed(future_to_test): test_func = future_to_test[future] try: result = future.result() results.append(result) except Exception as e: if self.logger: self.logger.error(f"测试执行失败: {str(e)}") results.append(TestResult( passed=False, test_name=test_func.__name__, error_message=str(e) )) return results def _execute_single_test(self, test_func: Callable, use_cache: bool, cache: ResultCache) -> TestResult: """ 执行单个测试 Args: test_func: 测试函数 use_cache: 是否使用缓存 cache: 缓存实例 Returns: 测试结果 """ try: # 尝试从缓存获取结果 if use_cache and cache: # 这里简化实现,实际应该根据测试函数的参数生成缓存键 pass # 执行测试 return test_func() except Exception as e: raise e def execute_tests_async(self, test_functions: List[Callable]) -> List[TestResult]: """ 异步并行执行测试 Args: test_functions: 测试函数列表 Returns: 测试结果列表 """ async def run_all(): tasks = [self._execute_single_test_async(func) for func in test_functions] return await asyncio.gather(*tasks, return_exceptions=True) results = asyncio.run(run_all()) # 处理异常结果 processed_results = [] for i, result in enumerate(results): if isinstance(result, Exception): processed_results.append(TestResult( passed=False, test_name=test_functions[i].__name__, error_message=str(result) )) else: processed_results.append(result) return processed_results async def _execute_single_test_async(self, test_func: Callable) -> TestResult: """ 异步执行单个测试 Args: test_func: 测试函数 Returns: 测试结果 """ # 这里简化实现,实际应该使用异步HTTP客户端 return test_func() class PerformanceOptimizer: """性能优化器""" def __init__(self, logger: TestLogger = None): """ 初始化性能优化器 Args: logger: 日志记录器 """ self.logger = logger self.cache = ResultCache() self.executor = ParallelTestExecutor(logger=logger) def optimize_test_execution(self, test_functions: List[Callable], parallel: bool = True, use_cache: bool = True) -> List[TestResult]: """ 优化测试执行 Args: test_functions: 测试函数列表 parallel: 是否并行执行 use_cache: 是否使用缓存 Returns: 测试结果列表 """ start_time = time.time() if self.logger: self.logger.info(f"开始优化测试执行: {len(test_functions)}个测试用例") self.logger.info(f"并行执行: {parallel}, 使用缓存: {use_cache}") # 清理过期缓存 self.cache.cleanup_expired() # 执行测试 if parallel: results = self.executor.execute_tests(test_functions, use_cache, self.cache) else: results = [] for func in test_functions: try: result = func() results.append(result) except Exception as e: if self.logger: self.logger.error(f"测试执行失败: {str(e)}") results.append(TestResult( passed=False, test_name=func.__name__, error_message=str(e) )) execution_time = time.time() - start_time if self.logger: self.logger.info(f"测试执行完成: {execution_time:.2f}秒") self.logger.info(f"平均每个测试: {execution_time/len(test_functions):.2f}秒") return results def get_cache_stats(self) -> Dict[str, Any]: """ 获取缓存统计信息 Returns: 缓存统计信息 """ return { "total_entries": len(self.cache.cache), "expired_entries": sum(1 for cached in self.cache.cache.values() if cached.is_expired()), "valid_entries": sum(1 for cached in self.cache.cache.values() if not cached.is_expired()) } def clear_cache(self) -> None: """清空缓存""" self.cache.clear() if self.logger: self.logger.info("缓存已清空") class ConnectionPool: """连接池""" def __init__(self, max_connections: int = 10): """ 初始化连接池 Args: max_connections: 最大连接数 """ self.max_connections = max_connections self.connections = [] def get_connection(self) -> APITester: """ 获取连接 Returns: API测试器实例 """ if self.connections: return self.connections.pop() return APITester() def return_connection(self, connection: APITester) -> None: """ 归还连接 Args: connection: API测试器实例 """ if len(self.connections) < self.max_connections: self.connections.append(connection) else: connection.close() def close_all(self) -> None: """关闭所有连接""" for connection in self.connections: connection.close() self.connections.clear()