diff --git a/ppcls/arch/backbone/legendary_models/swin_transformer.py b/ppcls/arch/backbone/legendary_models/swin_transformer.py index 3d83d0e83d24a271d2a8f9204e302e9c6c0016cd..c2b3cb5ea746d0b658835af63a1d0093d2a507b4 100644 --- a/ppcls/arch/backbone/legendary_models/swin_transformer.py +++ b/ppcls/arch/backbone/legendary_models/swin_transformer.py @@ -72,13 +72,14 @@ def get_roll_index(H, W, shifts, place): index_bp = paddle.to_tensor(index_fp, dtype='int64', place=place) return [index_fp, index_bp] -def singleton(class_): - instances = {} - def getinstance(*args, **kwargs): - if class_ not in instances: - instances[class_] = class_(*args, **kwargs) - return instances[class_] - return getinstance + +def singleton(cls): + def wrapper_singleton(*args, **kwargs): + if not wrapper_singleton.instance: + wrapper_singleton.instance = cls(*args, **kwargs) + return wrapper_singleton.instance + wrapper_singleton.instance = None + return wrapper_singleton @singleton class RollWrapperSingleton():