#!/usr/bin/env python3 """ 重试机制模块 提供重试装饰器和工具函数 """ import time import logging from functools import wraps from typing import Callable, Optional, Type, Tuple, Any from src.logging_config import get_logger logger = get_logger(__name__) def retry( max_attempts: int = 3, delay: float = 1.0, backoff_factor: float = 2.0, exceptions: Optional[Tuple[Type[Exception], ...]] = None, on_retry: Optional[Callable[[int, Exception], None]] = None ) -> Callable: """ 重试装饰器 参数: max_attempts: 最大重试次数 delay: 初始延迟时间(秒) backoff_factor: 退避因子,每次重试延迟时间乘以该因子 exceptions: 要捕获的异常类型,None表示捕获所有异常 on_retry: 重试时的回调函数,参数为 (attempt, exception) 使用示例: @retry(max_attempts=3, delay=2.0, backoff_factor=2.0) def fetch_data(): # 可能失败的代码 pass @retry(max_attempts=5, exceptions=(ConnectionError, TimeoutError)) def network_request(): # 网络请求代码 pass """ def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs) -> Any: last_exception = None for attempt in range(max_attempts): try: return func(*args, **kwargs) except Exception as e: # 检查是否需要捕获此异常 if exceptions and not isinstance(e, exceptions): raise last_exception = e # 如果是最后一次尝试,不再重试 if attempt == max_attempts - 1: logger.error( f"{func.__name__} 在 {max_attempts} 次尝试后仍然失败: {e}" ) raise # 计算延迟时间 current_delay = delay * (backoff_factor ** attempt) logger.warning( f"{func.__name__} 第 {attempt + 1} 次尝试失败: {e}, " f"{current_delay:.2f}秒后重试..." ) # 调用重试回调 if on_retry: try: on_retry(attempt + 1, e) except Exception as callback_error: logger.error(f"重试回调执行失败: {callback_error}") # 等待 time.sleep(current_delay) # 理论上不会到达这里,但为了类型检查 if last_exception: raise last_exception return wrapper return decorator def retry_with_exponential_backoff( max_attempts: int = 3, initial_delay: float = 1.0, max_delay: float = 60.0 ) -> Callable: """ 使用指数退避的重试装饰器 参数: max_attempts: 最大重试次数 initial_delay: 初始延迟时间(秒) max_delay: 最大延迟时间(秒) 使用示例: @retry_with_exponential_backoff(max_attempts=5, initial_delay=2.0) def api_call(): # API调用代码 pass """ def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs) -> Any: last_exception = None for attempt in range(max_attempts): try: return func(*args, **kwargs) except Exception as e: last_exception = e if attempt == max_attempts - 1: logger.error( f"{func.__name__} 在 {max_attempts} 次尝试后仍然失败: {e}" ) raise # 计算延迟时间(指数退避,但不超过最大延迟) current_delay = min(initial_delay * (2 ** attempt), max_delay) logger.warning( f"{func.__name__} 第 {attempt + 1} 次尝试失败: {e}, " f"{current_delay:.2f}秒后重试..." ) time.sleep(current_delay) if last_exception: raise last_exception return wrapper return decorator def retry_on_exception( exception_type: Type[Exception], max_attempts: int = 3, delay: float = 1.0 ) -> Callable: """ 只在特定异常时重试的装饰器 参数: exception_type: 要捕获的异常类型 max_attempts: 最大重试次数 delay: 延迟时间(秒) 使用示例: @retry_on_exception(ConnectionError, max_attempts=5, delay=2.0) def fetch_data(): # 可能抛出 ConnectionError 的代码 pass """ def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs) -> Any: last_exception = None for attempt in range(max_attempts): try: return func(*args, **kwargs) except exception_type as e: last_exception = e if attempt == max_attempts - 1: logger.error( f"{func.__name__} 在 {max_attempts} 次尝试后仍然失败: {e}" ) raise logger.warning( f"{func.__name__} 第 {attempt + 1} 次尝试失败: {e}, " f"{delay:.2f}秒后重试..." ) time.sleep(delay) if last_exception: raise last_exception return wrapper return decorator class RetryContext: """重试上下文管理器""" def __init__( self, operation_name: str, max_attempts: int = 3, delay: float = 1.0, backoff_factor: float = 2.0, exceptions: Optional[Tuple[Type[Exception], ...]] = None ): """ 初始化重试上下文 参数: operation_name: 操作名称 max_attempts: 最大重试次数 delay: 初始延迟时间(秒) backoff_factor: 退避因子 exceptions: 要捕获的异常类型 """ self.operation_name = operation_name self.max_attempts = max_attempts self.delay = delay self.backoff_factor = backoff_factor self.exceptions = exceptions self.attempt = 0 def __enter__(self): self.attempt = 0 return self def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: return False # 检查是否需要捕获此异常 if self.exceptions and not isinstance(exc_val, self.exceptions): return False self.attempt += 1 # 如果超过最大尝试次数,不再重试 if self.attempt >= self.max_attempts: logger.error( f"{self.operation_name} 在 {self.max_attempts} 次尝试后仍然失败: {exc_val}" ) return False # 计算延迟时间 current_delay = self.delay * (self.backoff_factor ** (self.attempt - 1)) logger.warning( f"{self.operation_name} 第 {self.attempt} 次尝试失败: {exc_val}, " f"{current_delay:.2f}秒后重试..." ) # 等待 time.sleep(current_delay) # 抑制异常,继续重试 return True def async_retry( max_attempts: int = 3, delay: float = 1.0, backoff_factor: float = 2.0 ) -> Callable: """ 异步重试装饰器(用于异步函数) 参数: max_attempts: 最大重试次数 delay: 初始延迟时间(秒) backoff_factor: 退避因子 使用示例: @async_retry(max_attempts=3, delay=2.0) async def async_fetch_data(): # 异步代码 pass """ import asyncio def decorator(func: Callable) -> Callable: @wraps(func) async def wrapper(*args, **kwargs) -> Any: last_exception = None for attempt in range(max_attempts): try: return await func(*args, **kwargs) except Exception as e: last_exception = e if attempt == max_attempts - 1: logger.error( f"{func.__name__} 在 {max_attempts} 次尝试后仍然失败: {e}" ) raise # 计算延迟时间 current_delay = delay * (backoff_factor ** attempt) logger.warning( f"{func.__name__} 第 {attempt + 1} 次尝试失败: {e}, " f"{current_delay:.2f}秒后重试..." ) # 异步等待 await asyncio.sleep(current_delay) if last_exception: raise last_exception return wrapper return decorator if __name__ == '__main__': # 测试代码 # 测试重试装饰器 call_count = 0 @retry(max_attempts=3, delay=0.1) def test_retry(): global call_count call_count += 1 print(f"调用次数: {call_count}") if call_count < 3: raise ValueError("测试异常") return "成功" result = test_retry() print(f"测试结果: {result}") # 测试重试上下文管理器 context_call_count = 0 def test_context_operation(): global context_call_count context_call_count += 1 print(f"上下文调用次数: {context_call_count}") if context_call_count < 3: raise ValueError("测试异常") return "成功" with RetryContext("测试操作", max_attempts=3, delay=0.1): result = test_context_operation() print(f"上下文测试结果: {result}") # 测试特定异常重试 @retry_on_exception(ValueError, max_attempts=3, delay=0.1) def test_specific_exception(): raise ValueError("测试异常") try: test_specific_exception() except ValueError as e: print(f"特定异常测试: {e}")