Skip to content

Middleware

Middleware is the right place for cross-cutting behavior like logging, auth, or request shaping. The API mirrors per-language conventions but calls into the same Rust pipeline.

Add middleware

from spikard import Spikard

app = Spikard()

@app.on_request
async def logging_hook(request):
    print(f"{request['method']} {request['path']}")
    return request
import { Spikard, type Request } from "spikard";

const app = new Spikard();

app.onRequest(async (request: Request): Promise<Request> => {
  console.log(`${request.method} ${request.path}`);
  return request;
});
require "spikard"

app = Spikard::App.new

app.on_request do |request|
  puts "#{request[:method]} #{request[:path]}"
  request
end
<?php

declare(strict_types=1);

use Spikard\App;
use Spikard\Config\ServerConfig;
use Spikard\Config\LifecycleHooks;
use Spikard\Config\HookResult;
use Spikard\Http\Request;

$hooks = LifecycleHooks::builder()
    ->withOnRequest(function (Request $request): HookResult {
        error_log(sprintf(
            "[%s] %s %s",
            date('Y-m-d H:i:s'),
            $request->method,
            $request->path
        ));
        return HookResult::continue();
    })
    ->build();

$app = (new App(new ServerConfig(port: 8000)))
    ->withLifecycleHooks($hooks);
use tower_http::trace::TraceLayer;

let mut app = App::new();
app.layer(TraceLayer::new_for_http());

Patterns

Auth guards

Check headers/cookies, enrich context with the authenticated principal, and short-circuit on failures.

from spikard import Spikard, HTTPError
import jwt

app = Spikard()

@app.on_request
async def auth_guard(request):
    # Extract token from Authorization header
    auth_header = request.get("headers", {}).get("authorization", "")
    if not auth_header.startswith("Bearer "):
        raise HTTPError(401, "Missing or invalid authorization header")

    token = auth_header[7:]  # Strip "Bearer "

    try:
        # Verify and decode JWT
        payload = jwt.decode(token, "your-secret-key", algorithms=["HS256"])

        # Enrich context with authenticated user
        request["context"] = request.get("context", {})
        request["context"]["user_id"] = payload["sub"]
        request["context"]["roles"] = payload.get("roles", [])

        return request
    except jwt.InvalidTokenError:
        raise HTTPError(401, "Invalid token")
import { Spikard, type Request, HTTPError } from "spikard";
import * as jwt from "jsonwebtoken";

interface JWTPayload {
  sub: string;
  iat: number;
  exp: number;
  roles?: string[];
}

const app = new Spikard();

app.onRequest(async (request: Request): Promise<Request> => {
  // Extract token from Authorization header
  const authHeader = request.headers?.authorization || "";
  if (!authHeader.startsWith("Bearer ")) {
    throw new HTTPError(401, "Missing or invalid authorization header");
  }

  const token = authHeader.slice(7); // Strip "Bearer "

  try {
    // Verify and decode JWT
    const payload = jwt.verify(token, "your-secret-key") as JWTPayload;

    // Enrich context with authenticated user
    request.context = request.context || {};
    request.context.userId = payload.sub;
    request.context.roles = payload.roles || [];

    return request;
  } catch (error) {
    throw new HTTPError(401, "Invalid token");
  }
});
require 'spikard'
require 'jwt'

app = Spikard::App.new

app.on_request do |request|
  # Extract token from Authorization header
  auth_header = request.dig(:headers, :authorization) || ""
  unless auth_header.start_with?("Bearer ")
    raise Spikard::HTTPError.new(401, "Missing or invalid authorization header")
  end

  token = auth_header[7..-1]  # Strip "Bearer "

  begin
    # Verify and decode JWT
    payload = JWT.decode(token, "your-secret-key", true, { algorithm: 'HS256' })[0]

    # Enrich context with authenticated user
    request[:context] ||= {}
    request[:context][:user_id] = payload["sub"]
    request[:context][:roles] = payload["roles"] || []

    request
  rescue JWT::DecodeError
    raise Spikard::HTTPError.new(401, "Invalid token")
  end
