From f9e55dee9cc1c7cac70bd87200d228aec931deea Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Tue, 31 May 2022 14:52:36 +0800 Subject: [PATCH] [NPU] fix arg_max and reduce_max (#42887) * fix arg_max and reduce_max * add arg_max ut --- paddle/fluid/operators/arg_max_op_npu.cc | 9 ++++++- .../operators/reduce_ops/reduce_max_op_npu.cc | 24 +++++++++++++++-- .../unittests/npu/test_arg_max_op_npu.py | 27 +++++++++++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/arg_max_op_npu.cc b/paddle/fluid/operators/arg_max_op_npu.cc index 680183b6ad..5c6b276c01 100644 --- a/paddle/fluid/operators/arg_max_op_npu.cc +++ b/paddle/fluid/operators/arg_max_op_npu.cc @@ -34,11 +34,18 @@ struct VisitDataArgNPUMaxFunctor { out.template mutable_data(ctx.GetPlace()); auto axis = ctx.Attr("axis"); auto dtype = ctx.Attr("dtype"); + const bool& flatten = ctx.Attr("flatten"); + + Tensor transformed_x(x.type()); + transformed_x.ShareDataWith(x); + if (flatten) { + transformed_x.Resize(phi::make_ddim({x.numel()})); + } auto stream = ctx.template device_context().stream(); NpuOpRunner runner; runner.SetType("ArgMaxV2") - .AddInput(x) + .AddInput(transformed_x) .AddInput(std::vector{axis}) .AddOutput(out) .AddAttrDataType("dtype", dtype) diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc index 04660fb501..e3d8d15a30 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc @@ -112,6 +112,8 @@ class ReduceMaxGradNPUKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* out = context.Input("Out"); auto* out_grad = context.Input(framework::GradVarName("Out")); + auto reduce_dims = context.Attr>("dim"); + bool reduce_all = context.Attr("reduce_all"); int in_dtype = context.Attr("in_dtype"); PADDLE_ENFORCE_EQ( @@ -129,12 +131,30 @@ class ReduceMaxGradNPUKernel : public framework::OpKernel { // broadcast auto x_dims_vec = phi::vectorize(x->dims()); + if (reduce_all) { + reduce_dims.clear(); + for (size_t d = 0; d < x_dims_vec.size(); ++d) { + reduce_dims.push_back(static_cast(d)); + } + } + + Tensor tmp_out, tmp_out_grad; + auto tmp_out_dims_vec = x_dims_vec; + for (auto d : reduce_dims) { + tmp_out_dims_vec[d] = 1; + } + + tmp_out.ShareDataWith(*out); + tmp_out.Resize(phi::make_ddim(tmp_out_dims_vec)); + tmp_out_grad.ShareDataWith(*out_grad); + tmp_out_grad.Resize(phi::make_ddim(tmp_out_dims_vec)); + Tensor transformed_out(x->type()); transformed_out.Resize(phi::make_ddim(x_dims_vec)); transformed_out.mutable_data(place); NpuOpRunner r_brd_out; r_brd_out.SetType("BroadcastTo") - .AddInput(*out) + .AddInput(tmp_out) .AddInput(std::move(x_dims_vec)) .AddOutput(transformed_out) .Run(stream); @@ -143,7 +163,7 @@ class ReduceMaxGradNPUKernel : public framework::OpKernel { transformed_out_grad.mutable_data(place); NpuOpRunner r_brd_out_grad; r_brd_out_grad.SetType("BroadcastTo") - .AddInput(*out_grad) + .AddInput(tmp_out_grad) .AddInput(std::move(x_dims_vec)) .AddOutput(transformed_out_grad) .Run(stream); diff --git a/python/paddle/fluid/tests/unittests/npu/test_arg_max_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_arg_max_op_npu.py index 85ade1179b..c613538372 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_arg_max_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_arg_max_op_npu.py @@ -328,5 +328,32 @@ class TestArgMaxAPI_2(unittest.TestCase): run(place) +class TestArgMaxAPI_3(unittest.TestCase): + def initTestCase(self): + self.dims = (1, 9) + self.dtype = 'float32' + + def setUp(self): + self.initTestCase() + self.__class__.use_npu = True + self.place = [paddle.NPUPlace(0)] + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + numpy_input = (np.random.random(self.dims)).astype(self.dtype) + tensor_input = paddle.to_tensor(numpy_input) + numpy_output = np.argmax(numpy_input).reshape([1]) + paddle_output = paddle.argmax(tensor_input) + self.assertEqual( + np.allclose(numpy_output, paddle_output.numpy()), True) + self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + if __name__ == '__main__': unittest.main() -- GitLab