diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 9b8d86b40424c5a63940c82454bbfcdf59d4ed0a..8dc4a59ba861237680b99c0d686e3dcb21b8071a 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -36,9 +36,10 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { SequencePoolOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "A LoDTensor, the variable-length input of SequencePoolOp"); + AddInput("X", "(LoDTensor), the variable-length input of SequencePoolOp"); AddOutput("Out", - "A LoDTensor, the variable-length output of SequencePoolOp."); + "(Tensor), output of SequencePoolOp, which does not contain LoD " + "infomation."); AddAttr( "strategy", "(int, default AVERAGE) the pooling strategy of SequencePoolOp.") @@ -53,7 +54,7 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { Besides, for the sake of simplicity, we assume M=1 and N=1, and the value of X = [[1, 3], [2, 4, 6], [5, 1]]. - Thus, Out is a [3,1,1] LoDTensor, but Out->lod() is nullptr. + Thus, Out is a [3,1,1] Tensor without LoD infomation. And for different strategy, the value of Out is as follows: - AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index 8bfb80c33fec49b07ff97e92ac3aed8835d098a8..ce68204d41607ed6618a3d61a77cab391960083a 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -109,8 +109,8 @@ class SequencePoolGradKernel : public framework::OpKernel { int64_t w = in->numel() / dims[0]; in_g->mutable_data(context.GetPlace()); - if (strategy > 2) { - // set X@Grad be zero at first when strategy is LAST/FIRST/MAX + if (strategy == LAST || strategy == FIRST) { + // set X@Grad be zero at first when strategy is LAST/FIRST math::SetConstant(context.device_context(), in_g, 0); } auto place = context.GetEigenDevice();