From 6a4282a20f1f9c110ea5aef5035a0b733da6db19 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Mon, 16 Oct 2017 20:02:04 +0800 Subject: [PATCH] refine comments of sequence_pool_op --- paddle/operators/sequence_pool_op.cc | 7 ++++--- paddle/operators/sequence_pool_op.h | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 9b8d86b4042..8dc4a59ba86 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -36,9 +36,10 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { SequencePoolOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "A LoDTensor, the variable-length input of SequencePoolOp"); + AddInput("X", "(LoDTensor), the variable-length input of SequencePoolOp"); AddOutput("Out", - "A LoDTensor, the variable-length output of SequencePoolOp."); + "(Tensor), output of SequencePoolOp, which does not contain LoD " + "infomation."); AddAttr( "strategy", "(int, default AVERAGE) the pooling strategy of SequencePoolOp.") @@ -53,7 +54,7 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { Besides, for the sake of simplicity, we assume M=1 and N=1, and the value of X = [[1, 3], [2, 4, 6], [5, 1]]. - Thus, Out is a [3,1,1] LoDTensor, but Out->lod() is nullptr. + Thus, Out is a [3,1,1] Tensor without LoD infomation. And for different strategy, the value of Out is as follows: - AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index 8bfb80c33fe..ce68204d416 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -109,8 +109,8 @@ class SequencePoolGradKernel : public framework::OpKernel { int64_t w = in->numel() / dims[0]; in_g->mutable_data(context.GetPlace()); - if (strategy > 2) { - // set X@Grad be zero at first when strategy is LAST/FIRST/MAX + if (strategy == LAST || strategy == FIRST) { + // set X@Grad be zero at first when strategy is LAST/FIRST math::SetConstant(context.device_context(), in_g, 0); } auto place = context.GetEigenDevice(); -- GitLab