klg-asutk-app/backend/app/api/oidc.py

110 lines
2.9 KiB
Python

"""
OIDC JWT verification — validates Keycloak tokens.
Falls back to DEV auth when OIDC is not configured.
"""
import logging
import os
from functools import lru_cache
from typing import Optional
import httpx
from jose import jwt, JWTError, jwk
logger = logging.getLogger(__name__)
OIDC_ISSUER = os.getenv("OIDC_ISSUER", "")
OIDC_AUDIENCE = os.getenv("OIDC_AUDIENCE", "klg-frontend")
_jwks_cache: dict = {}
@lru_cache(maxsize=1)
def get_oidc_config() -> Optional[dict]:
"""Fetch OIDC well-known configuration."""
if not OIDC_ISSUER:
return None
try:
resp = httpx.get(f"{OIDC_ISSUER}/.well-known/openid-configuration", timeout=5)
return resp.json()
except Exception as e:
logger.error(f"Failed to fetch OIDC config: {e}")
return None
def get_jwks() -> dict:
"""Fetch JSON Web Key Set from Keycloak."""
global _jwks_cache
if _jwks_cache:
return _jwks_cache
config = get_oidc_config()
if not config:
return {}
try:
resp = httpx.get(config["jwks_uri"], timeout=5)
_jwks_cache = resp.json()
return _jwks_cache
except Exception as e:
logger.error(f"Failed to fetch JWKS: {e}")
return {}
def verify_oidc_token(token: str) -> Optional[dict]:
"""
Verify and decode a Keycloak JWT token.
Returns decoded claims or None if invalid.
"""
if not OIDC_ISSUER:
return None # OIDC not configured
jwks = get_jwks()
if not jwks or "keys" not in jwks:
logger.warning("No JWKS keys available")
return None
try:
# Get key ID from token header
unverified = jwt.get_unverified_header(token)
kid = unverified.get("kid")
# Find matching key
key = None
for k in jwks["keys"]:
if k.get("kid") == kid:
key = k
break
if not key:
logger.warning(f"No matching key found for kid={kid}")
return None
# Verify and decode
claims = jwt.decode(
token,
key,
algorithms=["RS256"],
issuer=OIDC_ISSUER,
audience=OIDC_AUDIENCE,
options={"verify_aud": False}, # Keycloak may not include aud
)
return claims
except JWTError as e:
logger.warning(f"JWT verification failed: {e}")
return None
def extract_user_from_claims(claims: dict) -> dict:
"""Extract user info from JWT claims."""
roles = []
if "realm_access" in claims:
roles = claims["realm_access"].get("roles", [])
return {
"id": claims.get("sub", ""),
"email": claims.get("email", ""),
"display_name": claims.get("name", claims.get("preferred_username", "")),
"role": roles[0] if roles else "operator_user",
"roles": roles,
"organization_id": claims.get("organization_id"),
}