Middleware¶
Middleware is the right place for cross-cutting behavior like logging, auth, or request shaping. The API mirrors per-language conventions but calls into the same Rust pipeline.
Add middleware¶
<?php
declare(strict_types=1);
use Spikard\App;
use Spikard\Config\ServerConfig;
use Spikard\Config\LifecycleHooks;
use Spikard\Config\HookResult;
use Spikard\Http\Request;
$hooks = LifecycleHooks::builder()
->withOnRequest(function (Request $request): HookResult {
error_log(sprintf(
"[%s] %s %s",
date('Y-m-d H:i:s'),
$request->method,
$request->path
));
return HookResult::continue();
})
->build();
$app = (new App(new ServerConfig(port: 8000)))
->withLifecycleHooks($hooks);
Patterns¶
Auth guards¶
Check headers/cookies, enrich context with the authenticated principal, and short-circuit on failures.
from spikard import Spikard, HTTPError
import jwt
app = Spikard()
@app.on_request
async def auth_guard(request):
# Extract token from Authorization header
auth_header = request.get("headers", {}).get("authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPError(401, "Missing or invalid authorization header")
token = auth_header[7:] # Strip "Bearer "
try:
# Verify and decode JWT
payload = jwt.decode(token, "your-secret-key", algorithms=["HS256"])
# Enrich context with authenticated user
request["context"] = request.get("context", {})
request["context"]["user_id"] = payload["sub"]
request["context"]["roles"] = payload.get("roles", [])
return request
except jwt.InvalidTokenError:
raise HTTPError(401, "Invalid token")
import { Spikard, type Request, HTTPError } from "spikard";
import * as jwt from "jsonwebtoken";
interface JWTPayload {
sub: string;
iat: number;
exp: number;
roles?: string[];
}
const app = new Spikard();
app.onRequest(async (request: Request): Promise<Request> => {
// Extract token from Authorization header
const authHeader = request.headers?.authorization || "";
if (!authHeader.startsWith("Bearer ")) {
throw new HTTPError(401, "Missing or invalid authorization header");
}
const token = authHeader.slice(7); // Strip "Bearer "
try {
// Verify and decode JWT
const payload = jwt.verify(token, "your-secret-key") as JWTPayload;
// Enrich context with authenticated user
request.context = request.context || {};
request.context.userId = payload.sub;
request.context.roles = payload.roles || [];
return request;
} catch (error) {
throw new HTTPError(401, "Invalid token");
}
});
require 'spikard'
require 'jwt'
app = Spikard::App.new
app.on_request do |request|
# Extract token from Authorization header
auth_header = request.dig(:headers, :authorization) || ""
unless auth_header.start_with?("Bearer ")
raise Spikard::HTTPError.new(401, "Missing or invalid authorization header")
end
token = auth_header[7..-1] # Strip "Bearer "
begin
# Verify and decode JWT
payload = JWT.decode(token, "your-secret-key", true, { algorithm: 'HS256' })[0]
# Enrich context with authenticated user
request[:context] ||= {}
request[:context][:user_id] = payload["sub"]
request[:context][:roles] = payload["roles"] || []
request
rescue JWT::DecodeError
raise Spikard::HTTPError.new(401, "Invalid token")
end
end
<?php
declare(strict_types=1);
use Firebase\JWT\JWT;
use Firebase\JWT\Key;
use Spikard\App;
use Spikard\Config\ServerConfig;
use Spikard\Config\LifecycleHooks;
use Spikard\Config\HookResult;
use Spikard\Http\Request;
use Spikard\Http\Response;
$hooks = LifecycleHooks::builder()
->withPreHandler(function (Request $request): HookResult {
// Extract token from Authorization header
$authHeader = $request->headers['authorization'] ?? '';
if (!str_starts_with($authHeader, 'Bearer ')) {
return HookResult::shortCircuit(
Response::json(['error' => 'Missing or invalid authorization header'], 401)
);
}
$token = substr($authHeader, 7); // Strip "Bearer "
try {
// Verify and decode JWT
$payload = JWT::decode($token, new Key('your-secret-key', 'HS256'));
// Enrich context with authenticated user
$request->context['user_id'] = $payload->sub;
$request->context['roles'] = $payload->roles ?? [];
return HookResult::continue();
} catch (\Exception $e) {
return HookResult::shortCircuit(
Response::json(['error' => 'Invalid token'], 401)
);
}
})
->build();
$app = (new App(new ServerConfig(port: 8000)))
->withLifecycleHooks($hooks);
use axum::{
extract::Request,
http::{header, StatusCode},
middleware::Next,
response::Response,
Extension,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub roles: Vec<String>,
pub exp: usize,
}
#[derive(Debug, Clone)]
pub struct AuthContext {
pub user_id: String,
pub roles: Vec<String>,
}
pub async fn auth_guard(
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
// Extract token from Authorization header
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if !auth_header.starts_with("Bearer ") {
return Err(StatusCode::UNAUTHORIZED);
}
let token = &auth_header[7..]; // Strip "Bearer "
// Verify and decode JWT
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(b"your-secret-key"),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
// Enrich context with authenticated user
let auth_ctx = AuthContext {
user_id: token_data.claims.sub,
roles: token_data.claims.roles,
};
request.extensions_mut().insert(auth_ctx);
Ok(next.run(request).await)
}
// Usage with App:
// app.layer(axum::middleware::from_fn(auth_guard));
Observability¶
Emit structured logs and traces; forward request IDs and correlation IDs for distributed tracing.
import logging
import uuid
from spikard import Spikard
app = Spikard()
@app.on_request
async def observability_middleware(request):
# Generate or propagate request ID
request_id = request.get("headers", {}).get("x-request-id", str(uuid.uuid4()))
# Inject into context for handlers to use
request["context"] = request.get("context", {})
request["context"]["request_id"] = request_id
# Log request with structured data
logging.info("request_started", extra={
"request_id": request_id,
"method": request["method"],
"path": request["path"],
"user_agent": request.get("headers", {}).get("user-agent"),
})
return request
@app.on_response
async def response_logger(response):
request_id = response.get("context", {}).get("request_id")
logging.info("request_completed", extra={
"request_id": request_id,
"status": response["status"],
"duration_ms": response.get("duration_ms"),
})
# Propagate request ID in response headers
response["headers"] = response.get("headers", {})
response["headers"]["X-Request-ID"] = request_id
return response
import { Spikard, type Request, type Response } from "spikard";
import { v4 as uuidv4 } from "uuid";
const app = new Spikard();
app.onRequest(async (request: Request): Promise<Request> => {
// Generate or propagate request ID
const requestId = request.headers?.["x-request-id"] || uuidv4();
// Inject into context for handlers to use
request.context = request.context || {};
request.context.requestId = requestId;
// Log request with structured data
console.log(JSON.stringify({
event: "request_started",
request_id: requestId,
method: request.method,
path: request.path,
user_agent: request.headers?.["user-agent"],
}));
return request;
});
app.onResponse(async (response: Response): Promise<Response> => {
const requestId = response.context?.requestId;
console.log(JSON.stringify({
event: "request_completed",
request_id: requestId,
status: response.status,
duration_ms: response.durationMs,
}));
// Propagate request ID in response headers
response.headers = response.headers || {};
response.headers["X-Request-ID"] = requestId;
return response;
});
require 'spikard'
require 'securerandom'
require 'logger'
app = Spikard::App.new
logger = Logger.new(STDOUT)
app.on_request do |request|
# Generate or propagate request ID
request_id = request.dig(:headers, :'x-request-id') || SecureRandom.uuid
# Inject into context for handlers to use
request[:context] ||= {}
request[:context][:request_id] = request_id
# Log request with structured data
logger.info({
event: 'request_started',
request_id: request_id,
method: request[:method],
path: request[:path],
user_agent: request.dig(:headers, :'user-agent')
}.to_json)
request
end
app.on_response do |response|
request_id = response.dig(:context, :request_id)
logger.info({
event: 'request_completed',
request_id: request_id,
status: response[:status],
duration_ms: response[:duration_ms]
}.to_json)
# Propagate request ID in response headers
response[:headers] ||= {}
response[:headers][:'X-Request-ID'] = request_id
response
end
<?php
declare(strict_types=1);
use Monolog\Logger;
use Monolog\Handler\StreamHandler;
use Spikard\App;
use Spikard\Config\ServerConfig;
use Spikard\Config\LifecycleHooks;
use Spikard\Config\HookResult;
use Spikard\Http\Request;
use Spikard\Http\Response;
// PSR-3 compatible logger (Monolog)
$logger = new Logger('app');
$logger->pushHandler(new StreamHandler('php://stdout', Logger::INFO));
$hooks = LifecycleHooks::builder()
->withOnRequest(function (Request $request) use ($logger): HookResult {
// Generate or propagate request ID
$requestId = $request->headers['x-request-id'] ?? uniqid('req_', true);
// Inject into context for handlers to use
$request->context['request_id'] = $requestId;
// Log request with structured data
$logger->info('request_started', [
'request_id' => $requestId,
'method' => $request->method,
'path' => $request->path,
'user_agent' => $request->headers['user-agent'] ?? null,
]);
return HookResult::continue();
})
->withOnResponse(function (Request $request, Response $response) use ($logger): HookResult {
$requestId = $request->context['request_id'] ?? 'unknown';
$logger->info('request_completed', [
'request_id' => $requestId,
'status' => $response->status,
'duration_ms' => $response->durationMs ?? null,
]);
// Propagate request ID in response headers
$response->headers['X-Request-ID'] = $requestId;
return HookResult::continue();
})
->build();
$app = (new App(new ServerConfig(port: 8000)))
->withLifecycleHooks($hooks);
use axum::{
extract::Request,
http::HeaderValue,
middleware::Next,
response::Response,
};
use tracing::{info, span, Instrument, Level};
use uuid::Uuid;
pub async fn observability_middleware(
request: Request,
next: Next,
) -> Response {
// Generate or propagate request ID
let request_id = request
.headers()
.get("x-request-id")
.and_then(|h| h.to_str().ok())
.map(String::from)
.unwrap_or_else(|| Uuid::new_v4().to_string());
// Create span for distributed tracing
let span = span!(
Level::INFO,
"http_request",
request_id = %request_id,
method = %request.method(),
path = %request.uri().path(),
user_agent = request
.headers()
.get("user-agent")
.and_then(|h| h.to_str().ok())
.unwrap_or("unknown"),
);
// Log request start
info!(parent: &span, "request_started");
let start = std::time::Instant::now();
// Execute request within span
let mut response = next.run(request).instrument(span.clone()).await;
let duration_ms = start.elapsed().as_millis();
// Log request completion
info!(
parent: &span,
status = %response.status().as_u16(),
duration_ms = %duration_ms,
"request_completed"
);
// Propagate request ID in response headers
response.headers_mut().insert(
"x-request-id",
HeaderValue::from_str(&request_id).unwrap_or_else(|_| HeaderValue::from_static("")),
);
response
}
// Usage with App:
// app.layer(axum::middleware::from_fn(observability_middleware));
Request shaping¶
Normalize headers, coerce parameters, inject tenant/feature flags, or apply rate limiting.
from spikard import Spikard, HTTPError
import time
from collections import defaultdict
app = Spikard()
# Simple in-memory rate limiter (use Redis in production)
rate_limits = defaultdict(list)
@app.on_request
async def request_shaper(request):
# 1. Rate limiting: 100 requests per minute per IP
client_ip = request.get("client_ip", "unknown")
now = time.time()
# Clean old entries
rate_limits[client_ip] = [
ts for ts in rate_limits[client_ip] if now - ts < 60
]
if len(rate_limits[client_ip]) >= 100:
raise HTTPError(429, "Rate limit exceeded")
rate_limits[client_ip].append(now)
# 2. Normalize headers (lowercase keys)
if "headers" in request:
request["headers"] = {
k.lower(): v for k, v in request["headers"].items()
}
# 3. Inject tenant from subdomain
host = request.get("headers", {}).get("host", "")
tenant = host.split(".")[0] if "." in host else "default"
request["context"] = request.get("context", {})
request["context"]["tenant"] = tenant
# 4. Feature flags from query params or headers
feature_flags = request.get("query", {}).get("features", "").split(",")
request["context"]["features"] = set(f for f in feature_flags if f)
return request
@app.on_response
async def compress_response(response):
# Response compression for large payloads
body = response.get("body", "")
if len(body) > 1024: # Compress if > 1KB
import gzip
response["body"] = gzip.compress(body.encode())
response["headers"] = response.get("headers", {})
response["headers"]["content-encoding"] = "gzip"
return response
import { Spikard, type Request, type Response, HTTPError } from "spikard";
import * as zlib from "zlib";
const app = new Spikard();
// Simple in-memory rate limiter (use Redis in production)
const rateLimits = new Map<string, number[]>();
app.onRequest(async (request: Request): Promise<Request> => {
// 1. Rate limiting: 100 requests per minute per IP
const clientIp = request.clientIp || "unknown";
const now = Date.now();
// Clean old entries
const timestamps = (rateLimits.get(clientIp) || [])
.filter(ts => now - ts < 60000);
if (timestamps.length >= 100) {
throw new HTTPError(429, "Rate limit exceeded");
}
timestamps.push(now);
rateLimits.set(clientIp, timestamps);
// 2. Normalize headers (lowercase keys)
if (request.headers) {
const normalized: Record<string, string> = {};
for (const [key, value] of Object.entries(request.headers)) {
normalized[key.toLowerCase()] = value;
}
request.headers = normalized;
}
// 3. Inject tenant from subdomain
const host = request.headers?.host || "";
const tenant = host.includes(".") ? host.split(".")[0] : "default";
request.context = request.context || {};
request.context.tenant = tenant;
// 4. Feature flags from query params or headers
const featureStr = request.query?.features || "";
request.context.features = new Set(
featureStr.split(",").filter(f => f)
);
return request;
});
app.onResponse(async (response: Response): Promise<Response> => {
// Response compression for large payloads
const body = response.body || "";
if (body.length > 1024) { // Compress if > 1KB
response.body = zlib.gzipSync(Buffer.from(body));
response.headers = response.headers || {};
response.headers["content-encoding"] = "gzip";
}
return response;
});
require 'spikard'
require 'zlib'
app = Spikard::App.new
# Simple in-memory rate limiter (use Redis in production)
rate_limits = Hash.new { |h, k| h[k] = [] }
app.on_request do |request|
# 1. Rate limiting: 100 requests per minute per IP
client_ip = request[:client_ip] || 'unknown'
now = Time.now.to_f
# Clean old entries
rate_limits[client_ip].reject! { |ts| now - ts >= 60 }
if rate_limits[client_ip].length >= 100
raise Spikard::HTTPError.new(429, 'Rate limit exceeded')
end
rate_limits[client_ip] << now
# 2. Normalize headers (lowercase keys)
if request[:headers]
request[:headers] = request[:headers].transform_keys(&:downcase)
end
# 3. Inject tenant from subdomain
host = request.dig(:headers, :host) || ''
tenant = host.include?('.') ? host.split('.')[0] : 'default'
request[:context] ||= {}
request[:context][:tenant] = tenant
# 4. Feature flags from query params
feature_str = request.dig(:query, :features) || ''
request[:context][:features] = Set.new(feature_str.split(',').reject(&:empty?))
request
end
app.on_response do |response|
# Response compression for large payloads
body = response[:body] || ''
if body.bytesize > 1024 # Compress if > 1KB
response[:body] = Zlib::Deflate.deflate(body)
response[:headers] ||= {}
response[:headers][:'content-encoding'] = 'gzip'
end
response
end
<?php
declare(strict_types=1);
use Spikard\App;
use Spikard\Config\ServerConfig;
use Spikard\Config\RateLimitConfig;
use Spikard\Config\LifecycleHooks;
use Spikard\Config\HookResult;
use Spikard\Http\Request;
use Spikard\Http\Response;
// Server-side rate limiting configuration (uses Rust pipeline)
$rateLimit = RateLimitConfig::builder()
->withPerSecond(100)
->withBurst(200)
->withIpBased(true)
->build();
$hooks = LifecycleHooks::builder()
->withOnRequest(function (Request $request): HookResult {
// 1. Normalize headers (lowercase keys)
$normalizedHeaders = [];
foreach ($request->headers as $key => $value) {
$normalizedHeaders[strtolower($key)] = $value;
}
$request->headers = $normalizedHeaders;
// 2. Inject tenant from subdomain
$host = $request->headers['host'] ?? '';
$tenant = str_contains($host, '.') ? explode('.', $host)[0] : 'default';
$request->context['tenant'] = $tenant;
// 3. Feature flags from query params
$featureStr = $request->query['features'] ?? '';
$features = array_filter(explode(',', $featureStr));
$request->context['features'] = $features;
return HookResult::continue();
})
->withOnResponse(function (Request $request, Response $response): HookResult {
// Response compression for large payloads (handled by Rust layer)
// Add custom headers if needed
$response->headers['X-Tenant'] = $request->context['tenant'] ?? 'default';
return HookResult::continue();
})
->build();
$config = ServerConfig::builder()
->withPort(8000)
->withRateLimit($rateLimit)
->withMaxBodySize(10 * 1024 * 1024) // 10 MB limit
->withCompression(true) // Enable gzip compression
->build();
$app = (new App($config))
->withLifecycleHooks($hooks);
use axum::{
body::Body,
extract::Request,
http::{header, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Extension,
};
use governor::{Quota, RateLimiter};
use nonzero_ext::nonzero;
use std::{net::IpAddr, sync::Arc};
#[derive(Debug, Clone)]
pub struct RequestContext {
pub tenant: String,
pub features: Vec<String>,
}
type SharedRateLimiter = Arc<RateLimiter<IpAddr, governor::state::keyed::DefaultKeyedStateStore<IpAddr>, governor::clock::DefaultClock>>;
pub fn create_rate_limiter() -> SharedRateLimiter {
// 100 requests per minute per IP
Arc::new(RateLimiter::keyed(Quota::per_minute(nonzero!(100u32))))
}
pub async fn request_shaper(
Extension(limiter): Extension<SharedRateLimiter>,
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
// 1. Rate limiting
let client_ip: IpAddr = request
.headers()
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.split(',').next())
.and_then(|s| s.trim().parse().ok())
.unwrap_or([127, 0, 0, 1].into());
if limiter.check_key(&client_ip).is_err() {
return Err(StatusCode::TOO_MANY_REQUESTS);
}
// 2. Extract tenant from subdomain
let host = request
.headers()
.get(header::HOST)
.and_then(|h| h.to_str().ok())
.unwrap_or("");
let tenant = host
.split('.')
.next()
.filter(|s| !s.is_empty())
.unwrap_or("default")
.to_string();
// 3. Parse feature flags from query params
let features: Vec<String> = request
.uri()
.query()
.and_then(|q| {
q.split('&')
.find(|p| p.starts_with("features="))
.map(|p| p.trim_start_matches("features="))
})
.map(|f| f.split(',').map(String::from).collect())
.unwrap_or_default();
// Inject context for handlers
let ctx = RequestContext { tenant, features };
request.extensions_mut().insert(ctx);
Ok(next.run(request).await)
}
// Usage with App:
// let limiter = create_rate_limiter();
// app.layer(Extension(limiter));
// app.layer(axum::middleware::from_fn(request_shaper));
Middleware chaining and execution order¶
Middleware executes in the order it's registered. Request middleware runs top-to-bottom, response middleware runs bottom-to-top:
Request flow:
→ Middleware 1 (observability: log request)
→ Middleware 2 (auth: verify token)
→ Middleware 3 (rate limit: check limits)
→ Handler
← Middleware 3 (response shaping: compress)
← Middleware 2 (auth: add headers)
← Middleware 1 (observability: log response)
Register middleware in order of importance:
- Observability (request ID generation)
- Security (CORS, auth)
- Request shaping (rate limiting, normalization)
- Handler-specific middleware
Testing middleware¶
Test middleware in isolation by passing mock request/response objects:
import pytest
from your_app import auth_guard
@pytest.mark.asyncio
async def test_auth_guard_valid_token():
request = {
"headers": {"authorization": "Bearer valid-jwt-token"},
"method": "GET",
"path": "/api/users"
}
result = await auth_guard(request)
assert "context" in result
assert "user_id" in result["context"]
@pytest.mark.asyncio
async def test_auth_guard_missing_token():
request = {"headers": {}, "method": "GET", "path": "/api/users"}
with pytest.raises(HTTPError) as exc:
await auth_guard(request)
assert exc.value.status == 401
import { describe, it, expect } from "vitest";
import { authGuard } from "./middleware";
describe("authGuard", () => {
it("allows valid token", async () => {
const request = {
headers: { authorization: "Bearer valid-jwt-token" },
method: "GET",
path: "/api/users",
};
const result = await authGuard(request);
expect(result.context?.userId).toBeDefined();
});
it("rejects missing token", async () => {
const request = {
headers: {},
method: "GET",
path: "/api/users",
};
await expect(authGuard(request)).rejects.toThrow("401");
});
});
require 'spikard'
require 'rspec'
RSpec.describe 'auth_guard' do
it 'allows valid token' do
request = {
headers: { authorization: 'Bearer valid-jwt-token' },
method: 'GET',
path: '/api/users'
}
result = auth_guard.call(request)
expect(result[:context][:user_id]).to be_present
end
it 'rejects missing token' do
request = {
headers: {},
method: 'GET',
path: '/api/users'
}
expect { auth_guard.call(request) }.to raise_error(Spikard::HTTPError)
end
end
<?php
declare(strict_types=1);
use PHPUnit\Framework\TestCase;
use Spikard\Http\Request;
use Spikard\Http\Response;
use Spikard\Config\HookResult;
final class AuthGuardTest extends TestCase
{
public function testAllowsValidToken(): void
{
$request = new Request(
method: 'GET',
path: '/api/users',
headers: ['authorization' => 'Bearer valid-jwt-token']
);
$result = $this->authGuard($request);
$this->assertInstanceOf(HookResult::class, $result);
$this->assertArrayHasKey('user_id', $request->context);
}
public function testRejectsMissingToken(): void
{
$request = new Request(
method: 'GET',
path: '/api/users',
headers: []
);
$result = $this->authGuard($request);
$this->assertTrue($result->isShortCircuit());
$this->assertEquals(401, $result->getResponse()->status);
}
private function authGuard(Request $request): HookResult
{
// Your auth guard implementation
$authHeader = $request->headers['authorization'] ?? '';
if (!str_starts_with($authHeader, 'Bearer ')) {
return HookResult::shortCircuit(
Response::json(['error' => 'Unauthorized'], 401)
);
}
$request->context['user_id'] = 'extracted-user-id';
return HookResult::continue();
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
routing::get,
Router,
};
use tower::ServiceExt;
async fn handler() -> &'static str {
"OK"
}
#[tokio::test]
async fn test_auth_guard_valid_token() {
let app = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn(auth_guard));
let valid_token = create_test_jwt("user-123", vec!["admin".into()]);
let response = app
.oneshot(
Request::builder()
.uri("/")
.header("Authorization", format!("Bearer {}", valid_token))
.body(Body::empty())
.expect("failed to build request"),
)
.await
.expect("request failed");
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_auth_guard_missing_token() {
let app = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn(auth_guard));
let response = app
.oneshot(
Request::builder()
.uri("/")
.body(Body::empty())
.expect("failed to build request"),
)
.await
.expect("request failed");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_auth_guard_invalid_token() {
let app = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn(auth_guard));
let response = app
.oneshot(
Request::builder()
.uri("/")
.header("Authorization", "Bearer invalid-token")
.body(Body::empty())
.expect("failed to build request"),
)
.await
.expect("request failed");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
fn create_test_jwt(sub: &str, roles: Vec<String>) -> String {
use jsonwebtoken::{encode, EncodingKey, Header};
let claims = Claims {
sub: sub.to_string(),
roles,
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp() as usize,
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(b"your-secret-key"),
)
.expect("failed to encode JWT")
}
}
Tips¶
- Keep middleware pure and side-effect free when possible; expensive IO should be async.
- Prefer per-route middleware for sensitive endpoints.
- Use shared context keys to pass data to handlers; keep namespaced to avoid collisions.
- Chain middleware thoughtfully: observability first, then security, then request shaping.
- Test middleware in isolation with mock requests to ensure correct error handling.