提交 6a4282a2 编写于 作者: L Luo Tao

refine comments of sequence_pool_op

上级 216b81ac
...@@ -36,9 +36,10 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -36,9 +36,10 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
SequencePoolOpMaker(framework::OpProto* proto, SequencePoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "A LoDTensor, the variable-length input of SequencePoolOp"); AddInput("X", "(LoDTensor), the variable-length input of SequencePoolOp");
AddOutput("Out", AddOutput("Out",
"A LoDTensor, the variable-length output of SequencePoolOp."); "(Tensor), output of SequencePoolOp, which does not contain LoD "
"infomation.");
AddAttr<int>( AddAttr<int>(
"strategy", "strategy",
"(int, default AVERAGE) the pooling strategy of SequencePoolOp.") "(int, default AVERAGE) the pooling strategy of SequencePoolOp.")
...@@ -53,7 +54,7 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -53,7 +54,7 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
Besides, for the sake of simplicity, we assume M=1 and N=1, Besides, for the sake of simplicity, we assume M=1 and N=1,
and the value of X = [[1, 3], [2, 4, 6], [5, 1]]. and the value of X = [[1, 3], [2, 4, 6], [5, 1]].
Thus, Out is a [3,1,1] LoDTensor, but Out->lod() is nullptr. Thus, Out is a [3,1,1] Tensor without LoD infomation.
And for different strategy, the value of Out is as follows: And for different strategy, the value of Out is as follows:
- AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 - AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2
......
...@@ -109,8 +109,8 @@ class SequencePoolGradKernel : public framework::OpKernel<T> { ...@@ -109,8 +109,8 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
int64_t w = in->numel() / dims[0]; int64_t w = in->numel() / dims[0];
in_g->mutable_data<T>(context.GetPlace()); in_g->mutable_data<T>(context.GetPlace());
if (strategy > 2) { if (strategy == LAST || strategy == FIRST) {
// set X@Grad be zero at first when strategy is LAST/FIRST/MAX // set X@Grad be zero at first when strategy is LAST/FIRST
math::SetConstant<Place, T>(context.device_context(), in_g, 0); math::SetConstant<Place, T>(context.device_context(), in_g, 0);
} }
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册