end
<?php

declare(strict_types=1);

use Firebase\JWT\JWT;
use Firebase\JWT\Key;
use Spikard\App;
use Spikard\Config\ServerConfig;
use Spikard\Config\LifecycleHooks;
use Spikard\Config\HookResult;
use Spikard\Http\Request;
use Spikard\Http\Response;

$hooks = LifecycleHooks::builder()
    ->withPreHandler(function (Request $request): HookResult {
        // Extract token from Authorization header
        $authHeader = $request->headers['authorization'] ?? '';

        if (!str_starts_with($authHeader, 'Bearer ')) {
            return HookResult::shortCircuit(
                Response::json(['error' => 'Missing or invalid authorization header'], 401)
            );
        }

        $token = substr($authHeader, 7); // Strip "Bearer "

        try {
            // Verify and decode JWT
            $payload = JWT::decode($token, new Key('your-secret-key', 'HS256'));

            // Enrich context with authenticated user
            $request->context['user_id'] = $payload->sub;
            $request->context['roles'] = $payload->roles ?? [];

            return HookResult::continue();
        } catch (\Exception $e) {
            return HookResult::shortCircuit(
                Response::json(['error' => 'Invalid token'], 401)
            );
        }
    })
    ->build();

$app = (new App(new ServerConfig(port: 8000)))
    ->withLifecycleHooks($hooks);
use axum::{
    extract::Request,
    http::{header, StatusCode},
    middleware::Next,
    response::Response,
    Extension,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
    pub sub: String,
    pub roles: Vec<String>,
    pub exp: usize,
}

#[derive(Debug, Clone)]
pub struct AuthContext {
    pub user_id: String,
    pub roles: Vec<String>,
}

pub async fn auth_guard(
    mut request: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    // Extract token from Authorization header
    let auth_header = request
        .headers()
        .get(header::AUTHORIZATION)
        .and_then(|h| h.to_str().ok())
        .ok_or(StatusCode::UNAUTHORIZED)?;

    if !auth_header.starts_with("Bearer ") {
        return Err(StatusCode::UNAUTHORIZED);
    }

    let token = &auth_header[7..]; // Strip "Bearer "

    // Verify and decode JWT
    let token_data = decode::<Claims>(
        token,
        &DecodingKey::from_secret(b"your-secret-key"),
        &Validation::default(),
    )
    .map_err(|_| StatusCode::UNAUTHORIZED)?;

    // Enrich context with authenticated user
    let auth_ctx = AuthContext {
        user_id: token_data.claims.sub,
        roles: token_data.claims.roles,
    };

    request.extensions_mut().insert(auth_ctx);

    Ok(next.run(request).await)
}

// Usage with App:
// app.layer(axum::middleware::from_fn(auth_guard));

Observability

Emit structured logs and traces; forward request IDs and correlation IDs for distributed tracing.

import logging
import uuid
from spikard import Spikard

app = Spikard()

@app.on_request
async def observability_middleware(request):
    # Generate or propagate request ID
    request_id = request.get("headers", {}).get("x-request-id", str(uuid.uuid4()))

    # Inject into context for handlers to use
    request["context"] = request.get("context", {})
    request["context"]["request_id"] = request_id

    # Log request with structured data
    logging.info("request_started", extra={
        "request_id": request_id,
        "method": request["method"],
        "path": request["path"],
        "user_agent": request.get("headers", {}).get("user-agent"),
    })

    return request

@app.on_response
async def response_logger(response):
    request_id = response.get("context", {}).get("request_id")

    logging.info("request_completed", extra={
        "request_id": request_id,
        "status": response["status"],
        "duration_ms": response.get("duration_ms"),
    })

    # Propagate request ID in response headers
    response["headers"] = response.get("headers", {})
    response["headers"]["X-Request-ID"] = request_id

    return response
import { Spikard, type Request, type Response } from "spikard";
import { v4 as uuidv4 } from "uuid";

const app = new Spikard();

