diff --git a/paddle/fluid/operators/ngraph/ops/elementwise_div_op.h b/paddle/fluid/operators/ngraph/ops/elementwise_div_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b4cc2f862ba8cfbb26c21d41f061dfdd10f11903 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/elementwise_div_op.h @@ -0,0 +1,103 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/operators/ngraph/ops/elementwise_node.h" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +void BuildElementwiseDivGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + int axis = op_attrs.Get("axis"); + + auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); + auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map); + auto out = paddle::platform::GetInputNode(op, "Out", ngb_node_map); + auto dout_shape = dout->get_shape(); + auto y_shape = y->get_shape(); + if (dout->get_element_type() != y->get_element_type()) { + y = std::make_shared(y, dout->get_element_type()); + } + auto dy_hd = std::make_shared(out, dout); + if (dout_shape == y_shape) { + auto dx = std::make_shared(dout, y); + auto dy = std::make_shared(dy_hd, -y); + paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); + paddle::platform::SetOutputNode(op, "Y@GRAD", dy, ngb_node_map); + } else { + auto dy_hd_shape = dy_hd->get_shape(); + axis = (axis == -1 ? dy_hd_shape.size() - y_shape.size() : axis); + paddle::platform::TrimTrailingSingularDims(&y_shape); + axis = (y_shape.size() == 0 ? dy_hd_shape.size() : axis); + int pre, n, post; + paddle::platform::GetMidDims(dy_hd_shape, y_shape, axis, &pre, &n, &post); + ngraph::Shape lhs_shape{}; + lhs_shape.push_back(pre); + lhs_shape.push_back(n); + if (post != 1) { + lhs_shape.push_back(post); + } + + std::vector dy_order(dout_shape.size()); + std::iota(std::begin(dy_order), std::end(dy_order), 0); + auto dy_hd_reshape = std::make_shared( + dy_hd, ngraph::AxisVector(dy_order), lhs_shape); + + ngraph::AxisSet axis_set{0}; + if (post != 1) { + axis_set.insert(2); + } + + auto dy_sum = std::make_shared(dy_hd_reshape, axis_set); + auto dy_sum_yshape = std::make_shared( + dy_sum, ngraph::AxisVector{0}, y->get_shape()); + auto dy_ = std::make_shared(dy_sum_yshape, -y); + paddle::platform::SetOutputNode(op, "Y@GRAD", dy_, ngb_node_map); + + y_shape = y->get_shape(); + std::vector y_order(y_shape.size() == 0 ? 1 : y_shape.size()); + std::iota(std::begin(y_order), std::end(y_order), 0); + auto y_reshape = std::make_shared( + y, ngraph::AxisVector(y_order), ngraph::Shape{(size_t)n}); + auto y_broadcast = + std::make_shared(y_reshape, lhs_shape, axis_set); + std::vector lhs_order(lhs_shape.size()); + std::iota(std::begin(lhs_order), std::end(lhs_order), 0); + auto y_broadcast_reshape = std::make_shared( + y_broadcast, ngraph::AxisVector(lhs_order), dout_shape); + auto dx = std::make_shared(dout, y_broadcast_reshape); + + paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); + } +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(elementwise_div_grad, BuildElementwiseDivGradNode); diff --git a/paddle/fluid/operators/ngraph/ops/elementwise_node.h b/paddle/fluid/operators/ngraph/ops/elementwise_node.h index 7fd5e8390029a36b0c0e00df175f3950c04fbae4..2b10af4588c350e8581e304cdfdd075f56be53fd 100644 --- a/paddle/fluid/operators/ngraph/ops/elementwise_node.h +++ b/paddle/fluid/operators/ngraph/ops/elementwise_node.h @@ -61,6 +61,7 @@ void BuildElementwiseCompareNode( auto out = std::make_shared(x, y); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); } + } // namespace ngraphs } // namespace operators } // namespace paddle @@ -73,3 +74,4 @@ REGISTER_NG_OP(elementwise_sub, REGISTER_NG_OP(elementwise_min, BuildElementwiseBinaryNode); REGISTER_NG_OP(less_than, BuildElementwiseCompareNode); +REGISTER_NG_OP(elementwise_div, BuildElementwiseBinaryNode); diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_div_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_div_ngraph_op.py new file mode 100644 index 0000000000000000000000000000000000000000..55a2a05e23f56b6bc33c979ab46027212e506882 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_div_ngraph_op.py @@ -0,0 +1,22 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest, sys +sys.path.append("../") +from test_elementwise_div_op import ElementwiseDivOp, TestElementwiseDivOp_scalar, TestElementwiseDivOp_Vector, TestElementwiseDivOp_broadcast_0, TestElementwiseDivOp_broadcast_1, TestElementwiseDivOp_broadcast_2, TestElementwiseDivOp_broadcast_3 + +if __name__ == '__main__': + unittest.main()