未验证 提交 ee2f296e 编写于 作者: 石晓伟 提交者: GitHub

Fusion: seqpool_cvm_concat (#18471)

* add fusion_seqpool_cvm_concat test=develop

* simplify pass, test=develop

* fix code style, test=develop
上级 768059b3
...@@ -61,6 +61,7 @@ pass_library(multi_batch_merge_pass base) ...@@ -61,6 +61,7 @@ pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference) pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(seqpool_concat_fuse_pass inference) pass_library(seqpool_concat_fuse_pass inference)
pass_library(seqpool_cvm_concat_fuse_pass inference)
pass_library(repeated_fc_relu_fuse_pass inference) pass_library(repeated_fc_relu_fuse_pass inference)
pass_library(squared_mat_sub_fuse_pass inference) pass_library(squared_mat_sub_fuse_pass inference)
pass_library(is_test_pass base) pass_library(is_test_pass base)
...@@ -118,6 +119,7 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph ...@@ -118,6 +119,7 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
if(NOT WIN32) if(NOT WIN32)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
static PDNode* BuildCVMConcatPattern(PDPattern* pattern) {
auto cvm_behind_x = [](Node* x) -> bool {
Node* adj = x->inputs[0];
Node* alt = x->inputs[0]->inputs[0];
return x && adj && adj->IsVar() && alt->IsOp() &&
alt->Op()->Type() == "cvm";
};
auto* concat_op_node = pattern->NewNode("concat_op")
->assert_is_op("concat")
->assert_op_attr<int>("axis", 1)
->assert_more(cvm_behind_x);
return concat_op_node;
}
static void GetConcatNodes(ir::Graph* graph, std::vector<Node*>* concat_nodes) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
auto concat_op_node = BuildCVMConcatPattern(pattern);
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* concat_op = subgraph.at(concat_op_node);
concat_nodes->push_back(concat_op);
};
gpd(graph, handler);
}
} // anonymous namespace
void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("seqpool_cvm_concat_fuse", graph);
std::vector<Node*> concat_nodes;
GetConcatNodes(graph, &concat_nodes);
int count = 0;
for (auto* concat_node : concat_nodes) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
auto concat_before_x = [=](Node* x) -> bool {
return x && x->outputs[0] == concat_node;
};
PDNode* seqpool_in_var_node =
pattern->NewNode("seqpool_in_var")
->assert_is_only_input_of_op("sequence_pool");
PDNode* seqpool_op_node =
pattern->NewNode("seqpool_op")
->assert_is_op("sequence_pool")
->assert_op_attr<std::string>("pooltype", "SUM");
PDNode* seqpool_out_var_node =
pattern->NewNode("seqpool_out_var")
->assert_is_op_nth_output("sequence_pool", "Out", 0)
->assert_is_op_nth_input("cvm", "X", 0);
PDNode* seqpool_idx_out_var_node =
pattern->NewNode("seqpool_idx_out_var")
->assert_is_op_nth_output("sequence_pool", "MaxIndex", 0);
PDNode* cvm_op_node =
pattern->NewNode("cvm_op")->assert_is_op("cvm")->assert_op_attr<bool>(
"use_cvm", true);
PDNode* cvm_out_var_node = pattern->NewNode("cvm_op_out_var")
->assert_is_op_nth_output("cvm", "Y", 0)
->assert_more(concat_before_x);
PDNode* cvm_cvm_in_var_node = pattern->NewNode("cvm_cvm_in_var")
->assert_is_op_nth_input("cvm", "CVM", 0);
seqpool_op_node->LinksFrom({seqpool_in_var_node})
.LinksTo({seqpool_out_var_node, seqpool_idx_out_var_node});
seqpool_out_var_node->LinksFrom({seqpool_op_node}).LinksTo({cvm_op_node});
cvm_op_node->LinksTo({cvm_out_var_node})
.LinksFrom({cvm_cvm_in_var_node, seqpool_out_var_node});
std::unordered_map<std::string, Node*> ins_to_concat;
std::vector<Node*> subgraph_ins;
std::vector<std::string> subgraph_ins_name;
std::unordered_set<const Node*> marked_nodes;
Node* cvm_input_of_cvm;
Node* concat_out_var = concat_node->outputs[0];
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* seqpool_in_var = subgraph.at(seqpool_in_var_node);
Node* seqpool_op = subgraph.at(seqpool_op_node);
Node* seqpool_out_var = subgraph.at(seqpool_out_var_node);
Node* seqpool_idx_out_var = subgraph.at(seqpool_idx_out_var_node);
Node* cvm_op = subgraph.at(cvm_op_node);
Node* cvm_out_var = subgraph.at(cvm_out_var_node);
cvm_input_of_cvm = subgraph.at(cvm_cvm_in_var_node);
marked_nodes.insert({seqpool_op, seqpool_out_var, seqpool_idx_out_var,
cvm_op, cvm_out_var, concat_node});
ins_to_concat[cvm_out_var->Name()] = seqpool_in_var;
};
gpd(graph, handler);
if (!ins_to_concat.empty()) {
for (const auto* in : concat_node->inputs) {
subgraph_ins.push_back(ins_to_concat.at(in->Name()));
subgraph_ins_name.push_back(ins_to_concat.at(in->Name())->Name());
}
// Create New OpDesc
OpDesc op_desc;
op_desc.SetType("fusion_seqpool_cvm_concat");
op_desc.SetInput("X", subgraph_ins_name);
op_desc.SetInput("CVM", {cvm_input_of_cvm->Name()});
op_desc.SetAttr("pooltype", std::string("SUM"));
op_desc.SetAttr("use_cvm", true);
op_desc.SetAttr("axis", concat_node->Op()->GetAttr("axis"));
op_desc.SetOutput("Out", {concat_out_var->Name()});
auto* op = graph->CreateOpNode(&op_desc);
for (size_t i = 0; i < subgraph_ins.size(); ++i) {
IR_NODE_LINK_TO(subgraph_ins[i], op);
}
IR_NODE_LINK_TO(cvm_input_of_cvm, op);
IR_NODE_LINK_TO(op, concat_out_var);
GraphSafeRemoveNodes(graph, marked_nodes);
count++;
}
}
AddStatis(count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(seqpool_cvm_concat_fuse_pass,
paddle::framework::ir::SeqPoolCVMConcatFusePass);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/**
* Fuse SequencePool(with sum pooltype yet) and Concat;
*
* Before fuse:
* | | |
* seq_pool, seq_pool, ... seq_pool
* | | |
* cvm cvm cvm
* \ | ... /
* concat
* |
* After fuse:
* \ | /
* FusionSeqPoolCVMConcat
* |
*/
class SeqPoolCVMConcatFusePass : public FusePassBase {
public:
virtual ~SeqPoolCVMConcatFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"seqpool_cvm_concat_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "sequence_pool") {
op->SetInput("X", {inputs[0]});
std::string pooltype = "SUM";
op->SetAttr("pooltype", pooltype);
op->SetOutput("MaxIndex", {outputs[0]});
op->SetOutput("Out", {outputs[1]});
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetAttr("axis", 1);
op->SetOutput("Out", {outputs[0]});
} else if (type == "cvm") {
op->SetInput("X", {inputs[0]});
op->SetInput("CVM", {inputs[1]});
op->SetOutput("Y", {outputs[0]});
op->SetAttr("use_cvm", true);
} else {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
int CountOpType(const ir::Graph* graph,
const std::string& op_type = "fusion_seqpool_cvm_concat") {
int count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == op_type) {
++count;
}
}
return count;
}
std::unique_ptr<ir::Graph> GetNumNodesOfBeforeAfter(
std::unique_ptr<ir::Graph> graph, int* before, int* after,
const std::string& pass_type = "seqpool_cvm_concat_fuse_pass") {
auto pass = PassRegistry::Instance().Get(pass_type);
*before = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
*after = graph->Nodes().size();
return graph;
}
/*
* Before fuse:
*
*
* a b c
* | | |
* op1 op2 op3
* / \ / \ / \
* d e n f g n h i n
* | / | / | /
* op4 op5 op6
* | | |
j k l
* \ | /
* concat
* |
* m
*
* Type of op1, op2 and op3 are sequence_pool, with "SUM" pooltype attr.
* Type of op4, op5 and op6 are cvm, with use_cvm is true.
*
* After fuse:
* a b c n
* \ | | /
* fusion_seqpool_cvm_concat
* |
* m
*/
TEST(SeqPoolCVMConcatFusePass, basic) {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "d", "e", "f", "g", "h", "i",
"j", "k", "l", "m", "n"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
}
SetOp(&prog, "sequence_pool", std::vector<std::string>({"a"}),
std::vector<std::string>({"d", "e"}));
SetOp(&prog, "sequence_pool", std::vector<std::string>({"b"}),
std::vector<std::string>({"f", "g"}));
SetOp(&prog, "sequence_pool", std::vector<std::string>({"c"}),
std::vector<std::string>({"h", "i"}));
SetOp(&prog, "cvm", std::vector<std::string>({"e", "n"}),
std::vector<std::string>({"j"}));
SetOp(&prog, "cvm", std::vector<std::string>({"g", "n"}),
std::vector<std::string>({"k"}));
SetOp(&prog, "cvm", std::vector<std::string>({"i", "n"}),
std::vector<std::string>({"l"}));
SetOp(&prog, "concat", std::vector<std::string>({"j", "k", "l"}),
std::vector<std::string>({"m"}));
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int before, after;
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
// Remove 16 Nodes: op1, op2, op3, op4, op5, op6, d, e, f, g, h, i, j, k, l,
// concat_op
// Add 1 Node: fusion_seqpool_cvm_concat
EXPECT_EQ(after, before - 15);
EXPECT_EQ(CountOpType(graph.get()), 1);
}
/*
* Before fuse:
* a b
* | / \
* op1 k op2 k op3
* / \ / / \ / \
* c d e f g
* | |
* op4 op5
* | |
* h i
* \ /
* concat
* |
* j
* Type of op1 and op2 are sequence_pool, with "SUM" pooltype attr.
* Type of op4 and op5 are cvm, with use_cvm is true.
*
* After fuse:
* a k b
* \ | / \
* fusion_seqpool_cvm_concat op3
* | |
* j g
*/
TEST(SeqPoolCVMConcatFusePass, advanced) {
ProgramDesc prog;
for (auto& v : std::vector<std::string>(
{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
}
SetOp(&prog, "sequence_pool", std::vector<std::string>({"a"}),
std::vector<std::string>({"c", "d"}));
SetOp(&prog, "sequence_pool", std::vector<std::string>({"b"}),
std::vector<std::string>({"e", "f"}));
SetOp(&prog, "op3", std::vector<std::string>({"b"}),
std::vector<std::string>({"g"}));
SetOp(&prog, "cvm", std::vector<std::string>({"d", "k"}),
std::vector<std::string>({"h"}));
SetOp(&prog, "cvm", std::vector<std::string>({"f", "k"}),
std::vector<std::string>({"i"}));
SetOp(&prog, "concat", std::vector<std::string>({"h", "i"}),
std::vector<std::string>({"j"}));
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int before, after;
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
// Remove 11 Nodes: op1, op2, op4, op5, c, d, e, f, h, i, concat_op
// Add 1 Node: fusion_seqpool_cvm_concat
EXPECT_EQ(after, before - 10);
EXPECT_EQ(CountOpType(graph.get()), 1);
}
ProgramDesc BuildProgramDesc(int num_inputs_of_concat) {
ProgramDesc prog;
auto new_var = [&](const std::string& name) {
auto* var = prog.MutableBlock(0)->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
};
std::vector<std::string> concat_inputs;
new_var("cvm_in");
for (int i = 0; i < num_inputs_of_concat; ++i) {
std::string seqpool_prefix = "seqpool_op_" + std::to_string(i);
new_var(seqpool_prefix + "in");
new_var(seqpool_prefix + "out");
new_var(seqpool_prefix + "out_unused");
SetOp(&prog, "sequence_pool",
std::vector<std::string>({seqpool_prefix + "in"}),
std::vector<std::string>(
{seqpool_prefix + "out_unused", seqpool_prefix + "out"}));
std::string cvm_prefix = "cvm_op_" + std::to_string(i);
new_var(cvm_prefix + "out");
SetOp(&prog, "cvm",
std::vector<std::string>({seqpool_prefix + "out", "cvm_in"}),
std::vector<std::string>({cvm_prefix + "out"}));
concat_inputs.push_back(cvm_prefix + "out");
}
SetOp(&prog, "concat", concat_inputs,
std::vector<std::string>({"concat_out"}));
return prog;
}
// test more inputs of concat
TEST(SeqPoolCVMConcatFusePass, more_inputs) {
for (int num : {1, 2, 10}) {
ProgramDesc prog = BuildProgramDesc(num);
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int before, after;
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
// Remove Nodes: n * (seqpool_op, seqpool_out, out_unused, cvm_op, cvm_out),
// and concat_op
// Add Node: fusion_seqpool_cvm_concat op
EXPECT_EQ(after, before - num * 5);
EXPECT_EQ(CountOpType(graph.get()), 1);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(seqpool_cvm_concat_fuse_pass);
...@@ -145,6 +145,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { ...@@ -145,6 +145,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"attention_lstm_fuse_pass", // "attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", //
// "seqpool_concat_fuse_pass", // // "seqpool_concat_fuse_pass", //
"seqpool_cvm_concat_fuse_pass", //
// "embedding_fc_lstm_fuse_pass", // // "embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", // "fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", // "mul_lstm_fuse_pass", //
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
namespace paddle {
namespace operators {
void FusionSeqPoolCVMConcatOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GE(
ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqPoolCVMConcatOp should not be empty.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSeqPoolCVMConcatOp should not be null.");
int axis = ctx->Attrs().Get<int>("axis");
PADDLE_ENFORCE_EQ(
axis, 1, "FusionSeqPoolCVMConcatOp only supports concat axis=1 yet.");
bool use_cvm = ctx->Attrs().Get<bool>("use_cvm");
PADDLE_ENFORCE_EQ(
use_cvm, true,
"FusionSeqPoolCVMConcatOp only supports use_cvm is true yet.");
auto ins_dims = ctx->GetInputsDim("X");
const size_t n = ins_dims.size();
PADDLE_ENFORCE_GT(n, 0UL, "Input tensors count should > 0.");
if (n == 1) {
LOG(WARNING) << "Only have one input, may waste memory";
}
// The output height should be confirmed in Compute,
// since input lod is not accessible here.
PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2,
"The dims size of first input should be 2.");
ctx->SetOutputDim("Out", {-1, ins_dims[0][axis] * static_cast<int>(n)});
}
framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace());
}
void FusionSeqPoolCVMConcatOpMaker::Make() {
AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable();
AddInput("CVM",
"(Tensor), a 2-D Tensor with shape [N x 2], where N is the batch "
"size, 2 is show and click.");
AddOutput("Out", "(LoDTensor) Output tensor of concat operator.");
AddAttr<std::string>("pooltype",
"(string, default 'SUM') some of the pooling "
"pooltype of SequencePoolOp.")
.SetDefault("SUM")
.InEnum({"AVERAGE", "SUM", "SQRT"});
AddAttr<bool>("use_cvm", "bool, use cvm or not").SetDefault(true);
AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated. "
"Only supports concat axis=1 yet.")
.SetDefault(1);
AddComment(R"DOC(
Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
)DOC");
}
template <typename T>
class FusionSeqPoolCVMConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
std::string pooltype = ctx.Attr<std::string>("pooltype");
auto x0_lod = ins[0]->lod();
auto x0_dims = ins[0]->dims();
auto y_dims = out->dims();
size_t bs = x0_lod[0].size() - 1;
out->Resize({static_cast<int64_t>(bs), y_dims[1]});
framework::LoD y_lod(1);
y_lod[0].resize(bs + 1);
for (size_t i = 0; i <= bs; ++i) {
y_lod[0][i] = i;
}
out->set_lod(y_lod);
auto place = ctx.GetPlace();
T* y_data = out->mutable_data<T>(place);
int w = ins[0]->numel() / x0_dims[0];
PADDLE_ENFORCE_EQ(y_dims[1] % w, 0,
"The output of dims[1] should be dividable of w");
jit::seq_pool_attr_t attr(w, jit::SeqPoolType::kSum);
if (pooltype == "AVERAGE") {
attr.type = jit::SeqPoolType::kAvg;
} else if (pooltype == "SQRT") {
attr.type = jit::SeqPoolType::kSqrt;
}
auto seqpool =
jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(
attr);
size_t n = ins.size();
size_t dst_step_size = n * w;
for (size_t i = 0; i < n; ++i) {
auto x_dims = ins[i]->dims();
auto x_lod = ins[i]->lod()[0];
const T* src = ins[i]->data<T>();
T* dst = y_data + i * w;
PADDLE_ENFORCE_EQ(static_cast<int>(ins[i]->numel() / x_dims[0]), w,
"Width of all inputs should be equal.");
PADDLE_ENFORCE_EQ(x_lod.size(), bs + 1,
"Batchsize of all inputs should be equal.");
for (size_t j = 0; j < bs; ++j) {
attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]);
seqpool(src, dst, &attr);
// Currently only use_cvm is true.
dst[0] = log(dst[0] + 1);
dst[1] = log(dst[1] + 1) - dst[0];
dst += dst_step_size;
src += attr.h * attr.w;
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seqpool_cvm_concat, ops::FusionSeqPoolCVMConcatOp,
ops::FusionSeqPoolCVMConcatOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_seqpool_cvm_concat,
ops::FusionSeqPoolCVMConcatKernel<float>,
ops::FusionSeqPoolCVMConcatKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionSeqPoolCVMConcatOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionSeqPoolCVMConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from test_reorder_lod_tensor import convert_to_offset
from test_seq_pool import compute_seqpool_sum, compute_seqpool_avg, compute_seqpool_sqrt
from test_cvm_op import cvm_compute
class TestFusionSeqPoolCVMConcatOp(OpTest):
def setUp(self):
self.w = 11
self.use_cvm = True
self.lods = [[[2, 3, 5]], [[1, 5, 2]]]
self.set_conf()
self.set_pooltype()
self.op_type = 'fusion_seqpool_cvm_concat'
self.axis = 1
bs = len(self.lods[0][0])
inputs = []
outs = []
# The cvm variable is not actually used.
cvm = np.array([[0.6, 0.4]]).astype("float32")
i = 0
for lod in self.lods:
assert bs == len(lod[0]), 'All lod size should be equal'
x = np.random.uniform(0.1, 1,
[sum(lod[0]), self.w]).astype('float32')
offset = convert_to_offset(lod)
out = np.zeros((bs, self.w)).astype('float32')
if self.pooltype == "SUM":
compute_seqpool_sum(x, offset, out)
out = cvm_compute(out, self.w, self.use_cvm)
elif self.pooltype == "AVERAGE":
compute_seqpool_avg(x, offset, out)
out = cvm_compute(out, self.w, self.use_cvm)
elif self.pooltype == "SQRT":
compute_seqpool_sqrt(x, offset, out)
out = cvm_compute(out, self.w, self.use_cvm)
else:
raise Exception("Unsupported pool type!")
inputs.append(('x_{0}'.format(i), (x, lod)))
outs.append(out)
i = i + 1
self.inputs = {'X': inputs, "CVM": cvm}
self.outputs = {'Out': np.concatenate(outs, axis=self.axis)}
self.attrs = {
'pooltype': self.pooltype,
'axis': self.axis,
}
def set_pooltype(self):
self.pooltype = "SUM"
def set_conf(self):
pass
def test_check_output(self):
self.check_output()
class TestFusionSeqPoolCVMConcatOpCase1(TestFusionSeqPoolCVMConcatOp):
def set_conf(self):
self.lods = [[[1]]]
class TestFusionSeqPoolCVMConcatOpCase2(TestFusionSeqPoolCVMConcatOp):
def set_conf(self):
self.lods = [[[1]], [[1]], [[1]]]
class TestFusionSeqPoolCVMConcatOpCase3(TestFusionSeqPoolCVMConcatOp):
def set_conf(self):
self.lods = [[[1, 3, 4, 6]]]
self.w = 10
class TestFusionSeqPoolCVMConcatOpCase4(TestFusionSeqPoolCVMConcatOp):
def set_conf(self):
self.lods = [[[2, 13, 4]], [[1, 1, 1]], [[5, 3, 1]], [[9, 10, 3]]]
self.w = 3
## test avg pool and sqrt
def create_test_avg_sqrt_class(parent):
class TestSeqPoolAvgCase(parent):
def set_pooltype(self):
self.pooltype = "AVERAGE"
class TestSeqPoolSqrtCase(parent):
def set_pooltype(self):
self.pooltype = "SQRT"
cls_name_avg = "{0}_{1}".format(parent.__name__, "avg")
cls_name_sqrt = "{0}_{1}".format(parent.__name__, "sqrt")
TestSeqPoolAvgCase.__name__ = cls_name_avg
TestSeqPoolSqrtCase.__name__ = cls_name_sqrt
globals()[cls_name_avg] = TestSeqPoolAvgCase
globals()[cls_name_sqrt] = TestSeqPoolSqrtCase
create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOp)
create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase1)
create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase2)
create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase3)
create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase4)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册