Compare commits

..

10 Commits

Author SHA1 Message Date
Ho Ngoc Hai
99385d8263 feat(auth): validate KYC image URL hosts match MinIO bucket
Closes TEC-2725. Backend KYC presign + submit endpoints already landed in
8f8e20f; this adds the remaining acceptance criterion — host validation on
presigned URLs accepted via /auth/kyc/submit.

- Add IMediaStorageService.isTrustedUrl(url) — host+bucket check, supports
  MINIO_TRUSTED_HOSTS for CDN aliases
- SubmitKycHandler rejects imageUrls pointing outside our MinIO bucket
- Update handler specs with mock + new untrusted-host test

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-18 00:32:02 +07:00
Ho Ngoc Hai
e18390ead9 feat(auth): add phoneNumber to profile update with SMS OTP re-verify
TEC-2722 — PATCH /api/v1/auth/profile now accepts phoneNumber alongside
fullName, avatarUrl, and email. Phone changes are deferred until the user
confirms the SMS OTP via POST /api/v1/auth/profile/verify-phone, mirroring
the existing email-change OTP flow.

- Add PhoneChangeRequestedEvent + user.phone_change_otp SMS template
- Add VerifyPhoneChangeHandler with Redis-backed 10-minute OTP
- Re-check phone uniqueness at verify time to catch races
- Extend unit tests for UpdateProfileHandler + add VerifyPhoneChangeHandler spec

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-18 00:17:12 +07:00
Ho Ngoc Hai
78e46a024b feat(web): enhance KYC upload with validation, previews, test ids
- Add file type (JPG/PNG/WEBP/PDF) and 5MB size validation
- Show image previews with cleanup of object URLs
- Add data-testid attributes on inputs, buttons, previews, alerts for E2E
- Improve error messaging for expired/failed presigned uploads (403 vs other)
- Guard step 2->3 advance when front image missing

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-18 00:06:13 +07:00
Ho Ngoc Hai
b21f197c09 feat(notifications): add Zalo OA webhook controller + WebSocket gateway tests
- Add ZaloOaWebhookController: GET verification endpoint, POST event handler
  for follow/unfollow/user_send_text events with user linking via OAuthAccount
- Register webhook controller in NotificationsModule
- Add 13 unit tests for webhook (challenge verify, follow/unfollow/message
  handling, linked/unlinked users, error resilience)
- Add 18 unit tests for NotificationsGateway (JWT auth, multi-device tracking,
  disconnect cleanup, notification.sent event, Redis cache, unread count)

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-16 18:31:02 +07:00
Ho Ngoc Hai
8e9d021465 feat: add unit tests for featured listings, neighborhood scores + price history chart
- Add unit tests for FeatureListingHandler (6 tests) and ActivateFeaturedListingHandler (6 tests)
- Add unit tests for NeighborhoodScoreServiceImpl (5 tests) and GetNeighborhoodScoreHandler (2 tests)
- Add PriceHistoryChart component with recharts LineChart for listing detail page
- Wire up price history API client and integrate chart into listing detail view

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-16 18:21:44 +07:00
Ho Ngoc Hai
0dda2bffdb feat(api): add POST /avm/industrial endpoint for industrial rent estimation
Wire NestJS controller to Python AI service's industrial AVM. Adds CQRS
query/handler, Swagger-annotated DTOs, AI client method, and 7 unit tests
covering parameter mapping, response camelCase conversion, and error handling.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-16 18:01:23 +07:00
Ho Ngoc Hai
9eaec46a37 feat(ai-services): AVM v2 residential — expanded features, training pipeline, model versioning
Add neighborhood_score, developer_reputation, floor_level, direction premiums
to the multi-model ensemble. Implement real Optuna-based training pipeline
for XGBoost/LightGBM/CatBoost with grouped train/val/test splits. Add
file-based model registry with rollback and list-versions endpoints.
23 Python tests covering all new features.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-16 17:55:03 +07:00
Ho Ngoc Hai
6cf2c23170 feat(listings): add source field to PriceHistory + unit tests
- Add `source` column to PriceHistory Prisma model (manual_update, admin_override, market_adjustment)
- Add migration for the new column with default 'manual_update'
- Update ListingPriceChangedEvent domain event with optional source parameter
- Update RecordPriceHistoryHandler to persist source
- Update GetPriceHistoryHandler to return source in query results
- Add unit tests for RecordPriceHistoryHandler (5 cases)
- Add unit tests for GetPriceHistoryHandler (3 cases)
- Add ListingPriceChangedEvent tests to domain events spec (4 cases)
- Add getPriceHistory controller tests (2 cases)

All 1805 tests pass, typecheck clean.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-16 17:43:48 +07:00
Ho Ngoc Hai
f3a2a012c4 feat(web): add price range filter and list view to /du-an page
Add minPrice/maxPrice inputs to ProjectFilterBar and introduce a
list view mode alongside the existing grid/map toggle for project
browsing.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-16 17:40:30 +07:00
Ho Ngoc Hai
a6e53e3d06 feat(ai-services): add AVM v2 A/B comparison endpoint and tests
Add POST /avm/v2/compare-v1 endpoint that runs both v1 (single-model)
and v2 (ensemble) AVM predictions on the same property and returns a
side-by-side comparison with price diff, confidence delta, and a
recommendation on which model to prefer.

- ABComparisonRequest/Response schemas in avm_v2 models
- compare_v1() method in AVMv2EnsembleService
- 4 new integration tests for the comparison endpoint
- All 47 Python tests pass

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-04-16 17:35:30 +07:00
60 changed files with 4009 additions and 112 deletions

View File

