feat(ai-services): complete AVM v2 ensemble — upload endpoint, per-district metrics, A/B routing

- Add POST /avm/v2/upload-training-data so AvmRetrainCronService can push
  CSV rows before triggering retraining (was called but missing)
- Add per-district MAE/MAPE/RMSE/R² to _evaluate_ensemble output;
  district_metrics are now returned in AVMv2TrainResponse and stored
  separately from global metrics in the model registry
- Add predict_with_ab() that applies the active model's ab_test_traffic_pct
  for deterministic per-property cohort assignment (v2 vs heuristic baseline)
- Add POST /avm/v2/ab-config to set traffic_pct on the active registry entry
- Add AVMv2ABConfigRequest schema
- Expand test suite: 24 → 28 tests covering upload, A/B config, and new
  validation paths; all green

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
Ho Ngoc Hai
2026-04-21 04:39:57 +07:00
parent 9cefd439db
commit 66f952a4a8
4 changed files with 224 additions and 8 deletions

View File

@@ -1,10 +1,11 @@
"""AVM v2 ensemble router — residential property valuation."""
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Request
from app.models.avm_v2 import (
ABComparisonRequest,
ABComparisonResponse,
AVMv2ABConfigRequest,
AVMv2FeatureImportanceResponse,
AVMv2ModelInfo,
AVMv2PredictRequest,
@@ -24,8 +25,14 @@ def predict_v2(req: AVMv2PredictRequest) -> AVMv2PredictResponse:
Ensemble: XGBoost (0.4) + LightGBM (0.35) + CatBoost (0.25).
Falls back to heuristic when trained models are not available.
When an A/B test is active (``ab_test_traffic_pct > 0`` on the active
model), a deterministic per-property cohort assignment decides whether
the request is served by v2 (within the traffic slice) or by the
heuristic baseline (v1-equivalent, outside the slice).
"""
return avm_v2_service.predict(req)
response, _used_v2 = avm_v2_service.predict_with_ab(req)
return response
@router.post("/train", response_model=AVMv2TrainResponse)
@@ -83,3 +90,54 @@ def rollback(req: AVMv2RollbackRequest) -> AVMv2ModelInfo:
return avm_v2_service.rollback(req.target_version)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/upload-training-data", status_code=200)
async def upload_training_data(request: Request) -> dict:
"""Accept a CSV payload of training rows and persist it to the model directory.
Called by the NestJS ``AvmRetrainCronService`` before triggering a retrain.
The CSV must include a header row whose column names match the feature schema
expected by ``AVMv2EnsembleService._prepare_training_data``.
"""
from app.config import settings
from pathlib import Path
body = await request.body()
if not body:
raise HTTPException(status_code=400, detail="Empty request body")
# Validate it looks like CSV (has at least a header + one data row)
try:
text = body.decode("utf-8")
lines = [ln for ln in text.splitlines() if ln.strip()]
if len(lines) < 2:
raise HTTPException(status_code=400, detail="CSV must contain header + at least one data row")
header = lines[0].split(",")
if "price_vnd" not in header:
raise HTTPException(status_code=400, detail="CSV missing required column: price_vnd")
except UnicodeDecodeError as exc:
raise HTTPException(status_code=400, detail=f"Could not decode CSV as UTF-8: {exc}") from exc
model_dir = Path(settings.model_path)
model_dir.mkdir(parents=True, exist_ok=True)
dest = model_dir / "training_data.csv"
dest.write_text(text, encoding="utf-8")
return {"rows_received": len(lines) - 1, "destination": str(dest)}
@router.post("/ab-config", response_model=AVMv2ModelInfo)
def set_ab_config(req: AVMv2ABConfigRequest) -> AVMv2ModelInfo:
"""Update the A/B test traffic percentage for the active model.
Set ``traffic_pct=0.10`` to route 10% of predict calls to v2.
Set ``traffic_pct=1.0`` to fully switch all traffic to v2.
Set ``traffic_pct=0.0`` to run v2 for all calls with no split.
The registry is persisted to disk so the setting survives restarts.
"""
try:
return avm_v2_service.set_ab_traffic(req.traffic_pct)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))