解决servlet流只能读取一次的问题(拦截器验签)

2025-04-02

2025-04-02-解决servlet流只能读取一次的问题(拦截器验签)

1、问题场景分析

使用拦截器对接口传来的参数进行签名验证,校验数据完整性,为获取请求参数,需要调用HttpServletRequest​的getInputStream​方法获取请求的入参,在拦截器执行完毕,需要进入controller时,发现无法正确接收参数。

2、原因分析

经过查看spring源码可知,request​的getInputStream​方法只允许调用一次,原因是因为底层获取参数时是使用下标记录当前的位置,若全部读取完毕后,下标无法回到原来的位置,request​使用usingReader​字段来记录是否读取过数据。

public ServletInputStream getInputStream() throws IOException {
        if (this.usingReader) {
            throw new IllegalStateException(sm.getString("coyoteRequest.getInputStream.ise"));
        } else {
            this.usingInputStream = true;
            if (this.inputStream == null) {
                this.inputStream = new CoyoteInputStream(this.inputBuffer);
            }

            return this.inputStream;
        }
    }

3、问题解决

  • 第一步:自定义实现ServletRequestWrapper​包装类,缓存HttpServletRequest​的流数据
  • 第二步:使用过滤器将后续传递的HttpServletRequest​都替换为自定义包装类,后续拦截器可实现重复读取

    注意:

    • 后续传递的实际都是ServletRequestWrapper​,它是HttpServletRequest​的子类,所以可以正常使用 HttpServletRequest​ 的所有方法,并且可以调用 MyHttpServletRequestWrapper​ 自己扩展的方法。

    • 在 Spring MVC 控制器方法中,@RequestMapping​ 等注解的方法如果接收 HttpServletRequest​ 参数,得到的也是 MyHttpServletRequestWrapper​,因为整个请求链(过滤器 -> 拦截器 -> 控制器)中传递的都是这个包装后的对象。

1、实现ServletRequestWrapper

无需修改,可直接复制使用

MyHttpServletRequestWrapper​缓存了HttpServletRequest​的流数据,之后直接从缓存中读取,可以做到无限读取。

public class MyHttpServletRequestWrapper extends HttpServletRequestWrapper {

    /**
     * 缓存流
     */
    private byte[] cacheBody;

    /**
     * 构造方法
     *
     * @param request
     * @throws IOException
     */
    public MyHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }

    /**
     * 获取缓冲流
     *
     * @return
     */
    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream(), StandardCharsets.UTF_8));
    }

    /**
     * 获取流
     *
     * @return
     */
    @Override
    public ServletInputStream getInputStream() throws IOException {
        // 1. 初始化缓存
        if (cacheBody == null) {
            cacheBody = StreamUtils.copyToByteArray(super.getInputStream());
        }
        // 3. 从缓存中返回流
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(cacheBody);
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }
        };
    }
}

2、定义过滤器

使用MyHttpServletRequestWrapper​替代HttpServletRequest​传递下去

@Component
public class MyHttpServletRequestWrapperFilter implements Filter {

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    /**
     * 对 request 进行包装
     *
     * @param request
     * @param response
     * @param chain
     * @throws IOException
     * @throws ServletException
     */
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {

		// 使用MyHttpServletRequestWrapper替代ServletRequest传递下去
        MyHttpServletRequestWrapper requestWrapper = new MyHttpServletRequestWrapper((HttpServletRequest) request);

        chain.doFilter(requestWrapper, response);
    }

    @Override
    public void destroy() {

    }

} 

4、拦截器获取流数据

直接调用ServletInputStream inputStream = request.getInputStream();​获取数据

@Slf4j
@Component
public class GlobalClientInterceptor implements HandlerInterceptor {

    // 允许的时间戳误差(5分钟)
    private static final long TIME_DIFF_LIMIT = 5 * 60 * 1000;

    @Resource
    private ProductService productService;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        String appId = request.getHeader("app-id");
        String timestamp = request.getHeader("timestamp");
        String signature = request.getHeader("signature");

        // 检查必要的请求头参数
        if (!StringUtils.hasText(appId) || !StringUtils.hasText(timestamp) || !StringUtils.hasText(signature)) {
            log.error("验签失败:参数缺失,appId:{}", appId);
            response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
            response.getWriter().write("Missing authentication headers");
            return false;
        }

        // 校验时间戳,防止重放攻击
        long requestTime;
        try {
            requestTime = Long.parseLong(timestamp);
        } catch (NumberFormatException e) {
            response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
            response.getWriter().write("Invalid timestamp format");
            return false;
        }

        if (Math.abs(System.currentTimeMillis() - requestTime) > TIME_DIFF_LIMIT) {
            log.error("验签失败:服务器时间不一致,appId:{}", appId);
            log.error(String.valueOf(System.currentTimeMillis()));
            response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
            response.getWriter().write("Request expired");
            return false;
        }

        // **计算 URL 参数哈希**
        String urlParams = getSortedQueryParams(request);

        // **计算 Body 哈希** 
        ServletInputStream inputStream = request.getInputStream();
        String bodyAsString = new String(inputStream.readAllBytes());

        // 构造需要签名的数据
        String dataToSign = appId + timestamp + urlParams + bodyAsString;
        ProductApp productApp = productService.getAppInfoByAppId(appId);
        String serverSignature = SM3Util.hmacSM3(dataToSign, productApp.getAppSecret());

        // 校验签名
        if (!serverSignature.equals(signature)) {
            log.error("验签失败:签名错误,appId:{}", appId);
            log.error(serverSignature);
            response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
            response.getWriter().write("Invalid signature");
            return false;
        }

        // 初始化上下文变量
        productApp.setAppSecret(null);
        AppContextLocal.set(productApp);
        //log.info("签名校验成功,appId:{}", appId);
        return true;
    }


    /**
     * 获取url参数,排序输出
     *
     * @author huanghwh
     * @date 2025/04/02 16:53
     */
    public static String getSortedQueryParams(HttpServletRequest request) {
        Map<String, String[]> paramMap = request.getParameterMap();

        if (paramMap == null || paramMap.isEmpty()) {
            // **空格占位,确保签名计算一致**
            return " ";
        }

        TreeMap<String, String> sortedParams = new TreeMap<>();
        for (Map.Entry<String, String[]> entry : paramMap.entrySet()) {
            String key = entry.getKey();
            // 多值用逗号拼接
            String value = String.join(",", entry.getValue());
            sortedParams.put(key, value);
        }

        return sortedParams.entrySet().stream()
                .map(entry -> entry.getKey() + "=" + entry.getValue())
                .collect(Collectors.joining("&"));
    }


    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        // **请求完成后清理 ThreadLocal**
        AppContextLocal.clear();
    }
}

拦截器注册

@Configuration
public class WebConfig implements WebMvcConfigurer {

    @Resource
    private GlobalClientInterceptor globalClientInterceptor;

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(globalClientInterceptor)
                .addPathPatterns(
                        "/c/**",
                        "/client/c/**",
                        "/auth/c/**",
                        "/client/c/**",
                        "/client/route"
                ).excludePathPatterns("/c/biz/memberOrder/callback/notify");
    }
}