From e782b54b9c6f7355946a798b74d5499d35cc38a7 Mon Sep 17 00:00:00 2001 From: baojun <32073718+baojun-nervana@users.noreply.github.com> Date: Tue, 7 May 2019 00:34:34 -0700 Subject: [PATCH] update sofmax with axis arg test=develop (#17190) --- .../fluid/operators/ngraph/ops/softmax_op.h | 44 ++++++++++--------- .../ngraph/test_softmax_ngraph_op.py | 2 +- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/ngraph/ops/softmax_op.h b/paddle/fluid/operators/ngraph/ops/softmax_op.h index 6eb84703998..e1f6e8d3cfd 100644 --- a/paddle/fluid/operators/ngraph/ops/softmax_op.h +++ b/paddle/fluid/operators/ngraph/ops/softmax_op.h @@ -27,38 +27,38 @@ namespace paddle { namespace operators { namespace ngraphs { -std::shared_ptr GetSoftmax(std::shared_ptr x) { +std::shared_ptr GetSoftmax(std::shared_ptr x, + int axis = -1) { auto x_shape = x->get_shape(); - int rank = x_shape.size(); - auto x_2d_shape = paddle::platform::FlattenTo2d(x_shape, rank - 1); - x = paddle::platform::NgReshaper(x, x_2d_shape); + size_t rank = x_shape.size(); + size_t softmax_axis = axis; + if (axis < 0) softmax_axis = rank + axis; - auto x_max = std::make_shared(x, ngraph::AxisSet{1}); + auto x_max = + std::make_shared(x, ngraph::AxisSet{softmax_axis}); auto x_max_bcast = std::make_shared( - x_max, x_2d_shape, ngraph::AxisSet{1}); + x_max, x_shape, ngraph::AxisSet{softmax_axis}); auto x_shifted = x - x_max_bcast; auto x_clipped = paddle::operators::ngraphs::ElementwiseScalar( -64., x_shifted); - auto softmax = - std::make_shared(x_clipped, ngraph::AxisSet{1}); + auto softmax = std::make_shared( + x_clipped, ngraph::AxisSet{softmax_axis}); return softmax; } -std::shared_ptr GetSoftmaxGrad( - std::shared_ptr out, std::shared_ptr dout) { +std::shared_ptr GetSoftmaxGrad(std::shared_ptr out, + std::shared_ptr dout, + int axis = -1) { auto out_shape = out->get_shape(); - int rank = out_shape.size(); - auto out_2d_shape = paddle::platform::FlattenTo2d(out_shape, rank - 1); - auto dout_2d_shape = - paddle::platform::FlattenTo2d(dout->get_shape(), rank - 1); - out = paddle::platform::NgReshaper(out, out_2d_shape); - dout = paddle::platform::NgReshaper(dout, dout_2d_shape); + size_t rank = out_shape.size(); + size_t softmax_axis = axis; + if (axis < 0) softmax_axis = rank + axis; - auto node_sum = - std::make_shared(out * dout, ngraph::AxisSet{1}); + auto node_sum = std::make_shared( + out * dout, ngraph::AxisSet{softmax_axis}); auto node_bcast = std::make_shared( - node_sum, out_2d_shape, ngraph::AxisSet{1}); + node_sum, out_shape, ngraph::AxisSet{softmax_axis}); auto dx = (dout - node_bcast) * out; return dx; } @@ -68,8 +68,9 @@ void BuildSoftmaxNode( std::shared_ptr< std::unordered_map>> ngb_node_map) { + auto op_attrs = framework::AttrReader(op->Attrs()); auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); - auto softmax = GetSoftmax(x); + auto softmax = GetSoftmax(x, op_attrs.Get("axis")); paddle::platform::SetOutputNode(op, "Out", softmax, ngb_node_map); } @@ -78,9 +79,10 @@ void BuildSoftmaxGradNode( std::shared_ptr< std::unordered_map>> ngb_node_map) { + auto op_attrs = framework::AttrReader(op->Attrs()); auto out = paddle::platform::GetInputNode(op, "Out", ngb_node_map); auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); - auto dx = GetSoftmaxGrad(out, dout); + auto dx = GetSoftmaxGrad(out, dout, op_attrs.Get("axis")); paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); } } // namespace ngraphs diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_softmax_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_softmax_ngraph_op.py index 0cb08842df0..09c52e2b108 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_softmax_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_softmax_ngraph_op.py @@ -14,7 +14,7 @@ from __future__ import print_function import unittest -from paddle.fluid.tests.unittests.test_softmax_op import TestSoftmaxOp +from paddle.fluid.tests.unittests.test_softmax_op import TestSoftmaxOp, TestSoftmaxOp2, TestSoftmaxOp3, TestSoftmaxOp4, TestSoftmaxOp5 if __name__ == "__main__": unittest.main() -- GitLab