"""Tests for AVM v2 ensemble endpoints.""" from fastapi.testclient import TestClient from app.main import app client = TestClient(app) # ── Minimal valid request payload ─────────────────────────────── _PREDICT_PAYLOAD = { "district": "Cầu Giấy", "city": "Hà Nội", "property_type": "apartment", "area_m2": 80.0, "rooms": 2, "month": 3, "quarter": 1, } def test_predict_v2_heuristic(): """Predict using heuristic fallback (no trained models).""" resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD) assert resp.status_code == 200 data = resp.json() assert data["estimated_price_vnd"] > 0 assert 0 <= data["confidence"] <= 1 assert data["price_per_m2_vnd"] > 0 assert data["price_range_low_vnd"] < data["estimated_price_vnd"] assert data["price_range_high_vnd"] > data["estimated_price_vnd"] assert data["ensemble_method"] == "weighted_average" assert data["model_version"] == "ensemble-v2-heuristic" def test_predict_v2_returns_model_predictions(): """Heuristic should return 3 simulated model predictions.""" resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD) data = resp.json() preds = data["model_predictions"] assert len(preds) == 3 names = {p["model_name"] for p in preds} assert names == {"xgboost", "lightgbm", "catboost"} for p in preds: assert p["weight"] > 0 assert p["predicted_price_vnd"] > 0 assert p["predicted_price_per_m2_vnd"] > 0 def test_predict_v2_returns_drivers(): """Heuristic should return feature importance drivers.""" resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD) data = resp.json() drivers = data["drivers"] assert len(drivers) > 0 assert all(0 <= d["importance"] <= 1 for d in drivers) # Most important feature should be area or district price top_feature = drivers[0]["feature"] assert top_feature in ("area_m2", "avg_price_district_3m_vnd_m2") def test_predict_v2_with_full_features(): """Predict with all features populated (including new v2 features).""" payload = { **_PREDICT_PAYLOAD, "neighborhood_score": 0.8, "distance_to_cbd_km": 5.0, "distance_to_metro_km": 0.8, "distance_to_school_km": 0.5, "distance_to_hospital_km": 2.0, "distance_to_park_km": 0.3, "distance_to_mall_km": 1.0, "flood_zone_risk": 0.1, "floor_level": 12, "total_floors": 25, "direction": "southeast", "floor_ratio": 1.2, "building_age_years": 5, "has_elevator": True, "has_parking": True, "has_pool": False, "developer_reputation": 0.9, "avg_price_district_3m_vnd_m2": 85_000_000, "listing_density": 12.5, "absorption_rate": 0.3, "dom_avg": 45.0, "price_momentum_30d": 0.02, "yoy_change": 0.05, "renovation_score": 0.8, "view_quality": 0.7, "interior_quality": 0.75, "noise_level": 0.3, "natural_light": 0.8, "is_year_end": False, } resp = client.post("/avm/v2/predict", json=payload) assert resp.status_code == 200 data = resp.json() assert data["estimated_price_vnd"] > 0 assert data["confidence"] > 0 def test_predict_v2_villa_premium(): """Villas should be priced higher than apartments (same area).""" apt = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json() villa_payload = {**_PREDICT_PAYLOAD, "property_type": "villa"} villa = client.post("/avm/v2/predict", json=villa_payload).json() assert villa["price_per_m2_vnd"] > apt["price_per_m2_vnd"] def test_predict_v2_year_end_premium(): """Q4/Tết season should add a premium.""" normal = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "is_year_end": False, "month": 6, "quarter": 2}, ).json() year_end = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "is_year_end": True, "month": 12, "quarter": 4}, ).json() assert year_end["estimated_price_vnd"] > normal["estimated_price_vnd"] def test_predict_v2_no_legal_paper_discount(): """Properties without legal papers should be discounted.""" with_paper = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json() without_paper = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "has_legal_paper": False}, ).json() assert without_paper["estimated_price_vnd"] < with_paper["estimated_price_vnd"] def test_predict_v2_validation_error(): """Missing required fields should return 422.""" resp = client.post("/avm/v2/predict", json={"area_m2": 80}) assert resp.status_code == 422 def test_predict_v2_invalid_area(): """Zero or negative area should be rejected.""" resp = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "area_m2": 0}, ) assert resp.status_code == 422 # ── New v2 features: neighborhood, floor, direction, developer ── def test_predict_v2_neighborhood_premium(): """High neighborhood score should increase price.""" low_nb = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "neighborhood_score": 0.2}, ).json() high_nb = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "neighborhood_score": 0.9}, ).json() assert high_nb["estimated_price_vnd"] > low_nb["estimated_price_vnd"] def test_predict_v2_floor_level_premium(): """Higher floor apartments should command a premium.""" ground = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "floor_level": 2, "total_floors": 25}, ).json() high = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "floor_level": 20, "total_floors": 25}, ).json() assert high["estimated_price_vnd"] > ground["estimated_price_vnd"] def test_predict_v2_direction_premium(): """South-facing properties should be priced higher than north-facing.""" south = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "direction": "south"}, ).json() north = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "direction": "north"}, ).json() assert south["estimated_price_vnd"] > north["estimated_price_vnd"] def test_predict_v2_developer_reputation(): """Properties from reputable developers should be valued higher.""" low_rep = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "developer_reputation": 0.2}, ).json() high_rep = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "developer_reputation": 0.9}, ).json() assert high_rep["estimated_price_vnd"] > low_rep["estimated_price_vnd"] def test_predict_v2_direction_defaults_unknown(): """Unknown direction should not affect price (neutral).""" explicit = client.post( "/avm/v2/predict", json={**_PREDICT_PAYLOAD, "direction": "unknown"}, ).json() default = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json() assert explicit["estimated_price_vnd"] == default["estimated_price_vnd"] def test_predict_v2_drivers_include_new_features(): """Drivers should include neighborhood_score, direction, floor_level.""" resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD) data = resp.json() driver_names = {d["feature"] for d in data["drivers"]} assert "neighborhood_score" in driver_names assert "direction_encoded" in driver_names assert "floor_level" in driver_names assert "developer_reputation" in driver_names # ── Training & model info ─────────────────────────────────────── def test_train_v2_no_data(): """Training without data returns scaffold with zero metrics.""" resp = client.post( "/avm/v2/train", json={"optuna_trials": 10}, ) assert resp.status_code == 200 data = resp.json() assert "model_version" in data assert "ensemble-v2-" in data["model_version"] assert data["training_samples"] == 0 def test_model_info_v2(): """Model info endpoint should return current model version.""" resp = client.get("/avm/v2/model-info") assert resp.status_code == 200 data = resp.json() assert "model_version" in data assert data["is_active"] is True # ── Feature importance endpoint ────────────────────────────────── def test_feature_importance_heuristic(): """Dedicated endpoint returns heuristic drivers when no models are loaded.""" resp = client.get("/avm/v2/feature-importance") assert resp.status_code == 200 data = resp.json() assert data["source"] == "heuristic" assert data["model_version"] == "ensemble-v2-heuristic" drivers = data["drivers"] assert len(drivers) > 0 importances = [d["importance"] for d in drivers] assert importances == sorted(importances, reverse=True) assert all(0 <= i <= 1 for i in importances) feature_names = {d["feature"] for d in drivers} assert "area_m2" in feature_names assert "neighborhood_score" in feature_names # ── Model versioning ──────────────────────────────────────────── def test_list_versions(): """Versions endpoint returns a list.""" resp = client.get("/avm/v2/versions") assert resp.status_code == 200 data = resp.json() assert isinstance(data, list) def test_rollback_not_found(): """Rollback to non-existent version returns 404.""" resp = client.post( "/avm/v2/rollback", json={"target_version": "nonexistent-version-xyz"}, ) assert resp.status_code == 404 # ── A/B comparison tests ───────────────────────────────────── _COMPARE_PAYLOAD = { "district": "Cầu Giấy", "city": "Hà Nội", "property_type": "apartment", "area_m2": 80.0, "rooms": 2, "month": 3, "quarter": 1, } def test_compare_v1_returns_both_models(): """Compare endpoint returns v1 and v2 predictions.""" resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD) assert resp.status_code == 200 data = resp.json() assert "v1" in data assert "v2" in data assert data["v1"]["estimated_price_vnd"] > 0 assert data["v2"]["estimated_price_vnd"] > 0 assert 0 <= data["v1"]["confidence"] <= 1 assert 0 <= data["v2"]["confidence"] <= 1 def test_compare_v1_returns_diffs(): """Compare endpoint computes price and confidence differences.""" resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD) data = resp.json() expected_diff = data["v2"]["estimated_price_vnd"] - data["v1"]["estimated_price_vnd"] assert abs(data["price_diff_vnd"] - expected_diff) < 10_000 # rounding tolerance assert "price_diff_pct" in data assert isinstance(data["price_diff_pct"], float) assert "confidence_diff" in data def test_compare_v1_returns_recommendation(): """Compare endpoint provides a recommendation string.""" resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD) data = resp.json() assert "recommendation" in data assert len(data["recommendation"]) > 0 def test_compare_v1_with_v2_features(): """Compare endpoint passes v2-specific features correctly.""" payload = { **_COMPARE_PAYLOAD, "neighborhood_score": 0.8, "distance_to_cbd_km": 5.0, "distance_to_metro_km": 0.8, "flood_zone_risk": 0.1, "building_age_years": 3, "floor_level": 15, "total_floors": 30, "direction": "southeast", "has_elevator": True, "has_parking": True, "developer_reputation": 0.85, "renovation_score": 0.9, "view_quality": 0.8, "interior_quality": 0.85, } resp = client.post("/avm/v2/compare-v1", json=payload) assert resp.status_code == 200 data = resp.json() # v2 should capture these extra features assert data["v2"]["estimated_price_vnd"] > 0 assert data["v2"]["model_version"] is not None # ── Upload training data ──────────────────────────────────────── _CSV_HEADER = ( "property_type,area_m2,rooms,floor_level,total_floors,direction,floor_ratio," "building_age_years,has_elevator,has_parking,has_pool,has_legal_paper," "developer_reputation,neighborhood_score,distance_to_cbd_km,distance_to_metro_km," "distance_to_school_km,distance_to_hospital_km,distance_to_park_km,distance_to_mall_km," "flood_zone_risk,avg_price_district_3m_vnd_m2,listing_density,absorption_rate," "dom_avg,price_momentum_30d,yoy_change,renovation_score,view_quality,interior_quality," "noise_level,natural_light,month,district,price_vnd" ) _CSV_ROW = ( "apartment,80,2,5,20,south,1.0,3,1,1,0,1,0.8,0.7,5,1,0.5,2,1,3," "0.1,85000000,10,0.3,30,0.01,0.05,0.8,0.7,0.75,0.3,0.8,3,Cầu Giấy,7000000000" ) def test_upload_training_data_ok(tmp_path): """Upload endpoint accepts valid CSV and returns row count.""" from unittest.mock import patch from app import config as cfg with patch.object(cfg.settings, "model_path", str(tmp_path)): csv_body = f"{_CSV_HEADER}\n{_CSV_ROW}\n" resp = client.post( "/avm/v2/upload-training-data", content=csv_body, headers={"Content-Type": "text/csv"}, ) assert resp.status_code == 200 data = resp.json() assert data["rows_received"] == 1 def test_upload_training_data_missing_price_vnd(): """Upload endpoint rejects CSV without price_vnd column.""" bad_csv = "property_type,area_m2\napartment,80\n" resp = client.post( "/avm/v2/upload-training-data", content=bad_csv, headers={"Content-Type": "text/csv"}, ) assert resp.status_code == 400 assert "price_vnd" in resp.json()["detail"] def test_upload_training_data_empty_body(): """Upload endpoint rejects empty body.""" resp = client.post( "/avm/v2/upload-training-data", content=b"", headers={"Content-Type": "text/csv"}, ) assert resp.status_code == 400 # ── A/B config endpoint ───────────────────────────────────────── def test_ab_config_no_registry(): """AB config endpoint returns 404 when no model is registered (heuristic-only run).""" resp = client.post("/avm/v2/ab-config", json={"traffic_pct": 0.10}) # Fresh test env has no registry → 404 assert resp.status_code == 404