From 6246be294f1f09a9356b1fbb4c7feb0b7f9f20f8 Mon Sep 17 00:00:00 2001
From: chengduoZH <zhaochengduo@163.com>
Date: Sat, 21 Oct 2017 17:02:01 +0800
Subject: [PATCH] clean gradient data

---
 paddle/operators/sequence_project_op.cc              | 2 ++
 paddle/operators/sequence_project_op.h               | 9 ++++++++-
 python/paddle/v2/framework/tests/test_seq_project.py | 6 +++---
 3 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/paddle/operators/sequence_project_op.cc b/paddle/operators/sequence_project_op.cc
index b1351e8ac53..8baae0f1d8b 100644
--- a/paddle/operators/sequence_project_op.cc
+++ b/paddle/operators/sequence_project_op.cc
@@ -71,6 +71,8 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel {
     PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
                    "Gradient of Out should not be null.");
     PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null.");
+    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
+                   "Gradient of input(X@GRAD) should not be null.");
 
     if (ctx->Attrs().Get<bool>("padding_trainable")) {
       PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("PaddingData")),
diff --git a/paddle/operators/sequence_project_op.h b/paddle/operators/sequence_project_op.h
index 901939222e2..b31768b5583 100644
--- a/paddle/operators/sequence_project_op.h
+++ b/paddle/operators/sequence_project_op.h
@@ -16,6 +16,7 @@ limitations under the License. */
 #include "paddle/framework/eigen.h"
 #include "paddle/framework/op_registry.h"
 #include "paddle/operators/math/im2col.h"
+#include "paddle/operators/math/math_function.h"
 #include "paddle/operators/strided_memcpy.h"
 
 namespace paddle {
@@ -177,6 +178,10 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
     auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
     auto* in = context.Input<LoDTensor>("X");
     in_g->mutable_data<T>(context.GetPlace());
+    if (in_g) {
+      math::SetConstant<Place, T> functor;
+      functor(context.device_context(), in_g, 0);
+    }
     auto place = context.GetEigenDevice<Place>();
 
     int context_start = context.Attr<int>("context_start");
@@ -204,6 +209,8 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
       padding_width = padding_data_g->dims()[1];
       PADDLE_ENFORCE(padding_width == input_width,
                      "Input size and pooling size should be consistent.");
+      math::SetConstant<Place, T> functor;
+      functor(context.device_context(), padding_data_g, 0);
     }
 
     int up_pad = std::max(0, -context_start);
@@ -282,7 +289,7 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
         }
       }
 
-      if (in && input_row_begin < input_row_end) {
+      if (in_g && input_row_begin < input_row_end) {
         Tensor in_t = in_g->Slice(input_row_begin, input_row_end);
 
         std::vector<int64_t> output_shape(
diff --git a/python/paddle/v2/framework/tests/test_seq_project.py b/python/paddle/v2/framework/tests/test_seq_project.py
index e97a143c469..c783aff5162 100644
--- a/python/paddle/v2/framework/tests/test_seq_project.py
+++ b/python/paddle/v2/framework/tests/test_seq_project.py
@@ -87,9 +87,9 @@ class TestSeqProject(OpTest):
     def test_check_output(self):
         self.check_output()
 
-        # def test_check_grad(self):
-        #     self.check_grad(
-        #         set(['X', 'PaddingData']), 'Out', max_relative_error=0.05)
+    def test_check_grad(self):
+        self.check_grad(
+            set(['X', 'PaddingData']), 'Out', max_relative_error=0.05)
 
         # def test_check_grad_no_filter(self):
         #     self.check_grad(
-- 
GitLab