feat(security): add per-endpoint API rate limiting with Redis sliding window

Implement @EndpointRateLimit() decorator and EndpointRateLimitGuard for
granular per-endpoint rate limiting using a Redis sorted-set sliding window.
This prevents brute force attacks on auth endpoints, replay attacks on
payment callbacks, and scraping on search endpoints.

Applied rate limits:
- /auth/login: 5 req/min per IP
- /auth/register: 3 req/min per IP
- /listings POST: 10 req/min per user
- /search: 30 req/min per user
- /payments/callback/*: 100 req/min per IP

Features:
- True sliding window (sorted set) for accurate rate measurement
- Configurable key strategy (IP or authenticated user)
- Admin bypass support (enabled by default)
- Fail-open on Redis errors
- Proper 429 response with Retry-After header
- Rate limit headers (X-RateLimit-Limit/Remaining/Reset)
- 22 unit tests covering all scenarios

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
Ho Ngoc Hai
2026-04-11 00:36:35 +07:00
parent f27b13f712
commit d824d16760
8 changed files with 777 additions and 4 deletions

View File

@@ -12,7 +12,7 @@ import { type CommandBus, type QueryBus } from '@nestjs/cqrs';
import { ApiTags, ApiOperation, ApiResponse, ApiBearerAuth, ApiBody } from '@nestjs/swagger';
import { Throttle } from '@nestjs/throttler';
import { type Request, type Response } from 'express';
import { UnauthorizedException } from '@modules/shared';
import { EndpointRateLimit, EndpointRateLimitGuard, UnauthorizedException } from '@modules/shared';
import { LoginUserCommand } from '../../application/commands/login-user/login-user.command';
import { RefreshTokenCommand } from '../../application/commands/refresh-token/refresh-token.command';
import { RegisterUserCommand } from '../../application/commands/register-user/register-user.command';
@@ -79,6 +79,8 @@ export class AuthController {
) {}
@Throttle({ default: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT }, auth: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT } })
@EndpointRateLimit({ limit: IS_TEST ? 10_000 : 3, windowSeconds: 60, keyStrategy: 'ip' })
@UseGuards(EndpointRateLimitGuard)
@Post('register')
@ApiOperation({ summary: 'Register a new user' })
@ApiResponse({ status: 201, description: 'User registered, auth cookies set' })
@@ -100,7 +102,8 @@ export class AuthController {
}
@Throttle({ default: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT }, auth: { ttl: 3_600_000, limit: AUTH_RATE_LIMIT } })
@UseGuards(LocalAuthGuard)
@EndpointRateLimit({ limit: IS_TEST ? 10_000 : 5, windowSeconds: 60, keyStrategy: 'ip' })
@UseGuards(EndpointRateLimitGuard, LocalAuthGuard)
@Post('login')
@ApiOperation({ summary: 'Login with phone and password' })
@ApiBody({ type: LoginDto })

View File

@@ -22,7 +22,7 @@ import {
ApiParam,
} from '@nestjs/swagger';
import { type JwtPayload, CurrentUser, Roles, JwtAuthGuard, RolesGuard } from '@modules/auth';
import { FileValidationPipe, type UploadedFile as ValidatedFile } from '@modules/shared';
import { EndpointRateLimit, EndpointRateLimitGuard, FileValidationPipe, type UploadedFile as ValidatedFile } from '@modules/shared';
import { RequireQuota, QuotaGuard } from '@modules/subscriptions';
import { CreateListingCommand } from '../../application/commands/create-listing/create-listing.command';
import { type CreateListingResult } from '../../application/commands/create-listing/create-listing.handler';
@@ -53,7 +53,8 @@ export class ListingsController {
@ApiResponse({ status: 400, description: 'Validation error' })
@ApiResponse({ status: 401, description: 'Unauthorized' })
@ApiResponse({ status: 403, description: 'Quota exceeded' })
@UseGuards(JwtAuthGuard, QuotaGuard)
@UseGuards(EndpointRateLimitGuard, JwtAuthGuard, QuotaGuard)
@EndpointRateLimit({ limit: 10, windowSeconds: 60, keyStrategy: 'user' })
@RequireQuota('listings_created')
@Post()
async createListing(

View File

@@ -19,6 +19,7 @@ import {
import { Throttle } from '@nestjs/throttler';
import { type PaymentProvider } from '@prisma/client';
import { type JwtPayload, CurrentUser, Roles, JwtAuthGuard, RolesGuard } from '@modules/auth';
import { EndpointRateLimit, EndpointRateLimitGuard } from '@modules/shared';
import { CreatePaymentCommand } from '../../application/commands/create-payment/create-payment.command';
import { type CreatePaymentResult } from '../../application/commands/create-payment/create-payment.handler';
import { HandleCallbackCommand } from '../../application/commands/handle-callback/handle-callback.command';
@@ -72,6 +73,8 @@ export class PaymentsController {
@ApiResponse({ status: 201, description: 'Callback processed successfully' })
@ApiParam({ name: 'provider', enum: ['vnpay', 'momo', 'zalopay'] })
@Throttle({ 'payment-callback': { ttl: 60_000, limit: 20 } })
@EndpointRateLimit({ limit: 100, windowSeconds: 60, keyStrategy: 'ip', adminBypass: false })
@UseGuards(EndpointRateLimitGuard)
@Post('callback/:provider')
async handleCallback(
@Param('provider') provider: string,

View File

@@ -13,6 +13,7 @@ import {
ApiBearerAuth,
} from '@nestjs/swagger';
import { Roles, JwtAuthGuard, RolesGuard } from '@modules/auth';
import { EndpointRateLimit, EndpointRateLimitGuard } from '@modules/shared';
import { ReindexAllCommand } from '../../application/commands/reindex-all/reindex-all.command';
import { type ReindexResult } from '../../application/commands/reindex-all/reindex-all.handler';
import { GeoSearchQuery } from '../../application/queries/geo-search/geo-search.query';
@@ -29,6 +30,8 @@ export class SearchController {
private readonly queryBus: QueryBus,
) {}
@EndpointRateLimit({ limit: 30, windowSeconds: 60, keyStrategy: 'user' })
@UseGuards(EndpointRateLimitGuard)
@Get()
@ApiOperation({ summary: 'Search properties', description: 'Public full-text and faceted property search' })
@ApiResponse({ status: 200, description: 'Search results returned successfully' })
@@ -52,6 +55,8 @@ export class SearchController {
);
}
@EndpointRateLimit({ limit: 30, windowSeconds: 60, keyStrategy: 'user' })
@UseGuards(EndpointRateLimitGuard)
@Get('geo')
@ApiOperation({ summary: 'Geo search properties', description: 'Public geographic radius property search' })
@ApiResponse({ status: 200, description: 'Geo search results returned successfully' })

View File

@@ -0,0 +1,480 @@
import { HttpException, HttpStatus, type ExecutionContext } from '@nestjs/common';
import type { Reflector } from '@nestjs/core';
import { type EndpointRateLimitOptions } from '../decorators/endpoint-rate-limit.decorator';
import { EndpointRateLimitGuard } from '../guards/endpoint-rate-limit.guard';
// ── helpers ──────────────────────────────────────────────────────────────────
function mockRedis(
overrides: Partial<{
evalResult: [number, number];
throwError: boolean;
}> = {},
) {
const evalFn = overrides.throwError
? vi.fn().mockRejectedValue(new Error('Redis connection lost'))
: vi.fn().mockResolvedValue(overrides.evalResult ?? [1, 0]);
return {
getClient: vi.fn().mockReturnValue({ eval: evalFn }),
isAvailable: vi.fn().mockReturnValue(!overrides.throwError),
get: vi.fn(),
set: vi.fn(),
del: vi.fn(),
onModuleDestroy: vi.fn(),
} as any;
}
function mockLogger() {
return {
log: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
} as any;
}
function mockReflector(options?: EndpointRateLimitOptions | undefined) {
return {
getAllAndOverride: vi.fn().mockReturnValue(options ?? undefined),
} as unknown as Reflector;
}
interface MockContextOptions {
user?: { sub: string; role: string } | null;
method?: string;
path?: string;
routePath?: string;
ip?: string;
forwarded?: string | string[];
}
function buildContext(opts: MockContextOptions = {}): ExecutionContext {
const headers: Record<string, string | string[] | undefined> = {};
if (opts.forwarded !== undefined) {
headers['x-forwarded-for'] = opts.forwarded;
}
const response = { setHeader: vi.fn() };
const user = 'user' in opts ? opts.user : null;
return {
switchToHttp: () => ({
getRequest: () => ({
user,
method: opts.method ?? 'GET',
path: opts.path ?? '/search',
route: opts.routePath ? { path: opts.routePath } : undefined,
ip: opts.ip ?? '192.168.1.1',
headers,
}),
getResponse: () => response,
}),
getHandler: () => ({ name: 'testHandler' }),
getClass: () => ({ name: 'TestController' }),
} as unknown as ExecutionContext;
}
// ── tests ────────────────────────────────────────────────────────────────────
describe('EndpointRateLimitGuard', () => {
describe('when no @EndpointRateLimit decorator is present', () => {
it('allows request (skips rate limiting)', async () => {
const redis = mockRedis();
const guard = new EndpointRateLimitGuard(redis, mockReflector(), mockLogger());
const result = await guard.canActivate(buildContext());
expect(result).toBe(true);
// Redis should not be called at all
expect(redis.getClient).not.toHaveBeenCalled();
});
});
describe('when @EndpointRateLimit is present', () => {
const defaultOptions: EndpointRateLimitOptions = {
limit: 5,
windowSeconds: 60,
keyStrategy: 'ip',
};
it('allows request within rate limit', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector(defaultOptions),
mockLogger(),
);
const result = await guard.canActivate(buildContext());
expect(result).toBe(true);
});
it('rejects request with 429 when limit exceeded', async () => {
// current=5 (at limit), retryAfterMs=30000
const redis = mockRedis({ evalResult: [5, 30_000] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector(defaultOptions),
mockLogger(),
);
await expect(guard.canActivate(buildContext())).rejects.toThrow(HttpException);
try {
await guard.canActivate(buildContext());
} catch (error) {
expect(error).toBeInstanceOf(HttpException);
expect((error as HttpException).getStatus()).toBe(HttpStatus.TOO_MANY_REQUESTS);
const body = (error as HttpException).getResponse();
expect(body).toMatchObject({
statusCode: 429,
message: 'Too many requests. Please try again later.',
retryAfter: 30, // 30000ms → 30s
});
}
});
it('sets rate limit headers on allowed requests', async () => {
const redis = mockRedis({ evalResult: [3, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector(defaultOptions),
mockLogger(),
);
const ctx = buildContext();
await guard.canActivate(ctx);
const response = ctx.switchToHttp().getResponse();
expect(response.setHeader).toHaveBeenCalledWith('X-RateLimit-Limit', 5);
expect(response.setHeader).toHaveBeenCalledWith('X-RateLimit-Remaining', 2); // 5 - 3
expect(response.setHeader).toHaveBeenCalledWith('X-RateLimit-Reset', 60);
});
it('sets Retry-After header when limit exceeded', async () => {
const redis = mockRedis({ evalResult: [5, 45_000] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector(defaultOptions),
mockLogger(),
);
const ctx = buildContext();
try {
await guard.canActivate(ctx);
} catch {
// expected
}
const response = ctx.switchToHttp().getResponse();
expect(response.setHeader).toHaveBeenCalledWith('Retry-After', 45);
});
});
describe('key strategy', () => {
it('uses IP-based key by default', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5, keyStrategy: 'ip' }),
mockLogger(),
);
await guard.canActivate(
buildContext({ ip: '10.0.0.1', method: 'POST', routePath: '/auth/login' }),
);
const evalCall = redis.getClient().eval.mock.calls[0];
const key: string = evalCall[2];
expect(key).toContain('ip:10.0.0.1');
expect(key).toContain('POST');
expect(key).toContain('/auth/login');
});
it('uses X-Forwarded-For IP when available', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5, keyStrategy: 'ip' }),
mockLogger(),
);
await guard.canActivate(
buildContext({ forwarded: '203.0.113.5, 70.41.3.18', ip: '127.0.0.1' }),
);
const evalCall = redis.getClient().eval.mock.calls[0];
const key: string = evalCall[2];
expect(key).toContain('ip:203.0.113.5');
});
it('uses authenticated user ID when strategy is "user"', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 10, keyStrategy: 'user' }),
mockLogger(),
);
await guard.canActivate(
buildContext({
user: { sub: 'user-abc', role: 'BUYER' },
method: 'POST',
routePath: '/listings',
}),
);
const evalCall = redis.getClient().eval.mock.calls[0];
const key: string = evalCall[2];
expect(key).toContain('user:user-abc');
expect(key).toContain('POST');
expect(key).toContain('/listings');
});
it('falls back to IP when strategy is "user" but no authenticated user', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 30, keyStrategy: 'user' }),
mockLogger(),
);
await guard.canActivate(
buildContext({ user: null, ip: '10.0.0.5', routePath: '/search' }),
);
const evalCall = redis.getClient().eval.mock.calls[0];
const key: string = evalCall[2];
expect(key).toContain('ip:10.0.0.5');
});
});
describe('admin bypass', () => {
it('bypasses rate limit for ADMIN users by default', async () => {
const redis = mockRedis();
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5, keyStrategy: 'ip' }),
mockLogger(),
);
const result = await guard.canActivate(
buildContext({ user: { sub: 'admin-1', role: 'ADMIN' } }),
);
expect(result).toBe(true);
// Redis should not be called — admin bypasses
expect(redis.getClient).not.toHaveBeenCalled();
});
it('does not bypass ADMIN when adminBypass is false', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 100, keyStrategy: 'ip', adminBypass: false }),
mockLogger(),
);
await guard.canActivate(
buildContext({ user: { sub: 'admin-1', role: 'ADMIN' } }),
);
// Redis WAS called — no bypass
expect(redis.getClient).toHaveBeenCalled();
});
it('does not bypass for non-ADMIN roles', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5, keyStrategy: 'ip' }),
mockLogger(),
);
await guard.canActivate(
buildContext({ user: { sub: 'buyer-1', role: 'BUYER' } }),
);
expect(redis.getClient).toHaveBeenCalled();
});
});
describe('fail-open behavior', () => {
it('allows request when Redis throws an error', async () => {
const redis = mockRedis({ throwError: true });
const logger = mockLogger();
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5, keyStrategy: 'ip' }),
logger,
);
const result = await guard.canActivate(buildContext());
expect(result).toBe(true);
expect(logger.warn).toHaveBeenCalledWith(
expect.stringContaining('Endpoint rate limit check failed'),
'EndpointRateLimitGuard',
);
});
it('re-throws 429 errors (does not swallow intentional rejections)', async () => {
// Simulate a situation where Redis returns limit exceeded
const redis = mockRedis({ evalResult: [5, 30_000] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5, keyStrategy: 'ip' }),
mockLogger(),
);
await expect(guard.canActivate(buildContext())).rejects.toThrow(HttpException);
});
});
describe('window seconds', () => {
it('defaults to 60 seconds when windowSeconds is not specified', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5 }), // no windowSeconds
mockLogger(),
);
await guard.canActivate(buildContext());
const evalCall = redis.getClient().eval.mock.calls[0];
// ARGV[2] = windowMs = 60 * 1000
expect(evalCall[4]).toBe(60_000);
// ARGV[5] = windowSeconds = 60
expect(evalCall[7]).toBe(60);
});
it('uses custom window seconds when specified', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 10, windowSeconds: 120 }),
mockLogger(),
);
await guard.canActivate(buildContext());
const evalCall = redis.getClient().eval.mock.calls[0];
// ARGV[2] = windowMs = 120 * 1000
expect(evalCall[4]).toBe(120_000);
// ARGV[5] = windowSeconds = 120
expect(evalCall[7]).toBe(120);
});
});
describe('logging', () => {
it('logs warning when rate limit exceeded', async () => {
const redis = mockRedis({ evalResult: [5, 30_000] });
const logger = mockLogger();
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5, keyStrategy: 'ip' }),
logger,
);
try {
await guard.canActivate(buildContext({ method: 'POST', routePath: '/auth/login' }));
} catch {
// expected
}
expect(logger.warn).toHaveBeenCalledWith(
expect.stringContaining('Endpoint rate limit exceeded'),
'EndpointRateLimitGuard',
);
});
it('logs warning on Redis failure', async () => {
const redis = mockRedis({ throwError: true });
const logger = mockLogger();
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5, keyStrategy: 'ip' }),
logger,
);
await guard.canActivate(buildContext());
expect(logger.warn).toHaveBeenCalledWith(
expect.stringContaining('Redis error'),
'EndpointRateLimitGuard',
);
});
});
describe('Retry-After computation', () => {
it('ensures minimum Retry-After of 1 second', async () => {
// retryAfterMs = 500 → should ceil to 1 second
const redis = mockRedis({ evalResult: [5, 500] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5, keyStrategy: 'ip' }),
mockLogger(),
);
const ctx = buildContext();
try {
await guard.canActivate(ctx);
} catch (error) {
const body = (error as HttpException).getResponse() as any;
expect(body.retryAfter).toBe(1);
}
});
it('rounds up Retry-After to next second', async () => {
// retryAfterMs = 15100 → should ceil to 16 seconds
const redis = mockRedis({ evalResult: [5, 15_100] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 5, keyStrategy: 'ip' }),
mockLogger(),
);
const ctx = buildContext();
try {
await guard.canActivate(ctx);
} catch (error) {
const body = (error as HttpException).getResponse() as any;
expect(body.retryAfter).toBe(16);
}
});
});
describe('endpoint path in key', () => {
it('uses route path when available', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 100, keyStrategy: 'ip' }),
mockLogger(),
);
await guard.canActivate(
buildContext({ routePath: '/payments/callback/:provider', path: '/payments/callback/vnpay' }),
);
const evalCall = redis.getClient().eval.mock.calls[0];
const key: string = evalCall[2];
// Should use the route pattern, not the resolved path
expect(key).toContain('/payments/callback/:provider');
});
it('falls back to request path when no route defined', async () => {
const redis = mockRedis({ evalResult: [1, 0] });
const guard = new EndpointRateLimitGuard(
redis,
mockReflector({ limit: 30, keyStrategy: 'ip' }),
mockLogger(),
);
await guard.canActivate(buildContext({ path: '/search' }));
const evalCall = redis.getClient().eval.mock.calls[0];
const key: string = evalCall[2];
expect(key).toContain('/search');
});
});
});

View File

@@ -0,0 +1,57 @@
import { SetMetadata } from '@nestjs/common';
/**
* Metadata key for per-endpoint rate limit configuration.
* Read by EndpointRateLimitGuard via reflector.
*/
export const ENDPOINT_RATE_LIMIT_KEY = 'endpoint_rate_limit';
export interface EndpointRateLimitOptions {
/**
* Maximum number of requests allowed within the window.
* @example 5 (for login: 5 attempts per minute)
*/
limit: number;
/**
* Sliding window duration in seconds.
* @default 60
*/
windowSeconds?: number;
/**
* Key strategy for identifying the requester:
* - `'ip'` — key by client IP (for unauthenticated endpoints like login/register)
* - `'user'` — key by authenticated user ID (falls back to IP if unauthenticated)
* @default 'ip'
*/
keyStrategy?: 'ip' | 'user';
/**
* Whether ADMIN users bypass this rate limit.
* @default true
*/
adminBypass?: boolean;
}
/**
* Decorator to configure per-endpoint rate limiting with Redis sliding window.
*
* Works in conjunction with `EndpointRateLimitGuard`. Apply it to specific
* controller methods or the entire controller class to enforce granular
* rate limits keyed by IP or authenticated user + endpoint path.
*
* @example
* // Login: 5 attempts per minute per IP
* @EndpointRateLimit({ limit: 5, windowSeconds: 60, keyStrategy: 'ip' })
*
* @example
* // Listing creation: 10 per minute per user
* @EndpointRateLimit({ limit: 10, windowSeconds: 60, keyStrategy: 'user' })
*
* @example
* // Search: 30 per minute per user, admin bypass enabled
* @EndpointRateLimit({ limit: 30, windowSeconds: 60, keyStrategy: 'user', adminBypass: true })
*/
export const EndpointRateLimit = (options: EndpointRateLimitOptions) =>
SetMetadata<string, EndpointRateLimitOptions>(ENDPOINT_RATE_LIMIT_KEY, options);