app.onRequest(async (request: Request): Promise<Request> => {
  // Generate or propagate request ID
  const requestId = request.headers?.["x-request-id"] || uuidv4();

  // Inject into context for handlers to use
  request.context = request.context || {};
  request.context.requestId = requestId;

  // Log request with structured data
  console.log(JSON.stringify({
    event: "request_started",
    request_id: requestId,
    method: request.method,
    path: request.path,
    user_agent: request.headers?.["user-agent"],
  }));

  return request;
});

app.onResponse(async (response: Response): Promise<Response> => {
  const requestId = response.context?.requestId;

  console.log(JSON.stringify({
    event: "request_completed",
    request_id: requestId,
    status: response.status,
    duration_ms: response.durationMs,
  }));

  // Propagate request ID in response headers
  response.headers = response.headers || {};
  response.headers["X-Request-ID"] = requestId;

  return response;
});
require 'spikard'
require 'securerandom'
require 'logger'

app = Spikard::App.new
logger = Logger.new(STDOUT)

app.on_request do |request|
  # Generate or propagate request ID
  request_id = request.dig(:headers, :'x-request-id') || SecureRandom.uuid

  # Inject into context for handlers to use
  request[:context] ||= {}
  request[:context][:request_id] = request_id

  # Log request with structured data
  logger.info({
    event: 'request_started',
    request_id: request_id,
    method: request[:method],
    path: request[:path],
    user_agent: request.dig(:headers, :'user-agent')
  }.to_json)

  request
end

app.on_response do |response|
  request_id = response.dig(:context, :request_id)

  logger.info({
    event: 'request_completed',
    request_id: request_id,
    status: response[:status],
    duration_ms: response[:duration_ms]
  }.to_json)

  # Propagate request ID in response headers
  response[:headers] ||= {}
  response[:headers][:'X-Request-ID'] = request_id

  response
end
<?php

declare(strict_types=1);

use Monolog\Logger;
use Monolog\Handler\StreamHandler;
use Spikard\App;
use Spikard\Config\ServerConfig;
use Spikard\Config\LifecycleHooks;
use Spikard\Config\HookResult;
use Spikard\Http\Request;
use Spikard\Http\Response;

// PSR-3 compatible logger (Monolog)
$logger = new Logger('app');
$logger->pushHandler(new StreamHandler('php://stdout', Logger::INFO));

$hooks = LifecycleHooks::builder()
    ->withOnRequest(function (Request $request) use ($logger): HookResult {
        // Generate or propagate request ID
        $requestId = $request->headers['x-request-id'] ?? uniqid('req_', true);

        // Inject into context for handlers to use
        $request->context['request_id'] = $requestId;

        // Log request with structured data
        $logger->info('request_started', [
            'request_id' => $requestId,
            'method' => $request->method,
            'path' => $request->path,
            'user_agent' => $request->headers['user-agent'] ?? null,
        ]);

        return HookResult::continue();
    })
    ->withOnResponse(function (Request $request, Response $response) use ($logger): HookResult {
        $requestId = $request->context['request_id'] ?? 'unknown';

        $logger->info('request_completed', [
            'request_id' => $requestId,
            'status' => $response->status,
            'duration_ms' => $response->durationMs ?? null,
        ]);

        // Propagate request ID in response headers
        $response->headers['X-Request-ID'] = $requestId;

        return HookResult::continue();
    })
    ->build();

$app = (new App(new ServerConfig(port: 8000)))
    ->withLifecycleHooks($hooks);
use axum::{
    extract::Request,
    http::HeaderValue,
    middleware::Next,
    response::Response,
};
use tracing::{info, span, Instrument, Level};
use uuid::Uuid;

