package cn.lili.common.security.filter; import cn.hutool.core.text.CharSequenceUtil; import cn.hutool.http.HtmlUtil; import cn.hutool.json.JSONUtil; import lombok.extern.slf4j.Slf4j; import org.owasp.html.HtmlPolicyBuilder; import org.owasp.html.PolicyFactory; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; /** * 防止Xss * * @author Chopper * @version v1.0 * 2021-06-04 10:39 */ @Slf4j public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper { //允许的标签 private static final String[] allowedTags = {"h1", "h2", "h3", "h4", "h5", "h6", "span", "strong", "img", "video", "source", "iframe", "code", "blockquote", "p", "div", "ul", "ol", "li", "table", "thead", "caption", "tbody", "tr", "th", "td", "br", "a" }; //需要转化的标签 private static final String[] needTransformTags = {"article", "aside", "command", "datalist", "details", "figcaption", "figure", "footer", "header", "hgroup", "section", "summary"}; //带有超链接的标签 private static final String[] linkTags = {"img", "video", "source", "a", "iframe", "p"}; //带有超链接的标签 private static final String[] allowAttributes = {"style", "src", "href", "target", "width", "height"}; public XssHttpServletRequestWrapper(HttpServletRequest request) { super(request); } /** * 对数组参数进行特殊字符过滤 */ @Override public String[] getParameterValues(String name) { String[] values = super.getParameterValues(name); if (values == null) { return new String[0]; } int count = values.length; String[] encodedValues = new String[count]; for (int i = 0; i < count; i++) { encodedValues[i] = filterXss(name, values[i]); } return encodedValues; } /** * 对参数中特殊字符进行过滤 */ @Override public String getParameter(String name) { String value = super.getParameter(name); if (value == null) { return null; } return filterXss(name, value); } /** * 获取attribute,特殊字符过滤 */ @Override public Object getAttribute(String name) { Object value = super.getAttribute(name); if (value instanceof String) { value = filterXss(name, (String) value); } return value; } /** * 对请求头部进行特殊字符过滤 */ @Override public String getHeader(String name) { String value = super.getHeader(name); if (value == null) { return null; } return filterXss(name, value); } @Override public Map getParameterMap() { Map parameterMap = super.getParameterMap(); //因为super.getParameterMap()返回的是Map,所以我们需要定义Map的实现类对数据进行封装 Map params = new LinkedHashMap<>(); //如果参数不为空 if (parameterMap != null) { //对map进行遍历 for (Map.Entry entry : parameterMap.entrySet()) { //根据key获取value String[] values = entry.getValue(); //遍历数组 for (int i = 0; i < values.length; i++) { String value = values[i]; value = filterXss(entry.getKey(), value); //将转义后的数据放回数组中 values[i] = value; } //将转义后的数组put到linkMap当中 params.put(entry.getKey(), values); } } return params; } /** * 获取输入流 * * @return 过滤后的输入流 * @throws IOException 异常信息 */ @Override public ServletInputStream getInputStream() throws IOException { BufferedReader bufferedReader = null; InputStreamReader reader = null; //获取输入流 ServletInputStream in = null; try { in = super.getInputStream(); //用于存储输入流 StringBuilder body = new StringBuilder(); reader = new InputStreamReader(in, StandardCharsets.UTF_8); bufferedReader = new BufferedReader(reader); //按行读取输入流 String line = bufferedReader.readLine(); while (line != null) { //将获取到的第一行数据append到StringBuffer中 body.append(line); //继续读取下一行流,直到line为空 line = bufferedReader.readLine(); } if (CharSequenceUtil.isNotEmpty(body) && Boolean.TRUE.equals(JSONUtil.isJsonObj(body.toString()))) { //将body转换为map Map map = JSONUtil.parseObj(body.toString()); //创建空的map用于存储结果 Map resultMap = new HashMap<>(map.size()); //遍历数组 for (Map.Entry entry : map.entrySet()) { //如果map.get(key)获取到的是字符串就需要进行处理,如果不是直接存储resultMap if (map.get(entry.getKey()) instanceof String) { resultMap.put(entry.getKey(), filterXss(entry.getKey(), entry.getValue().toString())); } else { resultMap.put(entry.getKey(), entry.getValue()); } } //将resultMap转换为json字符串 String resultStr = JSONUtil.toJsonStr(resultMap); //将json字符串转换为字节 final ByteArrayInputStream resultBIS = new ByteArrayInputStream(resultStr.getBytes(StandardCharsets.UTF_8)); //实现接口 return new ServletInputStream() { @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return false; } @Override public void setReadListener(ReadListener readListener) { } @Override public int read() { return resultBIS.read(); } }; } //将json字符串转换为字节 final ByteArrayInputStream bis = new ByteArrayInputStream(body.toString().getBytes()); //实现接口 return new ServletInputStream() { @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return false; } @Override public void setReadListener(ReadListener readListener) { } @Override public int read() { return bis.read(); } }; } catch (Exception e) { log.error("get request inputStream error", e); return null; } finally { //关闭流 if (bufferedReader != null) { bufferedReader.close(); } if (reader != null) { reader.close(); } if (in != null) { in.close(); } } } private String cleanXSS(String value) { if (value != null) { // 自定义策略 PolicyFactory policy = new HtmlPolicyBuilder() .allowStandardUrlProtocols() //所有允许的标签 .allowElements(allowedTags) //内容标签转化为div .allowElements((elementName, attributes) -> "div", needTransformTags) .allowAttributes(allowAttributes).onElements(linkTags) .allowStyling() .toFactory(); // basic prepackaged policies for links, tables, integers, images, styles, blocks value = policy.sanitize(value); } return HtmlUtil.unescape(value); } /** * 过滤xss * * @param name 参数名 * @param value 参数值 * @return 参数值 */ private String filterXss(String name, String value) { return cleanXSS(value); } }