View File

@@ -0,0 +1,218 @@
import {
Injectable,
type CanActivate,
type ExecutionContext,
HttpException,
HttpStatus,
} from '@nestjs/common';
import { type Reflector } from '@nestjs/core';
import { type Request, type Response } from 'express';
import {
ENDPOINT_RATE_LIMIT_KEY,
type EndpointRateLimitOptions,
} from '../decorators/endpoint-rate-limit.decorator';
import { type LoggerService } from '../logger.service';
import { type RedisService } from '../redis.service';
/** Express request extended with optional JWT user payload. */
interface AuthenticatedRequest extends Request {
user?: { sub: string; role: string };
route?: { path: string };
}
/**
* Lua script implementing a true sliding-window rate limiter using a Redis sorted set.
*
* Algorithm:
* 1. Remove all entries outside the current window (older than `now - windowMs`).
* 2. Count remaining entries (= requests within the window).
* 3. If under the limit, add the current request with score = current timestamp.
* 4. Set TTL on the key to auto-expire after the window.
* 5. Return [currentCount, oldestEntryAge] so we can compute Retry-After.
*
* This is more precise than a fixed-window counter because it considers
* the exact timestamp of each request rather than bucketing.
*
* KEYS[1] = rate limit key
* ARGV[1] = now (ms timestamp)
* ARGV[2] = windowMs (window duration in ms)
* ARGV[3] = limit (max requests)
* ARGV[4] = unique request ID (to distinguish concurrent requests)
* ARGV[5] = windowSeconds (for key TTL)
*/
const SLIDING_WINDOW_LUA = `
local key = KEYS[1]
local now = tonumber(ARGV[1])
local windowMs = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
local requestId = ARGV[4]
local windowSec = tonumber(ARGV[5])
-- Remove entries outside the window
redis.call('ZREMRANGEBYSCORE', key, 0, now - windowMs)
-- Count current requests in window
local current = redis.call('ZCARD', key)
if current < limit then
-- Add this request
redis.call('ZADD', key, now, requestId)
-- Ensure key expires after the window
redis.call('EXPIRE', key, windowSec + 1)
return {current + 1, 0}
else
-- Rate limit exceeded — compute time until oldest entry expires
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local retryAfterMs = 0
if #oldest >= 2 then
retryAfterMs = tonumber(oldest[2]) + windowMs - now
if retryAfterMs < 0 then retryAfterMs = 0 end
end
return {current, retryAfterMs}
end
`;
/** Monotonically increasing counter to generate unique request IDs. */
let requestCounter = 0;
/**
* Guard that enforces per-endpoint rate limiting using a Redis sorted-set sliding window.
*
* Only activates on routes decorated with `@EndpointRateLimit()`.
* Routes without the decorator are not affected.
*
* Features:
* - True sliding window (sorted set) for accurate rate measurement
* - Keys by IP or authenticated user + endpoint path
* - Proper 429 response with `Retry-After` header
* - Rate limit headers (`X-RateLimit-Limit`, `X-RateLimit-Remaining`, `X-RateLimit-Reset`)
* - Admin bypass support
* - Fail-open on Redis errors (logs warning, allows request)
*/
@Injectable()
export class EndpointRateLimitGuard implements CanActivate {
constructor(
private readonly redis: RedisService,
private readonly reflector: Reflector,
private readonly logger: LoggerService,
) {}
async canActivate(context: ExecutionContext): Promise<boolean> {
const options = this.reflector.getAllAndOverride<EndpointRateLimitOptions | undefined>(
ENDPOINT_RATE_LIMIT_KEY,
[context.getHandler(), context.getClass()],
);
// No @EndpointRateLimit decorator → skip
if (!options) {
return true;
}
const request = context.switchToHttp().getRequest<AuthenticatedRequest>();
const response: Response = context.switchToHttp().getResponse();
const user = request.user;
// Admin bypass (default: true)
const adminBypass = options.adminBypass !== false;
if (adminBypass && user?.role === 'ADMIN') {
return true;
}
const windowSeconds = options.windowSeconds ?? 60;
const windowMs = windowSeconds * 1000;
const limit = options.limit;
const keyStrategy = options.keyStrategy ?? 'ip';
// Build the rate limit key: endpoint_rate_limit:{method}:{path}:{identifier}
const method = request.method;
const routePath = request.route?.path ?? request.path;
const identifier = this.resolveIdentifier(request, keyStrategy);
const key = `endpoint_rate_limit:${method}:${routePath}:${identifier}`;
try {
const client = this.redis.getClient();
const now = Date.now();
const requestId = `${now}:${process.pid}:${++requestCounter}`;
const result = await client.eval(
SLIDING_WINDOW_LUA,
1,
key,
now,
windowMs,
limit,
requestId,
windowSeconds,
) as [number, number];
const current = result[0];
const retryAfterMs = result[1];
// Set rate limit headers for observability
response.setHeader('X-RateLimit-Limit', limit);
response.setHeader('X-RateLimit-Remaining', Math.max(0, limit - current));
response.setHeader('X-RateLimit-Reset', windowSeconds);
if (current > limit || retryAfterMs > 0) {
// The request was NOT added (limit reached) — return 429
const retryAfterSeconds = Math.max(1, Math.ceil(retryAfterMs / 1000));
response.setHeader('Retry-After', retryAfterSeconds);
this.logger.warn(
`Endpoint rate limit exceeded: ${method} ${routePath}, ` +
`key=${keyStrategy}:${identifier}, current=${current}/${limit}, ` +
`retryAfter=${retryAfterSeconds}s`,
'EndpointRateLimitGuard',
);
throw new HttpException(
{
statusCode: HttpStatus.TOO_MANY_REQUESTS,
message: 'Too many requests. Please try again later.',
retryAfter: retryAfterSeconds,
},
HttpStatus.TOO_MANY_REQUESTS,
);
}
return true;
} catch (error) {
// Re-throw intentional 429 errors
if (error instanceof HttpException && error.getStatus() === HttpStatus.TOO_MANY_REQUESTS) {
throw error;
}
// Fail open on Redis errors
this.logger.warn(
`Endpoint rate limit check failed (Redis error), allowing request: ` +
`${method} ${routePath}, key=${keyStrategy}:${identifier}, ` +
`error=${error instanceof Error ? error.message : 'unknown'}`,
'EndpointRateLimitGuard',
);
return true;
}
}
/**
* Resolve the identifier for the rate limit key based on the strategy.
*/
private resolveIdentifier(request: AuthenticatedRequest, strategy: 'ip' | 'user'): string {
if (strategy === 'user') {
const user = request.user;
if (user?.sub) {
return `user:${user.sub}`;
}
// Fall back to IP for unauthenticated requests on user-keyed endpoints
}
// IP-based: extract real IP behind proxy
const forwarded = request.headers['x-forwarded-for'];
const ip =
typeof forwarded === 'string'
? (forwarded.split(',')[0]?.trim() ?? '127.0.0.1')
: (request.ip ?? '127.0.0.1');
return `ip:${ip}`;
}
}

View File

@@ -20,6 +20,12 @@ export {
type UserRateLimitOptions,
} from './guards/user-rate-limit.guard';
export { UserRateLimit } from './decorators/user-rate-limit.decorator';
export {
EndpointRateLimit,
ENDPOINT_RATE_LIMIT_KEY,
type EndpointRateLimitOptions,
} from './decorators/endpoint-rate-limit.decorator';
export { EndpointRateLimitGuard } from './guards/endpoint-rate-limit.guard';
export { FileValidationPipe } from './pipes/file-validation.pipe';
export type { FileValidationOptions, UploadedFile } from './pipes/file-validation.pipe';
export { validateEnv, validateJwtSecret } from './env-validation';