"""Tests for AVM v2 ensemble endpoints.""" from fastapi.testclient import TestClient from app.main import app client = TestClient(app) # ── Minimal valid request payload ─────────────────────────────── _PREDICT_PAYLOAD = { "district": "Cầu Giấy", "city": "Hà Nội", "property_type": "apartment", "area_m2": 80.0, "rooms": 2, "month": 3, "quarter": 1, } def test_predict_v2_heuristic(): """Predict using heuristic fallback (no trained models).""" resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD) assert resp.status_code == 200 data = resp.json() assert data["estimated_price_vnd"] > 0 assert 0 <= data["confidence"] <= 1 assert data["price_per_m2_vnd"] > 0 assert data["price_range_low_vnd"] < data["estimated_price_vnd"] assert data["price_range_high_vnd"] > data["estimated_price_vnd"] assert data["ensemble_method"] == "weighted_average" assert data["model_version"] == "ensemble-v2-heuristic" def test_predict_v2_returns_model_predictions(): """Heuristic should return 3 simulated model predictions.""" resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD) data = resp.json() preds = data["model_predictions"] assert len(preds) == 3 names = {p["model_name"] for p in preds} assert names == {"xgboost", "lightgbm", "catboost"} for p in preds: assert p["weight"] > 0 assert p["predicted_price_vnd"] > 0 assert p["predicted_price_per_m2_vnd"] > 0 def test_predict_v2_returns_drivers(): """Heuristic should return feature importance drivers.""" resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD) data = resp.json() drivers = data["drivers"] assert len(drivers) > 0 assert all(0 <= d["importance"] <= 1 for d in drivers) # Most important feature should be area or district price top_feature = drivers[0]["feature"] assert top_feature in ("area_m2", "avg_price_district_3m_vnd_m2") def test_predict_v2_with_full_features(): """Predict with all features populated.""" payload = { **_PREDICT_PAYLOAD, "distance_to_cbd_km": 5.0, "distance_to_metro_km": 0.8, "distance_to_school_km": 0.5, "distance_to_hospital_km": 2.0, "distance_to_park_km": 0.3, "distance_to_mall_km": 1.0, "flood_zone_risk": 0.1, "floor_ratio": 1.2, "building_age_years": 5, "has_elevator": True, "has_parking": True, "has_pool": False, "avg_price_district_3m_vnd_m2": 85_000_000, "listing_density": 12.5, "absorption_rate": 0.3, "dom_avg": 45.0, "price_momentum_30d": 0.02, "yoy_change": 0.05, "renovation_score": 0.8, "view_quality": 0.7, "interior_quality": 0.75, "noise_level": 0.3, "natural_light": 0.8, "is_year_end": False, } resp = client.post("/avm/v2/predict", json=payload) assert resp.status_code == 200 data = resp.json() assert data["estimated_price_vnd"] > 0 assert data["confidence"] > 0 def test_predict_v2_villa_premium(): """Villas should be priced higher than apartments (same area).""" apt = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json() villa_payload = {**_PREDICT_PAYLOAD, "property_type": "villa"} villa = client.post("/avm/v2/predict", json=villa_payload).json() assert villa["price_per_m2_vnd"] > apt["price_per_m2_vnd"] def test_predict_v2_year_end_premium(): """Q4/Tết season should add a premium.""" normal = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "is_year_end": False, "month": 6, "quarter": 2}, ).json() year_end = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "is_year_end": True, "month": 12, "quarter": 4}, ).json() assert year_end["estimated_price_vnd"] > normal["estimated_price_vnd"] def test_predict_v2_no_legal_paper_discount(): """Properties without legal papers should be discounted.""" with_paper = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json() without_paper = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "has_legal_paper": False}, ).json() assert without_paper["estimated_price_vnd"] < with_paper["estimated_price_vnd"] def test_predict_v2_validation_error(): """Missing required fields should return 422.""" resp = client.post("/avm/v2/predict", json={"area_m2": 80}) assert resp.status_code == 422 def test_predict_v2_invalid_area(): """Zero or negative area should be rejected.""" resp = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "area_m2": 0}, ) assert resp.status_code == 422 def test_train_v2_scaffold(): """Training endpoint should return scaffold response.""" resp = client.post( "/avm/v2/train", json={"optuna_trials": 10}, ) assert resp.status_code == 200 data = resp.json() assert "model_version" in data assert "ensemble-v2-" in data["model_version"] assert data["metrics"]["mae"] == 0.0 # scaffold returns zeros assert "xgboost" in data["best_params"] assert "lightgbm" in data["best_params"] assert "catboost" in data["best_params"] def test_model_info_v2(): """Model info endpoint should return current model version.""" resp = client.get("/avm/v2/model-info") assert resp.status_code == 200 data = resp.json() assert "model_version" in data assert data["is_active"] is True