阅读Spring WebFlux Security的CsrfWebFilter

2022-05-30 14:27:48 浏览数 (1)

上篇介绍了 CSRF 及 Spring Security 对防范 CSRF 所做的支持。Spring Security 实现了基于 Session 和 Cookie 的 Sychronizer Token Pattern,以防范 CSRF,默认实现是基于 session 的。

Spring Security 对 Servlet 和 WebFlux 技术栈分别进行了实现。 其中,基于 Servlet 技术栈的实现代码是:

代码语言:javascript复制
CsrfFilter:执行过滤、验证。
CsrfTokenRepository:定义了存储 csrf token 的接口。
CookieCsrfTokenRepository:用 cookie 保存 csrf token 的实现类
HttpSessionCsrfTokenRepository:用session 保存csrf token 的实现类
CsrfAuthenticationStrategy.onAuthentication() 每次请求生成新 csrf token

基于 WebFlux 技术栈的实现代码没有与 CsrfAuthenticationStrategy 对应的类,而是直接在 CsrfWebFilter.filter()–> continueFilterChain() –> … this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange)) 完成了。

Filter 和 TokenRepository 在Reactive模式下都有对应的类,如下:

代码语言:javascript复制
CsrfWebFilter
ServerCsrfTokenRepository
  CookieServerCsrfTokenRepository WebSessionServerCsrfTokenRepository

CookieServerCsrfTokenRepository WebSessionServerCsrfTokenRepository 的逻辑更直接一些。下面以注释的方式解读一下 CsrfWebFilter 代码。

可以特别留意以下几个 operators的用法和场景:switchIfEmpty() ilterWhen() delayUntil() Mono.defer()

代码语言:javascript复制
public class CsrfWebFilter implements WebFilter {

  public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher();

  /**
  * The attribute name to use when marking a given request as one that should not be
  * filtered.
  *
  * To use, set the attribute on your {@link ServerWebExchange}: <pre>
  *     CsrfWebFilter.skipExchange(exchange);
  * </pre>
  */
  private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER"   CsrfWebFilter.class.getName();

  private ServerWebExchangeMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;

  // 默认基于Session保存csrf token。
  private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();

  private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(
      HttpStatus.FORBIDDEN);

  private boolean isTokenFromMultipartDataEnabled;

