Files
Ho Ngoc Hai 66f952a4a8 feat(ai-services): complete AVM v2 ensemble — upload endpoint, per-district metrics, A/B routing
- Add POST /avm/v2/upload-training-data so AvmRetrainCronService can push
  CSV rows before triggering retraining (was called but missing)
- Add per-district MAE/MAPE/RMSE/R² to _evaluate_ensemble output;
  district_metrics are now returned in AVMv2TrainResponse and stored
  separately from global metrics in the model registry
- Add predict_with_ab() that applies the active model's ab_test_traffic_pct
  for deterministic per-property cohort assignment (v2 vs heuristic baseline)
- Add POST /avm/v2/ab-config to set traffic_pct on the active registry entry
- Add AVMv2ABConfigRequest schema
- Expand test suite: 24 → 28 tests covering upload, A/B config, and new
  validation paths; all green

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-21 04:39:57 +07:00

144 lines
5.4 KiB
Python

"""AVM v2 ensemble router — residential property valuation."""
from fastapi import APIRouter, HTTPException, Request
from app.models.avm_v2 import (
ABComparisonRequest,
ABComparisonResponse,
AVMv2ABConfigRequest,
AVMv2FeatureImportanceResponse,
AVMv2ModelInfo,
AVMv2PredictRequest,
AVMv2PredictResponse,
AVMv2RollbackRequest,
AVMv2TrainRequest,
AVMv2TrainResponse,
)
from app.services.avm_v2_service import avm_v2_service
router = APIRouter(prefix="/avm/v2", tags=["AVM v2 Ensemble"])
@router.post("/predict", response_model=AVMv2PredictResponse)
def predict_v2(req: AVMv2PredictRequest) -> AVMv2PredictResponse:
"""Predict residential property price using the multi-model ensemble.
Ensemble: XGBoost (0.4) + LightGBM (0.35) + CatBoost (0.25).
Falls back to heuristic when trained models are not available.
When an A/B test is active (``ab_test_traffic_pct > 0`` on the active
model), a deterministic per-property cohort assignment decides whether
the request is served by v2 (within the traffic slice) or by the
heuristic baseline (v1-equivalent, outside the slice).
"""
response, _used_v2 = avm_v2_service.predict_with_ab(req)
return response
@router.post("/train", response_model=AVMv2TrainResponse)
def train_v2(req: AVMv2TrainRequest) -> AVMv2TrainResponse:
"""Trigger model retraining with Optuna hyperparameter optimization.
Loads training data from the model directory, runs Optuna for each
model in the ensemble, saves versioned artifacts, and registers
the new version in the model registry.
"""
return avm_v2_service.train(req)
@router.post("/compare-v1", response_model=ABComparisonResponse)
def compare_v1(req: ABComparisonRequest) -> ABComparisonResponse:
"""Compare v1 (single-model) vs v2 (ensemble) predictions side by side.
Runs both models on the same property and returns price difference,
confidence delta, and a recommendation on which to prefer.
"""
return avm_v2_service.compare_v1(req)
@router.get("/model-info", response_model=AVMv2ModelInfo)
def model_info_v2() -> AVMv2ModelInfo:
"""Get current active ensemble model information."""
return avm_v2_service.get_model_info()
@router.get("/feature-importance", response_model=AVMv2FeatureImportanceResponse)
def feature_importance_v2() -> AVMv2FeatureImportanceResponse:
"""Global feature importance for the active ensemble.
Aggregates XGBoost gain (0.4) + LightGBM gain (0.35) + CatBoost importance (0.25)
when trained boosters are loaded. Falls back to a curated heuristic ranking when
the service is running without artifacts.
"""
return avm_v2_service.get_feature_importance()
@router.get("/versions", response_model=list[AVMv2ModelInfo])
def list_versions() -> list[AVMv2ModelInfo]:
"""List all registered model versions with their metrics and status."""
return avm_v2_service.list_versions()
@router.post("/rollback", response_model=AVMv2ModelInfo)
def rollback(req: AVMv2RollbackRequest) -> AVMv2ModelInfo:
"""Rollback to a previously trained model version.
Copies the target version's artifacts to the active model directory,
reloads models, and updates the registry.
"""
try:
return avm_v2_service.rollback(req.target_version)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/upload-training-data", status_code=200)
async def upload_training_data(request: Request) -> dict:
"""Accept a CSV payload of training rows and persist it to the model directory.
Called by the NestJS ``AvmRetrainCronService`` before triggering a retrain.
The CSV must include a header row whose column names match the feature schema
expected by ``AVMv2EnsembleService._prepare_training_data``.
"""
from app.config import settings
from pathlib import Path
body = await request.body()
if not body:
raise HTTPException(status_code=400, detail="Empty request body")
# Validate it looks like CSV (has at least a header + one data row)
try:
text = body.decode("utf-8")
lines = [ln for ln in text.splitlines() if ln.strip()]
if len(lines) < 2:
raise HTTPException(status_code=400, detail="CSV must contain header + at least one data row")
header = lines[0].split(",")
if "price_vnd" not in header:
raise HTTPException(status_code=400, detail="CSV missing required column: price_vnd")
except UnicodeDecodeError as exc:
raise HTTPException(status_code=400, detail=f"Could not decode CSV as UTF-8: {exc}") from exc
model_dir = Path(settings.model_path)
model_dir.mkdir(parents=True, exist_ok=True)
dest = model_dir / "training_data.csv"
dest.write_text(text, encoding="utf-8")
return {"rows_received": len(lines) - 1, "destination": str(dest)}
@router.post("/ab-config", response_model=AVMv2ModelInfo)
def set_ab_config(req: AVMv2ABConfigRequest) -> AVMv2ModelInfo:
"""Update the A/B test traffic percentage for the active model.
Set ``traffic_pct=0.10`` to route 10% of predict calls to v2.
Set ``traffic_pct=1.0`` to fully switch all traffic to v2.
Set ``traffic_pct=0.0`` to run v2 for all calls with no split.
The registry is persisted to disk so the setting survives restarts.
"""
try:
return avm_v2_service.set_ab_traffic(req.traffic_pct)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))