package cn.lili.common.security.filter;
|
|
|
import cn.hutool.core.text.CharSequenceUtil;
|
import cn.hutool.http.HtmlUtil;
|
import cn.hutool.json.JSONUtil;
|
import lombok.extern.slf4j.Slf4j;
|
import org.owasp.html.HtmlPolicyBuilder;
|
import org.owasp.html.PolicyFactory;
|
|
import javax.servlet.ReadListener;
|
import javax.servlet.ServletInputStream;
|
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequestWrapper;
|
import java.io.BufferedReader;
|
import java.io.ByteArrayInputStream;
|
import java.io.IOException;
|
import java.io.InputStreamReader;
|
import java.nio.charset.StandardCharsets;
|
import java.util.HashMap;
|
import java.util.LinkedHashMap;
|
import java.util.Map;
|
|
/**
|
* 防止Xss
|
*
|
* @author Chopper
|
* @version v1.0
|
* 2021-06-04 10:39
|
*/
|
@Slf4j
|
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
|
|
//允许的标签
|
private static final String[] allowedTags = {"h1", "h2", "h3", "h4", "h5", "h6",
|
"span", "strong",
|
"img", "video", "source", "iframe", "code",
|
"blockquote", "p", "div",
|
"ul", "ol", "li",
|
"table", "thead", "caption", "tbody", "tr", "th", "td", "br",
|
"a"
|
};
|
|
//需要转化的标签
|
private static final String[] needTransformTags = {"article", "aside", "command", "datalist", "details", "figcaption", "figure",
|
"footer", "header", "hgroup", "section", "summary"};
|
|
//带有超链接的标签
|
private static final String[] linkTags = {"img", "video", "source", "a", "iframe", "p"};
|
|
//带有超链接的标签
|
private static final String[] allowAttributes = {"style", "src", "href", "target", "width", "height"};
|
|
public XssHttpServletRequestWrapper(HttpServletRequest request) {
|
super(request);
|
}
|
|
/**
|
* 对数组参数进行特殊字符过滤
|
*/
|
@Override
|
public String[] getParameterValues(String name) {
|
String[] values = super.getParameterValues(name);
|
if (values == null) {
|
return new String[0];
|
}
|
int count = values.length;
|
String[] encodedValues = new String[count];
|
for (int i = 0; i < count; i++) {
|
encodedValues[i] = filterXss(name, values[i]);
|
}
|
return encodedValues;
|
}
|
|
/**
|
* 对参数中特殊字符进行过滤
|
*/
|
@Override
|
public String getParameter(String name) {
|
String value = super.getParameter(name);
|
if (value == null) {
|
return null;
|
}
|
return filterXss(name, value);
|
}
|
|
/**
|
* 获取attribute,特殊字符过滤
|
*/
|
@Override
|
public Object getAttribute(String name) {
|
Object value = super.getAttribute(name);
|
if (value instanceof String) {
|
value = filterXss(name, (String) value);
|
}
|
return value;
|
}
|
|
/**
|
* 对请求头部进行特殊字符过滤
|
*/
|
@Override
|
public String getHeader(String name) {
|
String value = super.getHeader(name);
|
if (value == null) {
|
return null;
|
}
|
return filterXss(name, value);
|
}
|
|
@Override
|
public Map<String, String[]> getParameterMap() {
|
Map<String, String[]> parameterMap = super.getParameterMap();
|
//因为super.getParameterMap()返回的是Map,所以我们需要定义Map的实现类对数据进行封装
|
Map<String, String[]> params = new LinkedHashMap<>();
|
//如果参数不为空
|
if (parameterMap != null) {
|
//对map进行遍历
|
for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
|
//根据key获取value
|
String[] values = entry.getValue();
|
//遍历数组
|
for (int i = 0; i < values.length; i++) {
|
String value = values[i];
|
value = filterXss(entry.getKey(), value);
|
//将转义后的数据放回数组中
|
values[i] = value;
|
}
|
|
//将转义后的数组put到linkMap当中
|
params.put(entry.getKey(), values);
|
}
|
}
|
return params;
|
}
|
|
/**
|
* 获取输入流
|
*
|
* @return 过滤后的输入流
|
* @throws IOException 异常信息
|
*/
|
@Override
|
public ServletInputStream getInputStream() throws IOException {
|
|
BufferedReader bufferedReader = null;
|
|
InputStreamReader reader = null;
|
|
//获取输入流
|
ServletInputStream in = null;
|
try {
|
in = super.getInputStream();
|
//用于存储输入流
|
StringBuilder body = new StringBuilder();
|
reader = new InputStreamReader(in, StandardCharsets.UTF_8);
|
bufferedReader = new BufferedReader(reader);
|
//按行读取输入流
|
String line = bufferedReader.readLine();
|
while (line != null) {
|
//将获取到的第一行数据append到StringBuffer中
|
body.append(line);
|
//继续读取下一行流,直到line为空
|
line = bufferedReader.readLine();
|
}
|
if (CharSequenceUtil.isNotEmpty(body) && Boolean.TRUE.equals(JSONUtil.isJsonObj(body.toString()))) {
|
//将body转换为map
|
Map<String, Object> map = JSONUtil.parseObj(body.toString());
|
//创建空的map用于存储结果
|
Map<String, Object> resultMap = new HashMap<>(map.size());
|
//遍历数组
|
for (Map.Entry<String, Object> entry : map.entrySet()) {
|
//如果map.get(key)获取到的是字符串就需要进行处理,如果不是直接存储resultMap
|
if (map.get(entry.getKey()) instanceof String) {
|
resultMap.put(entry.getKey(), filterXss(entry.getKey(), entry.getValue().toString()));
|
} else {
|
resultMap.put(entry.getKey(), entry.getValue());
|
}
|
}
|
|
//将resultMap转换为json字符串
|
String resultStr = JSONUtil.toJsonStr(resultMap);
|
//将json字符串转换为字节
|
final ByteArrayInputStream resultBIS = new ByteArrayInputStream(resultStr.getBytes(StandardCharsets.UTF_8));
|
|
//实现接口
|
return new ServletInputStream() {
|
@Override
|
public boolean isFinished() {
|
return false;
|
}
|
|
@Override
|
public boolean isReady() {
|
return false;
|
}
|
|
@Override
|
public void setReadListener(ReadListener readListener) {
|
}
|
|
@Override
|
public int read() {
|
return resultBIS.read();
|
}
|
};
|
}
|
|
//将json字符串转换为字节
|
final ByteArrayInputStream bis = new ByteArrayInputStream(body.toString().getBytes());
|
//实现接口
|
return new ServletInputStream() {
|
@Override
|
public boolean isFinished() {
|
return false;
|
}
|
|
@Override
|
public boolean isReady() {
|
return false;
|
}
|
|
@Override
|
public void setReadListener(ReadListener readListener) {
|
|
}
|
|
@Override
|
public int read() {
|
return bis.read();
|
}
|
};
|
} catch (Exception e) {
|
|
log.error("get request inputStream error", e);
|
return null;
|
} finally {
|
//关闭流
|
if (bufferedReader != null) {
|
bufferedReader.close();
|
}
|
if (reader != null) {
|
reader.close();
|
}
|
if (in != null) {
|
in.close();
|
}
|
}
|
|
}
|
|
private String cleanXSS(String value) {
|
if (value != null) {
|
// 自定义策略
|
PolicyFactory policy = new HtmlPolicyBuilder()
|
.allowStandardUrlProtocols()
|
//所有允许的标签
|
.allowElements(allowedTags)
|
//内容标签转化为div
|
.allowElements((elementName, attributes) -> "div", needTransformTags)
|
.allowAttributes(allowAttributes).onElements(linkTags)
|
.allowStyling()
|
.toFactory();
|
// basic prepackaged policies for links, tables, integers, images, styles, blocks
|
value = policy.sanitize(value);
|
}
|
return HtmlUtil.unescape(value);
|
}
|
|
/**
|
* 过滤xss
|
*
|
* @param name 参数名
|
* @param value 参数值
|
* @return 参数值
|
*/
|
private String filterXss(String name, String value) {
|
return cleanXSS(value);
|
}
|
|
}
|