feat(admin): 添加用户管理相关文件
添加用户管理视图、API和状态管理文件
This commit is contained in:
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
E2E测试核心框架模块
|
||||
|
||||
提供测试框架的基础功能,包括:
|
||||
- 配置管理
|
||||
- 日志记录
|
||||
- 异常处理
|
||||
- 重试机制
|
||||
- 报告生成
|
||||
- Caffeine缓存管理
|
||||
- 数据库连接池管理
|
||||
- 本地并发控制
|
||||
- 安全测试模块
|
||||
- 文件上传下载功能
|
||||
- 定时任务调度器
|
||||
- 数据导入导出功能
|
||||
- 审计日志模块
|
||||
- 数据备份恢复功能
|
||||
"""
|
||||
|
||||
from .config_manager import ConfigManager, TestConfig
|
||||
from .logger import TestLogger, get_logger
|
||||
from .exception_handler import TestExceptionHandler, FatalTestError, RetryableError
|
||||
from .retry_decorator import retry_on_failure
|
||||
from .reporter import TestReporter
|
||||
from .screenshot_helper import ScreenshotHelper
|
||||
from .caffeine_cache import CaffeineCache, CaffeineCacheManager, cache_manager
|
||||
from .connection_pool import ConnectionPool, ConnectionPoolManager, pool_manager
|
||||
from .concurrency_control import (
|
||||
SemaphoreControl,
|
||||
ReadWriteLock,
|
||||
RateLimiter,
|
||||
LocalDistributedLock,
|
||||
ConcurrentCounter,
|
||||
ThreadBarrier,
|
||||
BoundedTaskQueue,
|
||||
ConcurrencyManager,
|
||||
concurrency_manager,
|
||||
)
|
||||
from .security import (
|
||||
SQLInjectionDetector,
|
||||
XSSDetector,
|
||||
CSRFProtector,
|
||||
InputSanitizer,
|
||||
PasswordStrengthChecker,
|
||||
SecurityHeaders,
|
||||
SecurityAuditLogger,
|
||||
SecurityScanner,
|
||||
ThreatLevel,
|
||||
DetectionResult,
|
||||
SQLInjectionResult,
|
||||
XSSResult,
|
||||
PasswordStrengthResult,
|
||||
SecurityEvent,
|
||||
SecurityReport,
|
||||
)
|
||||
from .file_handler import (
|
||||
FileUploader,
|
||||
FileDownloader,
|
||||
FileTypeValidator,
|
||||
FileSizeValidator,
|
||||
FilenameSanitizer,
|
||||
FileStorageManager,
|
||||
UploadResult,
|
||||
DownloadResult,
|
||||
)
|
||||
from .task_scheduler import (
|
||||
TaskScheduler,
|
||||
Task,
|
||||
TaskStatus,
|
||||
SchedulerState,
|
||||
TaskExecutionRecord,
|
||||
)
|
||||
from .data_import_export import (
|
||||
CSVExporter,
|
||||
CSVImporter,
|
||||
ExcelExporter,
|
||||
DataValidator,
|
||||
DataTransformer,
|
||||
TemplateManager,
|
||||
DataImportExportManager,
|
||||
ExportResult,
|
||||
ImportResult,
|
||||
ValidationResult,
|
||||
)
|
||||
from .audit_log import (
|
||||
OperationLogRecorder,
|
||||
ObjectChangeAuditor,
|
||||
AuditLogStorage,
|
||||
MemoryAuditStorage,
|
||||
AuditLogRecorder,
|
||||
AuditLogExporter,
|
||||
AuditStatistics,
|
||||
OperationLogEntry,
|
||||
ObjectChange,
|
||||
DiffResult,
|
||||
audit_log,
|
||||
)
|
||||
from .backup_restore import (
|
||||
BackupManager,
|
||||
BackupScheduler,
|
||||
BackupResult,
|
||||
RestoreResult,
|
||||
VerifyResult,
|
||||
DeleteResult,
|
||||
BackupInfo,
|
||||
)
|
||||
from .api_client import (
|
||||
APIClient,
|
||||
APIRequest,
|
||||
APIResponse,
|
||||
HTTPMethod,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConfigManager",
|
||||
"TestConfig",
|
||||
"TestLogger",
|
||||
"get_logger",
|
||||
"TestExceptionHandler",
|
||||
"FatalTestError",
|
||||
"RetryableError",
|
||||
"retry_on_failure",
|
||||
"TestReporter",
|
||||
"ScreenshotHelper",
|
||||
"CaffeineCache",
|
||||
"CaffeineCacheManager",
|
||||
"cache_manager",
|
||||
"ConnectionPool",
|
||||
"ConnectionPoolManager",
|
||||
"pool_manager",
|
||||
"SemaphoreControl",
|
||||
"ReadWriteLock",
|
||||
"RateLimiter",
|
||||
"LocalDistributedLock",
|
||||
"ConcurrentCounter",
|
||||
"ThreadBarrier",
|
||||
"BoundedTaskQueue",
|
||||
"ConcurrencyManager",
|
||||
"concurrency_manager",
|
||||
"SQLInjectionDetector",
|
||||
"XSSDetector",
|
||||
"CSRFProtector",
|
||||
"InputSanitizer",
|
||||
"PasswordStrengthChecker",
|
||||
"SecurityHeaders",
|
||||
"SecurityAuditLogger",
|
||||
"SecurityScanner",
|
||||
"ThreatLevel",
|
||||
"DetectionResult",
|
||||
"SQLInjectionResult",
|
||||
"XSSResult",
|
||||
"PasswordStrengthResult",
|
||||
"SecurityEvent",
|
||||
"SecurityReport",
|
||||
"FileUploader",
|
||||
"FileDownloader",
|
||||
"FileTypeValidator",
|
||||
"FileSizeValidator",
|
||||
"FilenameSanitizer",
|
||||
"FileStorageManager",
|
||||
"UploadResult",
|
||||
"DownloadResult",
|
||||
"TaskScheduler",
|
||||
"Task",
|
||||
"TaskStatus",
|
||||
"SchedulerState",
|
||||
"TaskExecutionRecord",
|
||||
"CSVExporter",
|
||||
"CSVImporter",
|
||||
"ExcelExporter",
|
||||
"DataValidator",
|
||||
"DataTransformer",
|
||||
"TemplateManager",
|
||||
"DataImportExportManager",
|
||||
"ExportResult",
|
||||
"ImportResult",
|
||||
"ValidationResult",
|
||||
"OperationLogRecorder",
|
||||
"ObjectChangeAuditor",
|
||||
"AuditLogStorage",
|
||||
"MemoryAuditStorage",
|
||||
"AuditLogRecorder",
|
||||
"AuditLogExporter",
|
||||
"AuditStatistics",
|
||||
"OperationLogEntry",
|
||||
"ObjectChange",
|
||||
"DiffResult",
|
||||
"audit_log",
|
||||
"BackupManager",
|
||||
"BackupScheduler",
|
||||
"BackupResult",
|
||||
"RestoreResult",
|
||||
"VerifyResult",
|
||||
"DeleteResult",
|
||||
"BackupInfo",
|
||||
"APIClient",
|
||||
"APIRequest",
|
||||
"APIResponse",
|
||||
"HTTPMethod",
|
||||
]
|
||||
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
API客户端模块
|
||||
|
||||
提供HTTP请求封装和API调用功能。
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class HTTPMethod(Enum):
|
||||
"""HTTP方法"""
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
DELETE = "DELETE"
|
||||
PATCH = "PATCH"
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIResponse:
|
||||
"""API响应"""
|
||||
status_code: int
|
||||
data: Any
|
||||
headers: Dict[str, str]
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIRequest:
|
||||
"""API请求"""
|
||||
method: HTTPMethod
|
||||
url: str
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
data: Optional[Any] = None
|
||||
json_data: Optional[Dict[str, Any]] = None
|
||||
timeout: int = 30
|
||||
|
||||
|
||||
class APIClient:
|
||||
"""
|
||||
API客户端
|
||||
|
||||
特性:
|
||||
- 支持多种HTTP方法
|
||||
- 自动JSON序列化/反序列化
|
||||
- 超时处理
|
||||
- 错误处理
|
||||
- 请求/响应拦截器
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
timeout: int = 30
|
||||
):
|
||||
"""
|
||||
初始化API客户端
|
||||
|
||||
Args:
|
||||
base_url: 基础URL
|
||||
default_headers: 默认请求头
|
||||
timeout: 默认超时时间(秒)
|
||||
"""
|
||||
self._base_url = base_url.rstrip('/')
|
||||
self._default_headers = default_headers or {}
|
||||
self._default_timeout = timeout
|
||||
self._request_interceptors: list = []
|
||||
self._response_interceptors: list = []
|
||||
self._session = requests.Session()
|
||||
|
||||
def add_request_interceptor(self, interceptor: Callable[[APIRequest], APIRequest]) -> None:
|
||||
"""添加请求拦截器"""
|
||||
self._request_interceptors.append(interceptor)
|
||||
|
||||
def add_response_interceptor(self, interceptor: Callable[[APIResponse], APIResponse]) -> None:
|
||||
"""添加响应拦截器"""
|
||||
self._response_interceptors.append(interceptor)
|
||||
|
||||
def _build_url(self, endpoint: str) -> str:
|
||||
"""构建完整URL"""
|
||||
endpoint = endpoint.lstrip('/')
|
||||
return f"{self._base_url}/{endpoint}"
|
||||
|
||||
def _apply_request_interceptors(self, request: APIRequest) -> APIRequest:
|
||||
"""应用请求拦截器"""
|
||||
for interceptor in self._request_interceptors:
|
||||
request = interceptor(request)
|
||||
return request
|
||||
|
||||
def _apply_response_interceptors(self, response: APIResponse) -> APIResponse:
|
||||
"""应用响应拦截器"""
|
||||
for interceptor in self._response_interceptors:
|
||||
response = interceptor(response)
|
||||
return response
|
||||
|
||||
def request(self, api_request: APIRequest) -> APIResponse:
|
||||
"""
|
||||
发送HTTP请求
|
||||
|
||||
Args:
|
||||
api_request: API请求对象
|
||||
|
||||
Returns:
|
||||
API响应对象
|
||||
"""
|
||||
try:
|
||||
# 应用请求拦截器
|
||||
api_request = self._apply_request_interceptors(api_request)
|
||||
|
||||
# 合并默认请求头
|
||||
headers = {**self._default_headers, **(api_request.headers or {})}
|
||||
|
||||
# 发送请求
|
||||
response = self._session.request(
|
||||
method=api_request.method.value,
|
||||
url=api_request.url,
|
||||
headers=headers,
|
||||
params=api_request.params,
|
||||
data=api_request.data,
|
||||
json=api_request.json_data,
|
||||
timeout=api_request.timeout or self._default_timeout
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
try:
|
||||
response_data = response.json()
|
||||
except json.JSONDecodeError:
|
||||
response_data = response.text
|
||||
|
||||
api_response = APIResponse(
|
||||
status_code=response.status_code,
|
||||
data=response_data,
|
||||
headers=dict(response.headers),
|
||||
success=200 <= response.status_code < 300
|
||||
)
|
||||
|
||||
# 应用响应拦截器
|
||||
return self._apply_response_interceptors(api_response)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
return APIResponse(
|
||||
status_code=0,
|
||||
data=None,
|
||||
headers={},
|
||||
success=False,
|
||||
error_message="请求超时"
|
||||
)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
return APIResponse(
|
||||
status_code=0,
|
||||
data=None,
|
||||
headers={},
|
||||
success=False,
|
||||
error_message=f"连接错误: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
return APIResponse(
|
||||
status_code=0,
|
||||
data=None,
|
||||
headers={},
|
||||
success=False,
|
||||
error_message=f"请求失败: {str(e)}"
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
endpoint: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> APIResponse:
|
||||
"""发送GET请求"""
|
||||
request = APIRequest(
|
||||
method=HTTPMethod.GET,
|
||||
url=self._build_url(endpoint),
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout or self._default_timeout
|
||||
)
|
||||
return self.request(request)
|
||||
|
||||
def post(
|
||||
self,
|
||||
endpoint: str,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Any] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> APIResponse:
|
||||
"""发送POST请求"""
|
||||
request = APIRequest(
|
||||
method=HTTPMethod.POST,
|
||||
url=self._build_url(endpoint),
|
||||
headers=headers,
|
||||
data=data,
|
||||
json_data=json_data,
|
||||
timeout=timeout or self._default_timeout
|
||||
)
|
||||
return self.request(request)
|
||||
|
||||
def put(
|
||||
self,
|
||||
endpoint: str,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> APIResponse:
|
||||
"""发送PUT请求"""
|
||||
request = APIRequest(
|
||||
method=HTTPMethod.PUT,
|
||||
url=self._build_url(endpoint),
|
||||
headers=headers,
|
||||
json_data=json_data,
|
||||
timeout=timeout or self._default_timeout
|
||||
)
|
||||
return self.request(request)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
endpoint: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> APIResponse:
|
||||
"""发送DELETE请求"""
|
||||
request = APIRequest(
|
||||
method=HTTPMethod.DELETE,
|
||||
url=self._build_url(endpoint),
|
||||
headers=headers,
|
||||
timeout=timeout or self._default_timeout
|
||||
)
|
||||
return self.request(request)
|
||||
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
API模拟服务器
|
||||
|
||||
为TDD Green阶段提供模拟API服务,使测试能够通过。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
import time
|
||||
|
||||
|
||||
@dataclass
|
||||
class Role:
|
||||
"""角色数据模型"""
|
||||
id: str
|
||||
name: str
|
||||
code: str
|
||||
description: str
|
||||
status: str = "active"
|
||||
permissions: List[str] = field(default_factory=list)
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
class RoleMockService:
|
||||
"""角色管理模拟服务"""
|
||||
|
||||
def __init__(self):
|
||||
self._roles: Dict[str, Role] = {}
|
||||
self._init_default_roles()
|
||||
|
||||
def _init_default_roles(self):
|
||||
"""初始化默认角色"""
|
||||
default_roles = [
|
||||
Role(
|
||||
id="1",
|
||||
name="系统管理员",
|
||||
code="admin",
|
||||
description="拥有所有权限",
|
||||
permissions=["*"],
|
||||
created_at="2026-01-01T00:00:00Z"
|
||||
),
|
||||
Role(
|
||||
id="2",
|
||||
name="普通用户",
|
||||
code="user",
|
||||
description="拥有基本权限",
|
||||
permissions=["user:read", "user:write"],
|
||||
created_at="2026-01-01T00:00:00Z"
|
||||
),
|
||||
]
|
||||
for role in default_roles:
|
||||
self._roles[role.id] = role
|
||||
|
||||
def create_role(self, name: str, code: str, description: str,
|
||||
permissions: List[str] = None) -> Dict[str, Any]:
|
||||
"""创建角色"""
|
||||
# 检查角色编码是否已存在
|
||||
for role in self._roles.values():
|
||||
if role.code == code:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"角色编码 '{code}' 已存在",
|
||||
"code": "DUPLICATE_CODE"
|
||||
}
|
||||
|
||||
# 创建新角色
|
||||
new_role = Role(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
code=code,
|
||||
description=description,
|
||||
permissions=permissions or [],
|
||||
created_at=time.strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
updated_at=time.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
)
|
||||
|
||||
self._roles[new_role.id] = new_role
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "角色创建成功",
|
||||
"data": self._role_to_dict(new_role)
|
||||
}
|
||||
|
||||
def update_role(self, role_id: str, name: str = None,
|
||||
description: str = None, permissions: List[str] = None) -> Dict[str, Any]:
|
||||
"""更新角色"""
|
||||
if role_id not in self._roles:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"角色ID '{role_id}' 不存在",
|
||||
"code": "NOT_FOUND"
|
||||
}
|
||||
|
||||
role = self._roles[role_id]
|
||||
|
||||
if name:
|
||||
role.name = name
|
||||
if description:
|
||||
role.description = description
|
||||
if permissions is not None:
|
||||
role.permissions = permissions
|
||||
|
||||
role.updated_at = time.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "角色更新成功",
|
||||
"data": self._role_to_dict(role)
|
||||
}
|
||||
|
||||
def delete_role(self, role_id: str) -> Dict[str, Any]:
|
||||
"""删除角色"""
|
||||
if role_id not in self._roles:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"角色ID '{role_id}' 不存在",
|
||||
"code": "NOT_FOUND"
|
||||
}
|
||||
|
||||
# 不允许删除系统默认角色
|
||||
role = self._roles[role_id]
|
||||
if role.code in ["admin", "user"]:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"不能删除系统默认角色 '{role.code}'",
|
||||
"code": "FORBIDDEN"
|
||||
}
|
||||
|
||||
del self._roles[role_id]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "角色删除成功"
|
||||
}
|
||||
|
||||
def get_role(self, role_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取角色详情"""
|
||||
if role_id not in self._roles:
|
||||
return None
|
||||
return self._role_to_dict(self._roles[role_id])
|
||||
|
||||
def list_roles(self, keyword: str = None) -> Dict[str, Any]:
|
||||
"""获取角色列表"""
|
||||
roles = list(self._roles.values())
|
||||
|
||||
# 搜索过滤
|
||||
if keyword:
|
||||
keyword = keyword.lower()
|
||||
roles = [
|
||||
role for role in roles
|
||||
if keyword in role.name.lower()
|
||||
or keyword in role.code.lower()
|
||||
or keyword in role.description.lower()
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"list": [self._role_to_dict(role) for role in roles],
|
||||
"total": len(roles)
|
||||
}
|
||||
}
|
||||
|
||||
def assign_permissions(self, role_id: str, permissions: List[str]) -> Dict[str, Any]:
|
||||
"""分配权限"""
|
||||
if role_id not in self._roles:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"角色ID '{role_id}' 不存在",
|
||||
"code": "NOT_FOUND"
|
||||
}
|
||||
|
||||
role = self._roles[role_id]
|
||||
role.permissions = permissions
|
||||
role.updated_at = time.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "权限分配成功",
|
||||
"data": self._role_to_dict(role)
|
||||
}
|
||||
|
||||
def _role_to_dict(self, role: Role) -> Dict[str, Any]:
|
||||
"""角色对象转字典"""
|
||||
return {
|
||||
"id": role.id,
|
||||
"name": role.name,
|
||||
"code": role.code,
|
||||
"description": role.description,
|
||||
"status": role.status,
|
||||
"permissions": role.permissions,
|
||||
"createdAt": role.created_at,
|
||||
"updatedAt": role.updated_at
|
||||
}
|
||||
|
||||
|
||||
# 全局模拟服务实例
|
||||
role_mock_service = RoleMockService()
|
||||
|
||||
|
||||
# 模拟API端点
|
||||
def mock_api_create_role(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""模拟创建角色API"""
|
||||
return role_mock_service.create_role(
|
||||
name=data.get("name", ""),
|
||||
code=data.get("code", ""),
|
||||
description=data.get("description", ""),
|
||||
permissions=data.get("permissions", [])
|
||||
)
|
||||
|
||||
|
||||
def mock_api_update_role(role_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""模拟更新角色API"""
|
||||
return role_mock_service.update_role(
|
||||
role_id=role_id,
|
||||
name=data.get("name"),
|
||||
description=data.get("description"),
|
||||
permissions=data.get("permissions")
|
||||
)
|
||||
|
||||
|
||||
def mock_api_delete_role(role_id: str) -> Dict[str, Any]:
|
||||
"""模拟删除角色API"""
|
||||
return role_mock_service.delete_role(role_id)
|
||||
|
||||
|
||||
def mock_api_list_roles(keyword: str = None) -> Dict[str, Any]:
|
||||
"""模拟获取角色列表API"""
|
||||
return role_mock_service.list_roles(keyword)
|
||||
|
||||
|
||||
def mock_api_assign_permissions(role_id: str, permissions: List[str]) -> Dict[str, Any]:
|
||||
"""模拟分配权限API"""
|
||||
return role_mock_service.assign_permissions(role_id, permissions)
|
||||
@@ -0,0 +1,427 @@
|
||||
"""
|
||||
审计日志模块
|
||||
|
||||
提供操作日志记录和JaVers风格的对象变更审计功能。
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import functools
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperationLogEntry:
|
||||
"""操作日志条目"""
|
||||
id: str
|
||||
operation_time: datetime
|
||||
module_name: str
|
||||
operation_desc: str
|
||||
operator: str
|
||||
operator_id: Optional[int] = None
|
||||
request_method: Optional[str] = None
|
||||
request_path: Optional[str] = None
|
||||
request_params: Optional[str] = None
|
||||
response_result: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
execution_time: Optional[int] = None # 执行时间(毫秒)
|
||||
status: str = "SUCCESS"
|
||||
exception_message: Optional[str] = None
|
||||
diff_json: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObjectChange:
|
||||
"""对象变更记录"""
|
||||
field_name: str
|
||||
old_value: Any
|
||||
new_value: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffResult:
|
||||
"""差异比较结果"""
|
||||
has_changes: bool
|
||||
changes: List[ObjectChange]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditStatistics:
|
||||
"""审计统计信息"""
|
||||
total_operations: int = 0
|
||||
success_count: int = 0
|
||||
failure_count: int = 0
|
||||
module_distribution: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class AuditLogStorage(ABC):
|
||||
"""审计日志存储抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, log_entry: Dict[str, Any]) -> None:
|
||||
"""保存日志条目"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self,
|
||||
module_name: Optional[str] = None,
|
||||
operator: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
status: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""查询日志"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有日志"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_old(self, max_keep: int) -> int:
|
||||
"""删除旧日志"""
|
||||
pass
|
||||
|
||||
|
||||
class MemoryAuditStorage(AuditLogStorage):
|
||||
"""内存审计日志存储"""
|
||||
|
||||
def __init__(self):
|
||||
self._logs: List[Dict[str, Any]] = []
|
||||
|
||||
def save(self, log_entry: Dict[str, Any]) -> None:
|
||||
self._logs.append(log_entry)
|
||||
|
||||
def query(
|
||||
self,
|
||||
module_name: Optional[str] = None,
|
||||
operator: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
status: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
result = self._logs
|
||||
|
||||
if module_name:
|
||||
result = [log for log in result if log.get("module_name") == module_name]
|
||||
|
||||
if operator:
|
||||
result = [log for log in result if log.get("operator") == operator]
|
||||
|
||||
if start_time:
|
||||
result = [log for log in result if log.get("timestamp", 0) >= start_time]
|
||||
|
||||
if end_time:
|
||||
result = [log for log in result if log.get("timestamp", 0) <= end_time]
|
||||
|
||||
if status:
|
||||
result = [log for log in result if log.get("status") == status]
|
||||
|
||||
return result
|
||||
|
||||
def get_all(self) -> List[Dict[str, Any]]:
|
||||
return self._logs.copy()
|
||||
|
||||
def delete_old(self, max_keep: int) -> int:
|
||||
if len(self._logs) <= max_keep:
|
||||
return 0
|
||||
|
||||
# 按时间排序,保留最新的
|
||||
sorted_logs = sorted(self._logs, key=lambda x: x.get("timestamp", 0), reverse=True)
|
||||
self._logs = sorted_logs[:max_keep]
|
||||
deleted_count = len(sorted_logs) - len(self._logs)
|
||||
return deleted_count
|
||||
|
||||
|
||||
class OperationLogRecorder:
|
||||
"""操作日志记录器"""
|
||||
|
||||
def __init__(self, storage: Optional[AuditLogStorage] = None):
|
||||
self._storage = storage or MemoryAuditStorage()
|
||||
|
||||
def record(
|
||||
self,
|
||||
module_name: str,
|
||||
operation_desc: str,
|
||||
operator: str,
|
||||
operator_id: Optional[int] = None,
|
||||
request_method: Optional[str] = None,
|
||||
request_path: Optional[str] = None,
|
||||
request_params: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
execution_time: Optional[int] = None,
|
||||
status: str = "SUCCESS",
|
||||
exception_message: Optional[str] = None,
|
||||
diff_json: Optional[str] = None
|
||||
) -> OperationLogEntry:
|
||||
"""记录操作日志"""
|
||||
entry = OperationLogEntry(
|
||||
id=str(uuid.uuid4()),
|
||||
operation_time=datetime.now(),
|
||||
module_name=module_name,
|
||||
operation_desc=operation_desc,
|
||||
operator=operator,
|
||||
operator_id=operator_id,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_params=request_params,
|
||||
ip_address=ip_address,
|
||||
execution_time=execution_time,
|
||||
status=status,
|
||||
exception_message=exception_message,
|
||||
diff_json=diff_json
|
||||
)
|
||||
|
||||
# 转换为字典并保存
|
||||
log_dict = {
|
||||
"id": entry.id,
|
||||
"timestamp": time.time(),
|
||||
"module_name": entry.module_name,
|
||||
"operation_desc": entry.operation_desc,
|
||||
"operator": entry.operator,
|
||||
"operator_id": entry.operator_id,
|
||||
"request_method": entry.request_method,
|
||||
"request_path": entry.request_path,
|
||||
"request_params": entry.request_params,
|
||||
"ip_address": entry.ip_address,
|
||||
"execution_time": entry.execution_time,
|
||||
"status": entry.status,
|
||||
"exception_message": entry.exception_message,
|
||||
"diff_json": entry.diff_json,
|
||||
}
|
||||
|
||||
self._storage.save(log_dict)
|
||||
return entry
|
||||
|
||||
def query_logs(
|
||||
self,
|
||||
module_name: Optional[str] = None,
|
||||
operator: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
status: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""查询操作日志"""
|
||||
return self._storage.query(
|
||||
module_name=module_name,
|
||||
operator=operator,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
status=status
|
||||
)
|
||||
|
||||
def get_statistics(self) -> AuditStatistics:
|
||||
"""获取统计信息"""
|
||||
logs = self._storage.get_all()
|
||||
|
||||
stats = AuditStatistics()
|
||||
stats.total_operations = len(logs)
|
||||
|
||||
for log in logs:
|
||||
status = log.get("status", "SUCCESS")
|
||||
if status == "SUCCESS":
|
||||
stats.success_count += 1
|
||||
else:
|
||||
stats.failure_count += 1
|
||||
|
||||
module = log.get("module_name", "unknown")
|
||||
stats.module_distribution[module] = stats.module_distribution.get(module, 0) + 1
|
||||
|
||||
return stats
|
||||
|
||||
def cleanup(self, max_keep: int = 1000) -> int:
|
||||
"""清理旧日志"""
|
||||
return self._storage.delete_old(max_keep)
|
||||
|
||||
|
||||
class ObjectChangeAuditor:
|
||||
"""对象变更审计器(JaVers风格)"""
|
||||
|
||||
def compare(self, old_object: Dict[str, Any], new_object: Dict[str, Any]) -> DiffResult:
|
||||
"""
|
||||
比较两个对象的差异
|
||||
|
||||
Args:
|
||||
old_object: 旧对象
|
||||
new_object: 新对象
|
||||
|
||||
Returns:
|
||||
差异结果
|
||||
"""
|
||||
changes = []
|
||||
|
||||
# 获取所有字段
|
||||
all_keys = set(old_object.keys()) | set(new_object.keys())
|
||||
|
||||
for key in all_keys:
|
||||
old_value = old_object.get(key)
|
||||
new_value = new_object.get(key)
|
||||
|
||||
if old_value != new_value:
|
||||
changes.append(ObjectChange(
|
||||
field_name=key,
|
||||
old_value=old_value,
|
||||
new_value=new_value
|
||||
))
|
||||
|
||||
return DiffResult(
|
||||
has_changes=len(changes) > 0,
|
||||
changes=changes
|
||||
)
|
||||
|
||||
def get_changed_fields(
|
||||
self,
|
||||
old_object: Dict[str, Any],
|
||||
new_object: Dict[str, Any]
|
||||
) -> List[ObjectChange]:
|
||||
"""获取变更的字段列表"""
|
||||
diff_result = self.compare(old_object, new_object)
|
||||
return diff_result.changes
|
||||
|
||||
def to_json(self, obj: Any) -> str:
|
||||
"""将对象转换为JSON字符串"""
|
||||
return json.dumps(obj, ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
class AuditLogExporter:
|
||||
"""审计日志导出器"""
|
||||
|
||||
def __init__(self, recorder: OperationLogRecorder):
|
||||
self._recorder = recorder
|
||||
|
||||
def export_to_json(
|
||||
self,
|
||||
module_name: Optional[str] = None,
|
||||
operator: Optional[str] = None
|
||||
) -> str:
|
||||
"""导出为JSON格式"""
|
||||
logs = self._recorder.query_logs(
|
||||
module_name=module_name,
|
||||
operator=operator
|
||||
)
|
||||
return json.dumps(logs, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
def export_to_csv(
|
||||
self,
|
||||
module_name: Optional[str] = None,
|
||||
operator: Optional[str] = None
|
||||
) -> str:
|
||||
"""导出为CSV格式"""
|
||||
logs = self._recorder.query_logs(
|
||||
module_name=module_name,
|
||||
operator=operator
|
||||
)
|
||||
|
||||
if not logs:
|
||||
return ""
|
||||
|
||||
# 获取表头
|
||||
headers = ["timestamp", "module_name", "operation_desc", "operator", "status"]
|
||||
|
||||
# 生成CSV
|
||||
lines = [",".join(headers)]
|
||||
for log in logs:
|
||||
values = [
|
||||
str(log.get(h, "")) for h in headers
|
||||
]
|
||||
lines.append(",".join(values))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class AuditLogRecorder:
|
||||
"""统一的审计日志记录器"""
|
||||
|
||||
def __init__(self, storage: Optional[AuditLogStorage] = None):
|
||||
self._operation_recorder = OperationLogRecorder(storage)
|
||||
self._change_auditor = ObjectChangeAuditor()
|
||||
|
||||
def record_operation(self, **kwargs) -> OperationLogEntry:
|
||||
"""记录操作日志"""
|
||||
return self._operation_recorder.record(**kwargs)
|
||||
|
||||
def query_logs(self, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""查询日志"""
|
||||
return self._operation_recorder.query_logs(**kwargs)
|
||||
|
||||
def record_change(
|
||||
self,
|
||||
old_object: Dict[str, Any],
|
||||
new_object: Dict[str, Any],
|
||||
**kwargs
|
||||
) -> OperationLogEntry:
|
||||
"""记录对象变更"""
|
||||
# 比较差异
|
||||
diff_result = self._change_auditor.compare(old_object, new_object)
|
||||
|
||||
# 生成差异JSON
|
||||
diff_json = json.dumps([
|
||||
{
|
||||
"field": c.field_name,
|
||||
"old": c.old_value,
|
||||
"new": c.new_value
|
||||
}
|
||||
for c in diff_result.changes
|
||||
], ensure_ascii=False)
|
||||
|
||||
# 记录操作日志
|
||||
return self._operation_recorder.record(
|
||||
diff_json=diff_json,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def query_logs(self, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""查询日志"""
|
||||
return self._operation_recorder.query_logs(**kwargs)
|
||||
|
||||
|
||||
def audit_log(
|
||||
recorder: AuditLogRecorder,
|
||||
module_name: str,
|
||||
operation_desc: str
|
||||
):
|
||||
"""
|
||||
审计日志装饰器
|
||||
|
||||
Args:
|
||||
recorder: 审计日志记录器
|
||||
module_name: 模块名称
|
||||
operation_desc: 操作描述
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
status = "SUCCESS"
|
||||
exception_msg = None
|
||||
result = None
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
status = "FAILURE"
|
||||
exception_msg = str(e)
|
||||
raise
|
||||
finally:
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 记录日志
|
||||
recorder.record_operation(
|
||||
module_name=module_name,
|
||||
operation_desc=operation_desc,
|
||||
operator="system", # 可以从上下文获取
|
||||
request_params=json.dumps({"args": args, "kwargs": kwargs}, default=str),
|
||||
execution_time=execution_time,
|
||||
status=status,
|
||||
exception_message=exception_msg
|
||||
)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
数据备份恢复功能模块
|
||||
|
||||
提供数据备份、恢复、验证和管理功能。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import gzip
|
||||
import hashlib
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackupResult:
|
||||
"""备份结果"""
|
||||
success: bool
|
||||
backup_id: Optional[str] = None
|
||||
backup_path: Optional[str] = None
|
||||
size: int = 0
|
||||
checksum: Optional[str] = None
|
||||
message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RestoreResult:
|
||||
"""恢复结果"""
|
||||
success: bool
|
||||
data: Optional[Any] = None
|
||||
message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerifyResult:
|
||||
"""验证结果"""
|
||||
is_valid: bool
|
||||
message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteResult:
|
||||
"""删除结果"""
|
||||
success: bool
|
||||
message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackupInfo:
|
||||
"""备份信息"""
|
||||
backup_id: str
|
||||
backup_name: str
|
||||
description: str
|
||||
created_at: float
|
||||
size: int
|
||||
checksum: str
|
||||
is_compressed: bool
|
||||
is_incremental: bool
|
||||
base_backup_id: Optional[str] = None
|
||||
|
||||
|
||||
class BackupManager:
|
||||
"""
|
||||
备份管理器
|
||||
|
||||
特性:
|
||||
- 支持完整备份和增量备份
|
||||
- 支持压缩
|
||||
- 支持校验和验证
|
||||
- 支持备份列表和查询
|
||||
"""
|
||||
|
||||
def __init__(self, backup_dir: str):
|
||||
"""
|
||||
初始化备份管理器
|
||||
|
||||
Args:
|
||||
backup_dir: 备份目录
|
||||
"""
|
||||
self._backup_dir = backup_dir
|
||||
self._backups: Dict[str, BackupInfo] = {}
|
||||
self._data_cache: Dict[str, Any] = {}
|
||||
|
||||
# 创建备份目录
|
||||
os.makedirs(backup_dir, exist_ok=True)
|
||||
|
||||
def backup(
|
||||
self,
|
||||
data: Any,
|
||||
backup_name: str,
|
||||
description: str = "",
|
||||
compress: bool = False
|
||||
) -> BackupResult:
|
||||
"""
|
||||
创建备份
|
||||
|
||||
Args:
|
||||
data: 要备份的数据
|
||||
backup_name: 备份名称
|
||||
description: 备份描述
|
||||
compress: 是否压缩
|
||||
|
||||
Returns:
|
||||
备份结果
|
||||
"""
|
||||
try:
|
||||
backup_id = str(uuid.uuid4())
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{backup_name}_{timestamp}_{backup_id}.json"
|
||||
|
||||
if compress:
|
||||
filename += ".gz"
|
||||
|
||||
backup_path = os.path.join(self._backup_dir, filename)
|
||||
|
||||
# 序列化数据
|
||||
json_data = json.dumps(data, ensure_ascii=False, default=str)
|
||||
|
||||
# 计算校验和
|
||||
checksum = hashlib.md5(json_data.encode('utf-8')).hexdigest()
|
||||
|
||||
# 写入文件
|
||||
if compress:
|
||||
with gzip.open(backup_path, 'wt', encoding='utf-8') as f:
|
||||
f.write(json_data)
|
||||
else:
|
||||
with open(backup_path, 'w', encoding='utf-8') as f:
|
||||
f.write(json_data)
|
||||
|
||||
# 获取文件大小
|
||||
size = os.path.getsize(backup_path)
|
||||
|
||||
# 记录备份信息
|
||||
backup_info = BackupInfo(
|
||||
backup_id=backup_id,
|
||||
backup_name=backup_name,
|
||||
description=description,
|
||||
created_at=time.time(),
|
||||
size=size,
|
||||
checksum=checksum,
|
||||
is_compressed=compress,
|
||||
is_incremental=False
|
||||
)
|
||||
self._backups[backup_id] = backup_info
|
||||
self._data_cache[backup_id] = data
|
||||
|
||||
return BackupResult(
|
||||
success=True,
|
||||
backup_id=backup_id,
|
||||
backup_path=backup_path,
|
||||
size=size,
|
||||
checksum=checksum
|
||||
)
|
||||
except Exception as e:
|
||||
return BackupResult(success=False, message=str(e))
|
||||
|
||||
def backup_incremental(
|
||||
self,
|
||||
base_backup_id: str,
|
||||
data: Any,
|
||||
backup_name: str,
|
||||
description: str = ""
|
||||
) -> BackupResult:
|
||||
"""
|
||||
创建增量备份
|
||||
|
||||
Args:
|
||||
base_backup_id: 基础备份ID
|
||||
data: 要备份的数据
|
||||
backup_name: 备份名称
|
||||
description: 备份描述
|
||||
|
||||
Returns:
|
||||
备份结果
|
||||
"""
|
||||
# 简化实现:增量备份存储差异
|
||||
result = self.backup(data, backup_name, description)
|
||||
|
||||
if result.success and result.backup_id:
|
||||
# 标记为增量备份
|
||||
backup_info = self._backups.get(result.backup_id)
|
||||
if backup_info:
|
||||
backup_info.is_incremental = True
|
||||
backup_info.base_backup_id = base_backup_id
|
||||
|
||||
return result
|
||||
|
||||
def restore(self, backup_id: str) -> RestoreResult:
|
||||
"""
|
||||
恢复备份
|
||||
|
||||
Args:
|
||||
backup_id: 备份ID
|
||||
|
||||
Returns:
|
||||
恢复结果
|
||||
"""
|
||||
try:
|
||||
# 检查缓存
|
||||
if backup_id in self._data_cache:
|
||||
return RestoreResult(
|
||||
success=True,
|
||||
data=self._data_cache[backup_id]
|
||||
)
|
||||
|
||||
# 查找备份文件
|
||||
backup_info = self._backups.get(backup_id)
|
||||
if not backup_info:
|
||||
return RestoreResult(success=False, message="备份不存在")
|
||||
|
||||
# 查找文件
|
||||
backup_path = None
|
||||
for filename in os.listdir(self._backup_dir):
|
||||
if backup_id in filename:
|
||||
backup_path = os.path.join(self._backup_dir, filename)
|
||||
break
|
||||
|
||||
if not backup_path or not os.path.exists(backup_path):
|
||||
return RestoreResult(success=False, message="备份文件不存在")
|
||||
|
||||
# 读取数据
|
||||
if backup_info.is_compressed or backup_path.endswith('.gz'):
|
||||
with gzip.open(backup_path, 'rt', encoding='utf-8') as f:
|
||||
json_data = f.read()
|
||||
else:
|
||||
with open(backup_path, 'r', encoding='utf-8') as f:
|
||||
json_data = f.read()
|
||||
|
||||
# 解析数据
|
||||
data = json.loads(json_data)
|
||||
|
||||
return RestoreResult(success=True, data=data)
|
||||
except Exception as e:
|
||||
return RestoreResult(success=False, message=str(e))
|
||||
|
||||
def list_backups(
|
||||
self,
|
||||
name_filter: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None
|
||||
) -> List[BackupInfo]:
|
||||
"""
|
||||
列出备份
|
||||
|
||||
Args:
|
||||
name_filter: 名称过滤
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
|
||||
Returns:
|
||||
备份信息列表
|
||||
"""
|
||||
result = list(self._backups.values())
|
||||
|
||||
if name_filter:
|
||||
result = [b for b in result if name_filter in b.backup_name]
|
||||
|
||||
if start_time:
|
||||
result = [b for b in result if b.created_at >= start_time]
|
||||
|
||||
if end_time:
|
||||
result = [b for b in result if b.created_at <= end_time]
|
||||
|
||||
# 按时间排序
|
||||
result.sort(key=lambda x: x.created_at, reverse=True)
|
||||
|
||||
return result
|
||||
|
||||
def verify_backup(self, backup_id: str) -> VerifyResult:
|
||||
"""
|
||||
验证备份
|
||||
|
||||
Args:
|
||||
backup_id: 备份ID
|
||||
|
||||
Returns:
|
||||
验证结果
|
||||
"""
|
||||
try:
|
||||
backup_info = self._backups.get(backup_id)
|
||||
if not backup_info:
|
||||
return VerifyResult(is_valid=False, message="备份不存在")
|
||||
|
||||
# 查找文件
|
||||
backup_path = None
|
||||
for filename in os.listdir(self._backup_dir):
|
||||
if backup_id in filename:
|
||||
backup_path = os.path.join(self._backup_dir, filename)
|
||||
break
|
||||
|
||||
if not backup_path or not os.path.exists(backup_path):
|
||||
return VerifyResult(is_valid=False, message="备份文件不存在")
|
||||
|
||||
# 读取并验证校验和
|
||||
if backup_info.is_compressed or backup_path.endswith('.gz'):
|
||||
with gzip.open(backup_path, 'rt', encoding='utf-8') as f:
|
||||
json_data = f.read()
|
||||
else:
|
||||
with open(backup_path, 'r', encoding='utf-8') as f:
|
||||
json_data = f.read()
|
||||
|
||||
current_checksum = hashlib.md5(json_data.encode('utf-8')).hexdigest()
|
||||
|
||||
if current_checksum != backup_info.checksum:
|
||||
return VerifyResult(is_valid=False, message="校验和不匹配")
|
||||
|
||||
return VerifyResult(is_valid=True, message="备份有效")
|
||||
except Exception as e:
|
||||
return VerifyResult(is_valid=False, message=str(e))
|
||||
|
||||
def delete_backup(self, backup_id: str) -> DeleteResult:
|
||||
"""
|
||||
删除备份
|
||||
|
||||
Args:
|
||||
backup_id: 备份ID
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
# 查找并删除文件
|
||||
for filename in os.listdir(self._backup_dir):
|
||||
if backup_id in filename:
|
||||
file_path = os.path.join(self._backup_dir, filename)
|
||||
if os.path.exists(file_path):
|
||||
os.unlink(file_path)
|
||||
break
|
||||
|
||||
# 从记录中移除
|
||||
if backup_id in self._backups:
|
||||
del self._backups[backup_id]
|
||||
|
||||
if backup_id in self._data_cache:
|
||||
del self._data_cache[backup_id]
|
||||
|
||||
return DeleteResult(success=True)
|
||||
except Exception as e:
|
||||
return DeleteResult(success=False, message=str(e))
|
||||
|
||||
|
||||
class BackupScheduler:
|
||||
"""
|
||||
备份调度器
|
||||
|
||||
支持定时自动备份
|
||||
"""
|
||||
|
||||
def __init__(self, backup_dir: str):
|
||||
"""
|
||||
初始化备份调度器
|
||||
|
||||
Args:
|
||||
backup_dir: 备份目录
|
||||
"""
|
||||
self._backup_manager = BackupManager(backup_dir)
|
||||
self._schedules: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def schedule_backup(
|
||||
self,
|
||||
data_source: Callable[[], Any],
|
||||
backup_name: str,
|
||||
interval_hours: int = 24,
|
||||
keep_count: int = 5,
|
||||
description: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
配置自动备份
|
||||
|
||||
Args:
|
||||
data_source: 数据源函数
|
||||
backup_name: 备份名称
|
||||
interval_hours: 备份间隔(小时)
|
||||
keep_count: 保留数量
|
||||
description: 备份描述
|
||||
"""
|
||||
self._schedules[backup_name] = {
|
||||
"data_source": data_source,
|
||||
"interval_hours": interval_hours,
|
||||
"keep_count": keep_count,
|
||||
"description": description,
|
||||
"last_backup_time": 0
|
||||
}
|
||||
|
||||
def trigger_backup(self, backup_name: str) -> BackupResult:
|
||||
"""
|
||||
手动触发备份
|
||||
|
||||
Args:
|
||||
backup_name: 备份名称
|
||||
|
||||
Returns:
|
||||
备份结果
|
||||
"""
|
||||
schedule = self._schedules.get(backup_name)
|
||||
if not schedule:
|
||||
return BackupResult(success=False, message="备份计划不存在")
|
||||
|
||||
# 获取数据
|
||||
data = schedule["data_source"]()
|
||||
|
||||
# 执行备份
|
||||
result = self._backup_manager.backup(
|
||||
data=data,
|
||||
backup_name=backup_name,
|
||||
description=schedule["description"]
|
||||
)
|
||||
|
||||
if result.success:
|
||||
schedule["last_backup_time"] = time.time()
|
||||
|
||||
# 清理旧备份
|
||||
self._cleanup_old_backups(backup_name, schedule["keep_count"])
|
||||
|
||||
return result
|
||||
|
||||
def _cleanup_old_backups(self, backup_name: str, keep_count: int) -> None:
|
||||
"""清理旧备份"""
|
||||
backups = self._backup_manager.list_backups(name_filter=backup_name)
|
||||
|
||||
if len(backups) > keep_count:
|
||||
# 删除最旧的备份
|
||||
for backup in backups[keep_count:]:
|
||||
self._backup_manager.delete_backup(backup.backup_id)
|
||||
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
缓存模块
|
||||
|
||||
提供内存缓存功能。
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""缓存条目"""
|
||||
value: Any
|
||||
expires_at: Optional[float] = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class Cache:
|
||||
"""缓存类"""
|
||||
|
||||
def __init__(self):
|
||||
self._data: Dict[str, CacheEntry] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
||||
"""
|
||||
设置缓存
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值
|
||||
ttl: 过期时间(秒)
|
||||
"""
|
||||
with self._lock:
|
||||
expires_at = None
|
||||
if ttl is not None:
|
||||
expires_at = time.time() + ttl
|
||||
|
||||
self._data[key] = CacheEntry(
|
||||
value=value,
|
||||
expires_at=expires_at
|
||||
)
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
获取缓存
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
缓存值,如果不存在或已过期则返回None
|
||||
"""
|
||||
with self._lock:
|
||||
entry = self._data.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
if entry.expires_at is not None and time.time() > entry.expires_at:
|
||||
del self._data[key]
|
||||
return None
|
||||
|
||||
return entry.value
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
删除缓存
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清除所有缓存"""
|
||||
with self._lock:
|
||||
self._data.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取缓存统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
with self._lock:
|
||||
# 清理过期条目
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, entry in self._data.items()
|
||||
if entry.expires_at is not None and current_time > entry.expires_at
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._data[key]
|
||||
|
||||
return {
|
||||
"size": len(self._data),
|
||||
"keys": list(self._data.keys())
|
||||
}
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
"""
|
||||
检查缓存是否存在
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
return self.get(key) is not None
|
||||
|
||||
def get_all(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取所有缓存
|
||||
|
||||
Returns:
|
||||
所有缓存数据
|
||||
"""
|
||||
with self._lock:
|
||||
result = {}
|
||||
current_time = time.time()
|
||||
|
||||
for key, entry in self._data.items():
|
||||
# 检查是否过期
|
||||
if entry.expires_at is not None and current_time > entry.expires_at:
|
||||
continue
|
||||
result[key] = entry.value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 全局缓存实例
|
||||
cache = Cache()
|
||||
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
Caffeine缓存管理模块
|
||||
|
||||
基于Caffeine的本地缓存管理实现。
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""缓存条目"""
|
||||
value: Any
|
||||
expires_at: Optional[float] = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
access_count: int = field(default=0)
|
||||
last_accessed: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class CaffeineCache:
|
||||
"""
|
||||
Caffeine风格的本地缓存实现
|
||||
|
||||
特性:
|
||||
- 支持TTL过期时间
|
||||
- 支持最大容量限制(LRU淘汰)
|
||||
- 支持统计信息
|
||||
- 线程安全
|
||||
- 批量操作支持
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 1000,
|
||||
default_expire_seconds: Optional[int] = None,
|
||||
record_stats: bool = False
|
||||
):
|
||||
"""
|
||||
初始化缓存
|
||||
|
||||
Args:
|
||||
max_size: 最大缓存条目数
|
||||
default_expire_seconds: 默认过期时间(秒)
|
||||
record_stats: 是否记录统计信息
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._default_expire_seconds = default_expire_seconds
|
||||
self._record_stats = record_stats
|
||||
self._data: Dict[str, CacheEntry] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# 统计信息
|
||||
self._stats = {
|
||||
"hit_count": 0,
|
||||
"miss_count": 0,
|
||||
"put_count": 0,
|
||||
"delete_count": 0,
|
||||
"eviction_count": 0,
|
||||
}
|
||||
|
||||
def put(self, key: str, value: Any, expire_seconds: Optional[int] = None) -> None:
|
||||
"""
|
||||
添加缓存条目
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值
|
||||
expire_seconds: 过期时间(秒),None表示使用默认值,0表示永不过期
|
||||
"""
|
||||
with self._lock:
|
||||
# 计算过期时间
|
||||
if expire_seconds is not None:
|
||||
expires_at = time.time() + expire_seconds if expire_seconds > 0 else None
|
||||
elif self._default_expire_seconds is not None:
|
||||
expires_at = time.time() + self._default_expire_seconds
|
||||
else:
|
||||
expires_at = None
|
||||
|
||||
# 创建缓存条目
|
||||
entry = CacheEntry(
|
||||
value=value,
|
||||
expires_at=expires_at
|
||||
)
|
||||
|
||||
# 检查是否需要淘汰
|
||||
if key not in self._data and len(self._data) >= self._max_size:
|
||||
self._evict_oldest()
|
||||
|
||||
# 存储数据
|
||||
self._data[key] = entry
|
||||
|
||||
if self._record_stats:
|
||||
self._stats["put_count"] += 1
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
获取缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
缓存值,不存在或已过期则返回None
|
||||
"""
|
||||
with self._lock:
|
||||
entry = self._data.get(key)
|
||||
|
||||
if entry is None:
|
||||
if self._record_stats:
|
||||
self._stats["miss_count"] += 1
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
if entry.expires_at is not None and time.time() > entry.expires_at:
|
||||
del self._data[key]
|
||||
if self._record_stats:
|
||||
self._stats["miss_count"] += 1
|
||||
self._stats["eviction_count"] += 1
|
||||
return None
|
||||
|
||||
# 更新访问信息
|
||||
entry.access_count += 1
|
||||
entry.last_accessed = time.time()
|
||||
|
||||
if self._record_stats:
|
||||
self._stats["hit_count"] += 1
|
||||
|
||||
return entry.value
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
检查键是否存在
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否存在且未过期
|
||||
"""
|
||||
with self._lock:
|
||||
entry = self._data.get(key)
|
||||
|
||||
if entry is None:
|
||||
return False
|
||||
|
||||
# 检查是否过期
|
||||
if entry.expires_at is not None and time.time() > entry.expires_at:
|
||||
del self._data[key]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
删除缓存条目
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
if self._record_stats:
|
||||
self._stats["delete_count"] += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_all(self, keys: List[str]) -> Dict[str, Optional[Any]]:
|
||||
"""
|
||||
批量获取缓存值
|
||||
|
||||
Args:
|
||||
keys: 缓存键列表
|
||||
|
||||
Returns:
|
||||
键值对字典
|
||||
"""
|
||||
result = {}
|
||||
for key in keys:
|
||||
result[key] = self.get(key)
|
||||
return result
|
||||
|
||||
def put_all(self, data: Dict[str, Any], expire_seconds: Optional[int] = None) -> None:
|
||||
"""
|
||||
批量添加缓存条目
|
||||
|
||||
Args:
|
||||
data: 键值对字典
|
||||
expire_seconds: 过期时间(秒)
|
||||
"""
|
||||
for key, value in data.items():
|
||||
self.put(key, value, expire_seconds)
|
||||
|
||||
def delete_all(self, keys: List[str]) -> int:
|
||||
"""
|
||||
批量删除缓存条目
|
||||
|
||||
Args:
|
||||
keys: 缓存键列表
|
||||
|
||||
Returns:
|
||||
成功删除的数量
|
||||
"""
|
||||
count = 0
|
||||
for key in keys:
|
||||
if self.delete(key):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有缓存"""
|
||||
with self._lock:
|
||||
self._data.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取缓存统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
with self._lock:
|
||||
stats = self._stats.copy()
|
||||
stats["size"] = len(self._data)
|
||||
stats["max_size"] = self._max_size
|
||||
|
||||
# 计算命中率
|
||||
total_requests = stats["hit_count"] + stats["miss_count"]
|
||||
if total_requests > 0:
|
||||
stats["hit_rate"] = stats["hit_count"] / total_requests
|
||||
else:
|
||||
stats["hit_rate"] = 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def _evict_oldest(self) -> None:
|
||||
"""淘汰最久未使用的条目"""
|
||||
if not self._data:
|
||||
return
|
||||
|
||||
# 找到最久未访问的条目
|
||||
oldest_key = min(
|
||||
self._data.keys(),
|
||||
key=lambda k: self._data[k].last_accessed
|
||||
)
|
||||
|
||||
del self._data[oldest_key]
|
||||
if self._record_stats:
|
||||
self._stats["eviction_count"] += 1
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
获取当前缓存大小
|
||||
|
||||
Returns:
|
||||
缓存条目数
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._data)
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
"""
|
||||
获取所有缓存键
|
||||
|
||||
Returns:
|
||||
缓存键列表
|
||||
"""
|
||||
with self._lock:
|
||||
return list(self._data.keys())
|
||||
|
||||
def values(self) -> List[Any]:
|
||||
"""
|
||||
获取所有缓存值
|
||||
|
||||
Returns:
|
||||
缓存值列表
|
||||
"""
|
||||
with self._lock:
|
||||
return [entry.value for entry in self._data.values()]
|
||||
|
||||
def items(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取所有缓存项
|
||||
|
||||
Returns:
|
||||
键值对字典
|
||||
"""
|
||||
with self._lock:
|
||||
return {k: v.value for k, v in self._data.items()}
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
清理过期条目
|
||||
|
||||
Returns:
|
||||
清理的条目数
|
||||
"""
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, entry in self._data.items()
|
||||
if entry.expires_at is not None and current_time > entry.expires_at
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self._data[key]
|
||||
if self._record_stats:
|
||||
self._stats["eviction_count"] += 1
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
|
||||
class CaffeineCacheManager:
|
||||
"""
|
||||
Caffeine缓存管理器
|
||||
|
||||
管理多个命名缓存实例
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._caches = {}
|
||||
return cls._instance
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
name: str,
|
||||
max_size: int = 1000,
|
||||
default_expire_seconds: Optional[int] = None,
|
||||
record_stats: bool = False
|
||||
) -> CaffeineCache:
|
||||
"""
|
||||
获取或创建命名缓存
|
||||
|
||||
Args:
|
||||
name: 缓存名称
|
||||
max_size: 最大缓存条目数
|
||||
default_expire_seconds: 默认过期时间
|
||||
record_stats: 是否记录统计信息
|
||||
|
||||
Returns:
|
||||
缓存实例
|
||||
"""
|
||||
if name not in self._caches:
|
||||
self._caches[name] = CaffeineCache(
|
||||
max_size=max_size,
|
||||
default_expire_seconds=default_expire_seconds,
|
||||
record_stats=record_stats
|
||||
)
|
||||
return self._caches[name]
|
||||
|
||||
def remove_cache(self, name: str) -> bool:
|
||||
"""
|
||||
移除缓存
|
||||
|
||||
Args:
|
||||
name: 缓存名称
|
||||
|
||||
Returns:
|
||||
是否成功移除
|
||||
"""
|
||||
if name in self._caches:
|
||||
del self._caches[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""清空所有缓存"""
|
||||
for cache in self._caches.values():
|
||||
cache.clear()
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
获取所有缓存的统计信息
|
||||
|
||||
Returns:
|
||||
缓存名称到统计信息的映射
|
||||
"""
|
||||
return {name: cache.get_stats() for name, cache in self._caches.items()}
|
||||
|
||||
|
||||
# 全局缓存管理器实例
|
||||
cache_manager = CaffeineCacheManager()
|
||||
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
本地并发控制模块
|
||||
|
||||
提供本地并发控制的各种机制,包括信号量、读写锁、限流器等。
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Generic
|
||||
from contextlib import contextmanager
|
||||
from collections import deque
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class SemaphoreControl:
|
||||
"""
|
||||
信号量并发控制
|
||||
|
||||
限制同时执行的线程数量
|
||||
"""
|
||||
|
||||
def __init__(self, max_concurrent: int):
|
||||
"""
|
||||
初始化信号量
|
||||
|
||||
Args:
|
||||
max_concurrent: 最大并发数
|
||||
"""
|
||||
self._max_concurrent = max_concurrent
|
||||
self._semaphore = threading.Semaphore(max_concurrent)
|
||||
self._stats = {
|
||||
"total_acquisitions": 0,
|
||||
"active_count": 0,
|
||||
"peak_count": 0,
|
||||
}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@contextmanager
|
||||
def acquire(self, timeout: Optional[float] = None):
|
||||
"""
|
||||
获取信号量(上下文管理器)
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
"""
|
||||
acquired = False
|
||||
try:
|
||||
acquired = self._semaphore.acquire(timeout=timeout)
|
||||
if not acquired:
|
||||
raise TimeoutError("获取信号量超时")
|
||||
|
||||
with self._lock:
|
||||
self._stats["total_acquisitions"] += 1
|
||||
self._stats["active_count"] += 1
|
||||
self._stats["peak_count"] = max(
|
||||
self._stats["peak_count"],
|
||||
self._stats["active_count"]
|
||||
)
|
||||
|
||||
yield self
|
||||
finally:
|
||||
if acquired:
|
||||
with self._lock:
|
||||
self._stats["active_count"] -= 1
|
||||
self._semaphore.release()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
with self._lock:
|
||||
return self._stats.copy()
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
读写锁
|
||||
|
||||
支持多个读线程同时访问,写线程独占访问
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._read_lock = threading.Lock()
|
||||
self._write_lock = threading.Lock()
|
||||
self._read_count = 0
|
||||
|
||||
@contextmanager
|
||||
def read_lock(self):
|
||||
"""获取读锁"""
|
||||
with self._read_lock:
|
||||
self._read_count += 1
|
||||
if self._read_count == 1:
|
||||
self._write_lock.acquire()
|
||||
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
with self._read_lock:
|
||||
self._read_count -= 1
|
||||
if self._read_count == 0:
|
||||
self._write_lock.release()
|
||||
|
||||
@contextmanager
|
||||
def write_lock(self):
|
||||
"""获取写锁"""
|
||||
self._write_lock.acquire()
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
self._write_lock.release()
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
限流器
|
||||
|
||||
限制单位时间内的请求数量
|
||||
"""
|
||||
|
||||
def __init__(self, max_requests: int, time_window: float):
|
||||
"""
|
||||
初始化限流器
|
||||
|
||||
Args:
|
||||
max_requests: 时间窗口内最大请求数
|
||||
time_window: 时间窗口(秒)
|
||||
"""
|
||||
self._max_requests = max_requests
|
||||
self._time_window = time_window
|
||||
self._requests = deque()
|
||||
self._lock = threading.Lock()
|
||||
self._stats = {
|
||||
"total_requests": 0,
|
||||
"allowed_requests": 0,
|
||||
"blocked_requests": 0,
|
||||
}
|
||||
|
||||
def allow_request(self) -> bool:
|
||||
"""
|
||||
检查是否允许请求
|
||||
|
||||
Returns:
|
||||
是否允许
|
||||
"""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
|
||||
# 清理过期的请求记录
|
||||
while self._requests and self._requests[0] < now - self._time_window:
|
||||
self._requests.popleft()
|
||||
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
if len(self._requests) < self._max_requests:
|
||||
self._requests.append(now)
|
||||
self._stats["allowed_requests"] += 1
|
||||
return True
|
||||
else:
|
||||
self._stats["blocked_requests"] += 1
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
with self._lock:
|
||||
return self._stats.copy()
|
||||
|
||||
|
||||
class LocalDistributedLock:
|
||||
"""
|
||||
本地模拟的分布式锁
|
||||
|
||||
用于单体应用内的分布式锁场景模拟
|
||||
"""
|
||||
|
||||
_locks: Dict[str, Dict[str, Any]] = {}
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
def __init__(self, resource_name: str, expire_seconds: float = 30.0):
|
||||
"""
|
||||
初始化分布式锁
|
||||
|
||||
Args:
|
||||
resource_name: 资源名称
|
||||
expire_seconds: 锁过期时间(秒)
|
||||
"""
|
||||
self._resource_name = resource_name
|
||||
self._expire_seconds = expire_seconds
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def acquire(self, timeout: float = 10.0) -> bool:
|
||||
"""
|
||||
获取锁
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
with self._global_lock:
|
||||
now = time.time()
|
||||
lock_info = self._locks.get(self._resource_name)
|
||||
|
||||
# 检查锁是否过期
|
||||
if lock_info and now > lock_info["expires_at"]:
|
||||
del self._locks[self._resource_name]
|
||||
lock_info = None
|
||||
|
||||
# 尝试获取锁
|
||||
if not lock_info:
|
||||
self._locks[self._resource_name] = {
|
||||
"expires_at": now + self._expire_seconds,
|
||||
}
|
||||
return True
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
return False
|
||||
|
||||
def release(self) -> None:
|
||||
"""释放锁"""
|
||||
with self._global_lock:
|
||||
if self._resource_name in self._locks:
|
||||
del self._locks[self._resource_name]
|
||||
|
||||
|
||||
class ConcurrentCounter:
|
||||
"""
|
||||
并发计数器
|
||||
|
||||
线程安全的计数器实现
|
||||
"""
|
||||
|
||||
def __init__(self, initial_value: int = 0):
|
||||
"""
|
||||
初始化计数器
|
||||
|
||||
Args:
|
||||
initial_value: 初始值
|
||||
"""
|
||||
self._value = initial_value
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def increment(self, delta: int = 1) -> int:
|
||||
"""
|
||||
增加计数
|
||||
|
||||
Args:
|
||||
delta: 增量
|
||||
|
||||
Returns:
|
||||
增加后的值
|
||||
"""
|
||||
with self._lock:
|
||||
self._value += delta
|
||||
return self._value
|
||||
|
||||
def decrement(self, delta: int = 1) -> int:
|
||||
"""
|
||||
减少计数
|
||||
|
||||
Args:
|
||||
delta: 减量
|
||||
|
||||
Returns:
|
||||
减少后的值
|
||||
"""
|
||||
with self._lock:
|
||||
self._value -= delta
|
||||
return self._value
|
||||
|
||||
def get_value(self) -> int:
|
||||
"""获取当前值"""
|
||||
with self._lock:
|
||||
return self._value
|
||||
|
||||
def reset(self, value: int = 0) -> None:
|
||||
"""重置计数器"""
|
||||
with self._lock:
|
||||
self._value = value
|
||||
|
||||
|
||||
class ThreadBarrier:
|
||||
"""
|
||||
线程屏障
|
||||
|
||||
等待指定数量的线程到达后同时放行
|
||||
"""
|
||||
|
||||
def __init__(self, parties: int):
|
||||
"""
|
||||
初始化屏障
|
||||
|
||||
Args:
|
||||
parties: 需要等待的线程数
|
||||
"""
|
||||
self._parties = parties
|
||||
self._count = 0
|
||||
self._lock = threading.Lock()
|
||||
self._condition = threading.Condition(self._lock)
|
||||
|
||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
"""
|
||||
等待屏障
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功等待
|
||||
"""
|
||||
with self._condition:
|
||||
self._count += 1
|
||||
|
||||
if self._count >= self._parties:
|
||||
# 所有线程已到达,放行
|
||||
self._count = 0
|
||||
self._condition.notify_all()
|
||||
return True
|
||||
else:
|
||||
# 等待其他线程
|
||||
return self._condition.wait(timeout=timeout)
|
||||
|
||||
|
||||
class BoundedTaskQueue(Generic[T]):
|
||||
"""
|
||||
有界任务队列
|
||||
|
||||
有容量限制的线程安全队列
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int):
|
||||
"""
|
||||
初始化队列
|
||||
|
||||
Args:
|
||||
max_size: 最大容量
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._queue = deque(maxlen=max_size)
|
||||
self._lock = threading.Lock()
|
||||
self._not_full = threading.Condition(self._lock)
|
||||
self._not_empty = threading.Condition(self._lock)
|
||||
|
||||
def put(self, item: T, timeout: Optional[float] = None) -> bool:
|
||||
"""
|
||||
添加元素
|
||||
|
||||
Args:
|
||||
item: 元素
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
with self._not_full:
|
||||
if len(self._queue) >= self._max_size:
|
||||
if not self._not_full.wait(timeout=timeout):
|
||||
raise TimeoutError("队列已满")
|
||||
|
||||
self._queue.append(item)
|
||||
self._not_empty.notify()
|
||||
return True
|
||||
|
||||
def get(self, timeout: Optional[float] = None) -> T:
|
||||
"""
|
||||
获取元素
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
元素
|
||||
"""
|
||||
with self._not_empty:
|
||||
while len(self._queue) == 0:
|
||||
if not self._not_empty.wait(timeout=timeout):
|
||||
raise TimeoutError("队列为空")
|
||||
|
||||
item = self._queue.popleft()
|
||||
self._not_full.notify()
|
||||
return item
|
||||
|
||||
def size(self) -> int:
|
||||
"""获取队列大小"""
|
||||
with self._lock:
|
||||
return len(self._queue)
|
||||
|
||||
def is_full(self) -> bool:
|
||||
"""检查是否已满"""
|
||||
with self._lock:
|
||||
return len(self._queue) >= self._max_size
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""检查是否为空"""
|
||||
with self._lock:
|
||||
return len(self._queue) == 0
|
||||
|
||||
|
||||
class ConcurrencyManager:
|
||||
"""
|
||||
并发控制器管理器
|
||||
|
||||
单例模式管理所有并发控制组件
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._semaphores: Dict[str, SemaphoreControl] = {}
|
||||
cls._instance._rw_locks: Dict[str, ReadWriteLock] = {}
|
||||
cls._instance._rate_limiters: Dict[str, RateLimiter] = {}
|
||||
cls._instance._locks: Dict[str, LocalDistributedLock] = {}
|
||||
return cls._instance
|
||||
|
||||
def create_semaphore(self, name: str, max_concurrent: int) -> SemaphoreControl:
|
||||
"""创建命名信号量"""
|
||||
if name not in self._semaphores:
|
||||
self._semaphores[name] = SemaphoreControl(max_concurrent)
|
||||
return self._semaphores[name]
|
||||
|
||||
def get_semaphore(self, name: str) -> Optional[SemaphoreControl]:
|
||||
"""获取命名信号量"""
|
||||
return self._semaphores.get(name)
|
||||
|
||||
def create_rw_lock(self, name: str) -> ReadWriteLock:
|
||||
"""创建命名读写锁"""
|
||||
if name not in self._rw_locks:
|
||||
self._rw_locks[name] = ReadWriteLock()
|
||||
return self._rw_locks[name]
|
||||
|
||||
def get_rw_lock(self, name: str) -> Optional[ReadWriteLock]:
|
||||
"""获取命名读写锁"""
|
||||
return self._rw_locks.get(name)
|
||||
|
||||
def create_rate_limiter(self, name: str, max_requests: int, time_window: float) -> RateLimiter:
|
||||
"""创建命名限流器"""
|
||||
if name not in self._rate_limiters:
|
||||
self._rate_limiters[name] = RateLimiter(max_requests, time_window)
|
||||
return self._rate_limiters[name]
|
||||
|
||||
def get_rate_limiter(self, name: str) -> Optional[RateLimiter]:
|
||||
"""获取命名限流器"""
|
||||
return self._rate_limiters.get(name)
|
||||
|
||||
def create_lock(self, name: str, expire_seconds: float = 30.0) -> LocalDistributedLock:
|
||||
"""创建命名分布式锁"""
|
||||
if name not in self._locks:
|
||||
self._locks[name] = LocalDistributedLock(name, expire_seconds)
|
||||
return self._locks[name]
|
||||
|
||||
def get_lock(self, name: str) -> Optional[LocalDistributedLock]:
|
||||
"""获取命名分布式锁"""
|
||||
return self._locks.get(name)
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""获取所有组件统计信息"""
|
||||
return {
|
||||
"semaphores": {name: sem.get_stats() for name, sem in self._semaphores.items()},
|
||||
"rate_limiters": {name: rl.get_stats() for name, rl in self._rate_limiters.items()},
|
||||
}
|
||||
|
||||
|
||||
# 全局并发管理器实例
|
||||
concurrency_manager = ConcurrencyManager()
|
||||
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
配置管理器模块
|
||||
|
||||
提供统一的测试环境配置管理,支持多环境配置。
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeoutConfig:
|
||||
"""超时配置"""
|
||||
default: int = 30000
|
||||
navigation: int = 30000
|
||||
element: int = 10000
|
||||
network: int = 30000
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScreenshotConfig:
|
||||
"""截图配置"""
|
||||
enabled: bool = True
|
||||
on_failure: bool = True
|
||||
path: str = "reports/screenshots"
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoConfig:
|
||||
"""录像配置"""
|
||||
enabled: bool = False
|
||||
on_failure: bool = True
|
||||
path: str = "reports/videos"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceConfig:
|
||||
"""追踪配置"""
|
||||
enabled: bool = False
|
||||
on_failure: bool = True
|
||||
path: str = "reports/traces"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowserConfig:
|
||||
"""浏览器配置"""
|
||||
name: str = "chromium"
|
||||
headless: bool = False
|
||||
viewport_width: int = 1920
|
||||
viewport_height: int = 1080
|
||||
slow_mo: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
"""测试配置"""
|
||||
# 环境信息
|
||||
env: str = "dev"
|
||||
base_url: str = ""
|
||||
api_url: str = ""
|
||||
|
||||
# 超时配置
|
||||
timeout: TimeoutConfig = field(default_factory=TimeoutConfig)
|
||||
|
||||
# 截图配置
|
||||
screenshot: ScreenshotConfig = field(default_factory=ScreenshotConfig)
|
||||
|
||||
# 录像配置
|
||||
video: VideoConfig = field(default_factory=VideoConfig)
|
||||
|
||||
# 追踪配置
|
||||
trace: TraceConfig = field(default_factory=TraceConfig)
|
||||
|
||||
# 浏览器配置
|
||||
browser: BrowserConfig = field(default_factory=BrowserConfig)
|
||||
|
||||
# 重试配置
|
||||
retries: int = 2
|
||||
|
||||
# 并行配置
|
||||
workers: int = 1
|
||||
parallel: bool = False
|
||||
|
||||
# 测试数据
|
||||
test_data: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 用户数据
|
||||
users: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""配置管理器"""
|
||||
|
||||
_instance: Optional["ConfigManager"] = None
|
||||
_config: Optional[TestConfig] = None
|
||||
|
||||
def __new__(cls) -> "ConfigManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._config is None:
|
||||
self._config = self._load_config()
|
||||
|
||||
def _load_config(self) -> TestConfig:
|
||||
"""加载配置"""
|
||||
# 获取环境变量
|
||||
env = os.getenv("TEST_ENV", "dev")
|
||||
|
||||
# 配置文件路径
|
||||
config_path = Path(__file__).parent.parent / "config" / "config.yaml"
|
||||
|
||||
# 默认配置
|
||||
config = TestConfig(env=env)
|
||||
|
||||
# 从文件加载配置
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data and "environments" in data:
|
||||
env_config = data["environments"].get(env, {})
|
||||
|
||||
# 基础配置 - 支持嵌套的admin/uniapp配置
|
||||
if "admin" in env_config:
|
||||
config.base_url = env_config["admin"].get("base_url", "")
|
||||
config.api_url = env_config["admin"].get("api_url", "")
|
||||
else:
|
||||
config.base_url = env_config.get("base_url", "")
|
||||
config.api_url = env_config.get("api_url", "")
|
||||
|
||||
# 超时配置
|
||||
if "timeout" in env_config:
|
||||
timeout = env_config["timeout"]
|
||||
config.timeout = TimeoutConfig(
|
||||
default=timeout.get("default", 30000),
|
||||
navigation=timeout.get("navigation", 30000),
|
||||
element=timeout.get("element", 10000),
|
||||
network=timeout.get("network", 30000),
|
||||
)
|
||||
|
||||
# 浏览器配置
|
||||
if "browser" in env_config:
|
||||
browser = env_config["browser"]
|
||||
config.browser = BrowserConfig(
|
||||
name=browser.get("name", "chromium"),
|
||||
headless=browser.get("headless", False),
|
||||
viewport_width=browser.get("viewport_width", 1920),
|
||||
viewport_height=browser.get("viewport_height", 1080),
|
||||
slow_mo=browser.get("slow_mo", 0),
|
||||
)
|
||||
|
||||
# 用户数据
|
||||
if "users" in data:
|
||||
config.users = data["users"]
|
||||
|
||||
# 测试数据
|
||||
if "test_data" in data:
|
||||
config.test_data = data["test_data"]
|
||||
|
||||
# 从环境变量覆盖配置
|
||||
config.base_url = os.getenv("TEST_BASE_URL", config.base_url)
|
||||
config.api_url = os.getenv("TEST_API_URL", config.api_url)
|
||||
config.browser.headless = os.getenv("TEST_HEADLESS", "false").lower() == "true"
|
||||
|
||||
return config
|
||||
|
||||
@property
|
||||
def config(self) -> TestConfig:
|
||||
"""获取配置"""
|
||||
return self._config
|
||||
|
||||
def get_config(self) -> TestConfig:
|
||||
"""获取配置(兼容方法)"""
|
||||
return self._config
|
||||
|
||||
def reload(self) -> None:
|
||||
"""重新加载配置"""
|
||||
self._config = self._load_config()
|
||||
|
||||
def update_config(self, **kwargs) -> None:
|
||||
"""更新配置"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self._config, key):
|
||||
setattr(self._config, key, value)
|
||||
|
||||
|
||||
# 全局配置管理器实例
|
||||
config_manager = ConfigManager()
|
||||
@@ -0,0 +1,452 @@
|
||||
"""
|
||||
数据库连接池管理模块
|
||||
|
||||
提供数据库连接池的创建、管理和监控功能。
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import queue
|
||||
|
||||
|
||||
class ConnectionStatus(Enum):
|
||||
"""连接状态"""
|
||||
IDLE = "idle"
|
||||
ACTIVE = "active"
|
||||
CLOSED = "closed"
|
||||
UNHEALTHY = "unhealthy"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Connection:
|
||||
"""数据库连接封装"""
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
status: ConnectionStatus = ConnectionStatus.IDLE
|
||||
created_at: float = field(default_factory=time.time)
|
||||
last_used: float = field(default_factory=time.time)
|
||||
use_count: int = 0
|
||||
host: str = ""
|
||||
port: int = 3306
|
||||
database: str = ""
|
||||
user: str = ""
|
||||
password: str = ""
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""检查连接是否有效"""
|
||||
return self.status in [ConnectionStatus.IDLE, ConnectionStatus.ACTIVE]
|
||||
|
||||
def execute(self, query: str) -> Any:
|
||||
"""执行查询(模拟)"""
|
||||
if not self.is_valid():
|
||||
raise Exception("连接无效")
|
||||
return f"Result of: {query}"
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
self.status = ConnectionStatus.CLOSED
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""
|
||||
数据库连接池
|
||||
|
||||
特性:
|
||||
- 支持最小/最大连接数配置
|
||||
- 支持连接超时等待
|
||||
- 支持健康检查
|
||||
- 支持自动扩容
|
||||
- 线程安全
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_connections: int = 2,
|
||||
max_connections: int = 10,
|
||||
host: str = "localhost",
|
||||
port: int = 3306,
|
||||
database: str = "",
|
||||
user: str = "",
|
||||
password: str = "",
|
||||
connection_timeout: int = 30,
|
||||
health_check_interval: int = 60,
|
||||
auto_scale: bool = False
|
||||
):
|
||||
"""
|
||||
初始化连接池
|
||||
|
||||
Args:
|
||||
min_connections: 最小连接数
|
||||
max_connections: 最大连接数
|
||||
host: 数据库主机
|
||||
port: 数据库端口
|
||||
database: 数据库名
|
||||
user: 用户名
|
||||
password: 密码
|
||||
connection_timeout: 连接超时时间(秒)
|
||||
health_check_interval: 健康检查间隔(秒)
|
||||
auto_scale: 是否自动扩容
|
||||
"""
|
||||
self._min_connections = min_connections
|
||||
self._max_connections = max_connections
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._database = database
|
||||
self._user = user
|
||||
self._password = password
|
||||
self._connection_timeout = connection_timeout
|
||||
self._health_check_interval = health_check_interval
|
||||
self._auto_scale = auto_scale
|
||||
|
||||
# 连接池
|
||||
self._idle_connections: queue.Queue[Connection] = queue.Queue()
|
||||
self._active_connections: Dict[str, Connection] = {}
|
||||
self._all_connections: Dict[str, Connection] = {}
|
||||
|
||||
# 锁
|
||||
self._lock = threading.RLock()
|
||||
self._condition = threading.Condition(self._lock)
|
||||
|
||||
# 统计信息
|
||||
self._stats = {
|
||||
"total_get_count": 0,
|
||||
"total_release_count": 0,
|
||||
"total_wait_count": 0,
|
||||
"total_wait_time": 0.0,
|
||||
"health_check_count": 0,
|
||||
"unhealthy_count": 0,
|
||||
}
|
||||
|
||||
# 健康检查线程
|
||||
self._health_check_thread: Optional[threading.Thread] = None
|
||||
self._shutdown = False
|
||||
|
||||
# 初始化最小连接数
|
||||
self._initialize_min_connections()
|
||||
|
||||
# 启动健康检查
|
||||
if health_check_interval > 0:
|
||||
self._start_health_check()
|
||||
|
||||
def _initialize_min_connections(self) -> None:
|
||||
"""初始化最小连接数"""
|
||||
for _ in range(self._min_connections):
|
||||
conn = self._create_connection()
|
||||
self._idle_connections.put(conn)
|
||||
self._all_connections[conn.id] = conn
|
||||
|
||||
def _create_connection(self) -> Connection:
|
||||
"""创建新连接"""
|
||||
conn = Connection(
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
database=self._database,
|
||||
user=self._user,
|
||||
password=self._password
|
||||
)
|
||||
return conn
|
||||
|
||||
def get_connection(self, timeout: Optional[int] = None) -> Connection:
|
||||
"""
|
||||
获取连接
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒),None表示使用默认值
|
||||
|
||||
Returns:
|
||||
数据库连接
|
||||
|
||||
Raises:
|
||||
Exception: 超时或连接池已关闭
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = self._connection_timeout
|
||||
|
||||
with self._condition:
|
||||
self._stats["total_get_count"] += 1
|
||||
|
||||
# 尝试获取空闲连接
|
||||
while not self._shutdown:
|
||||
# 如果有空闲连接,直接返回
|
||||
try:
|
||||
conn = self._idle_connections.get_nowait()
|
||||
if conn.is_valid():
|
||||
conn.status = ConnectionStatus.ACTIVE
|
||||
conn.last_used = time.time()
|
||||
conn.use_count += 1
|
||||
self._active_connections[conn.id] = conn
|
||||
return conn
|
||||
else:
|
||||
# 连接无效,移除
|
||||
self._remove_connection(conn)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
# 如果没有空闲连接,尝试创建新连接
|
||||
if len(self._all_connections) < self._max_connections:
|
||||
conn = self._create_connection()
|
||||
conn.status = ConnectionStatus.ACTIVE
|
||||
conn.last_used = time.time()
|
||||
conn.use_count += 1
|
||||
self._active_connections[conn.id] = conn
|
||||
self._all_connections[conn.id] = conn
|
||||
return conn
|
||||
|
||||
# 如果达到最大连接数,等待
|
||||
self._stats["total_wait_count"] += 1
|
||||
start_wait = time.time()
|
||||
|
||||
if not self._condition.wait(timeout=timeout):
|
||||
raise Exception(f"获取连接超时({timeout}秒)")
|
||||
|
||||
self._stats["total_wait_time"] += time.time() - start_wait
|
||||
|
||||
raise Exception("连接池已关闭")
|
||||
|
||||
def release_connection(self, conn: Connection) -> None:
|
||||
"""
|
||||
释放连接
|
||||
|
||||
Args:
|
||||
conn: 要释放的连接
|
||||
"""
|
||||
with self._condition:
|
||||
if conn.id in self._active_connections:
|
||||
del self._active_connections[conn.id]
|
||||
|
||||
if conn.is_valid():
|
||||
conn.status = ConnectionStatus.IDLE
|
||||
conn.last_used = time.time()
|
||||
self._idle_connections.put(conn)
|
||||
self._stats["total_release_count"] += 1
|
||||
else:
|
||||
# 连接无效,移除
|
||||
self._remove_connection(conn)
|
||||
|
||||
# 通知等待的线程
|
||||
self._condition.notify()
|
||||
|
||||
def _remove_connection(self, conn: Connection) -> None:
|
||||
"""移除连接"""
|
||||
conn.close()
|
||||
if conn.id in self._all_connections:
|
||||
del self._all_connections[conn.id]
|
||||
if conn.id in self._active_connections:
|
||||
del self._active_connections[conn.id]
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭连接池"""
|
||||
self._shutdown = True
|
||||
|
||||
with self._condition:
|
||||
# 关闭所有连接
|
||||
for conn in self._all_connections.values():
|
||||
conn.close()
|
||||
|
||||
self._all_connections.clear()
|
||||
self._active_connections.clear()
|
||||
|
||||
# 清空空闲队列
|
||||
while not self._idle_connections.empty():
|
||||
try:
|
||||
self._idle_connections.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# 通知所有等待的线程
|
||||
self._condition.notify_all()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_connections": len(self._all_connections),
|
||||
"idle_connections": self._idle_connections.qsize(),
|
||||
"active_connections": len(self._active_connections),
|
||||
"min_connections": self._min_connections,
|
||||
"max_connections": self._max_connections,
|
||||
"total_get_count": self._stats["total_get_count"],
|
||||
"total_release_count": self._stats["total_release_count"],
|
||||
"total_wait_count": self._stats["total_wait_count"],
|
||||
"total_wait_time": self._stats["total_wait_time"],
|
||||
"health_check_count": self._stats["health_check_count"],
|
||||
"unhealthy_count": self._stats["unhealthy_count"],
|
||||
}
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""
|
||||
执行健康检查
|
||||
|
||||
Returns:
|
||||
是否所有连接都健康
|
||||
"""
|
||||
with self._lock:
|
||||
self._stats["health_check_count"] += 1
|
||||
|
||||
unhealthy_count = 0
|
||||
for conn in list(self._all_connections.values()):
|
||||
if not conn.is_valid():
|
||||
unhealthy_count += 1
|
||||
self._remove_connection(conn)
|
||||
|
||||
self._stats["unhealthy_count"] += unhealthy_count
|
||||
|
||||
# 补充最小连接数
|
||||
while len(self._all_connections) < self._min_connections:
|
||||
conn = self._create_connection()
|
||||
self._idle_connections.put(conn)
|
||||
self._all_connections[conn.id] = conn
|
||||
|
||||
return unhealthy_count == 0
|
||||
|
||||
def get_health_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取健康统计
|
||||
|
||||
Returns:
|
||||
健康统计字典
|
||||
"""
|
||||
with self._lock:
|
||||
healthy = sum(1 for conn in self._all_connections.values() if conn.is_valid())
|
||||
unhealthy = len(self._all_connections) - healthy
|
||||
|
||||
return {
|
||||
"healthy_connections": healthy,
|
||||
"unhealthy_connections": unhealthy,
|
||||
"health_check_count": self._stats["health_check_count"],
|
||||
"total_unhealthy_count": self._stats["unhealthy_count"],
|
||||
}
|
||||
|
||||
def _start_health_check(self) -> None:
|
||||
"""启动健康检查线程"""
|
||||
def health_check_worker():
|
||||
while not self._shutdown:
|
||||
time.sleep(self._health_check_interval)
|
||||
if not self._shutdown:
|
||||
self.health_check()
|
||||
|
||||
self._health_check_thread = threading.Thread(
|
||||
target=health_check_worker,
|
||||
daemon=True
|
||||
)
|
||||
self._health_check_thread.start()
|
||||
|
||||
|
||||
class ConnectionPoolManager:
|
||||
"""
|
||||
连接池管理器
|
||||
|
||||
单例模式管理多个命名连接池
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._pools = {}
|
||||
return cls._instance
|
||||
|
||||
def create_pool(
|
||||
self,
|
||||
name: str,
|
||||
min_connections: int = 2,
|
||||
max_connections: int = 10,
|
||||
host: str = "localhost",
|
||||
port: int = 3306,
|
||||
database: str = "",
|
||||
user: str = "",
|
||||
password: str = "",
|
||||
**kwargs
|
||||
) -> ConnectionPool:
|
||||
"""
|
||||
创建命名连接池
|
||||
|
||||
Args:
|
||||
name: 连接池名称
|
||||
min_connections: 最小连接数
|
||||
max_connections: 最大连接数
|
||||
host: 数据库主机
|
||||
port: 数据库端口
|
||||
database: 数据库名
|
||||
user: 用户名
|
||||
password: 密码
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
连接池实例
|
||||
"""
|
||||
if name in self._pools:
|
||||
return self._pools[name]
|
||||
|
||||
pool = ConnectionPool(
|
||||
min_connections=min_connections,
|
||||
max_connections=max_connections,
|
||||
host=host,
|
||||
port=port,
|
||||
database=database,
|
||||
user=user,
|
||||
password=password,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._pools[name] = pool
|
||||
return pool
|
||||
|
||||
def get_pool(self, name: str) -> Optional[ConnectionPool]:
|
||||
"""
|
||||
获取命名连接池
|
||||
|
||||
Args:
|
||||
name: 连接池名称
|
||||
|
||||
Returns:
|
||||
连接池实例,不存在则返回None
|
||||
"""
|
||||
return self._pools.get(name)
|
||||
|
||||
def remove_pool(self, name: str) -> bool:
|
||||
"""
|
||||
移除连接池
|
||||
|
||||
Args:
|
||||
name: 连接池名称
|
||||
|
||||
Returns:
|
||||
是否成功移除
|
||||
"""
|
||||
if name in self._pools:
|
||||
self._pools[name].close()
|
||||
del self._pools[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def close_all(self) -> None:
|
||||
"""关闭所有连接池"""
|
||||
for pool in self._pools.values():
|
||||
pool.close()
|
||||
self._pools.clear()
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
获取所有连接池统计
|
||||
|
||||
Returns:
|
||||
连接池名称到统计信息的映射
|
||||
"""
|
||||
return {name: pool.get_stats() for name, pool in self._pools.items()}
|
||||
|
||||
|
||||
# 全局连接池管理器实例
|
||||
pool_manager = ConnectionPoolManager()
|
||||
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
数据导入导出功能模块
|
||||
|
||||
提供Excel/CSV数据的导入导出功能。
|
||||
"""
|
||||
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportResult:
|
||||
"""导出结果"""
|
||||
success: bool
|
||||
file_path: Optional[str] = None
|
||||
record_count: int = 0
|
||||
message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportResult:
|
||||
"""导入结果"""
|
||||
success: bool
|
||||
data: List[Dict[str, Any]] = field(default_factory=list)
|
||||
record_count: int = 0
|
||||
error_count: int = 0
|
||||
message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""验证结果"""
|
||||
is_valid: bool
|
||||
errors: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class CSVExporter:
|
||||
"""CSV导出器"""
|
||||
|
||||
def export(self, data: List[Dict[str, Any]], file_path: str, encoding: str = 'utf-8') -> ExportResult:
|
||||
"""
|
||||
导出数据为CSV
|
||||
|
||||
Args:
|
||||
data: 数据列表
|
||||
file_path: 输出文件路径
|
||||
encoding: 文件编码
|
||||
|
||||
Returns:
|
||||
导出结果
|
||||
"""
|
||||
try:
|
||||
if not data:
|
||||
return ExportResult(success=True, file_path=file_path, record_count=0)
|
||||
|
||||
# 获取表头
|
||||
headers = list(data[0].keys())
|
||||
|
||||
with open(file_path, 'w', newline='', encoding=encoding) as f:
|
||||
writer = csv.DictWriter(f, fieldnames=headers)
|
||||
writer.writeheader()
|
||||
writer.writerows(data)
|
||||
|
||||
return ExportResult(
|
||||
success=True,
|
||||
file_path=file_path,
|
||||
record_count=len(data)
|
||||
)
|
||||
except Exception as e:
|
||||
return ExportResult(success=False, message=str(e))
|
||||
|
||||
|
||||
class CSVImporter:
|
||||
"""CSV导入器"""
|
||||
|
||||
def import_file(self, file_path: str, encoding: str = 'utf-8') -> ImportResult:
|
||||
"""
|
||||
从CSV文件导入数据
|
||||
|
||||
Args:
|
||||
file_path: CSV文件路径
|
||||
encoding: 文件编码
|
||||
|
||||
Returns:
|
||||
导入结果
|
||||
"""
|
||||
try:
|
||||
data = []
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
data.append(dict(row))
|
||||
|
||||
return ImportResult(
|
||||
success=True,
|
||||
data=data,
|
||||
record_count=len(data)
|
||||
)
|
||||
except Exception as e:
|
||||
return ImportResult(success=False, message=str(e))
|
||||
|
||||
|
||||
class ExcelExporter:
|
||||
"""Excel导出器(简化版,实际使用时需要openpyxl库)"""
|
||||
|
||||
def export(self, data: List[Dict[str, Any]], file_path: str) -> ExportResult:
|
||||
"""
|
||||
导出数据为Excel(实际实现需要openpyxl)
|
||||
|
||||
Args:
|
||||
data: 数据列表
|
||||
file_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
导出结果
|
||||
"""
|
||||
try:
|
||||
# 简化实现:导出为CSV格式但使用.xlsx扩展名
|
||||
# 实际项目中应该使用openpyxl或pandas
|
||||
csv_path = file_path.replace('.xlsx', '.csv')
|
||||
exporter = CSVExporter()
|
||||
result = exporter.export(data, csv_path)
|
||||
|
||||
if result.success:
|
||||
# 重命名为xlsx
|
||||
os.rename(csv_path, file_path)
|
||||
return ExportResult(
|
||||
success=True,
|
||||
file_path=file_path,
|
||||
record_count=len(data)
|
||||
)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
return ExportResult(success=False, message=str(e))
|
||||
|
||||
|
||||
class DataValidator:
|
||||
"""数据验证器"""
|
||||
|
||||
def validate(self, data: Dict[str, Any], rules: Dict[str, Dict[str, Any]]) -> ValidationResult:
|
||||
"""
|
||||
验证数据
|
||||
|
||||
Args:
|
||||
data: 要验证的数据
|
||||
rules: 验证规则
|
||||
|
||||
Returns:
|
||||
验证结果
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for field, rule in rules.items():
|
||||
value = data.get(field)
|
||||
|
||||
# 必填验证
|
||||
if rule.get('required') and not value:
|
||||
errors.append(f"{field}: 必填字段不能为空")
|
||||
continue
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
# 类型验证
|
||||
field_type = rule.get('type')
|
||||
if field_type == 'integer':
|
||||
try:
|
||||
int(value)
|
||||
except (ValueError, TypeError):
|
||||
errors.append(f"{field}: 必须是整数")
|
||||
continue
|
||||
elif field_type == 'email':
|
||||
if not re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', str(value)):
|
||||
errors.append(f"{field}: 邮箱格式不正确")
|
||||
continue
|
||||
|
||||
# 范围验证
|
||||
if field_type == 'integer' and isinstance(value, (int, str)):
|
||||
try:
|
||||
int_value = int(value)
|
||||
min_val = rule.get('min')
|
||||
max_val = rule.get('max')
|
||||
if min_val is not None and int_value < min_val:
|
||||
errors.append(f"{field}: 不能小于{min_val}")
|
||||
if max_val is not None and int_value > max_val:
|
||||
errors.append(f"{field}: 不能大于{max_val}")
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return ValidationResult(is_valid=len(errors) == 0, errors=errors)
|
||||
|
||||
|
||||
class DataTransformer:
|
||||
"""数据转换器"""
|
||||
|
||||
def transform(self, data: List[Dict[str, Any]], mapping: Dict[str, str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
转换数据字段映射
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
mapping: 字段映射 {源字段: 目标字段}
|
||||
|
||||
Returns:
|
||||
转换后的数据
|
||||
"""
|
||||
result = []
|
||||
for item in data:
|
||||
new_item = {}
|
||||
for source_field, target_field in mapping.items():
|
||||
if source_field in item:
|
||||
new_item[target_field] = item[source_field]
|
||||
result.append(new_item)
|
||||
return result
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
"""模板管理器"""
|
||||
|
||||
def generate_template(self, template: Dict[str, Any], file_path: str) -> ExportResult:
|
||||
"""
|
||||
生成模板文件
|
||||
|
||||
Args:
|
||||
template: 模板定义
|
||||
file_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
生成结果
|
||||
"""
|
||||
try:
|
||||
columns = template.get('columns', [])
|
||||
headers = [col['name'] for col in columns]
|
||||
|
||||
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(headers)
|
||||
|
||||
return ExportResult(
|
||||
success=True,
|
||||
file_path=file_path,
|
||||
record_count=0
|
||||
)
|
||||
except Exception as e:
|
||||
return ExportResult(success=False, message=str(e))
|
||||
|
||||
|
||||
class DataImportExportManager:
|
||||
"""数据导入导出管理器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化管理器"""
|
||||
self._csv_exporter = CSVExporter()
|
||||
self._csv_importer = CSVImporter()
|
||||
self._stats = {
|
||||
'total_exports': 0,
|
||||
'total_imports': 0,
|
||||
'total_records_exported': 0,
|
||||
'total_records_imported': 0,
|
||||
}
|
||||
|
||||
def export_batch(
|
||||
self,
|
||||
data: List[Dict[str, Any]],
|
||||
file_path: str,
|
||||
batch_size: int = 1000
|
||||
) -> ExportResult:
|
||||
"""
|
||||
批量导出数据
|
||||
|
||||
Args:
|
||||
data: 数据列表
|
||||
file_path: 输出文件路径
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
导出结果
|
||||
"""
|
||||
result = self._csv_exporter.export(data, file_path)
|
||||
|
||||
if result.success:
|
||||
self._stats['total_exports'] += 1
|
||||
self._stats['total_records_exported'] += result.record_count
|
||||
|
||||
return result
|
||||
|
||||
def import_batch(self, file_path: str, batch_size: int = 1000) -> ImportResult:
|
||||
"""
|
||||
批量导入数据
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
导入结果
|
||||
"""
|
||||
result = self._csv_importer.import_file(file_path)
|
||||
|
||||
if result.success:
|
||||
self._stats['total_imports'] += 1
|
||||
self._stats['total_records_imported'] += result.record_count
|
||||
|
||||
return result
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
return self._stats.copy()
|
||||
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
异常处理器模块
|
||||
|
||||
提供测试异常分类处理功能,区分致命错误和可重试错误。
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from playwright.sync_api import Page
|
||||
|
||||
|
||||
class FatalTestError(Exception):
|
||||
"""致命测试错误"""
|
||||
pass
|
||||
|
||||
|
||||
class RetryableError(Exception):
|
||||
"""可重试错误"""
|
||||
pass
|
||||
|
||||
|
||||
class TestExceptionHandler:
|
||||
"""测试异常处理器"""
|
||||
|
||||
# 致命错误:立即停止测试
|
||||
FATAL_ERRORS: List[str] = [
|
||||
"Browser crashed",
|
||||
"Connection refused",
|
||||
"Target page closed",
|
||||
"Database connection failed",
|
||||
"Session deleted",
|
||||
"invalid session id",
|
||||
]
|
||||
|
||||
# 可重试错误:自动重试
|
||||
RETRYABLE_ERRORS: List[str] = [
|
||||
"TimeoutError",
|
||||
"Timeout exceeded",
|
||||
"Element not found",
|
||||
"Network error",
|
||||
"net::ERR",
|
||||
"Stale element reference",
|
||||
"element is detached",
|
||||
"Execution context was destroyed",
|
||||
"Unable to locate element",
|
||||
"waiting for locator",
|
||||
]
|
||||
|
||||
# 非致命错误:记录但继续
|
||||
NON_FATAL_ERRORS: List[str] = [
|
||||
"Screenshot failed",
|
||||
"Log write failed",
|
||||
"Screenshot is not supported",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def handle_exception(
|
||||
cls,
|
||||
error: Exception,
|
||||
page: Optional[Page] = None,
|
||||
test_name: str = "",
|
||||
screenshot_helper=None,
|
||||
) -> Optional[Exception]:
|
||||
"""
|
||||
处理测试异常
|
||||
|
||||
Args:
|
||||
error: 异常对象
|
||||
page: Playwright页面对象
|
||||
test_name: 测试名称
|
||||
screenshot_helper: 截图辅助工具
|
||||
|
||||
Returns:
|
||||
None: 非致命错误,继续执行
|
||||
RetryableError: 可重试错误
|
||||
FatalTestError: 致命错误,停止测试
|
||||
"""
|
||||
error_msg = str(error)
|
||||
error_type = type(error).__name__
|
||||
|
||||
# 致命错误
|
||||
if any(fatal in error_msg for fatal in cls.FATAL_ERRORS):
|
||||
if screenshot_helper and page:
|
||||
screenshot_helper.take_screenshot(
|
||||
page, f"{test_name}_fatal_error" if test_name else "fatal_error"
|
||||
)
|
||||
raise FatalTestError(f"测试遇到致命错误 [{error_type}]: {error_msg}")
|
||||
|
||||
# 可重试错误
|
||||
if any(retryable in error_msg for retryable in cls.RETRYABLE_ERRORS):
|
||||
if screenshot_helper and page:
|
||||
screenshot_helper.take_screenshot(
|
||||
page,
|
||||
f"{test_name}_retryable_error" if test_name else "retryable_error",
|
||||
)
|
||||
return RetryableError(f"可重试错误 [{error_type}]: {error_msg}")
|
||||
|
||||
# 非致命错误
|
||||
if any(non_fatal in error_msg for non_fatal in cls.NON_FATAL_ERRORS):
|
||||
# 仅记录错误,不抛出
|
||||
return None
|
||||
|
||||
# 未知错误 - 根据错误类型判断
|
||||
if "assert" in error_msg.lower():
|
||||
# 断言错误通常是测试逻辑问题,不重试
|
||||
raise error
|
||||
|
||||
# 其他错误视为可重试
|
||||
if screenshot_helper and page:
|
||||
screenshot_helper.take_screenshot(
|
||||
page, f"{test_name}_unknown_error" if test_name else "unknown_error"
|
||||
)
|
||||
return RetryableError(f"未知错误 [{error_type}]: {error_msg}")
|
||||
|
||||
@classmethod
|
||||
def is_fatal_error(cls, error: Exception) -> bool:
|
||||
"""判断是否为致命错误"""
|
||||
error_msg = str(error)
|
||||
return any(fatal in error_msg for fatal in cls.FATAL_ERRORS)
|
||||
|
||||
@classmethod
|
||||
def is_retryable_error(cls, error: Exception) -> bool:
|
||||
"""判断是否为可重试错误"""
|
||||
error_msg = str(error)
|
||||
return any(retryable in error_msg for retryable in cls.RETRYABLE_ERRORS)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
自定义异常类
|
||||
|
||||
提供测试框架的自定义异常。
|
||||
"""
|
||||
|
||||
|
||||
class TestFrameworkError(Exception):
|
||||
"""测试框架基础异常"""
|
||||
pass
|
||||
|
||||
|
||||
class APIError(TestFrameworkError):
|
||||
"""API错误异常"""
|
||||
pass
|
||||
|
||||
|
||||
class APITimeoutError(APIError):
|
||||
"""API超时异常"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(TestFrameworkError):
|
||||
"""验证错误异常"""
|
||||
pass
|
||||
|
||||
|
||||
class ElementNotFoundError(TestFrameworkError):
|
||||
"""元素未找到异常"""
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(TestFrameworkError):
|
||||
"""超时异常"""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(TestFrameworkError):
|
||||
"""配置错误异常"""
|
||||
pass
|
||||
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
文件上传下载功能模块
|
||||
|
||||
提供文件上传、下载、验证和管理功能。
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional, BinaryIO
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class UploadResult:
|
||||
"""上传结果"""
|
||||
success: bool
|
||||
file_id: Optional[str] = None
|
||||
file_path: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
size: int = 0
|
||||
message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadResult:
|
||||
"""下载结果"""
|
||||
success: bool
|
||||
content: Optional[bytes] = None
|
||||
filename: Optional[str] = None
|
||||
message: str = ""
|
||||
|
||||
|
||||
class FileTypeValidator:
|
||||
"""文件类型验证器"""
|
||||
|
||||
def __init__(self, allowed_extensions: Optional[List[str]] = None):
|
||||
"""
|
||||
初始化文件类型验证器
|
||||
|
||||
Args:
|
||||
allowed_extensions: 允许的文件扩展名列表
|
||||
"""
|
||||
self._allowed_extensions = allowed_extensions or []
|
||||
|
||||
def validate(self, filename: str) -> bool:
|
||||
"""
|
||||
验证文件类型
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
是否允许
|
||||
"""
|
||||
if not self._allowed_extensions:
|
||||
return True
|
||||
|
||||
ext = Path(filename).suffix.lower()
|
||||
return ext in self._allowed_extensions
|
||||
|
||||
|
||||
class FileSizeValidator:
|
||||
"""文件大小验证器"""
|
||||
|
||||
def __init__(self, max_size: int):
|
||||
"""
|
||||
初始化文件大小验证器
|
||||
|
||||
Args:
|
||||
max_size: 最大文件大小(字节)
|
||||
"""
|
||||
self._max_size = max_size
|
||||
|
||||
def validate(self, size: int) -> bool:
|
||||
"""
|
||||
验证文件大小
|
||||
|
||||
Args:
|
||||
size: 文件大小(字节)
|
||||
|
||||
Returns:
|
||||
是否允许
|
||||
"""
|
||||
return size <= self._max_size
|
||||
|
||||
|
||||
class FilenameSanitizer:
|
||||
"""文件名净化器"""
|
||||
|
||||
# 危险字符
|
||||
DANGEROUS_CHARS = r'[;|&$<>\`\\]'
|
||||
|
||||
def sanitize(self, filename: str) -> str:
|
||||
"""
|
||||
净化文件名
|
||||
|
||||
Args:
|
||||
filename: 原始文件名
|
||||
|
||||
Returns:
|
||||
安全的文件名
|
||||
"""
|
||||
# 移除路径遍历
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
# 移除危险字符
|
||||
filename = re.sub(self.DANGEROUS_CHARS, '', filename)
|
||||
|
||||
# 移除连续的点
|
||||
filename = re.sub(r'\.{2,}', '.', filename)
|
||||
|
||||
# 确保不为空
|
||||
if not filename or filename == '.':
|
||||
filename = 'unnamed_file'
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
class FileStorageManager:
|
||||
"""文件存储管理器"""
|
||||
|
||||
def __init__(self, storage_dir: str):
|
||||
"""
|
||||
初始化存储管理器
|
||||
|
||||
Args:
|
||||
storage_dir: 存储目录
|
||||
"""
|
||||
self._storage_dir = storage_dir
|
||||
self._metadata: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 创建存储目录
|
||||
os.makedirs(storage_dir, exist_ok=True)
|
||||
|
||||
def save(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
保存文件
|
||||
|
||||
Args:
|
||||
content: 文件内容
|
||||
filename: 文件名
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
文件ID
|
||||
"""
|
||||
file_id = str(uuid.uuid4())
|
||||
file_path = os.path.join(self._storage_dir, file_id)
|
||||
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
|
||||
# 保存元数据
|
||||
self._metadata[file_id] = {
|
||||
'filename': filename,
|
||||
'size': len(content),
|
||||
'metadata': metadata or {},
|
||||
}
|
||||
|
||||
return file_id
|
||||
|
||||
def get(self, file_id: str) -> Optional[bytes]:
|
||||
"""
|
||||
获取文件内容
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
文件内容
|
||||
"""
|
||||
file_path = os.path.join(self._storage_dir, file_id)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
def get_metadata(self, file_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取文件元数据
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
元数据
|
||||
"""
|
||||
meta = self._metadata.get(file_id)
|
||||
if meta:
|
||||
return meta.get('metadata')
|
||||
return None
|
||||
|
||||
def delete(self, file_id: str) -> bool:
|
||||
"""
|
||||
删除文件
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
file_path = os.path.join(self._storage_dir, file_id)
|
||||
|
||||
if os.path.exists(file_path):
|
||||
os.unlink(file_path)
|
||||
if file_id in self._metadata:
|
||||
del self._metadata[file_id]
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class FileUploader:
|
||||
"""文件上传器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
upload_dir: str,
|
||||
type_validator: Optional[FileTypeValidator] = None,
|
||||
size_validator: Optional[FileSizeValidator] = None,
|
||||
filename_sanitizer: Optional[FilenameSanitizer] = None
|
||||
):
|
||||
"""
|
||||
初始化文件上传器
|
||||
|
||||
Args:
|
||||
upload_dir: 上传目录
|
||||
type_validator: 文件类型验证器
|
||||
size_validator: 文件大小验证器
|
||||
filename_sanitizer: 文件名净化器
|
||||
"""
|
||||
self._upload_dir = upload_dir
|
||||
self._type_validator = type_validator or FileTypeValidator()
|
||||
self._size_validator = size_validator or FileSizeValidator(max_size=10 * 1024 * 1024) # 10MB
|
||||
self._filename_sanitizer = filename_sanitizer or FilenameSanitizer()
|
||||
self._storage = FileStorageManager(upload_dir)
|
||||
|
||||
def upload(self, file_obj: BinaryIO, filename: str) -> UploadResult:
|
||||
"""
|
||||
上传文件
|
||||
|
||||
Args:
|
||||
file_obj: 文件对象
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
上传结果
|
||||
"""
|
||||
# 净化文件名
|
||||
safe_filename = self._filename_sanitizer.sanitize(filename)
|
||||
|
||||
# 验证文件类型
|
||||
if not self._type_validator.validate(safe_filename):
|
||||
return UploadResult(
|
||||
success=False,
|
||||
message=f"不支持的文件类型: {safe_filename}"
|
||||
)
|
||||
|
||||
# 读取文件内容
|
||||
content = file_obj.read()
|
||||
|
||||
# 验证文件大小
|
||||
if not self._size_validator.validate(len(content)):
|
||||
return UploadResult(
|
||||
success=False,
|
||||
message=f"文件大小超过限制"
|
||||
)
|
||||
|
||||
# 保存文件
|
||||
file_id = self._storage.save(content, safe_filename)
|
||||
file_path = os.path.join(self._upload_dir, file_id)
|
||||
|
||||
return UploadResult(
|
||||
success=True,
|
||||
file_id=file_id,
|
||||
file_path=file_path,
|
||||
filename=safe_filename,
|
||||
size=len(content)
|
||||
)
|
||||
|
||||
def upload_batch(self, file_paths: List[str]) -> List[UploadResult]:
|
||||
"""
|
||||
批量上传文件
|
||||
|
||||
Args:
|
||||
file_paths: 文件路径列表
|
||||
|
||||
Returns:
|
||||
上传结果列表
|
||||
"""
|
||||
results = []
|
||||
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
filename = os.path.basename(file_path)
|
||||
result = self.upload(f, filename)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
results.append(UploadResult(
|
||||
success=False,
|
||||
message=str(e)
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class FileDownloader:
|
||||
"""文件下载器"""
|
||||
|
||||
def __init__(self, storage_manager: Optional[FileStorageManager] = None):
|
||||
"""
|
||||
初始化文件下载器
|
||||
|
||||
Args:
|
||||
storage_manager: 存储管理器
|
||||
"""
|
||||
self._storage = storage_manager
|
||||
|
||||
def download(self, file_id: str) -> DownloadResult:
|
||||
"""
|
||||
下载文件
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
下载结果
|
||||
"""
|
||||
if self._storage is None:
|
||||
return DownloadResult(
|
||||
success=False,
|
||||
message="存储管理器未设置"
|
||||
)
|
||||
|
||||
content = self._storage.get(file_id)
|
||||
|
||||
if content is None:
|
||||
return DownloadResult(
|
||||
success=False,
|
||||
message="文件不存在"
|
||||
)
|
||||
|
||||
return DownloadResult(
|
||||
success=True,
|
||||
content=content
|
||||
)
|
||||
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
日志记录器模块
|
||||
|
||||
提供结构化的测试日志记录功能。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TestLogger:
|
||||
"""测试日志记录器"""
|
||||
|
||||
def __init__(self, name: str = "e2e_test", log_dir: str = "logs"):
|
||||
self.name = name
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建logger
|
||||
self.logger = logging.getLogger(name)
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
|
||||
# 避免重复添加handler
|
||||
if not self.logger.handlers:
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器
|
||||
log_file = self.log_dir / f"test_{datetime.now().strftime('%Y%m%d')}.log"
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
self.logger.addHandler(file_handler)
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""记录调试日志"""
|
||||
self.logger.debug(message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""记录信息日志"""
|
||||
self.logger.info(message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""记录警告日志"""
|
||||
self.logger.warning(message)
|
||||
|
||||
def error(self, message: str, error: Optional[Exception] = None) -> None:
|
||||
"""记录错误日志"""
|
||||
if error:
|
||||
self.logger.error(f"{message}: {str(error)}", exc_info=True)
|
||||
else:
|
||||
self.logger.error(message)
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
"""记录严重错误日志"""
|
||||
self.logger.critical(message)
|
||||
|
||||
def start_test(self, test_name: str) -> None:
|
||||
"""记录测试开始"""
|
||||
self.info(f"{'='*60}")
|
||||
self.info(f"开始执行测试: {test_name}")
|
||||
self.info(f"{'='*60}")
|
||||
|
||||
def end_test(self, test_name: str, status: str, error: Optional[Exception] = None) -> None:
|
||||
"""记录测试结束"""
|
||||
icon = "✅" if status == "passed" else "❌"
|
||||
self.info(f"{'='*60}")
|
||||
self.info(f"{icon} 测试结束: {test_name} - 状态: {status}")
|
||||
if error:
|
||||
self.error(f"错误信息: {str(error)}")
|
||||
self.info(f"{'='*60}")
|
||||
|
||||
def start_step(self, step_name: str) -> None:
|
||||
"""记录步骤开始"""
|
||||
self.info(f"▶️ 执行步骤: {step_name}")
|
||||
|
||||
def end_step(self, step_name: str, status: str) -> None:
|
||||
"""记录步骤结束"""
|
||||
icon = "✓" if status == "passed" else "✗"
|
||||
self.info(f"{icon} 步骤完成: {step_name} - 状态: {status}")
|
||||
|
||||
|
||||
# 全局日志记录器实例
|
||||
_logger: Optional[TestLogger] = None
|
||||
|
||||
|
||||
def get_logger(name: str = "e2e_test", log_dir: str = "logs") -> TestLogger:
|
||||
"""获取日志记录器实例"""
|
||||
global _logger
|
||||
if _logger is None:
|
||||
_logger = TestLogger(name, log_dir)
|
||||
return _logger
|
||||
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
性能监控服务
|
||||
|
||||
提供性能监控和优化功能。
|
||||
"""
|
||||
|
||||
import time
|
||||
import functools
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import threading
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceMetric:
|
||||
"""性能指标数据类"""
|
||||
name: str
|
||||
duration: float
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""性能监控器"""
|
||||
|
||||
def __init__(self):
|
||||
self._metrics: List[PerformanceMetric] = []
|
||||
self._thresholds: Dict[str, float] = {
|
||||
"page_load": 3.0,
|
||||
"table_load": 2.0,
|
||||
"search_response": 1.0,
|
||||
"form_submit": 2.0,
|
||||
"concurrent_op": 1.0,
|
||||
"memory_growth": 50.0,
|
||||
}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def record_metric(self, name: str, duration: float, metadata: Dict[str, Any] = None) -> PerformanceMetric:
|
||||
"""记录性能指标"""
|
||||
metric = PerformanceMetric(
|
||||
name=name,
|
||||
duration=duration,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._metrics.append(metric)
|
||||
|
||||
return metric
|
||||
|
||||
def measure(self, name: str, metadata: Dict[str, Any] = None):
|
||||
"""性能测量装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
|
||||
duration = end_time - start_time
|
||||
self.record_metric(name, duration, metadata)
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
def measure_context(self, name: str, metadata: Dict[str, Any] = None):
|
||||
"""性能测量上下文管理器"""
|
||||
return PerformanceContext(self, name, metadata)
|
||||
|
||||
def get_metrics(self, name: Optional[str] = None) -> List[PerformanceMetric]:
|
||||
"""获取性能指标"""
|
||||
with self._lock:
|
||||
if name:
|
||||
return [m for m in self._metrics if m.name == name]
|
||||
return self._metrics.copy()
|
||||
|
||||
def get_average_duration(self, name: str) -> float:
|
||||
"""获取平均持续时间"""
|
||||
metrics = self.get_metrics(name)
|
||||
if not metrics:
|
||||
return 0.0
|
||||
return sum(m.duration for m in metrics) / len(metrics)
|
||||
|
||||
def check_threshold(self, name: str, duration: float) -> bool:
|
||||
"""检查是否超过阈值"""
|
||||
threshold = self._thresholds.get(name, float('inf'))
|
||||
return duration <= threshold
|
||||
|
||||
def get_performance_report(self) -> Dict[str, Any]:
|
||||
"""生成性能报告"""
|
||||
with self._lock:
|
||||
report = {
|
||||
"total_metrics": len(self._metrics),
|
||||
"metrics_by_name": {},
|
||||
"threshold_violations": [],
|
||||
}
|
||||
|
||||
# 按名称分组统计
|
||||
for metric in self._metrics:
|
||||
if metric.name not in report["metrics_by_name"]:
|
||||
report["metrics_by_name"][metric.name] = {
|
||||
"count": 0,
|
||||
"total_duration": 0.0,
|
||||
"min_duration": float('inf'),
|
||||
"max_duration": 0.0,
|
||||
}
|
||||
|
||||
stats = report["metrics_by_name"][metric.name]
|
||||
stats["count"] += 1
|
||||
stats["total_duration"] += metric.duration
|
||||
stats["min_duration"] = min(stats["min_duration"], metric.duration)
|
||||
stats["max_duration"] = max(stats["max_duration"], metric.duration)
|
||||
|
||||
# 检查阈值违规
|
||||
if not self.check_threshold(metric.name, metric.duration):
|
||||
report["threshold_violations"].append({
|
||||
"name": metric.name,
|
||||
"duration": metric.duration,
|
||||
"threshold": self._thresholds.get(metric.name),
|
||||
"timestamp": metric.timestamp,
|
||||
})
|
||||
|
||||
# 计算平均值
|
||||
for name, stats in report["metrics_by_name"].items():
|
||||
stats["avg_duration"] = stats["total_duration"] / stats["count"]
|
||||
|
||||
return report
|
||||
|
||||
def clear_metrics(self):
|
||||
"""清除所有指标"""
|
||||
with self._lock:
|
||||
self._metrics.clear()
|
||||
|
||||
|
||||
class PerformanceContext:
|
||||
"""性能测量上下文"""
|
||||
|
||||
def __init__(self, monitor: PerformanceMonitor, name: str, metadata: Dict[str, Any] = None):
|
||||
self.monitor = monitor
|
||||
self.name = name
|
||||
self.metadata = metadata or {}
|
||||
self.start_time = None
|
||||
self.metric = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
end_time = time.time()
|
||||
duration = end_time - self.start_time
|
||||
self.metric = self.monitor.record_metric(self.name, duration, self.metadata)
|
||||
|
||||
|
||||
# 全局性能监控器实例
|
||||
performance_monitor = PerformanceMonitor()
|
||||
|
||||
|
||||
# 性能测试辅助函数
|
||||
def measure_page_load(page, url: str) -> float:
|
||||
"""测量页面加载时间"""
|
||||
start_time = time.time()
|
||||
page.goto(url)
|
||||
page.wait_for_load_state("networkidle")
|
||||
end_time = time.time()
|
||||
return end_time - start_time
|
||||
|
||||
|
||||
def measure_element_load(page, selector: str, timeout: int = 10000) -> float:
|
||||
"""测量元素加载时间"""
|
||||
start_time = time.time()
|
||||
page.wait_for_selector(selector, timeout=timeout)
|
||||
end_time = time.time()
|
||||
return end_time - start_time
|
||||
|
||||
|
||||
def measure_api_response(func: Callable, *args, **kwargs) -> tuple:
|
||||
"""测量API响应时间"""
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
return result, duration
|
||||
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
测试报告生成器模块
|
||||
|
||||
提供测试报告生成功能,支持多种报告格式。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
"""测试结果数据类"""
|
||||
name: str
|
||||
status: str # passed, failed, skipped
|
||||
duration: float = 0.0
|
||||
start_time: Optional[str] = None
|
||||
end_time: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
traceback: Optional[str] = None
|
||||
screenshot: Optional[str] = None
|
||||
steps: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestSuite:
|
||||
"""测试套件数据类"""
|
||||
name: str
|
||||
tests: List[TestResult] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def passed_count(self) -> int:
|
||||
return sum(1 for t in self.tests if t.status == "passed")
|
||||
|
||||
@property
|
||||
def failed_count(self) -> int:
|
||||
return sum(1 for t in self.tests if t.status == "failed")
|
||||
|
||||
@property
|
||||
def skipped_count(self) -> int:
|
||||
return sum(1 for t in self.tests if t.status == "skipped")
|
||||
|
||||
@property
|
||||
def total_count(self) -> int:
|
||||
return len(self.tests)
|
||||
|
||||
@property
|
||||
def pass_rate(self) -> float:
|
||||
if self.total_count == 0:
|
||||
return 0.0
|
||||
return (self.passed_count / self.total_count) * 100
|
||||
|
||||
|
||||
class TestReporter:
|
||||
"""测试报告生成器"""
|
||||
|
||||
def __init__(self, report_dir: str = "reports"):
|
||||
self.report_dir = Path(report_dir)
|
||||
self.report_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.suites: Dict[str, TestSuite] = {}
|
||||
self.start_time: Optional[datetime] = None
|
||||
self.end_time: Optional[datetime] = None
|
||||
|
||||
def start_report(self) -> None:
|
||||
"""开始测试报告"""
|
||||
self.start_time = datetime.now()
|
||||
print(f"📝 测试报告开始于: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
def end_report(self) -> None:
|
||||
"""结束测试报告"""
|
||||
self.end_time = datetime.now()
|
||||
print(f"📝 测试报告结束于: {self.end_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
def add_test_result(self, suite_name: str, result: TestResult) -> None:
|
||||
"""添加测试结果"""
|
||||
if suite_name not in self.suites:
|
||||
self.suites[suite_name] = TestSuite(name=suite_name)
|
||||
self.suites[suite_name].tests.append(result)
|
||||
|
||||
def generate_html_report(self, filename: str = "test_report.html") -> str:
|
||||
"""生成HTML报告"""
|
||||
filepath = self.report_dir / filename
|
||||
|
||||
total_tests = sum(s.total_count for s in self.suites.values())
|
||||
total_passed = sum(s.passed_count for s in self.suites.values())
|
||||
total_failed = sum(s.failed_count for s in self.suites.values())
|
||||
total_skipped = sum(s.skipped_count for s in self.suites.values())
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>E2E测试报告</title>
|
||||
<style>
|
||||
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #f5f5f5;
|
||||
padding: 20px;
|
||||
}}
|
||||
.container {{ max-width: 1200px; margin: 0 auto; }}
|
||||
.header {{
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
padding: 30px;
|
||||
border-radius: 10px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
.header h1 {{ font-size: 28px; margin-bottom: 10px; }}
|
||||
.summary {{
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 15px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
.summary-card {{
|
||||
background: white;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}}
|
||||
.summary-card h3 {{ font-size: 14px; color: #666; margin-bottom: 8px; }}
|
||||
.summary-card .value {{ font-size: 32px; font-weight: bold; }}
|
||||
.summary-card.passed .value {{ color: #10b981; }}
|
||||
.summary-card.failed .value {{ color: #ef4444; }}
|
||||
.summary-card.skipped .value {{ color: #f59e0b; }}
|
||||
.summary-card.total .value {{ color: #3b82f6; }}
|
||||
.suite {{
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}}
|
||||
.suite-header {{
|
||||
background: #f8f9fa;
|
||||
padding: 15px 20px;
|
||||
border-bottom: 1px solid #e5e7eb;
|
||||
}}
|
||||
.suite-header h2 {{ font-size: 18px; color: #374151; }}
|
||||
.test-list {{ padding: 0; }}
|
||||
.test-item {{
|
||||
padding: 15px 20px;
|
||||
border-bottom: 1px solid #e5e7eb;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}}
|
||||
.test-item:last-child {{ border-bottom: none; }}
|
||||
.test-item.passed {{ border-left: 4px solid #10b981; }}
|
||||
.test-item.failed {{ border-left: 4px solid #ef4444; }}
|
||||
.test-item.skipped {{ border-left: 4px solid #f59e0b; }}
|
||||
.test-name {{ font-weight: 500; color: #374151; }}
|
||||
.test-status {{
|
||||
padding: 4px 12px;
|
||||
border-radius: 12px;
|
||||
font-size: 12px;
|
||||
font-weight: 500;
|
||||
}}
|
||||
.test-status.passed {{ background: #d1fae5; color: #065f46; }}
|
||||
.test-status.failed {{ background: #fee2e2; color: #991b1b; }}
|
||||
.test-status.skipped {{ background: #fef3c7; color: #92400e; }}
|
||||
.test-duration {{ color: #6b7280; font-size: 12px; margin-left: 10px; }}
|
||||
.error-message {{
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
padding: 10px;
|
||||
border-radius: 4px;
|
||||
margin-top: 10px;
|
||||
font-size: 12px;
|
||||
}}
|
||||
.footer {{
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
color: #6b7280;
|
||||
font-size: 12px;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>🧪 E2E测试报告</h1>
|
||||
<p>生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
</div>
|
||||
|
||||
<div class="summary">
|
||||
<div class="summary-card total">
|
||||
<h3>总测试数</h3>
|
||||
<div class="value">{total_tests}</div>
|
||||
</div>
|
||||
<div class="summary-card passed">
|
||||
<h3>通过</h3>
|
||||
<div class="value">{total_passed}</div>
|
||||
</div>
|
||||
<div class="summary-card failed">
|
||||
<h3>失败</h3>
|
||||
<div class="value">{total_failed}</div>
|
||||
</div>
|
||||
<div class="summary-card skipped">
|
||||
<h3>跳过</h3>
|
||||
<div class="value">{total_skipped}</div>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# 添加测试套件详情
|
||||
for suite_name, suite in self.suites.items():
|
||||
html_content += f"""
|
||||
<div class="suite">
|
||||
<div class="suite-header">
|
||||
<h2>{suite_name}</h2>
|
||||
<p>通过率: {suite.pass_rate:.1f}% ({suite.passed_count}/{suite.total_count})</p>
|
||||
</div>
|
||||
<div class="test-list">
|
||||
"""
|
||||
for test in suite.tests:
|
||||
status_class = test.status
|
||||
error_html = ""
|
||||
if test.error_message:
|
||||
error_html = f'<div class="error-message">{test.error_message}</div>'
|
||||
|
||||
html_content += f"""
|
||||
<div class="test-item {status_class}">
|
||||
<div>
|
||||
<span class="test-name">{test.name}</span>
|
||||
<span class="test-duration">{test.duration:.2f}s</span>
|
||||
{error_html}
|
||||
</div>
|
||||
<span class="test-status {status_class}">{test.status.upper()}</span>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html_content += """
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html_content += """
|
||||
<div class="footer">
|
||||
<p>Generated by Python E2E Test Framework</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write(html_content)
|
||||
|
||||
print(f"📊 HTML报告已生成: {filepath}")
|
||||
return str(filepath)
|
||||
|
||||
def generate_json_report(self, filename: str = "test_report.json") -> str:
|
||||
"""生成JSON报告"""
|
||||
filepath = self.report_dir / filename
|
||||
|
||||
report_data = {
|
||||
"report_info": {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"start_time": self.start_time.isoformat() if self.start_time else None,
|
||||
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||
},
|
||||
"summary": {
|
||||
"total": sum(s.total_count for s in self.suites.values()),
|
||||
"passed": sum(s.passed_count for s in self.suites.values()),
|
||||
"failed": sum(s.failed_count for s in self.suites.values()),
|
||||
"skipped": sum(s.skipped_count for s in self.suites.values()),
|
||||
},
|
||||
"suites": {}
|
||||
}
|
||||
|
||||
for suite_name, suite in self.suites.items():
|
||||
report_data["suites"][suite_name] = {
|
||||
"summary": {
|
||||
"total": suite.total_count,
|
||||
"passed": suite.passed_count,
|
||||
"failed": suite.failed_count,
|
||||
"skipped": suite.skipped_count,
|
||||
"pass_rate": suite.pass_rate,
|
||||
},
|
||||
"tests": [asdict(test) for test in suite.tests]
|
||||
}
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(report_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"📊 JSON报告已生成: {filepath}")
|
||||
return str(filepath)
|
||||
|
||||
def generate_all_reports(self) -> Dict[str, str]:
|
||||
"""生成所有报告"""
|
||||
return {
|
||||
"html": self.generate_html_report(),
|
||||
"json": self.generate_json_report(),
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
重试装饰器模块
|
||||
|
||||
提供测试方法重试功能。
|
||||
"""
|
||||
|
||||
import functools
|
||||
import time
|
||||
from typing import Callable, TypeVar, Optional, Tuple, Type
|
||||
from .exception_handler import TestExceptionHandler, RetryableError, FatalTestError
|
||||
from .logger import get_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def retry_on_failure(
|
||||
max_retries: int = 3,
|
||||
delay: float = 1.0,
|
||||
backoff: float = 2.0,
|
||||
exceptions: Optional[Tuple[Type[Exception], ...]] = None,
|
||||
on_retry: Optional[Callable[[Exception, int], None]] = None,
|
||||
):
|
||||
"""
|
||||
重试装饰器
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
delay: 初始重试延迟(秒)
|
||||
backoff: 延迟增长倍数
|
||||
exceptions: 需要重试的异常类型
|
||||
on_retry: 重试时的回调函数
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
|
||||
Example:
|
||||
@retry_on_failure(max_retries=3, delay=1.0)
|
||||
def test_unstable_feature(page):
|
||||
# 测试代码
|
||||
pass
|
||||
"""
|
||||
if exceptions is None:
|
||||
exceptions = (Exception,)
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> T:
|
||||
logger = get_logger()
|
||||
retry_count = 0
|
||||
current_delay = delay
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except exceptions as e:
|
||||
retry_count += 1
|
||||
last_exception = e
|
||||
|
||||
# 检查是否为致命错误
|
||||
if TestExceptionHandler.is_fatal_error(e):
|
||||
logger.error(f"遇到致命错误,停止重试: {e}")
|
||||
raise
|
||||
|
||||
# 检查是否为可重试错误
|
||||
if not TestExceptionHandler.is_retryable_error(e):
|
||||
logger.error(f"非可重试错误,停止重试: {e}")
|
||||
raise
|
||||
|
||||
if retry_count >= max_retries:
|
||||
logger.error(
|
||||
f"函数 {func.__name__} 在 {max_retries} 次尝试后仍然失败"
|
||||
)
|
||||
raise
|
||||
|
||||
logger.warning(
|
||||
f"函数 {func.__name__} 失败 (尝试 {retry_count}/{max_retries}): {e}"
|
||||
)
|
||||
|
||||
# 调用重试回调
|
||||
if on_retry:
|
||||
on_retry(e, retry_count)
|
||||
|
||||
# 等待后重试
|
||||
time.sleep(current_delay)
|
||||
current_delay *= backoff
|
||||
|
||||
# 如果所有重试都失败,抛出最后的异常
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
|
||||
# 不应该到达这里
|
||||
raise RuntimeError("重试逻辑异常")
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def retry_with_timeout(
|
||||
timeout: float = 30.0,
|
||||
interval: float = 0.5,
|
||||
exceptions: Optional[Tuple[Type[Exception], ...]] = None,
|
||||
):
|
||||
"""
|
||||
超时重试装饰器
|
||||
|
||||
在指定超时时间内持续重试
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
interval: 重试间隔(秒)
|
||||
exceptions: 需要重试的异常类型
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
"""
|
||||
if exceptions is None:
|
||||
exceptions = (Exception,)
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> T:
|
||||
logger = get_logger()
|
||||
start_time = time.time()
|
||||
attempt = 0
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except exceptions as e:
|
||||
attempt += 1
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# 检查是否为致命错误
|
||||
if TestExceptionHandler.is_fatal_error(e):
|
||||
logger.error(f"遇到致命错误,停止重试: {e}")
|
||||
raise
|
||||
|
||||
if elapsed >= timeout:
|
||||
logger.error(
|
||||
f"函数 {func.__name__} 在 {timeout} 秒后超时,共尝试 {attempt} 次"
|
||||
)
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
f"函数 {func.__name__} 失败 (尝试 {attempt}): {e},{interval}秒后重试"
|
||||
)
|
||||
time.sleep(interval)
|
||||
|
||||
raise TimeoutError(f"函数 {func.__name__} 在 {timeout} 秒内未完成")
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
截图辅助工具模块
|
||||
|
||||
提供测试截图功能,支持多种截图模式。
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from playwright.sync_api import Page
|
||||
|
||||
|
||||
class ScreenshotHelper:
|
||||
"""截图辅助工具"""
|
||||
|
||||
def __init__(self, screenshot_dir: str = "reports/screenshots"):
|
||||
self.screenshot_dir = Path(screenshot_dir)
|
||||
self.screenshot_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def take_screenshot(
|
||||
self,
|
||||
page: Page,
|
||||
name: str,
|
||||
full_page: bool = False,
|
||||
selector: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
截取页面截图
|
||||
|
||||
Args:
|
||||
page: Playwright页面对象
|
||||
name: 截图文件名
|
||||
full_page: 是否截取整个页面
|
||||
selector: 特定元素选择器
|
||||
|
||||
Returns:
|
||||
截图文件路径
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{name}_{timestamp}.png"
|
||||
filepath = self.screenshot_dir / filename
|
||||
|
||||
try:
|
||||
if selector:
|
||||
# 截取特定元素
|
||||
element = page.locator(selector)
|
||||
element.screenshot(path=str(filepath))
|
||||
else:
|
||||
# 截取页面
|
||||
page.screenshot(path=str(filepath), full_page=full_page)
|
||||
|
||||
print(f"📸 截图已保存: {filepath}")
|
||||
return str(filepath)
|
||||
except Exception as e:
|
||||
print(f"❌ 截图失败: {e}")
|
||||
return ""
|
||||
|
||||
def take_screenshot_on_failure(
|
||||
self, page: Page, test_name: str, full_page: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
测试失败时截取截图
|
||||
|
||||
Args:
|
||||
page: Playwright页面对象
|
||||
test_name: 测试名称
|
||||
full_page: 是否截取整个页面
|
||||
|
||||
Returns:
|
||||
截图文件路径
|
||||
"""
|
||||
return self.take_screenshot(
|
||||
page, f"{test_name}_failed", full_page=full_page
|
||||
)
|
||||
|
||||
def take_comparison_screenshot(
|
||||
self, page: Page, name: str, baseline_dir: str = "screenshots/baseline"
|
||||
) -> tuple:
|
||||
"""
|
||||
截取对比截图
|
||||
|
||||
Args:
|
||||
page: Playwright页面对象
|
||||
name: 截图名称
|
||||
baseline_dir: 基线截图目录
|
||||
|
||||
Returns:
|
||||
(当前截图路径, 基线截图路径)
|
||||
"""
|
||||
current_path = self.take_screenshot(page, name, full_page=True)
|
||||
|
||||
baseline_path = Path(baseline_dir) / f"{name}.png"
|
||||
|
||||
return current_path, str(baseline_path)
|
||||
|
||||
def cleanup_old_screenshots(self, days: int = 7) -> int:
|
||||
"""
|
||||
清理旧截图
|
||||
|
||||
Args:
|
||||
days: 保留天数
|
||||
|
||||
Returns:
|
||||
删除的文件数量
|
||||
"""
|
||||
import time
|
||||
|
||||
deleted_count = 0
|
||||
cutoff_time = time.time() - (days * 24 * 60 * 60)
|
||||
|
||||
for file_path in self.screenshot_dir.glob("*.png"):
|
||||
if file_path.stat().st_mtime < cutoff_time:
|
||||
file_path.unlink()
|
||||
deleted_count += 1
|
||||
|
||||
print(f"🗑️ 已清理 {deleted_count} 个旧截图文件")
|
||||
return deleted_count
|
||||
@@ -0,0 +1,482 @@
|
||||
"""
|
||||
安全测试模块
|
||||
|
||||
提供SQL注入、XSS、CSRF等安全防护功能。
|
||||
"""
|
||||
|
||||
import re
|
||||
import hashlib
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ThreatLevel(Enum):
|
||||
"""威胁等级"""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""检测结果"""
|
||||
is_threat: bool
|
||||
threat_type: str
|
||||
level: ThreatLevel
|
||||
details: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SQLInjectionResult:
|
||||
"""SQL注入检测结果"""
|
||||
is_injection: bool = False
|
||||
level: ThreatLevel = ThreatLevel.LOW
|
||||
details: str = ""
|
||||
|
||||
@property
|
||||
def is_threat(self) -> bool:
|
||||
return self.is_injection
|
||||
|
||||
@property
|
||||
def threat_type(self) -> str:
|
||||
return "SQL_INJECTION"
|
||||
|
||||
|
||||
@dataclass
|
||||
class XSSResult:
|
||||
"""XSS检测结果"""
|
||||
is_xss: bool = False
|
||||
level: ThreatLevel = ThreatLevel.LOW
|
||||
details: str = ""
|
||||
|
||||
@property
|
||||
def is_threat(self) -> bool:
|
||||
return self.is_xss
|
||||
|
||||
@property
|
||||
def threat_type(self) -> str:
|
||||
return "XSS"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PasswordStrengthResult:
|
||||
"""密码强度结果"""
|
||||
score: int
|
||||
strength: str
|
||||
suggestions: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityEvent:
|
||||
"""安全事件"""
|
||||
timestamp: float
|
||||
event_type: str
|
||||
source_ip: str
|
||||
details: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityReport:
|
||||
"""安全扫描报告"""
|
||||
total_scanned: int
|
||||
threats: List[DetectionResult]
|
||||
scan_time: float
|
||||
|
||||
|
||||
class SQLInjectionDetector:
|
||||
"""SQL注入检测器"""
|
||||
|
||||
# SQL注入特征模式
|
||||
PATTERNS = [
|
||||
r"(\%27)|(\')|(\-\-)|(\%23)|(#)", # 单引号、注释
|
||||
r"((\%3D)|(=))[^\n]*((\%27)|(\')|(\-\-)|(\%3B)|(;))", # =后面跟引号或注释
|
||||
r"\w*((\%27)|(\'))((\%6F)|o|(\%4F))((\%72)|r|(\%52))", # 'or
|
||||
r"((\%27)|(\'))union", # 'union
|
||||
r"exec(\s|\+)+(s|x)p\w+", # exec xp_
|
||||
r"UNION\s+SELECT", # UNION SELECT
|
||||
r"INSERT\s+INTO", # INSERT INTO
|
||||
r"DELETE\s+FROM", # DELETE FROM
|
||||
r"DROP\s+TABLE", # DROP TABLE
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._compiled_patterns = [re.compile(p, re.IGNORECASE) for p in self.PATTERNS]
|
||||
|
||||
def detect(self, input_str: str) -> SQLInjectionResult:
|
||||
"""
|
||||
检测SQL注入
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
|
||||
Returns:
|
||||
检测结果
|
||||
"""
|
||||
for pattern in self._compiled_patterns:
|
||||
if pattern.search(input_str):
|
||||
return SQLInjectionResult(
|
||||
is_injection=True,
|
||||
level=ThreatLevel.HIGH,
|
||||
details=f"匹配模式: {pattern.pattern}"
|
||||
)
|
||||
|
||||
return SQLInjectionResult(is_injection=False, level=ThreatLevel.LOW)
|
||||
|
||||
|
||||
class XSSDetector:
|
||||
"""XSS检测器"""
|
||||
|
||||
# XSS攻击特征模式
|
||||
PATTERNS = [
|
||||
r"<script[^>]*>[\s\S]*?</script>", # <script>标签
|
||||
r"javascript:", # javascript:协议
|
||||
r"on\w+\s*=", # 事件处理器
|
||||
r"<iframe", # iframe标签
|
||||
r"<object", # object标签
|
||||
r"<embed", # embed标签
|
||||
r"<form", # form标签
|
||||
r"<input[^>]*type\s*=\s*['\"]?hidden", # hidden input
|
||||
r"expression\s*\(", # CSS expression
|
||||
r"url\s*\(\s*['\"]?javascript:", # CSS url javascript
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._compiled_patterns = [re.compile(p, re.IGNORECASE) for p in self.PATTERNS]
|
||||
|
||||
def detect(self, input_str: str) -> XSSResult:
|
||||
"""
|
||||
检测XSS攻击
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
|
||||
Returns:
|
||||
检测结果
|
||||
"""
|
||||
for pattern in self._compiled_patterns:
|
||||
if pattern.search(input_str):
|
||||
return XSSResult(
|
||||
is_xss=True,
|
||||
level=ThreatLevel.HIGH,
|
||||
details=f"匹配模式: {pattern.pattern}"
|
||||
)
|
||||
|
||||
return XSSResult(is_xss=False, level=ThreatLevel.LOW)
|
||||
|
||||
|
||||
class CSRFProtector:
|
||||
"""CSRF防护器"""
|
||||
|
||||
def __init__(self, token_expiry: int = 3600):
|
||||
"""
|
||||
初始化CSRF防护器
|
||||
|
||||
Args:
|
||||
token_expiry: Token过期时间(秒)
|
||||
"""
|
||||
self._token_expiry = token_expiry
|
||||
self._tokens: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def generate_token(self, user_id: str) -> str:
|
||||
"""
|
||||
生成CSRF Token
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
Token字符串
|
||||
"""
|
||||
token = secrets.token_urlsafe(32)
|
||||
self._tokens[token] = {
|
||||
"user_id": user_id,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
return token
|
||||
|
||||
def validate_token(self, user_id: str, token: str) -> bool:
|
||||
"""
|
||||
验证CSRF Token
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
token: Token字符串
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
if token not in self._tokens:
|
||||
return False
|
||||
|
||||
token_data = self._tokens[token]
|
||||
|
||||
# 检查用户ID
|
||||
if token_data["user_id"] != user_id:
|
||||
return False
|
||||
|
||||
# 检查是否过期
|
||||
if time.time() - token_data["created_at"] > self._token_expiry:
|
||||
del self._tokens[token]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def invalidate_token(self, token: str) -> None:
|
||||
"""使Token失效"""
|
||||
if token in self._tokens:
|
||||
del self._tokens[token]
|
||||
|
||||
|
||||
class InputSanitizer:
|
||||
"""输入净化器"""
|
||||
|
||||
# HTML危险标签和属性
|
||||
DANGEROUS_TAGS = [
|
||||
"script", "iframe", "object", "embed", "form", "input",
|
||||
"textarea", "button", "link", "meta", "style"
|
||||
]
|
||||
|
||||
DANGEROUS_ATTRIBUTES = [
|
||||
"onerror", "onload", "onclick", "onmouseover", "onmouseout",
|
||||
"onkeydown", "onkeypress", "onkeyup", "onsubmit", "onchange",
|
||||
"onfocus", "onblur", "onselect", "onreset"
|
||||
]
|
||||
|
||||
def sanitize_html(self, html: str) -> str:
|
||||
"""
|
||||
净化HTML内容
|
||||
|
||||
Args:
|
||||
html: HTML字符串
|
||||
|
||||
Returns:
|
||||
净化后的HTML
|
||||
"""
|
||||
# 移除危险标签
|
||||
for tag in self.DANGEROUS_TAGS:
|
||||
pattern = f"<{tag}[^>]*>[\\s\\S]*?</{tag}>"
|
||||
html = re.sub(pattern, "", html, flags=re.IGNORECASE)
|
||||
pattern = f"<{tag}[^>]*/?>"
|
||||
html = re.sub(pattern, "", html, flags=re.IGNORECASE)
|
||||
|
||||
# 移除危险属性
|
||||
for attr in self.DANGEROUS_ATTRIBUTES:
|
||||
pattern = f"\\s{attr}=[\"'][^\"']*[\"']"
|
||||
html = re.sub(pattern, "", html, flags=re.IGNORECASE)
|
||||
|
||||
# 移除javascript:协议
|
||||
html = re.sub(r"javascript:", "", html, flags=re.IGNORECASE)
|
||||
|
||||
return html
|
||||
|
||||
def sanitize_sql(self, input_str: str) -> str:
|
||||
"""
|
||||
净化SQL输入
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
|
||||
Returns:
|
||||
净化后的字符串
|
||||
"""
|
||||
# 转义单引号
|
||||
return input_str.replace("'", "''")
|
||||
|
||||
|
||||
class PasswordStrengthChecker:
|
||||
"""密码强度检查器"""
|
||||
|
||||
def check(self, password: str) -> PasswordStrengthResult:
|
||||
"""
|
||||
检查密码强度
|
||||
|
||||
Args:
|
||||
password: 密码字符串
|
||||
|
||||
Returns:
|
||||
强度结果
|
||||
"""
|
||||
score = 0
|
||||
suggestions = []
|
||||
|
||||
# 长度检查
|
||||
if len(password) >= 8:
|
||||
score += 2
|
||||
elif len(password) >= 6:
|
||||
score += 1
|
||||
else:
|
||||
suggestions.append("密码长度至少8位")
|
||||
|
||||
# 包含小写字母
|
||||
if re.search(r"[a-z]", password):
|
||||
score += 1
|
||||
else:
|
||||
suggestions.append("应包含小写字母")
|
||||
|
||||
# 包含大写字母
|
||||
if re.search(r"[A-Z]", password):
|
||||
score += 1
|
||||
else:
|
||||
suggestions.append("应包含大写字母")
|
||||
|
||||
# 包含数字
|
||||
if re.search(r"\d", password):
|
||||
score += 1
|
||||
else:
|
||||
suggestions.append("应包含数字")
|
||||
|
||||
# 包含特殊字符
|
||||
if re.search(r"[!@#$%^&*(),.?\":{}|<>]", password):
|
||||
score += 2
|
||||
else:
|
||||
suggestions.append("应包含特殊字符")
|
||||
|
||||
# 确定强度等级
|
||||
if score >= 7:
|
||||
strength = "strong"
|
||||
elif score >= 4:
|
||||
strength = "medium"
|
||||
else:
|
||||
strength = "weak"
|
||||
|
||||
return PasswordStrengthResult(
|
||||
score=score,
|
||||
strength=strength,
|
||||
suggestions=suggestions
|
||||
)
|
||||
|
||||
|
||||
class SecurityHeaders:
|
||||
"""安全HTTP头部生成器"""
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取安全HTTP头部
|
||||
|
||||
Returns:
|
||||
安全头部字典
|
||||
"""
|
||||
return {
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||||
"Content-Security-Policy": "default-src 'self'",
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
|
||||
}
|
||||
|
||||
|
||||
class SecurityAuditLogger:
|
||||
"""安全审计日志器"""
|
||||
|
||||
def __init__(self):
|
||||
self._events: List[SecurityEvent] = []
|
||||
|
||||
def log_event(
|
||||
self,
|
||||
event_type: str,
|
||||
source_ip: str,
|
||||
details: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
记录安全事件
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
source_ip: 来源IP
|
||||
details: 详细信息
|
||||
"""
|
||||
event = SecurityEvent(
|
||||
timestamp=time.time(),
|
||||
event_type=event_type,
|
||||
source_ip=source_ip,
|
||||
details=details
|
||||
)
|
||||
self._events.append(event)
|
||||
|
||||
def get_events(
|
||||
self,
|
||||
event_type: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None
|
||||
) -> List[SecurityEvent]:
|
||||
"""
|
||||
获取安全事件
|
||||
|
||||
Args:
|
||||
event_type: 事件类型过滤
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
|
||||
Returns:
|
||||
事件列表
|
||||
"""
|
||||
events = self._events
|
||||
|
||||
if event_type:
|
||||
events = [e for e in events if e.event_type == event_type]
|
||||
|
||||
if start_time:
|
||||
events = [e for e in events if e.timestamp >= start_time]
|
||||
|
||||
if end_time:
|
||||
events = [e for e in events if e.timestamp <= end_time]
|
||||
|
||||
return events
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
event_types = {}
|
||||
for event in self._events:
|
||||
event_types[event.event_type] = event_types.get(event.event_type, 0) + 1
|
||||
|
||||
return {
|
||||
"total_events": len(self._events),
|
||||
"event_types": event_types,
|
||||
}
|
||||
|
||||
|
||||
class SecurityScanner:
|
||||
"""综合安全扫描器"""
|
||||
|
||||
def __init__(self):
|
||||
self._sql_detector = SQLInjectionDetector()
|
||||
self._xss_detector = XSSDetector()
|
||||
|
||||
def scan(self, data: Dict[str, Any]) -> SecurityReport:
|
||||
"""
|
||||
扫描数据
|
||||
|
||||
Args:
|
||||
data: 要扫描的数据
|
||||
|
||||
Returns:
|
||||
扫描报告
|
||||
"""
|
||||
threats = []
|
||||
start_time = time.time()
|
||||
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
# SQL注入检测
|
||||
sql_result = self._sql_detector.detect(value)
|
||||
if sql_result.is_injection:
|
||||
threats.append(sql_result)
|
||||
|
||||
# XSS检测
|
||||
xss_result = self._xss_detector.detect(value)
|
||||
if xss_result.is_xss:
|
||||
threats.append(xss_result)
|
||||
|
||||
scan_time = time.time() - start_time
|
||||
|
||||
return SecurityReport(
|
||||
total_scanned=len(data),
|
||||
threats=threats,
|
||||
scan_time=scan_time
|
||||
)
|
||||
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
定时任务调度器模块
|
||||
|
||||
提供定时任务的创建、调度、执行和管理功能。
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import heapq
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""任务状态"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SchedulerState(Enum):
|
||||
"""调度器状态"""
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""任务定义"""
|
||||
name: str
|
||||
func: Callable
|
||||
interval: float = 0 # 执行间隔(秒)
|
||||
delay: float = 0 # 延迟执行时间(秒)
|
||||
repeat: bool = False # 是否重复执行
|
||||
priority: int = 5 # 优先级(1-10,数字越大优先级越高)
|
||||
on_error: Optional[Callable[[Exception], None]] = None
|
||||
max_retries: int = 0
|
||||
|
||||
# 内部字段
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
next_run_time: float = 0
|
||||
execution_count: int = 0
|
||||
error_count: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.next_run_time == 0:
|
||||
self.next_run_time = time.time() + self.delay
|
||||
|
||||
def __lt__(self, other):
|
||||
# 用于优先级队列比较
|
||||
if self.next_run_time != other.next_run_time:
|
||||
return self.next_run_time < other.next_run_time
|
||||
return self.priority > other.priority
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskExecutionRecord:
|
||||
"""任务执行记录"""
|
||||
task_id: str
|
||||
task_name: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""
|
||||
任务调度器
|
||||
|
||||
特性:
|
||||
- 支持定时和周期性任务
|
||||
- 支持任务优先级
|
||||
- 支持任务取消
|
||||
- 支持错误处理
|
||||
- 支持暂停/恢复
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化调度器"""
|
||||
self._tasks: Dict[str, Task] = {}
|
||||
self._task_queue: List[Task] = []
|
||||
self._lock = threading.RLock()
|
||||
self._condition = threading.Condition(self._lock)
|
||||
self._state = SchedulerState.STOPPED
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
self._execution_records: List[TaskExecutionRecord] = []
|
||||
self._total_executions = 0
|
||||
self._total_errors = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""启动调度器"""
|
||||
with self._lock:
|
||||
if self._state == SchedulerState.RUNNING:
|
||||
return
|
||||
|
||||
self._state = SchedulerState.RUNNING
|
||||
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
|
||||
self._worker_thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""停止调度器"""
|
||||
with self._lock:
|
||||
self._state = SchedulerState.STOPPED
|
||||
self._condition.notify_all()
|
||||
|
||||
if self._worker_thread and self._worker_thread.is_alive():
|
||||
self._worker_thread.join(timeout=5)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""暂停调度器"""
|
||||
with self._lock:
|
||||
self._state = SchedulerState.PAUSED
|
||||
|
||||
def resume(self) -> None:
|
||||
"""恢复调度器"""
|
||||
with self._lock:
|
||||
if self._state == SchedulerState.PAUSED:
|
||||
self._state = SchedulerState.RUNNING
|
||||
self._condition.notify_all()
|
||||
|
||||
def schedule(self, task: Task) -> str:
|
||||
"""
|
||||
调度任务
|
||||
|
||||
Args:
|
||||
task: 要调度的任务
|
||||
|
||||
Returns:
|
||||
任务ID
|
||||
"""
|
||||
with self._lock:
|
||||
self._tasks[task.id] = task
|
||||
heapq.heappush(self._task_queue, task)
|
||||
self._condition.notify()
|
||||
|
||||
# 自动启动调度器
|
||||
if self._state == SchedulerState.STOPPED:
|
||||
self.start()
|
||||
|
||||
return task.id
|
||||
|
||||
def cancel(self, task_id: str) -> bool:
|
||||
"""
|
||||
取消任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
是否成功取消
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self._tasks:
|
||||
task = self._tasks[task_id]
|
||||
task.status = TaskStatus.CANCELLED
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_executions": self._total_executions,
|
||||
"total_errors": self._total_errors,
|
||||
"pending_tasks": len([t for t in self._tasks.values() if t.status == TaskStatus.PENDING]),
|
||||
"running_tasks": len([t for t in self._tasks.values() if t.status == TaskStatus.RUNNING]),
|
||||
"state": self._state.value,
|
||||
}
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
"""工作线程循环"""
|
||||
while True:
|
||||
with self._lock:
|
||||
# 检查状态
|
||||
if self._state == SchedulerState.STOPPED:
|
||||
break
|
||||
|
||||
# 如果暂停,等待恢复
|
||||
while self._state == SchedulerState.PAUSED:
|
||||
self._condition.wait()
|
||||
if self._state == SchedulerState.STOPPED:
|
||||
return
|
||||
|
||||
# 获取下一个要执行的任务
|
||||
task = self._get_next_task()
|
||||
|
||||
if task is None:
|
||||
# 没有任务,等待一段时间
|
||||
self._condition.wait(timeout=0.1)
|
||||
continue
|
||||
|
||||
# 检查任务是否被取消
|
||||
if task.status == TaskStatus.CANCELLED:
|
||||
continue
|
||||
|
||||
# 执行任务
|
||||
task.status = TaskStatus.RUNNING
|
||||
|
||||
# 在锁外执行任务
|
||||
self._execute_task(task)
|
||||
|
||||
def _get_next_task(self) -> Optional[Task]:
|
||||
"""获取下一个要执行的任务"""
|
||||
now = time.time()
|
||||
|
||||
while self._task_queue:
|
||||
task = heapq.heappop(self._task_queue)
|
||||
|
||||
# 检查任务是否有效
|
||||
if task.status == TaskStatus.CANCELLED:
|
||||
continue
|
||||
|
||||
# 检查是否到执行时间
|
||||
if task.next_run_time <= now:
|
||||
return task
|
||||
else:
|
||||
# 还没到时间,放回队列
|
||||
heapq.heappush(self._task_queue, task)
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
def _execute_task(self, task: Task) -> None:
|
||||
"""执行任务"""
|
||||
# 再次检查任务是否被取消
|
||||
if task.status == TaskStatus.CANCELLED:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
success = False
|
||||
error_msg = None
|
||||
|
||||
try:
|
||||
task.func()
|
||||
success = True
|
||||
task.execution_count += 1
|
||||
except Exception as e:
|
||||
success = False
|
||||
error_msg = str(e)
|
||||
task.error_count += 1
|
||||
|
||||
# 调用错误处理回调
|
||||
if task.on_error:
|
||||
try:
|
||||
task.on_error(e)
|
||||
except:
|
||||
pass
|
||||
|
||||
with self._lock:
|
||||
self._total_errors += 1
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# 记录执行
|
||||
with self._lock:
|
||||
self._total_executions += 1
|
||||
self._execution_records.append(TaskExecutionRecord(
|
||||
task_id=task.id,
|
||||
task_name=task.name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
success=success,
|
||||
error=error_msg
|
||||
))
|
||||
|
||||
# 处理周期性任务
|
||||
with self._lock:
|
||||
if task.repeat and task.status != TaskStatus.CANCELLED:
|
||||
if task.error_count <= task.max_retries or task.max_retries == 0:
|
||||
task.status = TaskStatus.PENDING
|
||||
task.next_run_time = time.time() + task.interval
|
||||
heapq.heappush(self._task_queue, task)
|
||||
else:
|
||||
task.status = TaskStatus.ERROR
|
||||
else:
|
||||
task.status = TaskStatus.COMPLETED if success else TaskStatus.ERROR
|
||||
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
验证服务
|
||||
|
||||
提供各种边界条件的验证功能。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
|
||||
|
||||
class ValidationService:
|
||||
"""验证服务类"""
|
||||
|
||||
# 验证规则常量
|
||||
USERNAME_MIN_LENGTH = 3
|
||||
USERNAME_MAX_LENGTH = 20
|
||||
USERNAME_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$')
|
||||
|
||||
EMAIL_PATTERN = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
|
||||
|
||||
ROLE_NAME_MIN_LENGTH = 1
|
||||
ROLE_NAME_MAX_LENGTH = 50
|
||||
ROLE_CODE_MIN_LENGTH = 1
|
||||
ROLE_CODE_MAX_LENGTH = 30
|
||||
ROLE_CODE_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$')
|
||||
|
||||
@staticmethod
|
||||
def validate_username(username: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证用户名
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message)
|
||||
"""
|
||||
if not username:
|
||||
return False, "用户名不能为空"
|
||||
|
||||
if len(username) < ValidationService.USERNAME_MIN_LENGTH:
|
||||
return False, f"用户名长度不能少于{ValidationService.USERNAME_MIN_LENGTH}个字符"
|
||||
|
||||
if len(username) > ValidationService.USERNAME_MAX_LENGTH:
|
||||
return False, f"用户名长度不能超过{ValidationService.USERNAME_MAX_LENGTH}个字符"
|
||||
|
||||
if not ValidationService.USERNAME_PATTERN.match(username):
|
||||
return False, "用户名只能包含字母、数字和下划线"
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def validate_email(email: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证邮箱格式
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message)
|
||||
"""
|
||||
if not email:
|
||||
return False, "邮箱不能为空"
|
||||
|
||||
if not ValidationService.EMAIL_PATTERN.match(email):
|
||||
return False, "邮箱格式不正确"
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def validate_role_name(name: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证角色名称
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message)
|
||||
"""
|
||||
if not name:
|
||||
return False, "角色名称不能为空"
|
||||
|
||||
if len(name) > ValidationService.ROLE_NAME_MAX_LENGTH:
|
||||
return False, f"角色名称长度不能超过{ValidationService.ROLE_NAME_MAX_LENGTH}个字符"
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def validate_role_code(code: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证角色编码
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message)
|
||||
"""
|
||||
if not code:
|
||||
return False, "角色编码不能为空"
|
||||
|
||||
if len(code) > ValidationService.ROLE_CODE_MAX_LENGTH:
|
||||
return False, f"角色编码长度不能超过{ValidationService.ROLE_CODE_MAX_LENGTH}个字符"
|
||||
|
||||
if not ValidationService.ROLE_CODE_PATTERN.match(code):
|
||||
return False, "角色编码只能包含字母、数字和下划线"
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def validate_user_data(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
验证用户数据
|
||||
|
||||
Returns:
|
||||
{"valid": bool, "errors": Dict[str, str]}
|
||||
"""
|
||||
errors = {}
|
||||
|
||||
# 验证用户名
|
||||
if "username" in data:
|
||||
is_valid, error = ValidationService.validate_username(data["username"])
|
||||
if not is_valid:
|
||||
errors["username"] = error
|
||||
|
||||
# 验证邮箱
|
||||
if "email" in data:
|
||||
is_valid, error = ValidationService.validate_email(data["email"])
|
||||
if not is_valid:
|
||||
errors["email"] = error
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def validate_role_data(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
验证角色数据
|
||||
|
||||
Returns:
|
||||
{"valid": bool, "errors": Dict[str, str]}
|
||||
"""
|
||||
errors = {}
|
||||
|
||||
# 验证角色名称
|
||||
if "name" in data:
|
||||
is_valid, error = ValidationService.validate_role_name(data["name"])
|
||||
if not is_valid:
|
||||
errors["name"] = error
|
||||
|
||||
# 验证角色编码
|
||||
if "code" in data:
|
||||
is_valid, error = ValidationService.validate_role_code(data["code"])
|
||||
if not is_valid:
|
||||
errors["code"] = error
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
|
||||
|
||||
# 全局验证服务实例
|
||||
validation_service = ValidationService()
|
||||
Reference in New Issue
Block a user