创建自定义 Spring Cloud Gateway Filter

工程 | Fredrich Ombico | 2022 年 8 月 27 日 | ...

在本文中,我们将探讨如何为 Spring Cloud Gateway 编写自定义扩展。在我们开始之前,先回顾一下 Spring Cloud Gateway 的工作原理。

Spring Cloud Gateway diagram

  1. 首先,客户端向网关发起网络请求
  2. 网关定义了许多路由,每个路由都有谓词 (Predicates) 来将请求与路由匹配。例如,您可以基于 URL 的路径段或请求的 HTTP 方法进行匹配。
  3. 匹配成功后,网关会在应用于该路由的每个过滤器上执行请求前逻辑。例如,您可能希望向请求中添加查询参数。
  4. 代理过滤器将请求路由到被代理的服务
  5. 服务执行并返回响应
  6. 网关接收到响应并在返回响应之前对每个过滤器执行请求后逻辑。例如,您可以在返回客户端之前移除不想要的响应头。

我们的扩展将对请求体进行哈希计算,并将计算出的值添加为一个名为 X-Hash 的请求头。这对应于上面图示中的步骤 3。注意:由于我们需要读取请求体,网关将受到内存限制。

首先,我们在 start.spring.io 创建一个包含 Gateway 依赖的项目。在本例中,我们将使用 Java 和 JDK 17 以及 Spring Boot 2.7.3 的 Gradle 项目。下载、解压并在您喜欢的 IDE 中打开项目并运行,以确保您已完成本地开发的设置。

接下来,我们创建一个 GatewayFilter Factory,它是一个作用域限定于特定路由的过滤器,允许我们以某种方式修改传入的 HTTP 请求或传出的 HTTP 响应。在我们的例子中,我们将通过添加额外的请求头来修改传入的 HTTP 请求。

package com.example.demo;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.List;

import org.bouncycastle.util.encoders.Hex;
import reactor.core.publisher.Mono;

import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR;

/**
 * This filter hashes the request body, placing the value in the X-Hash header.
 * Note: This causes the gateway to be memory constrained.
 * Sample usage: RequestHashing=SHA-256
 */
@Component
public class RequestHashingGatewayFilterFactory extends
        AbstractGatewayFilterFactory<RequestHashingGatewayFilterFactory.Config> {

    private static final String HASH_ATTR = "hash";
    private static final String HASH_HEADER = "X-Hash";
    private final List<HttpMessageReader<?>> messageReaders =
            HandlerStrategies.withDefaults().messageReaders();

    public RequestHashingGatewayFilterFactory() {
        super(Config.class);
    }

    @Override
    public GatewayFilter apply(Config config) {
        MessageDigest digest = config.getMessageDigest();
        return (exchange, chain) -> ServerWebExchangeUtils
                .cacheRequestBodyAndRequest(exchange, (httpRequest) -> ServerRequest
                    .create(exchange.mutate().request(httpRequest).build(),
                            messageReaders)
                    .bodyToMono(String.class)
                    .doOnNext(requestPayload -> exchange
                            .getAttributes()
                            .put(HASH_ATTR, computeHash(digest, requestPayload)))
                    .then(Mono.defer(() -> {
                        ServerHttpRequest cachedRequest = exchange.getAttribute(
                                CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR);
                        Assert.notNull(cachedRequest, 
                                "cache request shouldn't be null");
                        exchange.getAttributes()
                                .remove(CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR);

                        String hash = exchange.getAttribute(HASH_ATTR);
                        cachedRequest = cachedRequest.mutate()
                                .header(HASH_HEADER, hash)
                                .build();
                        return chain.filter(exchange.mutate()
                                .request(cachedRequest)
                                .build());
                    })));
    }

    @Override
    public List<String> shortcutFieldOrder() {
        return Collections.singletonList("algorithm");
    }

    private String computeHash(MessageDigest messageDigest, String requestPayload) {
        return Hex.toHexString(messageDigest.digest(requestPayload.getBytes()));
    }

    static class Config {

        private MessageDigest messageDigest;

        public MessageDigest getMessageDigest() {
            return messageDigest;
        }

        public void setAlgorithm(String algorithm) throws NoSuchAlgorithmException {
            messageDigest = MessageDigest.getInstance(algorithm);
        }
    }
}

