diff --git a/paddle/fluid/operators/ngraph/ops/softmax_op.h b/paddle/fluid/operators/ngraph/ops/softmax_op.h index 6eb84703998c24ee7b9e0d4f6931c3fe0bd00e2e..e1f6e8d3cfdc56c00229bbe1c3b183c309d0394e 100644 --- a/paddle/fluid/operators/ngraph/ops/softmax_op.h +++ b/paddle/fluid/operators/ngraph/ops/softmax_op.h @@ -27,38 +27,38 @@ namespace paddle { namespace operators { namespace ngraphs { -std::shared_ptr GetSoftmax(std::shared_ptr x) { +std::shared_ptr GetSoftmax(std::shared_ptr x, + int axis = -1) { auto x_shape = x->get_shape(); - int rank = x_shape.size(); - auto x_2d_shape = paddle::platform::FlattenTo2d(x_shape, rank - 1); - x = paddle::platform::NgReshaper(x, x_2d_shape); + size_t rank = x_shape.size(); + size_t softmax_axis = axis; + if (axis < 0) softmax_axis = rank + axis; - auto x_max = std::make_shared(x, ngraph::AxisSet{1}); + auto x_max = + std::make_shared(x, ngraph::AxisSet{softmax_axis}); auto x_max_bcast = std::make_shared( - x_max, x_2d_shape, ngraph::AxisSet{1}); + x_max, x_shape, ngraph::AxisSet{softmax_axis}); auto x_shifted = x - x_max_bcast; auto x_clipped = paddle::operators::ngraphs::ElementwiseScalar( -64., x_shifted); - auto softmax = - std::make_shared(x_clipped, ngraph::AxisSet{1}); + auto softmax = std::make_shared( + x_clipped, ngraph::AxisSet{softmax_axis}); return softmax; } -std::shared_ptr GetSoftmaxGrad( - std::shared_ptr out, std::shared_ptr dout) { +std::shared_ptr GetSoftmaxGrad(std::shared_ptr out, + std::shared_ptr dout, + int axis = -1) { auto out_shape = out->get_shape(); - int rank = out_shape.size(); - auto out_2d_shape = paddle::platform::FlattenTo2d(out_shape, rank - 1); - auto dout_2d_shape = - paddle::platform::FlattenTo2d(dout->get_shape(), rank - 1); - out = paddle::platform::NgReshaper(out, out_2d_shape); - dout = paddle::platform::NgReshaper(dout, dout_2d_shape); + size_t rank = out_shape.size(); + size_t softmax_axis = axis; + if (axis < 0) softmax_axis = rank + axis; - auto node_sum = - std::make_shared(out * dout, ngraph::AxisSet{1}); + auto node_sum = std::make_shared( + out * dout, ngraph::AxisSet{softmax_axis}); auto node_bcast = std::make_shared( - node_sum, out_2d_shape, ngraph::AxisSet{1}); + node_sum, out_shape, ngraph::AxisSet{softmax_axis}); auto dx = (dout - node_bcast) * out; return dx; } @@ -68,8 +68,9 @@ void BuildSoftmaxNode( std::shared_ptr< std::unordered_map>> ngb_node_map) { + auto op_attrs = framework::AttrReader(op->Attrs()); auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); - auto softmax = GetSoftmax(x); + auto softmax = GetSoftmax(x, op_attrs.Get("axis")); paddle::platform::SetOutputNode(op, "Out", softmax, ngb_node_map); } @@ -78,9 +79,10 @@ void BuildSoftmaxGradNode( std::shared_ptr< std::unordered_map>> ngb_node_map) { + auto op_attrs = framework::AttrReader(op->Attrs()); auto out = paddle::platform::GetInputNode(op, "Out", ngb_node_map); auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); - auto dx = GetSoftmaxGrad(out, dout); + auto dx = GetSoftmaxGrad(out, dout, op_attrs.Get("axis")); paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); } } // namespace ngraphs diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_softmax_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_softmax_ngraph_op.py index 0cb08842df0797952c47a63ba2bbb8614c0e8a22..09c52e2b1084fc5f716a6d1abfb4968d2c5460da 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_softmax_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_softmax_ngraph_op.py @@ -14,7 +14,7 @@ from __future__ import print_function import unittest -from paddle.fluid.tests.unittests.test_softmax_op import TestSoftmaxOp +from paddle.fluid.tests.unittests.test_softmax_op import TestSoftmaxOp, TestSoftmaxOp2, TestSoftmaxOp3, TestSoftmaxOp4, TestSoftmaxOp5 if __name__ == "__main__": unittest.main()