未验证 提交 0f99debd 编写于 作者: R ronnywang 提交者: GitHub

Revert "remove ASCEND* keyword" (#53131)

* Revert "remove ASCEND* keyword (#53046)"

This reverts commit 7fa415ca.

* Delete ascend_trigger_op.cc

* revert-53046-remove_ASCEND_keyword

* update

* update
上级 e8e9d6c5
......@@ -1153,6 +1153,16 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"Predictor",
"Choose default funciton type in JitLayer.");
/**
* Custom Device NPU related FLAG
* Name: FLAGS_npu_storage_format
* Since Version: 2.5.0
* Value Range: bool, default=false
* Example:
* Note: Enable NPU Storage Format for Ascend910 performance improvement.
*/
PADDLE_DEFINE_EXPORTED_bool(npu_storage_format, false, "");
#ifdef PADDLE_WITH_CUDNN_FRONTEND
/**
* CUDNNv8 related FLAG
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/npu_identity_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
template <typename T, typename Context>
void NPUIdentityKernel(const Context& dev_ctx,
const DenseTensor& x,
const int format,
DenseTensor* out) {
VLOG(4) << "npu_identity op is only for NPU, please avoid using this kernel!";
out->ShareDataWith(x);
}
} // namespace phi
/** [ Why need npu_identity op? ]
*
* 1. Ascend CANN use internal storage format for high performance
* computing, for example if run BatchNorm2D op with CANN internal
* storage format ACL_FORMAT_NC1HWC0, time costs in transdata will
* be removed, and at will gain 2x performance improvement.
*
* 2.The internal storage format will use storage_properties_ in
* DenseTensor, and will change the size and layout of denser, and
* finally it should be called when change tensor to numpy and restore
* original size and format by calling CANN Identity OP.
*
* TODO(qili93): remove this op after custom op and custom device
* integrated and then move this op along with its code to plugin.
*/
PD_REGISTER_KERNEL(npu_identity,
CPU,
ALL_LAYOUT,
phi::NPUIdentityKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(npu_identity,
GPU,
ALL_LAYOUT,
phi::NPUIdentityKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16) {}
#endif
......@@ -269,6 +269,11 @@ class ShardingOptimizer(MetaOptimizerBase):
self._gradient_merge_acc_step = gm_acc_step
self._optimizer_sharding = optimizer_sharding
# this feature is design for ascend, and should NOT be used in GPU training
self.pp_allreduce_in_optimize = sharding_configs[
"pp_allreduce_in_optimize"
]
def _inner_opt_minimize(
self, loss, startup_program, parameter_list, no_grad_set
):
......
......@@ -377,6 +377,12 @@ def monkey_patch_varbase():
return None
new_ivar = self._grad_ivar()
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1)
new_ivar = new_ivar._copy_to(core.CPUPlace(), True)
if self._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS:
return (
......
......@@ -27,6 +27,7 @@ from .tensor import segment_sum
from .tensor import segment_mean
from .tensor import segment_max
from .tensor import segment_min
from .tensor import _npu_identity
from .passes import fuse_resnet_unit_pass
from . import autograd # noqa: F401
......
......@@ -16,5 +16,6 @@ from .math import segment_sum
from .math import segment_mean
from .math import segment_max
from .math import segment_min
from .manipulation import _npu_identity
__all__ = []
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
__all__ = []
# TODO(qili93): remove this op after custom op and custom device
# integrated and then move this op along with its code to plugin.
def _npu_identity(x, format=-1):
"""
This OP takes in the Tensor :attr:`x` and change it to ouptut with
aclFormat with int value. This API is only used for Ascend NPU.
Args:
x(Tensor): An input N-D Tensor with data type bool, float16,
float32, float64, int32, int64, int16, int8, uint8.
format(int): Storage data format of the output in aclFormat,
default value is -1.
Returns:
Tensor: A Tensor with acl storage format on Ascend NPU.
Examples:
.. code-block:: python
# required: npu
import paddle
x = paddle.ones(shape=[6])
y = paddle.incubate._npu_identity(x, 3) # ACL_FORMAT_NC1HWC0 = 3
# y.shape = [1, 1, 1, 1, 16]
"""
if in_dygraph_mode():
return _C_ops.npu_identity(x, format)
else:
check_variable_and_dtype(
x,
'x',
[
'bool',
'int8',
'uint8',
'int16',
'int32',
'int64',
'float16',
'float32',
'float64',
],
'npu_identity',
)
helper = LayerHelper('npu_identity', **locals())
out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=x.stop_gradient
)
helper.append_op(
type='npu_identity',
inputs={'x': [x]},
outputs={'out': [out]},
attrs={'format': format},
)
return out
......@@ -14,6 +14,7 @@
from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode
from paddle.device import (
get_all_custom_device_type,
is_compiled_with_cuda,
is_compiled_with_custom_device,
is_compiled_with_rocm,
......@@ -26,6 +27,7 @@ from ...common_ops_import import Variable
from ...device import get_cudnn_version
from ...fluid.data_feeder import check_dtype, check_variable_and_dtype
from ...fluid.layer_helper import LayerHelper
from ...framework import no_grad
from ...tensor.manipulation import squeeze, unsqueeze
from ...utils import (
_contain_var,
......@@ -143,6 +145,16 @@ def _conv_nd(
new_shape = [1] * len(x.shape)
new_shape[channel_dim] = -1
bias = bias.reshape(new_shape)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
bias_storage = _C_ops.npu_identity(
bias, 3
) # ACL_FORMAT_NC1HWC0 = 3
bias_storage._share_underline_tensor_to(bias)
return _C_ops.add(pre_bias, bias)
else:
return pre_bias
......@@ -727,6 +739,16 @@ def conv2d(
+ bias.shape
+ [1 for i in range(len(x.shape) - channel_dim - 1)],
)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
bias_storage = _C_ops.npu_identity(
bias, 3
) # ACL_FORMAT_NC1HWC0 = 3
bias_storage._share_underline_tensor_to(bias)
return _C_ops.add(pre_bias, bias)
else:
return pre_bias
......
......@@ -33,6 +33,7 @@ import warnings
import numpy as np
from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode
from paddle.device import get_all_custom_device_type
from paddle.fluid.framework import in_dygraph_mode
from ...fluid import dygraph_utils
......@@ -723,6 +724,30 @@ class _BatchNormBase(Layer):
shape=param_shape,
)
self._variance.stop_gradient = True
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
weight_trans = _C_ops.npu_identity(
self.weight, 3
) # ACL_FORMAT_NC1HWC0 = 3
bias_trans = _C_ops.npu_identity(
self.bias, 3
) # ACL_FORMAT_NC1HWC0 = 3
mean_trans = _C_ops.npu_identity(
self._mean, 3
) # ACL_FORMAT_NC1HWC0 = 3
var_trans = _C_ops.npu_identity(
self._variance, 3
) # ACL_FORMAT_NC1HWC0 = 3
weight_trans._share_underline_tensor_to(self.weight)
bias_trans._share_underline_tensor_to(self.bias)
mean_trans._share_underline_tensor_to(self._mean)
var_trans._share_underline_tensor_to(self._variance)
self._data_format = data_format
self._in_place = False
self._momentum = momentum
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册