让我们更详细地看一下代码

  • 我们为类添加了 @Component 注解。Spring Cloud Gateway 需要能够检测到这个类才能使用它。或者,我们可以使用 @Bean 定义一个实例。
  • 在我们的类名中,我们使用 GatewayFilterFactory 作为后缀。在 application.yaml 中添加此过滤器时,我们不包含后缀,只写 RequestHashing。这是 Spring Cloud Gateway 过滤器命名约定。
  • 我们的类也像所有其他 Spring Cloud Gateway 过滤器一样,扩展了 AbstractGatewayFilterFactory。我们还指定了一个类来配置我们的过滤器,一个嵌套的静态类 Config 有助于保持简单。配置类允许我们设置使用哪种哈希算法。
  • 重写的 apply 方法是所有工作发生的地方。在参数中,我们得到了一个配置类的实例,在这里我们可以访问用于哈希计算的 MessageDigest 实例。接下来,我们看到 (exchange, chain),这是返回的 GatewayFilter 接口类的一个 lambda 表达式。exchange 是 ServerWebExchange 的一个实例,它为 Gateway 过滤器提供了访问 HTTP 请求和响应的能力。对于我们的情况,我们希望修改 HTTP 请求,这需要我们修改 (mutate) exchange。
  • 我们需要读取请求体来生成哈希,然而,由于请求体存储在字节缓冲区中,它在过滤器中只能被读取一次。通过使用 ServerWebExchangeUtils,我们将请求作为属性缓存在 exchange 中。属性提供了一种在过滤器链中共享特定请求数据的方式。我们还将存储计算出的请求体哈希。
  • 我们使用 exchange 属性来获取缓存的请求和计算出的哈希。然后,通过添加哈希头来修改 exchange,最后将其发送到链中的下一个过滤器。
  • shortcutFieldOrder 方法有助于将参数的数量和顺序映射到过滤器。字符串 algorithm 匹配到 Config 类中的 setter 方法。

为了测试代码,我们将使用 WireMock。将依赖添加到您的 build.gradle 文件中。

testImplementation 'com.github.tomakehurst:wiremock:2.27.2'

这里我们有一个测试检查请求头的存在和值,另一个测试检查如果请求体不存在,请求头是否也不存在。

package com.example.demo;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.WireMock;
import com.github.tomakehurst.wiremock.core.WireMockConfiguration;
import org.bouncycastle.jcajce.provider.digest.SHA512;
import org.bouncycastle.util.encoders.Hex;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.http.HttpStatus;
import org.springframework.test.web.reactive.server.WebTestClient;

import static com.example.demo.RequestHashingGatewayFilterFactory.*;
import static com.example.demo.RequestHashingGatewayFilterFactoryTest.*;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT;

@SpringBootTest(
        webEnvironment = RANDOM_PORT,
        classes = RequestHashingFilterTestConfig.class)
@AutoConfigureWebTestClient
class RequestHashingGatewayFilterFactoryTest {

    @TestConfiguration
    static class RequestHashingFilterTestConfig {

        @Autowired
        RequestHashingGatewayFilterFactory requestHashingGatewayFilter;

        @Bean(destroyMethod = "stop")
        WireMockServer wireMockServer() {
            WireMockConfiguration options = wireMockConfig().dynamicPort();
            WireMockServer wireMock = new WireMockServer(options);
            wireMock.start();
            return wireMock;
        }

        @Bean
        RouteLocator testRoutes(RouteLocatorBuilder builder, WireMockServer wireMock)
                throws NoSuchAlgorithmException {
            Config config = new Config();
            config.setAlgorithm("SHA-512");

            GatewayFilter gatewayFilter = requestHashingGatewayFilter.apply(config);
            return builder
                    .routes()
                    .route(predicateSpec -> predicateSpec
                            .path("/post")
                            .filters(spec -> spec.filter(gatewayFilter))
                            .uri(wireMock.baseUrl()))
                    .build();
        }
    }

