Files
goodgo-platform/libs/ai-services/tests/test_avm_v2.py
Ho Ngoc Hai a6e53e3d06 feat(ai-services): add AVM v2 A/B comparison endpoint and tests
Add POST /avm/v2/compare-v1 endpoint that runs both v1 (single-model)
and v2 (ensemble) AVM predictions on the same property and returns a
side-by-side comparison with price diff, confidence delta, and a
recommendation on which model to prefer.

- ABComparisonRequest/Response schemas in avm_v2 models
- compare_v1() method in AVMv2EnsembleService
- 4 new integration tests for the comparison endpoint
- All 47 Python tests pass

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-16 17:35:30 +07:00

247 lines
7.8 KiB
Python

"""Tests for AVM v2 ensemble endpoints."""
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
# ── Minimal valid request payload ───────────────────────────────
_PREDICT_PAYLOAD = {
"district": "Cầu Giấy",
"city": "Hà Nội",
"property_type": "apartment",
"area_m2": 80.0,
"rooms": 2,
"month": 3,
"quarter": 1,
}
def test_predict_v2_heuristic():
"""Predict using heuristic fallback (no trained models)."""
resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD)
assert resp.status_code == 200
data = resp.json()
assert data["estimated_price_vnd"] > 0
assert 0 <= data["confidence"] <= 1
assert data["price_per_m2_vnd"] > 0
assert data["price_range_low_vnd"] < data["estimated_price_vnd"]
assert data["price_range_high_vnd"] > data["estimated_price_vnd"]
assert data["ensemble_method"] == "weighted_average"
assert data["model_version"] == "ensemble-v2-heuristic"
def test_predict_v2_returns_model_predictions():
"""Heuristic should return 3 simulated model predictions."""
resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD)
data = resp.json()
preds = data["model_predictions"]
assert len(preds) == 3
names = {p["model_name"] for p in preds}
assert names == {"xgboost", "lightgbm", "catboost"}
for p in preds:
assert p["weight"] > 0
assert p["predicted_price_vnd"] > 0
assert p["predicted_price_per_m2_vnd"] > 0
def test_predict_v2_returns_drivers():
"""Heuristic should return feature importance drivers."""
resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD)
data = resp.json()
drivers = data["drivers"]
assert len(drivers) > 0
assert all(0 <= d["importance"] <= 1 for d in drivers)
# Most important feature should be area or district price
top_feature = drivers[0]["feature"]
assert top_feature in ("area_m2", "avg_price_district_3m_vnd_m2")
def test_predict_v2_with_full_features():
"""Predict with all features populated."""
payload = {
**_PREDICT_PAYLOAD,
"distance_to_cbd_km": 5.0,
"distance_to_metro_km": 0.8,
"distance_to_school_km": 0.5,
"distance_to_hospital_km": 2.0,
"distance_to_park_km": 0.3,
"distance_to_mall_km": 1.0,
"flood_zone_risk": 0.1,
"floor_ratio": 1.2,
"building_age_years": 5,
"has_elevator": True,
"has_parking": True,
"has_pool": False,
"avg_price_district_3m_vnd_m2": 85_000_000,
"listing_density": 12.5,
"absorption_rate": 0.3,
"dom_avg": 45.0,
"price_momentum_30d": 0.02,
"yoy_change": 0.05,
"renovation_score": 0.8,
"view_quality": 0.7,
"interior_quality": 0.75,
"noise_level": 0.3,
"natural_light": 0.8,
"is_year_end": False,
}
resp = client.post("/avm/v2/predict", json=payload)
assert resp.status_code == 200
data = resp.json()
assert data["estimated_price_vnd"] > 0
assert data["confidence"] > 0
def test_predict_v2_villa_premium():
"""Villas should be priced higher than apartments (same area)."""
apt = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json()
villa_payload = {**_PREDICT_PAYLOAD, "property_type": "villa"}
villa = client.post("/avm/v2/predict", json=villa_payload).json()
assert villa["price_per_m2_vnd"] > apt["price_per_m2_vnd"]
def test_predict_v2_year_end_premium():
"""Q4/Tết season should add a premium."""
normal = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "is_year_end": False, "month": 6, "quarter": 2},
).json()
year_end = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "is_year_end": True, "month": 12, "quarter": 4},
).json()
assert year_end["estimated_price_vnd"] > normal["estimated_price_vnd"]
def test_predict_v2_no_legal_paper_discount():
"""Properties without legal papers should be discounted."""
with_paper = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json()
without_paper = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "has_legal_paper": False},
).json()
assert without_paper["estimated_price_vnd"] < with_paper["estimated_price_vnd"]
def test_predict_v2_validation_error():
"""Missing required fields should return 422."""
resp = client.post("/avm/v2/predict", json={"area_m2": 80})
assert resp.status_code == 422
def test_predict_v2_invalid_area():
"""Zero or negative area should be rejected."""
resp = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "area_m2": 0},
)
assert resp.status_code == 422
def test_train_v2_scaffold():
"""Training endpoint should return scaffold response."""
resp = client.post(
"/avm/v2/train",
json={"optuna_trials": 10},
)
assert resp.status_code == 200
data = resp.json()
assert "model_version" in data
assert "ensemble-v2-" in data["model_version"]
assert data["metrics"]["mae"] == 0.0 # scaffold returns zeros
assert "xgboost" in data["best_params"]
assert "lightgbm" in data["best_params"]
assert "catboost" in data["best_params"]
def test_model_info_v2():
"""Model info endpoint should return current model version."""
resp = client.get("/avm/v2/model-info")
assert resp.status_code == 200
data = resp.json()
assert "model_version" in data
assert data["is_active"] is True
# ── A/B comparison tests ─────────────────────────────────────
_COMPARE_PAYLOAD = {
"district": "Cầu Giấy",
"city": "Hà Nội",
"property_type": "apartment",
"area_m2": 80.0,
"rooms": 2,
"month": 3,
"quarter": 1,
}
def test_compare_v1_returns_both_models():
"""Compare endpoint returns v1 and v2 predictions."""
resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD)
assert resp.status_code == 200
data = resp.json()
assert "v1" in data
assert "v2" in data
assert data["v1"]["estimated_price_vnd"] > 0
assert data["v2"]["estimated_price_vnd"] > 0
assert 0 <= data["v1"]["confidence"] <= 1
assert 0 <= data["v2"]["confidence"] <= 1
def test_compare_v1_returns_diffs():
"""Compare endpoint computes price and confidence differences."""
resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD)
data = resp.json()
expected_diff = data["v2"]["estimated_price_vnd"] - data["v1"]["estimated_price_vnd"]
assert abs(data["price_diff_vnd"] - expected_diff) < 10_000 # rounding tolerance
assert "price_diff_pct" in data
assert isinstance(data["price_diff_pct"], float)
assert "confidence_diff" in data
def test_compare_v1_returns_recommendation():
"""Compare endpoint provides a recommendation string."""
resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD)
data = resp.json()
assert "recommendation" in data
assert len(data["recommendation"]) > 0
def test_compare_v1_with_v2_features():
"""Compare endpoint passes v2-specific features correctly."""
payload = {
**_COMPARE_PAYLOAD,
"distance_to_cbd_km": 5.0,
"distance_to_metro_km": 0.8,
"flood_zone_risk": 0.1,
"building_age_years": 3,
"has_elevator": True,
"has_parking": True,
"renovation_score": 0.9,
"view_quality": 0.8,
"interior_quality": 0.85,
}
resp = client.post("/avm/v2/compare-v1", json=payload)
assert resp.status_code == 200
data = resp.json()
# v2 should capture these extra features
assert data["v2"]["estimated_price_vnd"] > 0
assert data["v2"]["model_version"] is not None