@@ -5,6 +5,7 @@ import { TrackEventHandler } from './application/commands/track-event/track-even
import { UpdateMarketIndexHandler } from './application/commands/update-market-index/update-market-index.handler';
import { ListingCreatedModerationHandler } from './application/event-handlers/listing-created-moderation.handler';
import { BatchValuationHandler } from './application/queries/batch-valuation/batch-valuation.handler';
import { IndustrialValuationHandler } from './application/queries/industrial-valuation/industrial-valuation.handler';
import { GetDistrictStatsHandler } from './application/queries/get-district-stats/get-district-stats.handler';
import { GetHeatmapHandler } from './application/queries/get-heatmap/get-heatmap.handler';
import { GetMarketReportHandler } from './application/queries/get-market-report/get-market-report.handler';
@@ -25,6 +26,7 @@ import { MarketIndexCronService } from './infrastructure/services/market-index-c
import { NeighborhoodScoreServiceImpl } from './infrastructure/services/neighborhood-score.service';
import { PrismaAVMService } from './infrastructure/services/prisma-avm.service';
import { AnalyticsController } from './presentation/controllers/analytics.controller';
import { AvmController } from './presentation/controllers/avm.controller';
const CommandHandlers = [
TrackEventHandler,
@@ -42,6 +44,7 @@ const QueryHandlers = [
ValuationHistoryHandler,
ValuationComparisonHandler,
GetNeighborhoodScoreHandler,
IndustrialValuationHandler,
];
const EventHandlers = [
@@ -50,7 +53,7 @@ const EventHandlers = [
@Module({
imports: [CqrsModule],
controllers: [AnalyticsController],
controllers: [AnalyticsController, AvmController],
providers: [
// AI service client
{ provide: AI_SERVICE_CLIENT, useClass: AiServiceClient },

View File

@@ -0,0 +1,51 @@
import { type INeighborhoodScoreService, type NeighborhoodScoreResult } from '../../domain/services/neighborhood-score.service';
import { GetNeighborhoodScoreHandler } from '../queries/get-neighborhood-score/get-neighborhood-score.handler';
import { GetNeighborhoodScoreQuery } from '../queries/get-neighborhood-score/get-neighborhood-score.query';
const sampleScore: NeighborhoodScoreResult = {
district: 'Quận 1',
city: 'Hồ Chí Minh',
educationScore: 8,
healthcareScore: 7,
transportScore: 9,
shoppingScore: 6,
greeneryScore: 5,
safetyScore: 4,
totalScore: 68.5,
poiCounts: { education: 12, healthcare: 5, transport: 10, shopping: 6, greenery: 3, safety: 2 },
calculatedAt: new Date(),
};
describe('GetNeighborhoodScoreHandler', () => {
let handler: GetNeighborhoodScoreHandler;
let mockService: { [K in keyof INeighborhoodScoreService]: ReturnType<typeof vi.fn> };
beforeEach(() => {
mockService = {
getScore: vi.fn(),
calculateAndSave: vi.fn(),
};
handler = new GetNeighborhoodScoreHandler(mockService as any);
});
it('returns cached score when available', async () => {
mockService.getScore.mockResolvedValue(sampleScore);
const result = await handler.execute(new GetNeighborhoodScoreQuery('Quận 1', 'Hồ Chí Minh'));
expect(result).toEqual(sampleScore);
expect(mockService.getScore).toHaveBeenCalledWith('Quận 1', 'Hồ Chí Minh');
expect(mockService.calculateAndSave).not.toHaveBeenCalled();
});
it('calculates and saves score when no cached score exists', async () => {
mockService.getScore.mockResolvedValue(null);
mockService.calculateAndSave.mockResolvedValue(sampleScore);
const result = await handler.execute(new GetNeighborhoodScoreQuery('Quận 2', 'Hồ Chí Minh'));
expect(result).toEqual(sampleScore);
expect(mockService.getScore).toHaveBeenCalledWith('Quận 2', 'Hồ Chí Minh');
expect(mockService.calculateAndSave).toHaveBeenCalledWith('Quận 2', 'Hồ Chí Minh');
});
});

View File

@@ -0,0 +1,141 @@
import { InternalServerErrorException } from '@nestjs/common';
import { type IAiServiceClient } from '../../../../infrastructure/services/ai-service.client';
import { IndustrialValuationHandler } from '../industrial-valuation.handler';
import { IndustrialValuationQuery } from '../industrial-valuation.query';
describe('IndustrialValuationHandler', () => {
let handler: IndustrialValuationHandler;
let mockAiClient: { predictIndustrial: ReturnType<typeof vi.fn> };
let mockLogger: { error: ReturnType<typeof vi.fn> };
const query = new IndustrialValuationQuery(
'Bình Dương',
'south',
0.85,
500,
10,
25,
40,
5,
'factory',
5000,
10,
3,
2000,
0.6,
4,
'general_industrial',
0.7,
3000,
8000000,
0.75,
);
const aiResponse = {
estimated_rent_usd_m2: 5.2,
confidence: 0.65,
rent_range_low_usd_m2: 4.16,
rent_range_high_usd_m2: 6.24,
annual_rent_usd_m2: 62.4,
total_monthly_rent_usd: 26000,
comparables: [
{
park_name: 'VSIP I',
province: 'Bình Dương',
property_type: 'factory',
area_m2: 5000,
rent_usd_m2: 5.2,
similarity_score: 0.85,
},
],
drivers: [
{ feature: 'province_baseline', importance: 0.16 },
{ feature: 'property_type', importance: 0.12 },
],
model_version: 'heuristic-v1',
};
beforeEach(() => {
mockAiClient = { predictIndustrial: vi.fn() };
mockLogger = { error: vi.fn() };
handler = new IndustrialValuationHandler(
mockAiClient as unknown as IAiServiceClient,
mockLogger as any,
);
});
it('calls AI service with correct snake_case parameters', async () => {
mockAiClient.predictIndustrial.mockResolvedValue(aiResponse);
await handler.execute(query);
expect(mockAiClient.predictIndustrial).toHaveBeenCalledWith({
province: 'Bình Dương',
region: 'south',
park_occupancy_rate: 0.85,
park_area_ha: 500,
park_age_years: 10,
distance_to_port_km: 25,
distance_to_airport_km: 40,
distance_to_highway_km: 5,
property_type: 'factory',
area_m2: 5000,
ceiling_height_m: 10,
floor_load_ton_m2: 3,
power_capacity_kva: 2000,
building_coverage: 0.6,
loading_docks: 4,
zoning: 'general_industrial',
industry_demand_index: 0.7,
fdi_province_musd: 3000,
labor_cost_province_vnd: 8000000,
logistics_connectivity_score: 0.75,
});
});
it('maps AI response to camelCase DTO', async () => {
mockAiClient.predictIndustrial.mockResolvedValue(aiResponse);
const result = await handler.execute(query);
expect(result.estimatedRentUsdM2).toBe(5.2);
expect(result.confidence).toBe(0.65);
expect(result.rentRangeLowUsdM2).toBe(4.16);
expect(result.rentRangeHighUsdM2).toBe(6.24);
expect(result.annualRentUsdM2).toBe(62.4);
expect(result.totalMonthlyRentUsd).toBe(26000);
expect(result.modelVersion).toBe('heuristic-v1');
});
it('maps comparable properties to camelCase', async () => {
mockAiClient.predictIndustrial.mockResolvedValue(aiResponse);
const result = await handler.execute(query);
expect(result.comparables).toHaveLength(1);
expect(result.comparables[0]).toEqual({
parkName: 'VSIP I',
province: 'Bình Dương',
propertyType: 'factory',
areaM2: 5000,
rentUsdM2: 5.2,
similarityScore: 0.85,
});
});
it('maps drivers array', async () => {
mockAiClient.predictIndustrial.mockResolvedValue(aiResponse);
const result = await handler.execute(query);
expect(result.drivers).toHaveLength(2);
expect(result.drivers[0]).toEqual({ feature: 'province_baseline', importance: 0.16 });
});
it('throws InternalServerErrorException on AI service failure', async () => {
mockAiClient.predictIndustrial.mockRejectedValue(new Error('AI service down'));
await expect(handler.execute(query)).rejects.toThrow(InternalServerErrorException);
expect(mockLogger.error).toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,106 @@
import { Inject, InternalServerErrorException } from '@nestjs/common';
import { QueryHandler, type IQueryHandler } from '@nestjs/cqrs';
import { DomainException, type LoggerService } from '@modules/shared';
import {
AI_SERVICE_CLIENT,
type IAiServiceClient,
type AiIndustrialPredictResponse,
} from '../../../infrastructure/services/ai-service.client';
import { IndustrialValuationQuery } from './industrial-valuation.query';
export interface IndustrialValuationComparable {
parkName: string;
province: string;
propertyType: string;
areaM2: number;
rentUsdM2: number;
similarityScore: number;
}
export interface IndustrialValuationDriver {
feature: string;
importance: number;
}
export interface IndustrialValuationDto {
estimatedRentUsdM2: number;
confidence: number;
rentRangeLowUsdM2: number;
rentRangeHighUsdM2: number;
annualRentUsdM2: number;
totalMonthlyRentUsd: number;
comparables: IndustrialValuationComparable[];
drivers: IndustrialValuationDriver[];
modelVersion: string;
}
function mapResponse(res: AiIndustrialPredictResponse): IndustrialValuationDto {
return {
estimatedRentUsdM2: res.estimated_rent_usd_m2,
confidence: res.confidence,
rentRangeLowUsdM2: res.rent_range_low_usd_m2,
rentRangeHighUsdM2: res.rent_range_high_usd_m2,
annualRentUsdM2: res.annual_rent_usd_m2,
totalMonthlyRentUsd: res.total_monthly_rent_usd,
comparables: res.comparables.map((c) => ({
parkName: c.park_name,
province: c.province,
propertyType: c.property_type,
areaM2: c.area_m2,
rentUsdM2: c.rent_usd_m2,
similarityScore: c.similarity_score,
})),
drivers: res.drivers.map((d) => ({
feature: d.feature,
importance: d.importance,
})),
modelVersion: res.model_version,
};
}
@QueryHandler(IndustrialValuationQuery)
export class IndustrialValuationHandler implements IQueryHandler<IndustrialValuationQuery> {
constructor(
@Inject(AI_SERVICE_CLIENT) private readonly aiClient: IAiServiceClient,
private readonly logger: LoggerService,
) {}
async execute(query: IndustrialValuationQuery): Promise<IndustrialValuationDto> {
try {
const response = await this.aiClient.predictIndustrial({
province: query.province,
region: query.region,
park_occupancy_rate: query.parkOccupancyRate,
park_area_ha: query.parkAreaHa,
park_age_years: query.parkAgeYears,
distance_to_port_km: query.distanceToPortKm,
distance_to_airport_km: query.distanceToAirportKm,
distance_to_highway_km: query.distanceToHighwayKm,
property_type: query.propertyType,
area_m2: query.areaM2,
ceiling_height_m: query.ceilingHeightM,
floor_load_ton_m2: query.floorLoadTonM2,
power_capacity_kva: query.powerCapacityKva,
building_coverage: query.buildingCoverage,
loading_docks: query.loadingDocks,
zoning: query.zoning,
industry_demand_index: query.industryDemandIndex,
fdi_province_musd: query.fdiProvinceMusd,
labor_cost_province_vnd: query.laborCostProvinceVnd,
logistics_connectivity_score: query.logisticsConnectivityScore,
});
return mapResponse(response);
} catch (error) {
if (error instanceof DomainException) throw error;
this.logger.error(
`Failed to estimate industrial rent: ${error instanceof Error ? error.message : error}`,
error instanceof Error ? error.stack : undefined,
this.constructor.name,
);
throw new InternalServerErrorException(
'Không thể ước tính giá thuê khu công nghiệp. Vui lòng thử lại sau.',
);
}
}
}

View File

@@ -0,0 +1,24 @@
export class IndustrialValuationQuery {
constructor(
public readonly province: string,
public readonly region: string,
public readonly parkOccupancyRate: number,
public readonly parkAreaHa: number,
public readonly parkAgeYears: number,
public readonly distanceToPortKm: number,
public readonly distanceToAirportKm: number,
public readonly distanceToHighwayKm: number,
public readonly propertyType: string,
public readonly areaM2: number,
public readonly ceilingHeightM?: number,
public readonly floorLoadTonM2?: number,
public readonly powerCapacityKva?: number,
public readonly buildingCoverage?: number,
public readonly loadingDocks?: number,
public readonly zoning?: string,
public readonly industryDemandIndex?: number,
public readonly fdiProvinceMusd?: number,
public readonly laborCostProvinceVnd?: number,
public readonly logisticsConnectivityScore?: number,
) {}
}

View File

@@ -0,0 +1,132 @@
import { NeighborhoodScoreServiceImpl } from '../services/neighborhood-score.service';
describe('NeighborhoodScoreServiceImpl', () => {
let service: NeighborhoodScoreServiceImpl;
let mockPrisma: {
neighborhoodScore: { findUnique: ReturnType<typeof vi.fn>; upsert: ReturnType<typeof vi.fn> };
pOI: { count: ReturnType<typeof vi.fn> };
};
let mockLogger: { log: ReturnType<typeof vi.fn> };
beforeEach(() => {
mockPrisma = {
neighborhoodScore: {
findUnique: vi.fn(),
upsert: vi.fn(),
},
pOI: { count: vi.fn() },
};
mockLogger = { log: vi.fn() };
service = new NeighborhoodScoreServiceImpl(mockPrisma as any, mockLogger as any);
});
describe('getScore', () => {
it('returns existing score from database', async () => {
const stored = {
district: 'Quận 1',
city: 'Hồ Chí Minh',
educationScore: 8,
healthcareScore: 7,
transportScore: 9,
shoppingScore: 6,
greeneryScore: 5,
safetyScore: 4,
totalScore: 68.5,
poiCounts: { education: 12, healthcare: 5 },
calculatedAt: new Date(),
};
mockPrisma.neighborhoodScore.findUnique.mockResolvedValue(stored);
const result = await service.getScore('Quận 1', 'Hồ Chí Minh');
expect(result).not.toBeNull();
expect(result!.district).toBe('Quận 1');
expect(result!.totalScore).toBe(68.5);
expect(result!.poiCounts).toEqual({ education: 12, healthcare: 5 });
});
it('returns null when no score exists', async () => {
mockPrisma.neighborhoodScore.findUnique.mockResolvedValue(null);
const result = await service.getScore('Quận 99', 'Hồ Chí Minh');
expect(result).toBeNull();
});
});
describe('calculateAndSave', () => {
it('calculates scores from POI counts and upserts', async () => {
// Simulate POI counts: education=15 (max), healthcare=4 (50%), transport=6 (50%),
// shopping=5 (50%), greenery=3 (50%), safety=2 (50%)
const poiCountsByCategory = [15, 4, 6, 5, 3, 2];
let callIndex = 0;
mockPrisma.pOI.count.mockImplementation(() => {
return Promise.resolve(poiCountsByCategory[callIndex++]!);
});
mockPrisma.neighborhoodScore.upsert.mockImplementation(({ create }) => {
return Promise.resolve(create);
});
const result = await service.calculateAndSave('Quận 1', 'Hồ Chí Minh');
// education: 15/15 * 10 = 10 → 10 * 20/10 = 20
// healthcare: 4/8 * 10 = 5 → 5 * 20/10 = 10
// transport: 6/12 * 10 = 5 → 5 * 20/10 = 10
// shopping: 5/10 * 10 = 5 → 5 * 15/10 = 7.5
// greenery: 3/6 * 10 = 5 → 5 * 15/10 = 7.5
// safety: 2/4 * 10 = 5 → 5 * 10/10 = 5
// total = 20 + 10 + 10 + 7.5 + 7.5 + 5 = 60
expect(result.educationScore).toBe(10);
expect(result.healthcareScore).toBe(5);
expect(result.totalScore).toBe(60);
expect(mockPrisma.neighborhoodScore.upsert).toHaveBeenCalledTimes(1);
});
it('caps category scores at 10', async () => {
// All categories have way more POIs than max
mockPrisma.pOI.count.mockResolvedValue(100);
mockPrisma.neighborhoodScore.upsert.mockImplementation(({ create }) => {
return Promise.resolve(create);
});
const result = await service.calculateAndSave('Quận 1', 'Hồ Chí Minh');
// All scores capped at 10 → total = sum of weights = 100
expect(result.educationScore).toBe(10);
expect(result.healthcareScore).toBe(10);
expect(result.transportScore).toBe(10);
expect(result.shoppingScore).toBe(10);
expect(result.greeneryScore).toBe(10);
expect(result.safetyScore).toBe(10);
expect(result.totalScore).toBe(100);
});
it('returns 0 scores when no POIs exist', async () => {
mockPrisma.pOI.count.mockResolvedValue(0);
mockPrisma.neighborhoodScore.upsert.mockImplementation(({ create }) => {
return Promise.resolve(create);
});
const result = await service.calculateAndSave('Quận 1', 'Hồ Chí Minh');
expect(result.educationScore).toBe(0);
expect(result.totalScore).toBe(0);
});
it('logs the calculated score', async () => {
mockPrisma.pOI.count.mockResolvedValue(5);
mockPrisma.neighborhoodScore.upsert.mockImplementation(({ create }) => {
return Promise.resolve(create);
});
await service.calculateAndSave('Quận 1', 'Hồ Chí Minh');
expect(mockLogger.log).toHaveBeenCalledWith(
expect.stringContaining('Quận 1'),
'NeighborhoodScoreService',
);
});
});
});

View File

@@ -23,6 +23,55 @@ export interface AiPredictResponse {
price_range_high: number;
}
export interface AiIndustrialPredictRequest {
province: string;
region: string;
park_occupancy_rate: number;
park_area_ha: number;
park_age_years: number;
distance_to_port_km: number;
distance_to_airport_km: number;
distance_to_highway_km: number;
property_type: string;
area_m2: number;
ceiling_height_m?: number;
floor_load_ton_m2?: number;
power_capacity_kva?: number;
building_coverage?: number;
loading_docks?: number;
zoning?: string;
industry_demand_index?: number;
fdi_province_musd?: number;
labor_cost_province_vnd?: number;
logistics_connectivity_score?: number;
}
export interface AiIndustrialComparable {
park_name: string;
province: string;
property_type: string;
area_m2: number;
rent_usd_m2: number;
similarity_score: number;
}
export interface AiIndustrialFeatureImportance {
feature: string;
importance: number;
}
export interface AiIndustrialPredictResponse {
estimated_rent_usd_m2: number;
confidence: number;
rent_range_low_usd_m2: number;
rent_range_high_usd_m2: number;
annual_rent_usd_m2: number;
total_monthly_rent_usd: number;
comparables: AiIndustrialComparable[];
drivers: AiIndustrialFeatureImportance[];
model_version: string;
}
export interface AiModerationRequest {
text: string;
context?: string;
@@ -46,6 +95,7 @@ export const AI_SERVICE_CLIENT = Symbol('AI_SERVICE_CLIENT');
export interface IAiServiceClient {
predict(req: AiPredictRequest): Promise<AiPredictResponse>;
predictIndustrial(req: AiIndustrialPredictRequest): Promise<AiIndustrialPredictResponse>;
moderate(req: AiModerationRequest): Promise<AiModerationResponse>;
isAvailable(): Promise<boolean>;
}
@@ -66,6 +116,10 @@ export class AiServiceClient implements IAiServiceClient {
return this.post<AiPredictResponse>('/avm/predict', req);
}
async predictIndustrial(req: AiIndustrialPredictRequest): Promise<AiIndustrialPredictResponse> {
return this.post<AiIndustrialPredictResponse>('/avm/industrial/predict', req);
}
async moderate(req: AiModerationRequest): Promise<AiModerationResponse> {
return this.post<AiModerationResponse>('/moderation/check', req);
}

View File

@@ -0,0 +1,187 @@
import { type QueryBus } from '@nestjs/cqrs';
import { BatchValuationQuery } from '../../application/queries/batch-valuation/batch-valuation.query';
import { IndustrialValuationQuery } from '../../application/queries/industrial-valuation/industrial-valuation.query';
import { ValuationComparisonQuery } from '../../application/queries/valuation-comparison/valuation-comparison.query';
import { ValuationHistoryQuery } from '../../application/queries/valuation-history/valuation-history.query';
import { AvmController } from '../controllers/avm.controller';
describe('AvmController', () => {
let controller: AvmController;
let mockQueryBus: { execute: ReturnType<typeof vi.fn> };
beforeEach(() => {
mockQueryBus = { execute: vi.fn() };
controller = new AvmController(mockQueryBus as unknown as QueryBus);
});
describe('POST /avm/batch', () => {
it('dispatches BatchValuationQuery with property IDs', async () => {
const expected = {
results: [
{ propertyId: 'prop-1', valuation: { estimatedPrice: '5000000000' } },
{ propertyId: 'prop-2', valuation: { estimatedPrice: '6000000000' } },
],
};
mockQueryBus.execute.mockResolvedValue(expected);
const result = await controller.batchValuation({
propertyIds: ['prop-1', 'prop-2'],
} as any);
expect(mockQueryBus.execute).toHaveBeenCalledWith(
new BatchValuationQuery(['prop-1', 'prop-2']),
);
expect(result).toBe(expected);
});
});
describe('GET /avm/history/:propertyId', () => {
it('dispatches ValuationHistoryQuery with propertyId and limit', async () => {
const expected = { propertyId: 'prop-1', history: [], totalRecords: 0 };
mockQueryBus.execute.mockResolvedValue(expected);
const result = await controller.getHistory('prop-1', { limit: 25 } as any);
expect(mockQueryBus.execute).toHaveBeenCalledWith(
new ValuationHistoryQuery('prop-1', 25),
);
expect(result).toBe(expected);
});
it('defaults limit to 50 when not provided', async () => {
const expected = { propertyId: 'prop-1', history: [], totalRecords: 0 };
mockQueryBus.execute.mockResolvedValue(expected);
const result = await controller.getHistory('prop-1', {} as any);
expect(mockQueryBus.execute).toHaveBeenCalledWith(
new ValuationHistoryQuery('prop-1', 50),
);
expect(result).toBe(expected);
});
});
describe('GET /avm/compare', () => {
it('dispatches ValuationComparisonQuery with parsed IDs', async () => {
const expected = {
properties: [],
summary: { highestValue: null, lowestValue: null, averagePricePerM2: 0, averageConfidence: 0 },
};
mockQueryBus.execute.mockResolvedValue(expected);
const result = await controller.compare({
ids: ['prop-1', 'prop-2', 'prop-3'],
} as any);
expect(mockQueryBus.execute).toHaveBeenCalledWith(
new ValuationComparisonQuery(['prop-1', 'prop-2', 'prop-3']),
);
expect(result).toBe(expected);
});
it('handles two property IDs (minimum)', async () => {
const expected = { properties: [], summary: {} };
mockQueryBus.execute.mockResolvedValue(expected);
const result = await controller.compare({
ids: ['prop-1', 'prop-2'],
} as any);
expect(mockQueryBus.execute).toHaveBeenCalledWith(
new ValuationComparisonQuery(['prop-1', 'prop-2']),
);
expect(result).toBe(expected);
});
});
describe('POST /avm/industrial', () => {
const industrialDto = {
province: 'Bình Dương',
region: 'south',
parkOccupancyRate: 0.85,
parkAreaHa: 500,
parkAgeYears: 10,
distanceToPortKm: 25,
distanceToAirportKm: 40,
distanceToHighwayKm: 5,
propertyType: 'factory',
areaM2: 5000,
ceilingHeightM: 10,
loadingDocks: 4,
zoning: 'general_industrial',
};
it('dispatches IndustrialValuationQuery with all required fields', async () => {
const expected = {
estimatedRentUsdM2: 5.2,
confidence: 0.65,
rentRangeLowUsdM2: 4.16,
rentRangeHighUsdM2: 6.24,
annualRentUsdM2: 62.4,
totalMonthlyRentUsd: 26000,
comparables: [],
drivers: [],
modelVersion: 'heuristic-v1',
};
mockQueryBus.execute.mockResolvedValue(expected);
const result = await controller.industrialValuation(industrialDto as any);
expect(mockQueryBus.execute).toHaveBeenCalledWith(
new IndustrialValuationQuery(
'Bình Dương',
'south',
0.85,
500,
10,
25,
40,
5,
'factory',
5000,
10,
undefined,
undefined,
undefined,
4,
'general_industrial',
undefined,
undefined,
undefined,
undefined,
),
);
expect(result).toBe(expected);
});
it('passes optional fields when provided', async () => {
const fullDto = {
...industrialDto,
floorLoadTonM2: 3,
powerCapacityKva: 2000,
buildingCoverage: 0.6,
industryDemandIndex: 0.7,
fdiProvinceMusd: 3000,
laborCostProvinceVnd: 8000000,
logisticsConnectivityScore: 0.75,
};
const expected = {
estimatedRentUsdM2: 5.8,
confidence: 0.72,
comparables: [],
drivers: [],
};
mockQueryBus.execute.mockResolvedValue(expected);
const result = await controller.industrialValuation(fullDto as any);
const call = mockQueryBus.execute.mock.calls[0]![0] as IndustrialValuationQuery;
expect(call.province).toBe('Bình Dương');
expect(call.floorLoadTonM2).toBe(3);
expect(call.powerCapacityKva).toBe(2000);
expect(call.buildingCoverage).toBe(0.6);
expect(call.logisticsConnectivityScore).toBe(0.75);
expect(result).toBe(expected);
});
});
});

View File

@@ -0,0 +1,126 @@
import {
Body,
Controller,
Get,
Param,
Post,
Query,
UseGuards,
} from '@nestjs/common';
import { type QueryBus } from '@nestjs/cqrs';
import { ApiTags, ApiOperation, ApiResponse, ApiBearerAuth, ApiParam, ApiQuery } from '@nestjs/swagger';
import { JwtAuthGuard } from '@modules/auth';
import { EndpointRateLimit, EndpointRateLimitGuard } from '@modules/shared';
import { RequireQuota, QuotaGuard } from '@modules/subscriptions';
import { type BatchValuationDto as BatchValuationResultDto } from '../../application/queries/batch-valuation/batch-valuation.handler';
import { BatchValuationQuery } from '../../application/queries/batch-valuation/batch-valuation.query';
import { type IndustrialValuationDto as IndustrialValuationResultDto } from '../../application/queries/industrial-valuation/industrial-valuation.handler';
import { IndustrialValuationQuery } from '../../application/queries/industrial-valuation/industrial-valuation.query';
import { type ValuationComparisonDto as ValuationComparisonResultDto } from '../../application/queries/valuation-comparison/valuation-comparison.handler';
import { ValuationComparisonQuery } from '../../application/queries/valuation-comparison/valuation-comparison.query';
import { type ValuationHistoryDto as ValuationHistoryResultDto } from '../../application/queries/valuation-history/valuation-history.handler';
import { ValuationHistoryQuery } from '../../application/queries/valuation-history/valuation-history.query';
import { type AvmCompareQueryDto } from '../dto/avm-compare-query.dto';
import { type BatchValuationDto } from '../dto/batch-valuation.dto';
import { type IndustrialValuationDto } from '../dto/industrial-valuation.dto';
import { type ValuationHistoryDto } from '../dto/valuation-history.dto';
@ApiTags('avm')
@Controller('avm')
export class AvmController {
constructor(
private readonly queryBus: QueryBus,
) {}
@ApiBearerAuth('JWT')
@EndpointRateLimit({ limit: 10, windowSeconds: 60, keyStrategy: 'user' })
@UseGuards(EndpointRateLimitGuard, JwtAuthGuard, QuotaGuard)
@RequireQuota('analytics_queries')
@Post('batch')
@ApiOperation({ summary: 'Batch valuation for multiple properties (max 50)' })
@ApiResponse({ status: 200, description: 'Batch valuation results' })
@ApiResponse({ status: 400, description: 'Invalid parameters' })
@ApiResponse({ status: 403, description: 'Quota exceeded' })
@ApiResponse({ status: 429, description: 'Rate limit exceeded — max 10 requests per 60s' })
async batchValuation(@Body() dto: BatchValuationDto): Promise<BatchValuationResultDto> {
return this.queryBus.execute(
new BatchValuationQuery(dto.propertyIds),
);
}
@ApiBearerAuth('JWT')
@UseGuards(JwtAuthGuard, QuotaGuard)
@RequireQuota('analytics_queries')
@Get('history/:propertyId')
@ApiOperation({ summary: 'Get valuation history for a property (time-series)' })
@ApiParam({ name: 'propertyId', description: 'Property ID', example: 'prop-123' })
@ApiResponse({ status: 200, description: 'Valuation history time-series data' })
@ApiResponse({ status: 403, description: 'Quota exceeded' })
async getHistory(
@Param('propertyId') propertyId: string,
@Query() dto: ValuationHistoryDto,
): Promise<ValuationHistoryResultDto> {
return this.queryBus.execute(
new ValuationHistoryQuery(propertyId, dto.limit ?? 50),
);
}
@ApiBearerAuth('JWT')
@EndpointRateLimit({ limit: 10, windowSeconds: 60, keyStrategy: 'user' })
@UseGuards(EndpointRateLimitGuard, JwtAuthGuard, QuotaGuard)
@RequireQuota('analytics_queries')
@Get('compare')
@ApiOperation({ summary: 'Compare valuations for 2-5 properties side by side' })
@ApiQuery({
name: 'ids',
description: 'Comma-separated property IDs (2-5)',
example: 'prop-1,prop-2,prop-3',
type: String,
})
@ApiResponse({ status: 200, description: 'Normalized comparison data for UI' })
@ApiResponse({ status: 400, description: 'Invalid parameters — provide 2-5 property IDs' })
@ApiResponse({ status: 403, description: 'Quota exceeded' })
@ApiResponse({ status: 429, description: 'Rate limit exceeded — max 10 requests per 60s' })
async compare(@Query() dto: AvmCompareQueryDto): Promise<ValuationComparisonResultDto> {
return this.queryBus.execute(
new ValuationComparisonQuery(dto.ids),
);
}
@ApiBearerAuth('JWT')
@EndpointRateLimit({ limit: 10, windowSeconds: 60, keyStrategy: 'user' })
@UseGuards(EndpointRateLimitGuard, JwtAuthGuard, QuotaGuard)
@RequireQuota('analytics_queries')
@Post('industrial')
@ApiOperation({ summary: 'Estimate industrial property rent using AI model' })
@ApiResponse({ status: 200, description: 'Industrial rent estimation with comparables and drivers' })
@ApiResponse({ status: 400, description: 'Invalid parameters' })
@ApiResponse({ status: 403, description: 'Quota exceeded' })
@ApiResponse({ status: 429, description: 'Rate limit exceeded — max 10 requests per 60s' })
async industrialValuation(@Body() dto: IndustrialValuationDto): Promise<IndustrialValuationResultDto> {
return this.queryBus.execute(
new IndustrialValuationQuery(
dto.province,
dto.region,
dto.parkOccupancyRate,
dto.parkAreaHa,
dto.parkAgeYears,
dto.distanceToPortKm,
dto.distanceToAirportKm,
dto.distanceToHighwayKm,
dto.propertyType,
dto.areaM2,
dto.ceilingHeightM,
dto.floorLoadTonM2,
dto.powerCapacityKva,
dto.buildingCoverage,
dto.loadingDocks,
dto.zoning,
dto.industryDemandIndex,
dto.fdiProvinceMusd,
dto.laborCostProvinceVnd,
dto.logisticsConnectivityScore,
),
);
}
}

View File

@@ -1 +1,2 @@
export { AnalyticsController } from './analytics.controller';
export { AvmController } from './avm.controller';

View File

@@ -0,0 +1,19 @@
import { ApiProperty } from '@nestjs/swagger';
import { Transform } from 'class-transformer';
import { ArrayMaxSize, ArrayMinSize, IsArray, IsString } from 'class-validator';
export class AvmCompareQueryDto {
@ApiProperty({
description: 'Comma-separated property IDs to compare (2-5)',
example: 'prop-1,prop-2,prop-3',
type: String,
})
@Transform(({ value }) =>
typeof value === 'string' ? value.split(',').map((s: string) => s.trim()).filter(Boolean) : value,
)
@IsArray()
@ArrayMinSize(2)
@ArrayMaxSize(5)
@IsString({ each: true })
ids!: string[];
}

View File

@@ -6,3 +6,5 @@ export { GetValuationDto } from './get-valuation.dto';
export { BatchValuationDto } from './batch-valuation.dto';
export { ValuationHistoryDto } from './valuation-history.dto';
export { ValuationComparisonDto } from './valuation-comparison.dto';
export { AvmCompareQueryDto } from './avm-compare-query.dto';
export { IndustrialValuationDto } from './industrial-valuation.dto';

View File

@@ -0,0 +1,139 @@
import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger';
import { Type } from 'class-transformer';
import { IsString, IsNumber, Min, Max, IsOptional } from 'class-validator';
export class IndustrialValuationDto {
@ApiProperty({ description: 'Province name (e.g. Bình Dương)', example: 'Bình Dương' })
@IsString()
province!: string;
@ApiProperty({ description: 'Region: south, north, central, mekong_delta', example: 'south' })
@IsString()
region!: string;
@ApiProperty({ description: 'Park occupancy rate (0-1)', example: 0.85 })
@IsNumber()
@Type(() => Number)
@Min(0)
@Max(1)
parkOccupancyRate!: number;
@ApiProperty({ description: 'Total park area in hectares', example: 500 })
@IsNumber()
@Type(() => Number)
@Min(0)
parkAreaHa!: number;
@ApiProperty({ description: 'Park age in years', example: 10 })
@IsNumber()
@Type(() => Number)
@Min(0)
parkAgeYears!: number;
@ApiProperty({ description: 'Distance to nearest seaport in km', example: 25 })
@IsNumber()
@Type(() => Number)
@Min(0)
distanceToPortKm!: number;
@ApiProperty({ description: 'Distance to nearest airport in km', example: 40 })
@IsNumber()
@Type(() => Number)
@Min(0)
distanceToAirportKm!: number;
@ApiProperty({ description: 'Distance to nearest highway in km', example: 5 })
@IsNumber()
@Type(() => Number)
@Min(0)
distanceToHighwayKm!: number;
@ApiProperty({
description: 'Industrial property type',
example: 'factory',
enum: ['warehouse', 'factory', 'ready_built_factory', 'ready_built_warehouse', 'open_yard', 'office_in_park'],
})
@IsString()
propertyType!: string;
@ApiProperty({ description: 'Leasable area in m²', example: 5000 })
@IsNumber()
@Type(() => Number)
@Min(1)
areaM2!: number;
@ApiPropertyOptional({ description: 'Ceiling height in meters', example: 10 })
@IsOptional()
@IsNumber()
@Type(() => Number)
@Min(0)
ceilingHeightM?: number;
@ApiPropertyOptional({ description: 'Floor load capacity in tons/m²', example: 3 })
@IsOptional()
@IsNumber()
@Type(() => Number)
@Min(0)
floorLoadTonM2?: number;
@ApiPropertyOptional({ description: 'Power capacity in kVA', example: 2000 })
@IsOptional()
@IsNumber()
@Type(() => Number)
@Min(0)
powerCapacityKva?: number;
@ApiPropertyOptional({ description: 'Building coverage ratio (0-1)', example: 0.6 })
@IsOptional()
@IsNumber()
@Type(() => Number)
@Min(0)
@Max(1)
buildingCoverage?: number;
@ApiPropertyOptional({ description: 'Number of loading docks', example: 4 })
@IsOptional()
@IsNumber()
@Type(() => Number)
@Min(0)
loadingDocks?: number;
@ApiPropertyOptional({
description: 'Industrial zoning category',
example: 'general_industrial',
enum: ['general_industrial', 'heavy_industrial', 'light_industrial', 'logistics', 'free_trade_zone', 'high_tech'],
})
@IsOptional()
@IsString()
zoning?: string;
@ApiPropertyOptional({ description: 'Local industry demand index (0-1)', example: 0.7 })
@IsOptional()
@IsNumber()
@Type(() => Number)
@Min(0)
@Max(1)
industryDemandIndex?: number;
@ApiPropertyOptional({ description: 'Province FDI inflow in million USD', example: 3000 })
@IsOptional()
@IsNumber()
@Type(() => Number)
@Min(0)
fdiProvinceMusd?: number;
@ApiPropertyOptional({ description: 'Average province labor cost in VND/month', example: 8000000 })
@IsOptional()
@IsNumber()
@Type(() => Number)
@Min(0)
laborCostProvinceVnd?: number;
@ApiPropertyOptional({ description: 'Logistics connectivity score (0-1)', example: 0.7 })
@IsOptional()
@IsNumber()
@Type(() => Number)
@Min(0)
@Max(1)
logisticsConnectivityScore?: number;
}

View File

@@ -1,8 +1,8 @@
import { type IMediaStorageService } from '../../../../listings/infrastructure/services/media-storage.service';
import { UserEntity } from '../../domain/entities/user.entity';
import { type IUserRepository } from '../../domain/repositories/user.repository';
import { type HashedPassword } from '../../domain/value-objects/hashed-password.vo';
import { Phone } from '../../domain/value-objects/phone.vo';
import { type IMediaStorageService } from '../../../../listings/infrastructure/services/media-storage.service';
import { GenerateKycUploadUrlsCommand } from '../commands/generate-kyc-upload-urls/generate-kyc-upload-urls.command';
import { GenerateKycUploadUrlsHandler } from '../commands/generate-kyc-upload-urls/generate-kyc-upload-urls.handler';
@@ -42,6 +42,7 @@ describe('GenerateKycUploadUrlsHandler', () => {
getPresignedUploadUrl: vi.fn(),
generatePresignedUpload: vi.fn(),
getPublicUrl: vi.fn(),
isTrustedUrl: vi.fn().mockReturnValue(true),
};
mockLogger = {
error: vi.fn(),

View File

@@ -1,8 +1,8 @@
import { type IMediaStorageService } from '../../../../listings/infrastructure/services/media-storage.service';
import { UserEntity } from '../../domain/entities/user.entity';
import { type IUserRepository } from '../../domain/repositories/user.repository';
import { type HashedPassword } from '../../domain/value-objects/hashed-password.vo';
import { Phone } from '../../domain/value-objects/phone.vo';
import { type IMediaStorageService } from '../../../../listings/infrastructure/services/media-storage.service';
import { SubmitKycCommand } from '../commands/submit-kyc/submit-kyc.command';
import { SubmitKycHandler } from '../commands/submit-kyc/submit-kyc.handler';
@@ -43,6 +43,7 @@ describe('SubmitKycHandler', () => {
getPresignedUploadUrl: vi.fn(),
generatePresignedUpload: vi.fn(),
getPublicUrl: vi.fn(),
isTrustedUrl: vi.fn().mockReturnValue(true),
};
mockCache = {
invalidate: vi.fn().mockResolvedValue(undefined),
@@ -137,6 +138,27 @@ describe('SubmitKycHandler', () => {
expect(result.message).toBeTruthy();
expect(user.kycStatus).toBe('PENDING');
});
it('rejects untrusted image URL hosts', async () => {
const user = createTestUser();
mockUserRepo.findById.mockResolvedValue(user);
mockMediaStorage.isTrustedUrl.mockImplementation((url: string) =>
url.startsWith('https://minio/'),
);
const command = new SubmitKycCommand(
'user-1',
'CCCD',
'012345678901',
undefined,
undefined,
undefined,
{ frontImageUrl: 'https://evil.example.com/kyc/front.jpg' },
);
await expect(handler.execute(command)).rejects.toThrow();
expect(mockUserRepo.update).not.toHaveBeenCalled();
});
});
describe('legacy file upload flow', () => {

View File

@@ -191,4 +191,85 @@ describe('UpdateProfileHandler', () => {
expect(mockUserRepo.update).toHaveBeenCalledWith(user);
expect(mockCache.invalidate).toHaveBeenCalled();
});
it('defers phone change via SMS OTP instead of updating directly', async () => {
const user = createTestUser();
mockUserRepo.findById.mockResolvedValue(user);
mockUserRepo.findByPhone.mockResolvedValue(null);
mockUserRepo.update.mockResolvedValue(undefined);
const command = new UpdateProfileCommand(
'user-1',
undefined,
undefined,
undefined,
'0987654321',
);
const result = await handler.execute(command);
// Phone should NOT change yet — deferred pending OTP
expect(result.phoneNumber).toBe('+84912345678');
expect(result.phoneChangePending).toBe(true);
expect(mockRedis.set).toHaveBeenCalledWith(
'auth:phone_change_otp:user-1',
expect.stringContaining('+84987654321'),
600,
);
expect(mockEventBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
eventName: 'user.phone_change_requested',
newPhone: '+84987654321',
}),
);
});
it('throws ConflictException when new phone is already taken', async () => {
const user = createTestUser();
const otherUser = createTestUser({ id: 'user-2' });
mockUserRepo.findById.mockResolvedValue(user);
mockUserRepo.findByPhone.mockResolvedValue(otherUser);
const command = new UpdateProfileCommand(
'user-1',
undefined,
undefined,
undefined,
'0987654321',
);
await expect(handler.execute(command)).rejects.toThrow('Số điện thoại đã được sử dụng');
});
it('skips SMS OTP when phone is unchanged', async () => {
const user = createTestUser();
mockUserRepo.findById.mockResolvedValue(user);
mockUserRepo.update.mockResolvedValue(undefined);
const command = new UpdateProfileCommand(
'user-1',
undefined,
undefined,
undefined,
'0912345678',
);
const result = await handler.execute(command);
expect(mockRedis.set).not.toHaveBeenCalled();
expect(mockEventBus.publish).not.toHaveBeenCalled();
expect(result.phoneChangePending).toBeUndefined();
});
it('throws ValidationException for invalid phone format', async () => {
const user = createTestUser();
mockUserRepo.findById.mockResolvedValue(user);
const command = new UpdateProfileCommand(
'user-1',
undefined,
undefined,
undefined,
'not-a-phone',
);
await expect(handler.execute(command)).rejects.toThrow('Số điện thoại');
});
});

View File

@@ -0,0 +1,119 @@
import { UserEntity } from '../../domain/entities/user.entity';
import { type IUserRepository } from '../../domain/repositories/user.repository';
import { type HashedPassword } from '../../domain/value-objects/hashed-password.vo';
import { Phone } from '../../domain/value-objects/phone.vo';
import { VerifyPhoneChangeCommand } from '../commands/verify-phone-change/verify-phone-change.command';
import { VerifyPhoneChangeHandler } from '../commands/verify-phone-change/verify-phone-change.handler';
function createTestUser(overrides?: Partial<{ id: string; phone: string }>): UserEntity {
const phone = Phone.create(overrides?.phone ?? '0912345678').unwrap();
const pw = { value: 'hashed' } as HashedPassword;
return new UserEntity(overrides?.id ?? 'user-1', {
email: null,
phone,
passwordHash: pw,
fullName: 'Nguyen Van A',
avatarUrl: null,
role: 'BUYER',
kycStatus: 'NONE',
kycData: null,
isActive: true,
totpSecret: null,
totpEnabled: false,
totpBackupCodes: [],
totpEnabledAt: null,
});
}
describe('VerifyPhoneChangeHandler', () => {
let handler: VerifyPhoneChangeHandler;
let mockUserRepo: { [K in keyof IUserRepository]: ReturnType<typeof vi.fn> };
let mockRedis: { get: ReturnType<typeof vi.fn>; del: ReturnType<typeof vi.fn>; set: ReturnType<typeof vi.fn> };
let mockCache: { invalidate: ReturnType<typeof vi.fn> };
beforeEach(() => {
mockUserRepo = {
findById: vi.fn(),
findByPhone: vi.fn(),
findByEmail: vi.fn(),
save: vi.fn(),
update: vi.fn(),
updateMfaSecret: vi.fn(),
updateMfaEnabled: vi.fn(),
updateMfaDisabled: vi.fn(),
updateBackupCodes: vi.fn(),
};
mockRedis = {
get: vi.fn(),
del: vi.fn().mockResolvedValue(undefined),
set: vi.fn().mockResolvedValue(undefined),
};
mockCache = { invalidate: vi.fn().mockResolvedValue(undefined) };
handler = new VerifyPhoneChangeHandler(
mockUserRepo as any,
mockRedis as any,
mockCache as any,
{ error: vi.fn() } as any,
);
});
it('verifies SMS OTP and updates phone', async () => {
const user = createTestUser();
const payload = JSON.stringify({ newPhone: '+84987654321', code: '123456' });
mockRedis.get.mockResolvedValue(payload);
mockUserRepo.findById.mockResolvedValue(user);
mockUserRepo.findByPhone.mockResolvedValue(null);
mockUserRepo.update.mockResolvedValue(undefined);
const command = new VerifyPhoneChangeCommand('user-1', '123456');
const result = await handler.execute(command);
expect(result.phoneNumber).toBe('+84987654321');
expect(result.id).toBe('user-1');
expect(mockRedis.del).toHaveBeenCalledWith('auth:phone_change_otp:user-1');
expect(mockUserRepo.update).toHaveBeenCalledWith(user);
expect(mockCache.invalidate).toHaveBeenCalledWith(
expect.stringContaining('user-1'),
);
});
it('throws ValidationException when OTP has expired', async () => {
mockRedis.get.mockResolvedValue(null);
const command = new VerifyPhoneChangeCommand('user-1', '123456');
await expect(handler.execute(command)).rejects.toThrow('hết hạn');
});
it('throws ValidationException when OTP code is wrong', async () => {
const payload = JSON.stringify({ newPhone: '+84987654321', code: '123456' });
mockRedis.get.mockResolvedValue(payload);
const command = new VerifyPhoneChangeCommand('user-1', '999999');
await expect(handler.execute(command)).rejects.toThrow('không đúng');
});
it('throws ConflictException when phone was taken since OTP was issued', async () => {
const user = createTestUser();
const otherUser = createTestUser({ id: 'user-2', phone: '0987654321' });
const payload = JSON.stringify({ newPhone: '+84987654321', code: '123456' });
mockRedis.get.mockResolvedValue(payload);
mockUserRepo.findById.mockResolvedValue(user);
mockUserRepo.findByPhone.mockResolvedValue(otherUser);
const command = new VerifyPhoneChangeCommand('user-1', '123456');
await expect(handler.execute(command)).rejects.toThrow('Số điện thoại đã được sử dụng');
// OTP should be cleaned up on conflict
expect(mockRedis.del).toHaveBeenCalledWith('auth:phone_change_otp:user-1');
});
it('throws NotFoundException when user does not exist', async () => {
const payload = JSON.stringify({ newPhone: '+84987654321', code: '123456' });
mockRedis.get.mockResolvedValue(payload);
mockUserRepo.findById.mockResolvedValue(null);
const command = new VerifyPhoneChangeCommand('user-1', '123456');
await expect(handler.execute(command)).rejects.toThrow('Người dùng');
});
});

View File

@@ -49,6 +49,17 @@ export class SubmitKycHandler implements ICommandHandler<SubmitKycCommand> {
frontImageUrl = command.imageUrls.frontImageUrl;
backImageUrl = command.imageUrls.backImageUrl ?? null;
selfieUrl = command.imageUrls.selfieUrl ?? null;
// Validate URL hosts match our MinIO bucket (reject SSRF / tampering)
const untrusted: string[] = [];
if (!this.mediaStorage.isTrustedUrl(frontImageUrl)) untrusted.push('frontImageUrl');
if (backImageUrl && !this.mediaStorage.isTrustedUrl(backImageUrl)) untrusted.push('backImageUrl');
if (selfieUrl && !this.mediaStorage.isTrustedUrl(selfieUrl)) untrusted.push('selfieUrl');
if (untrusted.length > 0) {
throw new ValidationException(
`URL khong hop le (${untrusted.join(', ')}): chi chap nhan URL tu MinIO bucket cua he thong`,
);
}
} else if (command.frontImage) {
// Legacy file upload flow: upload buffers server-side
const folder = `${KYC_FOLDER}/${command.userId}`;

View File

@@ -4,5 +4,6 @@ export class UpdateProfileCommand {
public readonly fullName?: string,
public readonly avatarUrl?: string,
public readonly email?: string,
public readonly phoneNumber?: string,
) {}
}

View File

@@ -12,8 +12,10 @@ import {
ValidationException,
} from '@modules/shared';
import { EmailChangeRequestedEvent } from '../../../domain/events/email-change-requested.event';
import { PhoneChangeRequestedEvent } from '../../../domain/events/phone-change-requested.event';
import { type IUserRepository, USER_REPOSITORY } from '../../../domain/repositories/user.repository';
import { Email } from '../../../domain/value-objects/email.vo';
import { Phone } from '../../../domain/value-objects/phone.vo';
import { UpdateProfileCommand } from './update-profile.command';
/** TTL for email-change OTP codes stored in Redis (10 minutes). */
@@ -22,12 +24,20 @@ const EMAIL_CHANGE_OTP_TTL = 600;
/** Redis key prefix for pending email-change OTP. */
export const EMAIL_CHANGE_OTP_PREFIX = 'auth:email_change_otp';
/** TTL for phone-change OTP codes stored in Redis (10 minutes). */
const PHONE_CHANGE_OTP_TTL = 600;
/** Redis key prefix for pending phone-change OTP. */
export const PHONE_CHANGE_OTP_PREFIX = 'auth:phone_change_otp';
export interface UpdateProfileResultDto {
id: string;
fullName: string;
avatarUrl: string | null;
email: string | null;
phoneNumber: string;
emailChangePending?: boolean;
phoneChangePending?: boolean;
updatedAt: Date;
}
@@ -49,6 +59,7 @@ export class UpdateProfileHandler implements ICommandHandler<UpdateProfileComman
}
let emailChangePending = false;
let phoneChangePending = false;
// Validate and handle email change via OTP
if (command.email !== undefined) {
@@ -84,7 +95,41 @@ export class UpdateProfileHandler implements ICommandHandler<UpdateProfileComman
}
}
// Apply non-email fields immediately
// Validate and handle phone change via SMS OTP
if (command.phoneNumber !== undefined) {
const phoneResult = Phone.create(command.phoneNumber);
if (phoneResult.isErr) {
throw new ValidationException(phoneResult.unwrapErr());
}
const phone = phoneResult.unwrap();
// Check if phone is actually changing
if (user.phone.value !== phone.value) {
// Check uniqueness
const existingUser = await this.userRepo.findByPhone(phone.value);
if (existingUser && existingUser.id !== command.userId) {
throw new ConflictException('Số điện thoại đã được sử dụng bởi tài khoản khác');
}
// Generate OTP and store pending change in Redis
const otpCode = String(randomInt(100_000, 999_999));
const payload = JSON.stringify({ newPhone: phone.value, code: otpCode });
await this.redis.set(
`${PHONE_CHANGE_OTP_PREFIX}:${command.userId}`,
payload,
PHONE_CHANGE_OTP_TTL,
);
// Emit event so notifications module can send the SMS OTP
this.eventBus.publish(
new PhoneChangeRequestedEvent(command.userId, phone.value, otpCode),
);
phoneChangePending = true;
}
}
// Apply non-email / non-phone fields immediately
user.updateProfile(command.fullName, command.avatarUrl, undefined);
await this.userRepo.update(user);
@@ -97,7 +142,9 @@ export class UpdateProfileHandler implements ICommandHandler<UpdateProfileComman
fullName: user.fullName,
avatarUrl: user.avatarUrl,
email: user.email?.value ?? null,
phoneNumber: user.phone.value,
...(emailChangePending ? { emailChangePending: true } : {}),
...(phoneChangePending ? { phoneChangePending: true } : {}),
updatedAt: user.updatedAt,
};
} catch (error) {

View File

@@ -0,0 +1,6 @@
export class VerifyPhoneChangeCommand {
constructor(
public readonly userId: string,
public readonly code: string,
) {}
}

View File

@@ -0,0 +1,87 @@
import { Inject, InternalServerErrorException } from '@nestjs/common';
import { CommandHandler, type ICommandHandler } from '@nestjs/cqrs';
import {
CachePrefix,
CacheService,
ConflictException,
DomainException,
type LoggerService,
NotFoundException,
type RedisService,
ValidationException,
} from '@modules/shared';
import { type IUserRepository, USER_REPOSITORY } from '../../../domain/repositories/user.repository';
import { Phone } from '../../../domain/value-objects/phone.vo';
import { PHONE_CHANGE_OTP_PREFIX } from '../update-profile/update-profile.handler';
import { VerifyPhoneChangeCommand } from './verify-phone-change.command';
export interface VerifyPhoneChangeResultDto {
id: string;
phoneNumber: string;
updatedAt: Date;
}
@CommandHandler(VerifyPhoneChangeCommand)
export class VerifyPhoneChangeHandler implements ICommandHandler<VerifyPhoneChangeCommand> {
constructor(
@Inject(USER_REPOSITORY) private readonly userRepo: IUserRepository,
private readonly redis: RedisService,
private readonly cache: CacheService,
private readonly logger: LoggerService,
) {}
async execute(command: VerifyPhoneChangeCommand): Promise<VerifyPhoneChangeResultDto> {
try {
const redisKey = `${PHONE_CHANGE_OTP_PREFIX}:${command.userId}`;
const raw = await this.redis.get(redisKey);
if (!raw) {
throw new ValidationException(
'Mã xác thực đã hết hạn hoặc không tồn tại. Vui lòng yêu cầu đổi số điện thoại lại.',
);
}
const { newPhone, code } = JSON.parse(raw) as { newPhone: string; code: string };
if (code !== command.code) {
throw new ValidationException('Mã xác thực không đúng');
}
const user = await this.userRepo.findById(command.userId);
if (!user) {
throw new NotFoundException('Người dùng', command.userId);
}
// Re-check phone uniqueness (may have been taken since the request)
const existingUser = await this.userRepo.findByPhone(newPhone);
if (existingUser && existingUser.id !== command.userId) {
await this.redis.del(redisKey);
throw new ConflictException('Số điện thoại đã được sử dụng bởi tài khoản khác');
}
const phoneVo = Phone.create(newPhone).unwrap();
user.updatePhone(phoneVo);
await this.userRepo.update(user);
// Clean up OTP and invalidate profile cache
await this.redis.del(redisKey);
await this.cache.invalidate(
CacheService.buildKey(CachePrefix.USER_PROFILE, command.userId),
);
return {
id: user.id,
phoneNumber: phoneVo.value,
updatedAt: user.updatedAt,
};
} catch (error) {
if (error instanceof DomainException) throw error;
this.logger.error(
`Failed to verify phone change: ${error instanceof Error ? error.message : error}`,
error instanceof Error ? error.stack : undefined,
this.constructor.name,
);
throw new InternalServerErrorException('Không thể xác thực đổi số điện thoại');
}
}
}

View File

@@ -25,6 +25,7 @@ import { VerifyEmailChangeHandler } from './application/commands/verify-email-ch
import { VerifyKycHandler } from './application/commands/verify-kyc/verify-kyc.handler';
import { VerifyMfaChallengeHandler } from './application/commands/verify-mfa-challenge/verify-mfa-challenge.handler';
import { VerifyMfaSetupHandler } from './application/commands/verify-mfa-setup/verify-mfa-setup.handler';
import { VerifyPhoneChangeHandler } from './application/commands/verify-phone-change/verify-phone-change.handler';
import { GetAgentByUserIdHandler } from './application/queries/get-agent-by-user-id/get-agent-by-user-id.handler';
import { GetMfaStatusHandler } from './application/queries/get-mfa-status/get-mfa-status.handler';
import { GetProfileHandler } from './application/queries/get-profile/get-profile.handler';
@@ -55,6 +56,7 @@ const CommandHandlers = [
GenerateKycUploadUrlsHandler,
UpdateProfileHandler,
VerifyEmailChangeHandler,
VerifyPhoneChangeHandler,
RequestUserDeletionHandler,
CancelUserDeletionHandler,
ForceDeleteUserHandler,

View File

@@ -145,4 +145,9 @@ export class UserEntity extends AggregateRoot<string> {
if (email !== undefined) this._email = email;
this.updatedAt = new Date();
}
updatePhone(phone: Phone): void {
this._phone = phone;
this.updatedAt = new Date();
}
}

View File

@@ -1,3 +1,4 @@
export { UserRegisteredEvent } from './user-registered.event';
export { AgentVerifiedEvent } from './agent-verified.event';
export { EmailChangeRequestedEvent } from './email-change-requested.event';
export { PhoneChangeRequestedEvent } from './phone-change-requested.event';

View File

@@ -0,0 +1,12 @@
import { type DomainEvent } from '@modules/shared';
export class PhoneChangeRequestedEvent implements DomainEvent {
readonly eventName = 'user.phone_change_requested';
readonly occurredAt = new Date();
constructor(
public readonly aggregateId: string,
public readonly newPhone: string,
public readonly otpCode: string,
) {}
}

View File

@@ -12,4 +12,5 @@ export { UserDeactivatedEvent } from './domain/events/user-deactivated.event';
export { UserKycUpdatedEvent } from './domain/events/user-kyc-updated.event';
export { UserRegisteredEvent } from './domain/events/user-registered.event';
export { EmailChangeRequestedEvent } from './domain/events/email-change-requested.event';
export { PhoneChangeRequestedEvent } from './domain/events/phone-change-requested.event';
export { USER_REPOSITORY, IUserRepository } from './domain/repositories/user.repository';

View File

@@ -16,9 +16,8 @@ import {
EndpointRateLimit,
EndpointRateLimitGuard,
UnauthorizedException,
ValidationException,
} from '@modules/shared';
import { GenerateKycUploadUrlsCommand, type KycFileRequest } from '../../application/commands/generate-kyc-upload-urls/generate-kyc-upload-urls.command';
import { GenerateKycUploadUrlsCommand } from '../../application/commands/generate-kyc-upload-urls/generate-kyc-upload-urls.command';
import { LoginUserCommand } from '../../application/commands/login-user/login-user.command';
import { type LoginResult } from '../../application/commands/login-user/login-user.handler';
import { RefreshTokenCommand } from '../../application/commands/refresh-token/refresh-token.command';
@@ -29,6 +28,8 @@ import { type UpdateProfileResultDto } from '../../application/commands/update-p
import { VerifyEmailChangeCommand } from '../../application/commands/verify-email-change/verify-email-change.command';
import { type VerifyEmailChangeResultDto } from '../../application/commands/verify-email-change/verify-email-change.handler';
import { VerifyKycCommand } from '../../application/commands/verify-kyc/verify-kyc.command';
import { VerifyPhoneChangeCommand } from '../../application/commands/verify-phone-change/verify-phone-change.command';
import { type VerifyPhoneChangeResultDto } from '../../application/commands/verify-phone-change/verify-phone-change.handler';
import { type AgentDto } from '../../application/queries/get-agent-by-user-id/get-agent-by-user-id.handler';
import { GetAgentByUserIdQuery } from '../../application/queries/get-agent-by-user-id/get-agent-by-user-id.query';
import { type UserProfileDto } from '../../application/queries/get-profile/get-profile.handler';
@@ -37,12 +38,15 @@ import { type TokenService, type JwtPayload, type TokenPair } from '../../infras
import { type LocalStrategyResult } from '../../infrastructure/strategies/local.strategy';
import { CurrentUser } from '../decorators/current-user.decorator';
import { Roles } from '../decorators/roles.decorator';
import { type GenerateKycUploadUrlsDto } from '../dto/generate-kyc-upload-urls.dto';
import { LoginDto } from '../dto/login.dto';
import { type RefreshTokenDto } from '../dto/refresh-token.dto';
import { type RegisterDto } from '../dto/register.dto';
import { type SubmitKycDto } from '../dto/submit-kyc.dto';
import { type UpdateProfileDto } from '../dto/update-profile.dto';
import { type VerifyEmailChangeDto } from '../dto/verify-email-change.dto';
import { type VerifyKycDto } from '../dto/verify-kyc.dto';
import { type VerifyPhoneChangeDto } from '../dto/verify-phone-change.dto';
import { JwtAuthGuard } from '../guards/jwt-auth.guard';
import { LocalAuthGuard } from '../guards/local-auth.guard';
import { RolesGuard } from '../guards/roles.guard';
@@ -227,11 +231,29 @@ export class AuthController {
@Body() dto: UpdateProfileDto,
): Promise<{ message: string; data: UpdateProfileResultDto }> {
const result: UpdateProfileResultDto = await this.commandBus.execute(
new UpdateProfileCommand(user.sub, dto.fullName, dto.avatarUrl, dto.email),
new UpdateProfileCommand(user.sub, dto.fullName, dto.avatarUrl, dto.email, dto.phoneNumber),
);
return { message: 'Cập nhật hồ sơ thành công', data: result };
}
@UseGuards(JwtAuthGuard)
@Post('profile/verify-phone')
@ApiBearerAuth('JWT')
@ApiOperation({ summary: 'Verify phone number change with SMS OTP code' })
@ApiResponse({ status: 201, description: 'Phone number changed successfully' })
@ApiResponse({ status: 400, description: 'Invalid or expired OTP code' })
@ApiResponse({ status: 401, description: 'Unauthorized' })
@ApiResponse({ status: 409, description: 'Phone number already in use' })
async verifyPhoneChange(
@CurrentUser() user: JwtPayload,
@Body() dto: VerifyPhoneChangeDto,
): Promise<{ message: string; data: VerifyPhoneChangeResultDto }> {
const result: VerifyPhoneChangeResultDto = await this.commandBus.execute(
new VerifyPhoneChangeCommand(user.sub, dto.code),
);
return { message: 'Số điện thoại đã được cập nhật thành công', data: result };
}
@UseGuards(JwtAuthGuard)
@Post('profile/verify-email')
@ApiBearerAuth('JWT')
@@ -268,7 +290,7 @@ export class AuthController {
@ApiResponse({ status: 400, description: 'Validation error' })
@ApiResponse({ status: 401, description: 'Unauthorized' })
async generateKycUploadUrls(
@Body() body: { files: KycFileRequest[] },
@Body() body: GenerateKycUploadUrlsDto,
@CurrentUser() user: JwtPayload,
): Promise<{ field: string; uploadUrl: string; publicUrl: string; objectKey: string }[]> {
return this.commandBus.execute(
@@ -284,20 +306,9 @@ export class AuthController {
@ApiResponse({ status: 400, description: 'Validation error' })
@ApiResponse({ status: 401, description: 'Unauthorized' })
async submitKyc(
@Body()
body: {
documentType: string;
documentNumber: string;
frontImageUrl: string;
backImageUrl?: string;
selfieUrl?: string;
},
@Body() body: SubmitKycDto,
@CurrentUser() user: JwtPayload,
): Promise<{ message: string }> {
if (!body.frontImageUrl) {
throw new ValidationException('Vui lòng tải ảnh mặt trước giấy tờ');
}
return this.commandBus.execute(
new SubmitKycCommand(
user.sub,

View File

@@ -2,4 +2,5 @@ export { RegisterDto } from './register.dto';
export { LoginDto } from './login.dto';
export { RefreshTokenDto } from './refresh-token.dto';
export { VerifyKycDto } from './verify-kyc.dto';
export { GenerateKycUploadUrlsDto, KycFileRequestDto } from './generate-kyc-upload-urls.dto';
export { VerifyMfaSetupDto, VerifyMfaChallengeDto, UseBackupCodeDto, DisableMfaDto } from './mfa.dto';

View File

@@ -21,4 +21,13 @@ export class UpdateProfileDto {
@IsOptional()
@IsEmail({}, { message: 'Email không hợp lệ' })
email?: string;
@ApiPropertyOptional({
example: '0912345678',
description: 'Vietnamese phone number (will trigger SMS OTP re-verification)',
})
@IsOptional()
@IsString()
@MinLength(9, { message: 'Số điện thoại không hợp lệ' })
phoneNumber?: string;
}

View File

@@ -0,0 +1,10 @@
import { ApiProperty } from '@nestjs/swagger';
import { IsNotEmpty, IsString, Length } from 'class-validator';
export class VerifyPhoneChangeDto {
@ApiProperty({ example: '123456', description: '6-digit OTP code sent via SMS' })
@IsNotEmpty({ message: 'Mã xác thực không được để trống' })
@IsString()
@Length(6, 6, { message: 'Mã xác thực phải gồm 6 chữ số' })
code!: string;
}

View File

@@ -0,0 +1,113 @@
import { ActivateFeaturedListingHandler } from '../event-handlers/activate-featured-listing.handler';
describe('ActivateFeaturedListingHandler', () => {
let handler: ActivateFeaturedListingHandler;
let mockPrisma: {
payment: { findUnique: ReturnType<typeof vi.fn> };
listing: { findUnique: ReturnType<typeof vi.fn>; update: ReturnType<typeof vi.fn> };
};
let mockLogger: { log: ReturnType<typeof vi.fn> };
beforeEach(() => {
mockPrisma = {
payment: { findUnique: vi.fn() },
listing: { findUnique: vi.fn(), update: vi.fn() },
};
mockLogger = { log: vi.fn() };
handler = new ActivateFeaturedListingHandler(
mockPrisma as any,
mockLogger as any,
);
});
it('activates featured listing for 7 days on 199000 VND payment', async () => {
mockPrisma.payment.findUnique.mockResolvedValue({
type: 'FEATURED_LISTING',
transactionId: 'listing-1',
amountVND: 199000n,
});
mockPrisma.listing.findUnique.mockResolvedValue({ featuredUntil: null });
mockPrisma.listing.update.mockResolvedValue({});
await handler.handle({ aggregateId: 'pay-1' } as any);
expect(mockPrisma.listing.update).toHaveBeenCalledWith({
where: { id: 'listing-1' },
data: { featuredUntil: expect.any(Date) },
});
const updateCall = mockPrisma.listing.update.mock.calls[0][0];
const featuredUntil = updateCall.data.featuredUntil as Date;
const diffDays = Math.round((featuredUntil.getTime() - Date.now()) / (1000 * 60 * 60 * 24));
expect(diffDays).toBe(7);
});
it('activates featured listing for 3 days on 99000 VND payment', async () => {
mockPrisma.payment.findUnique.mockResolvedValue({
type: 'FEATURED_LISTING',
transactionId: 'listing-1',
amountVND: 99000n,
});
mockPrisma.listing.findUnique.mockResolvedValue({ featuredUntil: null });
mockPrisma.listing.update.mockResolvedValue({});
await handler.handle({ aggregateId: 'pay-1' } as any);
const updateCall = mockPrisma.listing.update.mock.calls[0][0];
const featuredUntil = updateCall.data.featuredUntil as Date;
const diffDays = Math.round((featuredUntil.getTime() - Date.now()) / (1000 * 60 * 60 * 24));
expect(diffDays).toBe(3);
});
it('extends from existing featuredUntil if still in the future', async () => {
const futureDate = new Date(Date.now() + 5 * 24 * 60 * 60 * 1000); // 5 days from now
mockPrisma.payment.findUnique.mockResolvedValue({
type: 'FEATURED_LISTING',
transactionId: 'listing-1',
amountVND: 199000n,
});
mockPrisma.listing.findUnique.mockResolvedValue({ featuredUntil: futureDate });
mockPrisma.listing.update.mockResolvedValue({});
await handler.handle({ aggregateId: 'pay-1' } as any);
const updateCall = mockPrisma.listing.update.mock.calls[0][0];
const featuredUntil = updateCall.data.featuredUntil as Date;
// Should extend from futureDate (5 days out) + 7 days = ~12 days from now
const diffDays = Math.round((featuredUntil.getTime() - Date.now()) / (1000 * 60 * 60 * 24));
expect(diffDays).toBe(12);
});
it('ignores non-FEATURED_LISTING payments', async () => {
mockPrisma.payment.findUnique.mockResolvedValue({
type: 'SUBSCRIPTION',
transactionId: 'listing-1',
amountVND: 199000n,
});
await handler.handle({ aggregateId: 'pay-1' } as any);
expect(mockPrisma.listing.update).not.toHaveBeenCalled();
});
it('ignores payments without transactionId', async () => {
mockPrisma.payment.findUnique.mockResolvedValue({
type: 'FEATURED_LISTING',
transactionId: null,
amountVND: 199000n,
});
await handler.handle({ aggregateId: 'pay-1' } as any);
expect(mockPrisma.listing.update).not.toHaveBeenCalled();
});
it('ignores payments that do not exist', async () => {
mockPrisma.payment.findUnique.mockResolvedValue(null);
await handler.handle({ aggregateId: 'pay-1' } as any);
expect(mockPrisma.listing.update).not.toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,128 @@
import { ListingEntity } from '@modules/listings/domain/entities/listing.entity';
import { type IListingRepository } from '@modules/listings/domain/repositories/listing.repository';
import { Price } from '@modules/listings/domain/value-objects/price.vo';
import { FeatureListingCommand } from '../commands/feature-listing/feature-listing.command';
import { FeatureListingHandler } from '../commands/feature-listing/feature-listing.handler';
function createListing(
id = 'listing-1',
sellerId = 'seller-1',
agentId: string | null = null,
status: 'DRAFT' | 'PENDING_REVIEW' | 'ACTIVE' = 'ACTIVE',
): ListingEntity {
const price = Price.create(2_000_000_000n).unwrap();
const listing = ListingEntity.createNew(id, 'prop-1', sellerId, 'SALE', price, 80, agentId ?? undefined);
if (status === 'PENDING_REVIEW' || status === 'ACTIVE') listing.submitForReview();
if (status === 'ACTIVE') listing.approve();
listing.clearDomainEvents();
return listing;
}
describe('FeatureListingHandler', () => {
let handler: FeatureListingHandler;
let mockListingRepo: Pick<IListingRepository, 'findById'>;
let mockCommandBus: { execute: ReturnType<typeof vi.fn> };
let mockLogger: { log: ReturnType<typeof vi.fn>; error: ReturnType<typeof vi.fn> };
beforeEach(() => {
mockListingRepo = { findById: vi.fn() };
mockCommandBus = {
execute: vi.fn().mockResolvedValue({
paymentId: 'pay-1',
paymentUrl: 'https://pay.example.com/checkout',
providerTxId: 'tx-1',
}),
};
mockLogger = { log: vi.fn(), error: vi.fn() };
handler = new FeatureListingHandler(
mockListingRepo as any,
mockCommandBus as any,
mockLogger as any,
);
});
it('creates payment for a valid feature request', async () => {
const listing = createListing('listing-1', 'seller-1', null, 'ACTIVE');
(mockListingRepo.findById as ReturnType<typeof vi.fn>).mockResolvedValue(listing);
const command = new FeatureListingCommand(
'listing-1', 'seller-1', '7_days', 'VNPAY',
'https://goodgo.vn/callback', '127.0.0.1',
);
const result = await handler.execute(command);
expect(result.paymentId).toBe('pay-1');
expect(result.paymentUrl).toBe('https://pay.example.com/checkout');
expect(result.package_).toBe('7_days');
expect(result.priceVND).toBe('199000');
expect(mockCommandBus.execute).toHaveBeenCalledTimes(1);
});
it('allows the assigned agent to feature the listing', async () => {
const listing = createListing('listing-1', 'seller-1', 'agent-1', 'ACTIVE');
(mockListingRepo.findById as ReturnType<typeof vi.fn>).mockResolvedValue(listing);
const command = new FeatureListingCommand(
'listing-1', 'agent-1', '3_days', 'MOMO',
'https://goodgo.vn/callback', '127.0.0.1',
);
const result = await handler.execute(command);
expect(result.paymentId).toBe('pay-1');
expect(result.priceVND).toBe('99000');
});
it('rejects feature request from unauthorized user', async () => {
const listing = createListing('listing-1', 'seller-1', null, 'ACTIVE');
(mockListingRepo.findById as ReturnType<typeof vi.fn>).mockResolvedValue(listing);
const command = new FeatureListingCommand(
'listing-1', 'stranger', '7_days', 'VNPAY',
'https://goodgo.vn/callback', '127.0.0.1',
);
await expect(handler.execute(command)).rejects.toThrow(/người bán|môi giới/);
});
it('rejects feature request for non-ACTIVE listing', async () => {
const listing = createListing('listing-1', 'seller-1', null, 'DRAFT');
(mockListingRepo.findById as ReturnType<typeof vi.fn>).mockResolvedValue(listing);
const command = new FeatureListingCommand(
'listing-1', 'seller-1', '7_days', 'VNPAY',
'https://goodgo.vn/callback', '127.0.0.1',
);
await expect(handler.execute(command)).rejects.toThrow(/hoạt động/);
});
it('throws NotFoundException for non-existent listing', async () => {
(mockListingRepo.findById as ReturnType<typeof vi.fn>).mockResolvedValue(null);
const command = new FeatureListingCommand(
'nonexistent', 'seller-1', '7_days', 'VNPAY',
'https://goodgo.vn/callback', '127.0.0.1',
);
await expect(handler.execute(command)).rejects.toThrow('Listing');
});
it('uses correct pricing for each package', async () => {
const listing = createListing('listing-1', 'seller-1', null, 'ACTIVE');
(mockListingRepo.findById as ReturnType<typeof vi.fn>).mockResolvedValue(listing);
for (const [pkg, expectedPrice] of [
['3_days', '99000'],
['7_days', '199000'],
['30_days', '499000'],
] as const) {
const command = new FeatureListingCommand(
'listing-1', 'seller-1', pkg, 'VNPAY',
'https://goodgo.vn/callback', '127.0.0.1',
);
const result = await handler.execute(command);
expect(result.priceVND).toBe(expectedPrice);
}
});
});

View File

@@ -0,0 +1,58 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { GetPriceHistoryHandler } from '../queries/get-price-history/get-price-history.handler';
import { GetPriceHistoryQuery } from '../queries/get-price-history/get-price-history.query';
describe('GetPriceHistoryHandler', () => {
let handler: GetPriceHistoryHandler;
let mockPrisma: { priceHistory: { findMany: ReturnType<typeof vi.fn> } };
beforeEach(() => {
mockPrisma = {
priceHistory: { findMany: vi.fn() },
};
handler = new GetPriceHistoryHandler(mockPrisma as any);
});
it('should query price history for the given listing ordered by changedAt desc', async () => {
const mockHistory = [
{ id: 'ph-2', oldPrice: 5_000_000_000n, newPrice: 6_000_000_000n, source: 'manual_update', changedAt: new Date('2026-04-16') },
{ id: 'ph-1', oldPrice: 4_000_000_000n, newPrice: 5_000_000_000n, source: 'manual_update', changedAt: new Date('2026-04-10') },
];
mockPrisma.priceHistory.findMany.mockResolvedValue(mockHistory);
const query = new GetPriceHistoryQuery('listing-1');
const result = await handler.execute(query);
expect(result).toEqual(mockHistory);
expect(mockPrisma.priceHistory.findMany).toHaveBeenCalledWith({
where: { listingId: 'listing-1' },
orderBy: { changedAt: 'desc' },
select: {
id: true,
oldPrice: true,
newPrice: true,
source: true,
changedAt: true,
},
});
});
it('should return empty array when no history exists', async () => {
mockPrisma.priceHistory.findMany.mockResolvedValue([]);
const query = new GetPriceHistoryQuery('listing-no-history');
const result = await handler.execute(query);
expect(result).toEqual([]);
});
it('should include source field in the select', async () => {
mockPrisma.priceHistory.findMany.mockResolvedValue([
{ id: 'ph-1', oldPrice: 1n, newPrice: 2n, source: 'admin_override', changedAt: new Date() },
]);
const result = await handler.execute(new GetPriceHistoryQuery('listing-1'));
expect(result[0].source).toBe('admin_override');
});
});

View File

@@ -0,0 +1,94 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { RecordPriceHistoryHandler } from '../event-handlers/record-price-history.handler';
import { ListingPriceChangedEvent } from '../../domain/events/listing-price-changed.event';
describe('RecordPriceHistoryHandler', () => {
let handler: RecordPriceHistoryHandler;
let mockPrisma: { priceHistory: { create: ReturnType<typeof vi.fn> } };
let mockLogger: { debug: ReturnType<typeof vi.fn>; error: ReturnType<typeof vi.fn> };
beforeEach(() => {
mockPrisma = {
priceHistory: { create: vi.fn().mockResolvedValue({ id: 'ph-1' }) },
};
mockLogger = {
debug: vi.fn(),
error: vi.fn(),
};
handler = new RecordPriceHistoryHandler(mockPrisma as any, mockLogger as any);
});
it('should persist a price history record with correct data', async () => {
const event = new ListingPriceChangedEvent(
'listing-1',
5_000_000_000n,
6_000_000_000n,
'manual_update',
);
await handler.handle(event);
expect(mockPrisma.priceHistory.create).toHaveBeenCalledWith({
data: {
listingId: 'listing-1',
oldPrice: 5_000_000_000n,
newPrice: 6_000_000_000n,
source: 'manual_update',
changedAt: event.occurredAt,
},
});
});
it('should persist source as admin_override when provided', async () => {
const event = new ListingPriceChangedEvent(
'listing-2',
3_000_000_000n,
4_500_000_000n,
'admin_override',
);
await handler.handle(event);
expect(mockPrisma.priceHistory.create).toHaveBeenCalledWith(
expect.objectContaining({
data: expect.objectContaining({ source: 'admin_override' }),
}),
);
});
it('should default source to manual_update', async () => {
const event = new ListingPriceChangedEvent('listing-3', 1_000_000n, 2_000_000n);
await handler.handle(event);
expect(mockPrisma.priceHistory.create).toHaveBeenCalledWith(
expect.objectContaining({
data: expect.objectContaining({ source: 'manual_update' }),
}),
);
});
it('should log debug message on success', async () => {
const event = new ListingPriceChangedEvent('listing-1', 100n, 200n);
await handler.handle(event);
expect(mockLogger.debug).toHaveBeenCalledWith(
expect.stringContaining('listing-1'),
'RecordPriceHistoryHandler',
);
});
it('should log error and not throw when persistence fails', async () => {
mockPrisma.priceHistory.create.mockRejectedValue(new Error('DB connection lost'));
const event = new ListingPriceChangedEvent('listing-1', 100n, 200n);
await expect(handler.handle(event)).resolves.toBeUndefined();
expect(mockLogger.error).toHaveBeenCalledWith(
expect.stringContaining('DB connection lost'),
expect.any(String),
'RecordPriceHistoryHandler',
);
});
});

View File

@@ -16,6 +16,7 @@ export class RecordPriceHistoryHandler implements IEventHandler<ListingPriceChan
listingId: event.aggregateId,
oldPrice: event.oldPrice,
newPrice: event.newPrice,
source: event.source,
changedAt: event.occurredAt,
},
});

View File

@@ -6,6 +6,7 @@ export interface PriceHistoryItem {
id: string;
oldPrice: bigint;
newPrice: bigint;
source: string;
changedAt: Date;
}
@@ -21,6 +22,7 @@ export class GetPriceHistoryHandler implements IQueryHandler<GetPriceHistoryQuer
id: true,
oldPrice: true,
newPrice: true,
source: true,
changedAt: true,
},
});

View File

@@ -1,6 +1,7 @@
import { describe, it, expect } from 'vitest';
import { ListingApprovedEvent } from '../events/listing-approved.event';
import { ListingCreatedEvent } from '../events/listing-created.event';
import { ListingPriceChangedEvent } from '../events/listing-price-changed.event';
import { ListingSoldEvent } from '../events/listing-sold.event';
import { ListingStatusChangedEvent } from '../events/listing-status-changed.event';
@@ -51,6 +52,34 @@ describe('Listings Domain Events', () => {
});
});
describe('ListingPriceChangedEvent', () => {
it('creates event with correct properties', () => {
const event = new ListingPriceChangedEvent('listing-1', 5_000_000_000n, 6_000_000_000n, 'manual_update');
expect(event.eventName).toBe('listing.price_changed');
expect(event.aggregateId).toBe('listing-1');
expect(event.oldPrice).toBe(5_000_000_000n);
expect(event.newPrice).toBe(6_000_000_000n);
expect(event.source).toBe('manual_update');
expect(event.occurredAt).toBeInstanceOf(Date);
});
it('defaults source to manual_update', () => {
const event = new ListingPriceChangedEvent('listing-2', 1_000_000n, 2_000_000n);
expect(event.source).toBe('manual_update');
});
it('accepts admin_override source', () => {
const event = new ListingPriceChangedEvent('listing-3', 1n, 2n, 'admin_override');
expect(event.source).toBe('admin_override');
});
it('accepts market_adjustment source', () => {
const event = new ListingPriceChangedEvent('listing-4', 1n, 2n, 'market_adjustment');
expect(event.source).toBe('market_adjustment');
});
});
describe('ListingStatusChangedEvent', () => {
it('creates event with correct properties', () => {
const event = new ListingStatusChangedEvent('listing-1', 'prop-1', 'DRAFT', 'PENDING_REVIEW');

View File

@@ -1,5 +1,7 @@
import type { DomainEvent } from '@modules/shared';
export type PriceChangeSource = 'manual_update' | 'admin_override' | 'market_adjustment';
export class ListingPriceChangedEvent implements DomainEvent {
readonly eventName = 'listing.price_changed';
readonly occurredAt = new Date();
@@ -8,5 +10,6 @@ export class ListingPriceChangedEvent implements DomainEvent {
public readonly aggregateId: string,
public readonly oldPrice: bigint,
public readonly newPrice: bigint,
public readonly source: PriceChangeSource = 'manual_update',
) {}
}

View File

@@ -30,6 +30,7 @@ export interface IMediaStorageService {
expiresInSeconds?: number,
): Promise<PresignedUploadResult>;
getPublicUrl(objectKey: string): string;
isTrustedUrl(url: string): boolean;
}
function requireEnv(key: string): string {
@@ -151,6 +152,45 @@ export class MinioMediaStorageService implements IMediaStorageService, OnModuleI
return `${protocol}://${this.endpoint}:${this.port}/${this.bucket}/${objectKey}`;
}
/**
* Validates that a URL points to our configured MinIO bucket.
* Accepts the primary endpoint, plus an optional comma-separated list of
* additional trusted hosts via `MINIO_TRUSTED_HOSTS` (e.g. public CDN domains).
* Also enforces the bucket is the first path segment.
*/
isTrustedUrl(url: string): boolean {
if (!url || typeof url !== 'string') {
return false;
}
let parsed: URL;
try {
parsed = new URL(url);
} catch {
return false;
}
const allowedHosts = new Set<string>();
allowedHosts.add(this.endpoint.toLowerCase());
allowedHosts.add(`${this.endpoint.toLowerCase()}:${this.port}`);
const extra = process.env['MINIO_TRUSTED_HOSTS'];
if (extra) {
for (const h of extra.split(',')) {
const trimmed = h.trim().toLowerCase();
if (trimmed) allowedHosts.add(trimmed);
}
}
const host = parsed.host.toLowerCase();
if (!allowedHosts.has(host) && !allowedHosts.has(parsed.hostname.toLowerCase())) {
return false;
}
// Path must start with /<bucket>/
const expectedPrefix = `/${this.bucket}/`;
return parsed.pathname.startsWith(expectedPrefix) && parsed.pathname.length > expectedPrefix.length;
}
async delete(fileUrl: string): Promise<void> {
try {
const urlObj = new URL(fileUrl);

View File

@@ -96,6 +96,29 @@ describe('ListingsController', () => {
});
});
describe('getPriceHistory', () => {
it('should execute GetPriceHistoryQuery via query bus', async () => {
const mockHistory = [
{ id: 'ph-1', oldPrice: '5000000000', newPrice: '6000000000', source: 'manual_update', changedAt: '2026-04-16T00:00:00.000Z' },
];
mockQueryBus.execute.mockResolvedValue(mockHistory);
const result = await controller.getPriceHistory('listing-1');
expect(result).toEqual(mockHistory);
expect(mockQueryBus.execute).toHaveBeenCalledTimes(1);
});
it('should return empty array when no price history exists', async () => {
mockQueryBus.execute.mockResolvedValue([]);
const result = await controller.getPriceHistory('listing-no-history');
expect(result).toEqual([]);
expect(mockQueryBus.execute).toHaveBeenCalledTimes(1);
});
});
describe('updateListing', () => {
it('should execute UpdateListingCommand via command bus', async () => {
const mockResult = {

View File

@@ -0,0 +1,32 @@
import { Injectable } from '@nestjs/common';
import { type CommandBus } from '@nestjs/cqrs';
import { OnEvent } from '@nestjs/event-emitter';
import { type PhoneChangeRequestedEvent } from '@modules/auth';
import { type LoggerService } from '@modules/shared';
import { SendNotificationCommand } from '../commands/send-notification/send-notification.command';
@Injectable()
export class PhoneChangeRequestedListener {
constructor(
private readonly commandBus: CommandBus,
private readonly logger: LoggerService,
) {}
@OnEvent('user.phone_change_requested', { async: true })
async handle(event: PhoneChangeRequestedEvent): Promise<void> {
this.logger.log(
`Handling phone change OTP for user ${event.aggregateId}`,
'PhoneChangeRequestedListener',
);
await this.commandBus.execute(
new SendNotificationCommand(
event.aggregateId,
'SMS',
'user.phone_change_otp',
{ otpCode: event.otpCode },
event.newPhone,
),
);
}
}

View File

@@ -81,10 +81,10 @@ describe('TemplateService', () => {
expect(result.body).toContain('/listings/2');
});
it('getTemplateKeys returns all 12 template keys', () => {
it('getTemplateKeys returns all 13 template keys', () => {
const keys = service.getTemplateKeys();
expect(keys).toHaveLength(12);
expect(keys).toHaveLength(13);
expect(keys).toContain('user.registered');
expect(keys).toContain('agent.verified');
expect(keys).toContain('listing.approved');
@@ -97,5 +97,6 @@ describe('TemplateService', () => {
expect(keys).toContain('saved_search_alert');
expect(keys).toContain('saved_search_digest');
expect(keys).toContain('user.email_change_otp');
expect(keys).toContain('user.phone_change_otp');
});
});

View File

@@ -86,6 +86,10 @@ const TEMPLATES: Record<string, TemplateDefinition> = {
<p>Nếu bạn không yêu cầu, hãy bỏ qua email này.</p>
<p>Trân trọng,<br/>Đội ngũ GoodGo</p>`,
},
'user.phone_change_otp': {
subject: 'Xác nhận thay đổi số điện thoại — GoodGo',
body: `Mã xác nhận thay đổi số điện thoại GoodGo: {{otpCode}}. Mã có hiệu lực trong 10 phút. Nếu bạn không yêu cầu, hãy bỏ qua tin nhắn này.`,
},
'saved_search_alert': {
subject: 'Tin mới phù hợp tìm kiếm "{{searchName}}"',
body: `<h1>Xin chào {{userName}}!</h1>

View File

@@ -11,6 +11,7 @@ import { ListingSoldListener } from './application/listeners/listing-sold.listen
import { PaymentCompletedListener } from './application/listeners/payment-completed.listener';
import { PaymentFailedListener } from './application/listeners/payment-failed.listener';
import { PaymentRefundedListener } from './application/listeners/payment-refunded.listener';
import { PhoneChangeRequestedListener } from './application/listeners/phone-change-requested.listener';
import { QuotaExceededListener } from './application/listeners/quota-exceeded.listener';
import { SubscriptionExpiredListener } from './application/listeners/subscription-expired.listener';
import { SubscriptionExpiringListener } from './application/listeners/subscription-expiring.listener';
@@ -27,6 +28,7 @@ import { StringeeSmsService } from './infrastructure/services/stringee-sms.servi
import { TemplateService } from './infrastructure/services/template.service';
import { ZaloOaService } from './infrastructure/services/zalo-oa.service';
import { NotificationsController } from './presentation/controllers/notifications.controller';
import { ZaloOaWebhookController } from './presentation/controllers/zalo-oa-webhook.controller';
import { NotificationsGateway } from './presentation/gateways/notifications.gateway';
const CommandHandlers = [SendNotificationHandler];
@@ -47,11 +49,12 @@ const EventListeners = [
ListingSoldListener,
UserKycUpdatedListener,
EmailChangeRequestedListener,
PhoneChangeRequestedListener,
];
@Module({
imports: [CqrsModule, AuthModule],
controllers: [NotificationsController],
controllers: [NotificationsController, ZaloOaWebhookController],
providers: [
// Repositories
{ provide: NOTIFICATION_REPOSITORY, useClass: PrismaNotificationRepository },

View File

@@ -0,0 +1,276 @@
import type { NotificationSentEvent } from '../../domain/events/notification-sent.event';
import { NotificationsGateway } from '../gateways/notifications.gateway';
function createMockSocket(overrides: Partial<{
id: string;
data: Record<string, unknown>;
handshake: { auth?: Record<string, unknown>; headers?: Record<string, unknown>; query?: Record<string, unknown> };
join: ReturnType<typeof vi.fn>;
emit: ReturnType<typeof vi.fn>;
disconnect: ReturnType<typeof vi.fn>;
}> = {}) {
return {
id: overrides.id ?? 'socket-1',
data: overrides.data ?? {},
handshake: overrides.handshake ?? { auth: { token: 'valid-jwt' }, headers: {}, query: {} },
join: overrides.join ?? vi.fn().mockResolvedValue(undefined),
emit: overrides.emit ?? vi.fn(),
disconnect: overrides.disconnect ?? vi.fn(),
} as any;
}
describe('NotificationsGateway', () => {
let gateway: NotificationsGateway;
let mockTokenService: { verifyAccessToken: ReturnType<typeof vi.fn> };
let mockLogger: {
log: ReturnType<typeof vi.fn>;
debug: ReturnType<typeof vi.fn>;
warn: ReturnType<typeof vi.fn>;
error: ReturnType<typeof vi.fn>;
};
let mockRedisService: {
isAvailable: ReturnType<typeof vi.fn>;
get: ReturnType<typeof vi.fn>;
set: ReturnType<typeof vi.fn>;
del: ReturnType<typeof vi.fn>;
getClient: ReturnType<typeof vi.fn>;
};
let mockNotificationRepo: { countUnreadByUserId: ReturnType<typeof vi.fn> };
let mockServer: {
to: ReturnType<typeof vi.fn>;
};
beforeEach(() => {
mockTokenService = {
verifyAccessToken: vi.fn().mockReturnValue({ sub: 'user-1', role: 'USER' }),
};
mockLogger = { log: vi.fn(), debug: vi.fn(), warn: vi.fn(), error: vi.fn() };
mockRedisService = {
isAvailable: vi.fn().mockReturnValue(true),
get: vi.fn().mockResolvedValue(null),
set: vi.fn().mockResolvedValue(undefined),
del: vi.fn().mockResolvedValue(undefined),
getClient: vi.fn().mockReturnValue({ exists: vi.fn().mockResolvedValue(0), incr: vi.fn() }),
};
mockNotificationRepo = { countUnreadByUserId: vi.fn().mockResolvedValue(3) };
gateway = new NotificationsGateway(
mockTokenService as any,
mockLogger as any,
mockRedisService as any,
mockNotificationRepo as any,
);
// Wire the server mock
mockServer = { to: vi.fn().mockReturnValue({ emit: vi.fn() }) };
(gateway as any).server = mockServer;
});
describe('afterInit', () => {
it('logs initialization', () => {
gateway.afterInit();
expect(mockLogger.log).toHaveBeenCalledWith(
expect.stringContaining('initialized'),
'NotificationsGateway',
);
});
});
describe('handleConnection', () => {
it('authenticates, joins room, and emits unread count', async () => {
const socket = createMockSocket();
await gateway.handleConnection(socket);
expect(mockTokenService.verifyAccessToken).toHaveBeenCalledWith('valid-jwt');
expect(socket.data['userId']).toBe('user-1');
expect(socket.data['role']).toBe('USER');
expect(socket.join).toHaveBeenCalledWith('user:user-1');
expect(socket.emit).toHaveBeenCalledWith('notification:unread-count', { unreadCount: 3 });
});
it('strips Bearer prefix from Authorization header', async () => {
const socket = createMockSocket({
handshake: { auth: {}, headers: { authorization: 'Bearer my-token' }, query: {} },
});
await gateway.handleConnection(socket);
expect(mockTokenService.verifyAccessToken).toHaveBeenCalledWith('my-token');
});
it('disconnects client when no token provided', async () => {
const socket = createMockSocket({
handshake: { auth: {}, headers: {}, query: {} },
});
mockTokenService.verifyAccessToken.mockReturnValue(null);
await gateway.handleConnection(socket);
expect(socket.disconnect).toHaveBeenCalledWith(true);
});
it('disconnects client when token is invalid', async () => {
mockTokenService.verifyAccessToken.mockReturnValue(null);
const socket = createMockSocket();
await gateway.handleConnection(socket);
expect(socket.disconnect).toHaveBeenCalledWith(true);
});
it('tracks multiple sockets per user (multi-device)', async () => {
const socket1 = createMockSocket({ id: 'sock-a' });
const socket2 = createMockSocket({ id: 'sock-b' });
await gateway.handleConnection(socket1);
await gateway.handleConnection(socket2);
// Both sockets tracked
const userSockets = (gateway as any).userSockets as Map<string, Set<string>>;
expect(userSockets.get('user-1')?.size).toBe(2);
expect(userSockets.get('user-1')?.has('sock-a')).toBe(true);
expect(userSockets.get('user-1')?.has('sock-b')).toBe(true);
});
it('uses cached unread count from Redis when available', async () => {
mockRedisService.get.mockResolvedValue('7');
const socket = createMockSocket();
await gateway.handleConnection(socket);
expect(socket.emit).toHaveBeenCalledWith('notification:unread-count', { unreadCount: 7 });
expect(mockNotificationRepo.countUnreadByUserId).not.toHaveBeenCalled();
});
it('falls back to DB when Redis unavailable', async () => {
mockRedisService.isAvailable.mockReturnValue(false);
const socket = createMockSocket();
await gateway.handleConnection(socket);
expect(mockNotificationRepo.countUnreadByUserId).toHaveBeenCalledWith('user-1');
expect(socket.emit).toHaveBeenCalledWith('notification:unread-count', { unreadCount: 3 });
});
});
describe('handleDisconnect', () => {
it('removes socket from tracking map', async () => {
const socket = createMockSocket({ id: 'sock-1' });
await gateway.handleConnection(socket);
gateway.handleDisconnect(socket);
const userSockets = (gateway as any).userSockets as Map<string, Set<string>>;
expect(userSockets.has('user-1')).toBe(false);
});
it('keeps other sockets when one disconnects', async () => {
const socket1 = createMockSocket({ id: 'sock-1' });
const socket2 = createMockSocket({ id: 'sock-2' });
await gateway.handleConnection(socket1);
await gateway.handleConnection(socket2);
gateway.handleDisconnect(socket1);
const userSockets = (gateway as any).userSockets as Map<string, Set<string>>;
expect(userSockets.get('user-1')?.size).toBe(1);
expect(userSockets.get('user-1')?.has('sock-2')).toBe(true);
});
it('handles disconnect from unknown socket gracefully', () => {
const socket = createMockSocket();
// No prior connection — should not throw
expect(() => gateway.handleDisconnect(socket)).not.toThrow();
});
});
describe('handleNotificationSent', () => {
const event: NotificationSentEvent = {
aggregateId: 'notif-1',
userId: 'user-1',
templateKey: 'listing_approved',
channel: 'EMAIL',
occurredAt: new Date('2026-04-16T12:00:00Z'),
} as any;
it('emits notification:new to user room', async () => {
const roomEmit = vi.fn();
mockServer.to.mockReturnValue({ emit: roomEmit });
await gateway.handleNotificationSent(event);
expect(mockServer.to).toHaveBeenCalledWith('user:user-1');
expect(roomEmit).toHaveBeenCalledWith('notification:new', {
id: 'notif-1',
templateKey: 'listing_approved',
channel: 'EMAIL',
occurredAt: '2026-04-16T12:00:00.000Z',
});
});
it('emits updated unread count after notification', async () => {
const roomEmit = vi.fn();
mockServer.to.mockReturnValue({ emit: roomEmit });
await gateway.handleNotificationSent(event);
// Called twice: once for notification:new, once for unread-count
expect(roomEmit).toHaveBeenCalledWith('notification:unread-count', { unreadCount: 3 });
});
it('increments cached unread count in Redis when key exists', async () => {
const mockIncr = vi.fn();
mockRedisService.getClient.mockReturnValue({
exists: vi.fn().mockResolvedValue(1),
incr: mockIncr,
});
mockServer.to.mockReturnValue({ emit: vi.fn() });
await gateway.handleNotificationSent(event);
expect(mockIncr).toHaveBeenCalled();
});
it('does not throw when event handling fails', async () => {
mockServer.to.mockImplementation(() => {
throw new Error('server error');
});
await expect(gateway.handleNotificationSent(event)).resolves.not.toThrow();
expect(mockLogger.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to emit'),
expect.any(String),
'NotificationsGateway',
);
});
});
describe('emitUnreadCount', () => {
it('emits unread count to user room', async () => {
const roomEmit = vi.fn();
mockServer.to.mockReturnValue({ emit: roomEmit });
await gateway.emitUnreadCount('user-1');
expect(mockServer.to).toHaveBeenCalledWith('user:user-1');
expect(roomEmit).toHaveBeenCalledWith('notification:unread-count', { unreadCount: 3 });
});
});
describe('invalidateUnreadCount', () => {
it('deletes cached unread count from Redis', async () => {
await gateway.invalidateUnreadCount('user-1');
expect(mockRedisService.del).toHaveBeenCalledWith('notifications:unread:user-1');
});
it('skips deletion when Redis unavailable', async () => {
mockRedisService.isAvailable.mockReturnValue(false);
await gateway.invalidateUnreadCount('user-1');
expect(mockRedisService.del).not.toHaveBeenCalled();
});
});
});

View File

@@ -0,0 +1,225 @@
import { ZaloOaWebhookController } from '../controllers/zalo-oa-webhook.controller';
describe('ZaloOaWebhookController', () => {
let controller: ZaloOaWebhookController;
let mockPrisma: {
oAuthAccount: {
findFirst: ReturnType<typeof vi.fn>;
};
};
let mockLogger: {
log: ReturnType<typeof vi.fn>;
warn: ReturnType<typeof vi.fn>;
error: ReturnType<typeof vi.fn>;
};
let mockZaloOaService: { isAvailable: boolean };
beforeEach(() => {
mockPrisma = {
oAuthAccount: { findFirst: vi.fn() },
};
mockLogger = { log: vi.fn(), warn: vi.fn(), error: vi.fn() };
mockZaloOaService = { isAvailable: true };
controller = new ZaloOaWebhookController(
mockPrisma as any,
mockLogger as any,
mockZaloOaService as any,
);
});
describe('verify', () => {
it('returns the challenge token', () => {
const result = controller.verify('test-challenge-123');
expect(result).toBe('test-challenge-123');
});
it('returns empty string when no challenge provided', () => {
const result = controller.verify(undefined as any);
expect(result).toBe('');
});
});
describe('handleEvent', () => {
const mockReq = {} as any;
it('returns received:true for all events', async () => {
const result = await controller.handleEvent(
{ app_id: 'app-1', event_name: 'follow', timestamp: '123', sender: { id: 'zalo-1' }, recipient: { id: 'oa-1' } },
mockReq,
);
expect(result).toEqual({ received: true });
});
it('skips processing when Zalo OA not configured', async () => {
mockZaloOaService.isAvailable = false;
await controller.handleEvent(
{ app_id: 'app-1', event_name: 'follow', timestamp: '123', sender: { id: 'zalo-1' }, recipient: { id: 'oa-1' } },
mockReq,
);
expect(mockLogger.warn).toHaveBeenCalledWith(
expect.stringContaining('not configured'),
'ZaloOaWebhookController',
);
expect(mockPrisma.oAuthAccount.findFirst).not.toHaveBeenCalled();
});
describe('follow event', () => {
it('checks for existing OAuth link on follow', async () => {
mockPrisma.oAuthAccount.findFirst.mockResolvedValue(null);
await controller.handleEvent(
{ app_id: 'app-1', event_name: 'follow', timestamp: '123', sender: { id: 'zalo-user-123' }, recipient: { id: 'oa-1' } },
mockReq,
);
expect(mockPrisma.oAuthAccount.findFirst).toHaveBeenCalledWith({
where: { provider: 'ZALO', providerUserId: 'zalo-user-123' },
});
});
it('logs when user is already linked', async () => {
mockPrisma.oAuthAccount.findFirst.mockResolvedValue({
userId: 'user-abc',
providerUserId: 'zalo-user-123',
});
await controller.handleEvent(
{ app_id: 'app-1', event_name: 'follow', timestamp: '123', sender: { id: 'zalo-user-123' }, recipient: { id: 'oa-1' } },
mockReq,
);
expect(mockLogger.log).toHaveBeenCalledWith(
expect.stringContaining('already linked'),
'ZaloOaWebhookController',
);
});
it('logs when no link found (manual linking needed)', async () => {
mockPrisma.oAuthAccount.findFirst.mockResolvedValue(null);
await controller.handleEvent(
{ app_id: 'app-1', event_name: 'follow', timestamp: '123', sender: { id: 'zalo-user-456' }, recipient: { id: 'oa-1' } },
mockReq,
);
expect(mockLogger.log).toHaveBeenCalledWith(
expect.stringContaining('no existing link'),
'ZaloOaWebhookController',
);
});
});
describe('unfollow event', () => {
it('logs unfollow event', async () => {
await controller.handleEvent(
{ app_id: 'app-1', event_name: 'unfollow', timestamp: '123', sender: { id: 'zalo-user-789' }, recipient: { id: 'oa-1' } },
mockReq,
);
expect(mockLogger.log).toHaveBeenCalledWith(
expect.stringContaining('unfollowed'),
'ZaloOaWebhookController',
);
});
});
describe('user_send_text event', () => {
it('logs incoming message and checks for linked user', async () => {
mockPrisma.oAuthAccount.findFirst.mockResolvedValue({ userId: 'user-linked' });
await controller.handleEvent(
{
app_id: 'app-1',
event_name: 'user_send_text',
timestamp: '123',
sender: { id: 'zalo-user-100' },
recipient: { id: 'oa-1' },
message: { text: 'Xin chào', msg_id: 'msg-001' },
},
mockReq,
);
expect(mockPrisma.oAuthAccount.findFirst).toHaveBeenCalledWith({
where: { provider: 'ZALO', providerUserId: 'zalo-user-100' },
select: { userId: true },
});
expect(mockLogger.log).toHaveBeenCalledWith(
expect.stringContaining('linked user user-linked'),
'ZaloOaWebhookController',
);
});
it('handles message from unlinked user', async () => {
mockPrisma.oAuthAccount.findFirst.mockResolvedValue(null);
await controller.handleEvent(
{
app_id: 'app-1',
event_name: 'user_send_text',
timestamp: '123',
sender: { id: 'zalo-user-200' },
recipient: { id: 'oa-1' },
message: { text: 'Hello', msg_id: 'msg-002' },
},
mockReq,
);
expect(mockLogger.log).toHaveBeenCalledWith(
expect.stringContaining('Message from Zalo UID'),
'ZaloOaWebhookController',
);
});
it('ignores messages without text', async () => {
await controller.handleEvent(
{
app_id: 'app-1',
event_name: 'user_send_text',
timestamp: '123',
sender: { id: 'zalo-user-300' },
recipient: { id: 'oa-1' },
message: { msg_id: 'msg-003' },
},
mockReq,
);
expect(mockPrisma.oAuthAccount.findFirst).not.toHaveBeenCalled();
});
});
describe('unknown events', () => {
it('logs unhandled event types', async () => {
await controller.handleEvent(
{ app_id: 'app-1', event_name: 'user_send_image', timestamp: '123', sender: { id: 'zalo-1' }, recipient: { id: 'oa-1' } },
mockReq,
);
expect(mockLogger.log).toHaveBeenCalledWith(
expect.stringContaining('Unhandled event type'),
'ZaloOaWebhookController',
);
});
});
describe('error handling', () => {
it('catches and logs errors without throwing', async () => {
mockPrisma.oAuthAccount.findFirst.mockRejectedValue(new Error('DB connection lost'));
const result = await controller.handleEvent(
{ app_id: 'app-1', event_name: 'follow', timestamp: '123', sender: { id: 'zalo-1' }, recipient: { id: 'oa-1' } },
mockReq,
);
expect(result).toEqual({ received: true });
expect(mockLogger.error).toHaveBeenCalledWith(
expect.stringContaining('DB connection lost'),
expect.any(String),
'ZaloOaWebhookController',
);
});
});
});
});

View File

@@ -0,0 +1,166 @@
import { Body, Controller, Get, HttpCode, Post, Query, RawBodyRequest, Req } from '@nestjs/common';
import type { Request } from 'express';
import { LoggerService, PrismaService } from '@modules/shared';
import { ZaloOaService } from '../../infrastructure/services/zalo-oa.service';
/**
* Zalo OA event types from webhook payloads.
*
* @see https://developers.zalo.me/docs/official-account/webhook
*/
interface ZaloOaWebhookPayload {
app_id: string;
event_name: string;
timestamp: string;
sender: { id: string };
recipient: { id: string };
message?: { text?: string; msg_id?: string; attachments?: unknown[] };
follower?: { id: string };
user_id_by_app?: string;
}
const WEBHOOK_CONTEXT = 'ZaloOaWebhookController';
@Controller('webhooks/zalo-oa')
export class ZaloOaWebhookController {
constructor(
private readonly prisma: PrismaService,
private readonly logger: LoggerService,
private readonly zaloOaService: ZaloOaService,
) {}
/**
* Webhook verification endpoint.
* Zalo OA sends a GET request with a challenge token during webhook setup.
*/
@Get()
verify(@Query('challenge') challenge: string): string {
this.logger.log(`Webhook verification: challenge=${challenge}`, WEBHOOK_CONTEXT);
return challenge ?? '';
}
/**
* Receive and process Zalo OA webhook events.
*
* Supported events:
* - `follow` — user follows the OA, attempt to link via phone
* - `unfollow` — user unfollows the OA
* - `user_send_text` — user sends a text message to the OA
*/
@Post()
@HttpCode(200)
async handleEvent(
@Body() payload: ZaloOaWebhookPayload,
@Req() req: RawBodyRequest<Request>,
): Promise<{ received: true }> {
const { event_name, sender, timestamp } = payload;
this.logger.log(
`Webhook event: ${event_name} from=${sender?.id ?? 'unknown'} at=${timestamp}`,
WEBHOOK_CONTEXT,
);
// Verify OA secret (app_id must match our configured OA)
if (!this.zaloOaService.isAvailable) {
this.logger.warn('Zalo OA not configured — ignoring webhook event', WEBHOOK_CONTEXT);
return { received: true };
}
try {
switch (event_name) {
case 'follow':
await this.handleFollow(payload);
break;
case 'unfollow':
await this.handleUnfollow(payload);
break;
case 'user_send_text':
await this.handleUserMessage(payload);
break;
default:
this.logger.log(`Unhandled event type: ${event_name}`, WEBHOOK_CONTEXT);
}
} catch (error) {
this.logger.error(
`Webhook processing failed for ${event_name}: ${error instanceof Error ? error.message : error}`,
error instanceof Error ? error.stack : undefined,
WEBHOOK_CONTEXT,
);
}
return { received: true };
}
/**
* Handle `follow` event — attempt to link the Zalo user to a platform user.
*
* Linking strategy: look up OAuthAccount with provider=ZALO and matching providerUserId,
* or try phone-based matching if the Zalo user ID can be resolved to a phone.
*/
private async handleFollow(payload: ZaloOaWebhookPayload): Promise<void> {
const zaloUid = payload.sender?.id ?? payload.follower?.id;
if (!zaloUid) return;
// Check if already linked via OAuth
const existingLink = await this.prisma.oAuthAccount.findFirst({
where: { provider: 'ZALO', providerUserId: zaloUid },
});
if (existingLink) {
this.logger.log(
`Follow event: Zalo UID ${zaloUid.slice(0, 6)}*** already linked to user ${existingLink.userId}`,
WEBHOOK_CONTEXT,
);
return;
}
this.logger.log(
`Follow event: Zalo UID ${zaloUid.slice(0, 6)}*** — no existing link found. Manual linking may be required via phone verification.`,
WEBHOOK_CONTEXT,
);
}
/**
* Handle `unfollow` event — log the event for analytics.
* We do NOT remove the OAuth link (user may re-follow).
*/
private async handleUnfollow(payload: ZaloOaWebhookPayload): Promise<void> {
const zaloUid = payload.sender?.id;
if (!zaloUid) return;
this.logger.log(
`Unfollow event: Zalo UID ${zaloUid.slice(0, 6)}*** unfollowed OA`,
WEBHOOK_CONTEXT,
);
}
/**
* Handle incoming text message from a Zalo user.
* Logs the message for now — can be extended to create inquiries or route to messaging.
*/
private async handleUserMessage(payload: ZaloOaWebhookPayload): Promise<void> {
const zaloUid = payload.sender?.id;
const text = payload.message?.text;
const msgId = payload.message?.msg_id;
if (!zaloUid || !text) return;
this.logger.log(
`Message from Zalo UID ${zaloUid.slice(0, 6)}***: msgId=${msgId ?? 'unknown'} length=${text.length}`,
WEBHOOK_CONTEXT,
);
// Find linked user if any
const link = await this.prisma.oAuthAccount.findFirst({
where: { provider: 'ZALO', providerUserId: zaloUid },
select: { userId: true },
});
if (link) {
this.logger.log(
`Message from linked user ${link.userId} via Zalo OA`,
WEBHOOK_CONTEXT,
);
}
}
}

View File

@@ -1,6 +1,6 @@
'use client';
import { useState, useCallback } from 'react';
import { useState, useCallback, useEffect } from 'react';
import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
@@ -32,6 +32,20 @@ const KYC_STEPS = [
const API_BASE_URL =
process.env['NEXT_PUBLIC_API_URL'] || 'http://localhost:3001/api/v1';
const MAX_FILE_SIZE_BYTES = 5 * 1024 * 1024; // 5MB
const ACCEPTED_MIME_TYPES = ['image/jpeg', 'image/jpg', 'image/png', 'image/webp', 'application/pdf'];
const ACCEPTED_ACCEPT_ATTR = 'image/jpeg,image/png,image/webp,application/pdf';
function validateFile(file: File): string | null {
if (!ACCEPTED_MIME_TYPES.includes(file.type)) {
return 'Định dạng không hợp lệ. Vui lòng chọn JPG, PNG, WEBP hoặc PDF.';
}
if (file.size > MAX_FILE_SIZE_BYTES) {
return 'Kích thước tệp vượt quá 5MB. Vui lòng chọn tệp nhỏ hơn.';
}
return null;
}
function getCsrfToken(): string | undefined {
const csrfMatch = document.cookie.match(/(?:^|;\s*)XSRF-TOKEN=([^;]*)/);
return csrfMatch?.[1] ? decodeURIComponent(csrfMatch[1]) : undefined;
@@ -71,8 +85,10 @@ function uploadFileWithProgress(
if (xhr.status >= 200 && xhr.status < 300) {
onProgress(100);
resolve();
} else if (xhr.status === 403) {
reject(new Error('Liên kết tải lên đã hết hạn. Vui lòng thử lại.'));
} else {
reject(new Error(`Upload thất bại (${xhr.status})`));
reject(new Error(`Tải ảnh thất bại (${xhr.status}). Vui lòng thử lại.`));
}
});
@@ -96,6 +112,48 @@ export default function KycPage() {
const [frontImage, setFrontImage] = useState<File | null>(null);
const [backImage, setBackImage] = useState<File | null>(null);
const [selfieImage, setSelfieImage] = useState<File | null>(null);
const [frontPreview, setFrontPreview] = useState<string | null>(null);
const [backPreview, setBackPreview] = useState<string | null>(null);
const [selfiePreview, setSelfiePreview] = useState<string | null>(null);
// Revoke object URLs on cleanup to avoid memory leaks
useEffect(() => {
return () => {
if (frontPreview) URL.revokeObjectURL(frontPreview);
if (backPreview) URL.revokeObjectURL(backPreview);
if (selfiePreview) URL.revokeObjectURL(selfiePreview);
};
}, [frontPreview, backPreview, selfiePreview]);
const handleFileSelect = useCallback(
(
field: 'front' | 'back' | 'selfie',
file: File | null,
) => {
if (!file) return;
const validationError = validateFile(file);
if (validationError) {
setError(validationError);
return;
}
setError(null);
const previewUrl = file.type.startsWith('image/') ? URL.createObjectURL(file) : null;
if (field === 'front') {
if (frontPreview) URL.revokeObjectURL(frontPreview);
setFrontImage(file);
setFrontPreview(previewUrl);
} else if (field === 'back') {
if (backPreview) URL.revokeObjectURL(backPreview);
setBackImage(file);
setBackPreview(previewUrl);
} else {
if (selfiePreview) URL.revokeObjectURL(selfiePreview);
setSelfieImage(file);
setSelfiePreview(previewUrl);
}
},
[frontPreview, backPreview, selfiePreview],
);
const kycStatus = user?.kycStatus ?? 'NONE';
const kycInfo = KYC_STATUS_MAP[kycStatus] ?? { label: 'Chưa xác minh', variant: 'outline' as const, description: 'Bạn chưa gửi hồ sơ xác minh danh tính.' };
@@ -197,7 +255,7 @@ export default function KycPage() {
};
return (
<div className="space-y-6">
<div className="space-y-6" data-testid="kyc-page">
<div>
<h1 className="text-2xl font-bold sm:text-3xl">Xác minh danh tính (KYC)</h1>
<p className="mt-2 text-muted-foreground">
@@ -219,7 +277,11 @@ export default function KycPage() {
</Card>
{error && (
<div className="rounded-lg border border-red-200 bg-red-50 p-4 text-sm text-red-700">
<div
role="alert"
data-testid="kyc-error"
className="rounded-lg border border-red-200 bg-red-50 p-4 text-sm text-red-700"
>
{error}
<button onClick={() => setError(null)} className="ml-2 font-medium underline">
Đóng
@@ -228,7 +290,11 @@ export default function KycPage() {
)}
{success && (
<div className="rounded-lg border border-green-200 bg-green-50 p-4 text-sm text-green-700">
<div
role="status"
data-testid="kyc-success"
className="rounded-lg border border-green-200 bg-green-50 p-4 text-sm text-green-700"
>
Hồ KYC đã đưc gửi thành công. Vui lòng chờ 1-3 ngày làm việc đ đưc xem xét.
</div>
)}
@@ -299,41 +365,71 @@ export default function KycPage() {
{/* Step 2: Upload images */}
{currentStep === 2 && (
<>
<p className="text-xs text-muted-foreground">
Đnh dạng hỗ trợ: JPG, PNG, WEBP, PDF. Kích thước tối đa: 5MB.
</p>
<div className="space-y-2">
<Label htmlFor="frontImg">nh mặt trước *</Label>
<Input
id="frontImg"
data-testid="kyc-front-input"
type="file"
accept="image/*"
onChange={(e) => setFrontImage(e.target.files?.[0] ?? null)}
accept={ACCEPTED_ACCEPT_ATTR}
onChange={(e) => handleFileSelect('front', e.target.files?.[0] ?? null)}
/>
{frontImage && (
<p className="text-xs text-muted-foreground">{frontImage.name}</p>
)}
{frontPreview && (
<img
src={frontPreview}
alt="Xem trước mặt trước"
data-testid="kyc-front-preview"
className="mt-2 max-h-48 rounded-md border object-contain"
/>
)}
</div>
<div className="space-y-2">
<Label htmlFor="backImg">nh mặt sau</Label>
<Input
id="backImg"
data-testid="kyc-back-input"
type="file"
accept="image/*"
onChange={(e) => setBackImage(e.target.files?.[0] ?? null)}
accept={ACCEPTED_ACCEPT_ATTR}
onChange={(e) => handleFileSelect('back', e.target.files?.[0] ?? null)}
/>
{backImage && (
<p className="text-xs text-muted-foreground">{backImage.name}</p>
)}
{backPreview && (
<img
src={backPreview}
alt="Xem trước mặt sau"
data-testid="kyc-back-preview"
className="mt-2 max-h-48 rounded-md border object-contain"
/>
)}
</div>
<div className="space-y-2">
<Label htmlFor="selfieImg">nh selfie cầm giấy tờ</Label>
<Input
id="selfieImg"
data-testid="kyc-selfie-input"
type="file"
accept="image/*"
onChange={(e) => setSelfieImage(e.target.files?.[0] ?? null)}
accept={ACCEPTED_ACCEPT_ATTR}
onChange={(e) => handleFileSelect('selfie', e.target.files?.[0] ?? null)}
/>
{selfieImage && (
<p className="text-xs text-muted-foreground">{selfieImage.name}</p>
)}
{selfiePreview && (
<img
src={selfiePreview}
alt="Xem trước selfie"
data-testid="kyc-selfie-preview"
className="mt-2 max-h-48 rounded-md border object-contain"
/>
)}
</div>
</>
)}
@@ -386,7 +482,12 @@ export default function KycPage() {
{/* Navigation buttons */}
<div className="flex justify-between pt-2">
{currentStep > 1 ? (
<Button variant="outline" onClick={() => setCurrentStep((s) => s - 1)} disabled={submitting}>
<Button
variant="outline"
data-testid="kyc-back-button"
onClick={() => setCurrentStep((s) => s - 1)}
disabled={submitting}
>
Quay lại
</Button>
) : (
@@ -394,11 +495,16 @@ export default function KycPage() {
)}
{currentStep < 3 ? (
<Button
data-testid="kyc-next-button"
onClick={() => {
if (currentStep === 1 && !documentNumber.trim()) {
setError('Vui lòng nhập số giấy tờ');
return;
}
if (currentStep === 2 && !frontImage) {
setError('Vui lòng tải ảnh mặt trước');
return;
}
setError(null);
setCurrentStep((s) => s + 1);
}}
@@ -406,7 +512,11 @@ export default function KycPage() {
Tiếp tục
</Button>
) : (
<Button onClick={handleSubmit} disabled={submitting}>
<Button
data-testid="kyc-submit-button"
onClick={handleSubmit}
disabled={submitting}
>
{submitting ? 'Đang gửi...' : 'Gửi hồ sơ xác minh'}
</Button>
)}

View File

@@ -1,12 +1,23 @@
'use client';
import { Building2, LayoutGrid, Map } from 'lucide-react';
import { Building2, LayoutGrid, List, Map, MapPin } from 'lucide-react';
import dynamic from 'next/dynamic';
import Image from 'next/image';
import * as React from 'react';
import { ProjectCard } from '@/components/du-an/project-card';
import { ProjectFilterBar } from '@/components/du-an/project-filter-bar';
import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button';
import type { SearchProjectsParams } from '@/lib/du-an-api';
import { Card } from '@/components/ui/card';
import { Link } from '@/i18n/navigation';
import { formatPrice } from '@/lib/currency';
import {
PROJECT_PROPERTY_TYPE_LABELS,
PROJECT_STATUS_COLORS,
PROJECT_STATUS_LABELS,
type ProjectSummary,
type SearchProjectsParams,
} from '@/lib/du-an-api';
import { useProjectsSearch } from '@/lib/hooks/use-du-an';
import { cn } from '@/lib/utils';
@@ -17,7 +28,7 @@ const ProjectMap = dynamic(
const PAGE_SIZE = 12;
type ViewMode = 'grid' | 'map';
type ViewMode = 'grid' | 'list' | 'map';
export default function DuAnPage() {
const [filters, setFilters] = React.useState<SearchProjectsParams>({
@@ -62,6 +73,19 @@ export default function DuAnPage() {
>
<LayoutGrid className="h-4 w-4" />
</button>
<button
type="button"
onClick={() => setViewMode('list')}
className={cn(
'rounded-md p-2 transition-colors',
viewMode === 'list'
? 'bg-primary text-primary-foreground'
: 'text-muted-foreground hover:text-foreground',
)}
aria-label="Xem dạng danh sách"
>
<List className="h-4 w-4" />
</button>
<button
type="button"
onClick={() => setViewMode('map')}
@@ -113,6 +137,12 @@ export default function DuAnPage() {
{viewMode === 'map' ? (
<ProjectMap projects={data.data} />
) : viewMode === 'list' ? (
<div className="space-y-4">
{data.data.map((project) => (
<ProjectListItem key={project.id} project={project} />
))}
</div>
) : (
<div className="grid gap-6 sm:grid-cols-2 lg:grid-cols-3">
{data.data.map((project) => (
@@ -121,8 +151,8 @@ export default function DuAnPage() {
</div>
)}
{/* Pagination (grid mode only) */}
{viewMode === 'grid' && data.totalPages > 1 && (
{/* Pagination (grid/list mode) */}
{viewMode !== 'map' && data.totalPages > 1 && (
<div className="mt-8 flex items-center justify-center gap-2">
<Button
variant="outline"
@@ -159,3 +189,75 @@ export default function DuAnPage() {
</div>
);
}
function ProjectListItem({ project }: { project: ProjectSummary }) {
const statusLabel = PROJECT_STATUS_LABELS[project.status];
const statusColor = PROJECT_STATUS_COLORS[project.status];
const propertyLabels = project.propertyTypes
.map((t) => PROJECT_PROPERTY_TYPE_LABELS[t])
.join(', ');
return (
<Link href={`/du-an/${project.slug}`}>
<Card className="group flex overflow-hidden transition-shadow hover:shadow-lg">
{/* Thumbnail */}
<div className="relative aspect-[4/3] w-48 shrink-0 overflow-hidden bg-muted sm:w-56 md:w-64">
{project.thumbnailUrl ? (
<Image
src={project.thumbnailUrl}
alt={project.name}
fill
className="object-cover transition-transform group-hover:scale-105"
sizes="256px"
/>
) : (
<div className="flex h-full items-center justify-center">
<Building2 className="h-10 w-10 text-muted-foreground/30" />
</div>
)}
<Badge
className={cn('absolute left-2 top-2 text-xs', statusColor)}
variant="secondary"
>
{statusLabel}
</Badge>
</div>
{/* Content */}
<div className="flex min-w-0 flex-1 flex-col justify-between p-4">
<div>
<h3 className="line-clamp-1 text-base font-semibold group-hover:text-primary">
{project.name}
</h3>
<div className="mt-1 flex items-center gap-1 text-sm text-muted-foreground">
<MapPin className="h-3.5 w-3.5 shrink-0" />
<span className="line-clamp-1">
{project.district}, {project.city}
</span>
</div>
<div className="mt-1 flex flex-wrap items-center gap-2 text-xs text-muted-foreground">
<span>{propertyLabels}</span>
<span>·</span>
<span>{project.totalUnits} căn</span>
<span>·</span>
<span>{project.developer.name}</span>
</div>
</div>
<div className="mt-2">
{project.minPrice ? (
<p className="text-sm font-bold text-primary">
{formatPrice(project.minPrice)}
{project.maxPrice && project.maxPrice !== project.minPrice && (
<span> {formatPrice(project.maxPrice)}</span>
)}
</p>
) : (
<p className="text-sm text-muted-foreground">Liên hệ</p>
)}
</div>
</div>
</Card>
</Link>
);
}

View File

@@ -29,7 +29,7 @@ export function ProjectFilterBar({ filters, onFilterChange }: ProjectFilterBarPr
updateFilter('q', search.trim());
};
const hasFilters = filters.city || filters.district || filters.status || filters.propertyType || filters.q;
const hasFilters = filters.city || filters.district || filters.status || filters.propertyType || filters.minPrice || filters.maxPrice || filters.q;
const clearAll = () => {
setSearch('');
@@ -104,6 +104,34 @@ export function ProjectFilterBar({ filters, onFilterChange }: ProjectFilterBarPr
aria-label="Quận/Huyện"
/>
<Input
type="number"
placeholder="Giá từ (tỷ)"
value={filters.minPrice ? String(Number(filters.minPrice) / 1_000_000_000) : ''}
onChange={(e) => {
const val = e.target.value;
updateFilter('minPrice', val ? String(Number(val) * 1_000_000_000) : '');
}}
className="w-28"
aria-label="Giá tối thiểu"
min={0}
step={0.1}
/>
<Input
type="number"
placeholder="Giá đến (tỷ)"
value={filters.maxPrice ? String(Number(filters.maxPrice) / 1_000_000_000) : ''}
onChange={(e) => {
const val = e.target.value;
updateFilter('maxPrice', val ? String(Number(val) * 1_000_000_000) : '');
}}
className="w-28"
aria-label="Giá tối đa"
min={0}
step={0.1}
/>
<select
value={filters.sort || ''}
onChange={(e) => updateFilter('sort', e.target.value)}

View File

@@ -6,13 +6,14 @@ import * as React from 'react';
import { AddToCompareButton } from '@/components/comparison/add-to-compare-button';
import { ImageGallery } from '@/components/listings/image-gallery';
import { InquiryModal } from '@/components/listings/inquiry-modal';
import { PriceHistoryChart } from '@/components/listings/price-history-chart';
import { SocialShare } from '@/components/listings/social-share';
import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { AiEstimateButton } from '@/components/valuation/ai-estimate-button';
import { formatPrice, formatPricePerM2 } from '@/lib/currency';
import type { ListingDetail, NeighborhoodScoreResult } from '@/lib/listings-api';
import type { ListingDetail, NeighborhoodScoreResult, PriceHistoryItem } from '@/lib/listings-api';
import { listingsApi } from '@/lib/listings-api';
import { PROPERTY_TYPES, DIRECTIONS, TRANSACTION_TYPES } from '@/lib/validations/listings';
@@ -59,6 +60,7 @@ export function ListingDetailClient({ listing }: ListingDetailClientProps) {
const propertyTypeLabel = getLabel(PROPERTY_TYPES, property.propertyType);
const [inquiryOpen, setInquiryOpen] = React.useState(false);
const [neighborhoodScore, setNeighborhoodScore] = React.useState<NeighborhoodScoreResult | null>(null);
const [priceHistory, setPriceHistory] = React.useState<PriceHistoryItem[]>([]);
React.useEffect(() => {
if (!property.district || !property.city) return;
@@ -68,6 +70,13 @@ export function ListingDetailClient({ listing }: ListingDetailClientProps) {
.catch(() => {/* silently ignore — section simply won't render */});
}, [property.district, property.city]);
React.useEffect(() => {
listingsApi
.getPriceHistory(listing.id)
.then(setPriceHistory)
.catch(() => {/* silently ignore */});
}, [listing.id]);
return (
<div className="mx-auto max-w-6xl px-4 py-6">
{/* Breadcrumb */}
@@ -201,26 +210,55 @@ export function ListingDetailClient({ listing }: ListingDetailClientProps) {
</CardContent>
</Card>
{/* Price History Chart */}
{priceHistory.length > 0 && (
<Card>
<CardHeader>
<CardTitle>Lịch sử giá</CardTitle>
</CardHeader>
<CardContent>
<PriceHistoryChart data={priceHistory} />
</CardContent>
</Card>
)}
{/* Neighborhood Score Radar Chart */}
{neighborhoodScore && (
<Card>
<CardHeader>
<CardTitle>Đánh giá khu vực</CardTitle>
</CardHeader>
<CardContent>
{neighborhoodScore ? (
<>
<div className="mb-3 flex items-center gap-2">
<span className="text-2xl font-bold text-primary">
{neighborhoodScore.totalScore.toFixed(1)}
</span>
<span className="text-sm text-muted-foreground">/10 điểm tổng</span>
<Badge
variant={
neighborhoodScore.totalScore > 7
? 'success'
: neighborhoodScore.totalScore >= 5
? 'warning'
: 'destructive'
}
className="px-3 py-1 text-lg font-bold"
>
{neighborhoodScore.totalScore.toFixed(1)}/10
</Badge>
<span className="text-sm text-muted-foreground">Điểm tổng khu vực</span>
</div>
<NeighborhoodRadarChart
categories={mapScoreToCategories(neighborhoodScore)}
height={300}
/>
</>
) : (
<div className="flex h-[200px] items-center justify-center rounded-lg bg-muted/50">
<p className="text-sm text-muted-foreground">
Chưa dữ liệu đánh giá khu vực này
</p>
</div>
)}
</CardContent>
</Card>
)}
</div>
{/* Sidebar */}

View File

@@ -0,0 +1,73 @@
'use client';
import {
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
ResponsiveContainer,
} from 'recharts';
import type { PriceHistoryItem } from '@/lib/listings-api';
interface PriceHistoryChartProps {
data: PriceHistoryItem[];
height?: number;
}
function priceToMillions(priceStr: string): number {
return Math.round(Number(priceStr) / 1_000_000);
}
function formatMillions(value: number): string {
if (value >= 1000) return `${(value / 1000).toFixed(1)} tỷ`;
return `${value} tr`;
}
export function PriceHistoryChart({ data, height = 280 }: PriceHistoryChartProps) {
if (data.length === 0) return null;
const chartData = [...data]
.sort((a, b) => new Date(a.changedAt).getTime() - new Date(b.changedAt).getTime())
.map((item) => ({
date: new Date(item.changedAt).toLocaleDateString('vi-VN', {
day: '2-digit',
month: '2-digit',
year: 'numeric',
}),
price: priceToMillions(item.newPrice),
}));
return (
<ResponsiveContainer width="100%" height={height}>
<LineChart data={chartData} margin={{ top: 5, right: 20, left: 0, bottom: 5 }}>
<CartesianGrid strokeDasharray="3 3" className="stroke-muted" />
<XAxis dataKey="date" tick={{ fontSize: 11 }} className="fill-muted-foreground" />
<YAxis
tick={{ fontSize: 11 }}
className="fill-muted-foreground"
tickFormatter={(v: number) => formatMillions(v)}
/>
<Tooltip
contentStyle={{
backgroundColor: 'hsl(var(--card))',
border: '1px solid hsl(var(--border))',
borderRadius: '0.5rem',
fontSize: '0.875rem',
}}
formatter={(value) => [formatMillions(Number(value)), 'Giá']}
/>
<Line
type="monotone"
dataKey="price"
stroke="hsl(var(--primary))"
strokeWidth={2}
dot={{ r: 4 }}
activeDot={{ r: 6 }}
/>
</LineChart>
</ResponsiveContainer>
);
}

View File

@@ -132,6 +132,14 @@ export interface SearchListingsParams {
limit?: number;
}
export interface PriceHistoryItem {
id: string;
oldPrice: string;
newPrice: string;
source: string;
changedAt: string;
}
export interface NeighborhoodScoreResult {
district: string;
city: string;
@@ -203,6 +211,9 @@ export const listingsApi = {
return res.json() as Promise<{ mediaId: string; url: string }>;
},
getPriceHistory: (listingId: string) =>
apiClient.get<PriceHistoryItem[]>(`/listings/${listingId}/price-history`),
getNeighborhoodScore: (district: string, city: string = 'Hồ Chí Minh') =>
apiClient.get<NeighborhoodScoreResult>(
`/analytics/neighborhoods/${encodeURIComponent(district)}/score?city=${encodeURIComponent(city)}`,

View File

@@ -29,10 +29,28 @@ class AVMv2PredictRequest(BaseModel):
0.0, ge=0, le=1, description="Flood zone risk score (0=safe, 1=high risk)"
)
# ── Neighborhood features ─────────────────────────────
neighborhood_score: float = Field(
0.5, ge=0, le=1,
description="Overall neighborhood quality score (0-1, aggregated from safety, amenities, walkability)",
)
# ── Physical features ──────────────────────────────────
property_type: str = Field(..., description="e.g. apartment, house, villa, land")
area_m2: float = Field(..., gt=0, description="Property area in m²")
rooms: int = Field(0, ge=0, description="Total rooms (bedrooms)")
floor_level: int = Field(
0, ge=0,
description="Floor level (0=ground or N/A, relevant for apartments/penthouses)",
)
total_floors: int = Field(
0, ge=0,
description="Total floors in the building (0=N/A)",
)
direction: str = Field(
"unknown",
description="Facing direction: north, south, east, west, northeast, northwest, southeast, southwest, unknown",
)
floor_ratio: float = Field(
1.0, gt=0, description="Total floor area / land area ratio"
)
@@ -41,6 +59,10 @@ class AVMv2PredictRequest(BaseModel):
has_parking: bool = Field(False, description="Property has dedicated parking")
has_pool: bool = Field(False, description="Property has swimming pool")
has_legal_paper: bool = Field(True, description="Has sổ đỏ/sổ hồng")
developer_reputation: float = Field(
0.5, ge=0, le=1,
description="Project developer reputation score (0-1, based on past projects, delivery record)",
)
# ── Market features ────────────────────────────────────
avg_price_district_3m_vnd_m2: float = Field(
@@ -183,3 +205,75 @@ class AVMv2ModelInfo(BaseModel):
metrics: dict
is_active: bool = Field(True)
ab_test_traffic_pct: float = Field(0.0, ge=0, le=1)
class AVMv2RollbackRequest(BaseModel):
"""Request to rollback to a specific model version."""
target_version: str = Field(..., min_length=1, description="Model version to roll back to")
class AVMv1Summary(BaseModel):
"""Compact summary of a v1 prediction for comparison."""
estimated_price_vnd: float
confidence: float
price_per_m2: float
price_range_low: float
price_range_high: float
class AVMv2Summary(BaseModel):
"""Compact summary of a v2 prediction for comparison."""
estimated_price_vnd: float
confidence: float
price_per_m2_vnd: float
price_range_low_vnd: float
price_range_high_vnd: float
model_version: str
ensemble_method: str
class ABComparisonRequest(BaseModel):
"""Request for A/B comparison between v1 and v2."""
district: str = Field(..., min_length=1)
city: str = Field(..., min_length=1)
property_type: str = Field(...)
area_m2: float = Field(..., gt=0)
rooms: int = Field(0, ge=0)
bedrooms: int = Field(0, ge=0, description="Alias for rooms, used by v1")
floors: int = Field(0, ge=0)
frontage: float = Field(0.0, ge=0)
has_legal_paper: bool = Field(True)
# v2-specific features (optional, defaults applied)
neighborhood_score: float = Field(0.5, ge=0, le=1)
distance_to_cbd_km: float = Field(0.0, ge=0)
distance_to_metro_km: float = Field(0.0, ge=0)
flood_zone_risk: float = Field(0.0, ge=0, le=1)
building_age_years: int = Field(0, ge=0)
floor_level: int = Field(0, ge=0)
total_floors: int = Field(0, ge=0)
direction: str = Field("unknown")
has_elevator: bool = Field(False)
has_parking: bool = Field(False)
has_pool: bool = Field(False)
developer_reputation: float = Field(0.5, ge=0, le=1)
renovation_score: float = Field(0.5, ge=0, le=1)
view_quality: float = Field(0.5, ge=0, le=1)
interior_quality: float = Field(0.5, ge=0, le=1)
month: int = Field(1, ge=1, le=12)
quarter: int = Field(1, ge=1, le=4)
is_year_end: bool = Field(False)
class ABComparisonResponse(BaseModel):
"""Side-by-side A/B comparison of v1 vs v2 predictions."""
v1: AVMv1Summary
v2: AVMv2Summary
price_diff_vnd: float = Field(..., description="v2 - v1 price difference")
price_diff_pct: float = Field(..., description="Percentage difference ((v2-v1)/v1 * 100)")
confidence_diff: float = Field(..., description="v2 - v1 confidence difference")
recommendation: str = Field(..., description="Which model to prefer and why")

View File

@@ -1,11 +1,14 @@
"""AVM v2 ensemble router — residential property valuation."""
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from app.models.avm_v2 import (
ABComparisonRequest,
ABComparisonResponse,
AVMv2ModelInfo,
AVMv2PredictRequest,
AVMv2PredictResponse,
AVMv2RollbackRequest,
AVMv2TrainRequest,
AVMv2TrainResponse,
)
@@ -28,12 +31,43 @@ def predict_v2(req: AVMv2PredictRequest) -> AVMv2PredictResponse:
def train_v2(req: AVMv2TrainRequest) -> AVMv2TrainResponse:
"""Trigger model retraining with Optuna hyperparameter optimization.
Requires training data pipeline (Phase 3). Currently returns scaffold.
Loads training data from the model directory, runs Optuna for each
model in the ensemble, saves versioned artifacts, and registers
the new version in the model registry.
"""
return avm_v2_service.train(req)
@router.post("/compare-v1", response_model=ABComparisonResponse)
def compare_v1(req: ABComparisonRequest) -> ABComparisonResponse:
"""Compare v1 (single-model) vs v2 (ensemble) predictions side by side.
Runs both models on the same property and returns price difference,
confidence delta, and a recommendation on which to prefer.
"""
return avm_v2_service.compare_v1(req)
@router.get("/model-info", response_model=AVMv2ModelInfo)
def model_info_v2() -> AVMv2ModelInfo:
"""Get current active ensemble model information."""
return avm_v2_service.get_model_info()
@router.get("/versions", response_model=list[AVMv2ModelInfo])
def list_versions() -> list[AVMv2ModelInfo]:
"""List all registered model versions with their metrics and status."""
return avm_v2_service.list_versions()
@router.post("/rollback", response_model=AVMv2ModelInfo)
def rollback(req: AVMv2RollbackRequest) -> AVMv2ModelInfo:
"""Rollback to a previously trained model version.
Copies the target version's artifacts to the active model directory,
reloads models, and updates the registry.
"""
try:
return avm_v2_service.rollback(req.target_version)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -5,19 +5,27 @@ Ensemble weights: XGBoost 0.4, LightGBM 0.35, CatBoost 0.25.
Confidence = 1 - CV(3 predictions), where CV = std / mean.
"""
import json
import logging
import os
import shutil
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import numpy as np
from app.models.avm import AVMPredictRequest
from app.models.avm_v2 import (
ABComparisonRequest,
ABComparisonResponse,
AVMv1Summary,
AVMv2Comparable,
AVMv2FeatureImportance,
AVMv2ModelInfo,
AVMv2PredictRequest,
AVMv2PredictResponse,
AVMv2Summary,
AVMv2TrainRequest,
AVMv2TrainResponse,
ModelPrediction,
@@ -42,16 +50,22 @@ FEATURE_NAMES = [
"distance_to_park_km",
"distance_to_mall_km",
"flood_zone_risk",
# Physical (8)
# Neighborhood (1)
"neighborhood_score",
# Physical (13)
"property_type_encoded",
"area_m2",
"rooms",
"floor_level",
"total_floors",
"direction_encoded",
"floor_ratio",
"building_age_years",
"has_elevator",
"has_parking",
"has_pool",
"has_legal_paper",
"developer_reputation",
# Market (6)
"avg_price_district_3m_vnd_m2",
"listing_density",
@@ -71,6 +85,18 @@ FEATURE_NAMES = [
"is_year_end",
]
DIRECTION_MAP = {
"south": 0,
"southeast": 1,
"east": 2,
"southwest": 3,
"northeast": 4,
"west": 5,
"northwest": 6,
"north": 7,
"unknown": 4, # neutral mid-value
}
PROPERTY_TYPE_MAP = {
"apartment": 0,
"house": 1,
@@ -101,7 +127,7 @@ def _encode_features(req: AVMv2PredictRequest) -> np.ndarray:
month_rad = 2 * np.pi * req.month / 12.0
return np.array(
[[
# Location
# Location (7)
req.distance_to_cbd_km,
req.distance_to_metro_km,
req.distance_to_school_km,
@@ -109,30 +135,36 @@ def _encode_features(req: AVMv2PredictRequest) -> np.ndarray:
req.distance_to_park_km,
req.distance_to_mall_km,
req.flood_zone_risk,
# Physical
# Neighborhood (1)
req.neighborhood_score,
# Physical (13)
PROPERTY_TYPE_MAP.get(req.property_type.lower(), 1),
req.area_m2,
req.rooms,
float(req.floor_level),
float(req.total_floors),
float(DIRECTION_MAP.get(req.direction.lower(), 4)),
req.floor_ratio,
req.building_age_years,
1.0 if req.has_elevator else 0.0,
1.0 if req.has_parking else 0.0,
1.0 if req.has_pool else 0.0,
1.0 if req.has_legal_paper else 0.0,
# Market
req.developer_reputation,
# Market (6)
req.avg_price_district_3m_vnd_m2,
req.listing_density,
req.absorption_rate,
req.dom_avg,
req.price_momentum_30d,
req.yoy_change,
# LLM-extracted
# LLM-extracted (5)
req.renovation_score,
req.view_quality,
req.interior_quality,
req.noise_level,
req.natural_light,
# Temporal
# Temporal (3)
np.sin(month_rad),
np.cos(month_rad),
1.0 if req.is_year_end else 0.0,
@@ -314,6 +346,9 @@ class AVMv2EnsembleService:
metro_adj = 1.0 + max(0.0, (2.0 - req.distance_to_metro_km) * 0.05)
flood_adj = 1.0 - req.flood_zone_risk * 0.15
# Neighborhood adjustment: ±15% swing around 0.5 midpoint
neighborhood_adj = 1.0 + (req.neighborhood_score - 0.5) * 0.30
# Physical adjustments
room_adj = 1.0 + req.rooms * 0.015
age_adj = max(0.75, 1.0 - req.building_age_years * 0.008)
@@ -325,6 +360,34 @@ class AVMv2EnsembleService:
)
legal_adj = 1.0 if req.has_legal_paper else 0.70
# Floor level premium (apartments/penthouses: higher floors = premium)
floor_adj = 1.0
if req.floor_level > 0 and req.property_type.lower() in ("apartment", "penthouse"):
if req.total_floors > 0:
relative_floor = req.floor_level / req.total_floors
# Mid-to-high floors get up to +8% premium, ground floor -3%
floor_adj = 1.0 + (relative_floor - 0.3) * 0.12
floor_adj = max(0.97, min(1.08, floor_adj))
else:
# No total_floors info: mild premium for higher floors
floor_adj = min(1.08, 1.0 + req.floor_level * 0.003)
# Direction premium (Vietnamese preference: south/southeast best)
direction_adj = {
"south": 1.05,
"southeast": 1.04,
"east": 1.02,
"southwest": 1.01,
"northeast": 1.0,
"west": 0.98,
"northwest": 0.97,
"north": 0.96,
"unknown": 1.0,
}.get(req.direction.lower(), 1.0)
# Developer reputation: ±10% swing
developer_adj = 1.0 + (req.developer_reputation - 0.5) * 0.20
# Market adjustments
if req.avg_price_district_3m_vnd_m2 > 0:
market_adj = req.avg_price_district_3m_vnd_m2 / (base * 1_000_000)
@@ -352,10 +415,14 @@ class AVMv2EnsembleService:
* cbd_adj
* metro_adj
* flood_adj
* neighborhood_adj
* room_adj
* age_adj
* amenity_adj
* legal_adj
* floor_adj
* direction_adj
* developer_adj
* market_adj
* momentum_adj
* quality_adj
@@ -402,16 +469,20 @@ class AVMv2EnsembleService:
# Heuristic driver ranking
drivers = [
AVMv2FeatureImportance(feature="area_m2", importance=0.18),
AVMv2FeatureImportance(feature="avg_price_district_3m_vnd_m2", importance=0.15),
AVMv2FeatureImportance(feature="property_type_encoded", importance=0.12),
AVMv2FeatureImportance(feature="distance_to_cbd_km", importance=0.10),
AVMv2FeatureImportance(feature="renovation_score", importance=0.08),
AVMv2FeatureImportance(feature="building_age_years", importance=0.07),
AVMv2FeatureImportance(feature="has_legal_paper", importance=0.06),
AVMv2FeatureImportance(feature="distance_to_metro_km", importance=0.05),
AVMv2FeatureImportance(feature="interior_quality", importance=0.05),
AVMv2FeatureImportance(feature="price_momentum_30d", importance=0.04),
AVMv2FeatureImportance(feature="area_m2", importance=0.14),
AVMv2FeatureImportance(feature="avg_price_district_3m_vnd_m2", importance=0.12),
AVMv2FeatureImportance(feature="neighborhood_score", importance=0.10),
AVMv2FeatureImportance(feature="property_type_encoded", importance=0.10),
AVMv2FeatureImportance(feature="distance_to_cbd_km", importance=0.08),
AVMv2FeatureImportance(feature="developer_reputation", importance=0.07),
AVMv2FeatureImportance(feature="renovation_score", importance=0.07),
AVMv2FeatureImportance(feature="building_age_years", importance=0.06),
AVMv2FeatureImportance(feature="direction_encoded", importance=0.05),
AVMv2FeatureImportance(feature="floor_level", importance=0.05),
AVMv2FeatureImportance(feature="has_legal_paper", importance=0.05),
AVMv2FeatureImportance(feature="distance_to_metro_km", importance=0.04),
AVMv2FeatureImportance(feature="interior_quality", importance=0.04),
AVMv2FeatureImportance(feature="price_momentum_30d", importance=0.03),
]
return AVMv2PredictResponse(
@@ -476,52 +547,455 @@ class AVMv2EnsembleService:
# ── Training pipeline ───────────────────────────────────────
def train(self, req: AVMv2TrainRequest) -> AVMv2TrainResponse:
"""Train the ensemble models.
"""Train the ensemble models on available data.
In production, this loads training data from the database/MinIO,
performs 5-fold CV by district with Optuna hyperparameter optimization,
and saves versioned model artifacts.
Currently returns a scaffold response. Real training requires
the data pipeline from Phase 3.
Pipeline:
1. Load training data from CSV/database export
2. Feature engineering (encode, normalize, cyclical)
3. Train/val/test split stratified by district
4. For each model: Optuna hyperparameter optimization
5. Save versioned artifacts + register in model registry
"""
from app.config import settings
version = f"ensemble-v2-{datetime.now(timezone.utc).strftime('%Y%m%d-%H%M%S')}"
logger.info("Training AVM v2 ensemble — version %s, trials=%d", version, req.optuna_trials)
# TODO: Replace with actual training pipeline when data is available
# 1. Load data from PostgreSQL/MinIO
# 2. Feature engineering (encode categoricals, normalize, cyclical)
# 3. 80/10/10 split stratified by district
# 4. For each model (XGBoost, LightGBM, CatBoost):
# a. Optuna study with req.optuna_trials trials
# b. 5-fold CV grouped by district
# c. Train on best params
# 5. Save artifacts to MinIO with version tag
# 6. Register in model registry
model_dir = Path(settings.model_path)
data_path = model_dir / "training_data.csv"
# Check for training data
if not data_path.exists():
logger.warning("No training data found at %s — returning scaffold", data_path)
return AVMv2TrainResponse(
model_version=version,
metrics={
"mae": 0.0,
"mape": 0.0,
"rmse": 0.0,
"r2": 0.0,
},
metrics={"mae": 0.0, "mape": 0.0, "rmse": 0.0, "r2": 0.0},
district_metrics={},
training_samples=0,
validation_samples=0,
test_samples=0,
best_params={
"xgboost": {"n_estimators": 500, "max_depth": 6, "learning_rate": 0.05},
"lightgbm": {"n_estimators": 500, "num_leaves": 31, "learning_rate": 0.05},
"catboost": {"iterations": 500, "depth": 6, "learning_rate": 0.05},
},
best_params={},
)
# Load and prepare data
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
df = pd.read_csv(data_path)
logger.info("Loaded %d training samples", len(df))
# Feature engineering
X, y, groups = self._prepare_training_data(df)
if len(X) < 50:
logger.warning("Insufficient training data (%d samples)", len(X))
return AVMv2TrainResponse(
model_version=version,
metrics={"mae": 0.0, "mape": 0.0, "rmse": 0.0, "r2": 0.0},
district_metrics={},
training_samples=len(X),
validation_samples=0,
test_samples=0,
best_params={},
)
# Split: train/val/test grouped by district
gss_test = GroupShuffleSplit(n_splits=1, test_size=req.test_size, random_state=42)
train_val_idx, test_idx = next(gss_test.split(X, y, groups))
X_trainval, y_trainval = X[train_val_idx], y[train_val_idx]
X_test, y_test = X[test_idx], y[test_idx]
groups_trainval = groups[train_val_idx]
val_ratio = req.val_size / (1.0 - req.test_size)
gss_val = GroupShuffleSplit(n_splits=1, test_size=val_ratio, random_state=42)
train_idx, val_idx = next(gss_val.split(X_trainval, y_trainval, groups_trainval))
X_train, y_train = X_trainval[train_idx], y_trainval[train_idx]
X_val, y_val = X_trainval[val_idx], y_trainval[val_idx]
logger.info("Split: train=%d, val=%d, test=%d", len(X_train), len(X_val), len(X_test))
# Train each model with Optuna
best_params: dict[str, dict] = {}
trained_models: dict[str, Any] = {}
xgb_params, xgb_model = self._train_xgboost(X_train, y_train, X_val, y_val, req.optuna_trials)
if xgb_model:
best_params["xgboost"] = xgb_params
trained_models["xgboost"] = xgb_model
lgb_params, lgb_model = self._train_lightgbm(X_train, y_train, X_val, y_val, req.optuna_trials)
if lgb_model:
best_params["lightgbm"] = lgb_params
trained_models["lightgbm"] = lgb_model
cat_params, cat_model = self._train_catboost(X_train, y_train, X_val, y_val, req.optuna_trials)
if cat_model:
best_params["catboost"] = cat_params
trained_models["catboost"] = cat_model
# Evaluate ensemble on test set
metrics = self._evaluate_ensemble(trained_models, X_test, y_test)
# Save versioned artifacts
version_dir = model_dir / "versions" / version
version_dir.mkdir(parents=True, exist_ok=True)
for name, model in trained_models.items():
self._save_model(name, model, version_dir)
# Also save to active model directory
self._save_model(name, model, model_dir)
# Register in model registry
registry_entry = AVMv2ModelInfo(
model_version=version,
created_at=datetime.now(timezone.utc).isoformat(),
metrics=metrics,
is_active=True,
ab_test_traffic_pct=0.0,
)
self._register_model(registry_entry, model_dir)
# Reload models
self._models = trained_models
self._model_version = version
return AVMv2TrainResponse(
model_version=version,
metrics=metrics,
district_metrics={},
training_samples=len(X_train),
validation_samples=len(X_val),
test_samples=len(X_test),
best_params=best_params,
)
def _prepare_training_data(
self, df: "pd.DataFrame"
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Encode a DataFrame into feature matrix, target vector, and group labels."""
import pandas as pd # noqa: F811
feature_cols = [
"distance_to_cbd_km", "distance_to_metro_km", "distance_to_school_km",
"distance_to_hospital_km", "distance_to_park_km", "distance_to_mall_km",
"flood_zone_risk", "neighborhood_score",
"property_type", "area_m2", "rooms", "floor_level", "total_floors",
"direction", "floor_ratio", "building_age_years",
"has_elevator", "has_parking", "has_pool", "has_legal_paper",
"developer_reputation",
"avg_price_district_3m_vnd_m2", "listing_density", "absorption_rate",
"dom_avg", "price_momentum_30d", "yoy_change",
"renovation_score", "view_quality", "interior_quality",
"noise_level", "natural_light",
"month",
]
# Fill missing columns with defaults
for col in feature_cols:
if col not in df.columns:
df[col] = 0.0 if col not in ("property_type", "direction") else "unknown"
# Encode categoricals
df["property_type_encoded"] = df["property_type"].str.lower().map(PROPERTY_TYPE_MAP).fillna(1)
df["direction_encoded"] = df["direction"].str.lower().map(DIRECTION_MAP).fillna(4)
# Cyclical month encoding
month_rad = 2 * np.pi * df["month"].astype(float) / 12.0
df["month_sin"] = np.sin(month_rad)
df["month_cos"] = np.cos(month_rad)
df["is_year_end_encoded"] = (df["month"].astype(int).isin([10, 11, 12])).astype(float)
# Boolean encoding
for col in ["has_elevator", "has_parking", "has_pool", "has_legal_paper"]:
df[col] = df[col].astype(float)
encoded_feature_cols = [
"distance_to_cbd_km", "distance_to_metro_km", "distance_to_school_km",
"distance_to_hospital_km", "distance_to_park_km", "distance_to_mall_km",
"flood_zone_risk", "neighborhood_score",
"property_type_encoded", "area_m2", "rooms", "floor_level", "total_floors",
"direction_encoded", "floor_ratio", "building_age_years",
"has_elevator", "has_parking", "has_pool", "has_legal_paper",
"developer_reputation",
"avg_price_district_3m_vnd_m2", "listing_density", "absorption_rate",
"dom_avg", "price_momentum_30d", "yoy_change",
"renovation_score", "view_quality", "interior_quality",
"noise_level", "natural_light",
"month_sin", "month_cos", "is_year_end_encoded",
]
X = df[encoded_feature_cols].values.astype(np.float64)
y = np.log(df["price_vnd"].values.astype(np.float64)) # Log-price target
groups = df.get("district", pd.Series(["default"] * len(df))).values
return X, y, groups
def _train_xgboost(
self,
X_train: np.ndarray, y_train: np.ndarray,
X_val: np.ndarray, y_val: np.ndarray,
n_trials: int,
) -> tuple[dict, Any]:
"""Train XGBoost with Optuna hyperparameter optimization."""
try:
import optuna
import xgboost as xgb
optuna.logging.set_verbosity(optuna.logging.WARNING)
dtrain = xgb.DMatrix(X_train, label=y_train, feature_names=FEATURE_NAMES)
dval = xgb.DMatrix(X_val, label=y_val, feature_names=FEATURE_NAMES)
def objective(trial: optuna.Trial) -> float:
params = {
"objective": "reg:squarederror",
"eval_metric": "rmse",
"max_depth": trial.suggest_int("max_depth", 3, 10),
"learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True),
"subsample": trial.suggest_float("subsample", 0.6, 1.0),
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.6, 1.0),
"min_child_weight": trial.suggest_int("min_child_weight", 1, 10),
"reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
"reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
"verbosity": 0,
}
n_rounds = trial.suggest_int("n_rounds", 100, 1000)
model = xgb.train(
params, dtrain, num_boost_round=n_rounds,
evals=[(dval, "val")], verbose_eval=False,
early_stopping_rounds=50,
)
preds = model.predict(dval)
rmse = float(np.sqrt(np.mean((preds - y_val) ** 2)))
return rmse
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
# Retrain with best params on full train set
best = study.best_params
n_rounds = best.pop("n_rounds", 500)
best.update({"objective": "reg:squarederror", "eval_metric": "rmse", "verbosity": 0})
model = xgb.train(
best, dtrain, num_boost_round=n_rounds,
evals=[(dval, "val")], verbose_eval=False,
early_stopping_rounds=50,
)
logger.info("XGBoost trained — best RMSE: %.4f", study.best_value)
return best, model
except Exception as e:
logger.warning("XGBoost training failed: %s", e)
return {}, None
def _train_lightgbm(
self,
X_train: np.ndarray, y_train: np.ndarray,
X_val: np.ndarray, y_val: np.ndarray,
n_trials: int,
) -> tuple[dict, Any]:
"""Train LightGBM with Optuna hyperparameter optimization."""
try:
import lightgbm as lgb
import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
dtrain = lgb.Dataset(X_train, label=y_train, feature_name=FEATURE_NAMES)
dval = lgb.Dataset(X_val, label=y_val, feature_name=FEATURE_NAMES, reference=dtrain)
def objective(trial: optuna.Trial) -> float:
params = {
"objective": "regression",
"metric": "rmse",
"num_leaves": trial.suggest_int("num_leaves", 15, 127),
"learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True),
"feature_fraction": trial.suggest_float("feature_fraction", 0.6, 1.0),
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.6, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 50),
"reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
"reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
"verbosity": -1,
}
n_rounds = trial.suggest_int("n_rounds", 100, 1000)
callbacks = [lgb.early_stopping(50, verbose=False), lgb.log_evaluation(period=0)]
model = lgb.train(
params, dtrain, num_boost_round=n_rounds,
valid_sets=[dval], callbacks=callbacks,
)
preds = model.predict(X_val)
rmse = float(np.sqrt(np.mean((preds - y_val) ** 2)))
return rmse
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
best = study.best_params
n_rounds = best.pop("n_rounds", 500)
best.update({"objective": "regression", "metric": "rmse", "verbosity": -1})
callbacks = [lgb.early_stopping(50, verbose=False), lgb.log_evaluation(period=0)]
model = lgb.train(
best, dtrain, num_boost_round=n_rounds,
valid_sets=[dval], callbacks=callbacks,
)
logger.info("LightGBM trained — best RMSE: %.4f", study.best_value)
return best, model
except Exception as e:
logger.warning("LightGBM training failed: %s", e)
return {}, None
def _train_catboost(
self,
X_train: np.ndarray, y_train: np.ndarray,
X_val: np.ndarray, y_val: np.ndarray,
n_trials: int,
) -> tuple[dict, Any]:
"""Train CatBoost with Optuna hyperparameter optimization."""
try:
import optuna
from catboost import CatBoostRegressor, Pool
optuna.logging.set_verbosity(optuna.logging.WARNING)
train_pool = Pool(X_train, label=y_train, feature_names=FEATURE_NAMES)
val_pool = Pool(X_val, label=y_val, feature_names=FEATURE_NAMES)
def objective(trial: optuna.Trial) -> float:
params = {
"iterations": trial.suggest_int("iterations", 100, 1000),
"depth": trial.suggest_int("depth", 4, 10),
"learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True),
"l2_leaf_reg": trial.suggest_float("l2_leaf_reg", 1e-8, 10.0, log=True),
"bagging_temperature": trial.suggest_float("bagging_temperature", 0.0, 1.0),
"random_strength": trial.suggest_float("random_strength", 1e-8, 10.0, log=True),
"verbose": 0,
"loss_function": "RMSE",
"early_stopping_rounds": 50,
}
model = CatBoostRegressor(**params)
model.fit(train_pool, eval_set=val_pool, verbose=0)
preds = model.predict(val_pool)
rmse = float(np.sqrt(np.mean((preds - y_val) ** 2)))
return rmse
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
best = study.best_params
best.update({"verbose": 0, "loss_function": "RMSE", "early_stopping_rounds": 50})
model = CatBoostRegressor(**best)
model.fit(train_pool, eval_set=val_pool, verbose=0)
logger.info("CatBoost trained — best RMSE: %.4f", study.best_value)
return best, model
except Exception as e:
logger.warning("CatBoost training failed: %s", e)
return {}, None
def _evaluate_ensemble(
self, models: dict[str, Any], X_test: np.ndarray, y_test: np.ndarray
) -> dict:
"""Evaluate ensemble performance on a test set."""
if not models:
return {"mae": 0.0, "mape": 0.0, "rmse": 0.0, "r2": 0.0}
predictions = []
weights = []
for name, model in models.items():
w = ENSEMBLE_WEIGHTS.get(name, 0.0)
features = X_test
if name == "xgboost":
import xgboost as xgb
preds = model.predict(xgb.DMatrix(features, feature_names=FEATURE_NAMES))
elif name == "lightgbm":
preds = model.predict(features)
elif name == "catboost":
preds = model.predict(features)
else:
continue
predictions.append(preds * w)
weights.append(w)
total_weight = sum(weights) or 1.0
ensemble_preds = sum(predictions) / total_weight
# Metrics in log-space then convert
y_actual = np.exp(y_test)
y_pred = np.exp(ensemble_preds)
mae = float(np.mean(np.abs(y_actual - y_pred)))
mape = float(np.mean(np.abs((y_actual - y_pred) / y_actual))) * 100
rmse = float(np.sqrt(np.mean((y_actual - y_pred) ** 2)))
ss_res = np.sum((y_actual - y_pred) ** 2)
ss_tot = np.sum((y_actual - np.mean(y_actual)) ** 2)
r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else 0.0
return {
"mae": round(mae, 2),
"mape": round(mape, 2),
"rmse": round(rmse, 2),
"r2": round(r2, 4),
}
def _save_model(self, name: str, model: Any, directory: Path) -> None:
"""Save a trained model to the specified directory."""
if name == "xgboost":
model.save_model(str(directory / "avm_v2_xgboost.json"))
elif name == "lightgbm":
model.save_model(str(directory / "avm_v2_lightgbm.txt"))
elif name == "catboost":
model.save_model(str(directory / "avm_v2_catboost.cbm"))
# ── Model registry ──────────────────────────────────────────
def _get_registry_path(self, model_dir: Path | None = None) -> Path:
"""Get the path to the model registry JSON file."""
if model_dir is None:
from app.config import settings
model_dir = Path(settings.model_path)
return model_dir / "model_registry.json"
def _load_registry(self, model_dir: Path | None = None) -> list[dict]:
"""Load the model registry from disk."""
path = self._get_registry_path(model_dir)
if path.exists():
with open(path) as f:
return json.load(f)
return []
def _save_registry(self, entries: list[dict], model_dir: Path | None = None) -> None:
"""Save the model registry to disk."""
path = self._get_registry_path(model_dir)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(entries, f, indent=2)
def _register_model(self, info: AVMv2ModelInfo, model_dir: Path) -> None:
"""Register a new model version and mark it as active."""
entries = self._load_registry(model_dir)
# Deactivate previous active models
for entry in entries:
entry["is_active"] = False
entries.append({
"model_version": info.model_version,
"created_at": info.created_at,
"metrics": info.metrics,
"is_active": True,
"ab_test_traffic_pct": info.ab_test_traffic_pct,
})
self._save_registry(entries, model_dir)
self._model_registry = [
AVMv2ModelInfo(**e) for e in entries
]
def get_model_info(self) -> AVMv2ModelInfo:
"""Return current active model information."""
# Check registry for active model
entries = self._load_registry()
for entry in reversed(entries):
if entry.get("is_active"):
return AVMv2ModelInfo(**entry)
return AVMv2ModelInfo(
model_version=self._model_version,
created_at=datetime.now(timezone.utc).isoformat(),
@@ -530,6 +1004,142 @@ class AVMv2EnsembleService:
ab_test_traffic_pct=0.0,
)
def list_versions(self) -> list[AVMv2ModelInfo]:
"""List all registered model versions."""
entries = self._load_registry()
return [AVMv2ModelInfo(**e) for e in entries]
def rollback(self, target_version: str) -> AVMv2ModelInfo:
"""Rollback to a previously trained model version.
Copies the target version's artifacts to the active model directory
and updates the registry.
"""
from app.config import settings
model_dir = Path(settings.model_path)
version_dir = model_dir / "versions" / target_version
if not version_dir.exists():
raise ValueError(f"Model version {target_version} not found")
# Copy versioned artifacts to active directory
for artifact in version_dir.iterdir():
if artifact.is_file():
shutil.copy2(artifact, model_dir / artifact.name)
# Update registry
entries = self._load_registry(model_dir)
found = False
for entry in entries:
entry["is_active"] = entry["model_version"] == target_version
if entry["model_version"] == target_version:
found = True
if not found:
raise ValueError(f"Model version {target_version} not in registry")
self._save_registry(entries, model_dir)
# Reload models from disk
self._models = {}
self._load_models()
self._model_version = target_version
logger.info("Rolled back to model version %s", target_version)
active = next(e for e in entries if e["is_active"])
return AVMv2ModelInfo(**active)
# ── A/B comparison ─────────────────────────────────────────
def compare_v1(self, req: ABComparisonRequest) -> ABComparisonResponse:
"""Compare v1 and v2 predictions on the same property."""
from app.services.avm_service import avm_service
# Build v1 request
v1_req = AVMPredictRequest(
area=req.area_m2,
district=req.district,
city=req.city,
property_type=req.property_type,
bedrooms=req.bedrooms or req.rooms,
floors=req.floors,
frontage=req.frontage,
has_legal_paper=req.has_legal_paper,
)
v1_result = avm_service.predict(v1_req)
# Build v2 request
v2_req = AVMv2PredictRequest(
district=req.district,
city=req.city,
property_type=req.property_type,
area_m2=req.area_m2,
rooms=req.rooms or req.bedrooms,
has_legal_paper=req.has_legal_paper,
neighborhood_score=req.neighborhood_score,
distance_to_cbd_km=req.distance_to_cbd_km,
distance_to_metro_km=req.distance_to_metro_km,
flood_zone_risk=req.flood_zone_risk,
building_age_years=req.building_age_years,
floor_level=req.floor_level,
total_floors=req.total_floors,
direction=req.direction,
has_elevator=req.has_elevator,
has_parking=req.has_parking,
has_pool=req.has_pool,
developer_reputation=req.developer_reputation,
renovation_score=req.renovation_score,
view_quality=req.view_quality,
interior_quality=req.interior_quality,
month=req.month,
quarter=req.quarter,
is_year_end=req.is_year_end,
)
v2_result = self.predict(v2_req)
# Compute diffs
price_diff = v2_result.estimated_price_vnd - v1_result.estimated_price_vnd
price_diff_pct = (
(price_diff / v1_result.estimated_price_vnd * 100)
if v1_result.estimated_price_vnd > 0
else 0.0
)
confidence_diff = v2_result.confidence - v1_result.confidence
# Recommendation logic
if v2_result.confidence > v1_result.confidence + 0.05:
recommendation = "v2 — higher confidence from ensemble model agreement"
elif v1_result.confidence > v2_result.confidence + 0.05:
recommendation = "v1 — higher confidence, v2 models may disagree on this property"
elif abs(price_diff_pct) < 5:
recommendation = "Both models agree (< 5% price difference)"
else:
recommendation = "v2 — richer feature set captures more market factors"
return ABComparisonResponse(
v1=AVMv1Summary(
estimated_price_vnd=v1_result.estimated_price_vnd,
confidence=v1_result.confidence,
price_per_m2=v1_result.price_per_m2,
price_range_low=v1_result.price_range_low,
price_range_high=v1_result.price_range_high,
),
v2=AVMv2Summary(
estimated_price_vnd=v2_result.estimated_price_vnd,
confidence=v2_result.confidence,
price_per_m2_vnd=v2_result.price_per_m2_vnd,
price_range_low_vnd=v2_result.price_range_low_vnd,
price_range_high_vnd=v2_result.price_range_high_vnd,
model_version=v2_result.model_version,
ensemble_method=v2_result.ensemble_method,
),
price_diff_vnd=round(price_diff, -3),
price_diff_pct=round(price_diff_pct, 2),
confidence_diff=round(confidence_diff, 4),
recommendation=recommendation,
)
# Module-level singleton
avm_v2_service = AVMv2EnsembleService()

View File

@@ -65,9 +65,10 @@ def test_predict_v2_returns_drivers():
def test_predict_v2_with_full_features():
"""Predict with all features populated."""
"""Predict with all features populated (including new v2 features)."""
payload = {
**_PREDICT_PAYLOAD,
"neighborhood_score": 0.8,
"distance_to_cbd_km": 5.0,
"distance_to_metro_km": 0.8,
"distance_to_school_km": 0.5,
@@ -75,11 +76,15 @@ def test_predict_v2_with_full_features():
"distance_to_park_km": 0.3,
"distance_to_mall_km": 1.0,
"flood_zone_risk": 0.1,
"floor_level": 12,
"total_floors": 25,
"direction": "southeast",
"floor_ratio": 1.2,
"building_age_years": 5,
"has_elevator": True,
"has_parking": True,
"has_pool": False,
"developer_reputation": 0.9,
"avg_price_district_3m_vnd_m2": 85_000_000,
"listing_density": 12.5,
"absorption_rate": 0.3,
@@ -149,8 +154,93 @@ def test_predict_v2_invalid_area():
assert resp.status_code == 422
def test_train_v2_scaffold():
"""Training endpoint should return scaffold response."""
# ── New v2 features: neighborhood, floor, direction, developer ──
def test_predict_v2_neighborhood_premium():
"""High neighborhood score should increase price."""
low_nb = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "neighborhood_score": 0.2},
).json()
high_nb = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "neighborhood_score": 0.9},
).json()
assert high_nb["estimated_price_vnd"] > low_nb["estimated_price_vnd"]
def test_predict_v2_floor_level_premium():
"""Higher floor apartments should command a premium."""
ground = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "floor_level": 2, "total_floors": 25},
).json()
high = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "floor_level": 20, "total_floors": 25},
).json()
assert high["estimated_price_vnd"] > ground["estimated_price_vnd"]
def test_predict_v2_direction_premium():
"""South-facing properties should be priced higher than north-facing."""
south = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "direction": "south"},
).json()
north = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "direction": "north"},
).json()
assert south["estimated_price_vnd"] > north["estimated_price_vnd"]
def test_predict_v2_developer_reputation():
"""Properties from reputable developers should be valued higher."""
low_rep = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "developer_reputation": 0.2},
).json()
high_rep = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "developer_reputation": 0.9},
).json()
assert high_rep["estimated_price_vnd"] > low_rep["estimated_price_vnd"]
def test_predict_v2_direction_defaults_unknown():
"""Unknown direction should not affect price (neutral)."""
explicit = client.post(
"/avm/v2/predict",
json={**_PREDICT_PAYLOAD, "direction": "unknown"},
).json()
default = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD).json()
assert explicit["estimated_price_vnd"] == default["estimated_price_vnd"]
def test_predict_v2_drivers_include_new_features():
"""Drivers should include neighborhood_score, direction, floor_level."""
resp = client.post("/avm/v2/predict", json=_PREDICT_PAYLOAD)
data = resp.json()
driver_names = {d["feature"] for d in data["drivers"]}
assert "neighborhood_score" in driver_names
assert "direction_encoded" in driver_names
assert "floor_level" in driver_names
assert "developer_reputation" in driver_names
# ── Training & model info ───────────────────────────────────────
def test_train_v2_no_data():
"""Training without data returns scaffold with zero metrics."""
resp = client.post(
"/avm/v2/train",
json={"optuna_trials": 10},
@@ -159,10 +249,7 @@ def test_train_v2_scaffold():
data = resp.json()
assert "model_version" in data
assert "ensemble-v2-" in data["model_version"]
assert data["metrics"]["mae"] == 0.0 # scaffold returns zeros
assert "xgboost" in data["best_params"]
assert "lightgbm" in data["best_params"]
assert "catboost" in data["best_params"]
assert data["training_samples"] == 0
def test_model_info_v2():
@@ -172,3 +259,100 @@ def test_model_info_v2():
data = resp.json()
assert "model_version" in data
assert data["is_active"] is True
# ── Model versioning ────────────────────────────────────────────
def test_list_versions():
"""Versions endpoint returns a list."""
resp = client.get("/avm/v2/versions")
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, list)
def test_rollback_not_found():
"""Rollback to non-existent version returns 404."""
resp = client.post(
"/avm/v2/rollback",
json={"target_version": "nonexistent-version-xyz"},
)
assert resp.status_code == 404
# ── A/B comparison tests ─────────────────────────────────────
_COMPARE_PAYLOAD = {
"district": "Cầu Giấy",
"city": "Hà Nội",
"property_type": "apartment",
"area_m2": 80.0,
"rooms": 2,
"month": 3,
"quarter": 1,
}
def test_compare_v1_returns_both_models():
"""Compare endpoint returns v1 and v2 predictions."""
resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD)
assert resp.status_code == 200
data = resp.json()
assert "v1" in data
assert "v2" in data
assert data["v1"]["estimated_price_vnd"] > 0
assert data["v2"]["estimated_price_vnd"] > 0
assert 0 <= data["v1"]["confidence"] <= 1
assert 0 <= data["v2"]["confidence"] <= 1
def test_compare_v1_returns_diffs():
"""Compare endpoint computes price and confidence differences."""
resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD)
data = resp.json()
expected_diff = data["v2"]["estimated_price_vnd"] - data["v1"]["estimated_price_vnd"]
assert abs(data["price_diff_vnd"] - expected_diff) < 10_000 # rounding tolerance
assert "price_diff_pct" in data
assert isinstance(data["price_diff_pct"], float)
assert "confidence_diff" in data
def test_compare_v1_returns_recommendation():
"""Compare endpoint provides a recommendation string."""
resp = client.post("/avm/v2/compare-v1", json=_COMPARE_PAYLOAD)
data = resp.json()
assert "recommendation" in data
assert len(data["recommendation"]) > 0
def test_compare_v1_with_v2_features():
"""Compare endpoint passes v2-specific features correctly."""
payload = {
**_COMPARE_PAYLOAD,
"neighborhood_score": 0.8,
"distance_to_cbd_km": 5.0,
"distance_to_metro_km": 0.8,
"flood_zone_risk": 0.1,
"building_age_years": 3,
"floor_level": 15,
"total_floors": 30,
"direction": "southeast",
"has_elevator": True,
"has_parking": True,
"developer_reputation": 0.85,
"renovation_score": 0.9,
"view_quality": 0.8,
"interior_quality": 0.85,
}
resp = client.post("/avm/v2/compare-v1", json=payload)
assert resp.status_code == 200
data = resp.json()
# v2 should capture these extra features
assert data["v2"]["estimated_price_vnd"] > 0
assert data["v2"]["model_version"] is not None

View File

@@ -0,0 +1,2 @@
-- AlterTable: Add source column to PriceHistory
ALTER TABLE "PriceHistory" ADD COLUMN "source" TEXT NOT NULL DEFAULT 'manual_update';

View File

@@ -366,6 +366,7 @@ model PriceHistory {
listing Listing @relation(fields: [listingId], references: [id], onDelete: Cascade)
oldPrice BigInt
newPrice BigInt
source String @default("manual_update")
changedAt DateTime @default(now())
@@index([listingId, changedAt(sort: Desc)])