- Add POST /avm/v2/upload-training-data so AvmRetrainCronService can push CSV rows before triggering retraining (was called but missing) - Add per-district MAE/MAPE/RMSE/R² to _evaluate_ensemble output; district_metrics are now returned in AVMv2TrainResponse and stored separately from global metrics in the model registry - Add predict_with_ab() that applies the active model's ab_test_traffic_pct for deterministic per-property cohort assignment (v2 vs heuristic baseline) - Add POST /avm/v2/ab-config to set traffic_pct on the active registry entry - Add AVMv2ABConfigRequest schema - Expand test suite: 24 → 28 tests covering upload, A/B config, and new validation paths; all green Co-Authored-By: Paperclip <noreply@paperclip.ing>
445 lines
15 KiB
Python
445 lines
15 KiB
Python
"""Tests for AVM v2 ensemble endpoints."""
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
from app.main import app
|
|
|
|
client = TestClient(app)
|
|
|
|
# ── Minimal valid request payload ───────────────────────────────
|
|
|
|
_PREDICT_PAYLOAD = {
|
|
"district": "Cầu Giấy",
|
|
"city": "Hà Nội",
|
|
"property_type": "apartment",
|
|
"area_m2": 80.0,
|
|
"rooms": 2,
|
|
"month": 3,
|
|
"quarter": 1,
|
|
}
|
|
|
|
|
|
def test_predict_v2_heuristic():
|
|
"""Predict using heuristic fallback (no trained models)."""
|
|
resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
|
|
assert data["estimated_price_vnd"] > 0
|
|
assert 0 <= data["confidence"] <= 1
|
|
assert data["price_per_m2_vnd"] > 0
|
|
assert data["price_range_low_vnd"] < data["estimated_price_vnd"]
|
|
assert data["price_range_high_vnd"] > data["estimated_price_vnd"]
|
|
assert data["ensemble_method"] == "weighted_average"
|
|
assert data["model_version"] == "ensemble-v2-heuristic"
|
|
|
|
|
|
def test_predict_v2_returns_model_predictions():
|
|
"""Heuristic should return 3 simulated model predictions."""
|
|
resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD)
|
|
data = resp.json()
|
|
|
|
preds = data["model_predictions"]
|
|
assert len(preds) == 3
|
|
|
|
names = {p["model_name"] for p in preds}
|
|
assert names == {"xgboost", "lightgbm", "catboost"}
|
|
|
|
for p in preds:
|
|
assert p["weight"] > 0
|
|
assert p["predicted_price_vnd"] > 0
|
|
assert p["predicted_price_per_m2_vnd"] > 0
|
|
|
|
|
|
def test_predict_v2_returns_drivers():
|
|
"""Heuristic should return feature importance drivers."""
|
|
resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD)
|
|
data = resp.json()
|
|
|
|
drivers = data["drivers"]
|
|
assert len(drivers) > 0
|
|
assert all(0 <= d["importance"] <= 1 for d in drivers)
|
|
# Most important feature should be area or district price
|
|
top_feature = drivers[0]["feature"]
|
|
assert top_feature in ("area_m2", "avg_price_district_3m_vnd_m2")
|
|
|
|
|
|
def test_predict_v2_with_full_features():
|
|
"""Predict with all features populated (including new v2 features)."""
|
|
payload = {
|
|
**_PREDICT_PAYLOAD,
|
|
"neighborhood_score": 0.8,
|
|
"distance_to_cbd_km": 5.0,
|
|
"distance_to_metro_km": 0.8,
|
|
"distance_to_school_km": 0.5,
|
|
"distance_to_hospital_km": 2.0,
|
|
"distance_to_park_km": 0.3,
|
|
"distance_to_mall_km": 1.0,
|
|
"flood_zone_risk": 0.1,
|
|
"floor_level": 12,
|
|
"total_floors": 25,
|
|
"direction": "southeast",
|
|
"floor_ratio": 1.2,
|
|
"building_age_years": 5,
|
|
"has_elevator": True,
|
|
"has_parking": True,
|
|
"has_pool": False,
|
|
"developer_reputation": 0.9,
|
|
"avg_price_district_3m_vnd_m2": 85_000_000,
|
|
"listing_density": 12.5,
|
|
"absorption_rate": 0.3,
|
|
"dom_avg": 45.0,
|
|
"price_momentum_30d": 0.02,
|
|
"yoy_change": 0.05,
|
|
"renovation_score": 0.8,
|
|
"view_quality": 0.7,
|
|
"interior_quality": 0.75,
|
|
"noise_level": 0.3,
|
|
"natural_light": 0.8,
|
|
"is_year_end": False,
|
|
}
|
|
resp = client.post("/avm/v2/predict", json=payload)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["estimated_price_vnd"] > 0
|
|
assert data["confidence"] > 0
|
|
|
|
|
|
def test_predict_v2_villa_premium():
|
|
"""Villas should be priced higher than apartments (same area)."""
|
|
apt = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json()
|
|
villa_payload = {**_PREDICT_PAYLOAD, "property_type": "villa"}
|
|
villa = client.post("/avm/v2/predict", json=villa_payload).json()
|
|
|
|
assert villa["price_per_m2_vnd"] > apt["price_per_m2_vnd"]
|
|
|
|
|
|
def test_predict_v2_year_end_premium():
|
|
"""Q4/Tết season should add a premium."""
|
|
normal = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "is_year_end": False, "month": 6, "quarter": 2},
|
|
).json()
|
|
year_end = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "is_year_end": True, "month": 12, "quarter": 4},
|
|
).json()
|
|
|
|
assert year_end["estimated_price_vnd"] > normal["estimated_price_vnd"]
|
|
|
|
|
|
def test_predict_v2_no_legal_paper_discount():
|
|
"""Properties without legal papers should be discounted."""
|
|
with_paper = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json()
|
|
without_paper = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "has_legal_paper": False},
|
|
).json()
|
|
|
|
assert without_paper["estimated_price_vnd"] < with_paper["estimated_price_vnd"]
|
|
|
|
|
|
def test_predict_v2_validation_error():
|
|
"""Missing required fields should return 422."""
|
|
resp = client.post("/avm/v2/predict", json={"area_m2": 80})
|
|
assert resp.status_code == 422
|
|
|
|
|
|
def test_predict_v2_invalid_area():
|
|
"""Zero or negative area should be rejected."""
|
|
resp = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "area_m2": 0},
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
|
|
# ── New v2 features: neighborhood, floor, direction, developer ──
|
|
|
|
|
|
def test_predict_v2_neighborhood_premium():
|
|
"""High neighborhood score should increase price."""
|
|
low_nb = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "neighborhood_score": 0.2},
|
|
).json()
|
|
high_nb = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "neighborhood_score": 0.9},
|
|
).json()
|
|
|
|
assert high_nb["estimated_price_vnd"] > low_nb["estimated_price_vnd"]
|
|
|
|
|
|
def test_predict_v2_floor_level_premium():
|
|
"""Higher floor apartments should command a premium."""
|
|
ground = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "floor_level": 2, "total_floors": 25},
|
|
).json()
|
|
high = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "floor_level": 20, "total_floors": 25},
|
|
).json()
|
|
|
|
assert high["estimated_price_vnd"] > ground["estimated_price_vnd"]
|
|
|
|
|
|
def test_predict_v2_direction_premium():
|
|
"""South-facing properties should be priced higher than north-facing."""
|
|
south = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "direction": "south"},
|
|
).json()
|
|
north = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "direction": "north"},
|
|
).json()
|
|
|
|
assert south["estimated_price_vnd"] > north["estimated_price_vnd"]
|
|
|
|
|
|
def test_predict_v2_developer_reputation():
|
|
"""Properties from reputable developers should be valued higher."""
|
|
low_rep = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "developer_reputation": 0.2},
|
|
).json()
|
|
high_rep = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "developer_reputation": 0.9},
|
|
).json()
|
|
|
|
assert high_rep["estimated_price_vnd"] > low_rep["estimated_price_vnd"]
|
|
|
|
|
|
def test_predict_v2_direction_defaults_unknown():
|
|
"""Unknown direction should not affect price (neutral)."""
|
|
explicit = client.post(
|
|
"/avm/v2/predict",
|
|
json={**_PREDICT_PAYLOAD, "direction": "unknown"},
|
|
).json()
|
|
default = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json()
|
|
|
|
assert explicit["estimated_price_vnd"] == default["estimated_price_vnd"]
|
|
|
|
|
|
def test_predict_v2_drivers_include_new_features():
|
|
"""Drivers should include neighborhood_score, direction, floor_level."""
|
|
resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD)
|
|
data = resp.json()
|
|
driver_names = {d["feature"] for d in data["drivers"]}
|
|
|
|
assert "neighborhood_score" in driver_names
|
|
assert "direction_encoded" in driver_names
|
|
assert "floor_level" in driver_names
|
|
assert "developer_reputation" in driver_names
|
|
|
|
|
|
# ── Training & model info ───────────────────────────────────────
|
|
|
|
|
|
def test_train_v2_no_data():
|
|
"""Training without data returns scaffold with zero metrics."""
|
|
resp = client.post(
|
|
"/avm/v2/train",
|
|
json={"optuna_trials": 10},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "model_version" in data
|
|
assert "ensemble-v2-" in data["model_version"]
|
|
assert data["training_samples"] == 0
|
|
|
|
|
|
def test_model_info_v2():
|
|
"""Model info endpoint should return current model version."""
|
|
resp = client.get("/avm/v2/model-info")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "model_version" in data
|
|
assert data["is_active"] is True
|
|
|
|
|
|
# ── Feature importance endpoint ──────────────────────────────────
|
|
|
|
|
|
def test_feature_importance_heuristic():
|
|
"""Dedicated endpoint returns heuristic drivers when no models are loaded."""
|
|
resp = client.get("/avm/v2/feature-importance")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
|
|
assert data["source"] == "heuristic"
|
|
assert data["model_version"] == "ensemble-v2-heuristic"
|
|
drivers = data["drivers"]
|
|
assert len(drivers) > 0
|
|
importances = [d["importance"] for d in drivers]
|
|
assert importances == sorted(importances, reverse=True)
|
|
assert all(0 <= i <= 1 for i in importances)
|
|
feature_names = {d["feature"] for d in drivers}
|
|
assert "area_m2" in feature_names
|
|
assert "neighborhood_score" in feature_names
|
|
|
|
|
|
# ── Model versioning ────────────────────────────────────────────
|
|
|
|
|
|
def test_list_versions():
|
|
"""Versions endpoint returns a list."""
|
|
resp = client.get("/avm/v2/versions")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert isinstance(data, list)
|
|
|
|
|
|
def test_rollback_not_found():
|
|
"""Rollback to non-existent version returns 404."""
|
|
resp = client.post(
|
|
"/avm/v2/rollback",
|
|
json={"target_version": "nonexistent-version-xyz"},
|
|
)
|
|
assert resp.status_code == 404
|
|
|
|
|
|
# ── A/B comparison tests ─────────────────────────────────────
|
|
|
|
_COMPARE_PAYLOAD = {
|
|
"district": "Cầu Giấy",
|
|
"city": "Hà Nội",
|
|
"property_type": "apartment",
|
|
"area_m2": 80.0,
|
|
"rooms": 2,
|
|
"month": 3,
|
|
"quarter": 1,
|
|
}
|
|
|
|
|
|
def test_compare_v1_returns_both_models():
|
|
"""Compare endpoint returns v1 and v2 predictions."""
|
|
resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
|
|
assert "v1" in data
|
|
assert "v2" in data
|
|
assert data["v1"]["estimated_price_vnd"] > 0
|
|
assert data["v2"]["estimated_price_vnd"] > 0
|
|
assert 0 <= data["v1"]["confidence"] <= 1
|
|
assert 0 <= data["v2"]["confidence"] <= 1
|
|
|
|
|
|
def test_compare_v1_returns_diffs():
|
|
"""Compare endpoint computes price and confidence differences."""
|
|
resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD)
|
|
data = resp.json()
|
|
|
|
expected_diff = data["v2"]["estimated_price_vnd"] - data["v1"]["estimated_price_vnd"]
|
|
assert abs(data["price_diff_vnd"] - expected_diff) < 10_000 # rounding tolerance
|
|
|
|
assert "price_diff_pct" in data
|
|
assert isinstance(data["price_diff_pct"], float)
|
|
assert "confidence_diff" in data
|
|
|
|
|
|
def test_compare_v1_returns_recommendation():
|
|
"""Compare endpoint provides a recommendation string."""
|
|
resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD)
|
|
data = resp.json()
|
|
|
|
assert "recommendation" in data
|
|
assert len(data["recommendation"]) > 0
|
|
|
|
|
|
def test_compare_v1_with_v2_features():
|
|
"""Compare endpoint passes v2-specific features correctly."""
|
|
payload = {
|
|
**_COMPARE_PAYLOAD,
|
|
"neighborhood_score": 0.8,
|
|
"distance_to_cbd_km": 5.0,
|
|
"distance_to_metro_km": 0.8,
|
|
"flood_zone_risk": 0.1,
|
|
"building_age_years": 3,
|
|
"floor_level": 15,
|
|
"total_floors": 30,
|
|
"direction": "southeast",
|
|
"has_elevator": True,
|
|
"has_parking": True,
|
|
"developer_reputation": 0.85,
|
|
"renovation_score": 0.9,
|
|
"view_quality": 0.8,
|
|
"interior_quality": 0.85,
|
|
}
|
|
resp = client.post("/avm/v2/compare-v1", json=payload)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
|
|
# v2 should capture these extra features
|
|
assert data["v2"]["estimated_price_vnd"] > 0
|
|
assert data["v2"]["model_version"] is not None
|
|
|
|
|
|
# ── Upload training data ────────────────────────────────────────
|
|
|
|
_CSV_HEADER = (
|
|
"property_type,area_m2,rooms,floor_level,total_floors,direction,floor_ratio,"
|
|
"building_age_years,has_elevator,has_parking,has_pool,has_legal_paper,"
|
|
"developer_reputation,neighborhood_score,distance_to_cbd_km,distance_to_metro_km,"
|
|
"distance_to_school_km,distance_to_hospital_km,distance_to_park_km,distance_to_mall_km,"
|
|
"flood_zone_risk,avg_price_district_3m_vnd_m2,listing_density,absorption_rate,"
|
|
"dom_avg,price_momentum_30d,yoy_change,renovation_score,view_quality,interior_quality,"
|
|
"noise_level,natural_light,month,district,price_vnd"
|
|
)
|
|
_CSV_ROW = (
|
|
"apartment,80,2,5,20,south,1.0,3,1,1,0,1,0.8,0.7,5,1,0.5,2,1,3,"
|
|
"0.1,85000000,10,0.3,30,0.01,0.05,0.8,0.7,0.75,0.3,0.8,3,Cầu Giấy,7000000000"
|
|
)
|
|
|
|
|
|
def test_upload_training_data_ok(tmp_path):
|
|
"""Upload endpoint accepts valid CSV and returns row count."""
|
|
from unittest.mock import patch
|
|
from app import config as cfg
|
|
with patch.object(cfg.settings, "model_path", str(tmp_path)):
|
|
csv_body = f"{_CSV_HEADER}\n{_CSV_ROW}\n"
|
|
resp = client.post(
|
|
"/avm/v2/upload-training-data",
|
|
content=csv_body,
|
|
headers={"Content-Type": "text/csv"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["rows_received"] == 1
|
|
|
|
|
|
def test_upload_training_data_missing_price_vnd():
|
|
"""Upload endpoint rejects CSV without price_vnd column."""
|
|
bad_csv = "property_type,area_m2\napartment,80\n"
|
|
resp = client.post(
|
|
"/avm/v2/upload-training-data",
|
|
content=bad_csv,
|
|
headers={"Content-Type": "text/csv"},
|
|
)
|
|
assert resp.status_code == 400
|
|
assert "price_vnd" in resp.json()["detail"]
|
|
|
|
|
|
def test_upload_training_data_empty_body():
|
|
"""Upload endpoint rejects empty body."""
|
|
resp = client.post(
|
|
"/avm/v2/upload-training-data",
|
|
content=b"",
|
|
headers={"Content-Type": "text/csv"},
|
|
)
|
|
assert resp.status_code == 400
|
|
|
|
|
|
# ── A/B config endpoint ─────────────────────────────────────────
|
|
|
|
|
|
def test_ab_config_no_registry():
|
|
"""AB config endpoint returns 404 when no model is registered (heuristic-only run)."""
|
|
resp = client.post("/avm/v2/ab-config", json={"traffic_pct": 0.10})
|
|
# Fresh test env has no registry → 404
|
|
assert resp.status_code == 404
|