pub async fn observability_middleware(
    request: Request,
    next: Next,
) -> Response {
    // Generate or propagate request ID
    let request_id = request
        .headers()
        .get("x-request-id")
        .and_then(|h| h.to_str().ok())
        .map(String::from)
        .unwrap_or_else(|| Uuid::new_v4().to_string());

    // Create span for distributed tracing
    let span = span!(
        Level::INFO,
        "http_request",
        request_id = %request_id,
        method = %request.method(),
        path = %request.uri().path(),
        user_agent = request
            .headers()
            .get("user-agent")
            .and_then(|h| h.to_str().ok())
            .unwrap_or("unknown"),
    );

    // Log request start
    info!(parent: &span, "request_started");

    let start = std::time::Instant::now();

    // Execute request within span
    let mut response = next.run(request).instrument(span.clone()).await;

    let duration_ms = start.elapsed().as_millis();

    // Log request completion
    info!(
        parent: &span,
        status = %response.status().as_u16(),
        duration_ms = %duration_ms,
        "request_completed"
    );

    // Propagate request ID in response headers
    response.headers_mut().insert(
        "x-request-id",
        HeaderValue::from_str(&request_id).unwrap_or_else(|_| HeaderValue::from_static("")),
    );

    response
}

// Usage with App:
// app.layer(axum::middleware::from_fn(observability_middleware));

Request shaping

Normalize headers, coerce parameters, inject tenant/feature flags, or apply rate limiting.

from spikard import Spikard, HTTPError
import time
from collections import defaultdict

app = Spikard()

# Simple in-memory rate limiter (use Redis in production)
rate_limits = defaultdict(list)

@app.on_request
async def request_shaper(request):
    # 1. Rate limiting: 100 requests per minute per IP
    client_ip = request.get("client_ip", "unknown")
    now = time.time()

    # Clean old entries
    rate_limits[client_ip] = [
        ts for ts in rate_limits[client_ip] if now - ts < 60
    ]

    if len(rate_limits[client_ip]) >= 100:
        raise HTTPError(429, "Rate limit exceeded")

    rate_limits[client_ip].append(now)

    # 2. Normalize headers (lowercase keys)
    if "headers" in request:
        request["headers"] = {
            k.lower(): v for k, v in request["headers"].items()
        }

    # 3. Inject tenant from subdomain
    host = request.get("headers", {}).get("host", "")
    tenant = host.split(".")[0] if "." in host else "default"
    request["context"] = request.get("context", {})
    request["context"]["tenant"] = tenant

    # 4. Feature flags from query params or headers
    feature_flags = request.get("query", {}).get("features", "").split(",")
    request["context"]["features"] = set(f for f in feature_flags if f)

    return request

@app.on_response
async def compress_response(response):
    # Response compression for large payloads
    body = response.get("body", "")
    if len(body) > 1024:  # Compress if > 1KB
        import gzip
        response["body"] = gzip.compress(body.encode())
        response["headers"] = response.get("headers", {})
        response["headers"]["content-encoding"] = "gzip"

    return response
import { Spikard, type Request, type Response, HTTPError } from "spikard";
import * as zlib from "zlib";

const app = new Spikard();

// Simple in-memory rate limiter (use Redis in production)
const rateLimits = new Map<string, number[]>();

app.onRequest(async (request: Request): Promise<Request> => {
  // 1. Rate limiting: 100 requests per minute per IP
  const clientIp = request.clientIp || "unknown";
  const now = Date.now();

  // Clean old entries
  const timestamps = (rateLimits.get(clientIp) || [])
    .filter(ts => now - ts < 60000);

  if (timestamps.length >= 100) {
    throw new HTTPError(429, "Rate limit exceeded");
  }

  timestamps.push(now);
  rateLimits.set(clientIp, timestamps);

  // 2. Normalize headers (lowercase keys)
  if (request.headers) {
    const normalized: Record<string, string> = {};
    for (const [key, value] of Object.entries(request.headers)) {
      normalized[key.toLowerCase()] = value;
    }
    request.headers = normalized;
  }

  // 3. Inject tenant from subdomain
  const host = request.headers?.host || "";
  const tenant = host.includes(".") ? host.split(".")[0] : "default";
  request.context = request.context || {};
  request.context.tenant = tenant;

  // 4. Feature flags from query params or headers
  const featureStr = request.query?.features || "";
  request.context.features = new Set(
    featureStr.split(",").filter(f => f)
  );

  return request;
});

