peng
9 天以前 75f9783d5a70a5f037e3b34dc0e479069e63c0e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
package cn.lili.common.security.filter;
 
 
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.http.HtmlUtil;
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.owasp.html.HtmlPolicyBuilder;
import org.owasp.html.PolicyFactory;
 
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
 
/**
 * 防止Xss
 *
 * @author Chopper
 * @version v1.0
 * 2021-06-04 10:39
 */
@Slf4j
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
 
    //允许的标签
    private static final String[] allowedTags = {"h1", "h2", "h3", "h4", "h5", "h6",
            "span", "strong",
            "img", "video", "source", "iframe", "code",
            "blockquote", "p", "div",
            "ul", "ol", "li",
            "table", "thead", "caption", "tbody", "tr", "th", "td", "br",
            "a"
    };
 
    //需要转化的标签
    private static final String[] needTransformTags = {"article", "aside", "command", "datalist", "details", "figcaption", "figure",
            "footer", "header", "hgroup", "section", "summary"};
 
    //带有超链接的标签
    private static final String[] linkTags = {"img", "video", "source", "a", "iframe", "p"};
 
    //带有超链接的标签
    private static final String[] allowAttributes = {"style", "src", "href", "target", "width", "height"};
 
    public XssHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }
 
    /**
     * 对数组参数进行特殊字符过滤
     */
    @Override
    public String[] getParameterValues(String name) {
        String[] values = super.getParameterValues(name);
        if (values == null) {
            return new String[0];
        }
        int count = values.length;
        String[] encodedValues = new String[count];
        for (int i = 0; i < count; i++) {
            encodedValues[i] = filterXss(name, values[i]);
        }
        return encodedValues;
    }
 
    /**
     * 对参数中特殊字符进行过滤
     */
    @Override
    public String getParameter(String name) {
        String value = super.getParameter(name);
        if (value == null) {
            return null;
        }
        return filterXss(name, value);
    }
 
    /**
     * 获取attribute,特殊字符过滤
     */
    @Override
    public Object getAttribute(String name) {
        Object value = super.getAttribute(name);
        if (value instanceof String) {
            value = filterXss(name, (String) value);
        }
        return value;
    }
 
    /**
     * 对请求头部进行特殊字符过滤
     */
    @Override
    public String getHeader(String name) {
        String value = super.getHeader(name);
        if (value == null) {
            return null;
        }
        return filterXss(name, value);
    }
 
    @Override
    public Map<String, String[]> getParameterMap() {
        Map<String, String[]> parameterMap = super.getParameterMap();
        //因为super.getParameterMap()返回的是Map,所以我们需要定义Map的实现类对数据进行封装
        Map<String, String[]> params = new LinkedHashMap<>();
        //如果参数不为空
        if (parameterMap != null) {
            //对map进行遍历
            for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
                //根据key获取value
                String[] values = entry.getValue();
                //遍历数组
                for (int i = 0; i < values.length; i++) {
                    String value = values[i];
                    value = filterXss(entry.getKey(), value);
                    //将转义后的数据放回数组中
                    values[i] = value;
                }
 
                //将转义后的数组put到linkMap当中
                params.put(entry.getKey(), values);
            }
        }
        return params;
    }
 
    /**
     * 获取输入流
     *
     * @return 过滤后的输入流
     * @throws IOException 异常信息
     */
    @Override
    public ServletInputStream getInputStream() throws IOException {
 
        BufferedReader bufferedReader = null;
 
        InputStreamReader reader = null;
 
        //获取输入流
        ServletInputStream in = null;
        try {
            in = super.getInputStream();
            //用于存储输入流
            StringBuilder body = new StringBuilder();
            reader = new InputStreamReader(in, StandardCharsets.UTF_8);
            bufferedReader = new BufferedReader(reader);
            //按行读取输入流
            String line = bufferedReader.readLine();
            while (line != null) {
                //将获取到的第一行数据append到StringBuffer中
                body.append(line);
                //继续读取下一行流,直到line为空
                line = bufferedReader.readLine();
            }
            if (CharSequenceUtil.isNotEmpty(body) && Boolean.TRUE.equals(JSONUtil.isJsonObj(body.toString()))) {
                //将body转换为map
                Map<String, Object> map = JSONUtil.parseObj(body.toString());
                //创建空的map用于存储结果
                Map<String, Object> resultMap = new HashMap<>(map.size());
                //遍历数组
                for (Map.Entry<String, Object> entry : map.entrySet()) {
                    //如果map.get(key)获取到的是字符串就需要进行处理,如果不是直接存储resultMap
                    if (map.get(entry.getKey()) instanceof String) {
                        resultMap.put(entry.getKey(), filterXss(entry.getKey(), entry.getValue().toString()));
                    } else {
                        resultMap.put(entry.getKey(), entry.getValue());
                    }
                }
 
                //将resultMap转换为json字符串
                String resultStr = JSONUtil.toJsonStr(resultMap);
                //将json字符串转换为字节
                final ByteArrayInputStream resultBIS = new ByteArrayInputStream(resultStr.getBytes(StandardCharsets.UTF_8));
 
                //实现接口
                return new ServletInputStream() {
                    @Override
                    public boolean isFinished() {
                        return false;
                    }
 
                    @Override
                    public boolean isReady() {
                        return false;
                    }
 
                    @Override
                    public void setReadListener(ReadListener readListener) {
                    }
 
                    @Override
                    public int read() {
                        return resultBIS.read();
                    }
                };
            }
 
            //将json字符串转换为字节
            final ByteArrayInputStream bis = new ByteArrayInputStream(body.toString().getBytes());
            //实现接口
            return new ServletInputStream() {
                @Override
                public boolean isFinished() {
                    return false;
                }
 
                @Override
                public boolean isReady() {
                    return false;
                }
 
                @Override
                public void setReadListener(ReadListener readListener) {
 
                }
 
                @Override
                public int read() {
                    return bis.read();
                }
            };
        } catch (Exception e) {
 
            log.error("get request inputStream error", e);
            return null;
        } finally {
            //关闭流
            if (bufferedReader != null) {
                bufferedReader.close();
            }
            if (reader != null) {
                reader.close();
            }
            if (in != null) {
                in.close();
            }
        }
 
    }
 
    private String cleanXSS(String value) {
        if (value != null) {
            // 自定义策略
            PolicyFactory policy = new HtmlPolicyBuilder()
                    .allowStandardUrlProtocols()
                    //所有允许的标签
                    .allowElements(allowedTags)
                    //内容标签转化为div
                    .allowElements((elementName, attributes) -> "div", needTransformTags)
                    .allowAttributes(allowAttributes).onElements(linkTags)
                    .allowStyling()
                    .toFactory();
            // basic prepackaged policies for links, tables, integers, images, styles, blocks
            value = policy.sanitize(value);
        }
        return HtmlUtil.unescape(value);
    }
 
    /**
     * 过滤xss
     *
     * @param name  参数名
     * @param value 参数值
     * @return 参数值
     */
    private String filterXss(String name, String value) {
        return cleanXSS(value);
    }
 
}