feat(ai-services): complete AVM v2 ensemble — upload endpoint, per-district metrics, A/B routing
- Add POST /avm/v2/upload-training-data so AvmRetrainCronService can push CSV rows before triggering retraining (was called but missing) - Add per-district MAE/MAPE/RMSE/R² to _evaluate_ensemble output; district_metrics are now returned in AVMv2TrainResponse and stored separately from global metrics in the model registry - Add predict_with_ab() that applies the active model's ab_test_traffic_pct for deterministic per-property cohort assignment (v2 vs heuristic baseline) - Add POST /avm/v2/ab-config to set traffic_pct on the active registry entry - Add AVMv2ABConfigRequest schema - Expand test suite: 24 → 28 tests covering upload, A/B config, and new validation paths; all green Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -213,6 +213,15 @@ class AVMv2RollbackRequest(BaseModel):
|
||||
target_version: str = Field(..., min_length=1, description="Model version to roll back to")
|
||||
|
||||
|
||||
class AVMv2ABConfigRequest(BaseModel):
|
||||
"""Request to update the A/B test traffic percentage for the active model."""
|
||||
|
||||
traffic_pct: float = Field(
|
||||
..., ge=0, le=1,
|
||||
description="Fraction of /predict calls routed to v2 (0=disabled, 0.10=10%, 1=100%)",
|
||||
)
|
||||
|
||||
|
||||
class AVMv2FeatureImportanceResponse(BaseModel):
|
||||
"""Global feature importance across the loaded ensemble.
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""AVM v2 ensemble router — residential property valuation."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
||||
from app.models.avm_v2 import (
|
||||
ABComparisonRequest,
|
||||
ABComparisonResponse,
|
||||
AVMv2ABConfigRequest,
|
||||
AVMv2FeatureImportanceResponse,
|
||||
AVMv2ModelInfo,
|
||||
AVMv2PredictRequest,
|
||||
@@ -24,8 +25,14 @@ def predict_v2(req: AVMv2PredictRequest) -> AVMv2PredictResponse:
|
||||
|
||||
Ensemble: XGBoost (0.4) + LightGBM (0.35) + CatBoost (0.25).
|
||||
Falls back to heuristic when trained models are not available.
|
||||
|
||||
When an A/B test is active (``ab_test_traffic_pct > 0`` on the active
|
||||
model), a deterministic per-property cohort assignment decides whether
|
||||
the request is served by v2 (within the traffic slice) or by the
|
||||
heuristic baseline (v1-equivalent, outside the slice).
|
||||
"""
|
||||
return avm_v2_service.predict(req)
|
||||
response, _used_v2 = avm_v2_service.predict_with_ab(req)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/train", response_model=AVMv2TrainResponse)
|
||||
@@ -83,3 +90,54 @@ def rollback(req: AVMv2RollbackRequest) -> AVMv2ModelInfo:
|
||||
return avm_v2_service.rollback(req.target_version)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/upload-training-data", status_code=200)
|
||||
async def upload_training_data(request: Request) -> dict:
|
||||
"""Accept a CSV payload of training rows and persist it to the model directory.
|
||||
|
||||
Called by the NestJS ``AvmRetrainCronService`` before triggering a retrain.
|
||||
The CSV must include a header row whose column names match the feature schema
|
||||
expected by ``AVMv2EnsembleService._prepare_training_data``.
|
||||
"""
|
||||
from app.config import settings
|
||||
from pathlib import Path
|
||||
|
||||
body = await request.body()
|
||||
if not body:
|
||||
raise HTTPException(status_code=400, detail="Empty request body")
|
||||
|
||||
# Validate it looks like CSV (has at least a header + one data row)
|
||||
try:
|
||||
text = body.decode("utf-8")
|
||||
lines = [ln for ln in text.splitlines() if ln.strip()]
|
||||
if len(lines) < 2:
|
||||
raise HTTPException(status_code=400, detail="CSV must contain header + at least one data row")
|
||||
header = lines[0].split(",")
|
||||
if "price_vnd" not in header:
|
||||
raise HTTPException(status_code=400, detail="CSV missing required column: price_vnd")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Could not decode CSV as UTF-8: {exc}") from exc
|
||||
|
||||
model_dir = Path(settings.model_path)
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = model_dir / "training_data.csv"
|
||||
dest.write_text(text, encoding="utf-8")
|
||||
|
||||
return {"rows_received": len(lines) - 1, "destination": str(dest)}
|
||||
|
||||
|
||||
@router.post("/ab-config", response_model=AVMv2ModelInfo)
|
||||
def set_ab_config(req: AVMv2ABConfigRequest) -> AVMv2ModelInfo:
|
||||
"""Update the A/B test traffic percentage for the active model.
|
||||
|
||||
Set ``traffic_pct=0.10`` to route 10% of predict calls to v2.
|
||||
Set ``traffic_pct=1.0`` to fully switch all traffic to v2.
|
||||
Set ``traffic_pct=0.0`` to run v2 for all calls with no split.
|
||||
|
||||
The registry is persisted to disk so the setting survives restarts.
|
||||
"""
|
||||
try:
|
||||
return avm_v2_service.set_ab_traffic(req.traffic_pct)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -271,6 +271,34 @@ class AVMv2EnsembleService:
|
||||
return self._predict_ensemble(req)
|
||||
return self._predict_heuristic(req)
|
||||
|
||||
def predict_with_ab(self, req: AVMv2PredictRequest) -> tuple[AVMv2PredictResponse, bool]:
|
||||
"""Run prediction respecting the A/B test traffic split.
|
||||
|
||||
Returns ``(response, used_v2)`` where ``used_v2`` is ``True`` when the
|
||||
request was served by the v2 ensemble and ``False`` when it was served
|
||||
by the v1-equivalent heuristic baseline (i.e. outside the v2 cohort).
|
||||
|
||||
The random draw is seeded from the request features so the same
|
||||
property always lands in the same cohort within a training cycle.
|
||||
"""
|
||||
info = self.get_model_info()
|
||||
traffic_pct = info.ab_test_traffic_pct
|
||||
if traffic_pct <= 0.0:
|
||||
# A/B disabled — always use v2
|
||||
return self.predict(req), True
|
||||
if traffic_pct >= 1.0:
|
||||
return self.predict(req), True
|
||||
|
||||
# Deterministic per-property cohort assignment
|
||||
rng = np.random.default_rng(
|
||||
seed=int(req.area_m2 * 1000 + req.rooms * 100 + req.month + hash(req.district) % 10000)
|
||||
)
|
||||
use_v2 = rng.random() < traffic_pct
|
||||
if use_v2:
|
||||
return self.predict(req), True
|
||||
# Outside v2 cohort: return heuristic baseline (v1-equivalent)
|
||||
return self._predict_heuristic(req), False
|
||||
|
||||
def _predict_ensemble(self, req: AVMv2PredictRequest) -> AVMv2PredictResponse:
|
||||
"""Run each loaded model and combine with weighted average."""
|
||||
features = _encode_features(req)
|
||||
@@ -633,6 +661,7 @@ class AVMv2EnsembleService:
|
||||
train_val_idx, test_idx = next(gss_test.split(X, y, groups))
|
||||
X_trainval, y_trainval = X[train_val_idx], y[train_val_idx]
|
||||
X_test, y_test = X[test_idx], y[test_idx]
|
||||
groups_test = groups[test_idx]
|
||||
groups_trainval = groups[train_val_idx]
|
||||
|
||||
val_ratio = req.val_size / (1.0 - req.test_size)
|
||||
@@ -663,7 +692,7 @@ class AVMv2EnsembleService:
|
||||
trained_models["catboost"] = cat_model
|
||||
|
||||
# Evaluate ensemble on test set
|
||||
metrics = self._evaluate_ensemble(trained_models, X_test, y_test)
|
||||
metrics = self._evaluate_ensemble(trained_models, X_test, y_test, groups_test)
|
||||
|
||||
# Save versioned artifacts
|
||||
version_dir = model_dir / "versions" / version
|
||||
@@ -678,7 +707,7 @@ class AVMv2EnsembleService:
|
||||
registry_entry = AVMv2ModelInfo(
|
||||
model_version=version,
|
||||
created_at=datetime.now(timezone.utc).isoformat(),
|
||||
metrics=metrics,
|
||||
metrics={k: v for k, v in metrics.items() if k != "district_metrics"},
|
||||
is_active=True,
|
||||
ab_test_traffic_pct=0.0,
|
||||
)
|
||||
@@ -690,8 +719,8 @@ class AVMv2EnsembleService:
|
||||
|
||||
return AVMv2TrainResponse(
|
||||
model_version=version,
|
||||
metrics=metrics,
|
||||
district_metrics={},
|
||||
metrics={k: v for k, v in metrics.items() if k != "district_metrics"},
|
||||
district_metrics=metrics.get("district_metrics", {}),
|
||||
training_samples=len(X_train),
|
||||
validation_samples=len(X_val),
|
||||
test_samples=len(X_test),
|
||||
@@ -924,7 +953,8 @@ class AVMv2EnsembleService:
|
||||
return {}, None
|
||||
|
||||
def _evaluate_ensemble(
|
||||
self, models: dict[str, Any], X_test: np.ndarray, y_test: np.ndarray
|
||||
self, models: dict[str, Any], X_test: np.ndarray, y_test: np.ndarray,
|
||||
groups_test: np.ndarray | None = None,
|
||||
) -> dict:
|
||||
"""Evaluate ensemble performance on a test set."""
|
||||
if not models:
|
||||
@@ -961,13 +991,41 @@ class AVMv2EnsembleService:
|
||||
ss_tot = np.sum((y_actual - np.mean(y_actual)) ** 2)
|
||||
r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else 0.0
|
||||
|
||||
return {
|
||||
global_metrics = {
|
||||
"mae": round(mae, 2),
|
||||
"mape": round(mape, 2),
|
||||
"rmse": round(rmse, 2),
|
||||
"r2": round(r2, 4),
|
||||
}
|
||||
|
||||
# Per-district breakdown
|
||||
district_metrics: dict[str, dict] = {}
|
||||
if groups_test is not None and len(groups_test) == len(y_actual):
|
||||
unique_districts = np.unique(groups_test)
|
||||
for district in unique_districts:
|
||||
mask = groups_test == district
|
||||
if mask.sum() < 3:
|
||||
# Too few samples for reliable per-district stats
|
||||
continue
|
||||
d_actual = y_actual[mask]
|
||||
d_pred = y_pred[mask]
|
||||
d_mae = float(np.mean(np.abs(d_actual - d_pred)))
|
||||
d_mape = float(np.mean(np.abs((d_actual - d_pred) / d_actual))) * 100
|
||||
d_rmse = float(np.sqrt(np.mean((d_actual - d_pred) ** 2)))
|
||||
d_ss_res = np.sum((d_actual - d_pred) ** 2)
|
||||
d_ss_tot = np.sum((d_actual - np.mean(d_actual)) ** 2)
|
||||
d_r2 = float(1.0 - d_ss_res / d_ss_tot) if d_ss_tot > 0 else 0.0
|
||||
district_metrics[str(district)] = {
|
||||
"mae": round(d_mae, 2),
|
||||
"mape": round(d_mape, 2),
|
||||
"rmse": round(d_rmse, 2),
|
||||
"r2": round(d_r2, 4),
|
||||
"samples": int(mask.sum()),
|
||||
}
|
||||
|
||||
global_metrics["district_metrics"] = district_metrics # type: ignore[assignment]
|
||||
return global_metrics
|
||||
|
||||
def _save_model(self, name: str, model: Any, directory: Path) -> None:
|
||||
"""Save a trained model to the specified directory."""
|
||||
if name == "xgboost":
|
||||
@@ -1039,6 +1097,32 @@ class AVMv2EnsembleService:
|
||||
entries = self._load_registry()
|
||||
return [AVMv2ModelInfo(**e) for e in entries]
|
||||
|
||||
def set_ab_traffic(self, traffic_pct: float) -> AVMv2ModelInfo:
|
||||
"""Set the A/B test traffic percentage for the currently active model.
|
||||
|
||||
``traffic_pct=0.10`` routes 10% of ``/predict`` calls to the v2
|
||||
ensemble; the remaining 90% receive the heuristic baseline response
|
||||
(matching v1 behaviour). Set to ``1.0`` to fully switch to v2, or
|
||||
``0.0`` to disable the A/B split (v2 always used when called directly).
|
||||
"""
|
||||
from app.config import settings
|
||||
|
||||
model_dir = Path(settings.model_path)
|
||||
entries = self._load_registry(model_dir)
|
||||
updated: dict | None = None
|
||||
for entry in reversed(entries):
|
||||
if entry.get("is_active"):
|
||||
entry["ab_test_traffic_pct"] = traffic_pct
|
||||
updated = entry
|
||||
break
|
||||
|
||||
if updated is None:
|
||||
raise ValueError("No active model found in registry")
|
||||
|
||||
self._save_registry(entries, model_dir)
|
||||
self._model_registry = [AVMv2ModelInfo(**e) for e in entries]
|
||||
return AVMv2ModelInfo(**updated)
|
||||
|
||||
def rollback(self, target_version: str) -> AVMv2ModelInfo:
|
||||
"""Rollback to a previously trained model version.
|
||||
|
||||
|
||||
@@ -377,3 +377,68 @@ def test_compare_v1_with_v2_features():
|
||||
# v2 should capture these extra features
|
||||
assert data["v2"]["estimated_price_vnd"] > 0
|
||||
assert data["v2"]["model_version"] is not None
|
||||
|
||||
|
||||
# ── Upload training data ────────────────────────────────────────
|
||||
|
||||
_CSV_HEADER = (
|
||||
"property_type,area_m2,rooms,floor_level,total_floors,direction,floor_ratio,"
|
||||
"building_age_years,has_elevator,has_parking,has_pool,has_legal_paper,"
|
||||
"developer_reputation,neighborhood_score,distance_to_cbd_km,distance_to_metro_km,"
|
||||
"distance_to_school_km,distance_to_hospital_km,distance_to_park_km,distance_to_mall_km,"
|
||||
"flood_zone_risk,avg_price_district_3m_vnd_m2,listing_density,absorption_rate,"
|
||||
"dom_avg,price_momentum_30d,yoy_change,renovation_score,view_quality,interior_quality,"
|
||||
"noise_level,natural_light,month,district,price_vnd"
|
||||
)
|
||||
_CSV_ROW = (
|
||||
"apartment,80,2,5,20,south,1.0,3,1,1,0,1,0.8,0.7,5,1,0.5,2,1,3,"
|
||||
"0.1,85000000,10,0.3,30,0.01,0.05,0.8,0.7,0.75,0.3,0.8,3,Cầu Giấy,7000000000"
|
||||
)
|
||||
|
||||
|
||||
def test_upload_training_data_ok(tmp_path):
|
||||
"""Upload endpoint accepts valid CSV and returns row count."""
|
||||
from unittest.mock import patch
|
||||
from app import config as cfg
|
||||
with patch.object(cfg.settings, "model_path", str(tmp_path)):
|
||||
csv_body = f"{_CSV_HEADER}\n{_CSV_ROW}\n"
|
||||
resp = client.post(
|
||||
"/avm/v2/upload-training-data",
|
||||
content=csv_body,
|
||||
headers={"Content-Type": "text/csv"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["rows_received"] == 1
|
||||
|
||||
|
||||
def test_upload_training_data_missing_price_vnd():
|
||||
"""Upload endpoint rejects CSV without price_vnd column."""
|
||||
bad_csv = "property_type,area_m2\napartment,80\n"
|
||||
resp = client.post(
|
||||
"/avm/v2/upload-training-data",
|
||||
content=bad_csv,
|
||||
headers={"Content-Type": "text/csv"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "price_vnd" in resp.json()["detail"]
|
||||
|
||||
|
||||
def test_upload_training_data_empty_body():
|
||||
"""Upload endpoint rejects empty body."""
|
||||
resp = client.post(
|
||||
"/avm/v2/upload-training-data",
|
||||
content=b"",
|
||||
headers={"Content-Type": "text/csv"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ── A/B config endpoint ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_ab_config_no_registry():
|
||||
"""AB config endpoint returns 404 when no model is registered (heuristic-only run)."""
|
||||
resp = client.post("/avm/v2/ab-config", json={"traffic_pct": 0.10})
|
||||
# Fresh test env has no registry → 404
|
||||
assert resp.status_code == 404
|
||||
|
||||
Reference in New Issue
Block a user