提交 17a4b50a 编写于 作者: K kuizhiqing 提交者: Wei Shengyu

rm singleton

上级 9d69b63c
......@@ -73,25 +73,12 @@ def get_roll_index(H, W, shifts, place):
return [index_fp, index_bp]
def singleton(cls):
def wrapper_singleton(*args, **kwargs):
if not wrapper_singleton.instance:
wrapper_singleton.instance = cls(*args, **kwargs)
return wrapper_singleton.instance
wrapper_singleton.instance = None
return wrapper_singleton
@singleton
class RollWrapperSingleton():
class NpuRollWithIndexSelect():
def __init__(self):
self.index_dict = {}
self.roll_with_index_select = RollWithIndexSelect.apply
self.enable = True if 'npu' in paddle.device.get_all_custom_device_type() else False
def __call__(self, x, shifts, axis):
if not self.enable:
return paddle.roll(x, shifts, axis)
assert x.dim() == 4
assert len(shifts) == 2
assert len(axis) == 2
......@@ -103,6 +90,18 @@ class RollWrapperSingleton():
return self.roll_with_index_select(x, index_fp, index_bp)
class RollWrapper(object):
_roll = None
@staticmethod
def roll(x, shifts, axis):
if RollWrapper._roll is None:
RollWrapper._roll = NpuRollWithIndexSelect() if 'npu' in paddle.device.get_all_custom_device_type() else paddle.roll
return RollWrapper._roll(x, shifts, axis)
class Mlp(nn.Layer):
def __init__(self,
in_features,
......@@ -414,10 +413,9 @@ class SwinTransformerBlock(nn.Layer):
x = self.norm1(x)
x = x.reshape([B, H, W, C])
roll = RollWrapperSingleton()
# cyclic shift
if self.shift_size > 0:
shifted_x = roll(
shifted_x = RollWrapper.roll(
x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
else:
shifted_x = x
......@@ -441,7 +439,7 @@ class SwinTransformerBlock(nn.Layer):
# reverse cyclic shift
if self.shift_size > 0:
x = roll(
x = RollWrapper.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
axis=(1, 2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册