diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 9facd8e917273b3f420404a5687b9714e3319408..eef98d96632fa3330813ed1b46816ab0bc151c44 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -975,6 +975,29 @@ class BatchNorm(Layer): ) self._variance.stop_gradient = True + # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op + if ( + _global_flags()['FLAGS_npu_storage_format'] + and 'npu' in get_all_custom_device_type() + ): + with no_grad(): + weight_trans = _C_ops.npu_identity( + self.weight, 3 + ) # ACL_FORMAT_NC1HWC0 = 3 + bias_trans = _C_ops.npu_identity( + self.bias, 3 + ) # ACL_FORMAT_NC1HWC0 = 3 + mean_trans = _C_ops.npu_identity( + self._mean, 3 + ) # ACL_FORMAT_NC1HWC0 = 3 + var_trans = _C_ops.npu_identity( + self._variance, 3 + ) # ACL_FORMAT_NC1HWC0 = 3 + weight_trans._share_underline_tensor_to(self.weight) + bias_trans._share_underline_tensor_to(self.bias) + mean_trans._share_underline_tensor_to(self._mean) + var_trans._share_underline_tensor_to(self._variance) + self._in_place = in_place self._data_layout = data_layout self._momentum = momentum