提交 82992033 编写于 作者: D Double_V 提交者: lvmengsi

Support reshape_op double gradient (#20304)

* support reshape doubel grad, test=develop

* fix reshape double grad, pass converage, test=develop

* fix review, test=develop
上级 faa8e30a
...@@ -363,6 +363,20 @@ class ReshapeGradKernel { ...@@ -363,6 +363,20 @@ class ReshapeGradKernel {
} }
}; };
class ReshapeDoubleGradKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *dd_x = ctx.Input<framework::Tensor>("DDX");
auto *dd_out = ctx.Output<framework::Tensor>("DDOut");
auto out_dims = dd_out->dims();
dd_out->mutable_data(ctx.GetPlace(), dd_x->type());
framework::TensorCopySync(*dd_x, ctx.GetPlace(), dd_out);
dd_out->Resize(out_dims);
}
};
// FIXME(zcd): reshape2 adds an intermediate output(XShape) based on reshape, // FIXME(zcd): reshape2 adds an intermediate output(XShape) based on reshape,
// the XShape is used to carry the shape and lod of X which will be used in // the XShape is used to carry the shape and lod of X which will be used in
// reshape_grad, in this way, the framework can reuse the memory of X // reshape_grad, in this way, the framework can reuse the memory of X
...@@ -409,6 +423,7 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker { ...@@ -409,6 +423,7 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("reshape2_grad"); grad_op->SetType("reshape2_grad");
grad_op->SetInput("X", Input("X"));
grad_op->SetInput("XShape", Output("XShape")); grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput("ShapeTensor", Input("ShapeTensor")); grad_op->SetInput("ShapeTensor", Input("ShapeTensor"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
...@@ -418,6 +433,27 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker { ...@@ -418,6 +433,27 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
class Reshape2DoubleGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("reshape2_grad_grad");
grad_op->SetInput("X", Input("X"));
grad_op->SetInput("ShapeTensor", Input("ShapeTensor"));
grad_op->SetInput("DOut", Input(framework::GradVarName("Out")));
grad_op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
auto ddx = OutputGrad(framework::GradVarName("X"));
grad_op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class Reshape2GradOp : public framework::OperatorWithKernel { class Reshape2GradOp : public framework::OperatorWithKernel {
public: public:
Reshape2GradOp(const std::string &type, Reshape2GradOp(const std::string &type,
...@@ -456,10 +492,47 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -456,10 +492,47 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
} }
}; };
class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
public:
Reshape2DoubleGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("DDX"), true,
"Input(X@GRAD_GRAD) shouldn't be null.");
if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) {
ctx->ShareDim("DOut", "DDOut");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("DDX")->type(),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ShapeTensor") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut, DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -471,6 +544,7 @@ REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, ...@@ -471,6 +544,7 @@ REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
ops::ReshapeOpInplaceInToOut); ops::ReshapeOpInplaceInToOut);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp, REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp,
ops::ReshapeGradInplaceInToOut); ops::ReshapeGradInplaceInToOut);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel); int64_t, ops::ReshapeKernel);
...@@ -478,11 +552,13 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ...@@ -478,11 +552,13 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker, ops::ReshapeOpInplaceInToOut); ops::Reshape2GradMaker, ops::ReshapeOpInplaceInToOut);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::ReshapeGradInplaceInToOut); ops::Reshape2DoubleGradMaker, ops::ReshapeGradInplaceInToOut);
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInToOut);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel); int64_t, ops::ReshapeKernel);
...@@ -490,6 +566,11 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, ...@@ -490,6 +566,11 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad_grad, float,
ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int,
ops::ReshapeDoubleGradKernel, int64_t,
ops::ReshapeDoubleGradKernel);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
...@@ -510,4 +591,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, ...@@ -510,4 +591,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16, ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int,
ops::ReshapeDoubleGradKernel, int64_t,
ops::ReshapeDoubleGradKernel, plat::float16,
ops::ReshapeDoubleGradKernel);
#endif #endif
...@@ -94,5 +94,29 @@ class TestMulDoubleGradCheck(unittest.TestCase): ...@@ -94,5 +94,29 @@ class TestMulDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestReshapeDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
x_shape = [3, 12]
new_shape = [4, 9]
eps = 0.005
dtype = np.float64
x = layers.data('x', x_shape, False, dtype)
x.persistable = True
out = layers.reshape(x, new_shape)
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
gradient_checker.double_grad_check(
[x], out, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册