From a5fcc4b5458d737efb3dc671193d046e07cc6136 Mon Sep 17 00:00:00 2001 From: TTerror Date: Tue, 8 Dec 2020 14:17:03 +0800 Subject: [PATCH] update reduce_sum op on xpu (#29367) * update reduce_sum op on xpu * update reduce_sum op on xpu * support running on xpu --- .../operators/reduce_ops/reduce_sum_op_xpu.cc | 162 ++++++++----- .../unittests/xpu/test_reduce_sum_op_xpu.py | 223 +++++++----------- 2 files changed, 185 insertions(+), 200 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc index b751eca9ee0..f67d43194a0 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc @@ -16,6 +16,8 @@ #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include #include +#include "paddle/fluid/platform/xpu_header.h" + namespace paddle { namespace operators { @@ -27,86 +29,120 @@ class ReduceSumXPUKernel : public framework::OpKernel { platform::is_xpu_place(context.GetPlace()), true, platform::errors::Unavailable("This kernel only runs on XPU.")); bool reduce_all = context.Attr("reduce_all"); - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); + auto dims = context.Attr>("dim"); + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + y->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); + + int out_dtype = context.Attr("out_dtype"); + PADDLE_ENFORCE_EQ( + out_dtype == -1, true, + platform::errors::InvalidArgument( + "XPU only support out_dtype == -1 in reduce_sum op.")); + + const auto* x_data = x->data(); + auto* y_data = y->data(); + const auto& input_dim_size = x->dims().size(); + std::vector true_dims; + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) { + true_dims.push_back(dims[i] + input_dim_size); + } else { + true_dims.push_back(dims[i]); + } + } + + std::vector reduce_dims; + std::vector xdims((input_dim_size)); + for (int i = 0; i < input_dim_size; ++i) { + xdims[i] = x->dims()[i]; + } if (reduce_all) { - int input_len = input->numel(); - int r = xpu::sum(dev_ctx.x_context(), input->data(), output->data(), - input_len); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External("XPU kernel error!")); + for (int i = 0; i < input_dim_size; ++i) { + reduce_dims.push_back(i); + } } else { - int ndim = input->dims().size(); - std::vector idims; - for (int i = 0; i < input->dims().size(); i++) { - idims.push_back(input->dims()[i]); + std::set dims_set(true_dims.begin(), true_dims.end()); + for (auto i = 0; i < input_dim_size; i++) { + if (dims_set.find(i) != dims_set.end()) { + if (x->dims()[i] != 1) { + reduce_dims.push_back(i); + } + } } - auto dims = context.Attr>("dim"); - int rdim = dims.size(); - int r = - xpu::reduce(dev_ctx.x_context(), input->data(), output->data(), - idims.data(), ndim, dims.data(), rdim, xpu::REDUCE_SUM); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External("XPU kernel error!")); + } + + if (reduce_dims.size() == 0) { + int r = xpu::copy(dev_ctx.x_context(), x_data, y_data, + x->numel() * sizeof(T)); + PADDLE_ENFORCE_EQ( + r == xpu::Error_t::SUCCESS, true, + platform::errors::External("XPU copy in reduce_sum op return " + "wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); + } else { + int r = xpu::reduce_sum(dev_ctx.x_context(), x_data, y_data, xdims, + reduce_dims); + PADDLE_ENFORCE_EQ( + r == xpu::Error_t::SUCCESS, true, + platform::errors::External("XPU reduce_sum in reduce_sum op return" + " wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); } } }; + template class ReduceSumGradXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto dims = context.Attr>("dim"); bool reduce_all = context.Attr("reduce_all"); - auto* input0 = context.Input("X"); - auto* input2 = context.Input(framework::GradVarName("Out")); - auto* output = context.Output(framework::GradVarName("X")); - output->mutable_data(context.GetPlace()); - const auto* input2_d = input2->data(); - auto* output_d = output->data(); + auto* x = context.Input("X"); + auto* out = context.Input(framework::GradVarName("Out")); + auto* x_grad = context.Output(framework::GradVarName("X")); + + int in_dtype = context.Attr("in_dtype"); + PADDLE_ENFORCE_EQ( + in_dtype == -1, true, + platform::errors::InvalidArgument( + "XPU only support in_dtype == -1 in reduce_sum_grad op.")); + auto& dev_ctx = context.template device_context(); - int r = 0; - std::vector idims; - int reduce_dim = 0; - if (reduce_all) { - idims.push_back(input0->numel()); - idims.push_back(1); - idims.push_back(1); - r = xpu::reduce_grad(dev_ctx.x_context(), input2_d, output_d, - idims.data(), idims.size(), &reduce_dim, 1, - xpu::REDUCE_SUM); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External("XPU kernel error!")); - } else if (dims.size() == 1) { - // handle reduce by one dimension - int reduce_dim_index = dims[0]; - if (reduce_dim_index < 0) { - reduce_dim_index += input0->dims().size(); - } - auto& input_dim = input0->dims(); - int before_dim = 1; - for (int i = 0; i < reduce_dim_index; ++i) { - before_dim *= input_dim[i]; + x_grad->mutable_data(context.GetPlace()); + const auto* out_data = out->data(); + auto* x_grad_data = x_grad->data(); + + const auto& input_dim_size = x->dims().size(); + std::vector true_dims; + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) { + true_dims.push_back(dims[i] + input_dim_size); + } else { + true_dims.push_back(dims[i]); } - int reduce_dim = input_dim[reduce_dim_index]; - int after_dim = 1; - for (int i = reduce_dim_index + 1; i < input_dim.size(); ++i) { - after_dim *= input_dim[i]; + } + + std::vector ydims(input_dim_size); + std::vector xdims((input_dim_size)); + std::set dims_set(true_dims.begin(), true_dims.end()); + for (auto i = 0; i < input_dim_size; i++) { + xdims[i] = x->dims()[i]; + if (dims_set.find(i) != dims_set.end() || reduce_all) { + ydims[i] = 1; + } else { + ydims[i] = x->dims()[i]; } - idims.push_back(before_dim); - idims.push_back(input_dim[reduce_dim_index]); - idims.push_back(after_dim); - reduce_dim = 1; - r = xpu::reduce_grad(dev_ctx.x_context(), input2_d, output_d, - idims.data(), idims.size(), &reduce_dim, 1, - xpu::REDUCE_SUM); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External("XPU kernel error!")); - } else { - PADDLE_THROW( - platform::errors::Unimplemented("unsupport reduce sum grad")); } + + int r = xpu::broadcast(dev_ctx.x_context(), out_data, x_grad_data, ydims, + xdims); + PADDLE_ENFORCE_EQ( + r == xpu::Error_t::SUCCESS, true, + platform::errors::External("XPU broadcast in reduce_sum_grad op return" + " wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); } }; diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py index 2a0457d1862..638da601a3d 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py @@ -18,7 +18,8 @@ import unittest import numpy as np import sys sys.path.append("..") -from op_test import OpTest, skip_check_grad_ci +from op_test_xpu import OpTest, XPUOpTest +from op_test import skip_check_grad_ci import paddle import paddle.fluid.core as core import paddle.fluid as fluid @@ -26,180 +27,128 @@ from paddle.fluid import compiler, Program, program_guard from paddle.fluid.framework import convert_np_dtype_to_dtype_ -class TestSumOp(OpTest): +class TestXPUReduceSumOp(XPUOpTest): def setUp(self): - self.op_type = "reduce_sum" - self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} - self.attrs = {'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum(axis=0)} - - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def check_grad_(self): - self.check_grad(['X'], 'Out') - - -class TestSumOp5D(OpTest): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = { - 'X': np.random.random((1, 2, 5, 6, 10)).astype("float64") + self.init_op_type() + self.initTestCase() + self.use_xpu = True + self.use_mkldnn = False + self.attrs = { + 'dim': self.axis, + 'keep_dim': self.keep_dim, + 'reduce_all': self.reduce_all } - self.attrs = {'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + self.inputs = {'X': np.random.random(self.shape).astype("float32")} + if self.attrs['reduce_all']: + self.outputs = {'Out': self.inputs['X'].sum()} + else: + self.outputs = { + 'Out': self.inputs['X'].sum(axis=self.axis, + keepdims=self.attrs['keep_dim']) + } def test_check_output(self): if paddle.is_compiled_with_xpu(): + paddle.enable_static() place = paddle.XPUPlace(0) self.check_output_with_place(place) def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class TestSumOp6D(OpTest): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = { - 'X': np.random.random((1, 1, 2, 5, 6, 10)).astype("float64") - } - self.attrs = {'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum(axis=0)} - - def test_check_output(self): if paddle.is_compiled_with_xpu(): + paddle.enable_static() place = paddle.XPUPlace(0) - self.check_output_with_place(place) + self.check_grad_with_place(place, ['X'], 'Out') - def test_check_grad(self): - self.check_grad(['X'], 'Out') + def init_op_type(self): + self.op_type = "reduce_sum" + self.use_mkldnn = False + self.keep_dim = False + self.reduce_all = False + def initTestCase(self): + self.shape = (5, 6, 10) + self.axis = (0, ) -class TestSumOp8D(OpTest): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = { - 'X': np.random.random((1, 3, 1, 2, 1, 4, 3, 10)).astype("float64") - } - self.attrs = {'dim': (0, 3), 'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum(axis=(0, 3))} - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) +class TestSumOp5D(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (1, 2, 5, 6, 10) + self.axis = (0, ) - def test_check_grad(self): - self.check_grad(['X'], 'Out') +class TestSumOp6D(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (1, 1, 2, 5, 6, 10) + self.axis = (0, ) -class Test1DReduce(OpTest): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = {'X': np.random.random(120).astype("float64")} - self.attrs = {'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum(axis=0)} - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) +class TestSumOp8D(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (1, 3, 1, 2, 1, 4, 3, 10) + self.axis = (0, 3) - def test_check_grad(self): - self.check_grad(['X'], 'Out') +class Test1DReduce(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = 120 + self.axis = (0, ) -class Test2DReduce0(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [0], 'use_xpu': True} - self.inputs = {'X': np.random.random((20, 10)).astype("float64")} - self.outputs = {'Out': self.inputs['X'].sum(axis=0)} +class Test2DReduce0(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (20, 10) + self.axis = (0, ) -class Test2DReduce1(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [1], 'use_xpu': True} - self.inputs = {'X': np.random.random((20, 10)).astype("float64")} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } +class Test2DReduce1(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (20, 10) + self.axis = (1, ) -class Test3DReduce0(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [1], 'use_xpu': True} - self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } +class Test3DReduce0(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (1, ) -class Test3DReduce1(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [2], 'use_xpu': True} - self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } +class Test3DReduce1(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (2, ) -class Test3DReduce2(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [-2], 'use_xpu': True} - self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } +class Test3DReduce2(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (-2, ) -class Test3DReduce3(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [1, 2], 'use_xpu': True} - self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } +class Test3DReduce3(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (1, 2) -class TestKeepDimReduce(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} - self.attrs = {'dim': [1], 'keep_dim': True, 'use_xpu': True} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), - keepdims=self.attrs['keep_dim']) - } +class TestKeepDimReduce(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 10) + self.axis = (1, ) + self.keep_dim = True -class TestKeepDim8DReduce(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = { - 'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64") - } - self.attrs = {'dim': (3, 4, 5), 'keep_dim': True, 'use_xpu': True} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), - keepdims=self.attrs['keep_dim']) - } +class TestKeepDim8DReduce(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (2, 5, 3, 2, 2, 3, 4, 2) + self.axis = (3, 4, 5) + self.keep_dim = True -class TestReduceAll(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")} - self.attrs = {'reduce_all': True, 'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum()} + +class TestReduceAll(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 2, 10) + self.axis = (0, ) + self.reduce_all = True if __name__ == '__main__': -- GitLab