From 43d15b9d96530ec60a631fe252510d6fb90190ab Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Wed, 22 May 2019 09:14:07 +0200 Subject: [PATCH] Enable square operator for the nGraph Bridge. (#17551) test=develop --- paddle/fluid/operators/ngraph/ops/activation_op.h | 11 +++++++++++ .../unittests/ngraph/test_activation_ngraph_op.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/ngraph/ops/activation_op.h b/paddle/fluid/operators/ngraph/ops/activation_op.h index a66ec65a336..ef6c11bce70 100644 --- a/paddle/fluid/operators/ngraph/ops/activation_op.h +++ b/paddle/fluid/operators/ngraph/ops/activation_op.h @@ -37,6 +37,16 @@ void BuildReluGradNode( platform::SetOutputNode(op, "X@GRAD", relu_grad, ngb_node_map); } +void BuildSquareNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = platform::GetInputNode(op, "X", ngb_node_map); + auto out = input * input; + platform::SetOutputNode(op, "Out", out, ngb_node_map); +} + void BuildTanhGradNode( const std::shared_ptr& op, std::shared_ptr< @@ -55,4 +65,5 @@ void BuildTanhGradNode( } // namespace paddle REGISTER_NG_OP(relu_grad, BuildReluGradNode); +REGISTER_NG_OP(square, BuildSquareNode); REGISTER_NG_OP(tanh_grad, BuildTanhGradNode); diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_activation_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_activation_ngraph_op.py index c7d62bd8ae1..3c1db3bf640 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_activation_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_activation_ngraph_op.py @@ -18,7 +18,7 @@ import unittest import numpy as np import paddle.fluid.core as core from paddle.fluid.tests.unittests.op_test import OpTest -from paddle.fluid.tests.unittests.test_activation_op import TestAbs, TestSigmoid, TestRelu, TestTanh +from paddle.fluid.tests.unittests.test_activation_op import TestAbs, TestSigmoid, TestSquare, TestRelu, TestTanh class TestNGRAPHReluDim4(TestRelu): -- GitLab