102 lines
3.4 KiB
Python
102 lines
3.4 KiB
Python
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
from slowapi.util import get_remote_address
|
|
from slowapi.errors import RateLimitExceeded
|
|
from fastapi import Request
|
|
import redis
|
|
import os
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Rate limiting configuration
|
|
DEFAULT_RATE_LIMIT = "10/minute" # 10 requests per minute per IP
|
|
WEBHOOK_RATE_LIMIT = "60/minute" # 60 webhook requests per minute per IP
|
|
BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP
|
|
|
|
# Redis configuration for distributed rate limiting (optional)
|
|
REDIS_URL = os.getenv("REDIS_URL", None)
|
|
|
|
|
|
def get_remote_address_with_forwarded(request: Request):
|
|
"""
|
|
Get client IP address, considering forwarded headers from proxies/load balancers
|
|
"""
|
|
# Check for forwarded headers (common in production behind proxies)
|
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
if forwarded_for:
|
|
# Take the first IP in the chain
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fallback to direct connection IP
|
|
return get_remote_address(request)
|
|
|
|
|
|
# Initialize limiter
|
|
if REDIS_URL:
|
|
# Use Redis for distributed rate limiting (recommended for production)
|
|
try:
|
|
import redis
|
|
|
|
redis_client = redis.from_url(REDIS_URL)
|
|
limiter = Limiter(
|
|
key_func=get_remote_address_with_forwarded, storage_uri=REDIS_URL
|
|
)
|
|
logger.info("Rate limiting initialized with Redis backend")
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to connect to Redis: {e}. Using in-memory rate limiting."
|
|
)
|
|
limiter = Limiter(key_func=get_remote_address_with_forwarded)
|
|
else:
|
|
# Use in-memory rate limiting (fine for single instance)
|
|
limiter = Limiter(key_func=get_remote_address_with_forwarded)
|
|
logger.info("Rate limiting initialized with in-memory backend")
|
|
|
|
|
|
def get_api_key_identifier(request: Request) -> str:
|
|
"""
|
|
Get identifier for rate limiting based on API key if available, otherwise IP
|
|
This allows different rate limits per API key
|
|
"""
|
|
# Try to get API key from Authorization header
|
|
auth_header = request.headers.get("Authorization")
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
api_key = auth_header[7:] # Remove "Bearer " prefix
|
|
# Use first 10 chars of API key as identifier (don't log full key)
|
|
return f"api_key:{api_key[:10]}"
|
|
|
|
# Fallback to IP address
|
|
return f"ip:{get_remote_address_with_forwarded(request)}"
|
|
|
|
|
|
# Custom rate limit key function for API key based limiting
|
|
def api_key_rate_limit_key(request: Request):
|
|
return get_api_key_identifier(request)
|
|
|
|
|
|
# Rate limiting decorators for different endpoint types
|
|
webhook_limiter = Limiter(
|
|
key_func=api_key_rate_limit_key, storage_uri=REDIS_URL if REDIS_URL else None
|
|
)
|
|
|
|
|
|
# Custom rate limit exceeded handler
|
|
def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
|
|
"""Custom handler for rate limit exceeded"""
|
|
logger.warning(
|
|
f"Rate limit exceeded for {get_remote_address_with_forwarded(request)}: "
|
|
f"{exc.detail}"
|
|
)
|
|
|
|
response = _rate_limit_exceeded_handler(request, exc)
|
|
|
|
# Add custom headers
|
|
response.headers["X-RateLimit-Limit"] = str(exc.retry_after)
|
|
response.headers["X-RateLimit-Retry-After"] = str(exc.retry_after)
|
|
|
|
return response
|