提交 fd4a97d1 编写于 作者: K kuizhiqing 提交者: Wei Shengyu

lazy roll

上级 a49e11db
...@@ -42,6 +42,9 @@ MODEL_URLS = { ...@@ -42,6 +42,9 @@ MODEL_URLS = {
__all__ = list(MODEL_URLS.keys()) __all__ = list(MODEL_URLS.keys())
# The following re-implementation of roll is inspired by
# https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py
class RollWithIndexSelect(paddle.autograd.PyLayer): class RollWithIndexSelect(paddle.autograd.PyLayer):
@staticmethod @staticmethod
def forward(ctx, input1, index_fp, index_bp): def forward(ctx, input1, index_fp, index_bp):
...@@ -62,6 +65,7 @@ class RollWithIndexSelect(paddle.autograd.PyLayer): ...@@ -62,6 +65,7 @@ class RollWithIndexSelect(paddle.autograd.PyLayer):
roll_with_index_select = RollWithIndexSelect.apply roll_with_index_select = RollWithIndexSelect.apply
def get_roll_index(H, W, shifts, place): def get_roll_index(H, W, shifts, place):
# following tensors will be created on cpu place with npu custom device
index = paddle.arange(0, H * W, dtype='int64').reshape([H, W]) # cpu index = paddle.arange(0, H * W, dtype='int64').reshape([H, W]) # cpu
index_fp = paddle.roll(index, shifts=shifts, axis=(0, 1)).reshape([-1]) # cpu index_fp = paddle.roll(index, shifts=shifts, axis=(0, 1)).reshape([-1]) # cpu
index_bp = {i:idx for idx, i in enumerate(index_fp.numpy().tolist())} index_bp = {i:idx for idx, i in enumerate(index_fp.numpy().tolist())}
...@@ -85,7 +89,14 @@ class NpuRollWithIndexSelect(): ...@@ -85,7 +89,14 @@ class NpuRollWithIndexSelect():
index_fp, index_bp = self.index_dict[key] index_fp, index_bp = self.index_dict[key]
return roll_with_index_select(x, index_fp, index_bp) return roll_with_index_select(x, index_fp, index_bp)
roll = NpuRollWithIndexSelect() if 'npu' in paddle.device.get_all_custom_device_type() else paddle.roll roll = None
def _lazy_init_roll(x):
global roll
if 'npu' in paddle.device.get_all_custom_device_type() and hasattr(x, '_place_str') and 'npu' in x._place_str:
roll = NpuRollWithIndexSelect()
else:
roll = paddle.roll
class Mlp(nn.Layer): class Mlp(nn.Layer):
def __init__(self, def __init__(self,
...@@ -400,6 +411,9 @@ class SwinTransformerBlock(nn.Layer): ...@@ -400,6 +411,9 @@ class SwinTransformerBlock(nn.Layer):
# cyclic shift # cyclic shift
if self.shift_size > 0: if self.shift_size > 0:
if roll is None:
_lazy_init_roll(x)
shifted_x = roll( shifted_x = roll(
x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2)) x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
else: else:
...@@ -424,6 +438,9 @@ class SwinTransformerBlock(nn.Layer): ...@@ -424,6 +438,9 @@ class SwinTransformerBlock(nn.Layer):
# reverse cyclic shift # reverse cyclic shift
if self.shift_size > 0: if self.shift_size > 0:
if roll is None:
_lazy_init_roll(shifted_x)
x = roll( x = roll(
shifted_x, shifted_x,
shifts=(self.shift_size, self.shift_size), shifts=(self.shift_size, self.shift_size),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册