app.onResponse(async (response: Response): Promise<Response> => {
  // Response compression for large payloads
  const body = response.body || "";
  if (body.length > 1024) {  // Compress if > 1KB
    response.body = zlib.gzipSync(Buffer.from(body));
    response.headers = response.headers || {};
    response.headers["content-encoding"] = "gzip";
  }

  return response;
});
require 'spikard'
require 'zlib'

app = Spikard::App.new

# Simple in-memory rate limiter (use Redis in production)
rate_limits = Hash.new { |h, k| h[k] = [] }

app.on_request do |request|
  # 1. Rate limiting: 100 requests per minute per IP
  client_ip = request[:client_ip] || 'unknown'
  now = Time.now.to_f

  # Clean old entries
  rate_limits[client_ip].reject! { |ts| now - ts >= 60 }

  if rate_limits[client_ip].length >= 100
    raise Spikard::HTTPError.new(429, 'Rate limit exceeded')
  end

  rate_limits[client_ip] << now

  # 2. Normalize headers (lowercase keys)
  if request[:headers]
    request[:headers] = request[:headers].transform_keys(&:downcase)
  end

  # 3. Inject tenant from subdomain
  host = request.dig(:headers, :host) || ''
  tenant = host.include?('.') ? host.split('.')[0] : 'default'
  request[:context] ||= {}
  request[:context][:tenant] = tenant

  # 4. Feature flags from query params
  feature_str = request.dig(:query, :features) || ''
  request[:context][:features] = Set.new(feature_str.split(',').reject(&:empty?))

  request
end

app.on_response do |response|
  # Response compression for large payloads
  body = response[:body] || ''
  if body.bytesize > 1024  # Compress if > 1KB
    response[:body] = Zlib::Deflate.deflate(body)
    response[:headers] ||= {}
    response[:headers][:'content-encoding'] = 'gzip'
  end

  response
end
<?php

declare(strict_types=1);

use Spikard\App;
use Spikard\Config\ServerConfig;
use Spikard\Config\RateLimitConfig;
use Spikard\Config\LifecycleHooks;
use Spikard\Config\HookResult;
use Spikard\Http\Request;
use Spikard\Http\Response;

// Server-side rate limiting configuration (uses Rust pipeline)
$rateLimit = RateLimitConfig::builder()
    ->withPerSecond(100)
    ->withBurst(200)
    ->withIpBased(true)
    ->build();

$hooks = LifecycleHooks::builder()
    ->withOnRequest(function (Request $request): HookResult {
        // 1. Normalize headers (lowercase keys)
        $normalizedHeaders = [];
        foreach ($request->headers as $key => $value) {
            $normalizedHeaders[strtolower($key)] = $value;
        }
        $request->headers = $normalizedHeaders;

        // 2. Inject tenant from subdomain
        $host = $request->headers['host'] ?? '';
        $tenant = str_contains($host, '.') ? explode('.', $host)[0] : 'default';
        $request->context['tenant'] = $tenant;

        // 3. Feature flags from query params
        $featureStr = $request->query['features'] ?? '';
        $features = array_filter(explode(',', $featureStr));
        $request->context['features'] = $features;

        return HookResult::continue();
    })
    ->withOnResponse(function (Request $request, Response $response): HookResult {
        // Response compression for large payloads (handled by Rust layer)
        // Add custom headers if needed
        $response->headers['X-Tenant'] = $request->context['tenant'] ?? 'default';

        return HookResult::continue();
    })
    ->build();

$config = ServerConfig::builder()
    ->withPort(8000)
    ->withRateLimit($rateLimit)
    ->withMaxBodySize(10 * 1024 * 1024)  // 10 MB limit
    ->withCompression(true)               // Enable gzip compression
    ->build();

$app = (new App($config))
    ->withLifecycleHooks($hooks);
use axum::{
    body::Body,
    extract::Request,
    http::{header, StatusCode},
    middleware::Next,
    response::{IntoResponse, Response},
    Extension,
};
use governor::{Quota, RateLimiter};
use nonzero_ext::nonzero;
use std::{net::IpAddr, sync::Arc};

