Source code
Revision control
Copy as Markdown
Other Tools
#!/usr/bin/env python3
"""
Example of using retry middleware with aiohttp client.
This example shows how to implement a middleware that automatically retries
failed requests with exponential backoff. The middleware can be configured
with custom retry statuses, maximum retries, and backoff parameters.
This example includes a test server that simulates various HTTP responses
and can return different status codes on sequential requests.
"""
import asyncio
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Set, Union
from aiohttp import ClientHandlerType, ClientRequest, ClientResponse, ClientSession, web
logging.basicConfig(level=logging.INFO)
_LOGGER = logging.getLogger(__name__)
DEFAULT_RETRY_STATUSES: Set[HTTPStatus] = {
HTTPStatus.TOO_MANY_REQUESTS,
HTTPStatus.INTERNAL_SERVER_ERROR,
HTTPStatus.BAD_GATEWAY,
HTTPStatus.SERVICE_UNAVAILABLE,
HTTPStatus.GATEWAY_TIMEOUT,
}
class RetryMiddleware:
"""Middleware that retries failed requests with exponential backoff."""
def __init__(
self,
max_retries: int = 3,
retry_statuses: Union[Set[HTTPStatus], None] = None,
initial_delay: float = 1.0,
backoff_factor: float = 2.0,
) -> None:
self.max_retries = max_retries
self.retry_statuses = retry_statuses or DEFAULT_RETRY_STATUSES
self.initial_delay = initial_delay
self.backoff_factor = backoff_factor
async def __call__(
self,
request: ClientRequest,
handler: ClientHandlerType,
) -> ClientResponse:
"""Execute request with retry logic."""
last_response: Union[ClientResponse, None] = None
delay = self.initial_delay
for attempt in range(self.max_retries + 1):
if attempt > 0:
_LOGGER.info(
"Retrying request to %s (attempt %s/%s)",
request.url,
attempt + 1,
self.max_retries + 1,
)
# Execute the request
response = await handler(request)
last_response = response
# Check if we should retry
if response.status not in self.retry_statuses:
return response
# Don't retry if we've exhausted attempts
if attempt >= self.max_retries:
_LOGGER.warning(
"Max retries (%s) exceeded for %s", self.max_retries, request.url
)
return response
# Wait before retrying
_LOGGER.debug("Waiting %ss before retry...", delay)
await asyncio.sleep(delay)
delay *= self.backoff_factor
# Return the last response
if TYPE_CHECKING:
assert last_response is not None # Always set since we loop at least once
return last_response
class TestServer:
"""Test server with stateful endpoints for retry testing."""
def __init__(self) -> None:
self.request_counters: Dict[str, int] = {}
self.status_sequences: Dict[str, List[int]] = {
"eventually-ok": [500, 503, 502, 200], # Fails 3 times, then succeeds
"always-error": [500, 500, 500, 500], # Always fails
"immediate-ok": [200], # Succeeds immediately
"flaky": [503, 200], # Fails once, then succeeds
}
async def handle_status(self, request: web.Request) -> web.Response:
"""Return the status code specified in the path."""
status = int(request.match_info["status"])
return web.Response(status=status, text=f"Status: {status}")
async def handle_status_sequence(self, request: web.Request) -> web.Response:
"""Return different status codes on sequential requests."""
path = request.path
# Initialize counter for this path if needed
if path not in self.request_counters:
self.request_counters[path] = 0
# Get the status sequence for this path
sequence_name = request.match_info["name"]
if sequence_name not in self.status_sequences:
return web.Response(status=404, text="Sequence not found")
sequence = self.status_sequences[sequence_name]
# Get the current status based on request count
count = self.request_counters[path]
if count < len(sequence):
status = sequence[count]
else:
# After sequence ends, always return the last status
status = sequence[-1]
# Increment counter for next request
self.request_counters[path] += 1
return web.Response(
status=status, text=f"Request #{count + 1}: Status {status}"
)
async def handle_delay(self, request: web.Request) -> web.Response:
"""Delay response by specified seconds."""
delay = float(request.match_info["delay"])
await asyncio.sleep(delay)
return web.json_response({"delay": delay, "message": "Response after delay"})
async def handle_reset(self, request: web.Request) -> web.Response:
"""Reset request counters."""
self.request_counters = {}
return web.Response(text="Counters reset")
async def run_test_server() -> web.AppRunner:
"""Run a simple test server."""
app = web.Application()
server = TestServer()
app.router.add_get("/status/{status}", server.handle_status)
app.router.add_get("/sequence/{name}", server.handle_status_sequence)
app.router.add_get("/delay/{delay}", server.handle_delay)
app.router.add_post("/reset", server.handle_reset)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "localhost", 8080)
await site.start()
return runner
async def run_tests() -> None:
"""Run all retry middleware tests."""
# Create retry middleware with custom settings
retry_middleware = RetryMiddleware(
max_retries=3,
retry_statuses=DEFAULT_RETRY_STATUSES,
initial_delay=0.5,
backoff_factor=2.0,
)
async with ClientSession(middlewares=(retry_middleware,)) as session:
# Reset counters before tests
# Test 1: Request that succeeds immediately
print("=== Test 1: Immediate success ===")
text = await resp.text()
print(f"Final status: {resp.status}")
print(f"Response: {text}")
print("Success - no retries needed\n")
# Test 2: Request that eventually succeeds after retries
print("=== Test 2: Eventually succeeds (500->503->502->200) ===")
text = await resp.text()
print(f"Final status: {resp.status}")
print(f"Response: {text}")
if resp.status == 200:
print("Success after retries!\n")
else:
print("Failed after retries\n")
# Test 3: Request that always fails
print("=== Test 3: Always fails (500->500->500->500) ===")
text = await resp.text()
print(f"Final status: {resp.status}")
print(f"Response: {text}")
print("Failed after exhausting all retries\n")
# Test 4: Flaky service (fails once then succeeds)
print("=== Test 4: Flaky service (503->200) ===")
text = await resp.text()
print(f"Final status: {resp.status}")
print(f"Response: {text}")
print("Success after one retry!\n")
# Test 5: Non-retryable status
print("=== Test 5: Non-retryable status (404) ===")
print(f"Final status: {resp.status}")
print("Failed immediately - not a retryable status\n")
# Test 6: Delayed response
print("=== Test 6: Testing with delay endpoint ===")
try:
print(f"Status: {resp.status}")
data = await resp.json()
print(f"Response received after delay: {data}\n")
except asyncio.TimeoutError:
print("Request timed out\n")
async def main() -> None:
# Start test server
server = await run_test_server()
try:
await run_tests()
finally:
await server.cleanup()
if __name__ == "__main__":
asyncio.run(main())