From 8e086a8521ab1aa8d7b2632b71b9193df630d0a4 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 9 Jan 2019 14:51:51 +0000 Subject: [PATCH] follow comment and fix typo test=develop --- .../framework/ir/seqpool_concat_fuse_pass.cc | 26 +++++++++++-------- .../framework/ir/seqpool_concat_fuse_pass.h | 14 ++++++++++ .../fused/fusion_seqpool_concat_op.cc | 10 ++++--- 3 files changed, 35 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc index 20b8220033..7dd6f4880a 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc @@ -39,21 +39,25 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, auto is_seqpool_op_with_pootype_of_nth_input_of_concat = [=]( Node* x, const std::string& type, int idx) -> bool { - bool ok = x && x->IsOp() && x->Op()->Type() == "sequence_pool" && - x->Op()->HasAttr("pooltype") && - boost::get(x->Op()->GetAttr("pooltype")) == type && - x->outputs.size() == 2; // seqpool should only have 2 outputs - if (ok) { - // only one output of seqpool_op is nth_input_var of concat - // the other one should be unused empty var + bool this_is_seqpool_op = + x && x->IsOp() && x->Op()->Type() == "sequence_pool" && + x->Op()->HasAttr("pooltype") && + boost::get(x->Op()->GetAttr("pooltype")) == type && + x->outputs.size() == 2; // seqpool should only have 2 outputs + bool satisfied_all = this_is_seqpool_op; + if (this_is_seqpool_op) { + // Only one output of seqpool_op is nth_input_var of concat, + // the other one should be unused empty var. if (is_nth_input_var_of_concat(x->outputs[0], idx)) { - ok = ok && x->outputs[1]->IsVar() && x->outputs[1]->outputs.size() == 0; + satisfied_all = satisfied_all && x->outputs[1]->IsVar() && + x->outputs[1]->outputs.size() == 0; } else { - ok = ok && is_nth_input_var_of_concat(x->outputs[1], idx) && - x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0; + satisfied_all = + satisfied_all && is_nth_input_var_of_concat(x->outputs[1], idx) && + x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0; } } - return ok; + return satisfied_all; }; auto* concat_op = pattern->NewNode( diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h index 59730fde55..ba2154045e 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h @@ -23,6 +23,20 @@ namespace paddle { namespace framework { namespace ir { +/** + * Fuse SequencePool(with sum pooltype yet) and Concat; + * + * Before fuse: + * | | | + * seq_pool, seq_pool, ... seq_pool + * \ | ... / + * concat + * | + * After fuse: + * \ | / + * FusionSeqPoolConcat + * | + */ class SeqPoolConcatFusePass : public FusePassBase { public: virtual ~SeqPoolConcatFusePass() {} diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc index 578ff6b2d0..b181140db7 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc @@ -23,7 +23,7 @@ namespace operators { void FusionSeqPoolConcatOp::InferShape( framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, - "Inputs(X) of FusionSeqPoolConcatOp should be empty."); + "Inputs(X) of FusionSeqPoolConcatOp should not be empty."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FusionSeqPoolConcatOp should not be null."); int axis = ctx->Attrs().Get("axis"); @@ -54,12 +54,13 @@ void FusionSeqPoolConcatOpMaker::Make() { AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable(); AddOutput("Out", "(LoDTensor) Output tensor of concat operator."); AddAttr("pooltype", - "(string, default 'AVERAGE') some of the pooling " + "(string, default 'SUM') some of the pooling " "pooltype of SequencePoolOp.") .SetDefault("SUM") .InEnum({"AVERAGE", "SUM", "SQRT"}); AddAttr("axis", - "The axis along which the input tensors will be concatenated.") + "The axis along which the input tensors will be concatenated. " + "Only supports concat axis=1 yet.") .SetDefault(1); AddComment(R"DOC( Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator. @@ -100,6 +101,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel { jit::Get, platform::CPUPlace>( attr); size_t n = ins.size(); + size_t dst_step_size = n * w; for (size_t i = 0; i < n; ++i) { auto x_dims = ins[i]->dims(); auto x_lod = ins[i]->lod()[0]; @@ -112,7 +114,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel { for (size_t j = 0; j < bs; ++j) { attr.h = static_cast(x_lod[j + 1] - x_lod[j]); seqpool(src, dst, &attr); - dst += n * w; + dst += dst_step_size; src += attr.h * attr.w; } } -- GitLab