feat(notifications): production-ready WebSocket gateway (TEC-2766)

- Add RedisIoAdapter (shared/infra) for multi-instance Socket.IO fan-out
  with graceful fallback to the in-memory IoAdapter when Redis is
  unreachable.
- Pin Socket.IO heartbeat (pingInterval/pingTimeout/connectTimeout)
  via env-tunable gateway options for reconnect stability.
- Expose Prometheus metrics on /notifications: goodgo_ws_connected_clients
  (Gauge) and goodgo_ws_messages_total (Counter) with namespace/event/
  direction labels. Wired through MetricsService and tracked across
  connect/disconnect + emits.
- Unit tests: RedisIoAdapter connect/fallback/close, new MetricsService
  WS helpers, and gateway metric increments/decrements on auth paths.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
Ho Ngoc Hai
2026-04-18 15:06:25 +07:00
parent 5d4ecdeb2f
commit 329a821b4a
13 changed files with 410 additions and 5 deletions

View File

@@ -1,6 +1,7 @@
import { Module } from '@nestjs/common';
import { CqrsModule } from '@nestjs/cqrs';
import { AuthModule } from '@modules/auth';
import { MetricsModule } from '@modules/metrics';
import { SendNotificationHandler } from './application/commands/send-notification/send-notification.handler';
import { AgentVerifiedListener } from './application/listeners/agent-verified.listener';
import { EmailChangeRequestedListener } from './application/listeners/email-change-requested.listener';
@@ -53,7 +54,7 @@ const EventListeners = [
];
@Module({
imports: [CqrsModule, AuthModule],
imports: [CqrsModule, AuthModule, MetricsModule],
controllers: [NotificationsController, ZaloOaWebhookController],
providers: [
// Repositories

View File

@@ -36,6 +36,11 @@ describe('NotificationsGateway', () => {
getClient: ReturnType<typeof vi.fn>;
};
let mockNotificationRepo: { countUnreadByUserId: ReturnType<typeof vi.fn> };
let mockMetrics: {
recordWsConnection: ReturnType<typeof vi.fn>;
setWsConnectedClients: ReturnType<typeof vi.fn>;
recordWsMessage: ReturnType<typeof vi.fn>;
};
let mockServer: {
to: ReturnType<typeof vi.fn>;
};
@@ -53,11 +58,17 @@ describe('NotificationsGateway', () => {
getClient: vi.fn().mockReturnValue({ exists: vi.fn().mockResolvedValue(0), incr: vi.fn() }),
};
mockNotificationRepo = { countUnreadByUserId: vi.fn().mockResolvedValue(3) };
mockMetrics = {
recordWsConnection: vi.fn(),
setWsConnectedClients: vi.fn(),
recordWsMessage: vi.fn(),
};
gateway = new NotificationsGateway(
mockTokenService as any,
mockLogger as any,
mockRedisService as any,
mockMetrics as any,
mockNotificationRepo as any,
);
@@ -74,6 +85,14 @@ describe('NotificationsGateway', () => {
'NotificationsGateway',
);
});
it('resets the WS connected-clients gauge to 0', () => {
gateway.afterInit();
expect(mockMetrics.setWsConnectedClients).toHaveBeenCalledWith(
'/notifications',
0,
);
});
});
describe('handleConnection', () => {
@@ -152,6 +171,28 @@ describe('NotificationsGateway', () => {
expect(mockNotificationRepo.countUnreadByUserId).toHaveBeenCalledWith('user-1');
expect(socket.emit).toHaveBeenCalledWith('notification:unread-count', { unreadCount: 3 });
});
it('increments WS connection metric and records the initial unread-count emit', async () => {
const socket = createMockSocket();
await gateway.handleConnection(socket);
expect(mockMetrics.recordWsConnection).toHaveBeenCalledWith('/notifications', 1);
expect(mockMetrics.recordWsMessage).toHaveBeenCalledWith(
'/notifications',
'notification:unread-count',
'out',
);
});
it('does not increment metrics when auth fails', async () => {
mockTokenService.verifyAccessToken.mockReturnValue(null);
const socket = createMockSocket();
await gateway.handleConnection(socket);
expect(mockMetrics.recordWsConnection).not.toHaveBeenCalled();
});
});
describe('handleDisconnect', () => {
@@ -183,6 +224,24 @@ describe('NotificationsGateway', () => {
// No prior connection — should not throw
expect(() => gateway.handleDisconnect(socket)).not.toThrow();
});
it('decrements the WS connection metric when a tracked socket disconnects', async () => {
const socket = createMockSocket({ id: 'sock-1' });
await gateway.handleConnection(socket);
mockMetrics.recordWsConnection.mockClear();
gateway.handleDisconnect(socket);
expect(mockMetrics.recordWsConnection).toHaveBeenCalledWith('/notifications', -1);
});
it('does not decrement the gauge for untracked sockets', () => {
const socket = createMockSocket();
gateway.handleDisconnect(socket);
expect(mockMetrics.recordWsConnection).not.toHaveBeenCalled();
});
});
describe('handleNotificationSent', () => {

View File

@@ -11,6 +11,8 @@ import type { Server, Socket } from 'socket.io';
// eslint-disable-next-line @typescript-eslint/consistent-type-imports -- NestJS DI requires value imports
import { TokenService, type JwtPayload } from '@modules/auth';
// eslint-disable-next-line @typescript-eslint/consistent-type-imports -- NestJS DI requires value imports
import { MetricsService } from '@modules/metrics';
// eslint-disable-next-line @typescript-eslint/consistent-type-imports -- NestJS DI requires value imports
import { LoggerService, RedisService } from '@modules/shared';
import type { NotificationSentEvent } from '../../domain/events/notification-sent.event';
import {
@@ -24,6 +26,20 @@ const UNREAD_COUNT_KEY = (userId: string) => `notifications:unread:${userId}`;
/** TTL for the cached unread count (1 hour). */
const UNREAD_COUNT_TTL = 3600;
/** Namespace label used for Prometheus metrics. */
const NAMESPACE_LABEL = '/notifications';
/**
* Server → client heartbeat every 25 s and 20 s wait for the pong
* before declaring the connection dead. Matches socket.io defaults but
* pinned explicitly so operations teams can tune via env without code
* changes. Clients must reconnect with exponential backoff on their side.
*/
const WS_PING_INTERVAL_MS = Number(process.env['WS_PING_INTERVAL_MS'] ?? 25_000);
const WS_PING_TIMEOUT_MS = Number(process.env['WS_PING_TIMEOUT_MS'] ?? 20_000);
/** Allow large upgrade windows so poor networks don't churn handshakes. */
const WS_CONNECT_TIMEOUT_MS = Number(process.env['WS_CONNECT_TIMEOUT_MS'] ?? 45_000);
@WebSocketGateway({
namespace: '/notifications',
cors: {
@@ -32,6 +48,10 @@ const UNREAD_COUNT_TTL = 3600;
.map((o) => o.trim()),
credentials: true,
},
pingInterval: WS_PING_INTERVAL_MS,
pingTimeout: WS_PING_TIMEOUT_MS,
connectTimeout: WS_CONNECT_TIMEOUT_MS,
transports: ['websocket', 'polling'],
})
export class NotificationsGateway
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
@@ -46,12 +66,17 @@ export class NotificationsGateway
private readonly tokenService: TokenService,
private readonly logger: LoggerService,
private readonly redisService: RedisService,
private readonly metrics: MetricsService,
@Inject(NOTIFICATION_REPOSITORY)
private readonly notificationRepo: INotificationRepository,
) {}
afterInit(): void {
this.logger.log('NotificationsGateway initialized', 'NotificationsGateway');
this.metrics.setWsConnectedClients(NAMESPACE_LABEL, 0);
this.logger.log(
`NotificationsGateway initialized (pingInterval=${WS_PING_INTERVAL_MS}ms, pingTimeout=${WS_PING_TIMEOUT_MS}ms)`,
'NotificationsGateway',
);
}
/* ────────────────────────────────────────────
@@ -83,6 +108,13 @@ export class NotificationsGateway
const unreadCount = await this.getUnreadCount(payload.sub);
client.emit('notification:unread-count', { unreadCount });
this.metrics.recordWsConnection(NAMESPACE_LABEL, 1);
this.metrics.recordWsMessage(
NAMESPACE_LABEL,
'notification:unread-count',
'out',
);
this.logger.debug(
`WS connected: user=${payload.sub} socket=${client.id}`,
'NotificationsGateway',
@@ -107,6 +139,8 @@ export class NotificationsGateway
this.userSockets.delete(userId);
}
}
// Only decrement if the socket completed auth (we tracked it).
this.metrics.recordWsConnection(NAMESPACE_LABEL, -1);
}
this.logger.debug(
`WS disconnected: user=${userId ?? 'unknown'} socket=${client.id}`,