110 lines
2.9 KiB
Python
110 lines
2.9 KiB
Python
"""
|
|
OIDC JWT verification — validates Keycloak tokens.
|
|
Falls back to DEV auth when OIDC is not configured.
|
|
"""
|
|
import logging
|
|
import os
|
|
from functools import lru_cache
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
from jose import jwt, JWTError, jwk
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
OIDC_ISSUER = os.getenv("OIDC_ISSUER", "")
|
|
OIDC_AUDIENCE = os.getenv("OIDC_AUDIENCE", "klg-frontend")
|
|
|
|
_jwks_cache: dict = {}
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_oidc_config() -> Optional[dict]:
|
|
"""Fetch OIDC well-known configuration."""
|
|
if not OIDC_ISSUER:
|
|
return None
|
|
try:
|
|
resp = httpx.get(f"{OIDC_ISSUER}/.well-known/openid-configuration", timeout=5)
|
|
return resp.json()
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch OIDC config: {e}")
|
|
return None
|
|
|
|
|
|
def get_jwks() -> dict:
|
|
"""Fetch JSON Web Key Set from Keycloak."""
|
|
global _jwks_cache
|
|
if _jwks_cache:
|
|
return _jwks_cache
|
|
config = get_oidc_config()
|
|
if not config:
|
|
return {}
|
|
try:
|
|
resp = httpx.get(config["jwks_uri"], timeout=5)
|
|
_jwks_cache = resp.json()
|
|
return _jwks_cache
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch JWKS: {e}")
|
|
return {}
|
|
|
|
|
|
def verify_oidc_token(token: str) -> Optional[dict]:
|
|
"""
|
|
Verify and decode a Keycloak JWT token.
|
|
Returns decoded claims or None if invalid.
|
|
"""
|
|
if not OIDC_ISSUER:
|
|
return None # OIDC not configured
|
|
|
|
jwks = get_jwks()
|
|
if not jwks or "keys" not in jwks:
|
|
logger.warning("No JWKS keys available")
|
|
return None
|
|
|
|
try:
|
|
# Get key ID from token header
|
|
unverified = jwt.get_unverified_header(token)
|
|
kid = unverified.get("kid")
|
|
|
|
# Find matching key
|
|
key = None
|
|
for k in jwks["keys"]:
|
|
if k.get("kid") == kid:
|
|
key = k
|
|
break
|
|
|
|
if not key:
|
|
logger.warning(f"No matching key found for kid={kid}")
|
|
return None
|
|
|
|
# Verify and decode
|
|
claims = jwt.decode(
|
|
token,
|
|
key,
|
|
algorithms=["RS256"],
|
|
issuer=OIDC_ISSUER,
|
|
audience=OIDC_AUDIENCE,
|
|
options={"verify_aud": False}, # Keycloak may not include aud
|
|
)
|
|
return claims
|
|
|
|
except JWTError as e:
|
|
logger.warning(f"JWT verification failed: {e}")
|
|
return None
|
|
|
|
|
|
def extract_user_from_claims(claims: dict) -> dict:
|
|
"""Extract user info from JWT claims."""
|
|
roles = []
|
|
if "realm_access" in claims:
|
|
roles = claims["realm_access"].get("roles", [])
|
|
|
|
return {
|
|
"id": claims.get("sub", ""),
|
|
"email": claims.get("email", ""),
|
|
"display_name": claims.get("name", claims.get("preferred_username", "")),
|
|
"role": roles[0] if roles else "operator_user",
|
|
"roles": roles,
|
|
"organization_id": claims.get("organization_id"),
|
|
}
|