feat(security): add per-endpoint API rate limiting with Redis sliding window
Implement @EndpointRateLimit() decorator and EndpointRateLimitGuard for granular per-endpoint rate limiting using a Redis sorted-set sliding window. This prevents brute force attacks on auth endpoints, replay attacks on payment callbacks, and scraping on search endpoints. Applied rate limits: - /auth/login: 5 req/min per IP - /auth/register: 3 req/min per IP - /listings POST: 10 req/min per user - /search: 30 req/min per user - /payments/callback/*: 100 req/min per IP Features: - True sliding window (sorted set) for accurate rate measurement - Configurable key strategy (IP or authenticated user) - Admin bypass support (enabled by default) - Fail-open on Redis errors - Proper 429 response with Retry-After header - Rate limit headers (X-RateLimit-Limit/Remaining/Reset) - 22 unit tests covering all scenarios Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -12,7 +12,7 @@ import { type CommandBus, type QueryBus } from '@nestjs/cqrs';
|
||||
import { ApiTags, ApiOperation, ApiResponse, ApiBearerAuth, ApiBody } from '@nestjs/swagger';
|
||||
import { Throttle } from '@nestjs/throttler';
|
||||
import { type Request, type Response } from 'express';
|
||||
import { UnauthorizedException } from '@modules/shared';
|
||||
import { EndpointRateLimit, EndpointRateLimitGuard, UnauthorizedException } from '@modules/shared';
|
||||
import { LoginUserCommand } from '../../application/commands/login-user/login-user.command';
|
||||
import { RefreshTokenCommand } from '../../application/commands/refresh-token/refresh-token.command';
|
||||
import { RegisterUserCommand } from '../../application/commands/register-user/register-user.command';
|
||||
@@ -79,6 +79,8 @@ export class AuthController {
|
||||
) {}
|
||||
|
||||
@Throttle({ default: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT }, auth: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT } })
|
||||
@EndpointRateLimit({ limit: IS_TEST ? 10_000 : 3, windowSeconds: 60, keyStrategy: 'ip' })
|
||||
@UseGuards(EndpointRateLimitGuard)
|
||||
@Post('register')
|
||||
@ApiOperation({ summary: 'Register a new user' })
|
||||
@ApiResponse({ status: 201, description: 'User registered, auth cookies set' })
|
||||
@@ -100,7 +102,8 @@ export class AuthController {
|
||||
}
|
||||
|
||||
@Throttle({ default: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT }, auth: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT } })
|
||||
@UseGuards(LocalAuthGuard)
|
||||
@EndpointRateLimit({ limit: IS_TEST ? 10_000 : 5, windowSeconds: 60, keyStrategy: 'ip' })
|
||||
@UseGuards(EndpointRateLimitGuard, LocalAuthGuard)
|
||||
@Post('login')
|
||||
@ApiOperation({ summary: 'Login with phone and password' })
|
||||
@ApiBody({ type: LoginDto })
|
||||
|
||||
@@ -22,7 +22,7 @@ import {
|
||||
ApiParam,
|
||||
} from '@nestjs/swagger';
|
||||
import { type JwtPayload, CurrentUser, Roles, JwtAuthGuard, RolesGuard } from '@modules/auth';
|
||||
import { FileValidationPipe, type UploadedFile as ValidatedFile } from '@modules/shared';
|
||||
import { EndpointRateLimit, EndpointRateLimitGuard, FileValidationPipe, type UploadedFile as ValidatedFile } from '@modules/shared';
|
||||
import { RequireQuota, QuotaGuard } from '@modules/subscriptions';
|
||||
import { CreateListingCommand } from '../../application/commands/create-listing/create-listing.command';
|
||||
import { type CreateListingResult } from '../../application/commands/create-listing/create-listing.handler';
|
||||
@@ -53,7 +53,8 @@ export class ListingsController {
|
||||
@ApiResponse({ status: 400, description: 'Validation error' })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 403, description: 'Quota exceeded' })
|
||||
@UseGuards(JwtAuthGuard, QuotaGuard)
|
||||
@UseGuards(EndpointRateLimitGuard, JwtAuthGuard, QuotaGuard)
|
||||
@EndpointRateLimit({ limit: 10, windowSeconds: 60, keyStrategy: 'user' })
|
||||
@RequireQuota('listings_created')
|
||||
@Post()
|
||||
async createListing(
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
import { Throttle } from '@nestjs/throttler';
|
||||
import { type PaymentProvider } from '@prisma/client';
|
||||
import { type JwtPayload, CurrentUser, Roles, JwtAuthGuard, RolesGuard } from '@modules/auth';
|
||||
import { EndpointRateLimit, EndpointRateLimitGuard } from '@modules/shared';
|
||||
import { CreatePaymentCommand } from '../../application/commands/create-payment/create-payment.command';
|
||||
import { type CreatePaymentResult } from '../../application/commands/create-payment/create-payment.handler';
|
||||
import { HandleCallbackCommand } from '../../application/commands/handle-callback/handle-callback.command';
|
||||
@@ -72,6 +73,8 @@ export class PaymentsController {
|
||||
@ApiResponse({ status: 201, description: 'Callback processed successfully' })
|
||||
@ApiParam({ name: 'provider', enum: ['vnpay', 'momo', 'zalopay'] })
|
||||
@Throttle({ 'payment-callback': { ttl: 60_000, limit: 20 } })
|
||||
@EndpointRateLimit({ limit: 100, windowSeconds: 60, keyStrategy: 'ip', adminBypass: false })
|
||||
@UseGuards(EndpointRateLimitGuard)
|
||||
@Post('callback/:provider')
|
||||
async handleCallback(
|
||||
@Param('provider') provider: string,
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
ApiBearerAuth,
|
||||
} from '@nestjs/swagger';
|
||||
import { Roles, JwtAuthGuard, RolesGuard } from '@modules/auth';
|
||||
import { EndpointRateLimit, EndpointRateLimitGuard } from '@modules/shared';
|
||||
import { ReindexAllCommand } from '../../application/commands/reindex-all/reindex-all.command';
|
||||
import { type ReindexResult } from '../../application/commands/reindex-all/reindex-all.handler';
|
||||
import { GeoSearchQuery } from '../../application/queries/geo-search/geo-search.query';
|
||||
@@ -29,6 +30,8 @@ export class SearchController {
|
||||
private readonly queryBus: QueryBus,
|
||||
) {}
|
||||
|
||||
@EndpointRateLimit({ limit: 30, windowSeconds: 60, keyStrategy: 'user' })
|
||||
@UseGuards(EndpointRateLimitGuard)
|
||||
@Get()
|
||||
@ApiOperation({ summary: 'Search properties', description: 'Public full-text and faceted property search' })
|
||||
@ApiResponse({ status: 200, description: 'Search results returned successfully' })
|
||||
@@ -52,6 +55,8 @@ export class SearchController {
|
||||
);
|
||||
}
|
||||
|
||||
@EndpointRateLimit({ limit: 30, windowSeconds: 60, keyStrategy: 'user' })
|
||||
@UseGuards(EndpointRateLimitGuard)
|
||||
@Get('geo')
|
||||
@ApiOperation({ summary: 'Geo search properties', description: 'Public geographic radius property search' })
|
||||
@ApiResponse({ status: 200, description: 'Geo search results returned successfully' })
|
||||
|
||||
@@ -0,0 +1,480 @@
|
||||
import { HttpException, HttpStatus, type ExecutionContext } from '@nestjs/common';
|
||||
import type { Reflector } from '@nestjs/core';
|
||||
import { type EndpointRateLimitOptions } from '../decorators/endpoint-rate-limit.decorator';
|
||||
import { EndpointRateLimitGuard } from '../guards/endpoint-rate-limit.guard';
|
||||
|
||||
// ── helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
function mockRedis(
|
||||
overrides: Partial<{
|
||||
evalResult: [number, number];
|
||||
throwError: boolean;
|
||||
}> = {},
|
||||
) {
|
||||
const evalFn = overrides.throwError
|
||||
? vi.fn().mockRejectedValue(new Error('Redis connection lost'))
|
||||
: vi.fn().mockResolvedValue(overrides.evalResult ?? [1, 0]);
|
||||
|
||||
return {
|
||||
getClient: vi.fn().mockReturnValue({ eval: evalFn }),
|
||||
isAvailable: vi.fn().mockReturnValue(!overrides.throwError),
|
||||
get: vi.fn(),
|
||||
set: vi.fn(),
|
||||
del: vi.fn(),
|
||||
onModuleDestroy: vi.fn(),
|
||||
} as any;
|
||||
}
|
||||
|
||||
function mockLogger() {
|
||||
return {
|
||||
log: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
} as any;
|
||||
}
|
||||
|
||||
function mockReflector(options?: EndpointRateLimitOptions | undefined) {
|
||||
return {
|
||||
getAllAndOverride: vi.fn().mockReturnValue(options ?? undefined),
|
||||
} as unknown as Reflector;
|
||||
}
|
||||
|
||||
interface MockContextOptions {
|
||||
user?: { sub: string; role: string } | null;
|
||||
method?: string;
|
||||
path?: string;
|
||||
routePath?: string;
|
||||
ip?: string;
|
||||
forwarded?: string | string[];
|
||||
}
|
||||
|
||||
function buildContext(opts: MockContextOptions = {}): ExecutionContext {
|
||||
const headers: Record<string, string | string[] | undefined> = {};
|
||||
if (opts.forwarded !== undefined) {
|
||||
headers['x-forwarded-for'] = opts.forwarded;
|
||||
}
|
||||
|
||||
const response = { setHeader: vi.fn() };
|
||||
const user = 'user' in opts ? opts.user : null;
|
||||
|
||||
return {
|
||||
switchToHttp: () => ({
|
||||
getRequest: () => ({
|
||||
user,
|
||||
method: opts.method ?? 'GET',
|
||||
path: opts.path ?? '/search',
|
||||
route: opts.routePath ? { path: opts.routePath } : undefined,
|
||||
ip: opts.ip ?? '192.168.1.1',
|
||||
headers,
|
||||
}),
|
||||
getResponse: () => response,
|
||||
}),
|
||||
getHandler: () => ({ name: 'testHandler' }),
|
||||
getClass: () => ({ name: 'TestController' }),
|
||||
} as unknown as ExecutionContext;
|
||||
}
|
||||
|
||||
// ── tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
describe('EndpointRateLimitGuard', () => {
|
||||
describe('when no @EndpointRateLimit decorator is present', () => {
|
||||
it('allows request (skips rate limiting)', async () => {
|
||||
const redis = mockRedis();
|
||||
const guard = new EndpointRateLimitGuard(redis, mockReflector(), mockLogger());
|
||||
|
||||
const result = await guard.canActivate(buildContext());
|
||||
expect(result).toBe(true);
|
||||
// Redis should not be called at all
|
||||
expect(redis.getClient).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when @EndpointRateLimit is present', () => {
|
||||
const defaultOptions: EndpointRateLimitOptions = {
|
||||
limit: 5,
|
||||
windowSeconds: 60,
|
||||
keyStrategy: 'ip',
|
||||
};
|
||||
|
||||
it('allows request within rate limit', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector(defaultOptions),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
const result = await guard.canActivate(buildContext());
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('rejects request with 429 when limit exceeded', async () => {
|
||||
// current=5 (at limit), retryAfterMs=30000
|
||||
const redis = mockRedis({ evalResult: [5, 30_000] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector(defaultOptions),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await expect(guard.canActivate(buildContext())).rejects.toThrow(HttpException);
|
||||
|
||||
try {
|
||||
await guard.canActivate(buildContext());
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(HttpException);
|
||||
expect((error as HttpException).getStatus()).toBe(HttpStatus.TOO_MANY_REQUESTS);
|
||||
const body = (error as HttpException).getResponse();
|
||||
expect(body).toMatchObject({
|
||||
statusCode: 429,
|
||||
message: 'Too many requests. Please try again later.',
|
||||
retryAfter: 30, // 30000ms → 30s
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
it('sets rate limit headers on allowed requests', async () => {
|
||||
const redis = mockRedis({ evalResult: [3, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector(defaultOptions),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
const ctx = buildContext();
|
||||
await guard.canActivate(ctx);
|
||||
|
||||
const response = ctx.switchToHttp().getResponse();
|
||||
expect(response.setHeader).toHaveBeenCalledWith('X-RateLimit-Limit', 5);
|
||||
expect(response.setHeader).toHaveBeenCalledWith('X-RateLimit-Remaining', 2); // 5 - 3
|
||||
expect(response.setHeader).toHaveBeenCalledWith('X-RateLimit-Reset', 60);
|
||||
});
|
||||
|
||||
it('sets Retry-After header when limit exceeded', async () => {
|
||||
const redis = mockRedis({ evalResult: [5, 45_000] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector(defaultOptions),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
const ctx = buildContext();
|
||||
try {
|
||||
await guard.canActivate(ctx);
|
||||
} catch {
|
||||
// expected
|
||||
}
|
||||
|
||||
const response = ctx.switchToHttp().getResponse();
|
||||
expect(response.setHeader).toHaveBeenCalledWith('Retry-After', 45);
|
||||
});
|
||||
});
|
||||
|
||||
describe('key strategy', () => {
|
||||
it('uses IP-based key by default', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5, keyStrategy: 'ip' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await guard.canActivate(
|
||||
buildContext({ ip: '10.0.0.1', method: 'POST', routePath: '/auth/login' }),
|
||||
);
|
||||
|
||||
const evalCall = redis.getClient().eval.mock.calls[0];
|
||||
const key: string = evalCall[2];
|
||||
expect(key).toContain('ip:10.0.0.1');
|
||||
expect(key).toContain('POST');
|
||||
expect(key).toContain('/auth/login');
|
||||
});
|
||||
|
||||
it('uses X-Forwarded-For IP when available', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5, keyStrategy: 'ip' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await guard.canActivate(
|
||||
buildContext({ forwarded: '203.0.113.5, 70.41.3.18', ip: '127.0.0.1' }),
|
||||
);
|
||||
|
||||
const evalCall = redis.getClient().eval.mock.calls[0];
|
||||
const key: string = evalCall[2];
|
||||
expect(key).toContain('ip:203.0.113.5');
|
||||
});
|
||||
|
||||
it('uses authenticated user ID when strategy is "user"', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 10, keyStrategy: 'user' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await guard.canActivate(
|
||||
buildContext({
|
||||
user: { sub: 'user-abc', role: 'BUYER' },
|
||||
method: 'POST',
|
||||
routePath: '/listings',
|
||||
}),
|
||||
);
|
||||
|
||||
const evalCall = redis.getClient().eval.mock.calls[0];
|
||||
const key: string = evalCall[2];
|
||||
expect(key).toContain('user:user-abc');
|
||||
expect(key).toContain('POST');
|
||||
expect(key).toContain('/listings');
|
||||
});
|
||||
|
||||
it('falls back to IP when strategy is "user" but no authenticated user', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 30, keyStrategy: 'user' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await guard.canActivate(
|
||||
buildContext({ user: null, ip: '10.0.0.5', routePath: '/search' }),
|
||||
);
|
||||
|
||||
const evalCall = redis.getClient().eval.mock.calls[0];
|
||||
const key: string = evalCall[2];
|
||||
expect(key).toContain('ip:10.0.0.5');
|
||||
});
|
||||
});
|
||||
|
||||
describe('admin bypass', () => {
|
||||
it('bypasses rate limit for ADMIN users by default', async () => {
|
||||
const redis = mockRedis();
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5, keyStrategy: 'ip' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
const result = await guard.canActivate(
|
||||
buildContext({ user: { sub: 'admin-1', role: 'ADMIN' } }),
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
// Redis should not be called — admin bypasses
|
||||
expect(redis.getClient).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does not bypass ADMIN when adminBypass is false', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 100, keyStrategy: 'ip', adminBypass: false }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await guard.canActivate(
|
||||
buildContext({ user: { sub: 'admin-1', role: 'ADMIN' } }),
|
||||
);
|
||||
|
||||
// Redis WAS called — no bypass
|
||||
expect(redis.getClient).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does not bypass for non-ADMIN roles', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5, keyStrategy: 'ip' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await guard.canActivate(
|
||||
buildContext({ user: { sub: 'buyer-1', role: 'BUYER' } }),
|
||||
);
|
||||
|
||||
expect(redis.getClient).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('fail-open behavior', () => {
|
||||
it('allows request when Redis throws an error', async () => {
|
||||
const redis = mockRedis({ throwError: true });
|
||||
const logger = mockLogger();
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5, keyStrategy: 'ip' }),
|
||||
logger,
|
||||
);
|
||||
|
||||
const result = await guard.canActivate(buildContext());
|
||||
expect(result).toBe(true);
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Endpoint rate limit check failed'),
|
||||
'EndpointRateLimitGuard',
|
||||
);
|
||||
});
|
||||
|
||||
it('re-throws 429 errors (does not swallow intentional rejections)', async () => {
|
||||
// Simulate a situation where Redis returns limit exceeded
|
||||
const redis = mockRedis({ evalResult: [5, 30_000] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5, keyStrategy: 'ip' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await expect(guard.canActivate(buildContext())).rejects.toThrow(HttpException);
|
||||
});
|
||||
});
|
||||
|
||||
describe('window seconds', () => {
|
||||
it('defaults to 60 seconds when windowSeconds is not specified', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5 }), // no windowSeconds
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await guard.canActivate(buildContext());
|
||||
|
||||
const evalCall = redis.getClient().eval.mock.calls[0];
|
||||
// ARGV[2] = windowMs = 60 * 1000
|
||||
expect(evalCall[4]).toBe(60_000);
|
||||
// ARGV[5] = windowSeconds = 60
|
||||
expect(evalCall[7]).toBe(60);
|
||||
});
|
||||
|
||||
it('uses custom window seconds when specified', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 10, windowSeconds: 120 }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await guard.canActivate(buildContext());
|
||||
|
||||
const evalCall = redis.getClient().eval.mock.calls[0];
|
||||
// ARGV[2] = windowMs = 120 * 1000
|
||||
expect(evalCall[4]).toBe(120_000);
|
||||
// ARGV[5] = windowSeconds = 120
|
||||
expect(evalCall[7]).toBe(120);
|
||||
});
|
||||
});
|
||||
|
||||
describe('logging', () => {
|
||||
it('logs warning when rate limit exceeded', async () => {
|
||||
const redis = mockRedis({ evalResult: [5, 30_000] });
|
||||
const logger = mockLogger();
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5, keyStrategy: 'ip' }),
|
||||
logger,
|
||||
);
|
||||
|
||||
try {
|
||||
await guard.canActivate(buildContext({ method: 'POST', routePath: '/auth/login' }));
|
||||
} catch {
|
||||
// expected
|
||||
}
|
||||
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Endpoint rate limit exceeded'),
|
||||
'EndpointRateLimitGuard',
|
||||
);
|
||||
});
|
||||
|
||||
it('logs warning on Redis failure', async () => {
|
||||
const redis = mockRedis({ throwError: true });
|
||||
const logger = mockLogger();
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5, keyStrategy: 'ip' }),
|
||||
logger,
|
||||
);
|
||||
|
||||
await guard.canActivate(buildContext());
|
||||
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Redis error'),
|
||||
'EndpointRateLimitGuard',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Retry-After computation', () => {
|
||||
it('ensures minimum Retry-After of 1 second', async () => {
|
||||
// retryAfterMs = 500 → should ceil to 1 second
|
||||
const redis = mockRedis({ evalResult: [5, 500] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5, keyStrategy: 'ip' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
const ctx = buildContext();
|
||||
try {
|
||||
await guard.canActivate(ctx);
|
||||
} catch (error) {
|
||||
const body = (error as HttpException).getResponse() as any;
|
||||
expect(body.retryAfter).toBe(1);
|
||||
}
|
||||
});
|
||||
|
||||
it('rounds up Retry-After to next second', async () => {
|
||||
// retryAfterMs = 15100 → should ceil to 16 seconds
|
||||
const redis = mockRedis({ evalResult: [5, 15_100] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 5, keyStrategy: 'ip' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
const ctx = buildContext();
|
||||
try {
|
||||
await guard.canActivate(ctx);
|
||||
} catch (error) {
|
||||
const body = (error as HttpException).getResponse() as any;
|
||||
expect(body.retryAfter).toBe(16);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('endpoint path in key', () => {
|
||||
it('uses route path when available', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 100, keyStrategy: 'ip' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await guard.canActivate(
|
||||
buildContext({ routePath: '/payments/callback/:provider', path: '/payments/callback/vnpay' }),
|
||||
);
|
||||
|
||||
const evalCall = redis.getClient().eval.mock.calls[0];
|
||||
const key: string = evalCall[2];
|
||||
// Should use the route pattern, not the resolved path
|
||||
expect(key).toContain('/payments/callback/:provider');
|
||||
});
|
||||
|
||||
it('falls back to request path when no route defined', async () => {
|
||||
const redis = mockRedis({ evalResult: [1, 0] });
|
||||
const guard = new EndpointRateLimitGuard(
|
||||
redis,
|
||||
mockReflector({ limit: 30, keyStrategy: 'ip' }),
|
||||
mockLogger(),
|
||||
);
|
||||
|
||||
await guard.canActivate(buildContext({ path: '/search' }));
|
||||
|
||||
const evalCall = redis.getClient().eval.mock.calls[0];
|
||||
const key: string = evalCall[2];
|
||||
expect(key).toContain('/search');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,57 @@
|
||||
import { SetMetadata } from '@nestjs/common';
|
||||
|
||||
/**
|
||||
* Metadata key for per-endpoint rate limit configuration.
|
||||
* Read by EndpointRateLimitGuard via reflector.
|
||||
*/
|
||||
export const ENDPOINT_RATE_LIMIT_KEY = 'endpoint_rate_limit';
|
||||
|
||||
export interface EndpointRateLimitOptions {
|
||||
/**
|
||||
* Maximum number of requests allowed within the window.
|
||||
* @example 5 (for login: 5 attempts per minute)
|
||||
*/
|
||||
limit: number;
|
||||
|
||||
/**
|
||||
* Sliding window duration in seconds.
|
||||
* @default 60
|
||||
*/
|
||||
windowSeconds?: number;
|
||||
|
||||
/**
|
||||
* Key strategy for identifying the requester:
|
||||
* - `'ip'` — key by client IP (for unauthenticated endpoints like login/register)
|
||||
* - `'user'` — key by authenticated user ID (falls back to IP if unauthenticated)
|
||||
* @default 'ip'
|
||||
*/
|
||||
keyStrategy?: 'ip' | 'user';
|
||||
|
||||
/**
|
||||
* Whether ADMIN users bypass this rate limit.
|
||||
* @default true
|
||||
*/
|
||||
adminBypass?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Decorator to configure per-endpoint rate limiting with Redis sliding window.
|
||||
*
|
||||
* Works in conjunction with `EndpointRateLimitGuard`. Apply it to specific
|
||||
* controller methods or the entire controller class to enforce granular
|
||||
* rate limits keyed by IP or authenticated user + endpoint path.
|
||||
*
|
||||
* @example
|
||||
* // Login: 5 attempts per minute per IP
|
||||
* @EndpointRateLimit({ limit: 5, windowSeconds: 60, keyStrategy: 'ip' })
|
||||
*
|
||||
* @example
|
||||
* // Listing creation: 10 per minute per user
|
||||
* @EndpointRateLimit({ limit: 10, windowSeconds: 60, keyStrategy: 'user' })
|
||||
*
|
||||
* @example
|
||||
* // Search: 30 per minute per user, admin bypass enabled
|
||||
* @EndpointRateLimit({ limit: 30, windowSeconds: 60, keyStrategy: 'user', adminBypass: true })
|
||||
*/
|
||||
export const EndpointRateLimit = (options: EndpointRateLimitOptions) =>
|
||||
SetMetadata<string, EndpointRateLimitOptions>(ENDPOINT_RATE_LIMIT_KEY, options);
|
||||
@@ -0,0 +1,218 @@
|
||||
import {
|
||||
Injectable,
|
||||
type CanActivate,
|
||||
type ExecutionContext,
|
||||
HttpException,
|
||||
HttpStatus,
|
||||
} from '@nestjs/common';
|
||||
import { type Reflector } from '@nestjs/core';
|
||||
import { type Request, type Response } from 'express';
|
||||
import {
|
||||
ENDPOINT_RATE_LIMIT_KEY,
|
||||
type EndpointRateLimitOptions,
|
||||
} from '../decorators/endpoint-rate-limit.decorator';
|
||||
import { type LoggerService } from '../logger.service';
|
||||
import { type RedisService } from '../redis.service';
|
||||
|
||||
/** Express request extended with optional JWT user payload. */
|
||||
interface AuthenticatedRequest extends Request {
|
||||
user?: { sub: string; role: string };
|
||||
route?: { path: string };
|
||||
}
|
||||
|
||||
/**
|
||||
* Lua script implementing a true sliding-window rate limiter using a Redis sorted set.
|
||||
*
|
||||
* Algorithm:
|
||||
* 1. Remove all entries outside the current window (older than `now - windowMs`).
|
||||
* 2. Count remaining entries (= requests within the window).
|
||||
* 3. If under the limit, add the current request with score = current timestamp.
|
||||
* 4. Set TTL on the key to auto-expire after the window.
|
||||
* 5. Return [currentCount, oldestEntryAge] so we can compute Retry-After.
|
||||
*
|
||||
* This is more precise than a fixed-window counter because it considers
|
||||
* the exact timestamp of each request rather than bucketing.
|
||||
*
|
||||
* KEYS[1] = rate limit key
|
||||
* ARGV[1] = now (ms timestamp)
|
||||
* ARGV[2] = windowMs (window duration in ms)
|
||||
* ARGV[3] = limit (max requests)
|
||||
* ARGV[4] = unique request ID (to distinguish concurrent requests)
|
||||
* ARGV[5] = windowSeconds (for key TTL)
|
||||
*/
|
||||
const SLIDING_WINDOW_LUA = `
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local windowMs = tonumber(ARGV[2])
|
||||
local limit = tonumber(ARGV[3])
|
||||
local requestId = ARGV[4]
|
||||
local windowSec = tonumber(ARGV[5])
|
||||
|
||||
-- Remove entries outside the window
|
||||
redis.call('ZREMRANGEBYSCORE', key, 0, now - windowMs)
|
||||
|
||||
-- Count current requests in window
|
||||
local current = redis.call('ZCARD', key)
|
||||
|
||||
if current < limit then
|
||||
-- Add this request
|
||||
redis.call('ZADD', key, now, requestId)
|
||||
-- Ensure key expires after the window
|
||||
redis.call('EXPIRE', key, windowSec + 1)
|
||||
return {current + 1, 0}
|
||||
else
|
||||
-- Rate limit exceeded — compute time until oldest entry expires
|
||||
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
||||
local retryAfterMs = 0
|
||||
if #oldest >= 2 then
|
||||
retryAfterMs = tonumber(oldest[2]) + windowMs - now
|
||||
if retryAfterMs < 0 then retryAfterMs = 0 end
|
||||
end
|
||||
return {current, retryAfterMs}
|
||||
end
|
||||
`;
|
||||
|
||||
/** Monotonically increasing counter to generate unique request IDs. */
|
||||
let requestCounter = 0;
|
||||
|
||||
/**
|
||||
* Guard that enforces per-endpoint rate limiting using a Redis sorted-set sliding window.
|
||||
*
|
||||
* Only activates on routes decorated with `@EndpointRateLimit()`.
|
||||
* Routes without the decorator are not affected.
|
||||
*
|
||||
* Features:
|
||||
* - True sliding window (sorted set) for accurate rate measurement
|
||||
* - Keys by IP or authenticated user + endpoint path
|
||||
* - Proper 429 response with `Retry-After` header
|
||||
* - Rate limit headers (`X-RateLimit-Limit`, `X-RateLimit-Remaining`, `X-RateLimit-Reset`)
|
||||
* - Admin bypass support
|
||||
* - Fail-open on Redis errors (logs warning, allows request)
|
||||
*/
|
||||
@Injectable()
|
||||
export class EndpointRateLimitGuard implements CanActivate {
|
||||
constructor(
|
||||
private readonly redis: RedisService,
|
||||
private readonly reflector: Reflector,
|
||||
private readonly logger: LoggerService,
|
||||
) {}
|
||||
|
||||
async canActivate(context: ExecutionContext): Promise<boolean> {
|
||||
const options = this.reflector.getAllAndOverride<EndpointRateLimitOptions | undefined>(
|
||||
ENDPOINT_RATE_LIMIT_KEY,
|
||||
[context.getHandler(), context.getClass()],
|
||||
);
|
||||
|
||||
// No @EndpointRateLimit decorator → skip
|
||||
if (!options) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const request = context.switchToHttp().getRequest<AuthenticatedRequest>();
|
||||
const response: Response = context.switchToHttp().getResponse();
|
||||
const user = request.user;
|
||||
|
||||
// Admin bypass (default: true)
|
||||
const adminBypass = options.adminBypass !== false;
|
||||
if (adminBypass && user?.role === 'ADMIN') {
|
||||
return true;
|
||||
}
|
||||
|
||||
const windowSeconds = options.windowSeconds ?? 60;
|
||||
const windowMs = windowSeconds * 1000;
|
||||
const limit = options.limit;
|
||||
const keyStrategy = options.keyStrategy ?? 'ip';
|
||||
|
||||
// Build the rate limit key: endpoint_rate_limit:{method}:{path}:{identifier}
|
||||
const method = request.method;
|
||||
const routePath = request.route?.path ?? request.path;
|
||||
const identifier = this.resolveIdentifier(request, keyStrategy);
|
||||
|
||||
const key = `endpoint_rate_limit:${method}:${routePath}:${identifier}`;
|
||||
|
||||
try {
|
||||
const client = this.redis.getClient();
|
||||
const now = Date.now();
|
||||
const requestId = `${now}:${process.pid}:${++requestCounter}`;
|
||||
|
||||
const result = await client.eval(
|
||||
SLIDING_WINDOW_LUA,
|
||||
1,
|
||||
key,
|
||||
now,
|
||||
windowMs,
|
||||
limit,
|
||||
requestId,
|
||||
windowSeconds,
|
||||
) as [number, number];
|
||||
|
||||
const current = result[0];
|
||||
const retryAfterMs = result[1];
|
||||
|
||||
// Set rate limit headers for observability
|
||||
response.setHeader('X-RateLimit-Limit', limit);
|
||||
response.setHeader('X-RateLimit-Remaining', Math.max(0, limit - current));
|
||||
response.setHeader('X-RateLimit-Reset', windowSeconds);
|
||||
|
||||
if (current > limit || retryAfterMs > 0) {
|
||||
// The request was NOT added (limit reached) — return 429
|
||||
const retryAfterSeconds = Math.max(1, Math.ceil(retryAfterMs / 1000));
|
||||
response.setHeader('Retry-After', retryAfterSeconds);
|
||||
|
||||
this.logger.warn(
|
||||
`Endpoint rate limit exceeded: ${method} ${routePath}, ` +
|
||||
`key=${keyStrategy}:${identifier}, current=${current}/${limit}, ` +
|
||||
`retryAfter=${retryAfterSeconds}s`,
|
||||
'EndpointRateLimitGuard',
|
||||
);
|
||||
|
||||
throw new HttpException(
|
||||
{
|
||||
statusCode: HttpStatus.TOO_MANY_REQUESTS,
|
||||
message: 'Too many requests. Please try again later.',
|
||||
retryAfter: retryAfterSeconds,
|
||||
},
|
||||
HttpStatus.TOO_MANY_REQUESTS,
|
||||
);
|
||||
}
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
// Re-throw intentional 429 errors
|
||||
if (error instanceof HttpException && error.getStatus() === HttpStatus.TOO_MANY_REQUESTS) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
// Fail open on Redis errors
|
||||
this.logger.warn(
|
||||
`Endpoint rate limit check failed (Redis error), allowing request: ` +
|
||||
`${method} ${routePath}, key=${keyStrategy}:${identifier}, ` +
|
||||
`error=${error instanceof Error ? error.message : 'unknown'}`,
|
||||
'EndpointRateLimitGuard',
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the identifier for the rate limit key based on the strategy.
|
||||
*/
|
||||
private resolveIdentifier(request: AuthenticatedRequest, strategy: 'ip' | 'user'): string {
|
||||
if (strategy === 'user') {
|
||||
const user = request.user;
|
||||
if (user?.sub) {
|
||||
return `user:${user.sub}`;
|
||||
}
|
||||
// Fall back to IP for unauthenticated requests on user-keyed endpoints
|
||||
}
|
||||
|
||||
// IP-based: extract real IP behind proxy
|
||||
const forwarded = request.headers['x-forwarded-for'];
|
||||
const ip =
|
||||
typeof forwarded === 'string'
|
||||
? (forwarded.split(',')[0]?.trim() ?? '127.0.0.1')
|
||||
: (request.ip ?? '127.0.0.1');
|
||||
|
||||
return `ip:${ip}`;
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,12 @@ export {
|
||||
type UserRateLimitOptions,
|
||||
} from './guards/user-rate-limit.guard';
|
||||
export { UserRateLimit } from './decorators/user-rate-limit.decorator';
|
||||
export {
|
||||
EndpointRateLimit,
|
||||
ENDPOINT_RATE_LIMIT_KEY,
|
||||
type EndpointRateLimitOptions,
|
||||
} from './decorators/endpoint-rate-limit.decorator';
|
||||
export { EndpointRateLimitGuard } from './guards/endpoint-rate-limit.guard';
|
||||
export { FileValidationPipe } from './pipes/file-validation.pipe';
|
||||
export type { FileValidationOptions, UploadedFile } from './pipes/file-validation.pipe';
|
||||
export { validateEnv, validateJwtSecret } from './env-validation';
|
||||
|
||||
Reference in New Issue
Block a user