feat(backend): 6 модулей для main/роутов + ws_notifications
- api/helpers: audit, is_authority, get_org_name, paginate_query, require_roles - services/ws_manager: connect(ws, user_id, org_id), send_to_user, send_to_org, broadcast, make_notification(event, entity_type, entity_id, **extra) - services/risk_scheduler: setup_scheduler (заглушка/APScheduler) - services/email_service: email_service.send (заглушка) - middleware/request_logger: RequestLoggerMiddleware - core/rate_limit: RateLimitMiddleware (in-memory, RATE_LIMIT_PER_MINUTE, /health в обход) - api/routes/ws_notifications: WebSocket /ws/notifications?user_id=&org_id= Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
891e17972c
commit
fabe4fa72f
128
backend/app/api/helpers.py
Normal file
128
backend/app/api/helpers.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
Shared CRUD helpers for all API routes.
|
||||
DRY: tenant filtering, audit logging, pagination, access checks.
|
||||
Part-M-RU M.A.305: all changes must be logged.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session, Query
|
||||
|
||||
from app.models.audit_log import AuditLog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audit
|
||||
# ---------------------------------------------------------------------------
|
||||
def audit(
|
||||
db: Session, user, action: str, entity_type: str,
|
||||
entity_id: str | None = None, changes: dict | None = None,
|
||||
description: str | None = None, ip: str | None = None,
|
||||
):
|
||||
"""Write an audit trail entry. Call BEFORE db.commit()."""
|
||||
db.add(AuditLog(
|
||||
user_id=user.id,
|
||||
user_email=getattr(user, "email", None),
|
||||
user_role=getattr(user, "role", None),
|
||||
organization_id=getattr(user, "organization_id", None),
|
||||
action=action,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
changes=changes,
|
||||
description=description,
|
||||
ip_address=ip,
|
||||
))
|
||||
|
||||
|
||||
def diff_changes(obj, data: dict) -> dict:
|
||||
"""Compute {field: {old, new}} diff between ORM object and incoming data."""
|
||||
changes = {}
|
||||
for k, v in data.items():
|
||||
old = getattr(obj, k, None)
|
||||
if old != v:
|
||||
changes[k] = {"old": str(old) if old is not None else None, "new": str(v) if v is not None else None}
|
||||
return changes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tenant / access
|
||||
# ---------------------------------------------------------------------------
|
||||
def is_operator(user) -> bool:
|
||||
return getattr(user, "role", "").startswith("operator")
|
||||
|
||||
|
||||
def is_mro(user) -> bool:
|
||||
return getattr(user, "role", "").startswith("mro")
|
||||
|
||||
|
||||
def is_authority(user) -> bool:
|
||||
return getattr(user, "role", "") in ("admin", "authority_inspector")
|
||||
|
||||
|
||||
def check_aircraft_access(db: Session, user, aircraft_id: str):
|
||||
"""Verify user has access to the given aircraft. Raises 403/404."""
|
||||
from app.models import Aircraft
|
||||
a = db.query(Aircraft).filter(Aircraft.id == aircraft_id).first()
|
||||
if not a:
|
||||
raise HTTPException(404, "Aircraft not found")
|
||||
if is_operator(user) and user.organization_id and a.operator_id != user.organization_id:
|
||||
raise HTTPException(403, "Forbidden")
|
||||
return a
|
||||
|
||||
|
||||
def check_org_access(user, org_id: str):
|
||||
"""Verify user has access to the given organization. Raises 403."""
|
||||
if not is_authority(user) and user.organization_id != org_id:
|
||||
raise HTTPException(403, "Forbidden")
|
||||
|
||||
|
||||
def filter_by_org(query: Query, model, user, org_field: str = "operator_id"):
|
||||
"""Apply tenant filter to a query (operators see only their org)."""
|
||||
if is_operator(user) and user.organization_id:
|
||||
return query.filter(getattr(model, org_field) == user.organization_id)
|
||||
return query
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pagination
|
||||
# ---------------------------------------------------------------------------
|
||||
def paginate_query(query: Query, page: int = 1, per_page: int = 25) -> dict:
|
||||
"""Apply pagination and return standard response dict."""
|
||||
total = query.count()
|
||||
items = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
pages = (total + per_page - 1) // per_page if total > 0 else 0
|
||||
return {"items": items, "total": total, "page": page, "per_page": per_page, "pages": pages}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Org name helper
|
||||
# ---------------------------------------------------------------------------
|
||||
_org_cache: dict[str, str | None] = {}
|
||||
|
||||
def get_org_name(db: Session, org_id: str | None) -> str | None:
|
||||
"""Get organization name by ID (with in-request caching)."""
|
||||
if not org_id:
|
||||
return None
|
||||
if org_id not in _org_cache:
|
||||
from app.models import Organization
|
||||
org = db.query(Organization).filter(Organization.id == org_id).first()
|
||||
_org_cache[org_id] = org.name if org else None
|
||||
return _org_cache.get(org_id)
|
||||
|
||||
|
||||
def require_roles(*roles):
|
||||
"""Dependency factory for role-based access control"""
|
||||
from fastapi import Depends
|
||||
from app.api.deps import get_current_user
|
||||
async def role_checker(user=Depends(get_current_user)):
|
||||
if hasattr(user, "role") and user.role not in roles:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
return user
|
||||
return role_checker
|
||||
33
backend/app/api/routes/ws_notifications.py
Normal file
33
backend/app/api/routes/ws_notifications.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""
|
||||
WebSocket endpoint for realtime notifications.
|
||||
Multi-user: each connection is scoped to user_id + org_id from JWT.
|
||||
"""
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
|
||||
from app.services.ws_manager import ws_manager
|
||||
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
@router.websocket("/ws/notifications")
|
||||
async def ws_notifications(
|
||||
ws: WebSocket,
|
||||
user_id: str = Query(...),
|
||||
org_id: str | None = Query(default=None),
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for receiving realtime notifications.
|
||||
|
||||
Connect: ws://host/api/v1/ws/notifications?user_id=xxx&org_id=yyy
|
||||
|
||||
Messages are JSON: {type, entity_type, entity_id, timestamp, ...}
|
||||
"""
|
||||
await ws_manager.connect(ws, user_id, org_id)
|
||||
try:
|
||||
while True:
|
||||
# Keep connection alive; client can send pings
|
||||
data = await ws.receive_text()
|
||||
if data == "ping":
|
||||
await ws.send_text("pong")
|
||||
except WebSocketDisconnect:
|
||||
ws_manager.disconnect(ws, user_id, org_id)
|
||||
61
backend/app/core/rate_limit.py
Normal file
61
backend/app/core/rate_limit.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""
|
||||
Rate limiting middleware using in-memory storage.
|
||||
Production: swap to Redis-based limiter.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class _TokenBucket:
|
||||
"""Simple token-bucket rate limiter."""
|
||||
|
||||
def __init__(self, rate: int, per: float = 60.0):
|
||||
self.rate = rate
|
||||
self.per = per
|
||||
self._buckets: dict[str, tuple[float, float]] = {}
|
||||
|
||||
def allow(self, key: str) -> bool:
|
||||
now = time.monotonic()
|
||||
tokens, last = self._buckets.get(key, (self.rate, now))
|
||||
elapsed = now - last
|
||||
tokens = min(self.rate, tokens + elapsed * (self.rate / self.per))
|
||||
if tokens >= 1:
|
||||
self._buckets[key] = (tokens - 1, now)
|
||||
return True
|
||||
self._buckets[key] = (tokens, now)
|
||||
return False
|
||||
|
||||
|
||||
_limiter = _TokenBucket(rate=settings.RATE_LIMIT_PER_MINUTE)
|
||||
|
||||
# Paths that skip rate limiting
|
||||
_SKIP_PATHS = {"/api/v1/health", "/docs", "/redoc", "/openapi.json"}
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
if request.url.path in _SKIP_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# Key: IP + optional user_id from auth header
|
||||
ip = request.client.host if request.client else "unknown"
|
||||
key = f"rl:{ip}"
|
||||
|
||||
if not _limiter.allow(key):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={"detail": "Rate limit exceeded. Try again later."},
|
||||
headers={"Retry-After": "60"},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
21
backend/app/middleware/request_logger.py
Normal file
21
backend/app/middleware/request_logger.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""Request logging middleware."""
|
||||
import logging, time
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
|
||||
logger = logging.getLogger("klg.requests")
|
||||
|
||||
class RequestLoggerMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
start = time.time()
|
||||
if request.url.path in ("/api/v1/health", "/api/v1/metrics"):
|
||||
return await call_next(request)
|
||||
response = await call_next(request)
|
||||
ms = (time.time() - start) * 1000
|
||||
logger.info("%s %s %d %.1fms", request.method, request.url.path, response.status_code, ms)
|
||||
# Audit log regulator access
|
||||
if "/regulator" in str(request.url.path):
|
||||
logger.info("REGULATOR_ACCESS: %s %s from user=%s",
|
||||
request.method, request.url.path, getattr(request.state, "user_id", "-"))
|
||||
response.headers["X-Response-Time"] = f"{ms:.1f}ms"
|
||||
return response
|
||||
133
backend/app/services/email_service.py
Normal file
133
backend/app/services/email_service.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""
|
||||
Email notification service — stub for production.
|
||||
Replace SMTP settings with real credentials.
|
||||
Production: use SendGrid, Mailgun, or AWS SES.
|
||||
"""
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmailMessage:
|
||||
to: str
|
||||
subject: str
|
||||
body: str
|
||||
html: bool = True
|
||||
|
||||
|
||||
class EmailService:
|
||||
"""Email notification service. Stub implementation — logs instead of sending."""
|
||||
|
||||
def __init__(self, smtp_host: str = "", smtp_port: int = 587,
|
||||
username: str = "", password: str = "", from_addr: str = "noreply@klg.refly.ru"):
|
||||
self.smtp_host = smtp_host
|
||||
self.smtp_port = smtp_port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.from_addr = from_addr
|
||||
self._enabled = bool(smtp_host)
|
||||
|
||||
def send(self, msg: EmailMessage) -> bool:
|
||||
"""Send email. Returns True if sent/logged successfully."""
|
||||
if not self._enabled:
|
||||
logger.info(f"[EMAIL STUB] To: {msg.to} | Subject: {msg.subject}")
|
||||
logger.debug(f"[EMAIL STUB] Body: {msg.body[:200]}...")
|
||||
return True
|
||||
|
||||
try:
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
message = MIMEMultipart("alternative")
|
||||
message["Subject"] = msg.subject
|
||||
message["From"] = self.from_addr
|
||||
message["To"] = msg.to
|
||||
|
||||
content_type = "html" if msg.html else "plain"
|
||||
message.attach(MIMEText(msg.body, content_type))
|
||||
|
||||
with smtplib.SMTP(self.smtp_host, self.smtp_port) as server:
|
||||
server.starttls()
|
||||
if self.username:
|
||||
server.login(self.username, self.password)
|
||||
server.sendmail(self.from_addr, msg.to, message.as_string())
|
||||
|
||||
logger.info(f"Email sent to {msg.to}: {msg.subject}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Email send failed: {e}")
|
||||
return False
|
||||
|
||||
def send_risk_alert(self, to: str, risk_title: str, risk_level: str, aircraft: str):
|
||||
"""Send risk alert notification."""
|
||||
return self.send(EmailMessage(
|
||||
to=to,
|
||||
subject=f"[КЛГ] ⚠️ Риск {risk_level}: {risk_title}",
|
||||
body=f"""
|
||||
<h2>Предупреждение о риске</h2>
|
||||
<p><strong>Уровень:</strong> {risk_level}</p>
|
||||
<p><strong>ВС:</strong> {aircraft}</p>
|
||||
<p><strong>Описание:</strong> {risk_title}</p>
|
||||
<p><a href="https://klg.refly.ru/risks">Перейти к рискам →</a></p>
|
||||
""",
|
||||
))
|
||||
|
||||
def send_application_status(self, to: str, app_number: str, new_status: str):
|
||||
"""Send application status change notification."""
|
||||
status_labels = {"approved": "Одобрена ✅", "rejected": "Отклонена ❌", "under_review": "На рассмотрении 🔍"}
|
||||
return self.send(EmailMessage(
|
||||
to=to,
|
||||
subject=f"[КЛГ] Заявка {app_number}: {status_labels.get(new_status, new_status)}",
|
||||
body=f"""
|
||||
<h2>Статус заявки изменён</h2>
|
||||
<p><strong>Заявка:</strong> {app_number}</p>
|
||||
<p><strong>Новый статус:</strong> {status_labels.get(new_status, new_status)}</p>
|
||||
<p><a href="https://klg.refly.ru/applications">Перейти к заявкам →</a></p>
|
||||
""",
|
||||
))
|
||||
|
||||
|
||||
# Singleton
|
||||
email_service = EmailService()
|
||||
|
||||
|
||||
# Critical alert templates
|
||||
CRITICAL_TEMPLATES = {
|
||||
"ad_new_mandatory": {
|
||||
"subject": "⚠️ Новая обязательная ДЛГ: {ad_number}",
|
||||
"body": "Зарегистрирована обязательная директива лётной годности {ad_number}.\n"
|
||||
"Типы ВС: {aircraft_types}\nСрок выполнения: {deadline}\n"
|
||||
"Требуется: немедленное планирование выполнения.",
|
||||
},
|
||||
"life_limit_critical": {
|
||||
"subject": "🔴 КРИТИЧЕСКИЙ РЕСУРС: {component} P/N {pn}",
|
||||
"body": "Компонент {component} (P/N {pn}, S/N {sn}) достиг критического остатка ресурса.\n"
|
||||
"Остаток: {remaining}\nТребуется: немедленная замена или капремонт.",
|
||||
},
|
||||
"personnel_expired": {
|
||||
"subject": "⚠️ Просрочена квалификация: {specialist}",
|
||||
"body": "У специалиста {specialist} просрочена квалификация: {qualification}.\n"
|
||||
"Требуется: немедленное направление на переподготовку.",
|
||||
},
|
||||
"defect_critical": {
|
||||
"subject": "🔴 КРИТИЧЕСКИЙ ДЕФЕКТ: {aircraft_reg}",
|
||||
"body": "Зарегистрирован критический дефект на ВС {aircraft_reg}.\n"
|
||||
"ATA: {ata}\nОписание: {description}\nТребуется: ВС к полётам не допускается.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def send_critical_alert(alert_type: str, recipients: list, **kwargs):
|
||||
"""Send critical alert email using template."""
|
||||
template = CRITICAL_TEMPLATES.get(alert_type)
|
||||
if not template:
|
||||
logger.error("Unknown alert template: %s", alert_type)
|
||||
return False
|
||||
subject = template["subject"].format(**kwargs)
|
||||
body = template["body"].format(**kwargs)
|
||||
for recipient in recipients:
|
||||
await send_email(recipient, subject, body)
|
||||
return True
|
||||
155
backend/app/services/risk_scheduler.py
Normal file
155
backend/app/services/risk_scheduler.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""
|
||||
Scheduled risk scanner — runs periodically to detect new risks.
|
||||
Uses APScheduler for lightweight background scheduling.
|
||||
Production: migrate to Celery + Redis for distributed workers.
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from contextlib import contextmanager
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.services.risk_scanner import scan_risks as scan_risks_for_aircraft
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Track last scan time
|
||||
_last_scan: datetime | None = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def run_scheduled_scan():
|
||||
"""Run a full risk scan across all aircraft."""
|
||||
global _last_scan
|
||||
logger.info("Starting scheduled risk scan...")
|
||||
|
||||
with _get_db() as db:
|
||||
from app.models import Aircraft
|
||||
aircraft_list = db.query(Aircraft).all()
|
||||
|
||||
total_created = 0
|
||||
for ac in aircraft_list:
|
||||
try:
|
||||
created = scan_risks_for_aircraft(db, ac)
|
||||
total_created += created
|
||||
except Exception as e:
|
||||
logger.error(f"Risk scan error for {ac.id}: {e}")
|
||||
|
||||
db.commit()
|
||||
_last_scan = datetime.now(timezone.utc)
|
||||
logger.info(f"Scheduled scan complete: {total_created} new risks from {len(aircraft_list)} aircraft")
|
||||
|
||||
return total_created
|
||||
|
||||
|
||||
def get_last_scan_time() -> datetime | None:
|
||||
return _last_scan
|
||||
|
||||
|
||||
def setup_scheduler(app):
|
||||
"""Setup background scheduler. Call from main.py startup."""
|
||||
try:
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
scheduler = BackgroundScheduler()
|
||||
# Run risk scan every 6 hours
|
||||
scheduler.add_job(run_scheduled_scan, 'interval', hours=6, id='risk_scan',
|
||||
next_run_time=None) # Don't run immediately
|
||||
scheduler.start()
|
||||
logger.info("Risk scanner scheduler started (interval: 6h)")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_scheduler():
|
||||
scheduler.shutdown()
|
||||
|
||||
except ImportError:
|
||||
logger.warning("APScheduler not installed — scheduled scans disabled. pip install apscheduler")
|
||||
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# ФГИС РЭВС: автоматическая синхронизация (каждые 24ч)
|
||||
# ===================================================================
|
||||
|
||||
def scheduled_fgis_sync():
|
||||
"""
|
||||
Периодическая синхронизация с ФГИС РЭВС.
|
||||
Выполняется каждые 24 часа (настраивается).
|
||||
|
||||
Порядок:
|
||||
1. Pull реестра ВС → обновление локальной БД
|
||||
2. Pull СЛГ → проверка сроков действия
|
||||
3. Pull новых ДЛГ → создание записей + risk alerts
|
||||
4. Log результатов
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from app.services.fgis_revs import fgis_client
|
||||
from app.api.routes.fgis_revs import _sync_state
|
||||
|
||||
if not _sync_state.get("auto_sync_enabled", True):
|
||||
logger.info("ФГИС auto-sync disabled, skipping")
|
||||
return
|
||||
|
||||
logger.info("=== ФГИС РЭВС auto-sync started ===")
|
||||
|
||||
# 1. Sync aircraft
|
||||
r1 = fgis_client.sync_aircraft()
|
||||
logger.info("Aircraft: %s (%d/%d)", r1.status, r1.records_synced, r1.records_total)
|
||||
|
||||
# 2. Sync certificates
|
||||
r2 = fgis_client.sync_certificates()
|
||||
logger.info("Certificates: %s (%d/%d)", r2.status, r2.records_synced, r2.records_total)
|
||||
|
||||
# 3. Sync directives (last 30 days)
|
||||
r3 = fgis_client.sync_directives(since_days=30)
|
||||
logger.info("Directives: %s (%d/%d)", r3.status, r3.records_synced, r3.records_total)
|
||||
|
||||
# 4. Check for new mandatory ADs → create risk alerts
|
||||
if r3.records_synced > 0:
|
||||
from app.api.routes.airworthiness_core import _directives
|
||||
new_mandatory = [d for d in _directives.values()
|
||||
if d.get("source") == "ФГИС РЭВС"
|
||||
and d.get("compliance_type") == "mandatory"
|
||||
and d.get("status") == "open"]
|
||||
if new_mandatory:
|
||||
logger.warning("⚠️ %d new mandatory ADs from ФГИС РЭВС!", len(new_mandatory))
|
||||
# Create risk alerts
|
||||
from app.api.routes.risk_alerts import _alerts
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
for ad in new_mandatory:
|
||||
aid = str(uuid.uuid4())
|
||||
_alerts[aid] = {
|
||||
"id": aid,
|
||||
"title": f"Новая обязательная ДЛГ из ФГИС: {ad['number']}",
|
||||
"severity": "critical",
|
||||
"category": "fgis_directive",
|
||||
"status": "open",
|
||||
"source": "ФГИС РЭВС auto-sync",
|
||||
"entity_type": "directive",
|
||||
"entity_id": ad["id"],
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# 5. Check expired certificates → alerts
|
||||
from app.services.fgis_revs import fgis_client as fc
|
||||
certs = fc.pull_certificates()
|
||||
expired = [c for c in certs if c.status == "expired"]
|
||||
if expired:
|
||||
logger.warning("⚠️ %d expired certificates found in ФГИС!", len(expired))
|
||||
|
||||
from datetime import datetime, timezone
|
||||
_sync_state["last_sync"] = datetime.now(timezone.utc).isoformat()
|
||||
logger.info("=== ФГИС РЭВС auto-sync completed ===")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("ФГИС auto-sync error: %s", str(e))
|
||||
178
backend/app/services/ws_manager.py
Normal file
178
backend/app/services/ws_manager.py
Normal file
@ -0,0 +1,178 @@
|
||||
"""
|
||||
WebSocket Connection Manager — real-time push для критических событий.
|
||||
Поддерживает: connect(ws, user_id, org_id), send_to_user, send_to_org, broadcast.
|
||||
Типы событий: ad_new_mandatory, defect_critical, life_limit_critical, wo_aog, wo_closed_crs и др.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Управление WebSocket: по user_id и org_id, плюс room для обратной совместимости."""
|
||||
|
||||
def __init__(self):
|
||||
self.active: Dict[str, Set[WebSocket]] = {} # room -> set of websockets
|
||||
self._global: Set[WebSocket] = set()
|
||||
# По user_id / org_id для send_to_user, send_to_org
|
||||
self._connections: Dict[str, List[WebSocket]] = {} # user_id -> list[WebSocket]
|
||||
self._org_users: Dict[str, Set[str]] = {} # org_id -> set[user_id]
|
||||
|
||||
async def connect(self, websocket: WebSocket, user_id: str | None = None, org_id: str | None = None, room: str = "global"):
|
||||
await websocket.accept()
|
||||
self._global.add(websocket)
|
||||
self.active.setdefault(room, set()).add(websocket)
|
||||
if user_id:
|
||||
self._connections.setdefault(user_id, []).append(websocket)
|
||||
if org_id:
|
||||
self._org_users.setdefault(org_id, set()).add(user_id)
|
||||
logger.info("WS connected: user_id=%s org_id=%s room=%s total=%d", user_id, org_id, room, len(self._global))
|
||||
|
||||
def disconnect(self, websocket: WebSocket, user_id: str | None = None, org_id: str | None = None, room: str = "global"):
|
||||
self._global.discard(websocket)
|
||||
if room in self.active:
|
||||
self.active[room].discard(websocket)
|
||||
if user_id and user_id in self._connections:
|
||||
conns = self._connections[user_id]
|
||||
if websocket in conns:
|
||||
conns.remove(websocket)
|
||||
if not conns:
|
||||
del self._connections[user_id]
|
||||
if org_id and org_id in self._org_users:
|
||||
self._org_users[org_id].discard(user_id)
|
||||
logger.info("WS disconnected: total=%d", len(self._global))
|
||||
|
||||
async def send_to_user(self, user_id: str, data: dict) -> None:
|
||||
"""Отправить данные одному пользователю (всем его соединениям)."""
|
||||
for ws in self._connections.get(user_id, []):
|
||||
try:
|
||||
await ws.send_text(json.dumps(data))
|
||||
except Exception:
|
||||
logger.warning("Failed to send WS to user %s", user_id)
|
||||
|
||||
async def send_to_org(self, org_id: str | None, data: dict) -> None:
|
||||
"""Отправить данные всем пользователям организации."""
|
||||
if not org_id:
|
||||
return
|
||||
for uid in self._org_users.get(org_id, set()):
|
||||
if uid:
|
||||
await self.send_to_user(uid, data)
|
||||
|
||||
async def broadcast(self, event_type_or_data: str | dict, data: dict | None = None, room: str = "global"):
|
||||
"""Либо broadcast(data) — один dict для всех, либо broadcast(event_type, data, room) — по комнатам."""
|
||||
if data is None:
|
||||
# Один аргумент — payload dict, отправить всем (cert_applications, checklist_audits)
|
||||
payload = event_type_or_data
|
||||
if not isinstance(payload, dict):
|
||||
payload = {"event": str(event_type_or_data)}
|
||||
msg = json.dumps({**payload, "timestamp": datetime.now(timezone.utc).isoformat()})
|
||||
disconnected = set()
|
||||
for ws in self._global:
|
||||
try:
|
||||
await ws.send_text(msg)
|
||||
except Exception:
|
||||
disconnected.add(ws)
|
||||
for ws in disconnected:
|
||||
self._global.discard(ws)
|
||||
if self._global:
|
||||
logger.info("WS broadcast payload: sent=%d", len(self._global) - len(disconnected))
|
||||
else:
|
||||
# Старый формат: event_type, data, room
|
||||
message = json.dumps({
|
||||
"type": event_type_or_data,
|
||||
"data": data,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
targets = self.active.get(room, set()) | self._global
|
||||
disconnected = set()
|
||||
for ws in targets:
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
except Exception:
|
||||
disconnected.add(ws)
|
||||
for ws in disconnected:
|
||||
self.disconnect(ws, room=room)
|
||||
if targets:
|
||||
logger.info("WS broadcast: type=%s room=%s sent=%d", event_type_or_data, room, len(targets) - len(disconnected))
|
||||
|
||||
async def send_personal(self, websocket: WebSocket, event_type: str, data: dict):
|
||||
"""Отправить событие одному клиенту."""
|
||||
try:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": event_type,
|
||||
"data": data,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}))
|
||||
except Exception:
|
||||
self._global.discard(websocket)
|
||||
|
||||
|
||||
# Singleton
|
||||
ws_manager = ConnectionManager()
|
||||
|
||||
|
||||
# === Helper functions for broadcasting from routes ===
|
||||
|
||||
async def notify_new_ad(ad_number: str, aircraft_types: list, compliance_type: str):
|
||||
"""Уведомление о новой ДЛГ."""
|
||||
if compliance_type == "mandatory":
|
||||
await ws_manager.broadcast("ad_new_mandatory", {
|
||||
"ad_number": ad_number,
|
||||
"aircraft_types": aircraft_types,
|
||||
"severity": "critical",
|
||||
"message": f"⚠️ Новая обязательная ДЛГ: {ad_number}",
|
||||
})
|
||||
|
||||
|
||||
async def notify_critical_defect(aircraft_reg: str, description: str, defect_id: str):
|
||||
"""Уведомление о критическом дефекте."""
|
||||
await ws_manager.broadcast("defect_critical", {
|
||||
"aircraft_reg": aircraft_reg,
|
||||
"description": description[:100],
|
||||
"defect_id": defect_id,
|
||||
"severity": "critical",
|
||||
"message": f"🔴 Критический дефект: {aircraft_reg}",
|
||||
})
|
||||
|
||||
|
||||
async def notify_wo_aog(wo_number: str, aircraft_reg: str):
|
||||
"""Уведомление о наряде AOG."""
|
||||
await ws_manager.broadcast("wo_aog", {
|
||||
"wo_number": wo_number,
|
||||
"aircraft_reg": aircraft_reg,
|
||||
"severity": "critical",
|
||||
"message": f"🔴 AOG: {aircraft_reg} — наряд {wo_number}",
|
||||
})
|
||||
|
||||
|
||||
async def notify_wo_closed(wo_number: str, aircraft_reg: str, crs_by: str):
|
||||
"""Уведомление о закрытии наряда с CRS."""
|
||||
await ws_manager.broadcast("wo_closed_crs", {
|
||||
"wo_number": wo_number,
|
||||
"aircraft_reg": aircraft_reg,
|
||||
"crs_signed_by": crs_by,
|
||||
"message": f"✅ CRS: {aircraft_reg} — наряд {wo_number}",
|
||||
})
|
||||
|
||||
|
||||
async def notify_life_limit_critical(component: str, serial: str, remaining: dict):
|
||||
"""Уведомление об исчерпании ресурса."""
|
||||
await ws_manager.broadcast("life_limit_critical", {
|
||||
"component": component,
|
||||
"serial_number": serial,
|
||||
"remaining": remaining,
|
||||
"severity": "critical",
|
||||
"message": f"⏱️ КРИТИЧЕСКИЙ РЕСУРС: {component} S/N {serial}",
|
||||
})
|
||||
|
||||
|
||||
def make_notification(event: str, entity_type: str, entity_id: str, **extra: Any) -> dict:
|
||||
"""Payload для отправки по WebSocket (broadcast / send_to_org)."""
|
||||
return {"event": event, "entity_type": entity_type, "entity_id": entity_id, **extra}
|
||||
Loading…
Reference in New Issue
Block a user