Got db saving working

This commit is contained in:
Jonas Linter
2025-09-29 13:56:34 +02:00
parent 384fb2b558
commit 06739ebea9
21 changed files with 1188 additions and 830 deletions

View File

@@ -1,18 +1,29 @@
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Depends, APIRouter, Form, File, UploadFile
from fastapi import (
FastAPI,
HTTPException,
BackgroundTasks,
Request,
Depends,
APIRouter,
Form,
File,
UploadFile,
)
from fastapi.concurrency import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPBasicCredentials, HTTPBasic
from .config_loader import load_config
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
from .models import WixFormSubmission
from datetime import datetime, date, timezone
from datetime import datetime, date, timezone
from .auth import validate_api_key, validate_wix_signature, generate_api_key
from .rate_limit import (
limiter,
webhook_limiter,
limiter,
webhook_limiter,
custom_rate_limit_handler,
DEFAULT_RATE_LIMIT,
WEBHOOK_RATE_LIMIT,
BURST_RATE_LIMIT
BURST_RATE_LIMIT,
)
from slowapi.errors import RateLimitExceeded
import logging
@@ -24,8 +35,14 @@ import gzip
import xml.etree.ElementTree as ET
from .alpinebits_server import AlpineBitsServer, Version
import urllib.parse
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from .db import get_async_session, Customer as DBCustomer, Reservation as DBReservation
from .db import (
Base,
Customer as DBCustomer,
Reservation as DBReservation,
get_database_url,
)
# Configure logging
@@ -42,12 +59,36 @@ except Exception as e:
_LOGGER.error(f"Failed to load config: {str(e)}")
config = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Setup DB
DATABASE_URL = get_database_url(config)
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
app.state.engine = engine
app.state.async_sessionmaker = AsyncSessionLocal
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
_LOGGER.info("Database tables checked/created at startup.")
yield
# Optional: Dispose engine on shutdown
await engine.dispose()
async def get_async_session(request: Request):
async_sessionmaker = request.app.state.async_sessionmaker
async with async_sessionmaker() as session:
yield session
app = FastAPI(
title="Wix Form Handler API",
description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting",
version="1.0.0"
version="1.0.0",
lifespan=lifespan
)
# Create API router with /api prefix
@@ -62,9 +103,9 @@ app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://*.wix.com",
"https://*.wixstatic.com",
"https://*.wixstatic.com",
"http://localhost:3000", # For development
"http://localhost:8000" # For local testing
"http://localhost:8000", # For local testing
],
allow_credentials=True,
allow_methods=["GET", "POST"],
@@ -78,27 +119,39 @@ async def process_form_submission(submission_data: Dict[str, Any]) -> None:
Add your business logic here.
"""
try:
_LOGGER.info(f"Processing form submission: {submission_data.get('submissionId')}")
_LOGGER.info(
f"Processing form submission: {submission_data.get('submissionId')}"
)
# Example processing - you can replace this with your actual logic
form_name = submission_data.get('formName')
contact_email = submission_data.get('contact', {}).get('email') if submission_data.get('contact') else None
form_name = submission_data.get("formName")
contact_email = (
submission_data.get("contact", {}).get("email")
if submission_data.get("contact")
else None
)
# Extract form fields
form_fields = {k: v for k, v in submission_data.items() if k.startswith('field:')}
_LOGGER.info(f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}")
form_fields = {
k: v for k, v in submission_data.items() if k.startswith("field:")
}
_LOGGER.info(
f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}"
)
# Here you could:
# - Save to database
# - Send emails
# - Call external APIs
# - Process the data further
except Exception as e:
_LOGGER.error(f"Error processing form submission: {str(e)}")
@api_router.get("/")
@limiter.limit(DEFAULT_RATE_LIMIT)
async def root(request: Request):
@@ -111,8 +164,8 @@ async def root(request: Request):
"rate_limits": {
"default": DEFAULT_RATE_LIMIT,
"webhook": WEBHOOK_RATE_LIMIT,
"burst": BURST_RATE_LIMIT
}
"burst": BURST_RATE_LIMIT,
},
}
@@ -126,11 +179,10 @@ async def health_check(request: Request):
"service": "wix-form-handler",
"version": "1.0.0",
"authentication": "enabled",
"rate_limiting": "enabled"
"rate_limiting": "enabled",
}
# Extracted business logic for handling Wix form submissions
async def process_wix_form_submission(request: Request, data: Dict[str, Any], db):
"""
@@ -138,10 +190,9 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
"""
timestamp = datetime.now().isoformat()
_LOGGER.info(f"Received Wix form data at {timestamp}")
#_LOGGER.info(f"Data keys: {list(data.keys())}")
#_LOGGER.info(f"Full data: {json.dumps(data, indent=2)}")
# _LOGGER.info(f"Data keys: {list(data.keys())}")
# _LOGGER.info(f"Full data: {json.dumps(data, indent=2)}")
log_entry = {
"timestamp": timestamp,
"client_ip": request.client.host if request.client else "unknown",
@@ -154,9 +205,13 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
if not os.path.exists(logs_dir):
os.makedirs(logs_dir, mode=0o755, exist_ok=True)
stat_info = os.stat(logs_dir)
_LOGGER.info(f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}")
_LOGGER.info(
f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}"
)
_LOGGER.info(f"Directory mode: {oct(stat_info.st_mode)[-3:]}")
log_filename = f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
log_filename = (
f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
)
with open(log_filename, "w", encoding="utf-8") as f:
json.dump(log_entry, f, indent=2, default=str, ensure_ascii=False)
file_stat = os.stat(log_filename)
@@ -164,16 +219,10 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
_LOGGER.info(f"File mode: {oct(file_stat.st_mode)[-3:]}")
_LOGGER.info(f"Data logged to: {log_filename}")
data = data.get("data") # Handle nested "data" key if present
# save customer and reservation to DB
contact_info = data.get("contact", {})
first_name = contact_info.get("name", {}).get("first")
last_name = contact_info.get("name", {}).get("last")
@@ -193,10 +242,18 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
language = data.get("contact", {}).get("locale", "en")[:2]
# Dates
start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value")
end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value")
start_date = (
data.get("field:date_picker_a7c8")
or data.get("Anreisedatum")
or data.get("submissions", [{}])[1].get("value")
)
end_date = (
data.get("field:date_picker_7e65")
or data.get("Abreisedatum")
or data.get("submissions", [{}])[2].get("value")
)
# Room/guest info
# Room/guest info
num_adults = int(data.get("field:number_7cf5") or 2)
num_children = int(data.get("field:anzahl_kinder") or 0)
children_ages = []
@@ -258,7 +315,7 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
end_date=date.fromisoformat(end_date) if end_date else None,
num_adults=num_adults,
num_children=num_children,
children_ages=','.join(str(a) for a in children_ages),
children_ages=",".join(str(a) for a in children_ages),
offer=offer,
utm_comment=utm_comment,
created_at=datetime.now(timezone.utc),
@@ -277,23 +334,21 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
await db.commit()
await db.refresh(db_reservation)
return {
"status": "success",
"message": "Wix form data received successfully",
"received_keys": list(data.keys()),
"data_logged_to": log_filename,
"timestamp": timestamp,
"process_info": log_entry["process_info"],
"note": "No authentication required for this endpoint"
"note": "No authentication required for this endpoint",
}
@api_router.post("/webhook/wix-form")
@webhook_limiter.limit(WEBHOOK_RATE_LIMIT)
async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)):
async def handle_wix_form(
request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)
):
"""
Unified endpoint to handle Wix form submissions (test and production).
No authentication required for this endpoint.
@@ -304,16 +359,19 @@ async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Dep
_LOGGER.error(f"Error in handle_wix_form: {str(e)}")
# log stacktrace
import traceback
traceback_str = traceback.format_exc()
_LOGGER.error(f"Stack trace for handle_wix_form: {traceback_str}")
raise HTTPException(
status_code=500,
detail=f"Error processing Wix form data: {str(e)}"
status_code=500, detail=f"Error processing Wix form data: {str(e)}"
)
@api_router.post("/webhook/wix-form/test")
@limiter.limit(DEFAULT_RATE_LIMIT)
async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session=Depends(get_async_session)):
async def handle_wix_form_test(
request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)
):
"""
Test endpoint to verify the API is working with raw JSON data.
No authentication required for testing purposes.
@@ -323,40 +381,37 @@ async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session
except Exception as e:
_LOGGER.error(f"Error in handle_wix_form_test: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error processing test data: {str(e)}"
status_code=500, detail=f"Error processing test data: {str(e)}"
)
@api_router.post("/admin/generate-api-key")
@limiter.limit("5/hour") # Very restrictive for admin operations
async def generate_new_api_key(
request: Request,
admin_key: str = Depends(validate_api_key)
request: Request, admin_key: str = Depends(validate_api_key)
):
"""
Admin endpoint to generate new API keys.
Requires admin API key and is heavily rate limited.
"""
if admin_key != "admin-key":
raise HTTPException(
status_code=403,
detail="Admin access required"
)
raise HTTPException(status_code=403, detail="Admin access required")
new_key = generate_api_key()
_LOGGER.info(f"Generated new API key (requested by: {admin_key})")
return {
"status": "success",
"message": "New API key generated",
"api_key": new_key,
"timestamp": datetime.now().isoformat(),
"note": "Store this key securely - it won't be shown again"
"note": "Store this key securely - it won't be shown again",
}
async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(security_basic)) -> str:
async def validate_basic_auth(
credentials: HTTPBasicCredentials = Depends(security_basic),
) -> str:
"""
Validate basic authentication for AlpineBits protocol.
Returns username if valid, raises HTTPException if not.
@@ -369,8 +424,11 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi
headers={"WWW-Authenticate": "Basic"},
)
valid = False
for entry in config['alpine_bits_auth']:
if credentials.username == entry['username'] and credentials.password == entry['password']:
for entry in config["alpine_bits_auth"]:
if (
credentials.username == entry["username"]
and credentials.password == entry["password"]
):
valid = True
break
if not valid:
@@ -379,7 +437,9 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi
detail="ERROR: Invalid credentials",
headers={"WWW-Authenticate": "Basic"},
)
_LOGGER.info(f"AlpineBits authentication successful for user: {credentials.username} (from config)")
_LOGGER.info(
f"AlpineBits authentication successful for user: {credentials.username} (from config)"
)
return credentials.username
@@ -390,10 +450,9 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
"""
if "multipart/form-data" not in content_type:
raise HTTPException(
status_code=400,
detail="ERROR: Content-Type must be multipart/form-data"
status_code=400, detail="ERROR: Content-Type must be multipart/form-data"
)
# Extract boundary
boundary = None
for part in content_type.split(";"):
@@ -401,62 +460,56 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
if part.startswith("boundary="):
boundary = part.split("=", 1)[1].strip('"')
break
if not boundary:
raise HTTPException(
status_code=400,
detail="ERROR: Missing boundary in multipart/form-data"
status_code=400, detail="ERROR: Missing boundary in multipart/form-data"
)
# Simple multipart parsing
parts = body.split(f"--{boundary}".encode())
data = {}
for part in parts:
if not part.strip() or part.strip() == b"--":
continue
# Split headers and content
if b"\r\n\r\n" in part:
headers_section, content = part.split(b"\r\n\r\n", 1)
content = content.rstrip(b"\r\n")
# Parse Content-Disposition header
headers = headers_section.decode('utf-8', errors='ignore')
headers = headers_section.decode("utf-8", errors="ignore")
name = None
for line in headers.split('\n'):
if 'Content-Disposition' in line and 'name=' in line:
for line in headers.split("\n"):
if "Content-Disposition" in line and "name=" in line:
# Extract name parameter
for param in line.split(';'):
for param in line.split(";"):
param = param.strip()
if param.startswith('name='):
name = param.split('=', 1)[1].strip('"')
if param.startswith("name="):
name = param.split("=", 1)[1].strip('"')
break
if name:
# Handle file uploads or text content
if content.startswith(b'<'):
if content.startswith(b"<"):
# Likely XML content
data[name] = content.decode('utf-8', errors='ignore')
data[name] = content.decode("utf-8", errors="ignore")
else:
data[name] = content.decode('utf-8', errors='ignore')
data[name] = content.decode("utf-8", errors="ignore")
return data
@api_router.post("/alpinebits/server-2024-10")
@limiter.limit("60/minute")
async def alpinebits_server_handshake(
request: Request,
username: str = Depends(validate_basic_auth)
request: Request, username: str = Depends(validate_basic_auth)
):
"""
AlpineBits server endpoint implementing the handshake protocol.
This endpoint handles:
- Protocol version negotiation via X-AlpineBits-ClientProtocolVersion header
- Client identification via X-AlpineBits-ClientID header (optional)
@@ -464,62 +517,67 @@ async def alpinebits_server_handshake(
- Gzip compression support
- Proper error handling with HTTP status codes
- Handshaking action processing
Authentication: HTTP Basic Auth required
Content-Type: multipart/form-data
Compression: gzip supported (check X-AlpineBits-Server-Accept-Encoding)
"""
try:
# Check required headers
client_protocol_version = request.headers.get("X-AlpineBits-ClientProtocolVersion")
client_protocol_version = request.headers.get(
"X-AlpineBits-ClientProtocolVersion"
)
if not client_protocol_version:
# Server concludes client speaks a protocol version preceding 2013-04
client_protocol_version = "pre-2013-04"
_LOGGER.info("No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04")
_LOGGER.info(
"No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04"
)
else:
_LOGGER.info(f"Client protocol version: {client_protocol_version}")
# Optional client ID
client_id = request.headers.get("X-AlpineBits-ClientID")
if client_id:
_LOGGER.info(f"Client ID: {client_id}")
# Check content encoding
content_encoding = request.headers.get("Content-Encoding")
is_compressed = content_encoding == "gzip"
if is_compressed:
_LOGGER.info("Request is gzip compressed")
# Get content type before processing
content_type = request.headers.get("Content-Type", "")
_LOGGER.info(f"Content-Type: {content_type}")
_LOGGER.info(f"Content-Encoding: {content_encoding}")
# Get request body
body = await request.body()
# Decompress if needed
if is_compressed:
try:
body = gzip.decompress(body)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"ERROR: Failed to decompress gzip content: {str(e)}"
detail=f"ERROR: Failed to decompress gzip content: {str(e)}",
)
# Check content type (after decompression)
if "multipart/form-data" not in content_type and "application/x-www-form-urlencoded" not in content_type:
if (
"multipart/form-data" not in content_type
and "application/x-www-form-urlencoded" not in content_type
):
raise HTTPException(
status_code=400,
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded"
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded",
)
# Parse multipart data
if "multipart/form-data" in content_type:
try:
@@ -527,7 +585,7 @@ async def alpinebits_server_handshake(
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}"
detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}",
)
elif "application/x-www-form-urlencoded" in content_type:
# Parse as urlencoded
@@ -535,75 +593,59 @@ async def alpinebits_server_handshake(
else:
raise HTTPException(
status_code=400,
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded"
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded",
)
# Check for required action parameter
action = form_data.get("action")
if not action:
raise HTTPException(
status_code=400,
detail="ERROR: Missing required 'action' parameter")
status_code=400, detail="ERROR: Missing required 'action' parameter"
)
_LOGGER.info(f"AlpineBits action: {action}")
# Get optional request XML
request_xml = form_data.get("request")
request_xml = form_data.get("request")
server = AlpineBitsServer()
version = Version.V2024_10
# Create successful handshake response
response = await server.handle_request(action, request_xml, version)
response = await server.handle_request(action, request_xml, version)
response_xml = response.xml_content
# Set response headers indicating server capabilities
headers = {
"Content-Type": "application/xml; charset=utf-8",
"X-AlpineBits-Server-Accept-Encoding": "gzip", # Indicate gzip support
"X-AlpineBits-Server-Version": "2024-10"
"X-AlpineBits-Server-Version": "2024-10",
}
return Response(
content=response_xml,
status_code=response.status_code,
headers=headers
)
return Response(
content=response_xml, status_code=response.status_code, headers=headers
)
except HTTPException:
# Re-raise HTTP exceptions (auth errors, etc.)
raise
except Exception as e:
_LOGGER.error(f"Error in AlpineBits handshake: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@api_router.get("/admin/stats")
@limiter.limit("10/minute")
async def get_api_stats(
request: Request,
admin_key: str = Depends(validate_api_key)
):
async def get_api_stats(request: Request, admin_key: str = Depends(validate_api_key)):
"""
Admin endpoint to get API usage statistics.
Requires admin API key.
"""
if admin_key != "admin-key":
raise HTTPException(
status_code=403,
detail="Admin access required"
)
raise HTTPException(status_code=403, detail="Admin access required")
# In a real application, you'd fetch this from your database/monitoring system
return {
"status": "success",
@@ -611,9 +653,9 @@ async def get_api_stats(
"uptime": "Available in production deployment",
"total_requests": "Available with monitoring setup",
"active_api_keys": len([k for k in ["wix-webhook-key", "admin-key"] if k]),
"rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory"
"rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory",
},
"timestamp": datetime.now().isoformat()
"timestamp": datetime.now().isoformat(),
}
@@ -629,11 +671,12 @@ async def landing_page():
try:
# Get the path to the HTML file
import os
html_path = os.path.join(os.path.dirname(__file__), "templates", "index.html")
with open(html_path, "r", encoding="utf-8") as f:
html_content = f.read()
return HTMLResponse(content=html_content, status_code=200)
except FileNotFoundError:
# Fallback if HTML file is not found
@@ -660,4 +703,5 @@ async def landing_page():
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)