feat(admin): 添加用户管理相关文件

添加用户管理视图、API和状态管理文件
This commit is contained in:
张翔
2026-03-28 14:37:29 +08:00
commit 08ea5fbe98
1643 changed files with 255646 additions and 0 deletions
@@ -0,0 +1,201 @@
"""
E2E测试核心框架模块
提供测试框架的基础功能,包括:
- 配置管理
- 日志记录
- 异常处理
- 重试机制
- 报告生成
- Caffeine缓存管理
- 数据库连接池管理
- 本地并发控制
- 安全测试模块
- 文件上传下载功能
- 定时任务调度器
- 数据导入导出功能
- 审计日志模块
- 数据备份恢复功能
"""
from .config_manager import ConfigManager, TestConfig
from .logger import TestLogger, get_logger
from .exception_handler import TestExceptionHandler, FatalTestError, RetryableError
from .retry_decorator import retry_on_failure
from .reporter import TestReporter
from .screenshot_helper import ScreenshotHelper
from .caffeine_cache import CaffeineCache, CaffeineCacheManager, cache_manager
from .connection_pool import ConnectionPool, ConnectionPoolManager, pool_manager
from .concurrency_control import (
SemaphoreControl,
ReadWriteLock,
RateLimiter,
LocalDistributedLock,
ConcurrentCounter,
ThreadBarrier,
BoundedTaskQueue,
ConcurrencyManager,
concurrency_manager,
)
from .security import (
SQLInjectionDetector,
XSSDetector,
CSRFProtector,
InputSanitizer,
PasswordStrengthChecker,
SecurityHeaders,
SecurityAuditLogger,
SecurityScanner,
ThreatLevel,
DetectionResult,
SQLInjectionResult,
XSSResult,
PasswordStrengthResult,
SecurityEvent,
SecurityReport,
)
from .file_handler import (
FileUploader,
FileDownloader,
FileTypeValidator,
FileSizeValidator,
FilenameSanitizer,
FileStorageManager,
UploadResult,
DownloadResult,
)
from .task_scheduler import (
TaskScheduler,
Task,
TaskStatus,
SchedulerState,
TaskExecutionRecord,
)
from .data_import_export import (
CSVExporter,
CSVImporter,
ExcelExporter,
DataValidator,
DataTransformer,
TemplateManager,
DataImportExportManager,
ExportResult,
ImportResult,
ValidationResult,
)
from .audit_log import (
OperationLogRecorder,
ObjectChangeAuditor,
AuditLogStorage,
MemoryAuditStorage,
AuditLogRecorder,
AuditLogExporter,
AuditStatistics,
OperationLogEntry,
ObjectChange,
DiffResult,
audit_log,
)
from .backup_restore import (
BackupManager,
BackupScheduler,
BackupResult,
RestoreResult,
VerifyResult,
DeleteResult,
BackupInfo,
)
from .api_client import (
APIClient,
APIRequest,
APIResponse,
HTTPMethod,
)
__all__ = [
"ConfigManager",
"TestConfig",
"TestLogger",
"get_logger",
"TestExceptionHandler",
"FatalTestError",
"RetryableError",
"retry_on_failure",
"TestReporter",
"ScreenshotHelper",
"CaffeineCache",
"CaffeineCacheManager",
"cache_manager",
"ConnectionPool",
"ConnectionPoolManager",
"pool_manager",
"SemaphoreControl",
"ReadWriteLock",
"RateLimiter",
"LocalDistributedLock",
"ConcurrentCounter",
"ThreadBarrier",
"BoundedTaskQueue",
"ConcurrencyManager",
"concurrency_manager",
"SQLInjectionDetector",
"XSSDetector",
"CSRFProtector",
"InputSanitizer",
"PasswordStrengthChecker",
"SecurityHeaders",
"SecurityAuditLogger",
"SecurityScanner",
"ThreatLevel",
"DetectionResult",
"SQLInjectionResult",
"XSSResult",
"PasswordStrengthResult",
"SecurityEvent",
"SecurityReport",
"FileUploader",
"FileDownloader",
"FileTypeValidator",
"FileSizeValidator",
"FilenameSanitizer",
"FileStorageManager",
"UploadResult",
"DownloadResult",
"TaskScheduler",
"Task",
"TaskStatus",
"SchedulerState",
"TaskExecutionRecord",
"CSVExporter",
"CSVImporter",
"ExcelExporter",
"DataValidator",
"DataTransformer",
"TemplateManager",
"DataImportExportManager",
"ExportResult",
"ImportResult",
"ValidationResult",
"OperationLogRecorder",
"ObjectChangeAuditor",
"AuditLogStorage",
"MemoryAuditStorage",
"AuditLogRecorder",
"AuditLogExporter",
"AuditStatistics",
"OperationLogEntry",
"ObjectChange",
"DiffResult",
"audit_log",
"BackupManager",
"BackupScheduler",
"BackupResult",
"RestoreResult",
"VerifyResult",
"DeleteResult",
"BackupInfo",
"APIClient",
"APIRequest",
"APIResponse",
"HTTPMethod",
]
@@ -0,0 +1,238 @@
"""
API客户端模块
提供HTTP请求封装和API调用功能。
"""
import requests
import json
from typing import Any, Dict, Optional, Callable
from dataclasses import dataclass
from enum import Enum
class HTTPMethod(Enum):
"""HTTP方法"""
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
@dataclass
class APIResponse:
"""API响应"""
status_code: int
data: Any
headers: Dict[str, str]
success: bool
error_message: Optional[str] = None
@dataclass
class APIRequest:
"""API请求"""
method: HTTPMethod
url: str
headers: Optional[Dict[str, str]] = None
params: Optional[Dict[str, Any]] = None
data: Optional[Any] = None
json_data: Optional[Dict[str, Any]] = None
timeout: int = 30
class APIClient:
"""
API客户端
特性:
- 支持多种HTTP方法
- 自动JSON序列化/反序列化
- 超时处理
- 错误处理
- 请求/响应拦截器
"""
def __init__(
self,
base_url: str,
default_headers: Optional[Dict[str, str]] = None,
timeout: int = 30
):
"""
初始化API客户端
Args:
base_url: 基础URL
default_headers: 默认请求头
timeout: 默认超时时间(秒)
"""
self._base_url = base_url.rstrip('/')
self._default_headers = default_headers or {}
self._default_timeout = timeout
self._request_interceptors: list = []
self._response_interceptors: list = []
self._session = requests.Session()
def add_request_interceptor(self, interceptor: Callable[[APIRequest], APIRequest]) -> None:
"""添加请求拦截器"""
self._request_interceptors.append(interceptor)
def add_response_interceptor(self, interceptor: Callable[[APIResponse], APIResponse]) -> None:
"""添加响应拦截器"""
self._response_interceptors.append(interceptor)
def _build_url(self, endpoint: str) -> str:
"""构建完整URL"""
endpoint = endpoint.lstrip('/')
return f"{self._base_url}/{endpoint}"
def _apply_request_interceptors(self, request: APIRequest) -> APIRequest:
"""应用请求拦截器"""
for interceptor in self._request_interceptors:
request = interceptor(request)
return request
def _apply_response_interceptors(self, response: APIResponse) -> APIResponse:
"""应用响应拦截器"""
for interceptor in self._response_interceptors:
response = interceptor(response)
return response
def request(self, api_request: APIRequest) -> APIResponse:
"""
发送HTTP请求
Args:
api_request: API请求对象
Returns:
API响应对象
"""
try:
# 应用请求拦截器
api_request = self._apply_request_interceptors(api_request)
# 合并默认请求头
headers = {**self._default_headers, **(api_request.headers or {})}
# 发送请求
response = self._session.request(
method=api_request.method.value,
url=api_request.url,
headers=headers,
params=api_request.params,
data=api_request.data,
json=api_request.json_data,
timeout=api_request.timeout or self._default_timeout
)
# 解析响应
try:
response_data = response.json()
except json.JSONDecodeError:
response_data = response.text
api_response = APIResponse(
status_code=response.status_code,
data=response_data,
headers=dict(response.headers),
success=200 <= response.status_code < 300
)
# 应用响应拦截器
return self._apply_response_interceptors(api_response)
except requests.exceptions.Timeout:
return APIResponse(
status_code=0,
data=None,
headers={},
success=False,
error_message="请求超时"
)
except requests.exceptions.ConnectionError as e:
return APIResponse(
status_code=0,
data=None,
headers={},
success=False,
error_message=f"连接错误: {str(e)}"
)
except Exception as e:
return APIResponse(
status_code=0,
data=None,
headers={},
success=False,
error_message=f"请求失败: {str(e)}"
)
def get(
self,
endpoint: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> APIResponse:
"""发送GET请求"""
request = APIRequest(
method=HTTPMethod.GET,
url=self._build_url(endpoint),
headers=headers,
params=params,
timeout=timeout or self._default_timeout
)
return self.request(request)
def post(
self,
endpoint: str,
json_data: Optional[Dict[str, Any]] = None,
data: Optional[Any] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> APIResponse:
"""发送POST请求"""
request = APIRequest(
method=HTTPMethod.POST,
url=self._build_url(endpoint),
headers=headers,
data=data,
json_data=json_data,
timeout=timeout or self._default_timeout
)
return self.request(request)
def put(
self,
endpoint: str,
json_data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> APIResponse:
"""发送PUT请求"""
request = APIRequest(
method=HTTPMethod.PUT,
url=self._build_url(endpoint),
headers=headers,
json_data=json_data,
timeout=timeout or self._default_timeout
)
return self.request(request)
def delete(
self,
endpoint: str,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> APIResponse:
"""发送DELETE请求"""
request = APIRequest(
method=HTTPMethod.DELETE,
url=self._build_url(endpoint),
headers=headers,
timeout=timeout or self._default_timeout
)
return self.request(request)
@@ -0,0 +1,237 @@
"""
API模拟服务器
为TDD Green阶段提供模拟API服务,使测试能够通过。
"""
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
import uuid
import time
@dataclass
class Role:
"""角色数据模型"""
id: str
name: str
code: str
description: str
status: str = "active"
permissions: List[str] = field(default_factory=list)
created_at: str = ""
updated_at: str = ""
class RoleMockService:
"""角色管理模拟服务"""
def __init__(self):
self._roles: Dict[str, Role] = {}
self._init_default_roles()
def _init_default_roles(self):
"""初始化默认角色"""
default_roles = [
Role(
id="1",
name="系统管理员",
code="admin",
description="拥有所有权限",
permissions=["*"],
created_at="2026-01-01T00:00:00Z"
),
Role(
id="2",
name="普通用户",
code="user",
description="拥有基本权限",
permissions=["user:read", "user:write"],
created_at="2026-01-01T00:00:00Z"
),
]
for role in default_roles:
self._roles[role.id] = role
def create_role(self, name: str, code: str, description: str,
permissions: List[str] = None) -> Dict[str, Any]:
"""创建角色"""
# 检查角色编码是否已存在
for role in self._roles.values():
if role.code == code:
return {
"success": False,
"message": f"角色编码 '{code}' 已存在",
"code": "DUPLICATE_CODE"
}
# 创建新角色
new_role = Role(
id=str(uuid.uuid4()),
name=name,
code=code,
description=description,
permissions=permissions or [],
created_at=time.strftime("%Y-%m-%dT%H:%M:%SZ"),
updated_at=time.strftime("%Y-%m-%dT%H:%M:%SZ")
)
self._roles[new_role.id] = new_role
return {
"success": True,
"message": "角色创建成功",
"data": self._role_to_dict(new_role)
}
def update_role(self, role_id: str, name: str = None,
description: str = None, permissions: List[str] = None) -> Dict[str, Any]:
"""更新角色"""
if role_id not in self._roles:
return {
"success": False,
"message": f"角色ID '{role_id}' 不存在",
"code": "NOT_FOUND"
}
role = self._roles[role_id]
if name:
role.name = name
if description:
role.description = description
if permissions is not None:
role.permissions = permissions
role.updated_at = time.strftime("%Y-%m-%dT%H:%M:%SZ")
return {
"success": True,
"message": "角色更新成功",
"data": self._role_to_dict(role)
}
def delete_role(self, role_id: str) -> Dict[str, Any]:
"""删除角色"""
if role_id not in self._roles:
return {
"success": False,
"message": f"角色ID '{role_id}' 不存在",
"code": "NOT_FOUND"
}
# 不允许删除系统默认角色
role = self._roles[role_id]
if role.code in ["admin", "user"]:
return {
"success": False,
"message": f"不能删除系统默认角色 '{role.code}'",
"code": "FORBIDDEN"
}
del self._roles[role_id]
return {
"success": True,
"message": "角色删除成功"
}
def get_role(self, role_id: str) -> Optional[Dict[str, Any]]:
"""获取角色详情"""
if role_id not in self._roles:
return None
return self._role_to_dict(self._roles[role_id])
def list_roles(self, keyword: str = None) -> Dict[str, Any]:
"""获取角色列表"""
roles = list(self._roles.values())
# 搜索过滤
if keyword:
keyword = keyword.lower()
roles = [
role for role in roles
if keyword in role.name.lower()
or keyword in role.code.lower()
or keyword in role.description.lower()
]
return {
"success": True,
"data": {
"list": [self._role_to_dict(role) for role in roles],
"total": len(roles)
}
}
def assign_permissions(self, role_id: str, permissions: List[str]) -> Dict[str, Any]:
"""分配权限"""
if role_id not in self._roles:
return {
"success": False,
"message": f"角色ID '{role_id}' 不存在",
"code": "NOT_FOUND"
}
role = self._roles[role_id]
role.permissions = permissions
role.updated_at = time.strftime("%Y-%m-%dT%H:%M:%SZ")
return {
"success": True,
"message": "权限分配成功",
"data": self._role_to_dict(role)
}
def _role_to_dict(self, role: Role) -> Dict[str, Any]:
"""角色对象转字典"""
return {
"id": role.id,
"name": role.name,
"code": role.code,
"description": role.description,
"status": role.status,
"permissions": role.permissions,
"createdAt": role.created_at,
"updatedAt": role.updated_at
}
# 全局模拟服务实例
role_mock_service = RoleMockService()
# 模拟API端点
def mock_api_create_role(data: Dict[str, Any]) -> Dict[str, Any]:
"""模拟创建角色API"""
return role_mock_service.create_role(
name=data.get("name", ""),
code=data.get("code", ""),
description=data.get("description", ""),
permissions=data.get("permissions", [])
)
def mock_api_update_role(role_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""模拟更新角色API"""
return role_mock_service.update_role(
role_id=role_id,
name=data.get("name"),
description=data.get("description"),
permissions=data.get("permissions")
)
def mock_api_delete_role(role_id: str) -> Dict[str, Any]:
"""模拟删除角色API"""
return role_mock_service.delete_role(role_id)
def mock_api_list_roles(keyword: str = None) -> Dict[str, Any]:
"""模拟获取角色列表API"""
return role_mock_service.list_roles(keyword)
def mock_api_assign_permissions(role_id: str, permissions: List[str]) -> Dict[str, Any]:
"""模拟分配权限API"""
return role_mock_service.assign_permissions(role_id, permissions)
@@ -0,0 +1,427 @@
"""
审计日志模块
提供操作日志记录和JaVers风格的对象变更审计功能。
"""
import json
import time
import uuid
import functools
from typing import Any, Dict, List, Optional, Callable
from dataclasses import dataclass, field
from datetime import datetime
from abc import ABC, abstractmethod
@dataclass
class OperationLogEntry:
"""操作日志条目"""
id: str
operation_time: datetime
module_name: str
operation_desc: str
operator: str
operator_id: Optional[int] = None
request_method: Optional[str] = None
request_path: Optional[str] = None
request_params: Optional[str] = None
response_result: Optional[str] = None
ip_address: Optional[str] = None
execution_time: Optional[int] = None # 执行时间(毫秒)
status: str = "SUCCESS"
exception_message: Optional[str] = None
diff_json: Optional[str] = None
@dataclass
class ObjectChange:
"""对象变更记录"""
field_name: str
old_value: Any
new_value: Any
@dataclass
class DiffResult:
"""差异比较结果"""
has_changes: bool
changes: List[ObjectChange]
@dataclass
class AuditStatistics:
"""审计统计信息"""
total_operations: int = 0
success_count: int = 0
failure_count: int = 0
module_distribution: Dict[str, int] = field(default_factory=dict)
class AuditLogStorage(ABC):
"""审计日志存储抽象基类"""
@abstractmethod
def save(self, log_entry: Dict[str, Any]) -> None:
"""保存日志条目"""
pass
@abstractmethod
def query(
self,
module_name: Optional[str] = None,
operator: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
status: Optional[str] = None
) -> List[Dict[str, Any]]:
"""查询日志"""
pass
@abstractmethod
def get_all(self) -> List[Dict[str, Any]]:
"""获取所有日志"""
pass
@abstractmethod
def delete_old(self, max_keep: int) -> int:
"""删除旧日志"""
pass
class MemoryAuditStorage(AuditLogStorage):
"""内存审计日志存储"""
def __init__(self):
self._logs: List[Dict[str, Any]] = []
def save(self, log_entry: Dict[str, Any]) -> None:
self._logs.append(log_entry)
def query(
self,
module_name: Optional[str] = None,
operator: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
status: Optional[str] = None
) -> List[Dict[str, Any]]:
result = self._logs
if module_name:
result = [log for log in result if log.get("module_name") == module_name]
if operator:
result = [log for log in result if log.get("operator") == operator]
if start_time:
result = [log for log in result if log.get("timestamp", 0) >= start_time]
if end_time:
result = [log for log in result if log.get("timestamp", 0) <= end_time]
if status:
result = [log for log in result if log.get("status") == status]
return result
def get_all(self) -> List[Dict[str, Any]]:
return self._logs.copy()
def delete_old(self, max_keep: int) -> int:
if len(self._logs) <= max_keep:
return 0
# 按时间排序,保留最新的
sorted_logs = sorted(self._logs, key=lambda x: x.get("timestamp", 0), reverse=True)
self._logs = sorted_logs[:max_keep]
deleted_count = len(sorted_logs) - len(self._logs)
return deleted_count
class OperationLogRecorder:
"""操作日志记录器"""
def __init__(self, storage: Optional[AuditLogStorage] = None):
self._storage = storage or MemoryAuditStorage()
def record(
self,
module_name: str,
operation_desc: str,
operator: str,
operator_id: Optional[int] = None,
request_method: Optional[str] = None,
request_path: Optional[str] = None,
request_params: Optional[str] = None,
ip_address: Optional[str] = None,
execution_time: Optional[int] = None,
status: str = "SUCCESS",
exception_message: Optional[str] = None,
diff_json: Optional[str] = None
) -> OperationLogEntry:
"""记录操作日志"""
entry = OperationLogEntry(
id=str(uuid.uuid4()),
operation_time=datetime.now(),
module_name=module_name,
operation_desc=operation_desc,
operator=operator,
operator_id=operator_id,
request_method=request_method,
request_path=request_path,
request_params=request_params,
ip_address=ip_address,
execution_time=execution_time,
status=status,
exception_message=exception_message,
diff_json=diff_json
)
# 转换为字典并保存
log_dict = {
"id": entry.id,
"timestamp": time.time(),
"module_name": entry.module_name,
"operation_desc": entry.operation_desc,
"operator": entry.operator,
"operator_id": entry.operator_id,
"request_method": entry.request_method,
"request_path": entry.request_path,
"request_params": entry.request_params,
"ip_address": entry.ip_address,
"execution_time": entry.execution_time,
"status": entry.status,
"exception_message": entry.exception_message,
"diff_json": entry.diff_json,
}
self._storage.save(log_dict)
return entry
def query_logs(
self,
module_name: Optional[str] = None,
operator: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
status: Optional[str] = None
) -> List[Dict[str, Any]]:
"""查询操作日志"""
return self._storage.query(
module_name=module_name,
operator=operator,
start_time=start_time,
end_time=end_time,
status=status
)
def get_statistics(self) -> AuditStatistics:
"""获取统计信息"""
logs = self._storage.get_all()
stats = AuditStatistics()
stats.total_operations = len(logs)
for log in logs:
status = log.get("status", "SUCCESS")
if status == "SUCCESS":
stats.success_count += 1
else:
stats.failure_count += 1
module = log.get("module_name", "unknown")
stats.module_distribution[module] = stats.module_distribution.get(module, 0) + 1
return stats
def cleanup(self, max_keep: int = 1000) -> int:
"""清理旧日志"""
return self._storage.delete_old(max_keep)
class ObjectChangeAuditor:
"""对象变更审计器(JaVers风格)"""
def compare(self, old_object: Dict[str, Any], new_object: Dict[str, Any]) -> DiffResult:
"""
比较两个对象的差异
Args:
old_object: 旧对象
new_object: 新对象
Returns:
差异结果
"""
changes = []
# 获取所有字段
all_keys = set(old_object.keys()) | set(new_object.keys())
for key in all_keys:
old_value = old_object.get(key)
new_value = new_object.get(key)
if old_value != new_value:
changes.append(ObjectChange(
field_name=key,
old_value=old_value,
new_value=new_value
))
return DiffResult(
has_changes=len(changes) > 0,
changes=changes
)
def get_changed_fields(
self,
old_object: Dict[str, Any],
new_object: Dict[str, Any]
) -> List[ObjectChange]:
"""获取变更的字段列表"""
diff_result = self.compare(old_object, new_object)
return diff_result.changes
def to_json(self, obj: Any) -> str:
"""将对象转换为JSON字符串"""
return json.dumps(obj, ensure_ascii=False, default=str)
class AuditLogExporter:
"""审计日志导出器"""
def __init__(self, recorder: OperationLogRecorder):
self._recorder = recorder
def export_to_json(
self,
module_name: Optional[str] = None,
operator: Optional[str] = None
) -> str:
"""导出为JSON格式"""
logs = self._recorder.query_logs(
module_name=module_name,
operator=operator
)
return json.dumps(logs, ensure_ascii=False, indent=2, default=str)
def export_to_csv(
self,
module_name: Optional[str] = None,
operator: Optional[str] = None
) -> str:
"""导出为CSV格式"""
logs = self._recorder.query_logs(
module_name=module_name,
operator=operator
)
if not logs:
return ""
# 获取表头
headers = ["timestamp", "module_name", "operation_desc", "operator", "status"]
# 生成CSV
lines = [",".join(headers)]
for log in logs:
values = [
str(log.get(h, "")) for h in headers
]
lines.append(",".join(values))
return "\n".join(lines)
class AuditLogRecorder:
"""统一的审计日志记录器"""
def __init__(self, storage: Optional[AuditLogStorage] = None):
self._operation_recorder = OperationLogRecorder(storage)
self._change_auditor = ObjectChangeAuditor()
def record_operation(self, **kwargs) -> OperationLogEntry:
"""记录操作日志"""
return self._operation_recorder.record(**kwargs)
def query_logs(self, **kwargs) -> List[Dict[str, Any]]:
"""查询日志"""
return self._operation_recorder.query_logs(**kwargs)
def record_change(
self,
old_object: Dict[str, Any],
new_object: Dict[str, Any],
**kwargs
) -> OperationLogEntry:
"""记录对象变更"""
# 比较差异
diff_result = self._change_auditor.compare(old_object, new_object)
# 生成差异JSON
diff_json = json.dumps([
{
"field": c.field_name,
"old": c.old_value,
"new": c.new_value
}
for c in diff_result.changes
], ensure_ascii=False)
# 记录操作日志
return self._operation_recorder.record(
diff_json=diff_json,
**kwargs
)
def query_logs(self, **kwargs) -> List[Dict[str, Any]]:
"""查询日志"""
return self._operation_recorder.query_logs(**kwargs)
def audit_log(
recorder: AuditLogRecorder,
module_name: str,
operation_desc: str
):
"""
审计日志装饰器
Args:
recorder: 审计日志记录器
module_name: 模块名称
operation_desc: 操作描述
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
status = "SUCCESS"
exception_msg = None
result = None
try:
result = func(*args, **kwargs)
return result
except Exception as e:
status = "FAILURE"
exception_msg = str(e)
raise
finally:
execution_time = int((time.time() - start_time) * 1000)
# 记录日志
recorder.record_operation(
module_name=module_name,
operation_desc=operation_desc,
operator="system", # 可以从上下文获取
request_params=json.dumps({"args": args, "kwargs": kwargs}, default=str),
execution_time=execution_time,
status=status,
exception_message=exception_msg
)
return wrapper
return decorator
@@ -0,0 +1,429 @@
"""
数据备份恢复功能模块
提供数据备份、恢复、验证和管理功能。
"""
import json
import os
import gzip
import hashlib
import shutil
import time
import uuid
from typing import Any, Callable, Dict, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
@dataclass
class BackupResult:
"""备份结果"""
success: bool
backup_id: Optional[str] = None
backup_path: Optional[str] = None
size: int = 0
checksum: Optional[str] = None
message: str = ""
@dataclass
class RestoreResult:
"""恢复结果"""
success: bool
data: Optional[Any] = None
message: str = ""
@dataclass
class VerifyResult:
"""验证结果"""
is_valid: bool
message: str = ""
@dataclass
class DeleteResult:
"""删除结果"""
success: bool
message: str = ""
@dataclass
class BackupInfo:
"""备份信息"""
backup_id: str
backup_name: str
description: str
created_at: float
size: int
checksum: str
is_compressed: bool
is_incremental: bool
base_backup_id: Optional[str] = None
class BackupManager:
"""
备份管理器
特性:
- 支持完整备份和增量备份
- 支持压缩
- 支持校验和验证
- 支持备份列表和查询
"""
def __init__(self, backup_dir: str):
"""
初始化备份管理器
Args:
backup_dir: 备份目录
"""
self._backup_dir = backup_dir
self._backups: Dict[str, BackupInfo] = {}
self._data_cache: Dict[str, Any] = {}
# 创建备份目录
os.makedirs(backup_dir, exist_ok=True)
def backup(
self,
data: Any,
backup_name: str,
description: str = "",
compress: bool = False
) -> BackupResult:
"""
创建备份
Args:
data: 要备份的数据
backup_name: 备份名称
description: 备份描述
compress: 是否压缩
Returns:
备份结果
"""
try:
backup_id = str(uuid.uuid4())
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{backup_name}_{timestamp}_{backup_id}.json"
if compress:
filename += ".gz"
backup_path = os.path.join(self._backup_dir, filename)
# 序列化数据
json_data = json.dumps(data, ensure_ascii=False, default=str)
# 计算校验和
checksum = hashlib.md5(json_data.encode('utf-8')).hexdigest()
# 写入文件
if compress:
with gzip.open(backup_path, 'wt', encoding='utf-8') as f:
f.write(json_data)
else:
with open(backup_path, 'w', encoding='utf-8') as f:
f.write(json_data)
# 获取文件大小
size = os.path.getsize(backup_path)
# 记录备份信息
backup_info = BackupInfo(
backup_id=backup_id,
backup_name=backup_name,
description=description,
created_at=time.time(),
size=size,
checksum=checksum,
is_compressed=compress,
is_incremental=False
)
self._backups[backup_id] = backup_info
self._data_cache[backup_id] = data
return BackupResult(
success=True,
backup_id=backup_id,
backup_path=backup_path,
size=size,
checksum=checksum
)
except Exception as e:
return BackupResult(success=False, message=str(e))
def backup_incremental(
self,
base_backup_id: str,
data: Any,
backup_name: str,
description: str = ""
) -> BackupResult:
"""
创建增量备份
Args:
base_backup_id: 基础备份ID
data: 要备份的数据
backup_name: 备份名称
description: 备份描述
Returns:
备份结果
"""
# 简化实现:增量备份存储差异
result = self.backup(data, backup_name, description)
if result.success and result.backup_id:
# 标记为增量备份
backup_info = self._backups.get(result.backup_id)
if backup_info:
backup_info.is_incremental = True
backup_info.base_backup_id = base_backup_id
return result
def restore(self, backup_id: str) -> RestoreResult:
"""
恢复备份
Args:
backup_id: 备份ID
Returns:
恢复结果
"""
try:
# 检查缓存
if backup_id in self._data_cache:
return RestoreResult(
success=True,
data=self._data_cache[backup_id]
)
# 查找备份文件
backup_info = self._backups.get(backup_id)
if not backup_info:
return RestoreResult(success=False, message="备份不存在")
# 查找文件
backup_path = None
for filename in os.listdir(self._backup_dir):
if backup_id in filename:
backup_path = os.path.join(self._backup_dir, filename)
break
if not backup_path or not os.path.exists(backup_path):
return RestoreResult(success=False, message="备份文件不存在")
# 读取数据
if backup_info.is_compressed or backup_path.endswith('.gz'):
with gzip.open(backup_path, 'rt', encoding='utf-8') as f:
json_data = f.read()
else:
with open(backup_path, 'r', encoding='utf-8') as f:
json_data = f.read()
# 解析数据
data = json.loads(json_data)
return RestoreResult(success=True, data=data)
except Exception as e:
return RestoreResult(success=False, message=str(e))
def list_backups(
self,
name_filter: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None
) -> List[BackupInfo]:
"""
列出备份
Args:
name_filter: 名称过滤
start_time: 开始时间
end_time: 结束时间
Returns:
备份信息列表
"""
result = list(self._backups.values())
if name_filter:
result = [b for b in result if name_filter in b.backup_name]
if start_time:
result = [b for b in result if b.created_at >= start_time]
if end_time:
result = [b for b in result if b.created_at <= end_time]
# 按时间排序
result.sort(key=lambda x: x.created_at, reverse=True)
return result
def verify_backup(self, backup_id: str) -> VerifyResult:
"""
验证备份
Args:
backup_id: 备份ID
Returns:
验证结果
"""
try:
backup_info = self._backups.get(backup_id)
if not backup_info:
return VerifyResult(is_valid=False, message="备份不存在")
# 查找文件
backup_path = None
for filename in os.listdir(self._backup_dir):
if backup_id in filename:
backup_path = os.path.join(self._backup_dir, filename)
break
if not backup_path or not os.path.exists(backup_path):
return VerifyResult(is_valid=False, message="备份文件不存在")
# 读取并验证校验和
if backup_info.is_compressed or backup_path.endswith('.gz'):
with gzip.open(backup_path, 'rt', encoding='utf-8') as f:
json_data = f.read()
else:
with open(backup_path, 'r', encoding='utf-8') as f:
json_data = f.read()
current_checksum = hashlib.md5(json_data.encode('utf-8')).hexdigest()
if current_checksum != backup_info.checksum:
return VerifyResult(is_valid=False, message="校验和不匹配")
return VerifyResult(is_valid=True, message="备份有效")
except Exception as e:
return VerifyResult(is_valid=False, message=str(e))
def delete_backup(self, backup_id: str) -> DeleteResult:
"""
删除备份
Args:
backup_id: 备份ID
Returns:
删除结果
"""
try:
# 查找并删除文件
for filename in os.listdir(self._backup_dir):
if backup_id in filename:
file_path = os.path.join(self._backup_dir, filename)
if os.path.exists(file_path):
os.unlink(file_path)
break
# 从记录中移除
if backup_id in self._backups:
del self._backups[backup_id]
if backup_id in self._data_cache:
del self._data_cache[backup_id]
return DeleteResult(success=True)
except Exception as e:
return DeleteResult(success=False, message=str(e))
class BackupScheduler:
"""
备份调度器
支持定时自动备份
"""
def __init__(self, backup_dir: str):
"""
初始化备份调度器
Args:
backup_dir: 备份目录
"""
self._backup_manager = BackupManager(backup_dir)
self._schedules: Dict[str, Dict[str, Any]] = {}
def schedule_backup(
self,
data_source: Callable[[], Any],
backup_name: str,
interval_hours: int = 24,
keep_count: int = 5,
description: str = ""
) -> None:
"""
配置自动备份
Args:
data_source: 数据源函数
backup_name: 备份名称
interval_hours: 备份间隔(小时)
keep_count: 保留数量
description: 备份描述
"""
self._schedules[backup_name] = {
"data_source": data_source,
"interval_hours": interval_hours,
"keep_count": keep_count,
"description": description,
"last_backup_time": 0
}
def trigger_backup(self, backup_name: str) -> BackupResult:
"""
手动触发备份
Args:
backup_name: 备份名称
Returns:
备份结果
"""
schedule = self._schedules.get(backup_name)
if not schedule:
return BackupResult(success=False, message="备份计划不存在")
# 获取数据
data = schedule["data_source"]()
# 执行备份
result = self._backup_manager.backup(
data=data,
backup_name=backup_name,
description=schedule["description"]
)
if result.success:
schedule["last_backup_time"] = time.time()
# 清理旧备份
self._cleanup_old_backups(backup_name, schedule["keep_count"])
return result
def _cleanup_old_backups(self, backup_name: str, keep_count: int) -> None:
"""清理旧备份"""
backups = self._backup_manager.list_backups(name_filter=backup_name)
if len(backups) > keep_count:
# 删除最旧的备份
for backup in backups[keep_count:]:
self._backup_manager.delete_backup(backup.backup_id)
@@ -0,0 +1,145 @@
"""
缓存模块
提供内存缓存功能。
"""
import time
import threading
from typing import Dict, Any, Optional
from dataclasses import dataclass, field
@dataclass
class CacheEntry:
"""缓存条目"""
value: Any
expires_at: Optional[float] = None
created_at: float = field(default_factory=time.time)
class Cache:
"""缓存类"""
def __init__(self):
self._data: Dict[str, CacheEntry] = {}
self._lock = threading.Lock()
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
"""
设置缓存
Args:
key: 缓存键
value: 缓存值
ttl: 过期时间(秒)
"""
with self._lock:
expires_at = None
if ttl is not None:
expires_at = time.time() + ttl
self._data[key] = CacheEntry(
value=value,
expires_at=expires_at
)
def get(self, key: str) -> Optional[Any]:
"""
获取缓存
Args:
key: 缓存键
Returns:
缓存值,如果不存在或已过期则返回None
"""
with self._lock:
entry = self._data.get(key)
if entry is None:
return None
# 检查是否过期
if entry.expires_at is not None and time.time() > entry.expires_at:
del self._data[key]
return None
return entry.value
def delete(self, key: str) -> bool:
"""
删除缓存
Args:
key: 缓存键
Returns:
是否成功删除
"""
with self._lock:
if key in self._data:
del self._data[key]
return True
return False
def clear(self) -> None:
"""清除所有缓存"""
with self._lock:
self._data.clear()
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
统计信息字典
"""
with self._lock:
# 清理过期条目
current_time = time.time()
expired_keys = [
key for key, entry in self._data.items()
if entry.expires_at is not None and current_time > entry.expires_at
]
for key in expired_keys:
del self._data[key]
return {
"size": len(self._data),
"keys": list(self._data.keys())
}
def has(self, key: str) -> bool:
"""
检查缓存是否存在
Args:
key: 缓存键
Returns:
是否存在
"""
return self.get(key) is not None
def get_all(self) -> Dict[str, Any]:
"""
获取所有缓存
Returns:
所有缓存数据
"""
with self._lock:
result = {}
current_time = time.time()
for key, entry in self._data.items():
# 检查是否过期
if entry.expires_at is not None and current_time > entry.expires_at:
continue
result[key] = entry.value
return result
# 全局缓存实例
cache = Cache()
@@ -0,0 +1,396 @@
"""
Caffeine缓存管理模块
基于Caffeine的本地缓存管理实现。
"""
import time
import threading
from typing import Any, Dict, List, Optional, Callable
from dataclasses import dataclass, field
from collections import OrderedDict
@dataclass
class CacheEntry:
"""缓存条目"""
value: Any
expires_at: Optional[float] = None
created_at: float = field(default_factory=time.time)
access_count: int = field(default=0)
last_accessed: float = field(default_factory=time.time)
class CaffeineCache:
"""
Caffeine风格的本地缓存实现
特性:
- 支持TTL过期时间
- 支持最大容量限制(LRU淘汰)
- 支持统计信息
- 线程安全
- 批量操作支持
"""
def __init__(
self,
max_size: int = 1000,
default_expire_seconds: Optional[int] = None,
record_stats: bool = False
):
"""
初始化缓存
Args:
max_size: 最大缓存条目数
default_expire_seconds: 默认过期时间(秒)
record_stats: 是否记录统计信息
"""
self._max_size = max_size
self._default_expire_seconds = default_expire_seconds
self._record_stats = record_stats
self._data: Dict[str, CacheEntry] = {}
self._lock = threading.RLock()
# 统计信息
self._stats = {
"hit_count": 0,
"miss_count": 0,
"put_count": 0,
"delete_count": 0,
"eviction_count": 0,
}
def put(self, key: str, value: Any, expire_seconds: Optional[int] = None) -> None:
"""
添加缓存条目
Args:
key: 缓存键
value: 缓存值
expire_seconds: 过期时间(秒),None表示使用默认值,0表示永不过期
"""
with self._lock:
# 计算过期时间
if expire_seconds is not None:
expires_at = time.time() + expire_seconds if expire_seconds > 0 else None
elif self._default_expire_seconds is not None:
expires_at = time.time() + self._default_expire_seconds
else:
expires_at = None
# 创建缓存条目
entry = CacheEntry(
value=value,
expires_at=expires_at
)
# 检查是否需要淘汰
if key not in self._data and len(self._data) >= self._max_size:
self._evict_oldest()
# 存储数据
self._data[key] = entry
if self._record_stats:
self._stats["put_count"] += 1
def get(self, key: str) -> Optional[Any]:
"""
获取缓存值
Args:
key: 缓存键
Returns:
缓存值,不存在或已过期则返回None
"""
with self._lock:
entry = self._data.get(key)
if entry is None:
if self._record_stats:
self._stats["miss_count"] += 1
return None
# 检查是否过期
if entry.expires_at is not None and time.time() > entry.expires_at:
del self._data[key]
if self._record_stats:
self._stats["miss_count"] += 1
self._stats["eviction_count"] += 1
return None
# 更新访问信息
entry.access_count += 1
entry.last_accessed = time.time()
if self._record_stats:
self._stats["hit_count"] += 1
return entry.value
def exists(self, key: str) -> bool:
"""
检查键是否存在
Args:
key: 缓存键
Returns:
是否存在且未过期
"""
with self._lock:
entry = self._data.get(key)
if entry is None:
return False
# 检查是否过期
if entry.expires_at is not None and time.time() > entry.expires_at:
del self._data[key]
return False
return True
def delete(self, key: str) -> bool:
"""
删除缓存条目
Args:
key: 缓存键
Returns:
是否成功删除
"""
with self._lock:
if key in self._data:
del self._data[key]
if self._record_stats:
self._stats["delete_count"] += 1
return True
return False
def get_all(self, keys: List[str]) -> Dict[str, Optional[Any]]:
"""
批量获取缓存值
Args:
keys: 缓存键列表
Returns:
键值对字典
"""
result = {}
for key in keys:
result[key] = self.get(key)
return result
def put_all(self, data: Dict[str, Any], expire_seconds: Optional[int] = None) -> None:
"""
批量添加缓存条目
Args:
data: 键值对字典
expire_seconds: 过期时间(秒)
"""
for key, value in data.items():
self.put(key, value, expire_seconds)
def delete_all(self, keys: List[str]) -> int:
"""
批量删除缓存条目
Args:
keys: 缓存键列表
Returns:
成功删除的数量
"""
count = 0
for key in keys:
if self.delete(key):
count += 1
return count
def clear(self) -> None:
"""清空所有缓存"""
with self._lock:
self._data.clear()
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
统计信息字典
"""
with self._lock:
stats = self._stats.copy()
stats["size"] = len(self._data)
stats["max_size"] = self._max_size
# 计算命中率
total_requests = stats["hit_count"] + stats["miss_count"]
if total_requests > 0:
stats["hit_rate"] = stats["hit_count"] / total_requests
else:
stats["hit_rate"] = 0.0
return stats
def _evict_oldest(self) -> None:
"""淘汰最久未使用的条目"""
if not self._data:
return
# 找到最久未访问的条目
oldest_key = min(
self._data.keys(),
key=lambda k: self._data[k].last_accessed
)
del self._data[oldest_key]
if self._record_stats:
self._stats["eviction_count"] += 1
def size(self) -> int:
"""
获取当前缓存大小
Returns:
缓存条目数
"""
with self._lock:
return len(self._data)
def keys(self) -> List[str]:
"""
获取所有缓存键
Returns:
缓存键列表
"""
with self._lock:
return list(self._data.keys())
def values(self) -> List[Any]:
"""
获取所有缓存值
Returns:
缓存值列表
"""
with self._lock:
return [entry.value for entry in self._data.values()]
def items(self) -> Dict[str, Any]:
"""
获取所有缓存项
Returns:
键值对字典
"""
with self._lock:
return {k: v.value for k, v in self._data.items()}
def cleanup_expired(self) -> int:
"""
清理过期条目
Returns:
清理的条目数
"""
with self._lock:
current_time = time.time()
expired_keys = [
key for key, entry in self._data.items()
if entry.expires_at is not None and current_time > entry.expires_at
]
for key in expired_keys:
del self._data[key]
if self._record_stats:
self._stats["eviction_count"] += 1
return len(expired_keys)
class CaffeineCacheManager:
"""
Caffeine缓存管理器
管理多个命名缓存实例
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._caches = {}
return cls._instance
def get_cache(
self,
name: str,
max_size: int = 1000,
default_expire_seconds: Optional[int] = None,
record_stats: bool = False
) -> CaffeineCache:
"""
获取或创建命名缓存
Args:
name: 缓存名称
max_size: 最大缓存条目数
default_expire_seconds: 默认过期时间
record_stats: 是否记录统计信息
Returns:
缓存实例
"""
if name not in self._caches:
self._caches[name] = CaffeineCache(
max_size=max_size,
default_expire_seconds=default_expire_seconds,
record_stats=record_stats
)
return self._caches[name]
def remove_cache(self, name: str) -> bool:
"""
移除缓存
Args:
name: 缓存名称
Returns:
是否成功移除
"""
if name in self._caches:
del self._caches[name]
return True
return False
def clear_all(self) -> None:
"""清空所有缓存"""
for cache in self._caches.values():
cache.clear()
def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
"""
获取所有缓存的统计信息
Returns:
缓存名称到统计信息的映射
"""
return {name: cache.get_stats() for name, cache in self._caches.items()}
# 全局缓存管理器实例
cache_manager = CaffeineCacheManager()
@@ -0,0 +1,470 @@
"""
本地并发控制模块
提供本地并发控制的各种机制,包括信号量、读写锁、限流器等。
"""
import threading
import time
from typing import Any, Dict, List, Optional, TypeVar, Generic
from contextlib import contextmanager
from collections import deque
T = TypeVar('T')
class SemaphoreControl:
"""
信号量并发控制
限制同时执行的线程数量
"""
def __init__(self, max_concurrent: int):
"""
初始化信号量
Args:
max_concurrent: 最大并发数
"""
self._max_concurrent = max_concurrent
self._semaphore = threading.Semaphore(max_concurrent)
self._stats = {
"total_acquisitions": 0,
"active_count": 0,
"peak_count": 0,
}
self._lock = threading.Lock()
@contextmanager
def acquire(self, timeout: Optional[float] = None):
"""
获取信号量(上下文管理器)
Args:
timeout: 超时时间(秒)
"""
acquired = False
try:
acquired = self._semaphore.acquire(timeout=timeout)
if not acquired:
raise TimeoutError("获取信号量超时")
with self._lock:
self._stats["total_acquisitions"] += 1
self._stats["active_count"] += 1
self._stats["peak_count"] = max(
self._stats["peak_count"],
self._stats["active_count"]
)
yield self
finally:
if acquired:
with self._lock:
self._stats["active_count"] -= 1
self._semaphore.release()
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
with self._lock:
return self._stats.copy()
class ReadWriteLock:
"""
读写锁
支持多个读线程同时访问,写线程独占访问
"""
def __init__(self):
self._read_lock = threading.Lock()
self._write_lock = threading.Lock()
self._read_count = 0
@contextmanager
def read_lock(self):
"""获取读锁"""
with self._read_lock:
self._read_count += 1
if self._read_count == 1:
self._write_lock.acquire()
try:
yield self
finally:
with self._read_lock:
self._read_count -= 1
if self._read_count == 0:
self._write_lock.release()
@contextmanager
def write_lock(self):
"""获取写锁"""
self._write_lock.acquire()
try:
yield self
finally:
self._write_lock.release()
class RateLimiter:
"""
限流器
限制单位时间内的请求数量
"""
def __init__(self, max_requests: int, time_window: float):
"""
初始化限流器
Args:
max_requests: 时间窗口内最大请求数
time_window: 时间窗口(秒)
"""
self._max_requests = max_requests
self._time_window = time_window
self._requests = deque()
self._lock = threading.Lock()
self._stats = {
"total_requests": 0,
"allowed_requests": 0,
"blocked_requests": 0,
}
def allow_request(self) -> bool:
"""
检查是否允许请求
Returns:
是否允许
"""
with self._lock:
now = time.time()
# 清理过期的请求记录
while self._requests and self._requests[0] < now - self._time_window:
self._requests.popleft()
self._stats["total_requests"] += 1
if len(self._requests) < self._max_requests:
self._requests.append(now)
self._stats["allowed_requests"] += 1
return True
else:
self._stats["blocked_requests"] += 1
return False
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
with self._lock:
return self._stats.copy()
class LocalDistributedLock:
"""
本地模拟的分布式锁
用于单体应用内的分布式锁场景模拟
"""
_locks: Dict[str, Dict[str, Any]] = {}
_global_lock = threading.Lock()
def __init__(self, resource_name: str, expire_seconds: float = 30.0):
"""
初始化分布式锁
Args:
resource_name: 资源名称
expire_seconds: 锁过期时间(秒)
"""
self._resource_name = resource_name
self._expire_seconds = expire_seconds
self._lock = threading.Lock()
def acquire(self, timeout: float = 10.0) -> bool:
"""
获取锁
Args:
timeout: 超时时间(秒)
"""
start_time = time.time()
while time.time() - start_time < timeout:
with self._global_lock:
now = time.time()
lock_info = self._locks.get(self._resource_name)
# 检查锁是否过期
if lock_info and now > lock_info["expires_at"]:
del self._locks[self._resource_name]
lock_info = None
# 尝试获取锁
if not lock_info:
self._locks[self._resource_name] = {
"expires_at": now + self._expire_seconds,
}
return True
time.sleep(0.1)
return False
def release(self) -> None:
"""释放锁"""
with self._global_lock:
if self._resource_name in self._locks:
del self._locks[self._resource_name]
class ConcurrentCounter:
"""
并发计数器
线程安全的计数器实现
"""
def __init__(self, initial_value: int = 0):
"""
初始化计数器
Args:
initial_value: 初始值
"""
self._value = initial_value
self._lock = threading.Lock()
def increment(self, delta: int = 1) -> int:
"""
增加计数
Args:
delta: 增量
Returns:
增加后的值
"""
with self._lock:
self._value += delta
return self._value
def decrement(self, delta: int = 1) -> int:
"""
减少计数
Args:
delta: 减量
Returns:
减少后的值
"""
with self._lock:
self._value -= delta
return self._value
def get_value(self) -> int:
"""获取当前值"""
with self._lock:
return self._value
def reset(self, value: int = 0) -> None:
"""重置计数器"""
with self._lock:
self._value = value
class ThreadBarrier:
"""
线程屏障
等待指定数量的线程到达后同时放行
"""
def __init__(self, parties: int):
"""
初始化屏障
Args:
parties: 需要等待的线程数
"""
self._parties = parties
self._count = 0
self._lock = threading.Lock()
self._condition = threading.Condition(self._lock)
def wait(self, timeout: Optional[float] = None) -> bool:
"""
等待屏障
Args:
timeout: 超时时间(秒)
Returns:
是否成功等待
"""
with self._condition:
self._count += 1
if self._count >= self._parties:
# 所有线程已到达,放行
self._count = 0
self._condition.notify_all()
return True
else:
# 等待其他线程
return self._condition.wait(timeout=timeout)
class BoundedTaskQueue(Generic[T]):
"""
有界任务队列
有容量限制的线程安全队列
"""
def __init__(self, max_size: int):
"""
初始化队列
Args:
max_size: 最大容量
"""
self._max_size = max_size
self._queue = deque(maxlen=max_size)
self._lock = threading.Lock()
self._not_full = threading.Condition(self._lock)
self._not_empty = threading.Condition(self._lock)
def put(self, item: T, timeout: Optional[float] = None) -> bool:
"""
添加元素
Args:
item: 元素
timeout: 超时时间(秒)
Returns:
是否成功
"""
with self._not_full:
if len(self._queue) >= self._max_size:
if not self._not_full.wait(timeout=timeout):
raise TimeoutError("队列已满")
self._queue.append(item)
self._not_empty.notify()
return True
def get(self, timeout: Optional[float] = None) -> T:
"""
获取元素
Args:
timeout: 超时时间(秒)
Returns:
元素
"""
with self._not_empty:
while len(self._queue) == 0:
if not self._not_empty.wait(timeout=timeout):
raise TimeoutError("队列为空")
item = self._queue.popleft()
self._not_full.notify()
return item
def size(self) -> int:
"""获取队列大小"""
with self._lock:
return len(self._queue)
def is_full(self) -> bool:
"""检查是否已满"""
with self._lock:
return len(self._queue) >= self._max_size
def is_empty(self) -> bool:
"""检查是否为空"""
with self._lock:
return len(self._queue) == 0
class ConcurrencyManager:
"""
并发控制器管理器
单例模式管理所有并发控制组件
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._semaphores: Dict[str, SemaphoreControl] = {}
cls._instance._rw_locks: Dict[str, ReadWriteLock] = {}
cls._instance._rate_limiters: Dict[str, RateLimiter] = {}
cls._instance._locks: Dict[str, LocalDistributedLock] = {}
return cls._instance
def create_semaphore(self, name: str, max_concurrent: int) -> SemaphoreControl:
"""创建命名信号量"""
if name not in self._semaphores:
self._semaphores[name] = SemaphoreControl(max_concurrent)
return self._semaphores[name]
def get_semaphore(self, name: str) -> Optional[SemaphoreControl]:
"""获取命名信号量"""
return self._semaphores.get(name)
def create_rw_lock(self, name: str) -> ReadWriteLock:
"""创建命名读写锁"""
if name not in self._rw_locks:
self._rw_locks[name] = ReadWriteLock()
return self._rw_locks[name]
def get_rw_lock(self, name: str) -> Optional[ReadWriteLock]:
"""获取命名读写锁"""
return self._rw_locks.get(name)
def create_rate_limiter(self, name: str, max_requests: int, time_window: float) -> RateLimiter:
"""创建命名限流器"""
if name not in self._rate_limiters:
self._rate_limiters[name] = RateLimiter(max_requests, time_window)
return self._rate_limiters[name]
def get_rate_limiter(self, name: str) -> Optional[RateLimiter]:
"""获取命名限流器"""
return self._rate_limiters.get(name)
def create_lock(self, name: str, expire_seconds: float = 30.0) -> LocalDistributedLock:
"""创建命名分布式锁"""
if name not in self._locks:
self._locks[name] = LocalDistributedLock(name, expire_seconds)
return self._locks[name]
def get_lock(self, name: str) -> Optional[LocalDistributedLock]:
"""获取命名分布式锁"""
return self._locks.get(name)
def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
"""获取所有组件统计信息"""
return {
"semaphores": {name: sem.get_stats() for name, sem in self._semaphores.items()},
"rate_limiters": {name: rl.get_stats() for name, rl in self._rate_limiters.items()},
}
# 全局并发管理器实例
concurrency_manager = ConcurrencyManager()
@@ -0,0 +1,193 @@
"""
配置管理器模块
提供统一的测试环境配置管理,支持多环境配置。
"""
import os
import yaml
from typing import Dict, Any, Optional
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class TimeoutConfig:
"""超时配置"""
default: int = 30000
navigation: int = 30000
element: int = 10000
network: int = 30000
@dataclass
class ScreenshotConfig:
"""截图配置"""
enabled: bool = True
on_failure: bool = True
path: str = "reports/screenshots"
@dataclass
class VideoConfig:
"""录像配置"""
enabled: bool = False
on_failure: bool = True
path: str = "reports/videos"
@dataclass
class TraceConfig:
"""追踪配置"""
enabled: bool = False
on_failure: bool = True
path: str = "reports/traces"
@dataclass
class BrowserConfig:
"""浏览器配置"""
name: str = "chromium"
headless: bool = False
viewport_width: int = 1920
viewport_height: int = 1080
slow_mo: int = 0
@dataclass
class TestConfig:
"""测试配置"""
# 环境信息
env: str = "dev"
base_url: str = ""
api_url: str = ""
# 超时配置
timeout: TimeoutConfig = field(default_factory=TimeoutConfig)
# 截图配置
screenshot: ScreenshotConfig = field(default_factory=ScreenshotConfig)
# 录像配置
video: VideoConfig = field(default_factory=VideoConfig)
# 追踪配置
trace: TraceConfig = field(default_factory=TraceConfig)
# 浏览器配置
browser: BrowserConfig = field(default_factory=BrowserConfig)
# 重试配置
retries: int = 2
# 并行配置
workers: int = 1
parallel: bool = False
# 测试数据
test_data: Dict[str, Any] = field(default_factory=dict)
# 用户数据
users: Dict[str, Any] = field(default_factory=dict)
class ConfigManager:
"""配置管理器"""
_instance: Optional["ConfigManager"] = None
_config: Optional[TestConfig] = None
def __new__(cls) -> "ConfigManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if self._config is None:
self._config = self._load_config()
def _load_config(self) -> TestConfig:
"""加载配置"""
# 获取环境变量
env = os.getenv("TEST_ENV", "dev")
# 配置文件路径
config_path = Path(__file__).parent.parent / "config" / "config.yaml"
# 默认配置
config = TestConfig(env=env)
# 从文件加载配置
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if data and "environments" in data:
env_config = data["environments"].get(env, {})
# 基础配置 - 支持嵌套的admin/uniapp配置
if "admin" in env_config:
config.base_url = env_config["admin"].get("base_url", "")
config.api_url = env_config["admin"].get("api_url", "")
else:
config.base_url = env_config.get("base_url", "")
config.api_url = env_config.get("api_url", "")
# 超时配置
if "timeout" in env_config:
timeout = env_config["timeout"]
config.timeout = TimeoutConfig(
default=timeout.get("default", 30000),
navigation=timeout.get("navigation", 30000),
element=timeout.get("element", 10000),
network=timeout.get("network", 30000),
)
# 浏览器配置
if "browser" in env_config:
browser = env_config["browser"]
config.browser = BrowserConfig(
name=browser.get("name", "chromium"),
headless=browser.get("headless", False),
viewport_width=browser.get("viewport_width", 1920),
viewport_height=browser.get("viewport_height", 1080),
slow_mo=browser.get("slow_mo", 0),
)
# 用户数据
if "users" in data:
config.users = data["users"]
# 测试数据
if "test_data" in data:
config.test_data = data["test_data"]
# 从环境变量覆盖配置
config.base_url = os.getenv("TEST_BASE_URL", config.base_url)
config.api_url = os.getenv("TEST_API_URL", config.api_url)
config.browser.headless = os.getenv("TEST_HEADLESS", "false").lower() == "true"
return config
@property
def config(self) -> TestConfig:
"""获取配置"""
return self._config
def get_config(self) -> TestConfig:
"""获取配置(兼容方法)"""
return self._config
def reload(self) -> None:
"""重新加载配置"""
self._config = self._load_config()
def update_config(self, **kwargs) -> None:
"""更新配置"""
for key, value in kwargs.items():
if hasattr(self._config, key):
setattr(self._config, key, value)
# 全局配置管理器实例
config_manager = ConfigManager()
@@ -0,0 +1,452 @@
"""
数据库连接池管理模块
提供数据库连接池的创建、管理和监控功能。
"""
import time
import threading
import uuid
from typing import Any, Dict, List, Optional, Callable
from dataclasses import dataclass, field
from enum import Enum
import queue
class ConnectionStatus(Enum):
"""连接状态"""
IDLE = "idle"
ACTIVE = "active"
CLOSED = "closed"
UNHEALTHY = "unhealthy"
@dataclass
class Connection:
"""数据库连接封装"""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
status: ConnectionStatus = ConnectionStatus.IDLE
created_at: float = field(default_factory=time.time)
last_used: float = field(default_factory=time.time)
use_count: int = 0
host: str = ""
port: int = 3306
database: str = ""
user: str = ""
password: str = ""
def is_valid(self) -> bool:
"""检查连接是否有效"""
return self.status in [ConnectionStatus.IDLE, ConnectionStatus.ACTIVE]
def execute(self, query: str) -> Any:
"""执行查询(模拟)"""
if not self.is_valid():
raise Exception("连接无效")
return f"Result of: {query}"
def close(self) -> None:
"""关闭连接"""
self.status = ConnectionStatus.CLOSED
class ConnectionPool:
"""
数据库连接池
特性:
- 支持最小/最大连接数配置
- 支持连接超时等待
- 支持健康检查
- 支持自动扩容
- 线程安全
"""
def __init__(
self,
min_connections: int = 2,
max_connections: int = 10,
host: str = "localhost",
port: int = 3306,
database: str = "",
user: str = "",
password: str = "",
connection_timeout: int = 30,
health_check_interval: int = 60,
auto_scale: bool = False
):
"""
初始化连接池
Args:
min_connections: 最小连接数
max_connections: 最大连接数
host: 数据库主机
port: 数据库端口
database: 数据库名
user: 用户名
password: 密码
connection_timeout: 连接超时时间(秒)
health_check_interval: 健康检查间隔(秒)
auto_scale: 是否自动扩容
"""
self._min_connections = min_connections
self._max_connections = max_connections
self._host = host
self._port = port
self._database = database
self._user = user
self._password = password
self._connection_timeout = connection_timeout
self._health_check_interval = health_check_interval
self._auto_scale = auto_scale
# 连接池
self._idle_connections: queue.Queue[Connection] = queue.Queue()
self._active_connections: Dict[str, Connection] = {}
self._all_connections: Dict[str, Connection] = {}
# 锁
self._lock = threading.RLock()
self._condition = threading.Condition(self._lock)
# 统计信息
self._stats = {
"total_get_count": 0,
"total_release_count": 0,
"total_wait_count": 0,
"total_wait_time": 0.0,
"health_check_count": 0,
"unhealthy_count": 0,
}
# 健康检查线程
self._health_check_thread: Optional[threading.Thread] = None
self._shutdown = False
# 初始化最小连接数
self._initialize_min_connections()
# 启动健康检查
if health_check_interval > 0:
self._start_health_check()
def _initialize_min_connections(self) -> None:
"""初始化最小连接数"""
for _ in range(self._min_connections):
conn = self._create_connection()
self._idle_connections.put(conn)
self._all_connections[conn.id] = conn
def _create_connection(self) -> Connection:
"""创建新连接"""
conn = Connection(
host=self._host,
port=self._port,
database=self._database,
user=self._user,
password=self._password
)
return conn
def get_connection(self, timeout: Optional[int] = None) -> Connection:
"""
获取连接
Args:
timeout: 超时时间(秒)None表示使用默认值
Returns:
数据库连接
Raises:
Exception: 超时或连接池已关闭
"""
if timeout is None:
timeout = self._connection_timeout
with self._condition:
self._stats["total_get_count"] += 1
# 尝试获取空闲连接
while not self._shutdown:
# 如果有空闲连接,直接返回
try:
conn = self._idle_connections.get_nowait()
if conn.is_valid():
conn.status = ConnectionStatus.ACTIVE
conn.last_used = time.time()
conn.use_count += 1
self._active_connections[conn.id] = conn
return conn
else:
# 连接无效,移除
self._remove_connection(conn)
except queue.Empty:
pass
# 如果没有空闲连接,尝试创建新连接
if len(self._all_connections) < self._max_connections:
conn = self._create_connection()
conn.status = ConnectionStatus.ACTIVE
conn.last_used = time.time()
conn.use_count += 1
self._active_connections[conn.id] = conn
self._all_connections[conn.id] = conn
return conn
# 如果达到最大连接数,等待
self._stats["total_wait_count"] += 1
start_wait = time.time()
if not self._condition.wait(timeout=timeout):
raise Exception(f"获取连接超时({timeout}秒)")
self._stats["total_wait_time"] += time.time() - start_wait
raise Exception("连接池已关闭")
def release_connection(self, conn: Connection) -> None:
"""
释放连接
Args:
conn: 要释放的连接
"""
with self._condition:
if conn.id in self._active_connections:
del self._active_connections[conn.id]
if conn.is_valid():
conn.status = ConnectionStatus.IDLE
conn.last_used = time.time()
self._idle_connections.put(conn)
self._stats["total_release_count"] += 1
else:
# 连接无效,移除
self._remove_connection(conn)
# 通知等待的线程
self._condition.notify()
def _remove_connection(self, conn: Connection) -> None:
"""移除连接"""
conn.close()
if conn.id in self._all_connections:
del self._all_connections[conn.id]
if conn.id in self._active_connections:
del self._active_connections[conn.id]
def close(self) -> None:
"""关闭连接池"""
self._shutdown = True
with self._condition:
# 关闭所有连接
for conn in self._all_connections.values():
conn.close()
self._all_connections.clear()
self._active_connections.clear()
# 清空空闲队列
while not self._idle_connections.empty():
try:
self._idle_connections.get_nowait()
except queue.Empty:
break
# 通知所有等待的线程
self._condition.notify_all()
def get_stats(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典
"""
with self._lock:
return {
"total_connections": len(self._all_connections),
"idle_connections": self._idle_connections.qsize(),
"active_connections": len(self._active_connections),
"min_connections": self._min_connections,
"max_connections": self._max_connections,
"total_get_count": self._stats["total_get_count"],
"total_release_count": self._stats["total_release_count"],
"total_wait_count": self._stats["total_wait_count"],
"total_wait_time": self._stats["total_wait_time"],
"health_check_count": self._stats["health_check_count"],
"unhealthy_count": self._stats["unhealthy_count"],
}
def health_check(self) -> bool:
"""
执行健康检查
Returns:
是否所有连接都健康
"""
with self._lock:
self._stats["health_check_count"] += 1
unhealthy_count = 0
for conn in list(self._all_connections.values()):
if not conn.is_valid():
unhealthy_count += 1
self._remove_connection(conn)
self._stats["unhealthy_count"] += unhealthy_count
# 补充最小连接数
while len(self._all_connections) < self._min_connections:
conn = self._create_connection()
self._idle_connections.put(conn)
self._all_connections[conn.id] = conn
return unhealthy_count == 0
def get_health_stats(self) -> Dict[str, Any]:
"""
获取健康统计
Returns:
健康统计字典
"""
with self._lock:
healthy = sum(1 for conn in self._all_connections.values() if conn.is_valid())
unhealthy = len(self._all_connections) - healthy
return {
"healthy_connections": healthy,
"unhealthy_connections": unhealthy,
"health_check_count": self._stats["health_check_count"],
"total_unhealthy_count": self._stats["unhealthy_count"],
}
def _start_health_check(self) -> None:
"""启动健康检查线程"""
def health_check_worker():
while not self._shutdown:
time.sleep(self._health_check_interval)
if not self._shutdown:
self.health_check()
self._health_check_thread = threading.Thread(
target=health_check_worker,
daemon=True
)
self._health_check_thread.start()
class ConnectionPoolManager:
"""
连接池管理器
单例模式管理多个命名连接池
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._pools = {}
return cls._instance
def create_pool(
self,
name: str,
min_connections: int = 2,
max_connections: int = 10,
host: str = "localhost",
port: int = 3306,
database: str = "",
user: str = "",
password: str = "",
**kwargs
) -> ConnectionPool:
"""
创建命名连接池
Args:
name: 连接池名称
min_connections: 最小连接数
max_connections: 最大连接数
host: 数据库主机
port: 数据库端口
database: 数据库名
user: 用户名
password: 密码
**kwargs: 其他参数
Returns:
连接池实例
"""
if name in self._pools:
return self._pools[name]
pool = ConnectionPool(
min_connections=min_connections,
max_connections=max_connections,
host=host,
port=port,
database=database,
user=user,
password=password,
**kwargs
)
self._pools[name] = pool
return pool
def get_pool(self, name: str) -> Optional[ConnectionPool]:
"""
获取命名连接池
Args:
name: 连接池名称
Returns:
连接池实例,不存在则返回None
"""
return self._pools.get(name)
def remove_pool(self, name: str) -> bool:
"""
移除连接池
Args:
name: 连接池名称
Returns:
是否成功移除
"""
if name in self._pools:
self._pools[name].close()
del self._pools[name]
return True
return False
def close_all(self) -> None:
"""关闭所有连接池"""
for pool in self._pools.values():
pool.close()
self._pools.clear()
def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
"""
获取所有连接池统计
Returns:
连接池名称到统计信息的映射
"""
return {name: pool.get_stats() for name, pool in self._pools.items()}
# 全局连接池管理器实例
pool_manager = ConnectionPoolManager()
@@ -0,0 +1,317 @@
"""
数据导入导出功能模块
提供Excel/CSV数据的导入导出功能。
"""
import csv
import os
import re
from typing import Any, Dict, List, Optional, Callable
from dataclasses import dataclass, field
@dataclass
class ExportResult:
"""导出结果"""
success: bool
file_path: Optional[str] = None
record_count: int = 0
message: str = ""
@dataclass
class ImportResult:
"""导入结果"""
success: bool
data: List[Dict[str, Any]] = field(default_factory=list)
record_count: int = 0
error_count: int = 0
message: str = ""
@dataclass
class ValidationResult:
"""验证结果"""
is_valid: bool
errors: List[str] = field(default_factory=list)
class CSVExporter:
"""CSV导出器"""
def export(self, data: List[Dict[str, Any]], file_path: str, encoding: str = 'utf-8') -> ExportResult:
"""
导出数据为CSV
Args:
data: 数据列表
file_path: 输出文件路径
encoding: 文件编码
Returns:
导出结果
"""
try:
if not data:
return ExportResult(success=True, file_path=file_path, record_count=0)
# 获取表头
headers = list(data[0].keys())
with open(file_path, 'w', newline='', encoding=encoding) as f:
writer = csv.DictWriter(f, fieldnames=headers)
writer.writeheader()
writer.writerows(data)
return ExportResult(
success=True,
file_path=file_path,
record_count=len(data)
)
except Exception as e:
return ExportResult(success=False, message=str(e))
class CSVImporter:
"""CSV导入器"""
def import_file(self, file_path: str, encoding: str = 'utf-8') -> ImportResult:
"""
从CSV文件导入数据
Args:
file_path: CSV文件路径
encoding: 文件编码
Returns:
导入结果
"""
try:
data = []
with open(file_path, 'r', encoding=encoding) as f:
reader = csv.DictReader(f)
for row in reader:
data.append(dict(row))
return ImportResult(
success=True,
data=data,
record_count=len(data)
)
except Exception as e:
return ImportResult(success=False, message=str(e))
class ExcelExporter:
"""Excel导出器(简化版,实际使用时需要openpyxl库)"""
def export(self, data: List[Dict[str, Any]], file_path: str) -> ExportResult:
"""
导出数据为Excel(实际实现需要openpyxl
Args:
data: 数据列表
file_path: 输出文件路径
Returns:
导出结果
"""
try:
# 简化实现:导出为CSV格式但使用.xlsx扩展名
# 实际项目中应该使用openpyxl或pandas
csv_path = file_path.replace('.xlsx', '.csv')
exporter = CSVExporter()
result = exporter.export(data, csv_path)
if result.success:
# 重命名为xlsx
os.rename(csv_path, file_path)
return ExportResult(
success=True,
file_path=file_path,
record_count=len(data)
)
else:
return result
except Exception as e:
return ExportResult(success=False, message=str(e))
class DataValidator:
"""数据验证器"""
def validate(self, data: Dict[str, Any], rules: Dict[str, Dict[str, Any]]) -> ValidationResult:
"""
验证数据
Args:
data: 要验证的数据
rules: 验证规则
Returns:
验证结果
"""
errors = []
for field, rule in rules.items():
value = data.get(field)
# 必填验证
if rule.get('required') and not value:
errors.append(f"{field}: 必填字段不能为空")
continue
if not value:
continue
# 类型验证
field_type = rule.get('type')
if field_type == 'integer':
try:
int(value)
except (ValueError, TypeError):
errors.append(f"{field}: 必须是整数")
continue
elif field_type == 'email':
if not re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', str(value)):
errors.append(f"{field}: 邮箱格式不正确")
continue
# 范围验证
if field_type == 'integer' and isinstance(value, (int, str)):
try:
int_value = int(value)
min_val = rule.get('min')
max_val = rule.get('max')
if min_val is not None and int_value < min_val:
errors.append(f"{field}: 不能小于{min_val}")
if max_val is not None and int_value > max_val:
errors.append(f"{field}: 不能大于{max_val}")
except (ValueError, TypeError):
pass
return ValidationResult(is_valid=len(errors) == 0, errors=errors)
class DataTransformer:
"""数据转换器"""
def transform(self, data: List[Dict[str, Any]], mapping: Dict[str, str]) -> List[Dict[str, Any]]:
"""
转换数据字段映射
Args:
data: 原始数据
mapping: 字段映射 {源字段: 目标字段}
Returns:
转换后的数据
"""
result = []
for item in data:
new_item = {}
for source_field, target_field in mapping.items():
if source_field in item:
new_item[target_field] = item[source_field]
result.append(new_item)
return result
class TemplateManager:
"""模板管理器"""
def generate_template(self, template: Dict[str, Any], file_path: str) -> ExportResult:
"""
生成模板文件
Args:
template: 模板定义
file_path: 输出文件路径
Returns:
生成结果
"""
try:
columns = template.get('columns', [])
headers = [col['name'] for col in columns]
with open(file_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(headers)
return ExportResult(
success=True,
file_path=file_path,
record_count=0
)
except Exception as e:
return ExportResult(success=False, message=str(e))
class DataImportExportManager:
"""数据导入导出管理器"""
def __init__(self):
"""初始化管理器"""
self._csv_exporter = CSVExporter()
self._csv_importer = CSVImporter()
self._stats = {
'total_exports': 0,
'total_imports': 0,
'total_records_exported': 0,
'total_records_imported': 0,
}
def export_batch(
self,
data: List[Dict[str, Any]],
file_path: str,
batch_size: int = 1000
) -> ExportResult:
"""
批量导出数据
Args:
data: 数据列表
file_path: 输出文件路径
batch_size: 批次大小
Returns:
导出结果
"""
result = self._csv_exporter.export(data, file_path)
if result.success:
self._stats['total_exports'] += 1
self._stats['total_records_exported'] += result.record_count
return result
def import_batch(self, file_path: str, batch_size: int = 1000) -> ImportResult:
"""
批量导入数据
Args:
file_path: 文件路径
batch_size: 批次大小
Returns:
导入结果
"""
result = self._csv_importer.import_file(file_path)
if result.success:
self._stats['total_imports'] += 1
self._stats['total_records_imported'] += result.record_count
return result
def get_statistics(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典
"""
return self._stats.copy()
@@ -0,0 +1,124 @@
"""
异常处理器模块
提供测试异常分类处理功能,区分致命错误和可重试错误。
"""
from typing import Optional, List
from playwright.sync_api import Page
class FatalTestError(Exception):
"""致命测试错误"""
pass
class RetryableError(Exception):
"""可重试错误"""
pass
class TestExceptionHandler:
"""测试异常处理器"""
# 致命错误:立即停止测试
FATAL_ERRORS: List[str] = [
"Browser crashed",
"Connection refused",
"Target page closed",
"Database connection failed",
"Session deleted",
"invalid session id",
]
# 可重试错误:自动重试
RETRYABLE_ERRORS: List[str] = [
"TimeoutError",
"Timeout exceeded",
"Element not found",
"Network error",
"net::ERR",
"Stale element reference",
"element is detached",
"Execution context was destroyed",
"Unable to locate element",
"waiting for locator",
]
# 非致命错误:记录但继续
NON_FATAL_ERRORS: List[str] = [
"Screenshot failed",
"Log write failed",
"Screenshot is not supported",
]
@classmethod
def handle_exception(
cls,
error: Exception,
page: Optional[Page] = None,
test_name: str = "",
screenshot_helper=None,
) -> Optional[Exception]:
"""
处理测试异常
Args:
error: 异常对象
page: Playwright页面对象
test_name: 测试名称
screenshot_helper: 截图辅助工具
Returns:
None: 非致命错误,继续执行
RetryableError: 可重试错误
FatalTestError: 致命错误,停止测试
"""
error_msg = str(error)
error_type = type(error).__name__
# 致命错误
if any(fatal in error_msg for fatal in cls.FATAL_ERRORS):
if screenshot_helper and page:
screenshot_helper.take_screenshot(
page, f"{test_name}_fatal_error" if test_name else "fatal_error"
)
raise FatalTestError(f"测试遇到致命错误 [{error_type}]: {error_msg}")
# 可重试错误
if any(retryable in error_msg for retryable in cls.RETRYABLE_ERRORS):
if screenshot_helper and page:
screenshot_helper.take_screenshot(
page,
f"{test_name}_retryable_error" if test_name else "retryable_error",
)
return RetryableError(f"可重试错误 [{error_type}]: {error_msg}")
# 非致命错误
if any(non_fatal in error_msg for non_fatal in cls.NON_FATAL_ERRORS):
# 仅记录错误,不抛出
return None
# 未知错误 - 根据错误类型判断
if "assert" in error_msg.lower():
# 断言错误通常是测试逻辑问题,不重试
raise error
# 其他错误视为可重试
if screenshot_helper and page:
screenshot_helper.take_screenshot(
page, f"{test_name}_unknown_error" if test_name else "unknown_error"
)
return RetryableError(f"未知错误 [{error_type}]: {error_msg}")
@classmethod
def is_fatal_error(cls, error: Exception) -> bool:
"""判断是否为致命错误"""
error_msg = str(error)
return any(fatal in error_msg for fatal in cls.FATAL_ERRORS)
@classmethod
def is_retryable_error(cls, error: Exception) -> bool:
"""判断是否为可重试错误"""
error_msg = str(error)
return any(retryable in error_msg for retryable in cls.RETRYABLE_ERRORS)
@@ -0,0 +1,40 @@
"""
自定义异常类
提供测试框架的自定义异常。
"""
class TestFrameworkError(Exception):
"""测试框架基础异常"""
pass
class APIError(TestFrameworkError):
"""API错误异常"""
pass
class APITimeoutError(APIError):
"""API超时异常"""
pass
class ValidationError(TestFrameworkError):
"""验证错误异常"""
pass
class ElementNotFoundError(TestFrameworkError):
"""元素未找到异常"""
pass
class TimeoutError(TestFrameworkError):
"""超时异常"""
pass
class ConfigurationError(TestFrameworkError):
"""配置错误异常"""
pass
@@ -0,0 +1,358 @@
"""
文件上传下载功能模块
提供文件上传、下载、验证和管理功能。
"""
import os
import re
import uuid
import shutil
from typing import Any, Dict, List, Optional, BinaryIO
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class UploadResult:
"""上传结果"""
success: bool
file_id: Optional[str] = None
file_path: Optional[str] = None
filename: Optional[str] = None
size: int = 0
message: str = ""
@dataclass
class DownloadResult:
"""下载结果"""
success: bool
content: Optional[bytes] = None
filename: Optional[str] = None
message: str = ""
class FileTypeValidator:
"""文件类型验证器"""
def __init__(self, allowed_extensions: Optional[List[str]] = None):
"""
初始化文件类型验证器
Args:
allowed_extensions: 允许的文件扩展名列表
"""
self._allowed_extensions = allowed_extensions or []
def validate(self, filename: str) -> bool:
"""
验证文件类型
Args:
filename: 文件名
Returns:
是否允许
"""
if not self._allowed_extensions:
return True
ext = Path(filename).suffix.lower()
return ext in self._allowed_extensions
class FileSizeValidator:
"""文件大小验证器"""
def __init__(self, max_size: int):
"""
初始化文件大小验证器
Args:
max_size: 最大文件大小(字节)
"""
self._max_size = max_size
def validate(self, size: int) -> bool:
"""
验证文件大小
Args:
size: 文件大小(字节)
Returns:
是否允许
"""
return size <= self._max_size
class FilenameSanitizer:
"""文件名净化器"""
# 危险字符
DANGEROUS_CHARS = r'[;|&$<>\`\\]'
def sanitize(self, filename: str) -> str:
"""
净化文件名
Args:
filename: 原始文件名
Returns:
安全的文件名
"""
# 移除路径遍历
filename = os.path.basename(filename)
# 移除危险字符
filename = re.sub(self.DANGEROUS_CHARS, '', filename)
# 移除连续的点
filename = re.sub(r'\.{2,}', '.', filename)
# 确保不为空
if not filename or filename == '.':
filename = 'unnamed_file'
return filename
class FileStorageManager:
"""文件存储管理器"""
def __init__(self, storage_dir: str):
"""
初始化存储管理器
Args:
storage_dir: 存储目录
"""
self._storage_dir = storage_dir
self._metadata: Dict[str, Dict[str, Any]] = {}
# 创建存储目录
os.makedirs(storage_dir, exist_ok=True)
def save(
self,
content: bytes,
filename: str,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""
保存文件
Args:
content: 文件内容
filename: 文件名
metadata: 元数据
Returns:
文件ID
"""
file_id = str(uuid.uuid4())
file_path = os.path.join(self._storage_dir, file_id)
with open(file_path, 'wb') as f:
f.write(content)
# 保存元数据
self._metadata[file_id] = {
'filename': filename,
'size': len(content),
'metadata': metadata or {},
}
return file_id
def get(self, file_id: str) -> Optional[bytes]:
"""
获取文件内容
Args:
file_id: 文件ID
Returns:
文件内容
"""
file_path = os.path.join(self._storage_dir, file_id)
if not os.path.exists(file_path):
return None
with open(file_path, 'rb') as f:
return f.read()
def get_metadata(self, file_id: str) -> Optional[Dict[str, Any]]:
"""
获取文件元数据
Args:
file_id: 文件ID
Returns:
元数据
"""
meta = self._metadata.get(file_id)
if meta:
return meta.get('metadata')
return None
def delete(self, file_id: str) -> bool:
"""
删除文件
Args:
file_id: 文件ID
Returns:
是否成功
"""
file_path = os.path.join(self._storage_dir, file_id)
if os.path.exists(file_path):
os.unlink(file_path)
if file_id in self._metadata:
del self._metadata[file_id]
return True
return False
class FileUploader:
"""文件上传器"""
def __init__(
self,
upload_dir: str,
type_validator: Optional[FileTypeValidator] = None,
size_validator: Optional[FileSizeValidator] = None,
filename_sanitizer: Optional[FilenameSanitizer] = None
):
"""
初始化文件上传器
Args:
upload_dir: 上传目录
type_validator: 文件类型验证器
size_validator: 文件大小验证器
filename_sanitizer: 文件名净化器
"""
self._upload_dir = upload_dir
self._type_validator = type_validator or FileTypeValidator()
self._size_validator = size_validator or FileSizeValidator(max_size=10 * 1024 * 1024) # 10MB
self._filename_sanitizer = filename_sanitizer or FilenameSanitizer()
self._storage = FileStorageManager(upload_dir)
def upload(self, file_obj: BinaryIO, filename: str) -> UploadResult:
"""
上传文件
Args:
file_obj: 文件对象
filename: 文件名
Returns:
上传结果
"""
# 净化文件名
safe_filename = self._filename_sanitizer.sanitize(filename)
# 验证文件类型
if not self._type_validator.validate(safe_filename):
return UploadResult(
success=False,
message=f"不支持的文件类型: {safe_filename}"
)
# 读取文件内容
content = file_obj.read()
# 验证文件大小
if not self._size_validator.validate(len(content)):
return UploadResult(
success=False,
message=f"文件大小超过限制"
)
# 保存文件
file_id = self._storage.save(content, safe_filename)
file_path = os.path.join(self._upload_dir, file_id)
return UploadResult(
success=True,
file_id=file_id,
file_path=file_path,
filename=safe_filename,
size=len(content)
)
def upload_batch(self, file_paths: List[str]) -> List[UploadResult]:
"""
批量上传文件
Args:
file_paths: 文件路径列表
Returns:
上传结果列表
"""
results = []
for file_path in file_paths:
try:
with open(file_path, 'rb') as f:
filename = os.path.basename(file_path)
result = self.upload(f, filename)
results.append(result)
except Exception as e:
results.append(UploadResult(
success=False,
message=str(e)
))
return results
class FileDownloader:
"""文件下载器"""
def __init__(self, storage_manager: Optional[FileStorageManager] = None):
"""
初始化文件下载器
Args:
storage_manager: 存储管理器
"""
self._storage = storage_manager
def download(self, file_id: str) -> DownloadResult:
"""
下载文件
Args:
file_id: 文件ID
Returns:
下载结果
"""
if self._storage is None:
return DownloadResult(
success=False,
message="存储管理器未设置"
)
content = self._storage.get(file_id)
if content is None:
return DownloadResult(
success=False,
message="文件不存在"
)
return DownloadResult(
success=True,
content=content
)
@@ -0,0 +1,106 @@
"""
日志记录器模块
提供结构化的测试日志记录功能。
"""
import logging
import sys
from pathlib import Path
from typing import Optional
from datetime import datetime
class TestLogger:
"""测试日志记录器"""
def __init__(self, name: str = "e2e_test", log_dir: str = "logs"):
self.name = name
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
# 创建logger
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.DEBUG)
# 避免重复添加handler
if not self.logger.handlers:
# 控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(console_formatter)
self.logger.addHandler(console_handler)
# 文件处理器
log_file = self.log_dir / f"test_{datetime.now().strftime('%Y%m%d')}.log"
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
file_handler.setFormatter(file_formatter)
self.logger.addHandler(file_handler)
def debug(self, message: str) -> None:
"""记录调试日志"""
self.logger.debug(message)
def info(self, message: str) -> None:
"""记录信息日志"""
self.logger.info(message)
def warning(self, message: str) -> None:
"""记录警告日志"""
self.logger.warning(message)
def error(self, message: str, error: Optional[Exception] = None) -> None:
"""记录错误日志"""
if error:
self.logger.error(f"{message}: {str(error)}", exc_info=True)
else:
self.logger.error(message)
def critical(self, message: str) -> None:
"""记录严重错误日志"""
self.logger.critical(message)
def start_test(self, test_name: str) -> None:
"""记录测试开始"""
self.info(f"{'='*60}")
self.info(f"开始执行测试: {test_name}")
self.info(f"{'='*60}")
def end_test(self, test_name: str, status: str, error: Optional[Exception] = None) -> None:
"""记录测试结束"""
icon = "" if status == "passed" else ""
self.info(f"{'='*60}")
self.info(f"{icon} 测试结束: {test_name} - 状态: {status}")
if error:
self.error(f"错误信息: {str(error)}")
self.info(f"{'='*60}")
def start_step(self, step_name: str) -> None:
"""记录步骤开始"""
self.info(f"▶️ 执行步骤: {step_name}")
def end_step(self, step_name: str, status: str) -> None:
"""记录步骤结束"""
icon = "" if status == "passed" else ""
self.info(f"{icon} 步骤完成: {step_name} - 状态: {status}")
# 全局日志记录器实例
_logger: Optional[TestLogger] = None
def get_logger(name: str = "e2e_test", log_dir: str = "logs") -> TestLogger:
"""获取日志记录器实例"""
global _logger
if _logger is None:
_logger = TestLogger(name, log_dir)
return _logger
@@ -0,0 +1,185 @@
"""
性能监控服务
提供性能监控和优化功能。
"""
import time
import functools
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass, field
from datetime import datetime
import threading
@dataclass
class PerformanceMetric:
"""性能指标数据类"""
name: str
duration: float
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
metadata: Dict[str, Any] = field(default_factory=dict)
class PerformanceMonitor:
"""性能监控器"""
def __init__(self):
self._metrics: List[PerformanceMetric] = []
self._thresholds: Dict[str, float] = {
"page_load": 3.0,
"table_load": 2.0,
"search_response": 1.0,
"form_submit": 2.0,
"concurrent_op": 1.0,
"memory_growth": 50.0,
}
self._lock = threading.Lock()
def record_metric(self, name: str, duration: float, metadata: Dict[str, Any] = None) -> PerformanceMetric:
"""记录性能指标"""
metric = PerformanceMetric(
name=name,
duration=duration,
metadata=metadata or {}
)
with self._lock:
self._metrics.append(metric)
return metric
def measure(self, name: str, metadata: Dict[str, Any] = None):
"""性能测量装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
self.record_metric(name, duration, metadata)
return result
return wrapper
return decorator
def measure_context(self, name: str, metadata: Dict[str, Any] = None):
"""性能测量上下文管理器"""
return PerformanceContext(self, name, metadata)
def get_metrics(self, name: Optional[str] = None) -> List[PerformanceMetric]:
"""获取性能指标"""
with self._lock:
if name:
return [m for m in self._metrics if m.name == name]
return self._metrics.copy()
def get_average_duration(self, name: str) -> float:
"""获取平均持续时间"""
metrics = self.get_metrics(name)
if not metrics:
return 0.0
return sum(m.duration for m in metrics) / len(metrics)
def check_threshold(self, name: str, duration: float) -> bool:
"""检查是否超过阈值"""
threshold = self._thresholds.get(name, float('inf'))
return duration <= threshold
def get_performance_report(self) -> Dict[str, Any]:
"""生成性能报告"""
with self._lock:
report = {
"total_metrics": len(self._metrics),
"metrics_by_name": {},
"threshold_violations": [],
}
# 按名称分组统计
for metric in self._metrics:
if metric.name not in report["metrics_by_name"]:
report["metrics_by_name"][metric.name] = {
"count": 0,
"total_duration": 0.0,
"min_duration": float('inf'),
"max_duration": 0.0,
}
stats = report["metrics_by_name"][metric.name]
stats["count"] += 1
stats["total_duration"] += metric.duration
stats["min_duration"] = min(stats["min_duration"], metric.duration)
stats["max_duration"] = max(stats["max_duration"], metric.duration)
# 检查阈值违规
if not self.check_threshold(metric.name, metric.duration):
report["threshold_violations"].append({
"name": metric.name,
"duration": metric.duration,
"threshold": self._thresholds.get(metric.name),
"timestamp": metric.timestamp,
})
# 计算平均值
for name, stats in report["metrics_by_name"].items():
stats["avg_duration"] = stats["total_duration"] / stats["count"]
return report
def clear_metrics(self):
"""清除所有指标"""
with self._lock:
self._metrics.clear()
class PerformanceContext:
"""性能测量上下文"""
def __init__(self, monitor: PerformanceMonitor, name: str, metadata: Dict[str, Any] = None):
self.monitor = monitor
self.name = name
self.metadata = metadata or {}
self.start_time = None
self.metric = None
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
end_time = time.time()
duration = end_time - self.start_time
self.metric = self.monitor.record_metric(self.name, duration, self.metadata)
# 全局性能监控器实例
performance_monitor = PerformanceMonitor()
# 性能测试辅助函数
def measure_page_load(page, url: str) -> float:
"""测量页面加载时间"""
start_time = time.time()
page.goto(url)
page.wait_for_load_state("networkidle")
end_time = time.time()
return end_time - start_time
def measure_element_load(page, selector: str, timeout: int = 10000) -> float:
"""测量元素加载时间"""
start_time = time.time()
page.wait_for_selector(selector, timeout=timeout)
end_time = time.time()
return end_time - start_time
def measure_api_response(func: Callable, *args, **kwargs) -> tuple:
"""测量API响应时间"""
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
return result, duration
@@ -0,0 +1,302 @@
"""
测试报告生成器模块
提供测试报告生成功能,支持多种报告格式。
"""
import json
import os
from pathlib import Path
from typing import Dict, List, Any, Optional
from datetime import datetime
from dataclasses import dataclass, field, asdict
@dataclass
class TestResult:
"""测试结果数据类"""
name: str
status: str # passed, failed, skipped
duration: float = 0.0
start_time: Optional[str] = None
end_time: Optional[str] = None
error_message: Optional[str] = None
traceback: Optional[str] = None
screenshot: Optional[str] = None
steps: List[Dict[str, Any]] = field(default_factory=list)
@dataclass
class TestSuite:
"""测试套件数据类"""
name: str
tests: List[TestResult] = field(default_factory=list)
@property
def passed_count(self) -> int:
return sum(1 for t in self.tests if t.status == "passed")
@property
def failed_count(self) -> int:
return sum(1 for t in self.tests if t.status == "failed")
@property
def skipped_count(self) -> int:
return sum(1 for t in self.tests if t.status == "skipped")
@property
def total_count(self) -> int:
return len(self.tests)
@property
def pass_rate(self) -> float:
if self.total_count == 0:
return 0.0
return (self.passed_count / self.total_count) * 100
class TestReporter:
"""测试报告生成器"""
def __init__(self, report_dir: str = "reports"):
self.report_dir = Path(report_dir)
self.report_dir.mkdir(parents=True, exist_ok=True)
self.suites: Dict[str, TestSuite] = {}
self.start_time: Optional[datetime] = None
self.end_time: Optional[datetime] = None
def start_report(self) -> None:
"""开始测试报告"""
self.start_time = datetime.now()
print(f"📝 测试报告开始于: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
def end_report(self) -> None:
"""结束测试报告"""
self.end_time = datetime.now()
print(f"📝 测试报告结束于: {self.end_time.strftime('%Y-%m-%d %H:%M:%S')}")
def add_test_result(self, suite_name: str, result: TestResult) -> None:
"""添加测试结果"""
if suite_name not in self.suites:
self.suites[suite_name] = TestSuite(name=suite_name)
self.suites[suite_name].tests.append(result)
def generate_html_report(self, filename: str = "test_report.html") -> str:
"""生成HTML报告"""
filepath = self.report_dir / filename
total_tests = sum(s.total_count for s in self.suites.values())
total_passed = sum(s.passed_count for s in self.suites.values())
total_failed = sum(s.failed_count for s in self.suites.values())
total_skipped = sum(s.skipped_count for s in self.suites.values())
html_content = f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>E2E测试报告</title>
<style>
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #f5f5f5;
padding: 20px;
}}
.container {{ max-width: 1200px; margin: 0 auto; }}
.header {{
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 30px;
border-radius: 10px;
margin-bottom: 20px;
}}
.header h1 {{ font-size: 28px; margin-bottom: 10px; }}
.summary {{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 15px;
margin-bottom: 20px;
}}
.summary-card {{
background: white;
padding: 20px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}}
.summary-card h3 {{ font-size: 14px; color: #666; margin-bottom: 8px; }}
.summary-card .value {{ font-size: 32px; font-weight: bold; }}
.summary-card.passed .value {{ color: #10b981; }}
.summary-card.failed .value {{ color: #ef4444; }}
.summary-card.skipped .value {{ color: #f59e0b; }}
.summary-card.total .value {{ color: #3b82f6; }}
.suite {{
background: white;
border-radius: 8px;
margin-bottom: 20px;
overflow: hidden;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}}
.suite-header {{
background: #f8f9fa;
padding: 15px 20px;
border-bottom: 1px solid #e5e7eb;
}}
.suite-header h2 {{ font-size: 18px; color: #374151; }}
.test-list {{ padding: 0; }}
.test-item {{
padding: 15px 20px;
border-bottom: 1px solid #e5e7eb;
display: flex;
justify-content: space-between;
align-items: center;
}}
.test-item:last-child {{ border-bottom: none; }}
.test-item.passed {{ border-left: 4px solid #10b981; }}
.test-item.failed {{ border-left: 4px solid #ef4444; }}
.test-item.skipped {{ border-left: 4px solid #f59e0b; }}
.test-name {{ font-weight: 500; color: #374151; }}
.test-status {{
padding: 4px 12px;
border-radius: 12px;
font-size: 12px;
font-weight: 500;
}}
.test-status.passed {{ background: #d1fae5; color: #065f46; }}
.test-status.failed {{ background: #fee2e2; color: #991b1b; }}
.test-status.skipped {{ background: #fef3c7; color: #92400e; }}
.test-duration {{ color: #6b7280; font-size: 12px; margin-left: 10px; }}
.error-message {{
background: #fee2e2;
color: #991b1b;
padding: 10px;
border-radius: 4px;
margin-top: 10px;
font-size: 12px;
}}
.footer {{
text-align: center;
padding: 20px;
color: #6b7280;
font-size: 12px;
}}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🧪 E2E测试报告</h1>
<p>生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
<div class="summary">
<div class="summary-card total">
<h3>总测试数</h3>
<div class="value">{total_tests}</div>
</div>
<div class="summary-card passed">
<h3>通过</h3>
<div class="value">{total_passed}</div>
</div>
<div class="summary-card failed">
<h3>失败</h3>
<div class="value">{total_failed}</div>
</div>
<div class="summary-card skipped">
<h3>跳过</h3>
<div class="value">{total_skipped}</div>
</div>
</div>
"""
# 添加测试套件详情
for suite_name, suite in self.suites.items():
html_content += f"""
<div class="suite">
<div class="suite-header">
<h2>{suite_name}</h2>
<p>通过率: {suite.pass_rate:.1f}% ({suite.passed_count}/{suite.total_count})</p>
</div>
<div class="test-list">
"""
for test in suite.tests:
status_class = test.status
error_html = ""
if test.error_message:
error_html = f'<div class="error-message">{test.error_message}</div>'
html_content += f"""
<div class="test-item {status_class}">
<div>
<span class="test-name">{test.name}</span>
<span class="test-duration">{test.duration:.2f}s</span>
{error_html}
</div>
<span class="test-status {status_class}">{test.status.upper()}</span>
</div>
"""
html_content += """
</div>
</div>
"""
html_content += """
<div class="footer">
<p>Generated by Python E2E Test Framework</p>
</div>
</div>
</body>
</html>
"""
with open(filepath, "w", encoding="utf-8") as f:
f.write(html_content)
print(f"📊 HTML报告已生成: {filepath}")
return str(filepath)
def generate_json_report(self, filename: str = "test_report.json") -> str:
"""生成JSON报告"""
filepath = self.report_dir / filename
report_data = {
"report_info": {
"generated_at": datetime.now().isoformat(),
"start_time": self.start_time.isoformat() if self.start_time else None,
"end_time": self.end_time.isoformat() if self.end_time else None,
},
"summary": {
"total": sum(s.total_count for s in self.suites.values()),
"passed": sum(s.passed_count for s in self.suites.values()),
"failed": sum(s.failed_count for s in self.suites.values()),
"skipped": sum(s.skipped_count for s in self.suites.values()),
},
"suites": {}
}
for suite_name, suite in self.suites.items():
report_data["suites"][suite_name] = {
"summary": {
"total": suite.total_count,
"passed": suite.passed_count,
"failed": suite.failed_count,
"skipped": suite.skipped_count,
"pass_rate": suite.pass_rate,
},
"tests": [asdict(test) for test in suite.tests]
}
with open(filepath, "w", encoding="utf-8") as f:
json.dump(report_data, f, ensure_ascii=False, indent=2)
print(f"📊 JSON报告已生成: {filepath}")
return str(filepath)
def generate_all_reports(self) -> Dict[str, str]:
"""生成所有报告"""
return {
"html": self.generate_html_report(),
"json": self.generate_json_report(),
}
@@ -0,0 +1,155 @@
"""
重试装饰器模块
提供测试方法重试功能。
"""
import functools
import time
from typing import Callable, TypeVar, Optional, Tuple, Type
from .exception_handler import TestExceptionHandler, RetryableError, FatalTestError
from .logger import get_logger
T = TypeVar("T")
def retry_on_failure(
max_retries: int = 3,
delay: float = 1.0,
backoff: float = 2.0,
exceptions: Optional[Tuple[Type[Exception], ...]] = None,
on_retry: Optional[Callable[[Exception, int], None]] = None,
):
"""
重试装饰器
Args:
max_retries: 最大重试次数
delay: 初始重试延迟(秒)
backoff: 延迟增长倍数
exceptions: 需要重试的异常类型
on_retry: 重试时的回调函数
Returns:
装饰器函数
Example:
@retry_on_failure(max_retries=3, delay=1.0)
def test_unstable_feature(page):
# 测试代码
pass
"""
if exceptions is None:
exceptions = (Exception,)
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> T:
logger = get_logger()
retry_count = 0
current_delay = delay
last_exception: Optional[Exception] = None
while retry_count < max_retries:
try:
return func(*args, **kwargs)
except exceptions as e:
retry_count += 1
last_exception = e
# 检查是否为致命错误
if TestExceptionHandler.is_fatal_error(e):
logger.error(f"遇到致命错误,停止重试: {e}")
raise
# 检查是否为可重试错误
if not TestExceptionHandler.is_retryable_error(e):
logger.error(f"非可重试错误,停止重试: {e}")
raise
if retry_count >= max_retries:
logger.error(
f"函数 {func.__name__}{max_retries} 次尝试后仍然失败"
)
raise
logger.warning(
f"函数 {func.__name__} 失败 (尝试 {retry_count}/{max_retries}): {e}"
)
# 调用重试回调
if on_retry:
on_retry(e, retry_count)
# 等待后重试
time.sleep(current_delay)
current_delay *= backoff
# 如果所有重试都失败,抛出最后的异常
if last_exception:
raise last_exception
# 不应该到达这里
raise RuntimeError("重试逻辑异常")
return wrapper
return decorator
def retry_with_timeout(
timeout: float = 30.0,
interval: float = 0.5,
exceptions: Optional[Tuple[Type[Exception], ...]] = None,
):
"""
超时重试装饰器
在指定超时时间内持续重试
Args:
timeout: 超时时间(秒)
interval: 重试间隔(秒)
exceptions: 需要重试的异常类型
Returns:
装饰器函数
"""
if exceptions is None:
exceptions = (Exception,)
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> T:
logger = get_logger()
start_time = time.time()
attempt = 0
while time.time() - start_time < timeout:
try:
return func(*args, **kwargs)
except exceptions as e:
attempt += 1
elapsed = time.time() - start_time
# 检查是否为致命错误
if TestExceptionHandler.is_fatal_error(e):
logger.error(f"遇到致命错误,停止重试: {e}")
raise
if elapsed >= timeout:
logger.error(
f"函数 {func.__name__}{timeout} 秒后超时,共尝试 {attempt}"
)
raise
logger.debug(
f"函数 {func.__name__} 失败 (尝试 {attempt}): {e}{interval}秒后重试"
)
time.sleep(interval)
raise TimeoutError(f"函数 {func.__name__}{timeout} 秒内未完成")
return wrapper
return decorator
@@ -0,0 +1,118 @@
"""
截图辅助工具模块
提供测试截图功能,支持多种截图模式。
"""
import os
from pathlib import Path
from typing import Optional
from datetime import datetime
from playwright.sync_api import Page
class ScreenshotHelper:
"""截图辅助工具"""
def __init__(self, screenshot_dir: str = "reports/screenshots"):
self.screenshot_dir = Path(screenshot_dir)
self.screenshot_dir.mkdir(parents=True, exist_ok=True)
def take_screenshot(
self,
page: Page,
name: str,
full_page: bool = False,
selector: Optional[str] = None,
) -> str:
"""
截取页面截图
Args:
page: Playwright页面对象
name: 截图文件名
full_page: 是否截取整个页面
selector: 特定元素选择器
Returns:
截图文件路径
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{name}_{timestamp}.png"
filepath = self.screenshot_dir / filename
try:
if selector:
# 截取特定元素
element = page.locator(selector)
element.screenshot(path=str(filepath))
else:
# 截取页面
page.screenshot(path=str(filepath), full_page=full_page)
print(f"📸 截图已保存: {filepath}")
return str(filepath)
except Exception as e:
print(f"❌ 截图失败: {e}")
return ""
def take_screenshot_on_failure(
self, page: Page, test_name: str, full_page: bool = True
) -> str:
"""
测试失败时截取截图
Args:
page: Playwright页面对象
test_name: 测试名称
full_page: 是否截取整个页面
Returns:
截图文件路径
"""
return self.take_screenshot(
page, f"{test_name}_failed", full_page=full_page
)
def take_comparison_screenshot(
self, page: Page, name: str, baseline_dir: str = "screenshots/baseline"
) -> tuple:
"""
截取对比截图
Args:
page: Playwright页面对象
name: 截图名称
baseline_dir: 基线截图目录
Returns:
(当前截图路径, 基线截图路径)
"""
current_path = self.take_screenshot(page, name, full_page=True)
baseline_path = Path(baseline_dir) / f"{name}.png"
return current_path, str(baseline_path)
def cleanup_old_screenshots(self, days: int = 7) -> int:
"""
清理旧截图
Args:
days: 保留天数
Returns:
删除的文件数量
"""
import time
deleted_count = 0
cutoff_time = time.time() - (days * 24 * 60 * 60)
for file_path in self.screenshot_dir.glob("*.png"):
if file_path.stat().st_mtime < cutoff_time:
file_path.unlink()
deleted_count += 1
print(f"🗑️ 已清理 {deleted_count} 个旧截图文件")
return deleted_count
@@ -0,0 +1,482 @@
"""
安全测试模块
提供SQL注入、XSS、CSRF等安全防护功能。
"""
import re
import hashlib
import secrets
import time
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
from enum import Enum
class ThreatLevel(Enum):
"""威胁等级"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
@dataclass
class DetectionResult:
"""检测结果"""
is_threat: bool
threat_type: str
level: ThreatLevel
details: str = ""
@dataclass
class SQLInjectionResult:
"""SQL注入检测结果"""
is_injection: bool = False
level: ThreatLevel = ThreatLevel.LOW
details: str = ""
@property
def is_threat(self) -> bool:
return self.is_injection
@property
def threat_type(self) -> str:
return "SQL_INJECTION"
@dataclass
class XSSResult:
"""XSS检测结果"""
is_xss: bool = False
level: ThreatLevel = ThreatLevel.LOW
details: str = ""
@property
def is_threat(self) -> bool:
return self.is_xss
@property
def threat_type(self) -> str:
return "XSS"
@dataclass
class PasswordStrengthResult:
"""密码强度结果"""
score: int
strength: str
suggestions: List[str] = field(default_factory=list)
@dataclass
class SecurityEvent:
"""安全事件"""
timestamp: float
event_type: str
source_ip: str
details: Dict[str, Any]
@dataclass
class SecurityReport:
"""安全扫描报告"""
total_scanned: int
threats: List[DetectionResult]
scan_time: float
class SQLInjectionDetector:
"""SQL注入检测器"""
# SQL注入特征模式
PATTERNS = [
r"(\%27)|(\')|(\-\-)|(\%23)|(#)", # 单引号、注释
r"((\%3D)|(=))[^\n]*((\%27)|(\')|(\-\-)|(\%3B)|(;))", # =后面跟引号或注释
r"\w*((\%27)|(\'))((\%6F)|o|(\%4F))((\%72)|r|(\%52))", # 'or
r"((\%27)|(\'))union", # 'union
r"exec(\s|\+)+(s|x)p\w+", # exec xp_
r"UNION\s+SELECT", # UNION SELECT
r"INSERT\s+INTO", # INSERT INTO
r"DELETE\s+FROM", # DELETE FROM
r"DROP\s+TABLE", # DROP TABLE
]
def __init__(self):
self._compiled_patterns = [re.compile(p, re.IGNORECASE) for p in self.PATTERNS]
def detect(self, input_str: str) -> SQLInjectionResult:
"""
检测SQL注入
Args:
input_str: 输入字符串
Returns:
检测结果
"""
for pattern in self._compiled_patterns:
if pattern.search(input_str):
return SQLInjectionResult(
is_injection=True,
level=ThreatLevel.HIGH,
details=f"匹配模式: {pattern.pattern}"
)
return SQLInjectionResult(is_injection=False, level=ThreatLevel.LOW)
class XSSDetector:
"""XSS检测器"""
# XSS攻击特征模式
PATTERNS = [
r"<script[^>]*>[\s\S]*?</script>", # <script>标签
r"javascript:", # javascript:协议
r"on\w+\s*=", # 事件处理器
r"<iframe", # iframe标签
r"<object", # object标签
r"<embed", # embed标签
r"<form", # form标签
r"<input[^>]*type\s*=\s*['\"]?hidden", # hidden input
r"expression\s*\(", # CSS expression
r"url\s*\(\s*['\"]?javascript:", # CSS url javascript
]
def __init__(self):
self._compiled_patterns = [re.compile(p, re.IGNORECASE) for p in self.PATTERNS]
def detect(self, input_str: str) -> XSSResult:
"""
检测XSS攻击
Args:
input_str: 输入字符串
Returns:
检测结果
"""
for pattern in self._compiled_patterns:
if pattern.search(input_str):
return XSSResult(
is_xss=True,
level=ThreatLevel.HIGH,
details=f"匹配模式: {pattern.pattern}"
)
return XSSResult(is_xss=False, level=ThreatLevel.LOW)
class CSRFProtector:
"""CSRF防护器"""
def __init__(self, token_expiry: int = 3600):
"""
初始化CSRF防护器
Args:
token_expiry: Token过期时间(秒)
"""
self._token_expiry = token_expiry
self._tokens: Dict[str, Dict[str, Any]] = {}
def generate_token(self, user_id: str) -> str:
"""
生成CSRF Token
Args:
user_id: 用户ID
Returns:
Token字符串
"""
token = secrets.token_urlsafe(32)
self._tokens[token] = {
"user_id": user_id,
"created_at": time.time(),
}
return token
def validate_token(self, user_id: str, token: str) -> bool:
"""
验证CSRF Token
Args:
user_id: 用户ID
token: Token字符串
Returns:
是否有效
"""
if token not in self._tokens:
return False
token_data = self._tokens[token]
# 检查用户ID
if token_data["user_id"] != user_id:
return False
# 检查是否过期
if time.time() - token_data["created_at"] > self._token_expiry:
del self._tokens[token]
return False
return True
def invalidate_token(self, token: str) -> None:
"""使Token失效"""
if token in self._tokens:
del self._tokens[token]
class InputSanitizer:
"""输入净化器"""
# HTML危险标签和属性
DANGEROUS_TAGS = [
"script", "iframe", "object", "embed", "form", "input",
"textarea", "button", "link", "meta", "style"
]
DANGEROUS_ATTRIBUTES = [
"onerror", "onload", "onclick", "onmouseover", "onmouseout",
"onkeydown", "onkeypress", "onkeyup", "onsubmit", "onchange",
"onfocus", "onblur", "onselect", "onreset"
]
def sanitize_html(self, html: str) -> str:
"""
净化HTML内容
Args:
html: HTML字符串
Returns:
净化后的HTML
"""
# 移除危险标签
for tag in self.DANGEROUS_TAGS:
pattern = f"<{tag}[^>]*>[\\s\\S]*?</{tag}>"
html = re.sub(pattern, "", html, flags=re.IGNORECASE)
pattern = f"<{tag}[^>]*/?>"
html = re.sub(pattern, "", html, flags=re.IGNORECASE)
# 移除危险属性
for attr in self.DANGEROUS_ATTRIBUTES:
pattern = f"\\s{attr}=[\"'][^\"']*[\"']"
html = re.sub(pattern, "", html, flags=re.IGNORECASE)
# 移除javascript:协议
html = re.sub(r"javascript:", "", html, flags=re.IGNORECASE)
return html
def sanitize_sql(self, input_str: str) -> str:
"""
净化SQL输入
Args:
input_str: 输入字符串
Returns:
净化后的字符串
"""
# 转义单引号
return input_str.replace("'", "''")
class PasswordStrengthChecker:
"""密码强度检查器"""
def check(self, password: str) -> PasswordStrengthResult:
"""
检查密码强度
Args:
password: 密码字符串
Returns:
强度结果
"""
score = 0
suggestions = []
# 长度检查
if len(password) >= 8:
score += 2
elif len(password) >= 6:
score += 1
else:
suggestions.append("密码长度至少8位")
# 包含小写字母
if re.search(r"[a-z]", password):
score += 1
else:
suggestions.append("应包含小写字母")
# 包含大写字母
if re.search(r"[A-Z]", password):
score += 1
else:
suggestions.append("应包含大写字母")
# 包含数字
if re.search(r"\d", password):
score += 1
else:
suggestions.append("应包含数字")
# 包含特殊字符
if re.search(r"[!@#$%^&*(),.?\":{}|<>]", password):
score += 2
else:
suggestions.append("应包含特殊字符")
# 确定强度等级
if score >= 7:
strength = "strong"
elif score >= 4:
strength = "medium"
else:
strength = "weak"
return PasswordStrengthResult(
score=score,
strength=strength,
suggestions=suggestions
)
class SecurityHeaders:
"""安全HTTP头部生成器"""
def get_headers(self) -> Dict[str, str]:
"""
获取安全HTTP头部
Returns:
安全头部字典
"""
return {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
"Content-Security-Policy": "default-src 'self'",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
}
class SecurityAuditLogger:
"""安全审计日志器"""
def __init__(self):
self._events: List[SecurityEvent] = []
def log_event(
self,
event_type: str,
source_ip: str,
details: Dict[str, Any]
) -> None:
"""
记录安全事件
Args:
event_type: 事件类型
source_ip: 来源IP
details: 详细信息
"""
event = SecurityEvent(
timestamp=time.time(),
event_type=event_type,
source_ip=source_ip,
details=details
)
self._events.append(event)
def get_events(
self,
event_type: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None
) -> List[SecurityEvent]:
"""
获取安全事件
Args:
event_type: 事件类型过滤
start_time: 开始时间
end_time: 结束时间
Returns:
事件列表
"""
events = self._events
if event_type:
events = [e for e in events if e.event_type == event_type]
if start_time:
events = [e for e in events if e.timestamp >= start_time]
if end_time:
events = [e for e in events if e.timestamp <= end_time]
return events
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
event_types = {}
for event in self._events:
event_types[event.event_type] = event_types.get(event.event_type, 0) + 1
return {
"total_events": len(self._events),
"event_types": event_types,
}
class SecurityScanner:
"""综合安全扫描器"""
def __init__(self):
self._sql_detector = SQLInjectionDetector()
self._xss_detector = XSSDetector()
def scan(self, data: Dict[str, Any]) -> SecurityReport:
"""
扫描数据
Args:
data: 要扫描的数据
Returns:
扫描报告
"""
threats = []
start_time = time.time()
for key, value in data.items():
if isinstance(value, str):
# SQL注入检测
sql_result = self._sql_detector.detect(value)
if sql_result.is_injection:
threats.append(sql_result)
# XSS检测
xss_result = self._xss_detector.detect(value)
if xss_result.is_xss:
threats.append(xss_result)
scan_time = time.time() - start_time
return SecurityReport(
total_scanned=len(data),
threats=threats,
scan_time=scan_time
)
@@ -0,0 +1,288 @@
"""
定时任务调度器模块
提供定时任务的创建、调度、执行和管理功能。
"""
import time
import threading
import uuid
from typing import Any, Callable, Dict, List, Optional
from dataclasses import dataclass, field
from enum import Enum
import heapq
class TaskStatus(Enum):
"""任务状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
CANCELLED = "cancelled"
ERROR = "error"
class SchedulerState(Enum):
"""调度器状态"""
RUNNING = "running"
PAUSED = "paused"
STOPPED = "stopped"
@dataclass
class Task:
"""任务定义"""
name: str
func: Callable
interval: float = 0 # 执行间隔(秒)
delay: float = 0 # 延迟执行时间(秒)
repeat: bool = False # 是否重复执行
priority: int = 5 # 优先级(1-10,数字越大优先级越高)
on_error: Optional[Callable[[Exception], None]] = None
max_retries: int = 0
# 内部字段
id: str = field(default_factory=lambda: str(uuid.uuid4()))
status: TaskStatus = TaskStatus.PENDING
next_run_time: float = 0
execution_count: int = 0
error_count: int = 0
def __post_init__(self):
if self.next_run_time == 0:
self.next_run_time = time.time() + self.delay
def __lt__(self, other):
# 用于优先级队列比较
if self.next_run_time != other.next_run_time:
return self.next_run_time < other.next_run_time
return self.priority > other.priority
@dataclass
class TaskExecutionRecord:
"""任务执行记录"""
task_id: str
task_name: str
start_time: float
end_time: float
success: bool
error: Optional[str] = None
class TaskScheduler:
"""
任务调度器
特性:
- 支持定时和周期性任务
- 支持任务优先级
- 支持任务取消
- 支持错误处理
- 支持暂停/恢复
"""
def __init__(self):
"""初始化调度器"""
self._tasks: Dict[str, Task] = {}
self._task_queue: List[Task] = []
self._lock = threading.RLock()
self._condition = threading.Condition(self._lock)
self._state = SchedulerState.STOPPED
self._worker_thread: Optional[threading.Thread] = None
self._execution_records: List[TaskExecutionRecord] = []
self._total_executions = 0
self._total_errors = 0
def start(self) -> None:
"""启动调度器"""
with self._lock:
if self._state == SchedulerState.RUNNING:
return
self._state = SchedulerState.RUNNING
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
self._worker_thread.start()
def stop(self) -> None:
"""停止调度器"""
with self._lock:
self._state = SchedulerState.STOPPED
self._condition.notify_all()
if self._worker_thread and self._worker_thread.is_alive():
self._worker_thread.join(timeout=5)
def pause(self) -> None:
"""暂停调度器"""
with self._lock:
self._state = SchedulerState.PAUSED
def resume(self) -> None:
"""恢复调度器"""
with self._lock:
if self._state == SchedulerState.PAUSED:
self._state = SchedulerState.RUNNING
self._condition.notify_all()
def schedule(self, task: Task) -> str:
"""
调度任务
Args:
task: 要调度的任务
Returns:
任务ID
"""
with self._lock:
self._tasks[task.id] = task
heapq.heappush(self._task_queue, task)
self._condition.notify()
# 自动启动调度器
if self._state == SchedulerState.STOPPED:
self.start()
return task.id
def cancel(self, task_id: str) -> bool:
"""
取消任务
Args:
task_id: 任务ID
Returns:
是否成功取消
"""
with self._lock:
if task_id in self._tasks:
task = self._tasks[task_id]
task.status = TaskStatus.CANCELLED
return True
return False
def get_stats(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典
"""
with self._lock:
return {
"total_executions": self._total_executions,
"total_errors": self._total_errors,
"pending_tasks": len([t for t in self._tasks.values() if t.status == TaskStatus.PENDING]),
"running_tasks": len([t for t in self._tasks.values() if t.status == TaskStatus.RUNNING]),
"state": self._state.value,
}
def _worker_loop(self) -> None:
"""工作线程循环"""
while True:
with self._lock:
# 检查状态
if self._state == SchedulerState.STOPPED:
break
# 如果暂停,等待恢复
while self._state == SchedulerState.PAUSED:
self._condition.wait()
if self._state == SchedulerState.STOPPED:
return
# 获取下一个要执行的任务
task = self._get_next_task()
if task is None:
# 没有任务,等待一段时间
self._condition.wait(timeout=0.1)
continue
# 检查任务是否被取消
if task.status == TaskStatus.CANCELLED:
continue
# 执行任务
task.status = TaskStatus.RUNNING
# 在锁外执行任务
self._execute_task(task)
def _get_next_task(self) -> Optional[Task]:
"""获取下一个要执行的任务"""
now = time.time()
while self._task_queue:
task = heapq.heappop(self._task_queue)
# 检查任务是否有效
if task.status == TaskStatus.CANCELLED:
continue
# 检查是否到执行时间
if task.next_run_time <= now:
return task
else:
# 还没到时间,放回队列
heapq.heappush(self._task_queue, task)
break
return None
def _execute_task(self, task: Task) -> None:
"""执行任务"""
# 再次检查任务是否被取消
if task.status == TaskStatus.CANCELLED:
return
start_time = time.time()
success = False
error_msg = None
try:
task.func()
success = True
task.execution_count += 1
except Exception as e:
success = False
error_msg = str(e)
task.error_count += 1
# 调用错误处理回调
if task.on_error:
try:
task.on_error(e)
except:
pass
with self._lock:
self._total_errors += 1
end_time = time.time()
# 记录执行
with self._lock:
self._total_executions += 1
self._execution_records.append(TaskExecutionRecord(
task_id=task.id,
task_name=task.name,
start_time=start_time,
end_time=end_time,
success=success,
error=error_msg
))
# 处理周期性任务
with self._lock:
if task.repeat and task.status != TaskStatus.CANCELLED:
if task.error_count <= task.max_retries or task.max_retries == 0:
task.status = TaskStatus.PENDING
task.next_run_time = time.time() + task.interval
heapq.heappush(self._task_queue, task)
else:
task.status = TaskStatus.ERROR
else:
task.status = TaskStatus.COMPLETED if success else TaskStatus.ERROR
@@ -0,0 +1,156 @@
"""
验证服务
提供各种边界条件的验证功能。
"""
import re
from typing import Dict, Any, Tuple, Optional
class ValidationService:
"""验证服务类"""
# 验证规则常量
USERNAME_MIN_LENGTH = 3
USERNAME_MAX_LENGTH = 20
USERNAME_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$')
EMAIL_PATTERN = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
ROLE_NAME_MIN_LENGTH = 1
ROLE_NAME_MAX_LENGTH = 50
ROLE_CODE_MIN_LENGTH = 1
ROLE_CODE_MAX_LENGTH = 30
ROLE_CODE_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$')
@staticmethod
def validate_username(username: str) -> Tuple[bool, Optional[str]]:
"""
验证用户名
Returns:
(is_valid, error_message)
"""
if not username:
return False, "用户名不能为空"
if len(username) < ValidationService.USERNAME_MIN_LENGTH:
return False, f"用户名长度不能少于{ValidationService.USERNAME_MIN_LENGTH}个字符"
if len(username) > ValidationService.USERNAME_MAX_LENGTH:
return False, f"用户名长度不能超过{ValidationService.USERNAME_MAX_LENGTH}个字符"
if not ValidationService.USERNAME_PATTERN.match(username):
return False, "用户名只能包含字母、数字和下划线"
return True, None
@staticmethod
def validate_email(email: str) -> Tuple[bool, Optional[str]]:
"""
验证邮箱格式
Returns:
(is_valid, error_message)
"""
if not email:
return False, "邮箱不能为空"
if not ValidationService.EMAIL_PATTERN.match(email):
return False, "邮箱格式不正确"
return True, None
@staticmethod
def validate_role_name(name: str) -> Tuple[bool, Optional[str]]:
"""
验证角色名称
Returns:
(is_valid, error_message)
"""
if not name:
return False, "角色名称不能为空"
if len(name) > ValidationService.ROLE_NAME_MAX_LENGTH:
return False, f"角色名称长度不能超过{ValidationService.ROLE_NAME_MAX_LENGTH}个字符"
return True, None
@staticmethod
def validate_role_code(code: str) -> Tuple[bool, Optional[str]]:
"""
验证角色编码
Returns:
(is_valid, error_message)
"""
if not code:
return False, "角色编码不能为空"
if len(code) > ValidationService.ROLE_CODE_MAX_LENGTH:
return False, f"角色编码长度不能超过{ValidationService.ROLE_CODE_MAX_LENGTH}个字符"
if not ValidationService.ROLE_CODE_PATTERN.match(code):
return False, "角色编码只能包含字母、数字和下划线"
return True, None
@staticmethod
def validate_user_data(data: Dict[str, Any]) -> Dict[str, Any]:
"""
验证用户数据
Returns:
{"valid": bool, "errors": Dict[str, str]}
"""
errors = {}
# 验证用户名
if "username" in data:
is_valid, error = ValidationService.validate_username(data["username"])
if not is_valid:
errors["username"] = error
# 验证邮箱
if "email" in data:
is_valid, error = ValidationService.validate_email(data["email"])
if not is_valid:
errors["email"] = error
return {
"valid": len(errors) == 0,
"errors": errors
}
@staticmethod
def validate_role_data(data: Dict[str, Any]) -> Dict[str, Any]:
"""
验证角色数据
Returns:
{"valid": bool, "errors": Dict[str, str]}
"""
errors = {}
# 验证角色名称
if "name" in data:
is_valid, error = ValidationService.validate_role_name(data["name"])
if not is_valid:
errors["name"] = error
# 验证角色编码
if "code" in data:
is_valid, error = ValidationService.validate_role_code(data["code"])
if not is_valid:
errors["code"] = error
return {
"valid": len(errors) == 0,
"errors": errors
}
# 全局验证服务实例
validation_service = ValidationService()