diff --git a/paddle/operators/sequence_project_op.cc b/paddle/operators/sequence_project_op.cc index b1351e8ac5311b2da98c2929202c99afd7fdbe13..8baae0f1d8b3ba12eeba02aa00049a92dd443da2 100644 --- a/paddle/operators/sequence_project_op.cc +++ b/paddle/operators/sequence_project_op.cc @@ -71,6 +71,8 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Gradient of Out should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Gradient of input(X@GRAD) should not be null."); if (ctx->Attrs().Get("padding_trainable")) { PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("PaddingData")), diff --git a/paddle/operators/sequence_project_op.h b/paddle/operators/sequence_project_op.h index 901939222e2383467ffe2557362dc76b8881692d..b31768b55833de8e90ff36739b9aee6c17398bbf 100644 --- a/paddle/operators/sequence_project_op.h +++ b/paddle/operators/sequence_project_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" #include "paddle/operators/strided_memcpy.h" namespace paddle { @@ -177,6 +178,10 @@ class SequenceProjectGradKernel : public framework::OpKernel { auto* in_g = context.Output(framework::GradVarName("X")); auto* in = context.Input("X"); in_g->mutable_data(context.GetPlace()); + if (in_g) { + math::SetConstant functor; + functor(context.device_context(), in_g, 0); + } auto place = context.GetEigenDevice(); int context_start = context.Attr("context_start"); @@ -204,6 +209,8 @@ class SequenceProjectGradKernel : public framework::OpKernel { padding_width = padding_data_g->dims()[1]; PADDLE_ENFORCE(padding_width == input_width, "Input size and pooling size should be consistent."); + math::SetConstant functor; + functor(context.device_context(), padding_data_g, 0); } int up_pad = std::max(0, -context_start); @@ -282,7 +289,7 @@ class SequenceProjectGradKernel : public framework::OpKernel { } } - if (in && input_row_begin < input_row_end) { + if (in_g && input_row_begin < input_row_end) { Tensor in_t = in_g->Slice(input_row_begin, input_row_end); std::vector output_shape( diff --git a/python/paddle/v2/framework/tests/test_seq_project.py b/python/paddle/v2/framework/tests/test_seq_project.py index e97a143c4690b27f8995ad5bd1afead0faf5407b..c783aff51628c21b4d97f80848b6f76f4c0e1664 100644 --- a/python/paddle/v2/framework/tests/test_seq_project.py +++ b/python/paddle/v2/framework/tests/test_seq_project.py @@ -87,9 +87,9 @@ class TestSeqProject(OpTest): def test_check_output(self): self.check_output() - # def test_check_grad(self): - # self.check_grad( - # set(['X', 'PaddingData']), 'Out', max_relative_error=0.05) + def test_check_grad(self): + self.check_grad( + set(['X', 'PaddingData']), 'Out', max_relative_error=0.05) # def test_check_grad_no_filter(self): # self.check_grad(