未验证 提交 79be8427 编写于 作者: L Liu-xiandong 提交者: GitHub

[NPU] Support npu kernel for flatten_contiguous_range op, test=develop (#34642)

* fix npu compile error, test=develop

* [NPU] Support npu kernel for flatten_contiguous_range op, test=develop

* [NPU] Support npu kernel for flatten_contiguous_range op, test=develop

* [NPU] Support npu kernel for flatten_contiguous_range op, test=develop

* [NPU] Support npu kernel for flatten_contiguous_range op, test=develop

* [NPU] Support npu kernel for flatten_contiguous_range op, test=develop

* [NPU] Support npu kernel for flatten_contiguous_range op, test=develop

* [NPU] Support npu kernel for flatten_contiguous_range op, test=develop

* Update flatten_op_npu.cc

* Update flatten_op_npu.cc
Co-authored-by: Nqili93 <qili93@qq.com>
上级 8f9d573f
......@@ -54,6 +54,30 @@ class Flatten2GradNPUKernel : public framework::OpKernel<T> {
}
};
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class FlattenContiguousRangeNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
auto *Out = ctx.Output<Tensor>("Out");
int start_axis = ctx.Attr<int>("start_axis");
int stop_axis = ctx.Attr<int>("stop_axis");
Out->mutable_data<T>(ctx.GetPlace());
const auto &runner =
NpuOpRunner("FlattenV2", {*X}, {*Out},
{{"axis", static_cast<int32_t>(start_axis)},
{"end_axis", static_cast<int32_t>(stop_axis)}});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
......@@ -71,3 +95,18 @@ REGISTER_OP_NPU_KERNEL(flatten2_grad, ops::Flatten2GradNPUKernel<float>,
ops::Flatten2GradNPUKernel<int>,
ops::Flatten2GradNPUKernel<int8_t>,
ops::Flatten2GradNPUKernel<int64_t>);
REGISTER_OP_NPU_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext,
float>,
ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext,
double>,
ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext,
uint8_t>,
ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext,
int>,
ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext,
int64_t>);
......@@ -33,6 +33,7 @@ static std::map<framework::proto::VarType::Type, aclDataType>
DTYPE_2_ACL_DTYPE = {
{framework::proto::VarType::BOOL, ACL_BOOL},
{framework::proto::VarType::UINT8, ACL_UINT8},
{framework::proto::VarType::INT8, ACL_INT8},
{framework::proto::VarType::INT16, ACL_INT16},
{framework::proto::VarType::INT32, ACL_INT32},
{framework::proto::VarType::INT64, ACL_INT64},
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
class TestFlattenOp(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "flatten_contiguous_range"
self.place = paddle.NPUPlace(0)
self.start_axis = 0
self.stop_axis = -1
self.dtype = np.float64
self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype(self.dtype)}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.in_shape).astype("float32")
}
def set_npu(self):
self.__class__.use_npu = True
def test_check_output(self):
self.check_output_with_place(self.place, no_check_set=["XShape"])
def test_check_grad(self):
pass
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = -1
self.new_shape = (120)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_1(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 1
self.stop_axis = 2
self.new_shape = (3, 10, 4)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_2(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_3(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 2
self.new_shape = (30, 4)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_4(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = -2
self.stop_axis = -1
self.new_shape = (3, 2, 20)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_5(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 2
self.stop_axis = 2
self.new_shape = (3, 2, 5, 4)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOpSixDims(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4)
self.start_axis = 3
self.stop_axis = 5
self.new_shape = (3, 2, 3, 32)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_Float32(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
self.dtype = np.float32
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_int(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
self.dtype = np.int
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_uint8(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
self.dtype = np.uint8
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_int8(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
self.dtype = np.int8
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_int64(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
self.dtype = np.int64
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlatten2OpError(unittest.TestCase):
def test_errors(self):
image_shape = (2, 3, 4, 4)
x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] *
image_shape[3]).reshape(image_shape) / 100.
x = x.astype('float32')
def test_ValueError1():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32')
out = paddle.flatten(x_var, start_axis=2, stop_axis=1)
self.assertRaises(ValueError, test_ValueError1)
def test_ValueError2():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32')
paddle.flatten(x_var, start_axis=10, stop_axis=1)
self.assertRaises(ValueError, test_ValueError2)
def test_ValueError3():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32')
paddle.flatten(x_var, start_axis=2, stop_axis=10)
self.assertRaises(ValueError, test_ValueError3)
def test_type():
# dtype must be float32, float64, int8, int32, int64, uint8.
x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] *
image_shape[3]).reshape(image_shape) / 100.
x2 = x2.astype('float16')
x2_var = paddle.fluid.data(
name='x2', shape=[3, 2, 4, 5], dtype='float16')
paddle.flatten(x2_var)
self.assertRaises(TypeError, test_type)
def test_InputError():
out = paddle.flatten(x)
self.assertRaises(ValueError, test_InputError)
class TestStaticFlattenPythonAPI(unittest.TestCase):
def execute_api(self, x, start_axis=0, stop_axis=-1):
return paddle.flatten(x, start_axis, stop_axis)
def test_static_api(self):
paddle.enable_static()
np_x = np.random.rand(2, 3, 4, 4).astype('float32')
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.static.data(
name="x", shape=[2, 3, 4, 4], dtype='float32')
out = self.execute_api(x, start_axis=-2, stop_axis=-1)
exe = paddle.static.Executor(place=paddle.NPUPlace(0))
fetch_out = exe.run(main_prog, feed={"x": np_x}, fetch_list=[out])
self.assertTrue((2, 3, 16) == fetch_out[0].shape)
class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI):
def execute_api(self, x, start_axis=0, stop_axis=-1):
return x.flatten_(start_axis, stop_axis)
class TestFlattenPython(unittest.TestCase):
def test_python_api(self):
image_shape = (2, 3, 4, 4)
x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] *
image_shape[3]).reshape(image_shape) / 100.
x = x.astype('float32')
def test_InputError():
out = paddle.flatten(x)
self.assertRaises(ValueError, test_InputError)
def test_Negative():
paddle.disable_static(paddle.NPUPlace(0))
img = paddle.to_tensor(x)
out = paddle.flatten(img, start_axis=-2, stop_axis=-1)
return out.numpy().shape
res_shape = test_Negative()
self.assertTrue((2, 3, 16) == res_shape)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册