import { Inject } from '@nestjs/common'; import { OnEvent } from '@nestjs/event-emitter'; import { WebSocketGateway, WebSocketServer, type OnGatewayConnection, type OnGatewayDisconnect, type OnGatewayInit, } from '@nestjs/websockets'; import type { Server, Socket } from 'socket.io'; // eslint-disable-next-line @typescript-eslint/consistent-type-imports -- NestJS DI requires value imports import { TokenService, type JwtPayload } from '@modules/auth'; // eslint-disable-next-line @typescript-eslint/consistent-type-imports -- NestJS DI requires value imports import { LoggerService, RedisService } from '@modules/shared'; import type { NotificationSentEvent } from '../../domain/events/notification-sent.event'; import { NOTIFICATION_REPOSITORY, type INotificationRepository, } from '../../domain/repositories/notification.repository'; /** Redis key for the per-user unread notification counter. */ const UNREAD_COUNT_KEY = (userId: string) => `notifications:unread:${userId}`; /** TTL for the cached unread count (1 hour). */ const UNREAD_COUNT_TTL = 3600; @WebSocketGateway({ namespace: '/notifications', cors: { origin: (process.env['CORS_ORIGINS'] ?? 'http://localhost:3000') .split(',') .map((o) => o.trim()), credentials: true, }, }) export class NotificationsGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect { @WebSocketServer() server!: Server; /** Track connected sockets per user for multi-device support. */ private readonly userSockets = new Map>(); constructor( private readonly tokenService: TokenService, private readonly logger: LoggerService, private readonly redisService: RedisService, @Inject(NOTIFICATION_REPOSITORY) private readonly notificationRepo: INotificationRepository, ) {} afterInit(): void { this.logger.log('NotificationsGateway initialized', 'NotificationsGateway'); } /* ──────────────────────────────────────────── * Connection lifecycle * ──────────────────────────────────────────── */ async handleConnection(client: Socket): Promise { try { const payload = this.extractAndVerifyToken(client); if (!payload) { client.disconnect(true); return; } // Attach identity to the socket for later use client.data['userId'] = payload.sub; client.data['role'] = payload.role; // Join the user's private room await client.join(`user:${payload.sub}`); // Track socket for bookkeeping if (!this.userSockets.has(payload.sub)) { this.userSockets.set(payload.sub, new Set()); } this.userSockets.get(payload.sub)!.add(client.id); // Push the current unread count on connect const unreadCount = await this.getUnreadCount(payload.sub); client.emit('notification:unread-count', { unreadCount }); this.logger.debug( `WS connected: user=${payload.sub} socket=${client.id}`, 'NotificationsGateway', ); } catch (error) { this.logger.error( `WS connection error: ${error instanceof Error ? error.message : error}`, error instanceof Error ? error.stack : undefined, 'NotificationsGateway', ); client.disconnect(true); } } handleDisconnect(client: Socket): void { const userId = client.data['userId'] as string | undefined; if (userId) { const sockets = this.userSockets.get(userId); if (sockets) { sockets.delete(client.id); if (sockets.size === 0) { this.userSockets.delete(userId); } } } this.logger.debug( `WS disconnected: user=${userId ?? 'unknown'} socket=${client.id}`, 'NotificationsGateway', ); } /* ──────────────────────────────────────────── * Domain event handlers * ──────────────────────────────────────────── */ /** * Listens to the `notification.sent` domain event emitted by * {@link SendNotificationHandler} after a notification is persisted & sent. * * Pushes `notification:new` to the user's room and bumps the * cached unread counter. */ @OnEvent('notification.sent', { async: true }) async handleNotificationSent(event: NotificationSentEvent): Promise { try { this.server.to(`user:${event.userId}`).emit('notification:new', { id: event.aggregateId, templateKey: event.templateKey, channel: event.channel, occurredAt: event.occurredAt.toISOString(), }); // Increment cached unread count await this.incrementUnreadCount(event.userId); // Also emit updated count const unreadCount = await this.getUnreadCount(event.userId); this.server .to(`user:${event.userId}`) .emit('notification:unread-count', { unreadCount }); } catch (error) { this.logger.error( `Failed to emit WS notification for user ${event.userId}: ${ error instanceof Error ? error.message : error }`, error instanceof Error ? error.stack : undefined, 'NotificationsGateway', ); } } /* ──────────────────────────────────────────── * Public helpers — used by the controller * ──────────────────────────────────────────── */ /** * Emit an updated unread count to a user after they mark * notifications as read (called from the controller). */ async emitUnreadCount(userId: string): Promise { const unreadCount = await this.getUnreadCount(userId); this.server .to(`user:${userId}`) .emit('notification:unread-count', { unreadCount }); } /** * Invalidate the cached unread count (called after mark-as-read). */ async invalidateUnreadCount(userId: string): Promise { if (this.redisService.isAvailable()) { await this.redisService.del(UNREAD_COUNT_KEY(userId)); } } /* ──────────────────────────────────────────── * Private helpers * ──────────────────────────────────────────── */ /** * Extract JWT from the socket handshake and verify it. * * Supports three sources (in priority order): * 1. `handshake.auth.token` — Socket.IO `auth` option (recommended) * 2. `handshake.headers.authorization` — HTTP upgrade header * 3. `handshake.query.token` — query string (least secure) */ private extractAndVerifyToken(client: Socket): JwtPayload | null { const raw: unknown = client.handshake.auth?.['token'] ?? client.handshake.headers?.['authorization'] ?? client.handshake.query?.['token']; if (!raw || typeof raw !== 'string') { this.logger.warn( `WS auth failed: no token provided (socket=${client.id})`, 'NotificationsGateway', ); return null; } const token = raw.startsWith('Bearer ') ? raw.slice(7) : raw; const payload = this.tokenService.verifyAccessToken(token); if (!payload) { this.logger.warn( `WS auth failed: invalid token (socket=${client.id})`, 'NotificationsGateway', ); } return payload; } /** * Read the unread count from Redis (cache-aside pattern). * Falls back to the database when Redis is unavailable or cache misses. */ private async getUnreadCount(userId: string): Promise { if (this.redisService.isAvailable()) { try { const cached = await this.redisService.get(UNREAD_COUNT_KEY(userId)); if (cached !== null) { return Number(cached); } } catch { // Redis unavailable — fall through to DB } } const count = await this.notificationRepo.countUnreadByUserId(userId); // Warm the cache if (this.redisService.isAvailable()) { try { await this.redisService.set( UNREAD_COUNT_KEY(userId), String(count), UNREAD_COUNT_TTL, ); } catch { // Non-critical — continue without cache } } return count; } /** * Increment the cached unread counter in Redis (if available). * The counter is lazily initialised from the DB on the next read if * the key does not exist. */ private async incrementUnreadCount(userId: string): Promise { if (!this.redisService.isAvailable()) return; try { const client = this.redisService.getClient(); const key = UNREAD_COUNT_KEY(userId); const exists = await client.exists(key); if (exists) { await client.incr(key); } // If key doesn't exist, getUnreadCount will populate it on next read } catch { // Non-critical } } }