#[derive(Debug, Clone)]
pub struct RequestContext {
    pub tenant: String,
    pub features: Vec<String>,
}

type SharedRateLimiter = Arc<RateLimiter<IpAddr, governor::state::keyed::DefaultKeyedStateStore<IpAddr>, governor::clock::DefaultClock>>;

pub fn create_rate_limiter() -> SharedRateLimiter {
    // 100 requests per minute per IP
    Arc::new(RateLimiter::keyed(Quota::per_minute(nonzero!(100u32))))
}

pub async fn request_shaper(
    Extension(limiter): Extension<SharedRateLimiter>,
    mut request: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    // 1. Rate limiting
    let client_ip: IpAddr = request
        .headers()
        .get("x-forwarded-for")
        .and_then(|h| h.to_str().ok())
        .and_then(|s| s.split(',').next())
        .and_then(|s| s.trim().parse().ok())
        .unwrap_or([127, 0, 0, 1].into());

    if limiter.check_key(&client_ip).is_err() {
        return Err(StatusCode::TOO_MANY_REQUESTS);
    }

    // 2. Extract tenant from subdomain
    let host = request
        .headers()
        .get(header::HOST)
        .and_then(|h| h.to_str().ok())
        .unwrap_or("");

    let tenant = host
        .split('.')
        .next()
        .filter(|s| !s.is_empty())
        .unwrap_or("default")
        .to_string();

    // 3. Parse feature flags from query params
    let features: Vec<String> = request
        .uri()
        .query()
        .and_then(|q| {
            q.split('&')
                .find(|p| p.starts_with("features="))
                .map(|p| p.trim_start_matches("features="))
        })
        .map(|f| f.split(',').map(String::from).collect())
        .unwrap_or_default();

    // Inject context for handlers
    let ctx = RequestContext { tenant, features };
    request.extensions_mut().insert(ctx);

    Ok(next.run(request).await)
}

// Usage with App:
// let limiter = create_rate_limiter();
// app.layer(Extension(limiter));
// app.layer(axum::middleware::from_fn(request_shaper));

Middleware chaining and execution order

Middleware executes in the order it's registered. Request middleware runs top-to-bottom, response middleware runs bottom-to-top:

Request flow:
  → Middleware 1 (observability: log request)
    → Middleware 2 (auth: verify token)
      → Middleware 3 (rate limit: check limits)
        → Handler
      ← Middleware 3 (response shaping: compress)
    ← Middleware 2 (auth: add headers)
  ← Middleware 1 (observability: log response)

Register middleware in order of importance:

  1. Observability (request ID generation)
  2. Security (CORS, auth)
  3. Request shaping (rate limiting, normalization)
  4. Handler-specific middleware

Testing middleware

Test middleware in isolation by passing mock request/response objects:

import pytest
from your_app import auth_guard

@pytest.mark.asyncio
async def test_auth_guard_valid_token():
    request = {
        "headers": {"authorization": "Bearer valid-jwt-token"},
        "method": "GET",
        "path": "/api/users"
    }

    result = await auth_guard(request)

    assert "context" in result
    assert "user_id" in result["context"]

@pytest.mark.asyncio
async def test_auth_guard_missing_token():
    request = {"headers": {}, "method": "GET", "path": "/api/users"}

    with pytest.raises(HTTPError) as exc:
        await auth_guard(request)

    assert exc.value.status == 401
import { describe, it, expect } from "vitest";
import { authGuard } from "./middleware";

describe("authGuard", () => {
  it("allows valid token", async () => {
    const request = {
      headers: { authorization: "Bearer valid-jwt-token" },
      method: "GET",
      path: "/api/users",
    };

    const result = await authGuard(request);

    expect(result.context?.userId).toBeDefined();
  });

  it("rejects missing token", async () => {
    const request = {
      headers: {},
      method: "GET",
      path: "/api/users",
    };

    await expect(authGuard(request)).rejects.toThrow("401");
  });
});
require 'spikard'
require 'rspec'

