未验证 提交 c5f957ae 编写于 作者: L lilong12 提交者: GitHub

add double grad for tile op and expand_v2 op (#27114)

* add double grad for tile, test=develop

* add double grad for expand_v2 op, test=develop
上级 58a88ba9
...@@ -230,6 +230,26 @@ class ExpandV2GradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -230,6 +230,26 @@ class ExpandV2GradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
template <typename T>
class ExpandV2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("expand_v2");
op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
if (this->HasInput("expand_shapes_tensor")) {
op->SetInput("expand_shapes_tensor", this->Input("expand_shapes_tensor"));
}
if (this->HasInput("Shape")) {
op->SetInput("Shape", this->Input("Shape"));
}
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandV2GradNoNeedBufVarsInferer, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandV2GradNoNeedBufVarsInferer, "X");
} // namespace operators } // namespace operators
...@@ -240,6 +260,8 @@ REGISTER_OPERATOR(expand_v2, ops::ExpandV2Op, ops::ExpandV2OpMaker, ...@@ -240,6 +260,8 @@ REGISTER_OPERATOR(expand_v2, ops::ExpandV2Op, ops::ExpandV2OpMaker,
ops::ExpandV2GradOpMaker<paddle::framework::OpDesc>, ops::ExpandV2GradOpMaker<paddle::framework::OpDesc>,
ops::ExpandV2GradOpMaker<paddle::imperative::OpBase>); ops::ExpandV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(expand_v2_grad, ops::ExpandV2GradOp, REGISTER_OPERATOR(expand_v2_grad, ops::ExpandV2GradOp,
ops::ExpandV2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ExpandV2DoubleGradOpMaker<paddle::imperative::OpBase>,
ops::ExpandV2GradNoNeedBufVarsInferer); ops::ExpandV2GradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
expand_v2, ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, float>, expand_v2, ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -241,6 +241,26 @@ class TileGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -241,6 +241,26 @@ class TileGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
template <typename T>
class TileDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("tile");
op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
if (this->HasInput("repeat_times_tensor")) {
op->SetInput("repeat_times_tensor", this->Input("repeat_times_tensor"));
}
if (this->HasInput("RepeatTimes")) {
op->SetInput("RepeatTimes", this->Input("RepeatTimes"));
}
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(TileGradNoNeedBufVarsInferer, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(TileGradNoNeedBufVarsInferer, "X");
} // namespace operators } // namespace operators
...@@ -251,6 +271,8 @@ REGISTER_OPERATOR(tile, ops::TileOp, ops::TileOpMaker, ...@@ -251,6 +271,8 @@ REGISTER_OPERATOR(tile, ops::TileOp, ops::TileOpMaker,
ops::TileGradOpMaker<paddle::framework::OpDesc>, ops::TileGradOpMaker<paddle::framework::OpDesc>,
ops::TileGradOpMaker<paddle::imperative::OpBase>); ops::TileGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(tile_grad, ops::TileGradOp, REGISTER_OPERATOR(tile_grad, ops::TileGradOp,
ops::TileDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::TileDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::TileGradNoNeedBufVarsInferer); ops::TileGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
tile, ops::TileKernel<paddle::platform::CPUDeviceContext, float>, tile, ops::TileKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -200,5 +201,53 @@ class TestExpandDoubleGradCheck(unittest.TestCase): ...@@ -200,5 +201,53 @@ class TestExpandDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestTileDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
x_shape = [3, 12]
repeat_times = [4, 9]
eps = 0.005
dtype = np.float64
x = layers.data('x', x_shape, False, dtype)
x.persistable = True
out = paddle.tile(x, repeat_times)
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
gradient_checker.double_grad_check(
[x], out, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestExpandV2DoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
x_shape = [1, 12]
new_shape = [4, 12]
eps = 0.005
dtype = np.float64
x = layers.data('x', x_shape, False, dtype)
x.persistable = True
out = paddle.expand(x, new_shape)
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
gradient_checker.double_grad_check(
[x], out, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册