Files
alpinebits_python/src/alpine_bits_python/rate_limit.py
2025-09-29 13:56:34 +02:00

102 lines
3.4 KiB
Python

from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi import Request
import redis
import os
import logging
logger = logging.getLogger(__name__)
# Rate limiting configuration
DEFAULT_RATE_LIMIT = "10/minute" # 10 requests per minute per IP
WEBHOOK_RATE_LIMIT = "60/minute" # 60 webhook requests per minute per IP
BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP
# Redis configuration for distributed rate limiting (optional)
REDIS_URL = os.getenv("REDIS_URL", None)
def get_remote_address_with_forwarded(request: Request):
"""
Get client IP address, considering forwarded headers from proxies/load balancers
"""
# Check for forwarded headers (common in production behind proxies)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# Take the first IP in the chain
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fallback to direct connection IP
return get_remote_address(request)
# Initialize limiter
if REDIS_URL:
# Use Redis for distributed rate limiting (recommended for production)
try:
import redis
redis_client = redis.from_url(REDIS_URL)
limiter = Limiter(
key_func=get_remote_address_with_forwarded, storage_uri=REDIS_URL
)
logger.info("Rate limiting initialized with Redis backend")
except Exception as e:
logger.warning(
f"Failed to connect to Redis: {e}. Using in-memory rate limiting."
)
limiter = Limiter(key_func=get_remote_address_with_forwarded)
else:
# Use in-memory rate limiting (fine for single instance)
limiter = Limiter(key_func=get_remote_address_with_forwarded)
logger.info("Rate limiting initialized with in-memory backend")
def get_api_key_identifier(request: Request) -> str:
"""
Get identifier for rate limiting based on API key if available, otherwise IP
This allows different rate limits per API key
"""
# Try to get API key from Authorization header
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:] # Remove "Bearer " prefix
# Use first 10 chars of API key as identifier (don't log full key)
return f"api_key:{api_key[:10]}"
# Fallback to IP address
return f"ip:{get_remote_address_with_forwarded(request)}"
# Custom rate limit key function for API key based limiting
def api_key_rate_limit_key(request: Request):
return get_api_key_identifier(request)
# Rate limiting decorators for different endpoint types
webhook_limiter = Limiter(
key_func=api_key_rate_limit_key, storage_uri=REDIS_URL if REDIS_URL else None
)
# Custom rate limit exceeded handler
def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Custom handler for rate limit exceeded"""
logger.warning(
f"Rate limit exceeded for {get_remote_address_with_forwarded(request)}: "
f"{exc.detail}"
)
response = _rate_limit_exceeded_handler(request, exc)
# Add custom headers
response.headers["X-RateLimit-Limit"] = str(exc.retry_after)
response.headers["X-RateLimit-Retry-After"] = str(exc.retry_after)
return response