提交 43d15b9d 编写于 作者: K Krzysztof Binias 提交者: Tao Luo

Enable square operator for the nGraph Bridge. (#17551)

test=develop
上级 ff5fdc0b
...@@ -37,6 +37,16 @@ void BuildReluGradNode( ...@@ -37,6 +37,16 @@ void BuildReluGradNode(
platform::SetOutputNode(op, "X@GRAD", relu_grad, ngb_node_map); platform::SetOutputNode(op, "X@GRAD", relu_grad, ngb_node_map);
} }
void BuildSquareNode(
const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = platform::GetInputNode(op, "X", ngb_node_map);
auto out = input * input;
platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
void BuildTanhGradNode( void BuildTanhGradNode(
const std::shared_ptr<framework::OperatorBase>& op, const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr< std::shared_ptr<
...@@ -55,4 +65,5 @@ void BuildTanhGradNode( ...@@ -55,4 +65,5 @@ void BuildTanhGradNode(
} // namespace paddle } // namespace paddle
REGISTER_NG_OP(relu_grad, BuildReluGradNode); REGISTER_NG_OP(relu_grad, BuildReluGradNode);
REGISTER_NG_OP(square, BuildSquareNode);
REGISTER_NG_OP(tanh_grad, BuildTanhGradNode); REGISTER_NG_OP(tanh_grad, BuildTanhGradNode);
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_activation_op import TestAbs, TestSigmoid, TestRelu, TestTanh from paddle.fluid.tests.unittests.test_activation_op import TestAbs, TestSigmoid, TestSquare, TestRelu, TestTanh
class TestNGRAPHReluDim4(TestRelu): class TestNGRAPHReluDim4(TestRelu):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册