""" 本地并发控制模块 提供本地并发控制的各种机制,包括信号量、读写锁、限流器等。 """ import threading import time from typing import Any, Dict, List, Optional, TypeVar, Generic from contextlib import contextmanager from collections import deque T = TypeVar('T') class SemaphoreControl: """ 信号量并发控制 限制同时执行的线程数量 """ def __init__(self, max_concurrent: int): """ 初始化信号量 Args: max_concurrent: 最大并发数 """ self._max_concurrent = max_concurrent self._semaphore = threading.Semaphore(max_concurrent) self._stats = { "total_acquisitions": 0, "active_count": 0, "peak_count": 0, } self._lock = threading.Lock() @contextmanager def acquire(self, timeout: Optional[float] = None): """ 获取信号量(上下文管理器) Args: timeout: 超时时间(秒) """ acquired = False try: acquired = self._semaphore.acquire(timeout=timeout) if not acquired: raise TimeoutError("获取信号量超时") with self._lock: self._stats["total_acquisitions"] += 1 self._stats["active_count"] += 1 self._stats["peak_count"] = max( self._stats["peak_count"], self._stats["active_count"] ) yield self finally: if acquired: with self._lock: self._stats["active_count"] -= 1 self._semaphore.release() def get_stats(self) -> Dict[str, Any]: """获取统计信息""" with self._lock: return self._stats.copy() class ReadWriteLock: """ 读写锁 支持多个读线程同时访问,写线程独占访问 """ def __init__(self): self._read_lock = threading.Lock() self._write_lock = threading.Lock() self._read_count = 0 @contextmanager def read_lock(self): """获取读锁""" with self._read_lock: self._read_count += 1 if self._read_count == 1: self._write_lock.acquire() try: yield self finally: with self._read_lock: self._read_count -= 1 if self._read_count == 0: self._write_lock.release() @contextmanager def write_lock(self): """获取写锁""" self._write_lock.acquire() try: yield self finally: self._write_lock.release() class RateLimiter: """ 限流器 限制单位时间内的请求数量 """ def __init__(self, max_requests: int, time_window: float): """ 初始化限流器 Args: max_requests: 时间窗口内最大请求数 time_window: 时间窗口(秒) """ self._max_requests = max_requests self._time_window = time_window self._requests = deque() self._lock = threading.Lock() self._stats = { "total_requests": 0, "allowed_requests": 0, "blocked_requests": 0, } def allow_request(self) -> bool: """ 检查是否允许请求 Returns: 是否允许 """ with self._lock: now = time.time() # 清理过期的请求记录 while self._requests and self._requests[0] < now - self._time_window: self._requests.popleft() self._stats["total_requests"] += 1 if len(self._requests) < self._max_requests: self._requests.append(now) self._stats["allowed_requests"] += 1 return True else: self._stats["blocked_requests"] += 1 return False def get_stats(self) -> Dict[str, Any]: """获取统计信息""" with self._lock: return self._stats.copy() class LocalDistributedLock: """ 本地模拟的分布式锁 用于单体应用内的分布式锁场景模拟 """ _locks: Dict[str, Dict[str, Any]] = {} _global_lock = threading.Lock() def __init__(self, resource_name: str, expire_seconds: float = 30.0): """ 初始化分布式锁 Args: resource_name: 资源名称 expire_seconds: 锁过期时间(秒) """ self._resource_name = resource_name self._expire_seconds = expire_seconds self._lock = threading.Lock() def acquire(self, timeout: float = 10.0) -> bool: """ 获取锁 Args: timeout: 超时时间(秒) """ start_time = time.time() while time.time() - start_time < timeout: with self._global_lock: now = time.time() lock_info = self._locks.get(self._resource_name) # 检查锁是否过期 if lock_info and now > lock_info["expires_at"]: del self._locks[self._resource_name] lock_info = None # 尝试获取锁 if not lock_info: self._locks[self._resource_name] = { "expires_at": now + self._expire_seconds, } return True time.sleep(0.1) return False def release(self) -> None: """释放锁""" with self._global_lock: if self._resource_name in self._locks: del self._locks[self._resource_name] class ConcurrentCounter: """ 并发计数器 线程安全的计数器实现 """ def __init__(self, initial_value: int = 0): """ 初始化计数器 Args: initial_value: 初始值 """ self._value = initial_value self._lock = threading.Lock() def increment(self, delta: int = 1) -> int: """ 增加计数 Args: delta: 增量 Returns: 增加后的值 """ with self._lock: self._value += delta return self._value def decrement(self, delta: int = 1) -> int: """ 减少计数 Args: delta: 减量 Returns: 减少后的值 """ with self._lock: self._value -= delta return self._value def get_value(self) -> int: """获取当前值""" with self._lock: return self._value def reset(self, value: int = 0) -> None: """重置计数器""" with self._lock: self._value = value class ThreadBarrier: """ 线程屏障 等待指定数量的线程到达后同时放行 """ def __init__(self, parties: int): """ 初始化屏障 Args: parties: 需要等待的线程数 """ self._parties = parties self._count = 0 self._lock = threading.Lock() self._condition = threading.Condition(self._lock) def wait(self, timeout: Optional[float] = None) -> bool: """ 等待屏障 Args: timeout: 超时时间(秒) Returns: 是否成功等待 """ with self._condition: self._count += 1 if self._count >= self._parties: # 所有线程已到达,放行 self._count = 0 self._condition.notify_all() return True else: # 等待其他线程 return self._condition.wait(timeout=timeout) class BoundedTaskQueue(Generic[T]): """ 有界任务队列 有容量限制的线程安全队列 """ def __init__(self, max_size: int): """ 初始化队列 Args: max_size: 最大容量 """ self._max_size = max_size self._queue = deque(maxlen=max_size) self._lock = threading.Lock() self._not_full = threading.Condition(self._lock) self._not_empty = threading.Condition(self._lock) def put(self, item: T, timeout: Optional[float] = None) -> bool: """ 添加元素 Args: item: 元素 timeout: 超时时间(秒) Returns: 是否成功 """ with self._not_full: if len(self._queue) >= self._max_size: if not self._not_full.wait(timeout=timeout): raise TimeoutError("队列已满") self._queue.append(item) self._not_empty.notify() return True def get(self, timeout: Optional[float] = None) -> T: """ 获取元素 Args: timeout: 超时时间(秒) Returns: 元素 """ with self._not_empty: while len(self._queue) == 0: if not self._not_empty.wait(timeout=timeout): raise TimeoutError("队列为空") item = self._queue.popleft() self._not_full.notify() return item def size(self) -> int: """获取队列大小""" with self._lock: return len(self._queue) def is_full(self) -> bool: """检查是否已满""" with self._lock: return len(self._queue) >= self._max_size def is_empty(self) -> bool: """检查是否为空""" with self._lock: return len(self._queue) == 0 class ConcurrencyManager: """ 并发控制器管理器 单例模式管理所有并发控制组件 """ _instance = None _lock = threading.Lock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._semaphores: Dict[str, SemaphoreControl] = {} cls._instance._rw_locks: Dict[str, ReadWriteLock] = {} cls._instance._rate_limiters: Dict[str, RateLimiter] = {} cls._instance._locks: Dict[str, LocalDistributedLock] = {} return cls._instance def create_semaphore(self, name: str, max_concurrent: int) -> SemaphoreControl: """创建命名信号量""" if name not in self._semaphores: self._semaphores[name] = SemaphoreControl(max_concurrent) return self._semaphores[name] def get_semaphore(self, name: str) -> Optional[SemaphoreControl]: """获取命名信号量""" return self._semaphores.get(name) def create_rw_lock(self, name: str) -> ReadWriteLock: """创建命名读写锁""" if name not in self._rw_locks: self._rw_locks[name] = ReadWriteLock() return self._rw_locks[name] def get_rw_lock(self, name: str) -> Optional[ReadWriteLock]: """获取命名读写锁""" return self._rw_locks.get(name) def create_rate_limiter(self, name: str, max_requests: int, time_window: float) -> RateLimiter: """创建命名限流器""" if name not in self._rate_limiters: self._rate_limiters[name] = RateLimiter(max_requests, time_window) return self._rate_limiters[name] def get_rate_limiter(self, name: str) -> Optional[RateLimiter]: """获取命名限流器""" return self._rate_limiters.get(name) def create_lock(self, name: str, expire_seconds: float = 30.0) -> LocalDistributedLock: """创建命名分布式锁""" if name not in self._locks: self._locks[name] = LocalDistributedLock(name, expire_seconds) return self._locks[name] def get_lock(self, name: str) -> Optional[LocalDistributedLock]: """获取命名分布式锁""" return self._locks.get(name) def get_all_stats(self) -> Dict[str, Dict[str, Any]]: """获取所有组件统计信息""" return { "semaphores": {name: sem.get_stats() for name, sem in self._semaphores.items()}, "rate_limiters": {name: rl.get_stats() for name, rl in self._rate_limiters.items()}, } # 全局并发管理器实例 concurrency_manager = ConcurrencyManager()