08ea5fbe98
添加用户管理视图、API和状态管理文件
339 lines
10 KiB
Python
339 lines
10 KiB
Python
"""
|
||
认证管理器模块
|
||
提供自动令牌获取、存储、验证和刷新功能
|
||
"""
|
||
|
||
import time
|
||
import json
|
||
import hashlib
|
||
from typing import Optional, Dict, Any
|
||
from pathlib import Path
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timedelta
|
||
|
||
from utils.logger import TestLogger
|
||
from config.settings import config
|
||
|
||
|
||
@dataclass
|
||
class TokenInfo:
|
||
"""令牌信息"""
|
||
token: str
|
||
username: str
|
||
issued_at: float
|
||
expires_at: float
|
||
refresh_token: Optional[str] = None
|
||
|
||
def is_expired(self, buffer_seconds: int = 60) -> bool:
|
||
"""
|
||
检查令牌是否过期
|
||
|
||
Args:
|
||
buffer_seconds: 缓冲时间(秒),提前多少秒认为过期
|
||
|
||
Returns:
|
||
是否过期
|
||
"""
|
||
return time.time() > (self.expires_at - buffer_seconds)
|
||
|
||
def time_until_expiry(self) -> float:
|
||
"""
|
||
获取距离过期的时间
|
||
|
||
Returns:
|
||
距离过期的秒数
|
||
"""
|
||
return max(0, self.expires_at - time.time())
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典"""
|
||
return {
|
||
"token": self.token,
|
||
"username": self.username,
|
||
"issued_at": self.issued_at,
|
||
"expires_at": self.expires_at,
|
||
"refresh_token": self.refresh_token
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: Dict[str, Any]) -> 'TokenInfo':
|
||
"""从字典创建"""
|
||
return cls(
|
||
token=data["token"],
|
||
username=data["username"],
|
||
issued_at=data["issued_at"],
|
||
expires_at=data["expires_at"],
|
||
refresh_token=data.get("refresh_token")
|
||
)
|
||
|
||
|
||
class AuthManager:
|
||
"""认证管理器"""
|
||
|
||
def __init__(self, logger: TestLogger = None):
|
||
"""
|
||
初始化认证管理器
|
||
|
||
Args:
|
||
logger: 日志记录器
|
||
"""
|
||
self.logger = logger or TestLogger("auth_manager", config.logging_file, config.logging_level)
|
||
self.token_info: Optional[TokenInfo] = None
|
||
self.token_cache_file = Path(config.report_output_dir) / "token_cache.json"
|
||
|
||
# 令牌刷新缓冲时间(秒)
|
||
self.refresh_buffer = 60
|
||
|
||
# 加载缓存的令牌
|
||
self._load_cached_token()
|
||
|
||
def _load_cached_token(self) -> None:
|
||
"""从缓存加载令牌"""
|
||
if not self.token_cache_file.exists():
|
||
return
|
||
|
||
try:
|
||
with open(self.token_cache_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
self.token_info = TokenInfo.from_dict(data)
|
||
|
||
if self.token_info.is_expired():
|
||
self.logger.info("缓存的令牌已过期,将重新获取")
|
||
self.token_info = None
|
||
else:
|
||
self.logger.info(f"从缓存加载令牌,剩余有效期: {self.token_info.time_until_expiry():.0f}秒")
|
||
|
||
except Exception as e:
|
||
self.logger.warning(f"加载缓存令牌失败: {e}")
|
||
self.token_info = None
|
||
|
||
def _save_cached_token(self) -> None:
|
||
"""保存令牌到缓存"""
|
||
if self.token_info is None:
|
||
return
|
||
|
||
try:
|
||
self.token_cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
with open(self.token_cache_file, 'w', encoding='utf-8') as f:
|
||
json.dump(self.token_info.to_dict(), f, indent=2)
|
||
|
||
self.logger.info("令牌已缓存")
|
||
|
||
except Exception as e:
|
||
self.logger.warning(f"保存缓存令牌失败: {e}")
|
||
|
||
def _clear_cached_token(self) -> None:
|
||
"""清除缓存的令牌"""
|
||
try:
|
||
if self.token_cache_file.exists():
|
||
self.token_cache_file.unlink()
|
||
self.logger.info("缓存的令牌已清除")
|
||
except Exception as e:
|
||
self.logger.warning(f"清除缓存令牌失败: {e}")
|
||
|
||
def login(
|
||
self,
|
||
username: str = None,
|
||
password: str = None,
|
||
force_refresh: bool = False
|
||
) -> bool:
|
||
"""
|
||
用户登录
|
||
|
||
Args:
|
||
username: 用户名
|
||
password: 密码
|
||
force_refresh: 是否强制刷新令牌
|
||
|
||
Returns:
|
||
是否登录成功
|
||
"""
|
||
username = username or config.auth_username
|
||
password = password or config.auth_password
|
||
|
||
# 检查是否需要重新登录
|
||
if not force_refresh and self.token_info and not self.token_info.is_expired(self.refresh_buffer):
|
||
self.logger.info(f"使用现有令牌,剩余有效期: {self.token_info.time_until_expiry():.0f}秒")
|
||
return True
|
||
|
||
# 执行登录
|
||
self.logger.info(f"用户登录: {username}")
|
||
|
||
try:
|
||
import requests
|
||
|
||
login_url = f"{config.api_base_url}{config.auth_login_endpoint}"
|
||
|
||
response = requests.post(
|
||
login_url,
|
||
json={"username": username, "password": password},
|
||
timeout=config.api_timeout
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
|
||
# 兼容两种响应格式:
|
||
# 格式1: {"code": 200, "data": {"token": "...", "user": {...}}}
|
||
# 格式2: {"token": "...", "user": {...}}
|
||
token = None
|
||
|
||
if "data" in data and isinstance(data["data"], dict):
|
||
token = data["data"].get("token")
|
||
else:
|
||
token = data.get("token")
|
||
|
||
if token:
|
||
# 解析JWT令牌获取过期时间
|
||
expires_at = self._parse_token_expiry(token)
|
||
|
||
# 创建令牌信息
|
||
self.token_info = TokenInfo(
|
||
token=token,
|
||
username=username,
|
||
issued_at=time.time(),
|
||
expires_at=expires_at,
|
||
refresh_token=data.get("refreshToken") if "data" in data else None
|
||
)
|
||
|
||
# 缓存令牌
|
||
self._save_cached_token()
|
||
|
||
self.logger.info(f"✅ 登录成功,令牌有效期: {(expires_at - time.time()):.0f}秒")
|
||
return True
|
||
|
||
self.logger.error(f"❌ 登录失败: {response.text}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"❌ 登录异常: {str(e)}")
|
||
return False
|
||
|
||
def _parse_token_expiry(self, token: str) -> float:
|
||
"""
|
||
解析JWT令牌的过期时间
|
||
|
||
Args:
|
||
token: JWT令牌
|
||
|
||
Returns:
|
||
过期时间戳
|
||
"""
|
||
try:
|
||
# JWT格式: header.payload.signature
|
||
parts = token.split('.')
|
||
if len(parts) != 3:
|
||
raise ValueError("无效的JWT令牌格式")
|
||
|
||
# 解码payload(Base64URL编码)
|
||
import base64
|
||
|
||
payload = parts[1]
|
||
# 添加必要的填充
|
||
padding = 4 - len(payload) % 4
|
||
if padding != 4:
|
||
payload += '=' * padding
|
||
|
||
decoded = base64.urlsafe_b64decode(payload)
|
||
payload_data = json.loads(decoded)
|
||
|
||
# 获取过期时间(exp字段是Unix时间戳,秒)
|
||
exp = payload_data.get('exp')
|
||
|
||
if exp:
|
||
return float(exp)
|
||
|
||
# 如果没有exp字段,默认24小时后过期
|
||
return time.time() + 24 * 3600
|
||
|
||
except Exception as e:
|
||
self.logger.warning(f"解析令牌过期时间失败: {e},使用默认过期时间")
|
||
return time.time() + 24 * 3600
|
||
|
||
def get_token(self, auto_refresh: bool = True) -> Optional[str]:
|
||
"""
|
||
获取当前令牌
|
||
|
||
Args:
|
||
auto_refresh: 是否自动刷新过期令牌
|
||
|
||
Returns:
|
||
令牌字符串,如果未登录则返回None
|
||
"""
|
||
if self.token_info is None:
|
||
return None
|
||
|
||
# 检查令牌是否过期
|
||
if self.token_info.is_expired(self.refresh_buffer):
|
||
if auto_refresh:
|
||
self.logger.info("令牌即将过期,尝试自动刷新")
|
||
if self.login(force_refresh=True):
|
||
return self.token_info.token
|
||
else:
|
||
self.logger.error("自动刷新令牌失败")
|
||
return None
|
||
else:
|
||
self.logger.warning("令牌已过期")
|
||
return None
|
||
|
||
return self.token_info.token
|
||
|
||
def get_auth_header(self, auto_refresh: bool = True) -> Dict[str, str]:
|
||
"""
|
||
获取认证请求头
|
||
|
||
Args:
|
||
auto_refresh: 是否自动刷新过期令牌
|
||
|
||
Returns:
|
||
认证请求头字典
|
||
"""
|
||
token = self.get_token(auto_refresh)
|
||
|
||
if token:
|
||
return {"Authorization": f"Bearer {token}"}
|
||
else:
|
||
return {}
|
||
|
||
def logout(self) -> None:
|
||
"""用户登出"""
|
||
self.token_info = None
|
||
self._clear_cached_token()
|
||
self.logger.info("用户已登出")
|
||
|
||
def is_authenticated(self) -> bool:
|
||
"""
|
||
检查是否已认证
|
||
|
||
Returns:
|
||
是否已认证
|
||
"""
|
||
return self.token_info is not None and not self.token_info.is_expired()
|
||
|
||
def get_token_info(self) -> Optional[TokenInfo]:
|
||
"""
|
||
获取令牌信息
|
||
|
||
Returns:
|
||
令牌信息对象
|
||
"""
|
||
return self.token_info
|
||
|
||
def ensure_authenticated(self, username: str = None, password: str = None) -> bool:
|
||
"""
|
||
确保已认证,如果未认证则自动登录
|
||
|
||
Args:
|
||
username: 用户名
|
||
password: 密码
|
||
|
||
Returns:
|
||
是否认证成功
|
||
"""
|
||
if self.is_authenticated():
|
||
return True
|
||
|
||
return self.login(username, password)
|