AuthState.java 5.6 KB
Newer Older
1 2 3 4 5 6
package me.zhyd.oauth.utils;

import cn.hutool.core.codec.Base64;
import cn.hutool.core.util.RandomUtil;
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
7
import me.zhyd.oauth.config.AuthSource;
8
import me.zhyd.oauth.exception.AuthException;
智布道's avatar
智布道 已提交
9
import me.zhyd.oauth.model.AuthResponseStatus;
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33

import java.nio.charset.Charset;
import java.util.concurrent.ConcurrentHashMap;

/**
 * state工具,负责创建、获取和删除state
 *
 * @author yadong.zhang (yadong.zhang0415(a)gmail.com)
 * @version 1.0
 * @since 1.8
 */
@Slf4j
public class AuthState {

    /**
     * 空字符串
     */
    private static final String EMPTY_STR = "";

    /**
     * state存储器
     */
    private static ConcurrentHashMap<String, String> stateBucket = new ConcurrentHashMap<>();

34 35 36 37 38 39 40 41 42 43
    /**
     * 生成随机的state
     *
     * @param source oauth平台
     * @return state
     */
    public static String create(AuthSource source) {
        return create(source.name());
    }

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    /**
     * 生成随机的state
     *
     * @param source oauth平台
     * @return state
     */
    public static String create(String source) {
        return create(source, RandomUtil.randomString(4));
    }

    /**
     * 创建state
     *
     * @param source oauth平台
     * @param body   希望加密到state的消息体
     * @return state
     */
    public static String create(String source, Object body) {
        return create(source, JSON.toJSONString(body));
    }

    /**
     * 创建state
     *
     * @param source oauth平台
     * @param body   希望加密到state的消息体
     * @return state
     */
    public static String create(String source, String body) {
        String currentIp = getCurrentIp();
        String simpleKey = ((source + currentIp));
        String key = Base64.encode(simpleKey.getBytes(Charset.forName("UTF-8")));
        log.debug("Create the state: ip={}, platform={}, simpleKey={}, key={}, body={}", currentIp, source, simpleKey, key, body);

        if (stateBucket.containsKey(key)) {
            log.debug("Get from bucket: {}", stateBucket.get(key));
            return stateBucket.get(key);
        }

        String simpleState = source + "_" + currentIp + "_" + body;
        String state = Base64.encode(simpleState.getBytes(Charset.forName("UTF-8")));
        log.debug("Create a new state: {}", state, simpleState);
        stateBucket.put(key, state);
        return state;
    }

    /**
     * 获取state
     *
     * @param source oauth平台
     * @return state
     */
    public static String get(String source) {
        String currentIp = getCurrentIp();
        String simpleKey = ((source + currentIp));
        String key = Base64.encode(simpleKey.getBytes(Charset.forName("UTF-8")));
        log.debug("Get state by the key[{}], current ip[{}]", key, currentIp);
        return stateBucket.get(key);
    }

    /**
     * 获取state中保存的body内容
     *
     * @param source oauth平台
     * @param state  加密后的state
     * @param clazz  body的实际类型
智布道's avatar
智布道 已提交
110
     * @param <T>    需要转换的具体的class类型
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
     * @return state
     */
    public static <T> T getBody(String source, String state, Class<T> clazz) {
        if (StringUtils.isEmpty(state) || null == clazz) {
            return null;
        }
        log.debug("Get body from the state[{}] of the {} and convert it to {}", state, source, clazz.toString());
        String currentIp = getCurrentIp();
        String decodedState = Base64.decodeStr(state);
        log.debug("The decoded state is [{}]", decodedState);
        if (!decodedState.startsWith(source)) {
            return null;
        }
        String noneSourceState = decodedState.substring(source.length() + 1);
        if (!noneSourceState.startsWith(currentIp)) {
            // ip不相同,可能为非法的请求
智布道's avatar
智布道 已提交
127
            throw new AuthException(AuthResponseStatus.ILLEGAL_REQUEST);
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        }
        String body = noneSourceState.substring(currentIp.length() + 1);
        log.debug("body is [{}]", body);
        if (clazz == String.class) {
            return (T) body;
        }
        if (clazz == Integer.class) {
            return (T) Integer.valueOf(Integer.parseInt(body));
        }
        if (clazz == Long.class) {
            return (T) Long.valueOf(Long.parseLong(body));
        }
        if (clazz == Short.class) {
            return (T) Short.valueOf(Short.parseShort(body));
        }
        if (clazz == Double.class) {
            return (T) Double.valueOf(Double.parseDouble(body));
        }
        if (clazz == Float.class) {
            return (T) Float.valueOf(Float.parseFloat(body));
        }
        if (clazz == Boolean.class) {
            return (T) Boolean.valueOf(Boolean.parseBoolean(body));
        }
        if (clazz == Byte.class) {
            return (T) Byte.valueOf(Byte.parseByte(body));
        }
        return JSON.parseObject(body, clazz);
    }

    /**
     * 登录成功后,清除state
     *
     * @param source oauth平台
     */
    public static void delete(String source) {
        String currentIp = getCurrentIp();

        String simpleKey = ((source + currentIp));
        String key = Base64.encode(simpleKey.getBytes(Charset.forName("UTF-8")));
        log.debug("Delete used state[{}] by the key[{}], current ip[{}]", stateBucket.get(key), key, currentIp);
        stateBucket.remove(key);
    }

172 173 174 175 176 177 178 179 180
    /**
     * 登录成功后,清除state
     *
     * @param source oauth平台
     */
    public static void delete(AuthSource source) {
        delete(source.name());
    }

181 182 183 184 185
    private static String getCurrentIp() {
        String currentIp = IpUtils.getIp();
        return StringUtils.isEmpty(currentIp) ? EMPTY_STR : currentIp;
    }
}