diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index f71a3d0f2e51f1a8d30fbc5f436edc97e80c57c1..a595a8ab4299298f625b8322a0adbed6d0b4fda3 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -69,6 +69,7 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) +cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) if (WITH_MKLDNN) cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc index 20b822003359a169c2b0a6f8c95b7f5abf404b1c..96a3b7ee058647156258b946c1301138c185fa31 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc @@ -39,21 +39,25 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, auto is_seqpool_op_with_pootype_of_nth_input_of_concat = [=]( Node* x, const std::string& type, int idx) -> bool { - bool ok = x && x->IsOp() && x->Op()->Type() == "sequence_pool" && - x->Op()->HasAttr("pooltype") && - boost::get(x->Op()->GetAttr("pooltype")) == type && - x->outputs.size() == 2; // seqpool should only have 2 outputs - if (ok) { - // only one output of seqpool_op is nth_input_var of concat - // the other one should be unused empty var + bool this_is_seqpool_op = + x && x->IsOp() && x->Op()->Type() == "sequence_pool" && + x->Op()->HasAttr("pooltype") && + boost::get(x->Op()->GetAttr("pooltype")) == type && + x->outputs.size() == 2; // seqpool should only have 2 outputs + bool satisfied_all = this_is_seqpool_op; + if (this_is_seqpool_op) { + // Only one output of seqpool_op is nth_input_var of concat, + // the other one should be unused empty var. if (is_nth_input_var_of_concat(x->outputs[0], idx)) { - ok = ok && x->outputs[1]->IsVar() && x->outputs[1]->outputs.size() == 0; + satisfied_all = satisfied_all && x->outputs[1]->IsVar() && + x->outputs[1]->outputs.size() == 0; } else { - ok = ok && is_nth_input_var_of_concat(x->outputs[1], idx) && - x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0; + satisfied_all = + satisfied_all && is_nth_input_var_of_concat(x->outputs[1], idx) && + x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0; } } - return ok; + return satisfied_all; }; auto* concat_op = pattern->NewNode( @@ -72,6 +76,7 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, std::vector seqpool_ops_input_var(num_inputs); std::vector seqpool_ops_output_var(num_inputs); + std::vector seqpool_ops_output_unused_var(num_inputs); std::vector seqpool_ops(num_inputs); for (int i = 0; i < num_inputs; ++i) { @@ -84,6 +89,15 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, }, name_scope + "/sequence_pool_out_" + std::to_string(i)); + seqpool_ops_output_unused_var[i] = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && x->inputs.size() == 1 && + x->outputs.size() == 0 && + is_seqpool_op_with_pootype_of_nth_input_of_concat(x->inputs[0], + "SUM", i); + }, + name_scope + "/sequence_pool_unused_out_" + std::to_string(i)); + seqpool_ops[i] = pattern->NewNode( [=](Node* x) { return x && x->IsOp() && @@ -93,23 +107,29 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, seqpool_ops_input_var[i] = pattern->NewNode( [=](Node* x) { - return x && x->IsVar() && x->outputs.size() >= 1 && - is_seqpool_op_with_pootype_of_nth_input_of_concat( - x->outputs[0], "SUM", i); + bool basic = x && x->IsVar() && x->outputs.size() >= 1; + bool next_is_fine = false; + for (auto* o : x->outputs) { + if (is_seqpool_op_with_pootype_of_nth_input_of_concat(o, "SUM", + i)) { + next_is_fine = true; + break; + } + } + return basic && next_is_fine; }, name_scope + "/sequence_pool_in_" + std::to_string(i)); // Links seqpool_ops[i] ->LinksFrom({seqpool_ops_input_var[i]}) - .LinksTo({seqpool_ops_output_var[i]}); + .LinksTo({seqpool_ops_output_var[i], seqpool_ops_output_unused_var[i]}); } concat_op->LinksFrom(seqpool_ops_output_var).LinksTo({concat_out_var}); return concat_out_var; } -int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, - int num_inputs) { +int BuildFusion(Graph* graph, const std::string& name_scope, int num_inputs) { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs); @@ -178,8 +198,8 @@ std::unique_ptr SeqPoolConcatFusePass::ApplyImpl( FusePassBase::Init(name_scope_, graph.get()); int fusion_count = 0; for (int i = MAX_CONCAT_INPUTS; i > 0; --i) { - fusion_count += BuildFusion( - graph.get(), name_scope_ + "/" + std::to_string(i), param_scope(), i); + fusion_count += + BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i); } AddStatis(fusion_count); diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h index 59730fde55f6090a952bb8f56d0e2a73130972b8..ba2154045e62c687173565c5ad30ea4d45d3c8f4 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h @@ -23,6 +23,20 @@ namespace paddle { namespace framework { namespace ir { +/** + * Fuse SequencePool(with sum pooltype yet) and Concat; + * + * Before fuse: + * | | | + * seq_pool, seq_pool, ... seq_pool + * \ | ... / + * concat + * | + * After fuse: + * \ | / + * FusionSeqPoolConcat + * | + */ class SeqPoolConcatFusePass : public FusePassBase { public: virtual ~SeqPoolConcatFusePass() {} diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..456a03192cc4e4a9d0dbe2dcb649b6c1b4d9cd5a --- /dev/null +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc @@ -0,0 +1,198 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h" +#include +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, + const std::vector& inputs, + const std::vector& outputs) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + if (type == "sequence_pool") { + op->SetInput("X", {inputs[0]}); + std::string pooltype = "SUM"; + op->SetAttr("pooltype", pooltype); + op->SetOutput("MaxIndex", {outputs[0]}); + op->SetOutput("Out", {outputs[1]}); + } else if (type == "concat") { + op->SetInput("X", inputs); + op->SetAttr("axis", 1); + op->SetOutput("Out", {outputs[0]}); + } else { + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); + } + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); +} + +int CountOpType(const ir::Graph* graph, + const std::string& op_type = "fusion_seqpool_concat") { + int count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == op_type) { + ++count; + } + } + return count; +} + +std::unique_ptr GetNumNodesOfBeforeAfter( + std::unique_ptr graph, int* before, int* after, + const std::string& pass_type = "seqpool_concat_fuse_pass") { + auto pass = PassRegistry::Instance().Get(pass_type); + *before = graph->Nodes().size(); + graph = pass->Apply(std::move(graph)); + *after = graph->Nodes().size(); + return graph; +} + +/* + * Before fuse: + * a b c + * | | | + * op1 op2 op3 + * / \ / \ / \ + * d e f g h i + * \ | / + * concat + * | + * j + * Type of op1, op2 and op3 are sequence_pool, with "SUM" pooltype attr + * + * After fuse: + * a b c + * \ | / + * fusion_seqpool_concat + * | + * j + */ +TEST(SeqPoolConcatFusePass, basic) { + ProgramDesc prog; + for (auto& v : std::vector( + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + } + + SetOp(&prog, "sequence_pool", std::vector({"a"}), + std::vector({"d", "e"})); + SetOp(&prog, "sequence_pool", std::vector({"b"}), + std::vector({"f", "g"})); + SetOp(&prog, "sequence_pool", std::vector({"c"}), + std::vector({"h", "i"})); + SetOp(&prog, "concat", std::vector({"e", "g", "i"}), + std::vector({"j"})); + + std::unique_ptr graph(new ir::Graph(prog)); + int before, after; + graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); + // Remove 10 Nodes: op1, op2, op3, d, e, f, g, h, i, concat_op + // Add 1 Node: fusion_seqpool_concat + EXPECT_EQ(after, before - 9); + EXPECT_EQ(CountOpType(graph.get()), 1); +} + +/* + * Before fuse: + * a b + * | / \ + * op1 op2 op3 + * / \ / \ \ + * c d e f g + * \ / + * concat + * | + * h + * Type of op1 and op2 are sequence_pool, with "SUM" pooltype attr + * + * After fuse: + * a b + * \ / \ + * fusion_seqpool_concat op3 + * | | + * h g + */ +TEST(SeqPoolConcatFusePass, advanced) { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "c", "d", "e", "f", "g", "h"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + } + + SetOp(&prog, "sequence_pool", std::vector({"a"}), + std::vector({"c", "d"})); + SetOp(&prog, "sequence_pool", std::vector({"b"}), + std::vector({"e", "f"})); + SetOp(&prog, "op3", std::vector({"b"}), + std::vector({"g"})); + SetOp(&prog, "concat", std::vector({"d", "f"}), + std::vector({"h"})); + + std::unique_ptr graph(new ir::Graph(prog)); + int before, after; + graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); + // Remove 7 Nodes: op1, op2, c, d, e, f concat_op + // Add 1 Node: fusion_seqpool_concat + EXPECT_EQ(after, before - 6); + EXPECT_EQ(CountOpType(graph.get()), 1); +} + +ProgramDesc BuildProgramDesc(int num_inputs_of_concat) { + ProgramDesc prog; + auto new_var = [&](const std::string& name) { + auto* var = prog.MutableBlock(0)->Var(name); + var->SetType(proto::VarType::LOD_TENSOR); + }; + std::vector concat_inputs; + for (int i = 0; i < num_inputs_of_concat; ++i) { + std::string prefix = "seqpool_op_" + i; + new_var(prefix + "in"); + new_var(prefix + "out"); + new_var(prefix + "out_unused"); + SetOp(&prog, "sequence_pool", std::vector({prefix + "in"}), + std::vector({prefix + "out", prefix + "out_unused"})); + concat_inputs.push_back(prefix + "out"); + } + SetOp(&prog, "concat", concat_inputs, + std::vector({"concat_out"})); + return prog; +} + +// test more inputs of concat +TEST(SeqPoolConcatFusePass, more_inputs) { + for (int num : {1, 2, 10}) { + ProgramDesc prog = BuildProgramDesc(num); + std::unique_ptr graph(new ir::Graph(prog)); + int before, after; + graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); + // Remove Nodes: n * (seqpool_op, out, out_unused), and concat_op + // Add Node: fusion_seqpool_concat op + EXPECT_EQ(after, before - num * 3); + EXPECT_EQ(CountOpType(graph.get()), 1); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(seqpool_concat_fuse_pass); diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index 7830e859567747e6c05686335919e8346f76320d..cdd01cb9f06424b39d17e192f9a924451ad1daaf 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -204,11 +204,14 @@ static std::string DescribeTensor(const PaddleTensor &tensor) { os << to_string(l) << "; "; } os << "\n"; - os << " - data: "; + os << " - memory length: " << tensor.data.length(); + os << "\n"; + os << " - data: "; int dim = VecReduceToInt(tensor.shape); + float *pdata = static_cast(tensor.data.data()); for (int i = 0; i < dim; i++) { - os << static_cast(tensor.data.data())[i] << " "; + os << pdata[i] << " "; } os << '\n'; return os.str(); @@ -224,10 +227,12 @@ static std::string DescribeZeroCopyTensor(const ZeroCopyTensor &tensor) { os << to_string(l) << "; "; } os << "\n"; - os << " - data: "; PaddlePlace place; int size; const auto *data = tensor.data(&place, &size); + os << " - numel: " << size; + os << "\n"; + os << " - data: "; for (int i = 0; i < size; i++) { os << data[i] << " "; } diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index 3642f36127f1f8df30858e34bc0e1a8d09603775..832c8cdf2849279c4c32a81e9f81ef522c401b86 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -123,7 +123,8 @@ class ZeroCopyTensor { */ template T* mutable_data(PaddlePlace place); - /** Get the memory directly, will return the place and memory size by pointer. + /** Get the memory directly, will return the place and element size by + * pointer. * This is for reading the output tensor. */ template diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc index 7e7c386f97e6bf132bbf262e534c153baf3af84f..315b49533260e57f7124d382555c902b31be4e1e 100644 --- a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc @@ -351,10 +351,10 @@ TEST(Analyzer_rnn1, ZeroCopy) { ASSERT_TRUE(native_predictor->Run(native_inputs.front(), &native_outputs)); LOG(INFO) << "native output " << DescribeTensor(native_outputs.front()); - int output_size{0}; + int output_size{0}; // this is the number of elements not memory size auto *zero_copy_data = output_tensor->data(&place, &output_size); auto *native_data = static_cast(native_outputs.front().data.data()); - for (size_t i = 0; i < output_size / sizeof(float); i++) { + for (int i = 0; i < output_size; i++) { EXPECT_NEAR(zero_copy_data[i], native_data[i], 1e-3); } } diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc index 083bdf15e92782666af18f4ab2c07c6603675105..d9de55ab76e66d4f129674db829ca85cb471d2de 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc @@ -121,14 +121,6 @@ void PrepareInputs(std::vector *input_slots, DataRecord *data) { } } -void SetConfig(AnalysisConfig *cfg) { - cfg->SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params"); - cfg->DisableGpu(); - cfg->SwitchSpecifyInputNames(); - cfg->pass_builder()->TurnOnDebug(); - cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); -} - void SetInput(std::vector> *inputs) { DataRecord data(FLAGS_infer_data, FLAGS_batch_size); std::vector input_slots; @@ -141,15 +133,22 @@ void SetInput(std::vector> *inputs) { } } +void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { + cfg->SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params"); + cfg->DisableGpu(); + cfg->SwitchSpecifyInputNames(); + cfg->pass_builder()->TurnOnDebug(); + cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); + if (use_mkldnn) { + cfg->EnableMKLDNN(); + } +} + void profile(bool use_mkldnn = false) { AnalysisConfig cfg; - SetConfig(&cfg); + SetConfig(&cfg, use_mkldnn); - if (use_mkldnn) { - cfg.EnableMKLDNN(); - } std::vector outputs; - std::vector> input_slots_all; SetInput(&input_slots_all); TestPrediction(reinterpret_cast(&cfg), @@ -169,22 +168,112 @@ TEST(Analyzer_seq_pool1, compare) { reinterpret_cast(&cfg), input_slots_all); } -// Check the fuse status -TEST(Analyzer_seq_pool1, fuse_statis) { +// Compare Deterministic result +TEST(Analyzer_seq_pool1, compare_determine) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareDeterministic(reinterpret_cast(&cfg), + input_slots_all); +} + +void analysis_fuse_statis(bool use_zerocopy) { AnalysisConfig cfg; SetConfig(&cfg); + cfg.SwitchUseFeedFetchOps(!use_zerocopy); int num_ops; auto predictor = CreatePaddlePredictor(cfg); - auto fuse_statis = GetFuseStatis( - static_cast(predictor.get()), &num_ops); - + auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops); + ASSERT_TRUE(fuse_statis.count("fc_fuse")); + ASSERT_EQ(fuse_statis.at("fc_fuse"), 10); ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse")); EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2); - LOG(INFO) << "num_ops: " << num_ops; EXPECT_EQ(num_ops, 195); } +// Check the fuse status +TEST(Analyzer_seq_pool1, fuse_statis) { analysis_fuse_statis(false); } + +void PrepareZeroCopyInputs( + const std::unique_ptr &predictor, + std::vector> *inputs) { + DataRecord data(FLAGS_infer_data, FLAGS_batch_size); + // only feed one batch + const auto &one_batch = data.NextBatch(); + inputs->clear(); + for (size_t i = 0; i < one_batch.size(); ++i) { + auto &slot = one_batch[i]; + auto tensor = predictor->GetInputTensor(slot.name + "_embed"); + tensor->Reshape(slot.shape); + tensor->SetLoD({slot.lod}); + ZeroCopyTensorAssignData(tensor.get(), slot.data); + inputs->emplace_back(std::move(tensor)); + } +} + +// return the output values +std::vector zerocopy_profile(int repeat_times) { + AnalysisConfig config; + SetConfig(&config); + config.SwitchUseFeedFetchOps(false); + auto predictor = CreatePaddlePredictor(config); + std::vector> inputs; + PrepareZeroCopyInputs(predictor, &inputs); + auto output_tensor = predictor->GetOutputTensor("reduce_sum_0.tmp_0"); + Timer timer; + LOG(INFO) << "Warm up run..."; + timer.tic(); + predictor->ZeroCopyRun(); + PrintTime(FLAGS_batch_size, 1, 1, 0, timer.toc(), 1); + if (FLAGS_profile) { + paddle::platform::ResetProfiler(); + } + LOG(INFO) << "Run " << repeat_times << " times..."; + timer.tic(); + for (int i = 0; i < repeat_times; i++) { + predictor->ZeroCopyRun(); + } + PrintTime(FLAGS_batch_size, repeat_times, 1, 0, timer.toc() / repeat_times, + 1); + + VLOG(3) << "ZeroCopy output: " << DescribeZeroCopyTensor(*output_tensor); + PaddlePlace place; + int output_size{0}; + auto *pdata = output_tensor->data(&place, &output_size); + std::vector res(output_size); + for (int i = 0; i < output_size; ++i) { + res[i] = pdata[i]; + } + return res; +} + +TEST(Analyzer_seq_pool1, zerocopy_profile) { zerocopy_profile(FLAGS_repeat); } + +TEST(Analyzer_seq_pool1, zerocopy_fuse_statis) { analysis_fuse_statis(true); } + +TEST(Analyzer_seq_pool1, zerocopy_compare_native) { + AnalysisConfig config; + SetConfig(&config); + config.SwitchUseFeedFetchOps(true); + auto predictor = CreatePaddlePredictor(config.ToNativeConfig()); + std::vector native_outputs; + std::vector> input_slots_all; + SetInput(&input_slots_all); + ASSERT_TRUE(predictor->Run(input_slots_all[0], &native_outputs)); + EXPECT_EQ(native_outputs.size(), 1UL); + + auto zerocopy_output = zerocopy_profile(1); + EXPECT_EQ(zerocopy_output.size() * sizeof(float), + native_outputs.front().data.length()); + auto *native_data = static_cast(native_outputs.front().data.data()); + for (size_t i = 0; i < zerocopy_output.size(); ++i) { + EXPECT_NEAR(zerocopy_output[i], native_data[i], 1e-3); + } +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc index 578ff6b2d010bebdacac3a9ef7ff69a4490fdf4c..b181140db750a8d1b74c0b6cc93259a208fe5b06 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc @@ -23,7 +23,7 @@ namespace operators { void FusionSeqPoolConcatOp::InferShape( framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, - "Inputs(X) of FusionSeqPoolConcatOp should be empty."); + "Inputs(X) of FusionSeqPoolConcatOp should not be empty."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FusionSeqPoolConcatOp should not be null."); int axis = ctx->Attrs().Get("axis"); @@ -54,12 +54,13 @@ void FusionSeqPoolConcatOpMaker::Make() { AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable(); AddOutput("Out", "(LoDTensor) Output tensor of concat operator."); AddAttr("pooltype", - "(string, default 'AVERAGE') some of the pooling " + "(string, default 'SUM') some of the pooling " "pooltype of SequencePoolOp.") .SetDefault("SUM") .InEnum({"AVERAGE", "SUM", "SQRT"}); AddAttr("axis", - "The axis along which the input tensors will be concatenated.") + "The axis along which the input tensors will be concatenated. " + "Only supports concat axis=1 yet.") .SetDefault(1); AddComment(R"DOC( Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator. @@ -100,6 +101,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel { jit::Get, platform::CPUPlace>( attr); size_t n = ins.size(); + size_t dst_step_size = n * w; for (size_t i = 0; i < n; ++i) { auto x_dims = ins[i]->dims(); auto x_lod = ins[i]->lod()[0]; @@ -112,7 +114,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel { for (size_t j = 0; j < bs; ++j) { attr.h = static_cast(x_lod[j + 1] - x_lod[j]); seqpool(src, dst, &attr); - dst += n * w; + dst += dst_step_size; src += attr.h * attr.w; } }