diff --git a/ppcls/arch/backbone/legendary_models/swin_transformer.py b/ppcls/arch/backbone/legendary_models/swin_transformer.py index a1c0906dcd60af898103f65bf409b9869063f561..412d76ba82c23f313c144d6cf22f16a68298a9f1 100644 --- a/ppcls/arch/backbone/legendary_models/swin_transformer.py +++ b/ppcls/arch/backbone/legendary_models/swin_transformer.py @@ -42,6 +42,9 @@ MODEL_URLS = { __all__ = list(MODEL_URLS.keys()) +# The following re-implementation of roll is inspired by +# https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py + class RollWithIndexSelect(paddle.autograd.PyLayer): @staticmethod def forward(ctx, input1, index_fp, index_bp): @@ -62,6 +65,7 @@ class RollWithIndexSelect(paddle.autograd.PyLayer): roll_with_index_select = RollWithIndexSelect.apply def get_roll_index(H, W, shifts, place): + # following tensors will be created on cpu place with npu custom device index = paddle.arange(0, H * W, dtype='int64').reshape([H, W]) # cpu index_fp = paddle.roll(index, shifts=shifts, axis=(0, 1)).reshape([-1]) # cpu index_bp = {i:idx for idx, i in enumerate(index_fp.numpy().tolist())} @@ -85,7 +89,14 @@ class NpuRollWithIndexSelect(): index_fp, index_bp = self.index_dict[key] return roll_with_index_select(x, index_fp, index_bp) -roll = NpuRollWithIndexSelect() if 'npu' in paddle.device.get_all_custom_device_type() else paddle.roll +roll = None + +def _lazy_init_roll(x): + global roll + if 'npu' in paddle.device.get_all_custom_device_type() and hasattr(x, '_place_str') and 'npu' in x._place_str: + roll = NpuRollWithIndexSelect() + else: + roll = paddle.roll class Mlp(nn.Layer): def __init__(self, @@ -400,6 +411,9 @@ class SwinTransformerBlock(nn.Layer): # cyclic shift if self.shift_size > 0: + if roll is None: + _lazy_init_roll(x) + shifted_x = roll( x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2)) else: @@ -424,6 +438,9 @@ class SwinTransformerBlock(nn.Layer): # reverse cyclic shift if self.shift_size > 0: + if roll is None: + _lazy_init_roll(shifted_x) + x = roll( shifted_x, shifts=(self.shift_size, self.shift_size),