- Add POST /avm/v2/upload-training-data so AvmRetrainCronService can push CSV rows before triggering retraining (was called but missing) - Add per-district MAE/MAPE/RMSE/R² to _evaluate_ensemble output; district_metrics are now returned in AVMv2TrainResponse and stored separately from global metrics in the model registry - Add predict_with_ab() that applies the active model's ab_test_traffic_pct for deterministic per-property cohort assignment (v2 vs heuristic baseline) - Add POST /avm/v2/ab-config to set traffic_pct on the active registry entry - Add AVMv2ABConfigRequest schema - Expand test suite: 24 → 28 tests covering upload, A/B config, and new validation paths; all green Co-Authored-By: Paperclip <noreply@paperclip.ing>
144 lines
5.4 KiB
Python
144 lines
5.4 KiB
Python
"""AVM v2 ensemble router — residential property valuation."""
|
|
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
|
|
from app.models.avm_v2 import (
|
|
ABComparisonRequest,
|
|
ABComparisonResponse,
|
|
AVMv2ABConfigRequest,
|
|
AVMv2FeatureImportanceResponse,
|
|
AVMv2ModelInfo,
|
|
AVMv2PredictRequest,
|
|
AVMv2PredictResponse,
|
|
AVMv2RollbackRequest,
|
|
AVMv2TrainRequest,
|
|
AVMv2TrainResponse,
|
|
)
|
|
from app.services.avm_v2_service import avm_v2_service
|
|
|
|
router = APIRouter(prefix="/avm/v2", tags=["AVM v2 Ensemble"])
|
|
|
|
|
|
@router.post("/predict", response_model=AVMv2PredictResponse)
|
|
def predict_v2(req: AVMv2PredictRequest) -> AVMv2PredictResponse:
|
|
"""Predict residential property price using the multi-model ensemble.
|
|
|
|
Ensemble: XGBoost (0.4) + LightGBM (0.35) + CatBoost (0.25).
|
|
Falls back to heuristic when trained models are not available.
|
|
|
|
When an A/B test is active (``ab_test_traffic_pct > 0`` on the active
|
|
model), a deterministic per-property cohort assignment decides whether
|
|
the request is served by v2 (within the traffic slice) or by the
|
|
heuristic baseline (v1-equivalent, outside the slice).
|
|
"""
|
|
response, _used_v2 = avm_v2_service.predict_with_ab(req)
|
|
return response
|
|
|
|
|
|
@router.post("/train", response_model=AVMv2TrainResponse)
|
|
def train_v2(req: AVMv2TrainRequest) -> AVMv2TrainResponse:
|
|
"""Trigger model retraining with Optuna hyperparameter optimization.
|
|
|
|
Loads training data from the model directory, runs Optuna for each
|
|
model in the ensemble, saves versioned artifacts, and registers
|
|
the new version in the model registry.
|
|
"""
|
|
return avm_v2_service.train(req)
|
|
|
|
|
|
@router.post("/compare-v1", response_model=ABComparisonResponse)
|
|
def compare_v1(req: ABComparisonRequest) -> ABComparisonResponse:
|
|
"""Compare v1 (single-model) vs v2 (ensemble) predictions side by side.
|
|
|
|
Runs both models on the same property and returns price difference,
|
|
confidence delta, and a recommendation on which to prefer.
|
|
"""
|
|
return avm_v2_service.compare_v1(req)
|
|
|
|
|
|
@router.get("/model-info", response_model=AVMv2ModelInfo)
|
|
def model_info_v2() -> AVMv2ModelInfo:
|
|
"""Get current active ensemble model information."""
|
|
return avm_v2_service.get_model_info()
|
|
|
|
|
|
@router.get("/feature-importance", response_model=AVMv2FeatureImportanceResponse)
|
|
def feature_importance_v2() -> AVMv2FeatureImportanceResponse:
|
|
"""Global feature importance for the active ensemble.
|
|
|
|
Aggregates XGBoost gain (0.4) + LightGBM gain (0.35) + CatBoost importance (0.25)
|
|
when trained boosters are loaded. Falls back to a curated heuristic ranking when
|
|
the service is running without artifacts.
|
|
"""
|
|
return avm_v2_service.get_feature_importance()
|
|
|
|
|
|
@router.get("/versions", response_model=list[AVMv2ModelInfo])
|
|
def list_versions() -> list[AVMv2ModelInfo]:
|
|
"""List all registered model versions with their metrics and status."""
|
|
return avm_v2_service.list_versions()
|
|
|
|
|
|
@router.post("/rollback", response_model=AVMv2ModelInfo)
|
|
def rollback(req: AVMv2RollbackRequest) -> AVMv2ModelInfo:
|
|
"""Rollback to a previously trained model version.
|
|
|
|
Copies the target version's artifacts to the active model directory,
|
|
reloads models, and updates the registry.
|
|
"""
|
|
try:
|
|
return avm_v2_service.rollback(req.target_version)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
|
|
@router.post("/upload-training-data", status_code=200)
|
|
async def upload_training_data(request: Request) -> dict:
|
|
"""Accept a CSV payload of training rows and persist it to the model directory.
|
|
|
|
Called by the NestJS ``AvmRetrainCronService`` before triggering a retrain.
|
|
The CSV must include a header row whose column names match the feature schema
|
|
expected by ``AVMv2EnsembleService._prepare_training_data``.
|
|
"""
|
|
from app.config import settings
|
|
from pathlib import Path
|
|
|
|
body = await request.body()
|
|
if not body:
|
|
raise HTTPException(status_code=400, detail="Empty request body")
|
|
|
|
# Validate it looks like CSV (has at least a header + one data row)
|
|
try:
|
|
text = body.decode("utf-8")
|
|
lines = [ln for ln in text.splitlines() if ln.strip()]
|
|
if len(lines) < 2:
|
|
raise HTTPException(status_code=400, detail="CSV must contain header + at least one data row")
|
|
header = lines[0].split(",")
|
|
if "price_vnd" not in header:
|
|
raise HTTPException(status_code=400, detail="CSV missing required column: price_vnd")
|
|
except UnicodeDecodeError as exc:
|
|
raise HTTPException(status_code=400, detail=f"Could not decode CSV as UTF-8: {exc}") from exc
|
|
|
|
model_dir = Path(settings.model_path)
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
dest = model_dir / "training_data.csv"
|
|
dest.write_text(text, encoding="utf-8")
|
|
|
|
return {"rows_received": len(lines) - 1, "destination": str(dest)}
|
|
|
|
|
|
@router.post("/ab-config", response_model=AVMv2ModelInfo)
|
|
def set_ab_config(req: AVMv2ABConfigRequest) -> AVMv2ModelInfo:
|
|
"""Update the A/B test traffic percentage for the active model.
|
|
|
|
Set ``traffic_pct=0.10`` to route 10% of predict calls to v2.
|
|
Set ``traffic_pct=1.0`` to fully switch all traffic to v2.
|
|
Set ``traffic_pct=0.0`` to run v2 for all calls with no split.
|
|
|
|
The registry is persisted to disk so the setting survives restarts.
|
|
"""
|
|
try:
|
|
return avm_v2_service.set_ab_traffic(req.traffic_pct)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|