public static class ContentCachingHttpServletRequestWrapper extends HttpServletRequestWrapper {
private byte[] content;
private final Map<String, List<Part>> partMap = new HashMap<>();
public ContentCachingHttpServletRequestWrapper(final HttpServletRequest request) throws IOException {
super(request);
final String contentType =
Optional.ofNullable(request.getHeader("Content-Type")).map(String::toLowerCase).orElse("");
if (contentType.contains("multipart/form-data")) {
try {
Optional.ofNullable(request.getParts()).filter(parts -> !parts.isEmpty())
.map(parts -> parts.stream().collect(Collectors.groupingBy(Part::getName)))
.ifPresent(partMap::putAll);
} catch (final ServletException e) {
}
return;
}
content = IOUtils.toByteArray(new InputStreamReader(request.getInputStream()), StandardCharsets.UTF_8);
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream inputStream = Optional.ofNullable(content).map(ByteArrayInputStream::new)
.orElseGet(() -> new ByteArrayInputStream(new byte[0]));
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setReadListener(final ReadListener readListener) {
}
@Override
public int read() throws IOException {
return inputStream.read();
}
};
}
@Override
public Collection<Part> getParts() throws IOException, ServletException {
return partMap.values().stream().flatMap(Collection::stream).collect(Collectors.toList());
}
@Override
public Part getPart(final String name) throws IOException, ServletException {
return Optional.ofNullable(partMap.get(name)).flatMap(parts -> parts.stream().findFirst()).orElse(null);
}
public byte[] getContentAsByteArray() {
return Optional.ofNullable(content).orElseGet(() -> new byte[0]);
}
}