防篡改、防重放过滤器。实现过滤器Filter
package src.main.biz.village.filter;
import cn.hutool.core.convert.Convert;
import cn.hutool.core.date.DateUnit;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import src.main.biz.village.filter.sign.HttpUtil;
import src.main.biz.village.filter.sign.RequestWrapper;
import src.main.biz.village.filter.sign.SecurityProperties;
import src.main.biz.village.filter.sign.SignUtil;
import src.main.biz.village.utils.IPUtils;
import src.main.newgrand.framework.common.utils.ResultRes;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashSet;
import java.util.Set;
import java.util.SortedMap;
@Slf4j
@Component
@Order(5)
/**
* 防篡改、防重放过滤器
*/
public class SignAuthFilter implements Filter {
@Autowired
private SecurityProperties securityProperties;
@Override
public void init(FilterConfig filterConfig) {
log.info("初始化 SignAuthFilter");
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws ServletException, IOException {
// 防止流读取一次后就没有了, 所以需要将流继续写出去
HttpServletRequest httpRequest = (HttpServletRequest) request;
HttpServletRequest requestWrapper = new RequestWrapper(httpRequest);
Integer open = securityProperties.getOpen();
//没有配置默认开启
if(open == null || open.intValue() == 0){
filterChain.doFilter(requestWrapper, response);
return;
}
String userAgent = httpRequest.getHeader("User-Agent");
if (userAgent != null && userAgent.toLowerCase().contains("mobile")) {
// 移动端访问不检验签名
filterChain.doFilter(requestWrapper, response);
return;
}
Set<String> uriSet = new HashSet<>(securityProperties.getIgnoreSignUri());
String requestUri = httpRequest.getRequestURI();
//isSign:true允许忽悠签名,false需要签名验证,yml进行配置
boolean ignoreSign = false;
if(uriSet.contains(requestUri)){
ignoreSign = true;
}
String ip1 = IPUtils.getClientIP((HttpServletRequest)request);
String ip2 = IPUtils.getClientRealIP((HttpServletRequest)request);
log.info("当前请求的ip1是==>{},ip2==>{}", ip1,ip2);
String ipAddress = request.getRemoteAddr();
Set<String> ipSet = new HashSet<>(securityProperties.getWhiteIP());
if(ipSet.contains(ipAddress) || ipSet.contains(ip1)){
ignoreSign = true;
}
log.info("当前请求的URI是==>{},ignoreSign==>{}", httpRequest.getRequestURI(), ignoreSign);
if (ignoreSign) {
filterChain.doFilter(requestWrapper, response);
return;
}
String sign = requestWrapper.getHeader("Sign");
Long timestamp = Convert.toLong(requestWrapper.getHeader("Timestamp"));
if (StrUtil.isEmpty(sign)) {
returnFail("签名不允许为空", response);
return;
}
if (timestamp == null) {
returnFail("时间戳不允许为空", response);
return;
}
//重放时间限制(单位分)
Long difference = DateUtil.between(DateUtil.date(), DateUtil.date(timestamp * 1000), DateUnit.MINUTE);
if (difference > securityProperties.getSignTimeout()) {
returnFail("已过期的签名", response);
log.info("前端时间戳:{},服务端时间戳:{}", DateUtil.date(timestamp * 1000), DateUtil.date());
return;
}
boolean accept = true;
SortedMap<String, String> paramMap;
switch (requestWrapper.getMethod()) {
case "GET":
paramMap = HttpUtil.getUrlParams(requestWrapper);
accept = SignUtil.verifySign(paramMap, timestamp,securityProperties.getSecret(),sign);
break;
case "POST":
case "PUT":
case "DELETE":
paramMap = HttpUtil.getBodyParams(requestWrapper);
accept = SignUtil.verifySign(paramMap, timestamp,securityProperties.getSecret(),sign);
break;
default:
accept = true;
break;
}
if (accept) {
filterChain.doFilter(requestWrapper, response);
} else {
returnFail("签名验证不通过", response);
}
}
private void returnFail(String msg, ServletResponse response) throws IOException {
response.setCharacterEncoding("UTF-8");
response.setContentType("application/json; charset=utf-8");
PrintWriter out = response.getWriter();
String result = JSONObject.toJSONString(ResultRes.fail(msg));
out.println(result);
out.flush();
out.close();
}
@Override
public void destroy() {
log.info("销毁 SignAuthFilter");
}
}
签名工具类
package src.main.biz.village.filter.sign;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.codec.digest.DigestUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.SortedMap;
/**
* 签名工具类
*/
@Slf4j
public class SignUtil {
public static boolean verifySign(SortedMap<String, String> params, Long timestamp, String secret, String sign) {
String paramsSign = getSign(params, timestamp, secret);
return sign.equals(paramsSign);
}
public static String getSign(SortedMap<String, String> params, Long timestamp, String secret) {
List<String> paramKeys = new ArrayList<>(params.keySet());
Collections.sort(paramKeys);
//校验逻辑
StringBuilder sb = new StringBuilder();
//拼接参数
for (String key : paramKeys) {
Object value = params.get(key);
sb.append(key).append("=").append(value).append("&");
}
//拼接secret
sb.append("timestamp=").append(timestamp).append("&").append("secret=").append(secret);
String paramsSign = "";
try {
log.info("拼接后加密前的字符串 : {}", sb.toString());
paramsSign = DigestUtils.md5Hex(sb.toString().toUpperCase());
log.info("Param Sign : {}", paramsSign);
} catch (Exception e) {
log.info("签名生成失败");
e.printStackTrace();
}
return paramsSign;
}
}
配置白名单,密钥等
package src.main.biz.village.filter.sign;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
@ConfigurationProperties(prefix = "spring.security")
@Data
public class SecurityProperties {
/**
* 允许忽略签名地址
*/
List<String> ignoreSignUri;
List<String> whiteIP;
/**
* 签名超时时间(分)
*/
Integer signTimeout;
String secret;
Integer open;
}
保存过滤器里面的流
package src.main.biz.village.filter.sign;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
/**
* 保存过滤器里面的流
*/
public class RequestWrapper extends HttpServletRequestWrapper {
private final byte[] body;
public RequestWrapper(HttpServletRequest request) {
super(request);
String sessionStream = getBodyString(request);
body = sessionStream.getBytes(Charset.forName("UTF-8"));
}
/**
* 获取请求Body
*
* @param request
* @return
*/
public String getBodyString(final ServletRequest request) {
StringBuilder sb = new StringBuilder();
try (
InputStream inputStream = cloneInputStream(request.getInputStream());
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")))
) {
String line;
while ((line = reader.readLine()) != null) {
sb.append(line);
}
} catch (IOException e) {
e.printStackTrace();
}
return sb.toString();
}
/**
* 复制输入流
* @param inputStream
* @return
*/
public InputStream cloneInputStream(ServletInputStream inputStream) {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int len;
try {
while ((len = inputStream.read(buffer)) > -1) {
byteArrayOutputStream.write(buffer, 0, len);
}
byteArrayOutputStream.flush();
} catch (IOException e) {
e.printStackTrace();
}
return new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
}
@Override
public BufferedReader getReader() {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public ServletInputStream getInputStream() {
final ByteArrayInputStream bais = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public int read() {
return bais.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
}
MD5UtilSign工具类
package src.main.biz.village.filter.sign;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
public class MD5UtilSign {
private static ThreadLocal<MessageDigest> messageDigestHolder = new ThreadLocal();
static final char[] hexDigits = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
public MD5UtilSign() {
}
public static String getMD5Format(String data) {
try {
MessageDigest message = (MessageDigest) messageDigestHolder.get();
if (message == null) {
message = MessageDigest.getInstance("MD5");
messageDigestHolder.set(message);
}
message.update(data.getBytes());
byte[] b = message.digest();
String digestHexStr = "";
for (int i = 0; i < 16; ++i) {
digestHexStr = digestHexStr + byteHEX(b[i]);
}
return digestHexStr;
} catch (Exception var5) {
throw new RuntimeException("MD5格式化时发生异常: " , var5);
}
}
public static String getMD5Format(byte[] data) {
try {
MessageDigest message = (MessageDigest) messageDigestHolder.get();
if (message == null) {
message = MessageDigest.getInstance("MD5");
messageDigestHolder.set(message);
}
message.update(data);
byte[] b = message.digest();
String digestHexStr = "";
for (int i = 0; i < 16; ++i) {
digestHexStr = digestHexStr + byteHEX(b[i]);
}
return digestHexStr;
} catch (Exception var5) {
return null;
}
}
private static String byteHEX(byte ib) {
char[] ob = new char[]{hexDigits[ib >>> 4 & 15], hexDigits[ib & 15]};
String s = new String(ob);
return s;
}
static {
try {
MessageDigest message = MessageDigest.getInstance("MD5");
messageDigestHolder.set(message);
} catch (NoSuchAlgorithmException var1) {
throw new RuntimeException("初始化java.security.MessageDigest失败:" , var1);
}
}
}
http 工具类 获取请求中的参数
package src.main.biz.village.filter.sign;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSON;
import cn.hutool.json.JSONUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.type.MapType;
import com.fasterxml.jackson.databind.type.TypeFactory;
import lombok.extern.slf4j.Slf4j;
import javax.servlet.http.HttpServletRequest;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
/**
* http 工具类 获取请求中的参数
*/
@Slf4j
public class HttpUtil {
/**
* 获取 Body 参数
*
* @param request
*/
public static SortedMap<String, String> getBodyParams(final HttpServletRequest request) throws IOException {
SortedMap<String, String> result = new TreeMap<>();
BufferedReader reader = new BufferedReader(new InputStreamReader(request.getInputStream()));
String str;
StringBuilder wholeStr = new StringBuilder();
//一行一行的读取body体里面的内容;
while ((str = reader.readLine()) != null) {
wholeStr.append(str);
}
if (StrUtil.isEmpty(wholeStr)) {
wholeStr.append("{}");
}
//转化成json对象
Object json = JSONUtil.parse(wholeStr.toString());
//log.info("Parsed object: ={}",json);
//log.info("Type of parsed object={}",json.getClass());
if (json instanceof cn.hutool.json.JSONArray) {
result.put("array", wholeStr.toString());
} else if (json instanceof JSON) {
//result = JSONObject.parseObject(wholeStr.toString(), SortedMap.class);
try {
result = jsonToSortedMap(wholeStr.toString());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return result;
}
/**
* 将URL请求参数转换成Map
*
* @param request
*/
public static SortedMap<String, String> getUrlParams(HttpServletRequest request) {
String param = "";
SortedMap<String, String> result = new TreeMap<>();
if (StrUtil.isEmpty(request.getQueryString())) {
return result;
}
try {
param = URLDecoder.decode(request.getQueryString(), "utf-8");
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
String[] params = param.split("&");
for (String s : params) {
int index = s.indexOf("=");
result.put(s.substring(0, index), s.substring(index + 1));
}
return result;
}
public static SortedMap<String, String> jsonToSortedMap(String jsonString) throws Exception {
ObjectMapper mapper = new ObjectMapper();
TypeFactory factory = mapper.getTypeFactory();
// 创建一个类型,表示一个从String到Object的映射,其中键按字典顺序排序
MapType type = factory.constructMapType(TreeMap.class, String.class, Object.class);
// 将JSON字符串转换为映射
Map<String, Object> map = mapper.readValue(jsonString, type);
// 遍历映射并将所有值转换为字符串
SortedMap<String, String> sortedMap = new TreeMap<>();
for (String key : map.keySet()) {
Object value = map.get(key);
if (value instanceof String) {
sortedMap.put(key, (String)value);
} else {
sortedMap.put(key, mapper.writeValueAsString(value));
}
}
return sortedMap;
}
}