未验证 提交 4d167240 编写于 作者: P pangyoki 提交者: GitHub

[NPU] delete useless GELU in gelu grad npu op (#33872)

* delete useless GELU in gelu npu op

* add description

* fix format

* add check_grad in gelu unittest
上级 e8052710
......@@ -61,13 +61,14 @@ class GeluGradNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
Tensor out(x->type());
out.mutable_data<T>(x->dims(), place);
const auto& runner_out = NpuOpRunner("Gelu", {*x}, {out}, {});
runner_out.Run(stream);
// NOTE(pangyoki): In the original implementation of GeluGrad op, the input
// is {*dout, *x, out}, where out = Gelu(x). However, we find that variable
// `out` was not actually used. In order to improve performance, the
// useless GELU operation was deleted.
// We directly use `*dout` as a placeholder to replace `out`, it will not
// be used in calculations.
const auto& runner_dx =
NpuOpRunner("GeluGrad", {*dout, *x, out}, {*dx}, {});
NpuOpRunner("GeluGrad", {*dout, *x, *dout}, {*dx}, {});
runner_dx.Run(stream);
}
};
......
......@@ -58,12 +58,9 @@ class TestGelu(OpTest):
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3)
# TODO(ascendrc): Add grad test
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
#
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
......@@ -115,10 +112,10 @@ class TestGeluNet(unittest.TestCase):
name="label", shape=[32, 1], dtype='int64')
c = paddle.multiply(a, b)
d = fluid.layers.gelu(c)
fc_1 = fluid.layers.fc(input=d, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
fc_1 = fluid.layers.fc(input=c, size=128)
fc_1_gelu = fluid.layers.gelu(fc_1)
prediction = fluid.layers.fc(input=fc_1_gelu, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册