From 6cd988eaa0c58587eda953976e404c104f176179 Mon Sep 17 00:00:00 2001 From: tomsun28 Date: Wed, 12 Aug 2020 23:13:47 +0800 Subject: [PATCH] reconstruct security subject holder - ThreadContext to SurenessContextHolder --- .../com/usthe/sureness/subject/Subject.java | 7 +- .../sureness/util/SurenessContextHolder.java | 100 ++++++++++++++++++ .../usthe/sureness/util/ThreadContext.java | 99 ----------------- 3 files changed, 101 insertions(+), 105 deletions(-) create mode 100644 core/src/main/java/com/usthe/sureness/util/SurenessContextHolder.java delete mode 100644 core/src/main/java/com/usthe/sureness/util/ThreadContext.java diff --git a/core/src/main/java/com/usthe/sureness/subject/Subject.java b/core/src/main/java/com/usthe/sureness/subject/Subject.java index ed04796..e2b4556 100644 --- a/core/src/main/java/com/usthe/sureness/subject/Subject.java +++ b/core/src/main/java/com/usthe/sureness/subject/Subject.java @@ -2,7 +2,6 @@ package com.usthe.sureness.subject; import com.usthe.sureness.subject.support.SurenessSubjectSum; -import com.usthe.sureness.util.ThreadContext; import java.io.Serializable; import java.util.List; @@ -68,14 +67,10 @@ public interface Subject extends Serializable { String principal = (String)getPrincipal(); List roles = (List)getOwnRoles(); String targetUri = (String)getTargetResource(); - SubjectSum subject = SurenessSubjectSum.builder() + return SurenessSubjectSum.builder() .setTargetResource(targetUri) .setRoles(roles) .setPrincipal(principal) .build(); - // 将subject 绑定到localThread变量中 - ThreadContext.bind(subject); - // 如果是网关认证中心, 之后可以考虑把subject绑定到request请求中,供子系统使用 - return subject; } } diff --git a/core/src/main/java/com/usthe/sureness/util/SurenessContextHolder.java b/core/src/main/java/com/usthe/sureness/util/SurenessContextHolder.java new file mode 100644 index 0000000..6fc4215 --- /dev/null +++ b/core/src/main/java/com/usthe/sureness/util/SurenessContextHolder.java @@ -0,0 +1,100 @@ +package com.usthe.sureness.util; + +import com.usthe.sureness.subject.Subject; +import com.usthe.sureness.subject.SubjectSum; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; + +/** + * learn from ThreadContext + * @author from shiro + * @date 23:01 2019-01-09 + */ +public class SurenessContextHolder { + + private static final Logger logger = LoggerFactory.getLogger(SurenessContextHolder.class); + + public static final String SUBJECT_KEY = "SUBJECT_KEY"; + + private static final ThreadLocal> RESOURCES = InheritableThreadLocal + .withInitial(() -> new HashMap<>(8)); + + /** + * 线程结束前调用 清空内容 防止oom + */ + public static void clear() { + RESOURCES.remove(); + } + + public static void bind(Object key, Object value) { + internalPut(key, value); + } + + public static void unbind(Object key) { + if (key != null) { + internalRemove(key); + } + } + + public static Object getBind(Object key) { + if (key == null) { + return null; + } + return internalGet(key); + } + + public static void bindSubject(SubjectSum subjectSum) { + internalPut(SUBJECT_KEY, subjectSum); + } + + public static void bindSubject(Subject subject) { + if (subject != null) { + internalPut(SUBJECT_KEY, subject.generateSubjectSummary()); + } + } + + public static void unbindSubject() { + internalRemove(SUBJECT_KEY); + } + + + public static SubjectSum getBindSubject() { + return (SubjectSum) internalGet(SUBJECT_KEY); + } + + + private static void internalPut(Object key, Object value) { + if (key == null) { + throw new NullPointerException("key cannot be null"); + } else if (value == null) { + internalRemove(key); + } else { + ensureResourcesInitialized(); + RESOURCES.get().put(key, value); + } + } + + private static Object internalGet(Object key) { + if (logger.isTraceEnabled()) { + logger.trace("get() - in thread [{}]", Thread.currentThread().getName()); + } + Map perThreadResources = RESOURCES.get(); + return perThreadResources != null ? perThreadResources.get(key) : null; + } + + private static void internalRemove(Object key) { + Map perThreadResources = RESOURCES.get(); + if (perThreadResources != null) { + perThreadResources.remove(key); + } + } + + private static void ensureResourcesInitialized() { + if (RESOURCES.get() == null) { + RESOURCES.set(new HashMap<>(8)); + } + } +} diff --git a/core/src/main/java/com/usthe/sureness/util/ThreadContext.java b/core/src/main/java/com/usthe/sureness/util/ThreadContext.java deleted file mode 100644 index 9528f5e..0000000 --- a/core/src/main/java/com/usthe/sureness/util/ThreadContext.java +++ /dev/null @@ -1,99 +0,0 @@ -package com.usthe.sureness.util; - -import com.usthe.sureness.subject.SubjectSum; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.HashMap; -import java.util.Map; - -/** - * learn from shiro ThreadContext - * @author from shiro - * @date 23:01 2019-01-09 - */ -public class ThreadContext { - - private static final Logger logger = LoggerFactory.getLogger(ThreadContext.class); - - public static final String SUBJECT_KEY = ThreadContext.class.getName() + "_SUBJECT_KEY"; - @SuppressWarnings("unchecked") - private static final ThreadLocal> RESOURCES = new ThreadContext.InheritableThreadLocalMap(); - - public static void bind(SubjectSum subject) { - if (subject != null) { - put(SUBJECT_KEY, subject); - } - } - - public static SubjectSum unbindSubject() { - return (SubjectSum)remove(SUBJECT_KEY); - } - - public static SubjectSum getBindSubject() { - return (SubjectSum)get(SUBJECT_KEY); - } - - - private static void put(Object key, Object value) { - if (key == null) { - throw new IllegalArgumentException("key cannot be null"); - } else if (value == null) { - remove(key); - } else { - ensureResourcesInitialized(); - (RESOURCES.get()).put(key, value); - if (logger.isTraceEnabled()) { - String msg = "Bound value of type [" + value.getClass().getName() + "] for key [" + key + "] to thread [" + Thread.currentThread().getName() + "]"; - logger.trace(msg); - } - } - } - - private static Object get(Object key) { - if (logger.isTraceEnabled()) { - String msg = "get() - in thread [" + Thread.currentThread().getName() + "]"; - logger.trace(msg); - } - Map perThreadResources = RESOURCES.get(); - Object value = perThreadResources != null ? perThreadResources.get(key) : null; - if (value != null && logger.isTraceEnabled()) { - String msg = "Retrieved value of type [" + value.getClass().getName() + "] for key [" + key + "] bound to thread [" + Thread.currentThread().getName() + "]"; - logger.trace(msg); - } - return value; - } - - private static Object remove(Object key) { - Map perThreadResources = RESOURCES.get(); - Object value = perThreadResources != null ? perThreadResources.remove(key) : null; - if (value != null && logger.isTraceEnabled()) { - String msg = "Removed value of type [" + value.getClass().getName() + "] for key [" + key + "]from thread [" + Thread.currentThread().getName() + "]"; - logger.trace(msg); - } - return value; - } - - @SuppressWarnings("unchecked") - private static void ensureResourcesInitialized() { - if (RESOURCES.get() == null) { - RESOURCES.set(new HashMap(8)); - } - - } - - public static void remove() { - RESOURCES.remove(); - } - - private static final class InheritableThreadLocalMap> extends InheritableThreadLocal> { - private InheritableThreadLocalMap() { - } - - @Override - @SuppressWarnings("unchecked") - protected Map childValue(Map parentValue) { - return parentValue != null ? (Map)((HashMap)parentValue).clone() : null; - } - } -} -- GitLab