Source code

Revision control

Copy as Markdown

Other Tools

#!/usr/bin/env python3
"""
Example of using basic authentication middleware with aiohttp client.
This example shows how to implement a middleware that automatically adds
Basic Authentication headers to all requests. The middleware encodes the
username and password in base64 format as required by the HTTP Basic Auth
specification.
This example includes a test server that validates basic auth credentials.
"""
import asyncio
import base64
import binascii
import logging
from aiohttp import (
ClientHandlerType,
ClientRequest,
ClientResponse,
ClientSession,
hdrs,
web,
)
logging.basicConfig(level=logging.DEBUG)
_LOGGER = logging.getLogger(__name__)
class BasicAuthMiddleware:
"""Middleware that adds Basic Authentication to all requests."""
def __init__(self, username: str, password: str) -> None:
self.username = username
self.password = password
self._auth_header = self._encode_credentials()
def _encode_credentials(self) -> str:
"""Encode username and password to base64."""
credentials = f"{self.username}:{self.password}"
encoded = base64.b64encode(credentials.encode()).decode()
return f"Basic {encoded}"
async def __call__(
self,
request: ClientRequest,
handler: ClientHandlerType,
) -> ClientResponse:
"""Add Basic Auth header to the request."""
# Only add auth if not already present
if hdrs.AUTHORIZATION not in request.headers:
request.headers[hdrs.AUTHORIZATION] = self._auth_header
# Proceed with the request
return await handler(request)
class TestServer:
"""Test server for basic auth endpoints."""
async def handle_basic_auth(self, request: web.Request) -> web.Response:
"""Handle basic auth validation."""
# Get expected credentials from path
expected_user = request.match_info["user"]
expected_pass = request.match_info["pass"]
# Check if Authorization header is present
auth_header = request.headers.get(hdrs.AUTHORIZATION, "")
if not auth_header.startswith("Basic "):
return web.Response(
status=401,
text="Unauthorized",
headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'},
)
# Decode the credentials
encoded_creds = auth_header[6:] # Remove "Basic "
try:
decoded = base64.b64decode(encoded_creds).decode()
username, password = decoded.split(":", 1)
except (ValueError, binascii.Error):
return web.Response(
status=401,
text="Invalid credentials format",
headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'},
)
# Validate credentials
if username != expected_user or password != expected_pass:
return web.Response(
status=401,
text="Invalid username or password",
headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'},
)
return web.json_response({"authenticated": True, "user": username})
async def handle_protected_resource(self, request: web.Request) -> web.Response:
"""A protected resource that requires any valid auth."""
auth_header = request.headers.get(hdrs.AUTHORIZATION, "")
if not auth_header.startswith("Basic "):
return web.Response(
status=401,
text="Authentication required",
headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="protected"'},
)
return web.json_response(
{
"message": "Access granted to protected resource",
"auth_provided": True,
}
)
async def run_test_server() -> web.AppRunner:
"""Run a simple test server with basic auth endpoints."""
app = web.Application()
server = TestServer()
app.router.add_get("/basic-auth/{user}/{pass}", server.handle_basic_auth)
app.router.add_get("/protected", server.handle_protected_resource)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "localhost", 8080)
await site.start()
return runner
async def run_tests() -> None:
"""Run all basic auth middleware tests."""
# Create middleware instance
auth_middleware = BasicAuthMiddleware("user", "pass")
# Use middleware in session
async with ClientSession(middlewares=(auth_middleware,)) as session:
# Test 1: Correct credentials endpoint
print("=== Test 1: Correct credentials ===")
async with session.get("http://localhost:8080/basic-auth/user/pass") as resp:
_LOGGER.info("Status: %s", resp.status)
if resp.status == 200:
data = await resp.json()
_LOGGER.info("Response: %s", data)
print("Authentication successful!")
print(f"Authenticated: {data.get('authenticated')}")
print(f"User: {data.get('user')}")
else:
print("Authentication failed!")
print(f"Status: {resp.status}")
text = await resp.text()
print(f"Response: {text}")
# Test 2: Wrong credentials endpoint
print("\n=== Test 2: Wrong credentials endpoint ===")
async with session.get("http://localhost:8080/basic-auth/other/secret") as resp:
if resp.status == 401:
print("Authentication failed as expected (wrong credentials)")
text = await resp.text()
print(f"Response: {text}")
else:
print(f"Unexpected status: {resp.status}")
# Test 3: Protected resource
print("\n=== Test 3: Access protected resource ===")
async with session.get("http://localhost:8080/protected") as resp:
if resp.status == 200:
data = await resp.json()
print("Successfully accessed protected resource!")
print(f"Response: {data}")
else:
print(f"Failed to access protected resource: {resp.status}")
async def main() -> None:
# Start test server
server = await run_test_server()
try:
await run_tests()
finally:
await server.cleanup()
if __name__ == "__main__":
asyncio.run(main())