未验证 提交 099c2302 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] add _npu_identity op and api, test=develop (#47850)

* [NPU] add _npu_identity op and api, test=develop

* fix doc

* address comments
上级 7619188a
......@@ -580,6 +580,15 @@
func : mv
backward : mv_grad
- op : npu_identity
args : (Tensor x, int format = -1)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : npu_identity
- op : poisson
args : (Tensor x)
output : Tensor
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/npu_identity_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
template <typename T, typename Context>
void NPUIdentityKernel(const Context& dev_ctx,
const DenseTensor& x,
const int format,
DenseTensor* out) {
VLOG(4) << "npu_identity op is only for NPU, CPU or GPU kernel just empty "
"tensor with shape: "
<< out->dims() << ", please avoid using this kernel!";
*out = phi::EmptyLike<T, Context>(dev_ctx, *out);
}
} // namespace phi
/** [ Why need npu_identity op? ]
*
* 1. Ascend CANN use internal storage format for high performance
* computing, for example if run BatchNorm2D op with CANN internal
* storage format ACL_FORMAT_NC1HWC0, time costs in transdata will
* be removed, and at will gain 2x performance improvement.
*
* 2.The internal storage format will use storage_properties_ in
* DenseTensor, and will change the size and layout of denser, and
* finally it should be called when change tensor to numpy and restore
* original size and format by calling CANN Identity OP.
*
* TODO(qili93): remove this op after custom op and custom device
* integrated and then move this op along with its code to plugin.
*/
PD_REGISTER_KERNEL(npu_identity,
CPU,
ALL_LAYOUT,
phi::NPUIdentityKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(npu_identity,
GPU,
ALL_LAYOUT,
phi::NPUIdentityKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void NPUIdentityKernel(const Context& dev_ctx,
const DenseTensor& x,
const int format,
DenseTensor* out);
} // namespace phi
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class TestNPUIdentityOp(unittest.TestCase):
def setUp(self):
self.op_type = "npu_identity"
self.shape = [64, 6, 28, 28]
self.x = np.random.random(self.shape).astype(np.float32)
self.format = 3 # ACL_FORMAT_NC1HWC0 = 3
self.place = paddle.CPUPlace()
def test_api_static(self):
paddle.enable_static()
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with paddle.static.program_guard(main_program, startup_program):
x_data = paddle.static.data(
shape=self.shape, name="data", dtype='float32'
)
output = paddle.incubate._npu_identity(x=x_data, format=self.format)
exe = paddle.static.Executor()
exe.run(startup_program)
result = exe.run(
main_program, feed={x_data.name: self.x}, fetch_list=[output]
)
np.testing.assert_allclose(result[0].shape, self.shape, rtol=1e-08)
def test_api_dygraph(self):
paddle.disable_static(self.place)
x_tensor = paddle.to_tensor(self.x)
out = paddle.incubate._npu_identity(x_tensor, self.format)
np.testing.assert_allclose(out.shape, self.shape, rtol=1e-08)
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -27,6 +27,7 @@ from .tensor import segment_sum
from .tensor import segment_mean
from .tensor import segment_max
from .tensor import segment_min
from .tensor import _npu_identity
from .passes import fuse_resnet_unit_pass
from . import autograd # noqa: F401
......
......@@ -16,5 +16,6 @@ from .math import segment_sum
from .math import segment_mean
from .math import segment_max
from .math import segment_min
from .manipulation import _npu_identity
__all__ = []
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle import _C_ops, _legacy_C_ops
__all__ = []
# TODO(qili93): remove this op after custom op and custom device
# integrated and then move this op along with its code to plugin.
def _npu_identity(x, format=-1):
"""
This OP takes in the Tensor :attr:`x` and change it to ouptut with
aclFormat with int value. This API is only used for Ascend NPU.
Args:
x(Tensor): An input N-D Tensor with data type bool, float16,
float32, float64, int32, int64, int16, int8, uint8.
format(int): Storage data format of the output in aclFormat,
default value is -1.
Returns:
Tensor: A Tensor with acl storage format on Ascend NPU.
Examples:
.. code-block:: python
# required: npu
import paddle
x = paddle.ones(shape=[6])
y = paddle.incubate._npu_identity(x, 3) # ACL_FORMAT_NC1HWC0 = 3
# y.shape = [1, 1, 1, 1, 16]
"""
if in_dygraph_mode():
return _C_ops.npu_identity(x, format)
if _in_legacy_dygraph():
return _legacy_C_ops.npu_identity(x, format)
check_variable_and_dtype(
x,
'x',
[
'bool',
'int8',
'uint8',
'int16',
'int32',
'int64',
'float16',
'float32',
'float64',
],
'npu_identity',
)
helper = LayerHelper('npu_identity', **locals())
out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=x.stop_gradient
)
helper.append_op(
type='npu_identity',
inputs={'x': [x]},
outputs={'out': [out]},
attrs={'format': format},
)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册