  public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) {
    Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
    this.accessDeniedHandler = accessDeniedHandler;
  }

  public void setCsrfTokenRepository(ServerCsrfTokenRepository csrfTokenRepository) {
    Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
    this.csrfTokenRepository = csrfTokenRepository;
  }

  /**
  * 如果我们的应用GET,HEAD,TRACE,OPTIONS 这些方法会改变应用的状态,就需要自定义matcher。
  * 自定义的matcher可以有更复杂的匹配逻辑,而不仅仅限于DefaultRequireCsrfProtectionMatcher 中所用的request method。
  * 可参考 DefaultRequireCsrfProtectionMatcher。
  *
  * @param requireCsrfProtectionMatcher
  */
  public void setRequireCsrfProtectionMatcher(ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
    Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null");
    this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
  }

  /**
  * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token
  * from the body of multipart data requests.
  * @param tokenFromMultipartDataEnabled true if should read from multipart form body,
  * else false. Default is false
  */
  public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
    this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
  }

  @Override
  public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
    // 如果代码通过调用 CsrfWebFilter.skipExchange(exchange) 设置无需进行CSRF保护,则继续chain的下一个节点。
    if (Boolean.TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
      return chain.filter(exchange).then(Mono.empty());
    }

    return this.requireCsrfProtectionMatcher.matches(exchange) // 根据 requireCsrfProtectionMatcher 进行匹配
        .filter(MatchResult::isMatch) // 如果需要CSRF保护。
        // 这里matchResult只是个占位符。过滤出已经设置了CsrfToken的exchange。
        .filter((matchResult) -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
        // 验证这个csrf token,这个验证逻辑是CSRF的核心。根据csrf token是保存在session还是cookie,逻辑不同。
        // 详细逻辑参考 ServerCsrfTokenRepository 的两个实现 CookieServerCsrfTokenRepository WebSessionServerCsrfTokenRepository
        .flatMap((m) -> validateToken(exchange))
        // 校验文csrf token,继续chain的下个节点处理。
        .flatMap((m) -> continueFilterChain(exchange, chain))
        // switchIfEmpty实际是继续.filter(MatchResult::isMatch).filter((matchResult) -> 。。。) 这两个filter之后的情况。
        .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty()))
        .onErrorResume(CsrfException.class, (ex) -> this.accessDeniedHandler.handle(exchange, ex));
  }

  public static void skipExchange(ServerWebExchange exchange) {
    exchange.getAttributes().put(SHOULD_NOT_FILTER, Boolean.TRUE);
  }

  private Mono<Void> validateToken(ServerWebExchange exchange) {
    return this.csrfTokenRepository.loadToken(exchange) // 从session或者cookie中得到csrf token,作为比较基准。
        // 如果没有找到csrf token则抛出异常没有找到期望的csrf token。(之前版本的错误信息有误 https://github.com/spring-projects/spring-security/commit/a1083d9a5ce3fef8fa458a47e5a6b7a6576ec01e#diff-9c109ef13a33c07de3231051c08e424e664985d142f54ccd7830169d4adcadb1)
        // 注意这里用的defer而不是just用来避免不必要的对象创建。
        // switchIfEmpty() 感觉上就是reactive stream里的的if语句:如果是空流,则用另一个publisher。
        .switchIfEmpty(
            Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found"))))
        /**
        * 判断用户从页面提交的csrf token和从session/cookie中保存的是否一致。
        * 注意,这里用了filterWhen而不是filter。表面的原因是containsValidCsrfToken()的返回值是Mono<Boolean>而不是Boolean。
        * 而更深层的原因应该是CsrfWebFilter不应该是阻塞的,因为这个filter要处理所有的request,所以需要异步地测试、过滤。
        * 感觉上也有点像flatMap,都是接收Publisher参数。
        **/
        .filterWhen((expected) -> containsValidCsrfToken(exchange, expected))
        // 如果没有找到有效的csrf token,那么就抛出异常。
        .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token")))).then();
  }

  private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
    // 先从form里找用户提交的csrf token。
    return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
        // form 表单里没有,就从header里找
        .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
        // 最后从multipart里找
        .switchIfEmpty(tokenFromMultipartData(exchange, expected))
        // 如果找到了,则和从session或cookie里得到的csrf token做比较
        .map((actual) -> equalsConstantTime(actual, expected.getToken()));
  }

  private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
    if (!this.isTokenFromMultipartDataEnabled) {
      return Mono.empty();
    }
    ServerHttpRequest request = exchange.getRequest();
    HttpHeaders headers = request.getHeaders();
    MediaType contentType = headers.getContentType();
    if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) {
      return Mono.empty();
    }
    return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class)
        .map(FormFieldPart::value);
  }

  private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
    return Mono.defer(() -> {
      Mono<CsrfToken> csrfToken = csrfToken(exchange);
      exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
      return chain.filter(exchange);
    });
  }

  private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
    // 如果从repository中找到了csrf token就返回,否则就生成一个token返回。
    return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange));
  }

  /**
  * Constant time comparison to prevent against timing attacks.
  * @param expected
  * @param actual
  * @return
  */
  private static boolean equalsConstantTime(String expected, String actual) {
    if (expected == actual) {
      return true;
    }
    if (expected == null || actual == null) {
      return false;
    }
    // Encode after ensure that the string is not null
    byte[] expectedBytes = Utf8.encode(expected);
    byte[] actualBytes = Utf8.encode(actual);
    return MessageDigest.isEqual(expectedBytes, actualBytes);
  }

  private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
    // 委托给csrfTokenRepository 产生一个新token。
    return this.csrfTokenRepository.generateToken(exchange)
        // 注意这里的delayUntil()用法:直到token保存好之后,上步由csrfTokenRepository产生的Mono<CsrfToken>才继续走流水线。
        .delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token));
  }

  private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {

    // 如果我们应用在保证GET,HEAD,TRACE,OPTIONS等方法不改变系统状态,不需要进行CSRF保护,就可使用这个默认的matcher类。
    private static final Set<HttpMethod> ALLOWED_METHODS = new HashSet<>(
        Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS));

    @Override
    public Mono<MatchResult> matches(ServerWebExchange exchange) {
      return Mono.just(exchange.getRequest()) // 从exchagne得到request对象
          .flatMap((r) -> Mono.justOrEmpty(r.getMethod())) // 得到http request的method。
          .filter(ALLOWED_METHODS::contains) // 仅保留无需CSRF保护的方法
          .flatMap((m) -> MatchResult.notMatch()) // 过滤后,有无需CSRF保护的方法则不匹配(无需CSRF保护)。
          .switchIfEmpty(MatchResult.match()); // 过滤后,没有匹配到方法,则匹配(需CSRF保护)。
    }

  }

}

Reference: https://docs.spring.io/spring-security/site/docs/5.4.5/reference/html5/#csrf

0 人点赞