未验证 提交 badc6f22 编写于 作者: L lijianshe02 提交者: GitHub

add transpose double grad , cherry-pick from #29600 (#30435)

* add transpose double grad test=develop (#29600)

* add transpose double grad test=develop

* cherry-pick test=develop
上级 a64c7c91
...@@ -272,6 +272,20 @@ class Transpose2GradMaker : public framework::SingleGradOpMaker<T> { ...@@ -272,6 +272,20 @@ class Transpose2GradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
template <typename T>
class Transpose2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("transpose2");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetOutput("XShape", this->Input("XShape"));
grad_op->SetAttrMap(this->Attrs());
}
};
class Transpose2OpGrad : public framework::OperatorWithKernel { class Transpose2OpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -338,7 +352,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -338,7 +352,9 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker<paddle::framework::OpDesc>, ops::Transpose2GradMaker<paddle::framework::OpDesc>,
ops::Transpose2GradMaker<paddle::imperative::OpBase>); ops::Transpose2GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad); REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad,
ops::Transpose2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>, transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -329,5 +329,49 @@ class TestUnsqueezeDoubleGradCheck(unittest.TestCase): ...@@ -329,5 +329,49 @@ class TestUnsqueezeDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestTransposeDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
x_shape = [3, 40]
perm = [1, 0]
dtype = np.float64
x = layers.data('x', x_shape, False, dtype)
x.persistable = True
out = paddle.transpose(x, perm)
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
gradient_checker.double_grad_check([x], out, x_init=x_arr, place=place)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestTransposeDoubleGradCheckCase1(unittest.TestCase):
@prog_scope()
def func(self, place):
x_shape = [2, 3, 4, 5]
perm = [0, 2, 3, 1]
dtype = np.float64
x = layers.data('x', x_shape, False, dtype)
x.persistable = True
out = paddle.transpose(x, perm)
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
gradient_checker.double_grad_check([x], out, x_init=x_arr, place=place)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册