From 6246be294f1f09a9356b1fbb4c7feb0b7f9f20f8 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 21 Oct 2017 17:02:01 +0800 Subject: [PATCH] clean gradient data --- paddle/operators/sequence_project_op.cc | 2 ++ paddle/operators/sequence_project_op.h | 9 ++++++++- python/paddle/v2/framework/tests/test_seq_project.py | 6 +++--- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/paddle/operators/sequence_project_op.cc b/paddle/operators/sequence_project_op.cc index b1351e8ac53..8baae0f1d8b 100644 --- a/paddle/operators/sequence_project_op.cc +++ b/paddle/operators/sequence_project_op.cc @@ -71,6 +71,8 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Gradient of Out should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Gradient of input(X@GRAD) should not be null."); if (ctx->Attrs().Get("padding_trainable")) { PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("PaddingData")), diff --git a/paddle/operators/sequence_project_op.h b/paddle/operators/sequence_project_op.h index 901939222e2..b31768b5583 100644 --- a/paddle/operators/sequence_project_op.h +++ b/paddle/operators/sequence_project_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" #include "paddle/operators/strided_memcpy.h" namespace paddle { @@ -177,6 +178,10 @@ class SequenceProjectGradKernel : public framework::OpKernel { auto* in_g = context.Output(framework::GradVarName("X")); auto* in = context.Input("X"); in_g->mutable_data(context.GetPlace()); + if (in_g) { + math::SetConstant functor; + functor(context.device_context(), in_g, 0); + } auto place = context.GetEigenDevice(); int context_start = context.Attr("context_start"); @@ -204,6 +209,8 @@ class SequenceProjectGradKernel : public framework::OpKernel { padding_width = padding_data_g->dims()[1]; PADDLE_ENFORCE(padding_width == input_width, "Input size and pooling size should be consistent."); + math::SetConstant functor; + functor(context.device_context(), padding_data_g, 0); } int up_pad = std::max(0, -context_start); @@ -282,7 +289,7 @@ class SequenceProjectGradKernel : public framework::OpKernel { } } - if (in && input_row_begin < input_row_end) { + if (in_g && input_row_begin < input_row_end) { Tensor in_t = in_g->Slice(input_row_begin, input_row_end); std::vector output_shape( diff --git a/python/paddle/v2/framework/tests/test_seq_project.py b/python/paddle/v2/framework/tests/test_seq_project.py index e97a143c469..c783aff5162 100644 --- a/python/paddle/v2/framework/tests/test_seq_project.py +++ b/python/paddle/v2/framework/tests/test_seq_project.py @@ -87,9 +87,9 @@ class TestSeqProject(OpTest): def test_check_output(self): self.check_output() - # def test_check_grad(self): - # self.check_grad( - # set(['X', 'PaddingData']), 'Out', max_relative_error=0.05) + def test_check_grad(self): + self.check_grad( + set(['X', 'PaddingData']), 'Out', max_relative_error=0.05) # def test_check_grad_no_filter(self): # self.check_grad( -- GitLab