提交 8e086a85 编写于 作者: T tensor-tang

follow comment and fix typo

test=develop
上级 54afcb7e
...@@ -39,21 +39,25 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, ...@@ -39,21 +39,25 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
auto is_seqpool_op_with_pootype_of_nth_input_of_concat = [=]( auto is_seqpool_op_with_pootype_of_nth_input_of_concat = [=](
Node* x, const std::string& type, int idx) -> bool { Node* x, const std::string& type, int idx) -> bool {
bool ok = x && x->IsOp() && x->Op()->Type() == "sequence_pool" && bool this_is_seqpool_op =
x && x->IsOp() && x->Op()->Type() == "sequence_pool" &&
x->Op()->HasAttr("pooltype") && x->Op()->HasAttr("pooltype") &&
boost::get<std::string>(x->Op()->GetAttr("pooltype")) == type && boost::get<std::string>(x->Op()->GetAttr("pooltype")) == type &&
x->outputs.size() == 2; // seqpool should only have 2 outputs x->outputs.size() == 2; // seqpool should only have 2 outputs
if (ok) { bool satisfied_all = this_is_seqpool_op;
// only one output of seqpool_op is nth_input_var of concat if (this_is_seqpool_op) {
// the other one should be unused empty var // Only one output of seqpool_op is nth_input_var of concat,
// the other one should be unused empty var.
if (is_nth_input_var_of_concat(x->outputs[0], idx)) { if (is_nth_input_var_of_concat(x->outputs[0], idx)) {
ok = ok && x->outputs[1]->IsVar() && x->outputs[1]->outputs.size() == 0; satisfied_all = satisfied_all && x->outputs[1]->IsVar() &&
x->outputs[1]->outputs.size() == 0;
} else { } else {
ok = ok && is_nth_input_var_of_concat(x->outputs[1], idx) && satisfied_all =
satisfied_all && is_nth_input_var_of_concat(x->outputs[1], idx) &&
x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0; x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0;
} }
} }
return ok; return satisfied_all;
}; };
auto* concat_op = pattern->NewNode( auto* concat_op = pattern->NewNode(
......
...@@ -23,6 +23,20 @@ namespace paddle { ...@@ -23,6 +23,20 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
/**
* Fuse SequencePool(with sum pooltype yet) and Concat;
*
* Before fuse:
* | | |
* seq_pool, seq_pool, ... seq_pool
* \ | ... /
* concat
* |
* After fuse:
* \ | /
* FusionSeqPoolConcat
* |
*/
class SeqPoolConcatFusePass : public FusePassBase { class SeqPoolConcatFusePass : public FusePassBase {
public: public:
virtual ~SeqPoolConcatFusePass() {} virtual ~SeqPoolConcatFusePass() {}
......
...@@ -23,7 +23,7 @@ namespace operators { ...@@ -23,7 +23,7 @@ namespace operators {
void FusionSeqPoolConcatOp::InferShape( void FusionSeqPoolConcatOp::InferShape(
framework::InferShapeContext* ctx) const { framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqPoolConcatOp should be empty."); "Inputs(X) of FusionSeqPoolConcatOp should not be empty.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSeqPoolConcatOp should not be null."); "Output(Out) of FusionSeqPoolConcatOp should not be null.");
int axis = ctx->Attrs().Get<int>("axis"); int axis = ctx->Attrs().Get<int>("axis");
...@@ -54,12 +54,13 @@ void FusionSeqPoolConcatOpMaker::Make() { ...@@ -54,12 +54,13 @@ void FusionSeqPoolConcatOpMaker::Make() {
AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable(); AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable();
AddOutput("Out", "(LoDTensor) Output tensor of concat operator."); AddOutput("Out", "(LoDTensor) Output tensor of concat operator.");
AddAttr<std::string>("pooltype", AddAttr<std::string>("pooltype",
"(string, default 'AVERAGE') some of the pooling " "(string, default 'SUM') some of the pooling "
"pooltype of SequencePoolOp.") "pooltype of SequencePoolOp.")
.SetDefault("SUM") .SetDefault("SUM")
.InEnum({"AVERAGE", "SUM", "SQRT"}); .InEnum({"AVERAGE", "SUM", "SQRT"});
AddAttr<int>("axis", AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated.") "The axis along which the input tensors will be concatenated. "
"Only supports concat axis=1 yet.")
.SetDefault(1); .SetDefault(1);
AddComment(R"DOC( AddComment(R"DOC(
Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator. Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
...@@ -100,6 +101,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> { ...@@ -100,6 +101,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr); attr);
size_t n = ins.size(); size_t n = ins.size();
size_t dst_step_size = n * w;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
auto x_dims = ins[i]->dims(); auto x_dims = ins[i]->dims();
auto x_lod = ins[i]->lod()[0]; auto x_lod = ins[i]->lod()[0];
...@@ -112,7 +114,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> { ...@@ -112,7 +114,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
for (size_t j = 0; j < bs; ++j) { for (size_t j = 0; j < bs; ++j) {
attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]); attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]);
seqpool(src, dst, &attr); seqpool(src, dst, &attr);
dst += n * w; dst += dst_step_size;
src += attr.h * attr.w; src += attr.h * attr.w;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册