提交 b1bd483a 编写于 作者: K Krzysztof Binias 提交者: tensor-tang

[NGraph] Enable gelu operator for the nGraph Bridge. (#17547)

test=develop
上级 8bd651b7
......@@ -26,6 +26,52 @@ namespace paddle {
namespace operators {
namespace ngraphs {
void BuildGeluNode(
const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = platform::GetInputNode(op, "X", ngb_node_map);
auto half = paddle::platform::CreateConstant(input->get_element_type(),
input->get_shape(), {0.5});
auto one = paddle::platform::CreateConstant(input->get_element_type(),
input->get_shape(), {1});
auto sqrt_two =
std::make_shared<ngraph::op::Sqrt>(paddle::platform::CreateConstant(
input->get_element_type(), input->get_shape(), {2}));
auto out = half * input *
(one + std::make_shared<ngraph::op::Erf>(input / sqrt_two));
platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
void BuildGeluGradNode(
const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = platform::GetInputNode(op, "X", ngb_node_map);
auto dout = platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
auto half = paddle::platform::CreateConstant(input->get_element_type(),
input->get_shape(), {0.5});
auto minus_half = paddle::platform::CreateConstant(
input->get_element_type(), input->get_shape(), {-0.5});
auto one = paddle::platform::CreateConstant(input->get_element_type(),
input->get_shape(), {1});
auto two = paddle::platform::CreateConstant(input->get_element_type(),
input->get_shape(), {2});
auto pi = paddle::platform::CreateConstant(
input->get_element_type(), input->get_shape(), {3.14159265359});
auto sqrt_two = std::make_shared<ngraph::op::Sqrt>(two);
auto sqrt_pi = std::make_shared<ngraph::op::Sqrt>(pi);
auto first =
half * (one + std::make_shared<ngraph::op::Erf>(input * one / sqrt_two));
auto second = half * (two / sqrt_pi) * (one / sqrt_two) * input *
std::make_shared<ngraph::op::Exp>(minus_half * input * input);
auto gelu_grad = dout * (first + second);
platform::SetOutputNode(op, "X@GRAD", gelu_grad, ngb_node_map);
}
void BuildReluGradNode(
const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr<
......@@ -64,6 +110,8 @@ void BuildTanhGradNode(
} // namespace operators
} // namespace paddle
REGISTER_NG_OP(gelu, BuildGeluNode);
REGISTER_NG_OP(gelu_grad, BuildGeluGradNode);
REGISTER_NG_OP(relu_grad, BuildReluGradNode);
REGISTER_NG_OP(square, BuildSquareNode);
REGISTER_NG_OP(tanh_grad, BuildTanhGradNode);
......@@ -18,7 +18,7 @@ import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_activation_op import TestAbs, TestSigmoid, TestSquare, TestRelu, TestTanh
from paddle.fluid.tests.unittests.test_activation_op import TestAbs, TestGelu, TestSigmoid, TestSquare, TestRelu, TestTanh
class TestNGRAPHReluDim4(TestRelu):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册