“86fd748808dee2448bf368f3b1389f91ec6e9d29”上不存在“doc/api/v2/config/attr.html”
未验证 提交 0b20b76e 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] add NPU ops of stack and unstack, test=develop (#34084)

上级 2dde0eb0
...@@ -12,15 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,15 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/stack_op.h" #include "paddle/fluid/operators/stack_op.h"
#include "paddle/fluid/operators/unsqueeze_op.h" #include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -32,64 +25,56 @@ class StackNPUKernel : public framework::OpKernel<T> { ...@@ -32,64 +25,56 @@ class StackNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.MultiInput<Tensor>("X"); auto x = ctx.MultiInput<Tensor>("X");
int32_t N = x.size(); auto* y = ctx.Output<Tensor>("Y");
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += (x[0]->dims().size() + 1);
int num = static_cast<int>(x.size());
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
N, 0, platform::errors::InvalidArgument("number of input Tensor <= 0")); "number of input Tensor <= 0"));
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
std::vector<paddle::framework::Tensor> x_list; std::vector<paddle::framework::Tensor> x_list;
for (int i = 0; i < N; i++) { for (int i = 0; i < num; i++) {
x_list.push_back(*x[i]); x_list.push_back(*x[i]);
} }
y->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); const auto& runner =
NpuOpRunner("Pack", {x_list}, {*y}, {{"axis", axis}, {"N", num}});
runner.Run(stream);
}
};
if (axis < 0) { template <typename DeviceContext, typename T>
axis = axis + x_list[0].dims().size() + 1; class StackGradNPUKernel : public framework::OpKernel<T> {
} public:
auto* out = ctx.Output<Tensor>("Y"); void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += dy->dims().size();
int num = dy->dims()[axis];
auto place = ctx.GetPlace(); PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
"number of input Tensor <= 0"));
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
out->mutable_data<T>(place); std::vector<paddle::framework::Tensor> dx_list;
for (int i = 0; i < num; i++) {
if (axis != 0) { dx[i]->mutable_data<T>(ctx.GetPlace());
auto x_dim = x_list[0].dims(); dx_list.push_back(*dx[i]);
std::vector<int> vec_dim_tmp;
vec_dim_tmp.push_back(N);
for (auto i = 0; i < x_dim.size(); ++i) {
vec_dim_tmp.push_back(x_dim[i]);
}
Tensor tmp_stack(out->type());
tmp_stack.Resize(framework::make_ddim(vec_dim_tmp));
tmp_stack.mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("Pack", {x_list}, {tmp_stack}, {{"axis", 0}, {"N", N}});
runner.Run(stream);
std::vector<int64_t> vec_trans;
for (auto i = 1; i <= x_dim.size(); ++i) {
vec_trans.push_back(i);
if (i == axis) {
vec_trans.push_back(0);
}
}
const auto& runner_trans_final =
NpuOpRunner("TransposeD", {tmp_stack}, {*out}, {{"perm", vec_trans}});
runner_trans_final.Run(stream);
} else {
const auto& runner =
NpuOpRunner("Pack", {x_list}, {*out}, {{"axis", axis}, {"N", N}});
runner.Run(stream);
} }
const auto& runner =
NpuOpRunner("Unpack", {*dy}, {dx_list}, {{"axis", axis}, {"num", num}});
runner.Run(stream);
} }
}; };
...@@ -103,4 +88,8 @@ REGISTER_OP_NPU_KERNEL( ...@@ -103,4 +88,8 @@ REGISTER_OP_NPU_KERNEL(
ops::StackNPUKernel<paddle::platform::NPUDeviceContext, ops::StackNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
#endif REGISTER_OP_NPU_KERNEL(
stack_grad,
ops::StackGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::StackGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unstack_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class UnStackNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *dy = ctx.Input<Tensor>("X");
auto dx = ctx.MultiOutput<Tensor>("Y");
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += dy->dims().size();
int num = dy->dims()[axis];
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
std::vector<paddle::framework::Tensor> dx_list;
for (int i = 0; i < num; i++) {
dx[i]->mutable_data<T>(ctx.GetPlace());
dx_list.push_back(*dx[i]);
}
const auto &runner =
NpuOpRunner("Unpack", {*dy}, {dx_list}, {{"axis", axis}, {"num", num}});
runner.Run(stream);
}
};
template <typename DeviceContext, typename T>
class UnStackGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto x = ctx.MultiInput<Tensor>(framework::GradVarName("Y"));
auto *y = ctx.Output<Tensor>(framework::GradVarName("X"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += (x[0]->dims().size() + 1);
int num = static_cast<int>(x.size());
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
std::vector<paddle::framework::Tensor> x_list;
for (int i = 0; i < num; i++) {
x_list.push_back(*x[i]);
}
y->mutable_data<T>(ctx.GetPlace());
const auto &runner =
NpuOpRunner("Pack", {x_list}, {*y}, {{"axis", axis}, {"N", num}});
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
unstack, ops::UnStackNPUKernel<plat::NPUDeviceContext, float>,
ops::UnStackNPUKernel<plat::NPUDeviceContext, plat::float16>);
REGISTER_OP_NPU_KERNEL(
unstack_grad, ops::UnStackGradNPUKernel<plat::NPUDeviceContext, float>,
ops::UnStackGradNPUKernel<plat::NPUDeviceContext, plat::float16>);
...@@ -24,17 +24,18 @@ import paddle.fluid as fluid ...@@ -24,17 +24,18 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
paddle.enable_static() paddle.enable_static()
SEED = 2021
@unittest.skipIf(not paddle.is_compiled_with_npu(), @unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU") "core is not compiled with NPU")
class TestStack1(OpTest): class TestStackOpBase(OpTest):
def initDefaultParameters(self): def initDefaultParameters(self):
self.num_inputs = 4 self.num_inputs = 4
self.input_dim = (5, 6, 7) self.input_dim = (5, 6, 7)
self.axis = 0 self.axis = 0
self.dtype = 'float32'
def initParameters(self):
pass
def get_x_names(self): def get_x_names(self):
x_names = [] x_names = []
...@@ -44,10 +45,10 @@ class TestStack1(OpTest): ...@@ -44,10 +45,10 @@ class TestStack1(OpTest):
def setUp(self): def setUp(self):
self.initDefaultParameters() self.initDefaultParameters()
self.initParameters()
self.op_type = 'stack'
self.set_npu() self.set_npu()
self.op_type = "stack" self.init_dtype()
self.place = paddle.NPUPlace(0)
self.x = [] self.x = []
for i in range(self.num_inputs): for i in range(self.num_inputs):
self.x.append( self.x.append(
...@@ -64,89 +65,191 @@ class TestStack1(OpTest): ...@@ -64,89 +65,191 @@ class TestStack1(OpTest):
def set_npu(self): def set_npu(self):
self.__class__.use_npu = True self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False) self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, self.get_x_names(), 'Y')
class TestStack2(OpTest):
def initDefaultParameters(self):
self.num_inputs = 4
self.input_dim = (2, 3, 4)
self.axis = -1
self.dtype = 'float32'
def get_x_names(self): @unittest.skipIf(not paddle.is_compiled_with_npu(),
x_names = [] "core is not compiled with NPU")
for i in range(self.num_inputs): class TestStackOp1(TestStackOpBase):
x_names.append('x{}'.format(i)) def initParameters(self):
return x_names self.num_inputs = 16
def setUp(self):
self.initDefaultParameters()
self.set_npu()
self.op_type = "stack"
self.place = paddle.NPUPlace(0)
self.x = [] @unittest.skipIf(not paddle.is_compiled_with_npu(),
for i in range(self.num_inputs): "core is not compiled with NPU")
self.x.append( class TestStackOp2(TestStackOpBase):
np.random.random(size=self.input_dim).astype(self.dtype)) def initParameters(self):
self.num_inputs = 20
tmp = []
x_names = self.get_x_names()
for i in range(self.num_inputs):
tmp.append((x_names[i], self.x[i]))
self.inputs = {'X': tmp} @unittest.skipIf(not paddle.is_compiled_with_npu(),
self.outputs = {'Y': np.stack(self.x, axis=self.axis)} "core is not compiled with NPU")
self.attrs = {'axis': self.axis} class TestStackOp3(TestStackOpBase):
def initParameters(self):
self.axis = -1
def set_npu(self):
self.__class__.use_npu = True
def test_check_output(self): @unittest.skipIf(not paddle.is_compiled_with_npu(),
self.check_output_with_place(self.place, check_dygraph=False) "core is not compiled with NPU")
class TestStackOp4(TestStackOpBase):
def initParameters(self):
self.axis = -4
class TestStack3(OpTest): @unittest.skipIf(not paddle.is_compiled_with_npu(),
def initDefaultParameters(self): "core is not compiled with NPU")
self.num_inputs = 4 class TestStackOp5(TestStackOpBase):
self.input_dim = (2, 3, 4) def initParameters(self):
self.axis = 1 self.axis = 1
self.dtype = 'float32'
def get_x_names(self):
x_names = [] @unittest.skipIf(not paddle.is_compiled_with_npu(),
for i in range(self.num_inputs): "core is not compiled with NPU")
x_names.append('x{}'.format(i)) class TestStackOp6(TestStackOpBase):
return x_names def initParameters(self):
self.axis = 3
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestStackAPIWithLoDTensorArray(unittest.TestCase):
"""
Test stack api when the input(x) is a LoDTensorArray.
"""
def setUp(self): def setUp(self):
self.initDefaultParameters() self.axis = 1
self.set_npu() self.iter_num = 3
self.op_type = "stack" self.input_shape = [2, 3]
self.place = paddle.NPUPlace(0) self.x = np.random.random(self.input_shape).astype("float32")
self.place = paddle.NPUPlace(0) \
if paddle.is_compiled_with_npu() else paddle.CPUPlace()
self.set_program()
def set_program(self):
self.program = fluid.Program()
with fluid.program_guard(self.program):
input = fluid.layers.assign(self.x)
tensor_array = fluid.layers.create_array(dtype='float32')
zero = fluid.layers.fill_constant(shape=[1], value=0, dtype="int64")
for i in range(self.iter_num):
fluid.layers.array_write(input, zero + i, tensor_array)
self.out_var = fluid.layers.stack(tensor_array, axis=self.axis)
def test_case(self):
self.assertTrue(self.out_var.shape[self.axis] == -1)
exe = fluid.Executor(self.place)
res = exe.run(self.program, fetch_list=self.out_var)
self.assertTrue(
np.array_equal(
res[0], np.stack(
[self.x] * self.iter_num, axis=self.axis)))
self.x = []
for i in range(self.num_inputs):
self.x.append(
np.random.random(size=self.input_dim).astype(self.dtype))
tmp = [] @unittest.skipIf(not paddle.is_compiled_with_npu(),
x_names = self.get_x_names() "core is not compiled with NPU")
for i in range(self.num_inputs): class TestTensorStackAPIWithLoDTensorArray(unittest.TestCase):
tmp.append((x_names[i], self.x[i])) """
Test stack api when the input(x) is a LoDTensorArray.
"""
self.inputs = {'X': tmp} def setUp(self):
self.outputs = {'Y': np.stack(self.x, axis=self.axis)} self.axis = 1
self.attrs = {'axis': self.axis} self.iter_num = 3
self.input_shape = [2, 3]
self.x = np.random.random(self.input_shape).astype("float32")
self.place = paddle.NPUPlace(0) \
if paddle.is_compiled_with_npu() else paddle.CPUPlace()
self.set_program()
def set_program(self):
self.program = fluid.Program()
with fluid.program_guard(self.program):
input = fluid.layers.assign(self.x)
tensor_array = fluid.layers.create_array(dtype='float32')
zero = fluid.layers.fill_constant(shape=[1], value=0, dtype="int64")
for i in range(self.iter_num):
fluid.layers.array_write(input, zero + i, tensor_array)
self.out_var = paddle.stack(tensor_array, axis=self.axis)
def test_case(self):
self.assertTrue(self.out_var.shape[self.axis] == -1)
exe = fluid.Executor(self.place)
res = exe.run(self.program, fetch_list=self.out_var)
self.assertTrue(
np.array_equal(
res[0], np.stack(
[self.x] * self.iter_num, axis=self.axis)))
def set_npu(self):
self.__class__.use_npu = True
def test_check_output(self): @unittest.skipIf(not paddle.is_compiled_with_npu(),
self.check_output_with_place(self.place, check_dygraph=False) "core is not compiled with NPU")
class API_test(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.layers.data('data1', shape=[1, 2], dtype='float64')
data2 = fluid.layers.data('data2', shape=[1, 2], dtype='float64')
data3 = fluid.layers.data('data3', shape=[1, 2], dtype='float64')
result_stack = paddle.stack([data1, data2, data3], axis=0)
place = paddle.NPUPlace(0)
exe = fluid.Executor(place)
input1 = np.random.random([1, 2]).astype('float64')
input2 = np.random.random([1, 2]).astype('float64')
input3 = np.random.random([1, 2]).astype('float64')
result, = exe.run(
feed={"data1": input1,
"data2": input2,
"data3": input3},
fetch_list=[result_stack])
expected_result = np.stack([input1, input2, input3], axis=0)
self.assertTrue(np.allclose(expected_result, result))
def test_single_tensor_error(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = paddle.rand([2, 3])
self.assertRaises(TypeError, paddle.stack, x)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class API_DygraphTest(unittest.TestCase):
def test_out(self):
data1 = np.array([[1.0, 2.0]])
data2 = np.array([[3.0, 4.0]])
data3 = np.array([[5.0, 6.0]])
with fluid.dygraph.guard(place=paddle.NPUPlace(0)):
x1 = fluid.dygraph.to_variable(data1)
x2 = fluid.dygraph.to_variable(data2)
x3 = fluid.dygraph.to_variable(data3)
result = paddle.stack([x1, x2, x3])
result_np = result.numpy()
expected_result = np.stack([data1, data2, data3])
self.assertTrue(np.allclose(expected_result, result_np))
with fluid.dygraph.guard(place=paddle.NPUPlace(0)):
y1 = fluid.dygraph.to_variable(data1)
result = paddle.stack([y1], axis=0)
result_np_2 = result.numpy()
expected_result_2 = np.stack([data1], axis=0)
self.assertTrue(np.allclose(expected_result_2, result_np_2))
def test_single_tensor_error(self):
with fluid.dygraph.guard(place=paddle.NPUPlace(0)):
x = paddle.to_tensor([1, 2, 3])
self.assertRaises(Exception, paddle.stack, x)
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import unittest
import paddle
paddle.enable_static()
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestUnStackOpBase(OpTest):
def initDefaultParameters(self):
self.input_dim = (5, 6, 7)
self.axis = 0
def initParameters(self):
pass
def get_y_names(self):
y_names = []
for i in range(self.input_dim[self.axis]):
y_names.append('y{}'.format(i))
return y_names
def setUp(self):
self.initDefaultParameters()
self.initParameters()
self.op_type = 'unstack'
self.set_npu()
self.init_dtype()
self.x = np.random.random(size=self.input_dim).astype(self.dtype)
outs = np.split(self.x, self.input_dim[self.axis], self.axis)
new_shape = list(self.input_dim)
del new_shape[self.axis]
y_names = self.get_y_names()
tmp = []
for i in range(self.input_dim[self.axis]):
tmp.append((y_names[i], np.reshape(outs[i], new_shape)))
self.inputs = {'X': self.x}
self.outputs = {'Y': tmp}
self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]}
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], self.get_y_names())
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestStackOp3(TestUnStackOpBase):
def initParameters(self):
self.axis = -1
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestStackOp4(TestUnStackOpBase):
def initParameters(self):
self.axis = -3
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestStackOp5(TestUnStackOpBase):
def initParameters(self):
self.axis = 1
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestStackOp6(TestUnStackOpBase):
def initParameters(self):
self.axis = 2
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册