From 397c9403a2c02c988ce8ea322209099ff073c263 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Tue, 21 Feb 2023 18:20:53 +0800 Subject: [PATCH] add c_reduce_sum/unstack/all_reduce_datatype for kunlun (#50606) --- .../collective/c_allreduce_max_op_xpu.cc | 2 + .../collective/c_reduce_sum_op_xpu.cc | 1 + paddle/phi/backends/xpu/xpu2_op_list.cc | 16 +++ paddle/phi/kernels/xpu/unstack_grad_kernel.cc | 65 +++++++++ paddle/phi/kernels/xpu/unstack_kernel.cc | 60 ++++++++ .../unittests/xpu/test_unstack_op_xpu.py | 131 ++++++++++++++++++ 6 files changed, 275 insertions(+) create mode 100644 paddle/phi/kernels/xpu/unstack_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/unstack_kernel.cc create mode 100755 python/paddle/fluid/tests/unittests/xpu/test_unstack_op_xpu.py diff --git a/paddle/fluid/operators/collective/c_allreduce_max_op_xpu.cc b/paddle/fluid/operators/collective/c_allreduce_max_op_xpu.cc index 296a8b3a633..86527cf0e6e 100644 --- a/paddle/fluid/operators/collective/c_allreduce_max_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_allreduce_max_op_xpu.cc @@ -18,4 +18,6 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL(c_allreduce_max, + ops::CAllReduceOpXPUKernel, + ops::CAllReduceOpXPUKernel, ops::CAllReduceOpXPUKernel) diff --git a/paddle/fluid/operators/collective/c_reduce_sum_op_xpu.cc b/paddle/fluid/operators/collective/c_reduce_sum_op_xpu.cc index f3ec15ca9e5..9f35a6866ee 100644 --- a/paddle/fluid/operators/collective/c_reduce_sum_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_reduce_sum_op_xpu.cc @@ -18,4 +18,5 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL(c_reduce_sum, + ops::CReduceOpXPUKernel, ops::CReduceOpXPUKernel) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 31681d95703..f9bb267bd01 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -79,6 +79,10 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT64, phi::DataType::INT32, phi::DataType::INT64})}, + {"c_allreduce_max", + XPUKernelSet({phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::INT32})}, {"c_allreduce_sum", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32, @@ -94,6 +98,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT64, phi::DataType::INT32, phi::DataType::INT64})}, + {"c_reduce_sum", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, {"c_split", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32, @@ -730,6 +736,16 @@ XPUOpMap& get_kl2_ops() { phi::DataType::UINT8, phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"unstack", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"unstack_grad", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, {"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"warpctc", XPUKernelSet({phi::DataType::FLOAT32})}, {"where_index", diff --git a/paddle/phi/kernels/xpu/unstack_grad_kernel.cc b/paddle/phi/kernels/xpu/unstack_grad_kernel.cc new file mode 100644 index 00000000000..e29c313dad1 --- /dev/null +++ b/paddle/phi/kernels/xpu/unstack_grad_kernel.cc @@ -0,0 +1,65 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/unstack_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void UnStackGradKernel(const Context &dev_ctx, + const std::vector &x, + int axis, + DenseTensor *x_grad) { + using XPUType = typename XPUTypeTrait::Type; + if (axis < 0) { + axis += x[0]->dims().size() + 1; + } + dev_ctx.template Alloc(x_grad); + auto &dim = x[0]->dims(); + std::vector xdims; + for (auto i = 0; i < dim.size(); ++i) { + xdims.push_back(dim[i]); + } + xdims.push_back(1); + std::vector> xdims_list; + int n = static_cast(x.size()); + for (int i = 0; i < n; i++) { + xdims_list.push_back(xdims); + } + + std::vector x_list; + for (int i = 0; i < n; i++) { + x_list.push_back(reinterpret_cast(x[i]->data())); + } + + int r = xpu::concat(dev_ctx.x_context(), + x_list, + reinterpret_cast(x_grad->data()), + xdims_list, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "concat in unstack_grad op"); +} +} // namespace phi + +PD_REGISTER_KERNEL(unstack_grad, + XPU, + ALL_LAYOUT, + phi::UnStackGradKernel, + float, + phi::dtype::float16, + int, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/unstack_kernel.cc b/paddle/phi/kernels/xpu/unstack_kernel.cc new file mode 100644 index 00000000000..1c9c7a79795 --- /dev/null +++ b/paddle/phi/kernels/xpu/unstack_kernel.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/unstack_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void UnStackKernel(const Context &dev_ctx, + const DenseTensor &x, + int axis, + int num, + std::vector outs) { + using XPUType = typename XPUTypeTrait::Type; + auto x_dims = x.dims(); + + if (axis < 0) axis += x_dims.size(); + auto x_shape = phi::vectorize(x_dims); + + std::vector dx_dims_list(outs.size(), 1); + std::vector dx_lists; + for (size_t j = 0; j < outs.size(); ++j) { + dev_ctx.template Alloc(outs[j]); + dx_lists.push_back(reinterpret_cast(outs[j]->data())); + } + + int r = xpu::split(dev_ctx.x_context(), + reinterpret_cast(x.data()), + dx_lists, + x_shape, + dx_dims_list, + axis); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "split in unstack op"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(unstack, + XPU, + ALL_LAYOUT, + phi::UnStackKernel, + phi::dtype::float16, + float, + int, + int64_t) {} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_unstack_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_unstack_op_xpu.py new file mode 100755 index 00000000000..bb9e9bff21d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_unstack_op_xpu.py @@ -0,0 +1,131 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +sys.path.append("..") +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) + +import paddle + +paddle.enable_static() + + +class XPUTestUnStackOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'unstack' + self.use_dynamic_create_class = False + + class TestUnStackOpBase(XPUOpTest): + def initDefaultParameters(self): + self.input_dim = (5, 6, 7) + self.axis = 0 + self.dtype = 'float32' + + def initParameters(self): + pass + + def get_y_names(self): + y_names = [] + for i in range(self.input_dim[self.axis]): + y_names.append('y{}'.format(i)) + return y_names + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + self.op_type = 'unstack' + self.python_api = paddle.unstack + self.x = np.random.random(size=self.input_dim).astype(self.dtype) + + outs = np.split(self.x, self.input_dim[self.axis], self.axis) + new_shape = list(self.input_dim) + del new_shape[self.axis] + y_names = self.get_y_names() + tmp = [] + tmp_names = [] + for i in range(self.input_dim[self.axis]): + tmp.append((y_names[i], np.reshape(outs[i], new_shape))) + tmp_names.append(y_names[i]) + + self.python_out_sig = tmp_names + self.inputs = {'X': self.x} + self.outputs = {'Y': tmp} + self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]} + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.XPUPlace(0), self.get_y_names, 'Y' + ) + + class TestStackOp3(TestUnStackOpBase): + def initParameters(self): + self.axis = -1 + + class TestStackOp4(TestUnStackOpBase): + def initParameters(self): + self.axis = -3 + + class TestStackOp5(TestUnStackOpBase): + def initParameters(self): + self.axis = 1 + + class TestStackOp6(TestUnStackOpBase): + def initParameters(self): + self.axis = 2 + + class TestUnstackZeroInputOp(unittest.TestCase): + def unstack_zero_input_static(self): + + paddle.enable_static() + + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32') + paddle.unstack(x, axis=1) + + def unstack_zero_input_dynamic(self): + + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32') + paddle.unstack(x, axis=1) + + def test_type_error(self): + paddle.disable_static() + + self.assertRaises(ValueError, self.unstack_zero_input_dynamic) + self.assertRaises(ValueError, self.unstack_zero_input_static) + + paddle.disable_static() + + +support_types = get_xpu_op_support_types('unstack') +for stype in support_types: + create_test_class(globals(), XPUTestUnStackOp, stype) + + +if __name__ == '__main__': + unittest.main() -- GitLab