    @Autowired
    WebTestClient webTestClient;

    @Autowired
    WireMockServer wireMockServer;

    @AfterEach
    void afterEach() {
        wireMockServer.resetAll();
    }

    @Test
    void shouldAddHeaderWithComputedHash() {
        MessageDigest messageDigest = new SHA512.Digest();
        String body = "hello world";
        String expectedHash = Hex.toHexString(messageDigest.digest(body.getBytes()));

        wireMockServer.stubFor(WireMock.post("/post").willReturn(WireMock.ok()));

        webTestClient.post().uri("/post")
                .bodyValue(body)
                .exchange()
                .expectStatus()
                .isEqualTo(HttpStatus.OK);

        wireMockServer.verify(postRequestedFor(urlEqualTo("/post"))
                .withHeader("X-Hash", equalTo(expectedHash)));
    }

    @Test
    void shouldNotAddHeaderIfNoBody() {
        wireMockServer.stubFor(WireMock.post("/post").willReturn(WireMock.ok()));

        webTestClient.post().uri("/post")
                .exchange()
                .expectStatus()
                .isEqualTo(HttpStatus.OK);

        wireMockServer.verify(postRequestedFor(urlEqualTo("/post"))
                .withoutHeader("X-Hash"));
    }
}

要在网关中使用该过滤器,我们将 RequestHashing 过滤器添加到 application.yaml 中的一个路由中,使用 SHA-256 作为算法。

spring:
  cloud:
    gateway:
      routes:
        - id: demo
          uri: https://httpbin.org
          predicates:
            - Path=/post/**
          filters:
            - RequestHashing=SHA-256

我们使用 https://httpbin.org,因为它在返回的响应中显示了我们的请求头。运行应用程序并发送一个 curl 请求查看结果。

$> curl --request POST 'http://localhost:8080/post' \
--header 'Content-Type: application/json' \
--data-raw '{
    "data": {
        "hello": "world"
    }
}'

{
  ...
  "data": "{\n    \"data\": {\n        \"hello\": \"world\"\n    }\n}",
  "headers": {
        "Accept": "*/*",
        "Accept-Encoding": "gzip, deflate, br",
        "Content-Length": "48",
        "Content-Type": "application/json",
        "Forwarded": "proto=http;host=\"localhost:8080\";for=\"[0:0:0:0:0:0:0:1]:55647\"",
        "Host": "httpbin.org",
        "User-Agent": "PostmanRuntime/7.29.0",
        "X-Forwarded-Host": "localhost:8080",
        "X-Hash": "1bd93d38735501b5aec7a822f8bc8136d9f1f71a30c2020511bdd5df379772b8"
    },
  ...
}

总之,我们了解了如何为 Spring Cloud Gateway 编写自定义扩展。我们的过滤器读取了请求体以生成一个哈希值,并将其添加为请求头。我们还使用 WireMock 编写了过滤器的测试,以检查请求头的值。最后,我们运行了一个带有该过滤器的网关来验证结果。

如果您计划在 Kubernetes 集群上部署 Spring Cloud Gateway,请务必查阅 VMware Spring Cloud Gateway for Kubernetes。除了支持开源 Spring Cloud Gateway 过滤器和自定义过滤器(例如我们在上面编写的过滤器)之外,它还提供了 更多内置过滤器 来操作您的请求和响应。Spring Cloud Gateway for Kubernetes 代表 API 开发团队处理横切关注点,例如:单点登录 (SSO)、访问控制、速率限制、弹性、安全性等。

订阅 Spring 新闻通讯

通过 Spring 新闻通讯保持联系

订阅

抢先一步

VMware 提供培训和认证,助力您加速发展。

了解更多

获取支持

Tanzu Spring 通过一个简单的订阅提供对 OpenJDK™、Spring 和 Apache Tomcat® 的支持和二进制文件。

了解更多

即将举行的活动

查看 Spring 社区所有即将举行的活动。

查看全部