diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 5228840c960aca4817f03f6421a24139054e5575..d50d84137965d42c2e116925fbe9a263ed676fac 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -61,6 +61,7 @@ pass_library(multi_batch_merge_pass base) pass_library(conv_bn_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference) pass_library(seqpool_concat_fuse_pass inference) +pass_library(seqpool_cvm_concat_fuse_pass inference) pass_library(repeated_fc_relu_fuse_pass inference) pass_library(squared_mat_sub_fuse_pass inference) pass_library(is_test_pass base) @@ -118,6 +119,7 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) +cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) if(NOT WIN32) cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) diff --git a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..8261bfc15348f90af4ed7acb9e5b68373dc5e715 --- /dev/null +++ b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h" +#include +#include +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" + +namespace paddle { +namespace framework { +namespace ir { + +namespace { +static PDNode* BuildCVMConcatPattern(PDPattern* pattern) { + auto cvm_behind_x = [](Node* x) -> bool { + Node* adj = x->inputs[0]; + Node* alt = x->inputs[0]->inputs[0]; + return x && adj && adj->IsVar() && alt->IsOp() && + alt->Op()->Type() == "cvm"; + }; + auto* concat_op_node = pattern->NewNode("concat_op") + ->assert_is_op("concat") + ->assert_op_attr("axis", 1) + ->assert_more(cvm_behind_x); + return concat_op_node; +} + +static void GetConcatNodes(ir::Graph* graph, std::vector* concat_nodes) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + auto concat_op_node = BuildCVMConcatPattern(pattern); + GraphPatternDetector::handle_t handler = [&]( + const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + Node* concat_op = subgraph.at(concat_op_node); + concat_nodes->push_back(concat_op); + }; + gpd(graph, handler); +} +} // anonymous namespace + +void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init("seqpool_cvm_concat_fuse", graph); + std::vector concat_nodes; + GetConcatNodes(graph, &concat_nodes); + + int count = 0; + for (auto* concat_node : concat_nodes) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + auto concat_before_x = [=](Node* x) -> bool { + return x && x->outputs[0] == concat_node; + }; + PDNode* seqpool_in_var_node = + pattern->NewNode("seqpool_in_var") + ->assert_is_only_input_of_op("sequence_pool"); + PDNode* seqpool_op_node = + pattern->NewNode("seqpool_op") + ->assert_is_op("sequence_pool") + ->assert_op_attr("pooltype", "SUM"); + PDNode* seqpool_out_var_node = + pattern->NewNode("seqpool_out_var") + ->assert_is_op_nth_output("sequence_pool", "Out", 0) + ->assert_is_op_nth_input("cvm", "X", 0); + PDNode* seqpool_idx_out_var_node = + pattern->NewNode("seqpool_idx_out_var") + ->assert_is_op_nth_output("sequence_pool", "MaxIndex", 0); + PDNode* cvm_op_node = + pattern->NewNode("cvm_op")->assert_is_op("cvm")->assert_op_attr( + "use_cvm", true); + PDNode* cvm_out_var_node = pattern->NewNode("cvm_op_out_var") + ->assert_is_op_nth_output("cvm", "Y", 0) + ->assert_more(concat_before_x); + PDNode* cvm_cvm_in_var_node = pattern->NewNode("cvm_cvm_in_var") + ->assert_is_op_nth_input("cvm", "CVM", 0); + + seqpool_op_node->LinksFrom({seqpool_in_var_node}) + .LinksTo({seqpool_out_var_node, seqpool_idx_out_var_node}); + seqpool_out_var_node->LinksFrom({seqpool_op_node}).LinksTo({cvm_op_node}); + cvm_op_node->LinksTo({cvm_out_var_node}) + .LinksFrom({cvm_cvm_in_var_node, seqpool_out_var_node}); + + std::unordered_map ins_to_concat; + std::vector subgraph_ins; + std::vector subgraph_ins_name; + std::unordered_set marked_nodes; + + Node* cvm_input_of_cvm; + Node* concat_out_var = concat_node->outputs[0]; + + GraphPatternDetector::handle_t handler = [&]( + const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + Node* seqpool_in_var = subgraph.at(seqpool_in_var_node); + Node* seqpool_op = subgraph.at(seqpool_op_node); + Node* seqpool_out_var = subgraph.at(seqpool_out_var_node); + Node* seqpool_idx_out_var = subgraph.at(seqpool_idx_out_var_node); + Node* cvm_op = subgraph.at(cvm_op_node); + Node* cvm_out_var = subgraph.at(cvm_out_var_node); + cvm_input_of_cvm = subgraph.at(cvm_cvm_in_var_node); + marked_nodes.insert({seqpool_op, seqpool_out_var, seqpool_idx_out_var, + cvm_op, cvm_out_var, concat_node}); + ins_to_concat[cvm_out_var->Name()] = seqpool_in_var; + }; + gpd(graph, handler); + + if (!ins_to_concat.empty()) { + for (const auto* in : concat_node->inputs) { + subgraph_ins.push_back(ins_to_concat.at(in->Name())); + subgraph_ins_name.push_back(ins_to_concat.at(in->Name())->Name()); + } + + // Create New OpDesc + OpDesc op_desc; + op_desc.SetType("fusion_seqpool_cvm_concat"); + op_desc.SetInput("X", subgraph_ins_name); + op_desc.SetInput("CVM", {cvm_input_of_cvm->Name()}); + op_desc.SetAttr("pooltype", std::string("SUM")); + op_desc.SetAttr("use_cvm", true); + op_desc.SetAttr("axis", concat_node->Op()->GetAttr("axis")); + op_desc.SetOutput("Out", {concat_out_var->Name()}); + auto* op = graph->CreateOpNode(&op_desc); + + for (size_t i = 0; i < subgraph_ins.size(); ++i) { + IR_NODE_LINK_TO(subgraph_ins[i], op); + } + IR_NODE_LINK_TO(cvm_input_of_cvm, op); + IR_NODE_LINK_TO(op, concat_out_var); + + GraphSafeRemoveNodes(graph, marked_nodes); + count++; + } + } + AddStatis(count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(seqpool_cvm_concat_fuse_pass, + paddle::framework::ir::SeqPoolCVMConcatFusePass); diff --git a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..88a41983c6bf7b4e76d7912dbb3821b2c2ed533b --- /dev/null +++ b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/** + * Fuse SequencePool(with sum pooltype yet) and Concat; + * + * Before fuse: + * | | | + * seq_pool, seq_pool, ... seq_pool + * | | | + * cvm cvm cvm + * \ | ... / + * concat + * | + * After fuse: + * \ | / + * FusionSeqPoolCVMConcat + * | + */ +class SeqPoolCVMConcatFusePass : public FusePassBase { + public: + virtual ~SeqPoolCVMConcatFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + const std::string name_scope_{"seqpool_cvm_concat_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass_tester.cc b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..bba640cf148d1ebfc2583b420c3ffd8ff1d110f1 --- /dev/null +++ b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass_tester.cc @@ -0,0 +1,239 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h" +#include +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, + const std::vector& inputs, + const std::vector& outputs) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + if (type == "sequence_pool") { + op->SetInput("X", {inputs[0]}); + std::string pooltype = "SUM"; + op->SetAttr("pooltype", pooltype); + op->SetOutput("MaxIndex", {outputs[0]}); + op->SetOutput("Out", {outputs[1]}); + } else if (type == "concat") { + op->SetInput("X", inputs); + op->SetAttr("axis", 1); + op->SetOutput("Out", {outputs[0]}); + } else if (type == "cvm") { + op->SetInput("X", {inputs[0]}); + op->SetInput("CVM", {inputs[1]}); + op->SetOutput("Y", {outputs[0]}); + op->SetAttr("use_cvm", true); + } else { + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); + } + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); +} + +int CountOpType(const ir::Graph* graph, + const std::string& op_type = "fusion_seqpool_cvm_concat") { + int count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == op_type) { + ++count; + } + } + return count; +} + +std::unique_ptr GetNumNodesOfBeforeAfter( + std::unique_ptr graph, int* before, int* after, + const std::string& pass_type = "seqpool_cvm_concat_fuse_pass") { + auto pass = PassRegistry::Instance().Get(pass_type); + *before = graph->Nodes().size(); + graph.reset(pass->Apply(graph.release())); + *after = graph->Nodes().size(); + return graph; +} + +/* + * Before fuse: + * + * + * a b c + * | | | + * op1 op2 op3 + * / \ / \ / \ + * d e n f g n h i n + * | / | / | / + * op4 op5 op6 + * | | | + j k l + * \ | / + * concat + * | + * m + * + * Type of op1, op2 and op3 are sequence_pool, with "SUM" pooltype attr. + * Type of op4, op5 and op6 are cvm, with use_cvm is true. + * + * After fuse: + * a b c n + * \ | | / + * fusion_seqpool_cvm_concat + * | + * m + */ +TEST(SeqPoolCVMConcatFusePass, basic) { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "c", "d", "e", "f", "g", "h", "i", + "j", "k", "l", "m", "n"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + } + + SetOp(&prog, "sequence_pool", std::vector({"a"}), + std::vector({"d", "e"})); + SetOp(&prog, "sequence_pool", std::vector({"b"}), + std::vector({"f", "g"})); + SetOp(&prog, "sequence_pool", std::vector({"c"}), + std::vector({"h", "i"})); + SetOp(&prog, "cvm", std::vector({"e", "n"}), + std::vector({"j"})); + SetOp(&prog, "cvm", std::vector({"g", "n"}), + std::vector({"k"})); + SetOp(&prog, "cvm", std::vector({"i", "n"}), + std::vector({"l"})); + SetOp(&prog, "concat", std::vector({"j", "k", "l"}), + std::vector({"m"})); + + std::unique_ptr graph(new ir::Graph(prog)); + int before, after; + graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); + // Remove 16 Nodes: op1, op2, op3, op4, op5, op6, d, e, f, g, h, i, j, k, l, + // concat_op + // Add 1 Node: fusion_seqpool_cvm_concat + EXPECT_EQ(after, before - 15); + EXPECT_EQ(CountOpType(graph.get()), 1); +} + +/* + * Before fuse: + * a b + * | / \ + * op1 k op2 k op3 + * / \ / / \ / \ + * c d e f g + * | | + * op4 op5 + * | | + * h i + * \ / + * concat + * | + * j + * Type of op1 and op2 are sequence_pool, with "SUM" pooltype attr. + * Type of op4 and op5 are cvm, with use_cvm is true. + * + * After fuse: + * a k b + * \ | / \ + * fusion_seqpool_cvm_concat op3 + * | | + * j g + */ +TEST(SeqPoolCVMConcatFusePass, advanced) { + ProgramDesc prog; + for (auto& v : std::vector( + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + } + + SetOp(&prog, "sequence_pool", std::vector({"a"}), + std::vector({"c", "d"})); + SetOp(&prog, "sequence_pool", std::vector({"b"}), + std::vector({"e", "f"})); + SetOp(&prog, "op3", std::vector({"b"}), + std::vector({"g"})); + SetOp(&prog, "cvm", std::vector({"d", "k"}), + std::vector({"h"})); + SetOp(&prog, "cvm", std::vector({"f", "k"}), + std::vector({"i"})); + SetOp(&prog, "concat", std::vector({"h", "i"}), + std::vector({"j"})); + + std::unique_ptr graph(new ir::Graph(prog)); + int before, after; + graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); + // Remove 11 Nodes: op1, op2, op4, op5, c, d, e, f, h, i, concat_op + // Add 1 Node: fusion_seqpool_cvm_concat + EXPECT_EQ(after, before - 10); + EXPECT_EQ(CountOpType(graph.get()), 1); +} + +ProgramDesc BuildProgramDesc(int num_inputs_of_concat) { + ProgramDesc prog; + auto new_var = [&](const std::string& name) { + auto* var = prog.MutableBlock(0)->Var(name); + var->SetType(proto::VarType::LOD_TENSOR); + }; + std::vector concat_inputs; + new_var("cvm_in"); + for (int i = 0; i < num_inputs_of_concat; ++i) { + std::string seqpool_prefix = "seqpool_op_" + std::to_string(i); + new_var(seqpool_prefix + "in"); + new_var(seqpool_prefix + "out"); + new_var(seqpool_prefix + "out_unused"); + SetOp(&prog, "sequence_pool", + std::vector({seqpool_prefix + "in"}), + std::vector( + {seqpool_prefix + "out_unused", seqpool_prefix + "out"})); + + std::string cvm_prefix = "cvm_op_" + std::to_string(i); + new_var(cvm_prefix + "out"); + SetOp(&prog, "cvm", + std::vector({seqpool_prefix + "out", "cvm_in"}), + std::vector({cvm_prefix + "out"})); + + concat_inputs.push_back(cvm_prefix + "out"); + } + SetOp(&prog, "concat", concat_inputs, + std::vector({"concat_out"})); + return prog; +} + +// test more inputs of concat +TEST(SeqPoolCVMConcatFusePass, more_inputs) { + for (int num : {1, 2, 10}) { + ProgramDesc prog = BuildProgramDesc(num); + std::unique_ptr graph(new ir::Graph(prog)); + int before, after; + graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); + // Remove Nodes: n * (seqpool_op, seqpool_out, out_unused, cvm_op, cvm_out), + // and concat_op + // Add Node: fusion_seqpool_cvm_concat op + EXPECT_EQ(after, before - num * 5); + EXPECT_EQ(CountOpType(graph.get()), 1); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(seqpool_cvm_concat_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index bc2c0914728f30fe45dc4ece6477d03a244e8b40..7d5a77b4b8b090970b8b94ed4b04ab4e208b1b68 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -144,6 +144,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "attention_lstm_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", // // "seqpool_concat_fuse_pass", // + "seqpool_cvm_concat_fuse_pass", // // "embedding_fc_lstm_fuse_pass", // "fc_lstm_fuse_pass", // "mul_lstm_fuse_pass", // diff --git a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..14e327bb37d1381affe0189ce220fe13c63eac99 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc @@ -0,0 +1,148 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h" +#include +#include +#include "paddle/fluid/operators/jit/kernels.h" + +namespace paddle { +namespace operators { + +void FusionSeqPoolCVMConcatOp::InferShape( + framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE_GE( + ctx->Inputs("X").size(), 1UL, + "Inputs(X) of FusionSeqPoolCVMConcatOp should not be empty."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FusionSeqPoolCVMConcatOp should not be null."); + int axis = ctx->Attrs().Get("axis"); + PADDLE_ENFORCE_EQ( + axis, 1, "FusionSeqPoolCVMConcatOp only supports concat axis=1 yet."); + bool use_cvm = ctx->Attrs().Get("use_cvm"); + PADDLE_ENFORCE_EQ( + use_cvm, true, + "FusionSeqPoolCVMConcatOp only supports use_cvm is true yet."); + + auto ins_dims = ctx->GetInputsDim("X"); + const size_t n = ins_dims.size(); + PADDLE_ENFORCE_GT(n, 0UL, "Input tensors count should > 0."); + if (n == 1) { + LOG(WARNING) << "Only have one input, may waste memory"; + } + + // The output height should be confirmed in Compute, + // since input lod is not accessible here. + PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2, + "The dims size of first input should be 2."); + ctx->SetOutputDim("Out", {-1, ins_dims[0][axis] * static_cast(n)}); +} + +framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace()); +} + +void FusionSeqPoolCVMConcatOpMaker::Make() { + AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable(); + AddInput("CVM", + "(Tensor), a 2-D Tensor with shape [N x 2], where N is the batch " + "size, 2 is show and click."); + AddOutput("Out", "(LoDTensor) Output tensor of concat operator."); + AddAttr("pooltype", + "(string, default 'SUM') some of the pooling " + "pooltype of SequencePoolOp.") + .SetDefault("SUM") + .InEnum({"AVERAGE", "SUM", "SQRT"}); + AddAttr("use_cvm", "bool, use cvm or not").SetDefault(true); + AddAttr("axis", + "The axis along which the input tensors will be concatenated. " + "Only supports concat axis=1 yet.") + .SetDefault(1); + AddComment(R"DOC( +Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator. +)DOC"); +} + +template +class FusionSeqPoolCVMConcatKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + std::string pooltype = ctx.Attr("pooltype"); + auto x0_lod = ins[0]->lod(); + auto x0_dims = ins[0]->dims(); + auto y_dims = out->dims(); + size_t bs = x0_lod[0].size() - 1; + out->Resize({static_cast(bs), y_dims[1]}); + framework::LoD y_lod(1); + y_lod[0].resize(bs + 1); + for (size_t i = 0; i <= bs; ++i) { + y_lod[0][i] = i; + } + out->set_lod(y_lod); + auto place = ctx.GetPlace(); + T* y_data = out->mutable_data(place); + + int w = ins[0]->numel() / x0_dims[0]; + PADDLE_ENFORCE_EQ(y_dims[1] % w, 0, + "The output of dims[1] should be dividable of w"); + jit::seq_pool_attr_t attr(w, jit::SeqPoolType::kSum); + if (pooltype == "AVERAGE") { + attr.type = jit::SeqPoolType::kAvg; + } else if (pooltype == "SQRT") { + attr.type = jit::SeqPoolType::kSqrt; + } + auto seqpool = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + attr); + size_t n = ins.size(); + size_t dst_step_size = n * w; + for (size_t i = 0; i < n; ++i) { + auto x_dims = ins[i]->dims(); + auto x_lod = ins[i]->lod()[0]; + const T* src = ins[i]->data(); + T* dst = y_data + i * w; + PADDLE_ENFORCE_EQ(static_cast(ins[i]->numel() / x_dims[0]), w, + "Width of all inputs should be equal."); + PADDLE_ENFORCE_EQ(x_lod.size(), bs + 1, + "Batchsize of all inputs should be equal."); + for (size_t j = 0; j < bs; ++j) { + attr.h = static_cast(x_lod[j + 1] - x_lod[j]); + seqpool(src, dst, &attr); + + // Currently only use_cvm is true. + dst[0] = log(dst[0] + 1); + dst[1] = log(dst[1] + 1) - dst[0]; + + dst += dst_step_size; + src += attr.h * attr.w; + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fusion_seqpool_cvm_concat, ops::FusionSeqPoolCVMConcatOp, + ops::FusionSeqPoolCVMConcatOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(fusion_seqpool_cvm_concat, + ops::FusionSeqPoolCVMConcatKernel, + ops::FusionSeqPoolCVMConcatKernel); diff --git a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h new file mode 100644 index 0000000000000000000000000000000000000000..75e8556c31a819572b1e73464f6dba235642ddcd --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +class FusionSeqPoolCVMConcatOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FusionSeqPoolCVMConcatOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_fusion_seqpool_cvm_concat_op.py b/python/paddle/fluid/tests/unittests/test_fusion_seqpool_cvm_concat_op.py new file mode 100644 index 0000000000000000000000000000000000000000..332f48ae71a9cc7b64d6aa7641c1ef8db63bc3a4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fusion_seqpool_cvm_concat_op.py @@ -0,0 +1,125 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from test_reorder_lod_tensor import convert_to_offset +from test_seq_pool import compute_seqpool_sum, compute_seqpool_avg, compute_seqpool_sqrt +from test_cvm_op import cvm_compute + + +class TestFusionSeqPoolCVMConcatOp(OpTest): + def setUp(self): + self.w = 11 + self.use_cvm = True + self.lods = [[[2, 3, 5]], [[1, 5, 2]]] + self.set_conf() + self.set_pooltype() + self.op_type = 'fusion_seqpool_cvm_concat' + self.axis = 1 + bs = len(self.lods[0][0]) + inputs = [] + outs = [] + # The cvm variable is not actually used. + cvm = np.array([[0.6, 0.4]]).astype("float32") + i = 0 + for lod in self.lods: + assert bs == len(lod[0]), 'All lod size should be equal' + x = np.random.uniform(0.1, 1, + [sum(lod[0]), self.w]).astype('float32') + offset = convert_to_offset(lod) + out = np.zeros((bs, self.w)).astype('float32') + if self.pooltype == "SUM": + compute_seqpool_sum(x, offset, out) + out = cvm_compute(out, self.w, self.use_cvm) + elif self.pooltype == "AVERAGE": + compute_seqpool_avg(x, offset, out) + out = cvm_compute(out, self.w, self.use_cvm) + elif self.pooltype == "SQRT": + compute_seqpool_sqrt(x, offset, out) + out = cvm_compute(out, self.w, self.use_cvm) + else: + raise Exception("Unsupported pool type!") + inputs.append(('x_{0}'.format(i), (x, lod))) + outs.append(out) + i = i + 1 + + self.inputs = {'X': inputs, "CVM": cvm} + self.outputs = {'Out': np.concatenate(outs, axis=self.axis)} + self.attrs = { + 'pooltype': self.pooltype, + 'axis': self.axis, + } + + def set_pooltype(self): + self.pooltype = "SUM" + + def set_conf(self): + pass + + def test_check_output(self): + self.check_output() + + +class TestFusionSeqPoolCVMConcatOpCase1(TestFusionSeqPoolCVMConcatOp): + def set_conf(self): + self.lods = [[[1]]] + + +class TestFusionSeqPoolCVMConcatOpCase2(TestFusionSeqPoolCVMConcatOp): + def set_conf(self): + self.lods = [[[1]], [[1]], [[1]]] + + +class TestFusionSeqPoolCVMConcatOpCase3(TestFusionSeqPoolCVMConcatOp): + def set_conf(self): + self.lods = [[[1, 3, 4, 6]]] + self.w = 10 + + +class TestFusionSeqPoolCVMConcatOpCase4(TestFusionSeqPoolCVMConcatOp): + def set_conf(self): + self.lods = [[[2, 13, 4]], [[1, 1, 1]], [[5, 3, 1]], [[9, 10, 3]]] + self.w = 3 + + +## test avg pool and sqrt +def create_test_avg_sqrt_class(parent): + class TestSeqPoolAvgCase(parent): + def set_pooltype(self): + self.pooltype = "AVERAGE" + + class TestSeqPoolSqrtCase(parent): + def set_pooltype(self): + self.pooltype = "SQRT" + + cls_name_avg = "{0}_{1}".format(parent.__name__, "avg") + cls_name_sqrt = "{0}_{1}".format(parent.__name__, "sqrt") + TestSeqPoolAvgCase.__name__ = cls_name_avg + TestSeqPoolSqrtCase.__name__ = cls_name_sqrt + globals()[cls_name_avg] = TestSeqPoolAvgCase + globals()[cls_name_sqrt] = TestSeqPoolSqrtCase + + +create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOp) +create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase1) +create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase2) +create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase3) +create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase4) + +if __name__ == '__main__': + unittest.main()