Source code for litestar_vite.plugin._proxy_headers
"""Proxy headers middleware for handling X-Forwarded-* headers securely.
This module provides middleware to handle X-Forwarded-* headers from reverse proxies
like Railway, Heroku, AWS ALB, nginx, etc.
Security: Headers are only trusted when the direct caller IP is in the configured
trusted_proxies list. This prevents header spoofing attacks.
Related: https://github.com/litestar-org/litestar-vite/issues/167
"""
import ipaddress
from typing import TYPE_CHECKING, Any, cast
from litestar.enums import ScopeType
from litestar.middleware import AbstractMiddleware
if TYPE_CHECKING:
from litestar.types import ASGIApp, Receive, Scope, Send
__all__ = ("ProxyHeadersMiddleware", "TrustedHosts")
[docs]
class TrustedHosts:
"""Container for trusted proxy hosts and networks.
Provides efficient lookup for IP addresses and CIDR networks.
Following Uvicorn's security model for proxy header validation.
Supports:
- Wildcard "*" to trust all hosts (for controlled environments)
- IPv4 addresses: "192.168.1.1"
- IPv6 addresses: "::1"
- CIDR notation: "10.0.0.0/8", "fd00::/8"
- Literals for non-IP hosts (e.g., Unix socket paths)
"""
__slots__ = ("always_trust", "trusted_hosts", "trusted_literals", "trusted_networks")
[docs]
def __init__(self, trusted_hosts: "list[str] | str") -> None:
"""Initialize trusted hosts container.
Args:
trusted_hosts: A single host, comma-separated string, or list of hosts.
Use "*" to trust all hosts (only in controlled environments).
"""
self.always_trust: bool = trusted_hosts in ("*", ["*"])
self.trusted_literals: set[str] = set()
self.trusted_hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
self.trusted_networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set()
if not self.always_trust:
hosts_list: list[str]
if isinstance(trusted_hosts, str):
hosts_list = [h.strip() for h in trusted_hosts.split(",") if h.strip()]
else:
hosts_list = trusted_hosts
for host in hosts_list:
if "/" in host:
# CIDR notation
try:
self.trusted_networks.add(ipaddress.ip_network(host, strict=False))
except ValueError:
# Not a valid network, treat as literal
self.trusted_literals.add(host)
else:
try:
self.trusted_hosts.add(ipaddress.ip_address(host))
except ValueError:
# Not a valid IP, treat as literal (e.g., Unix socket path)
self.trusted_literals.add(host)
[docs]
def __contains__(self, host: "str | None") -> bool:
"""Check if a host is trusted.
Args:
host: The host to check. Can be an IP address or literal.
Returns:
True if the host is trusted, False otherwise.
"""
# None and empty string are never trusted
if not host:
return False
if self.always_trust:
return True
try:
ip = ipaddress.ip_address(host)
if ip in self.trusted_hosts:
return True
return any(ip in net for net in self.trusted_networks)
except ValueError:
return host in self.trusted_literals
[docs]
def get_trusted_client_host(self, x_forwarded_for: str) -> str:
"""Extract the real client IP from X-Forwarded-For header.
The X-Forwarded-For header contains a comma-separated list of IPs.
Each proxy appends the client IP to the list. We find the first
untrusted host (reading from right to left) which is the real client.
Args:
x_forwarded_for: The X-Forwarded-For header value.
Returns:
The first untrusted host in the chain, or the original client
if all hosts are trusted.
"""
hosts = [h.strip() for h in x_forwarded_for.split(",") if h.strip()]
if not hosts:
return ""
if self.always_trust:
# When trusting all, return the leftmost (original client)
return hosts[0]
# Each proxy appends to the list, so check in reverse
# Find the first untrusted host from the right
for host in reversed(hosts):
if host not in self:
return host
# All hosts are trusted - return the original client
return hosts[0]
[docs]
class ProxyHeadersMiddleware(AbstractMiddleware):
"""ASGI middleware for secure proxy header handling.
Only processes X-Forwarded-* headers when the direct caller (scope["client"])
is in the trusted hosts list. This prevents header spoofing attacks.
Handles:
- X-Forwarded-Proto: Sets scope["scheme"] (http/https/ws/wss)
- X-Forwarded-For: Sets scope["client"] to the real client IP
- X-Forwarded-Host: Optionally sets the Host header
Security:
Never blindly trusts headers from any client. Validates caller IP
against trusted hosts before reading headers. Validates scheme values
to only allow http/https/ws/wss.
Example::
from litestar_vite import VitePlugin, ViteConfig
from litestar_vite.config import RuntimeConfig
# Trust all proxies (Railway, Heroku, container environments)
app = Litestar(
plugins=[VitePlugin(config=ViteConfig(
runtime=RuntimeConfig(trusted_proxies="*")
))]
)
# Trust specific proxy IPs
app = Litestar(
plugins=[VitePlugin(config=ViteConfig(
runtime=RuntimeConfig(trusted_proxies=["10.0.0.0/8", "172.16.0.0/12"])
))]
)
"""
scopes = {ScopeType.HTTP, ScopeType.WEBSOCKET}
[docs]
def __init__(
self, app: "ASGIApp", trusted_hosts: "list[str] | str" = "127.0.0.1", handle_forwarded_host: bool = True
) -> None:
"""Initialize the proxy headers middleware.
Args:
app: The ASGI application to wrap.
trusted_hosts: Hosts to trust for X-Forwarded-* headers.
Defaults to "127.0.0.1" (localhost only).
handle_forwarded_host: Whether to handle X-Forwarded-Host header
for Host header rewriting. Defaults to True.
"""
super().__init__(app)
self.trusted_hosts = TrustedHosts(trusted_hosts)
self.handle_forwarded_host = handle_forwarded_host
async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None:
"""Process the request and apply proxy headers if trusted.
Args:
scope: The ASGI scope.
receive: The receive callable.
send: The send callable.
"""
client_addr = scope.get("client") # pyright: ignore[reportUnknownMemberType]
client_host = client_addr[0] if client_addr else None
if client_host in self.trusted_hosts:
# Build a dict of headers for efficient lookup
headers: dict[bytes, bytes] = {}
for key, value in scope.get("headers", []): # pyright: ignore[reportUnknownMemberType]
# Use first occurrence only (as per HTTP spec)
if key not in headers:
headers[key] = value
scope_dict = cast("dict[str, Any]", scope)
# X-Forwarded-Proto -> scope["scheme"]
if b"x-forwarded-proto" in headers:
proto = headers[b"x-forwarded-proto"].decode("latin-1").strip().lower()
if proto in {"http", "https", "ws", "wss"}:
# For WebSocket, ensure ws/wss scheme
if scope["type"] == "websocket":
if proto == "https":
scope_dict["scheme"] = "wss"
elif proto == "http":
scope_dict["scheme"] = "ws"
else:
scope_dict["scheme"] = proto
else:
scope_dict["scheme"] = proto
# X-Forwarded-For -> scope["client"]
if b"x-forwarded-for" in headers:
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin-1")
real_client = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)
if real_client:
scope_dict["client"] = (real_client, 0)
# X-Forwarded-Host -> replace Host header
if self.handle_forwarded_host and b"x-forwarded-host" in headers:
forwarded_host = headers[b"x-forwarded-host"]
# Rebuild headers list with replaced Host
new_headers: list[tuple[bytes, bytes]] = []
host_replaced = False
for key, value in scope.get("headers", []): # pyright: ignore[reportUnknownMemberType]
if key == b"host" and not host_replaced:
new_headers.append((b"host", forwarded_host))
host_replaced = True
else:
new_headers.append((key, value))
# If no Host header existed, add it
if not host_replaced:
new_headers.append((b"host", forwarded_host))
scope_dict["headers"] = new_headers
await self.app(scope, receive, send)