RSpec.describe 'auth_guard' do
  it 'allows valid token' do
    request = {
      headers: { authorization: 'Bearer valid-jwt-token' },
      method: 'GET',
      path: '/api/users'
    }

    result = auth_guard.call(request)

    expect(result[:context][:user_id]).to be_present
  end

  it 'rejects missing token' do
    request = {
      headers: {},
      method: 'GET',
      path: '/api/users'
    }

    expect { auth_guard.call(request) }.to raise_error(Spikard::HTTPError)
  end
end
<?php

declare(strict_types=1);

use PHPUnit\Framework\TestCase;
use Spikard\Http\Request;
use Spikard\Http\Response;
use Spikard\Config\HookResult;

final class AuthGuardTest extends TestCase
{
    public function testAllowsValidToken(): void
    {
        $request = new Request(
            method: 'GET',
            path: '/api/users',
            headers: ['authorization' => 'Bearer valid-jwt-token']
        );

        $result = $this->authGuard($request);

        $this->assertInstanceOf(HookResult::class, $result);
        $this->assertArrayHasKey('user_id', $request->context);
    }

    public function testRejectsMissingToken(): void
    {
        $request = new Request(
            method: 'GET',
            path: '/api/users',
            headers: []
        );

        $result = $this->authGuard($request);

        $this->assertTrue($result->isShortCircuit());
        $this->assertEquals(401, $result->getResponse()->status);
    }

    private function authGuard(Request $request): HookResult
    {
        // Your auth guard implementation
        $authHeader = $request->headers['authorization'] ?? '';

        if (!str_starts_with($authHeader, 'Bearer ')) {
            return HookResult::shortCircuit(
                Response::json(['error' => 'Unauthorized'], 401)
            );
        }

        $request->context['user_id'] = 'extracted-user-id';
        return HookResult::continue();
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use axum::{
        body::Body,
        http::{Request, StatusCode},
        routing::get,
        Router,
    };
    use tower::ServiceExt;

    async fn handler() -> &'static str {
        "OK"
    }

    #[tokio::test]
    async fn test_auth_guard_valid_token() {
        let app = Router::new()
            .route("/", get(handler))
            .layer(axum::middleware::from_fn(auth_guard));

        let valid_token = create_test_jwt("user-123", vec!["admin".into()]);

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/")
                    .header("Authorization", format!("Bearer {}", valid_token))
                    .body(Body::empty())
                    .expect("failed to build request"),
            )
            .await
            .expect("request failed");

        assert_eq!(response.status(), StatusCode::OK);
    }

    #[tokio::test]
    async fn test_auth_guard_missing_token() {
        let app = Router::new()
            .route("/", get(handler))
            .layer(axum::middleware::from_fn(auth_guard));

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/")
                    .body(Body::empty())
                    .expect("failed to build request"),
            )
            .await
            .expect("request failed");

        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    }

    #[tokio::test]
    async fn test_auth_guard_invalid_token() {
        let app = Router::new()
            .route("/", get(handler))
            .layer(axum::middleware::from_fn(auth_guard));

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/")
                    .header("Authorization", "Bearer invalid-token")
                    .body(Body::empty())
                    .expect("failed to build request"),
            )
            .await
            .expect("request failed");

        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    }

    fn create_test_jwt(sub: &str, roles: Vec<String>) -> String {
        use jsonwebtoken::{encode, EncodingKey, Header};

        let claims = Claims {
            sub: sub.to_string(),
            roles,
            exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp() as usize,
        };

        encode(
            &Header::default(),
            &claims,
            &EncodingKey::from_secret(b"your-secret-key"),
        )
        .expect("failed to encode JWT")
    }
}

Tips

  • Keep middleware pure and side-effect free when possible; expensive IO should be async.
  • Prefer per-route middleware for sensitive endpoints.
  • Use shared context keys to pass data to handlers; keep namespaced to avoid collisions.
  • Chain middleware thoughtfully: observability first, then security, then request shaping.
  • Test middleware in isolation with mock requests to ensure correct error handling.