Skip to main content

缓存请求 content 的 HttpServletRequest 包装类

public static class ContentCachingHttpServletRequestWrapper extends HttpServletRequestWrapper {
/**
* request body (include application/x-www-form-urlencoded)
*/
private byte[] content;
/**
* multi part
*/
private final Map<String, List<Part>> partMap = new HashMap<>();

public ContentCachingHttpServletRequestWrapper(final HttpServletRequest request) throws IOException {
super(request);
final String contentType =
Optional.ofNullable(request.getHeader("Content-Type")).map(String::toLowerCase).orElse("");
if (contentType.contains("multipart/form-data")) {
try {
Optional.ofNullable(request.getParts()).filter(parts -> !parts.isEmpty())
.map(parts -> parts.stream().collect(Collectors.groupingBy(Part::getName)))
.ifPresent(partMap::putAll);
} catch (final ServletException e) {
// Nothing to do
}
return;
}
content = IOUtils.toByteArray(new InputStreamReader(request.getInputStream()), StandardCharsets.UTF_8);
}

@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream inputStream = Optional.ofNullable(content).map(ByteArrayInputStream::new)
.orElseGet(() -> new ByteArrayInputStream(new byte[0]));
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}

@Override
public boolean isReady() {
return true;
}

@Override
public void setReadListener(final ReadListener readListener) {
}

@Override
public int read() throws IOException {
return inputStream.read();
}
};
}

@Override
public Collection<Part> getParts() throws IOException, ServletException {
return partMap.values().stream().flatMap(Collection::stream).collect(Collectors.toList());
}

@Override
public Part getPart(final String name) throws IOException, ServletException {
return Optional.ofNullable(partMap.get(name)).flatMap(parts -> parts.stream().findFirst()).orElse(null);
}

/**
* 获取 content
*
* @return 请求 content
*/
public byte[] getContentAsByteArray() {
return Optional.ofNullable(content).orElseGet(() -> new byte[0]);
}
}