Files
张翔 08ea5fbe98 feat(admin): 添加用户管理相关文件
添加用户管理视图、API和状态管理文件
2026-03-28 14:37:29 +08:00

471 lines
13 KiB
Python

"""
本地并发控制模块
提供本地并发控制的各种机制,包括信号量、读写锁、限流器等。
"""
import threading
import time
from typing import Any, Dict, List, Optional, TypeVar, Generic
from contextlib import contextmanager
from collections import deque
T = TypeVar('T')
class SemaphoreControl:
"""
信号量并发控制
限制同时执行的线程数量
"""
def __init__(self, max_concurrent: int):
"""
初始化信号量
Args:
max_concurrent: 最大并发数
"""
self._max_concurrent = max_concurrent
self._semaphore = threading.Semaphore(max_concurrent)
self._stats = {
"total_acquisitions": 0,
"active_count": 0,
"peak_count": 0,
}
self._lock = threading.Lock()
@contextmanager
def acquire(self, timeout: Optional[float] = None):
"""
获取信号量(上下文管理器)
Args:
timeout: 超时时间(秒)
"""
acquired = False
try:
acquired = self._semaphore.acquire(timeout=timeout)
if not acquired:
raise TimeoutError("获取信号量超时")
with self._lock:
self._stats["total_acquisitions"] += 1
self._stats["active_count"] += 1
self._stats["peak_count"] = max(
self._stats["peak_count"],
self._stats["active_count"]
)
yield self
finally:
if acquired:
with self._lock:
self._stats["active_count"] -= 1
self._semaphore.release()
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
with self._lock:
return self._stats.copy()
class ReadWriteLock:
"""
读写锁
支持多个读线程同时访问,写线程独占访问
"""
def __init__(self):
self._read_lock = threading.Lock()
self._write_lock = threading.Lock()
self._read_count = 0
@contextmanager
def read_lock(self):
"""获取读锁"""
with self._read_lock:
self._read_count += 1
if self._read_count == 1:
self._write_lock.acquire()
try:
yield self
finally:
with self._read_lock:
self._read_count -= 1
if self._read_count == 0:
self._write_lock.release()
@contextmanager
def write_lock(self):
"""获取写锁"""
self._write_lock.acquire()
try:
yield self
finally:
self._write_lock.release()
class RateLimiter:
"""
限流器
限制单位时间内的请求数量
"""
def __init__(self, max_requests: int, time_window: float):
"""
初始化限流器
Args:
max_requests: 时间窗口内最大请求数
time_window: 时间窗口(秒)
"""
self._max_requests = max_requests
self._time_window = time_window
self._requests = deque()
self._lock = threading.Lock()
self._stats = {
"total_requests": 0,
"allowed_requests": 0,
"blocked_requests": 0,
}
def allow_request(self) -> bool:
"""
检查是否允许请求
Returns:
是否允许
"""
with self._lock:
now = time.time()
# 清理过期的请求记录
while self._requests and self._requests[0] < now - self._time_window:
self._requests.popleft()
self._stats["total_requests"] += 1
if len(self._requests) < self._max_requests:
self._requests.append(now)
self._stats["allowed_requests"] += 1
return True
else:
self._stats["blocked_requests"] += 1
return False
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
with self._lock:
return self._stats.copy()
class LocalDistributedLock:
"""
本地模拟的分布式锁
用于单体应用内的分布式锁场景模拟
"""
_locks: Dict[str, Dict[str, Any]] = {}
_global_lock = threading.Lock()
def __init__(self, resource_name: str, expire_seconds: float = 30.0):
"""
初始化分布式锁
Args:
resource_name: 资源名称
expire_seconds: 锁过期时间(秒)
"""
self._resource_name = resource_name
self._expire_seconds = expire_seconds
self._lock = threading.Lock()
def acquire(self, timeout: float = 10.0) -> bool:
"""
获取锁
Args:
timeout: 超时时间(秒)
"""
start_time = time.time()
while time.time() - start_time < timeout:
with self._global_lock:
now = time.time()
lock_info = self._locks.get(self._resource_name)
# 检查锁是否过期
if lock_info and now > lock_info["expires_at"]:
del self._locks[self._resource_name]
lock_info = None
# 尝试获取锁
if not lock_info:
self._locks[self._resource_name] = {
"expires_at": now + self._expire_seconds,
}
return True
time.sleep(0.1)
return False
def release(self) -> None:
"""释放锁"""
with self._global_lock:
if self._resource_name in self._locks:
del self._locks[self._resource_name]
class ConcurrentCounter:
"""
并发计数器
线程安全的计数器实现
"""
def __init__(self, initial_value: int = 0):
"""
初始化计数器
Args:
initial_value: 初始值
"""
self._value = initial_value
self._lock = threading.Lock()
def increment(self, delta: int = 1) -> int:
"""
增加计数
Args:
delta: 增量
Returns:
增加后的值
"""
with self._lock:
self._value += delta
return self._value
def decrement(self, delta: int = 1) -> int:
"""
减少计数
Args:
delta: 减量
Returns:
减少后的值
"""
with self._lock:
self._value -= delta
return self._value
def get_value(self) -> int:
"""获取当前值"""
with self._lock:
return self._value
def reset(self, value: int = 0) -> None:
"""重置计数器"""
with self._lock:
self._value = value
class ThreadBarrier:
"""
线程屏障
等待指定数量的线程到达后同时放行
"""
def __init__(self, parties: int):
"""
初始化屏障
Args:
parties: 需要等待的线程数
"""
self._parties = parties
self._count = 0
self._lock = threading.Lock()
self._condition = threading.Condition(self._lock)
def wait(self, timeout: Optional[float] = None) -> bool:
"""
等待屏障
Args:
timeout: 超时时间(秒)
Returns:
是否成功等待
"""
with self._condition:
self._count += 1
if self._count >= self._parties:
# 所有线程已到达,放行
self._count = 0
self._condition.notify_all()
return True
else:
# 等待其他线程
return self._condition.wait(timeout=timeout)
class BoundedTaskQueue(Generic[T]):
"""
有界任务队列
有容量限制的线程安全队列
"""
def __init__(self, max_size: int):
"""
初始化队列
Args:
max_size: 最大容量
"""
self._max_size = max_size
self._queue = deque(maxlen=max_size)
self._lock = threading.Lock()
self._not_full = threading.Condition(self._lock)
self._not_empty = threading.Condition(self._lock)
def put(self, item: T, timeout: Optional[float] = None) -> bool:
"""
添加元素
Args:
item: 元素
timeout: 超时时间(秒)
Returns:
是否成功
"""
with self._not_full:
if len(self._queue) >= self._max_size:
if not self._not_full.wait(timeout=timeout):
raise TimeoutError("队列已满")
self._queue.append(item)
self._not_empty.notify()
return True
def get(self, timeout: Optional[float] = None) -> T:
"""
获取元素
Args:
timeout: 超时时间(秒)
Returns:
元素
"""
with self._not_empty:
while len(self._queue) == 0:
if not self._not_empty.wait(timeout=timeout):
raise TimeoutError("队列为空")
item = self._queue.popleft()
self._not_full.notify()
return item
def size(self) -> int:
"""获取队列大小"""
with self._lock:
return len(self._queue)
def is_full(self) -> bool:
"""检查是否已满"""
with self._lock:
return len(self._queue) >= self._max_size
def is_empty(self) -> bool:
"""检查是否为空"""
with self._lock:
return len(self._queue) == 0
class ConcurrencyManager:
"""
并发控制器管理器
单例模式管理所有并发控制组件
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._semaphores: Dict[str, SemaphoreControl] = {}
cls._instance._rw_locks: Dict[str, ReadWriteLock] = {}
cls._instance._rate_limiters: Dict[str, RateLimiter] = {}
cls._instance._locks: Dict[str, LocalDistributedLock] = {}
return cls._instance
def create_semaphore(self, name: str, max_concurrent: int) -> SemaphoreControl:
"""创建命名信号量"""
if name not in self._semaphores:
self._semaphores[name] = SemaphoreControl(max_concurrent)
return self._semaphores[name]
def get_semaphore(self, name: str) -> Optional[SemaphoreControl]:
"""获取命名信号量"""
return self._semaphores.get(name)
def create_rw_lock(self, name: str) -> ReadWriteLock:
"""创建命名读写锁"""
if name not in self._rw_locks:
self._rw_locks[name] = ReadWriteLock()
return self._rw_locks[name]
def get_rw_lock(self, name: str) -> Optional[ReadWriteLock]:
"""获取命名读写锁"""
return self._rw_locks.get(name)
def create_rate_limiter(self, name: str, max_requests: int, time_window: float) -> RateLimiter:
"""创建命名限流器"""
if name not in self._rate_limiters:
self._rate_limiters[name] = RateLimiter(max_requests, time_window)
return self._rate_limiters[name]
def get_rate_limiter(self, name: str) -> Optional[RateLimiter]:
"""获取命名限流器"""
return self._rate_limiters.get(name)
def create_lock(self, name: str, expire_seconds: float = 30.0) -> LocalDistributedLock:
"""创建命名分布式锁"""
if name not in self._locks:
self._locks[name] = LocalDistributedLock(name, expire_seconds)
return self._locks[name]
def get_lock(self, name: str) -> Optional[LocalDistributedLock]:
"""获取命名分布式锁"""
return self._locks.get(name)
def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
"""获取所有组件统计信息"""
return {
"semaphores": {name: sem.get_stats() for name, sem in self._semaphores.items()},
"rate_limiters": {name: rl.get_stats() for name, rl in self._rate_limiters.items()},
}
# 全局并发管理器实例
concurrency_manager = ConcurrencyManager()