diff --git a/apps/api/src/modules/auth/presentation/controllers/auth.controller.ts b/apps/api/src/modules/auth/presentation/controllers/auth.controller.ts index faaeb5f..1a44e30 100644 --- a/apps/api/src/modules/auth/presentation/controllers/auth.controller.ts +++ b/apps/api/src/modules/auth/presentation/controllers/auth.controller.ts @@ -12,7 +12,7 @@ import { type CommandBus, type QueryBus } from '@nestjs/cqrs'; import { ApiTags, ApiOperation, ApiResponse, ApiBearerAuth, ApiBody } from '@nestjs/swagger'; import { Throttle } from '@nestjs/throttler'; import { type Request, type Response } from 'express'; -import { UnauthorizedException } from '@modules/shared'; +import { EndpointRateLimit, EndpointRateLimitGuard, UnauthorizedException } from '@modules/shared'; import { LoginUserCommand } from '../../application/commands/login-user/login-user.command'; import { RefreshTokenCommand } from '../../application/commands/refresh-token/refresh-token.command'; import { RegisterUserCommand } from '../../application/commands/register-user/register-user.command'; @@ -79,6 +79,8 @@ export class AuthController { ) {} @Throttle({ default: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT }, auth: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT } }) + @EndpointRateLimit({ limit: IS_TEST ? 10_000 : 3, windowSeconds: 60, keyStrategy: 'ip' }) + @UseGuards(EndpointRateLimitGuard) @Post('register') @ApiOperation({ summary: 'Register a new user' }) @ApiResponse({ status: 201, description: 'User registered, auth cookies set' }) @@ -100,7 +102,8 @@ export class AuthController { } @Throttle({ default: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT }, auth: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT } }) - @UseGuards(LocalAuthGuard) + @EndpointRateLimit({ limit: IS_TEST ? 10_000 : 5, windowSeconds: 60, keyStrategy: 'ip' }) + @UseGuards(EndpointRateLimitGuard, LocalAuthGuard) @Post('login') @ApiOperation({ summary: 'Login with phone and password' }) @ApiBody({ type: LoginDto }) diff --git a/apps/api/src/modules/listings/presentation/controllers/listings.controller.ts b/apps/api/src/modules/listings/presentation/controllers/listings.controller.ts index 195617b..6dee9b6 100644 --- a/apps/api/src/modules/listings/presentation/controllers/listings.controller.ts +++ b/apps/api/src/modules/listings/presentation/controllers/listings.controller.ts @@ -22,7 +22,7 @@ import { ApiParam, } from '@nestjs/swagger'; import { type JwtPayload, CurrentUser, Roles, JwtAuthGuard, RolesGuard } from '@modules/auth'; -import { FileValidationPipe, type UploadedFile as ValidatedFile } from '@modules/shared'; +import { EndpointRateLimit, EndpointRateLimitGuard, FileValidationPipe, type UploadedFile as ValidatedFile } from '@modules/shared'; import { RequireQuota, QuotaGuard } from '@modules/subscriptions'; import { CreateListingCommand } from '../../application/commands/create-listing/create-listing.command'; import { type CreateListingResult } from '../../application/commands/create-listing/create-listing.handler'; @@ -53,7 +53,8 @@ export class ListingsController { @ApiResponse({ status: 400, description: 'Validation error' }) @ApiResponse({ status: 401, description: 'Unauthorized' }) @ApiResponse({ status: 403, description: 'Quota exceeded' }) - @UseGuards(JwtAuthGuard, QuotaGuard) + @UseGuards(EndpointRateLimitGuard, JwtAuthGuard, QuotaGuard) + @EndpointRateLimit({ limit: 10, windowSeconds: 60, keyStrategy: 'user' }) @RequireQuota('listings_created') @Post() async createListing( diff --git a/apps/api/src/modules/payments/presentation/controllers/payments.controller.ts b/apps/api/src/modules/payments/presentation/controllers/payments.controller.ts index e23fa66..bdcadc1 100644 --- a/apps/api/src/modules/payments/presentation/controllers/payments.controller.ts +++ b/apps/api/src/modules/payments/presentation/controllers/payments.controller.ts @@ -19,6 +19,7 @@ import { import { Throttle } from '@nestjs/throttler'; import { type PaymentProvider } from '@prisma/client'; import { type JwtPayload, CurrentUser, Roles, JwtAuthGuard, RolesGuard } from '@modules/auth'; +import { EndpointRateLimit, EndpointRateLimitGuard } from '@modules/shared'; import { CreatePaymentCommand } from '../../application/commands/create-payment/create-payment.command'; import { type CreatePaymentResult } from '../../application/commands/create-payment/create-payment.handler'; import { HandleCallbackCommand } from '../../application/commands/handle-callback/handle-callback.command'; @@ -72,6 +73,8 @@ export class PaymentsController { @ApiResponse({ status: 201, description: 'Callback processed successfully' }) @ApiParam({ name: 'provider', enum: ['vnpay', 'momo', 'zalopay'] }) @Throttle({ 'payment-callback': { ttl: 60_000, limit: 20 } }) + @EndpointRateLimit({ limit: 100, windowSeconds: 60, keyStrategy: 'ip', adminBypass: false }) + @UseGuards(EndpointRateLimitGuard) @Post('callback/:provider') async handleCallback( @Param('provider') provider: string, diff --git a/apps/api/src/modules/search/presentation/controllers/search.controller.ts b/apps/api/src/modules/search/presentation/controllers/search.controller.ts index 679d15a..a679c03 100644 --- a/apps/api/src/modules/search/presentation/controllers/search.controller.ts +++ b/apps/api/src/modules/search/presentation/controllers/search.controller.ts @@ -13,6 +13,7 @@ import { ApiBearerAuth, } from '@nestjs/swagger'; import { Roles, JwtAuthGuard, RolesGuard } from '@modules/auth'; +import { EndpointRateLimit, EndpointRateLimitGuard } from '@modules/shared'; import { ReindexAllCommand } from '../../application/commands/reindex-all/reindex-all.command'; import { type ReindexResult } from '../../application/commands/reindex-all/reindex-all.handler'; import { GeoSearchQuery } from '../../application/queries/geo-search/geo-search.query'; @@ -29,6 +30,8 @@ export class SearchController { private readonly queryBus: QueryBus, ) {} + @EndpointRateLimit({ limit: 30, windowSeconds: 60, keyStrategy: 'user' }) + @UseGuards(EndpointRateLimitGuard) @Get() @ApiOperation({ summary: 'Search properties', description: 'Public full-text and faceted property search' }) @ApiResponse({ status: 200, description: 'Search results returned successfully' }) @@ -52,6 +55,8 @@ export class SearchController { ); } + @EndpointRateLimit({ limit: 30, windowSeconds: 60, keyStrategy: 'user' }) + @UseGuards(EndpointRateLimitGuard) @Get('geo') @ApiOperation({ summary: 'Geo search properties', description: 'Public geographic radius property search' }) @ApiResponse({ status: 200, description: 'Geo search results returned successfully' }) diff --git a/apps/api/src/modules/shared/infrastructure/__tests__/endpoint-rate-limit.guard.spec.ts b/apps/api/src/modules/shared/infrastructure/__tests__/endpoint-rate-limit.guard.spec.ts new file mode 100644 index 0000000..1344ca4 --- /dev/null +++ b/apps/api/src/modules/shared/infrastructure/__tests__/endpoint-rate-limit.guard.spec.ts @@ -0,0 +1,480 @@ +import { HttpException, HttpStatus, type ExecutionContext } from '@nestjs/common'; +import type { Reflector } from '@nestjs/core'; +import { type EndpointRateLimitOptions } from '../decorators/endpoint-rate-limit.decorator'; +import { EndpointRateLimitGuard } from '../guards/endpoint-rate-limit.guard'; + +// ── helpers ────────────────────────────────────────────────────────────────── + +function mockRedis( + overrides: Partial<{ + evalResult: [number, number]; + throwError: boolean; + }> = {}, +) { + const evalFn = overrides.throwError + ? vi.fn().mockRejectedValue(new Error('Redis connection lost')) + : vi.fn().mockResolvedValue(overrides.evalResult ?? [1, 0]); + + return { + getClient: vi.fn().mockReturnValue({ eval: evalFn }), + isAvailable: vi.fn().mockReturnValue(!overrides.throwError), + get: vi.fn(), + set: vi.fn(), + del: vi.fn(), + onModuleDestroy: vi.fn(), + } as any; +} + +function mockLogger() { + return { + log: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + } as any; +} + +function mockReflector(options?: EndpointRateLimitOptions | undefined) { + return { + getAllAndOverride: vi.fn().mockReturnValue(options ?? undefined), + } as unknown as Reflector; +} + +interface MockContextOptions { + user?: { sub: string; role: string } | null; + method?: string; + path?: string; + routePath?: string; + ip?: string; + forwarded?: string | string[]; +} + +function buildContext(opts: MockContextOptions = {}): ExecutionContext { + const headers: Record = {}; + if (opts.forwarded !== undefined) { + headers['x-forwarded-for'] = opts.forwarded; + } + + const response = { setHeader: vi.fn() }; + const user = 'user' in opts ? opts.user : null; + + return { + switchToHttp: () => ({ + getRequest: () => ({ + user, + method: opts.method ?? 'GET', + path: opts.path ?? '/search', + route: opts.routePath ? { path: opts.routePath } : undefined, + ip: opts.ip ?? '192.168.1.1', + headers, + }), + getResponse: () => response, + }), + getHandler: () => ({ name: 'testHandler' }), + getClass: () => ({ name: 'TestController' }), + } as unknown as ExecutionContext; +} + +// ── tests ──────────────────────────────────────────────────────────────────── + +describe('EndpointRateLimitGuard', () => { + describe('when no @EndpointRateLimit decorator is present', () => { + it('allows request (skips rate limiting)', async () => { + const redis = mockRedis(); + const guard = new EndpointRateLimitGuard(redis, mockReflector(), mockLogger()); + + const result = await guard.canActivate(buildContext()); + expect(result).toBe(true); + // Redis should not be called at all + expect(redis.getClient).not.toHaveBeenCalled(); + }); + }); + + describe('when @EndpointRateLimit is present', () => { + const defaultOptions: EndpointRateLimitOptions = { + limit: 5, + windowSeconds: 60, + keyStrategy: 'ip', + }; + + it('allows request within rate limit', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector(defaultOptions), + mockLogger(), + ); + + const result = await guard.canActivate(buildContext()); + expect(result).toBe(true); + }); + + it('rejects request with 429 when limit exceeded', async () => { + // current=5 (at limit), retryAfterMs=30000 + const redis = mockRedis({ evalResult: [5, 30_000] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector(defaultOptions), + mockLogger(), + ); + + await expect(guard.canActivate(buildContext())).rejects.toThrow(HttpException); + + try { + await guard.canActivate(buildContext()); + } catch (error) { + expect(error).toBeInstanceOf(HttpException); + expect((error as HttpException).getStatus()).toBe(HttpStatus.TOO_MANY_REQUESTS); + const body = (error as HttpException).getResponse(); + expect(body).toMatchObject({ + statusCode: 429, + message: 'Too many requests. Please try again later.', + retryAfter: 30, // 30000ms → 30s + }); + } + }); + + it('sets rate limit headers on allowed requests', async () => { + const redis = mockRedis({ evalResult: [3, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector(defaultOptions), + mockLogger(), + ); + + const ctx = buildContext(); + await guard.canActivate(ctx); + + const response = ctx.switchToHttp().getResponse(); + expect(response.setHeader).toHaveBeenCalledWith('X-RateLimit-Limit', 5); + expect(response.setHeader).toHaveBeenCalledWith('X-RateLimit-Remaining', 2); // 5 - 3 + expect(response.setHeader).toHaveBeenCalledWith('X-RateLimit-Reset', 60); + }); + + it('sets Retry-After header when limit exceeded', async () => { + const redis = mockRedis({ evalResult: [5, 45_000] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector(defaultOptions), + mockLogger(), + ); + + const ctx = buildContext(); + try { + await guard.canActivate(ctx); + } catch { + // expected + } + + const response = ctx.switchToHttp().getResponse(); + expect(response.setHeader).toHaveBeenCalledWith('Retry-After', 45); + }); + }); + + describe('key strategy', () => { + it('uses IP-based key by default', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5, keyStrategy: 'ip' }), + mockLogger(), + ); + + await guard.canActivate( + buildContext({ ip: '10.0.0.1', method: 'POST', routePath: '/auth/login' }), + ); + + const evalCall = redis.getClient().eval.mock.calls[0]; + const key: string = evalCall[2]; + expect(key).toContain('ip:10.0.0.1'); + expect(key).toContain('POST'); + expect(key).toContain('/auth/login'); + }); + + it('uses X-Forwarded-For IP when available', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5, keyStrategy: 'ip' }), + mockLogger(), + ); + + await guard.canActivate( + buildContext({ forwarded: '203.0.113.5, 70.41.3.18', ip: '127.0.0.1' }), + ); + + const evalCall = redis.getClient().eval.mock.calls[0]; + const key: string = evalCall[2]; + expect(key).toContain('ip:203.0.113.5'); + }); + + it('uses authenticated user ID when strategy is "user"', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 10, keyStrategy: 'user' }), + mockLogger(), + ); + + await guard.canActivate( + buildContext({ + user: { sub: 'user-abc', role: 'BUYER' }, + method: 'POST', + routePath: '/listings', + }), + ); + + const evalCall = redis.getClient().eval.mock.calls[0]; + const key: string = evalCall[2]; + expect(key).toContain('user:user-abc'); + expect(key).toContain('POST'); + expect(key).toContain('/listings'); + }); + + it('falls back to IP when strategy is "user" but no authenticated user', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 30, keyStrategy: 'user' }), + mockLogger(), + ); + + await guard.canActivate( + buildContext({ user: null, ip: '10.0.0.5', routePath: '/search' }), + ); + + const evalCall = redis.getClient().eval.mock.calls[0]; + const key: string = evalCall[2]; + expect(key).toContain('ip:10.0.0.5'); + }); + }); + + describe('admin bypass', () => { + it('bypasses rate limit for ADMIN users by default', async () => { + const redis = mockRedis(); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5, keyStrategy: 'ip' }), + mockLogger(), + ); + + const result = await guard.canActivate( + buildContext({ user: { sub: 'admin-1', role: 'ADMIN' } }), + ); + expect(result).toBe(true); + // Redis should not be called — admin bypasses + expect(redis.getClient).not.toHaveBeenCalled(); + }); + + it('does not bypass ADMIN when adminBypass is false', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 100, keyStrategy: 'ip', adminBypass: false }), + mockLogger(), + ); + + await guard.canActivate( + buildContext({ user: { sub: 'admin-1', role: 'ADMIN' } }), + ); + + // Redis WAS called — no bypass + expect(redis.getClient).toHaveBeenCalled(); + }); + + it('does not bypass for non-ADMIN roles', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5, keyStrategy: 'ip' }), + mockLogger(), + ); + + await guard.canActivate( + buildContext({ user: { sub: 'buyer-1', role: 'BUYER' } }), + ); + + expect(redis.getClient).toHaveBeenCalled(); + }); + }); + + describe('fail-open behavior', () => { + it('allows request when Redis throws an error', async () => { + const redis = mockRedis({ throwError: true }); + const logger = mockLogger(); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5, keyStrategy: 'ip' }), + logger, + ); + + const result = await guard.canActivate(buildContext()); + expect(result).toBe(true); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Endpoint rate limit check failed'), + 'EndpointRateLimitGuard', + ); + }); + + it('re-throws 429 errors (does not swallow intentional rejections)', async () => { + // Simulate a situation where Redis returns limit exceeded + const redis = mockRedis({ evalResult: [5, 30_000] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5, keyStrategy: 'ip' }), + mockLogger(), + ); + + await expect(guard.canActivate(buildContext())).rejects.toThrow(HttpException); + }); + }); + + describe('window seconds', () => { + it('defaults to 60 seconds when windowSeconds is not specified', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5 }), // no windowSeconds + mockLogger(), + ); + + await guard.canActivate(buildContext()); + + const evalCall = redis.getClient().eval.mock.calls[0]; + // ARGV[2] = windowMs = 60 * 1000 + expect(evalCall[4]).toBe(60_000); + // ARGV[5] = windowSeconds = 60 + expect(evalCall[7]).toBe(60); + }); + + it('uses custom window seconds when specified', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 10, windowSeconds: 120 }), + mockLogger(), + ); + + await guard.canActivate(buildContext()); + + const evalCall = redis.getClient().eval.mock.calls[0]; + // ARGV[2] = windowMs = 120 * 1000 + expect(evalCall[4]).toBe(120_000); + // ARGV[5] = windowSeconds = 120 + expect(evalCall[7]).toBe(120); + }); + }); + + describe('logging', () => { + it('logs warning when rate limit exceeded', async () => { + const redis = mockRedis({ evalResult: [5, 30_000] }); + const logger = mockLogger(); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5, keyStrategy: 'ip' }), + logger, + ); + + try { + await guard.canActivate(buildContext({ method: 'POST', routePath: '/auth/login' })); + } catch { + // expected + } + + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Endpoint rate limit exceeded'), + 'EndpointRateLimitGuard', + ); + }); + + it('logs warning on Redis failure', async () => { + const redis = mockRedis({ throwError: true }); + const logger = mockLogger(); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5, keyStrategy: 'ip' }), + logger, + ); + + await guard.canActivate(buildContext()); + + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Redis error'), + 'EndpointRateLimitGuard', + ); + }); + }); + + describe('Retry-After computation', () => { + it('ensures minimum Retry-After of 1 second', async () => { + // retryAfterMs = 500 → should ceil to 1 second + const redis = mockRedis({ evalResult: [5, 500] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5, keyStrategy: 'ip' }), + mockLogger(), + ); + + const ctx = buildContext(); + try { + await guard.canActivate(ctx); + } catch (error) { + const body = (error as HttpException).getResponse() as any; + expect(body.retryAfter).toBe(1); + } + }); + + it('rounds up Retry-After to next second', async () => { + // retryAfterMs = 15100 → should ceil to 16 seconds + const redis = mockRedis({ evalResult: [5, 15_100] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 5, keyStrategy: 'ip' }), + mockLogger(), + ); + + const ctx = buildContext(); + try { + await guard.canActivate(ctx); + } catch (error) { + const body = (error as HttpException).getResponse() as any; + expect(body.retryAfter).toBe(16); + } + }); + }); + + describe('endpoint path in key', () => { + it('uses route path when available', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 100, keyStrategy: 'ip' }), + mockLogger(), + ); + + await guard.canActivate( + buildContext({ routePath: '/payments/callback/:provider', path: '/payments/callback/vnpay' }), + ); + + const evalCall = redis.getClient().eval.mock.calls[0]; + const key: string = evalCall[2]; + // Should use the route pattern, not the resolved path + expect(key).toContain('/payments/callback/:provider'); + }); + + it('falls back to request path when no route defined', async () => { + const redis = mockRedis({ evalResult: [1, 0] }); + const guard = new EndpointRateLimitGuard( + redis, + mockReflector({ limit: 30, keyStrategy: 'ip' }), + mockLogger(), + ); + + await guard.canActivate(buildContext({ path: '/search' })); + + const evalCall = redis.getClient().eval.mock.calls[0]; + const key: string = evalCall[2]; + expect(key).toContain('/search'); + }); + }); +}); diff --git a/apps/api/src/modules/shared/infrastructure/decorators/endpoint-rate-limit.decorator.ts b/apps/api/src/modules/shared/infrastructure/decorators/endpoint-rate-limit.decorator.ts new file mode 100644 index 0000000..b9731d4 --- /dev/null +++ b/apps/api/src/modules/shared/infrastructure/decorators/endpoint-rate-limit.decorator.ts @@ -0,0 +1,57 @@ +import { SetMetadata } from '@nestjs/common'; + +/** + * Metadata key for per-endpoint rate limit configuration. + * Read by EndpointRateLimitGuard via reflector. + */ +export const ENDPOINT_RATE_LIMIT_KEY = 'endpoint_rate_limit'; + +export interface EndpointRateLimitOptions { + /** + * Maximum number of requests allowed within the window. + * @example 5 (for login: 5 attempts per minute) + */ + limit: number; + + /** + * Sliding window duration in seconds. + * @default 60 + */ + windowSeconds?: number; + + /** + * Key strategy for identifying the requester: + * - `'ip'` — key by client IP (for unauthenticated endpoints like login/register) + * - `'user'` — key by authenticated user ID (falls back to IP if unauthenticated) + * @default 'ip' + */ + keyStrategy?: 'ip' | 'user'; + + /** + * Whether ADMIN users bypass this rate limit. + * @default true + */ + adminBypass?: boolean; +} + +/** + * Decorator to configure per-endpoint rate limiting with Redis sliding window. + * + * Works in conjunction with `EndpointRateLimitGuard`. Apply it to specific + * controller methods or the entire controller class to enforce granular + * rate limits keyed by IP or authenticated user + endpoint path. + * + * @example + * // Login: 5 attempts per minute per IP + * @EndpointRateLimit({ limit: 5, windowSeconds: 60, keyStrategy: 'ip' }) + * + * @example + * // Listing creation: 10 per minute per user + * @EndpointRateLimit({ limit: 10, windowSeconds: 60, keyStrategy: 'user' }) + * + * @example + * // Search: 30 per minute per user, admin bypass enabled + * @EndpointRateLimit({ limit: 30, windowSeconds: 60, keyStrategy: 'user', adminBypass: true }) + */ +export const EndpointRateLimit = (options: EndpointRateLimitOptions) => + SetMetadata(ENDPOINT_RATE_LIMIT_KEY, options); diff --git a/apps/api/src/modules/shared/infrastructure/guards/endpoint-rate-limit.guard.ts b/apps/api/src/modules/shared/infrastructure/guards/endpoint-rate-limit.guard.ts new file mode 100644 index 0000000..7f2e2a9 --- /dev/null +++ b/apps/api/src/modules/shared/infrastructure/guards/endpoint-rate-limit.guard.ts @@ -0,0 +1,218 @@ +import { + Injectable, + type CanActivate, + type ExecutionContext, + HttpException, + HttpStatus, +} from '@nestjs/common'; +import { type Reflector } from '@nestjs/core'; +import { type Request, type Response } from 'express'; +import { + ENDPOINT_RATE_LIMIT_KEY, + type EndpointRateLimitOptions, +} from '../decorators/endpoint-rate-limit.decorator'; +import { type LoggerService } from '../logger.service'; +import { type RedisService } from '../redis.service'; + +/** Express request extended with optional JWT user payload. */ +interface AuthenticatedRequest extends Request { + user?: { sub: string; role: string }; + route?: { path: string }; +} + +/** + * Lua script implementing a true sliding-window rate limiter using a Redis sorted set. + * + * Algorithm: + * 1. Remove all entries outside the current window (older than `now - windowMs`). + * 2. Count remaining entries (= requests within the window). + * 3. If under the limit, add the current request with score = current timestamp. + * 4. Set TTL on the key to auto-expire after the window. + * 5. Return [currentCount, oldestEntryAge] so we can compute Retry-After. + * + * This is more precise than a fixed-window counter because it considers + * the exact timestamp of each request rather than bucketing. + * + * KEYS[1] = rate limit key + * ARGV[1] = now (ms timestamp) + * ARGV[2] = windowMs (window duration in ms) + * ARGV[3] = limit (max requests) + * ARGV[4] = unique request ID (to distinguish concurrent requests) + * ARGV[5] = windowSeconds (for key TTL) + */ +const SLIDING_WINDOW_LUA = ` +local key = KEYS[1] +local now = tonumber(ARGV[1]) +local windowMs = tonumber(ARGV[2]) +local limit = tonumber(ARGV[3]) +local requestId = ARGV[4] +local windowSec = tonumber(ARGV[5]) + +-- Remove entries outside the window +redis.call('ZREMRANGEBYSCORE', key, 0, now - windowMs) + +-- Count current requests in window +local current = redis.call('ZCARD', key) + +if current < limit then + -- Add this request + redis.call('ZADD', key, now, requestId) + -- Ensure key expires after the window + redis.call('EXPIRE', key, windowSec + 1) + return {current + 1, 0} +else + -- Rate limit exceeded — compute time until oldest entry expires + local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') + local retryAfterMs = 0 + if #oldest >= 2 then + retryAfterMs = tonumber(oldest[2]) + windowMs - now + if retryAfterMs < 0 then retryAfterMs = 0 end + end + return {current, retryAfterMs} +end +`; + +/** Monotonically increasing counter to generate unique request IDs. */ +let requestCounter = 0; + +/** + * Guard that enforces per-endpoint rate limiting using a Redis sorted-set sliding window. + * + * Only activates on routes decorated with `@EndpointRateLimit()`. + * Routes without the decorator are not affected. + * + * Features: + * - True sliding window (sorted set) for accurate rate measurement + * - Keys by IP or authenticated user + endpoint path + * - Proper 429 response with `Retry-After` header + * - Rate limit headers (`X-RateLimit-Limit`, `X-RateLimit-Remaining`, `X-RateLimit-Reset`) + * - Admin bypass support + * - Fail-open on Redis errors (logs warning, allows request) + */ +@Injectable() +export class EndpointRateLimitGuard implements CanActivate { + constructor( + private readonly redis: RedisService, + private readonly reflector: Reflector, + private readonly logger: LoggerService, + ) {} + + async canActivate(context: ExecutionContext): Promise { + const options = this.reflector.getAllAndOverride( + ENDPOINT_RATE_LIMIT_KEY, + [context.getHandler(), context.getClass()], + ); + + // No @EndpointRateLimit decorator → skip + if (!options) { + return true; + } + + const request = context.switchToHttp().getRequest(); + const response: Response = context.switchToHttp().getResponse(); + const user = request.user; + + // Admin bypass (default: true) + const adminBypass = options.adminBypass !== false; + if (adminBypass && user?.role === 'ADMIN') { + return true; + } + + const windowSeconds = options.windowSeconds ?? 60; + const windowMs = windowSeconds * 1000; + const limit = options.limit; + const keyStrategy = options.keyStrategy ?? 'ip'; + + // Build the rate limit key: endpoint_rate_limit:{method}:{path}:{identifier} + const method = request.method; + const routePath = request.route?.path ?? request.path; + const identifier = this.resolveIdentifier(request, keyStrategy); + + const key = `endpoint_rate_limit:${method}:${routePath}:${identifier}`; + + try { + const client = this.redis.getClient(); + const now = Date.now(); + const requestId = `${now}:${process.pid}:${++requestCounter}`; + + const result = await client.eval( + SLIDING_WINDOW_LUA, + 1, + key, + now, + windowMs, + limit, + requestId, + windowSeconds, + ) as [number, number]; + + const current = result[0]; + const retryAfterMs = result[1]; + + // Set rate limit headers for observability + response.setHeader('X-RateLimit-Limit', limit); + response.setHeader('X-RateLimit-Remaining', Math.max(0, limit - current)); + response.setHeader('X-RateLimit-Reset', windowSeconds); + + if (current > limit || retryAfterMs > 0) { + // The request was NOT added (limit reached) — return 429 + const retryAfterSeconds = Math.max(1, Math.ceil(retryAfterMs / 1000)); + response.setHeader('Retry-After', retryAfterSeconds); + + this.logger.warn( + `Endpoint rate limit exceeded: ${method} ${routePath}, ` + + `key=${keyStrategy}:${identifier}, current=${current}/${limit}, ` + + `retryAfter=${retryAfterSeconds}s`, + 'EndpointRateLimitGuard', + ); + + throw new HttpException( + { + statusCode: HttpStatus.TOO_MANY_REQUESTS, + message: 'Too many requests. Please try again later.', + retryAfter: retryAfterSeconds, + }, + HttpStatus.TOO_MANY_REQUESTS, + ); + } + + return true; + } catch (error) { + // Re-throw intentional 429 errors + if (error instanceof HttpException && error.getStatus() === HttpStatus.TOO_MANY_REQUESTS) { + throw error; + } + + // Fail open on Redis errors + this.logger.warn( + `Endpoint rate limit check failed (Redis error), allowing request: ` + + `${method} ${routePath}, key=${keyStrategy}:${identifier}, ` + + `error=${error instanceof Error ? error.message : 'unknown'}`, + 'EndpointRateLimitGuard', + ); + return true; + } + } + + /** + * Resolve the identifier for the rate limit key based on the strategy. + */ + private resolveIdentifier(request: AuthenticatedRequest, strategy: 'ip' | 'user'): string { + if (strategy === 'user') { + const user = request.user; + if (user?.sub) { + return `user:${user.sub}`; + } + // Fall back to IP for unauthenticated requests on user-keyed endpoints + } + + // IP-based: extract real IP behind proxy + const forwarded = request.headers['x-forwarded-for']; + const ip = + typeof forwarded === 'string' + ? (forwarded.split(',')[0]?.trim() ?? '127.0.0.1') + : (request.ip ?? '127.0.0.1'); + + return `ip:${ip}`; + } +} diff --git a/apps/api/src/modules/shared/infrastructure/index.ts b/apps/api/src/modules/shared/infrastructure/index.ts index 676dab3..13635bc 100644 --- a/apps/api/src/modules/shared/infrastructure/index.ts +++ b/apps/api/src/modules/shared/infrastructure/index.ts @@ -20,6 +20,12 @@ export { type UserRateLimitOptions, } from './guards/user-rate-limit.guard'; export { UserRateLimit } from './decorators/user-rate-limit.decorator'; +export { + EndpointRateLimit, + ENDPOINT_RATE_LIMIT_KEY, + type EndpointRateLimitOptions, +} from './decorators/endpoint-rate-limit.decorator'; +export { EndpointRateLimitGuard } from './guards/endpoint-rate-limit.guard'; export { FileValidationPipe } from './pipes/file-validation.pipe'; export type { FileValidationOptions, UploadedFile } from './pipes/file-validation.pipe'; export { validateEnv, validateJwtSecret } from './env-validation';