提交 7ab5626d 编写于 作者: J Jacek Czaja

- Added initial pass for embedding-fc-lstm

- Added draft of new operator

- Added fused embedding fc lstm files

- First time embedding_fc_lstm_fuse_pass was invoked in
  test_text_classification

- Added Embedding pattern

- Not crashing

- Enabled draft of embedding_fc_lstm pass (does it job)

- First working (Seqcompute only) version

- Removed diagnostic comment

- First enabling of BatchCompute

- Disabling pass for embedding with is_sparse and is_distributed

- Cosmetics

- Style

- Style
上级 4e81e228
...@@ -34,6 +34,7 @@ endif() ...@@ -34,6 +34,7 @@ endif()
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference) pass_library(infer_clean_graph_pass inference)
pass_library(fc_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference) pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace framework {
namespace ir {
static int BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Build pattern
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x"))
->assert_is_op_input("lookup_table")
->assert_var_not_persistable();
patterns::Embedding embedding_pattern(pattern, name_scope);
// TODO(jczaja): Intermediate can only be for val that are not used anywhere
// but lookup table output may go into other LSTM (for reverse
// direction)
auto* embedding_out = embedding_pattern(x);
patterns::FC fc_pattern(pattern, name_scope);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto* fc_out = fc_pattern(embedding_out, with_fc_bias)->AsIntermediate();
patterns::LSTM lstm_pattern(pattern, name_scope);
lstm_pattern(fc_out);
// Create New OpDesc
auto embedding_lstm_creator = [&](Node* embedding, Node* W, Node* lstm,
Node* input, Node* weight_x, Node* weight_h,
Node* bias, Node* hidden, Node* cell,
Node* xx, Node* fc_bias) {
OpDesc op_desc;
op_desc.SetType("fused_embedding_fc_lstm");
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN(Ids, input);
SET_IN(WeightH, weight_h);
// Neet to have this passed as We need Wc data for peephole connections
SET_IN(Bias, bias);
#undef SET_IN
// Multiply embeddings with Weights
PADDLE_ENFORCE(scope);
const std::string& embeddings = patterns::UniqueKey("Embeddings");
auto* embeddings_var = scope->Var(embeddings);
PADDLE_ENFORCE(embeddings_var);
auto* embeddings_tensor =
embeddings_var->GetMutable<framework::LoDTensor>();
// Get WeightX size: [single_embedding, fc_size]
// and embedding size: [dict_size, single_embedding]
// and create new size of embeddings eg. [dict_size , hidden_size]
auto* embedding_var = scope->FindVar(W->Name());
PADDLE_ENFORCE(embedding_var);
const auto& embedding_tensor = embedding_var->Get<framework::LoDTensor>();
const auto& weightx_tensor =
scope->FindVar(weight_x->Name())->Get<framework::LoDTensor>();
embeddings_tensor->Resize(
{embedding_tensor.dims()[0], weightx_tensor.dims()[1]});
// Multiplie embeddings via WeightsX and add bias
auto embedding_data = embedding_tensor.data<float>();
auto weightx_data = weightx_tensor.data<float>();
auto embeddings_data =
embeddings_tensor->mutable_data<float>(platform::CPUPlace());
// Adding biases to GEMM result to be
auto* lstm_bias_var = scope->FindVar(bias->Name());
PADDLE_ENFORCE(lstm_bias_var);
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
auto alpha = 1.0f;
auto beta = 1.0f;
int m = embedding_tensor.dims()[0];
int n = weightx_tensor.dims()[1];
int k = embedding_tensor.dims()[1];
// Copy only gate biases values (only actual bias data, not peephole
// weights)
std::vector<float> combined_biases(n, 0.0f);
memcpy(&combined_biases[0], lstm_bias_tensor.data<float>(),
n * sizeof(float));
if (with_fc_bias) {
// Add FC-bias with LSTM-bias (into GEMM result to be)
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
for (int i = 0; i < fc_bias_tensor.numel(); i++) {
combined_biases[i] =
lstm_bias_tensor.data<float>()[i] + fc_bias_tensor.data<float>()[i];
}
}
// broadcast biases
std::vector<float> ones(m, 1.0f);
paddle::operators::math::CBlas<float>::GEMM(
CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, 1, alpha, &ones[0], 1,
&combined_biases[0], n, 0.0f, embeddings_data, n);
// Wx*embeddings
paddle::operators::math::CBlas<float>::GEMM(
CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha,
embedding_data, k, weightx_data, n, beta, embeddings_data, n);
op_desc.SetInput("Embeddings", {embeddings});
// Create temp variables.
const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
const std::string BatchedCellPreAct =
patterns::UniqueKey("BatchedCellPreAct");
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetOutput("Cell", {cell->Name()});
op_desc.SetOutput("XX", {xx->Name()});
op_desc.SetOutput("BatchedGate", {BatchedGate});
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
op_desc.SetOutput("BatchedInput", {BatchedInput});
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
OP_SET_OUT(ReorderedC0);
#undef OP_SET_OUT
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op);
IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias, op);
IR_NODE_LINK_TO(op, hidden);
return op;
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, embedding_pattern);
GET_IR_NODE_FROM_SUBGRAPH(W, W, embedding_pattern);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
// TODO(jczaja): Add support for is_sparse / is_distributed
auto is_sparse = boost::get<bool>(lookup_table->Op()->GetAttr("is_sparse"));
auto is_distributed =
boost::get<bool>(lookup_table->Op()->GetAttr("is_distributed"));
if (is_sparse == true || is_distributed == true) {
return;
}
if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
Bias, Hidden, Cell, fc_out, fc_bias);
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of loopup table
std::unordered_set<const Node*> marked_nodes(
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
{mul, lstm, elementwise_add, fc_bias});
GraphSafeRemoveNodes(graph, marked_nodes);
} else {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
Bias, Hidden, Cell, fc_out, nullptr);
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of loopup table
// std::unordered_set<const Node*> marked_nodes({lookup_table, W, mul,
// lstm});
std::unordered_set<const Node*> marked_nodes({mul, lstm});
GraphSafeRemoveNodes(graph, marked_nodes);
}
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(embedding_fc_lstm_fuse_pass,
paddle::framework::ir::EmbeddingFCLSTMFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
// Fusing of Embedding , FC and LSTM op
// Just FC without bias
class EmbeddingFCLSTMFusePass : public FusePassBase {
public:
virtual ~EmbeddingFCLSTMFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"embedding_fc_lstm_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -692,6 +692,24 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, ...@@ -692,6 +692,24 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
} }
} }
PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op =
pattern->NewNode(lookup_table_repr())->assert_is_op("lookup_table");
#define NEW_NODE(arg__, io__) \
auto *arg__ = pattern->NewNode(arg__##_repr()) \
->assert_is_op_##io__("lookup_table", #arg__);
NEW_NODE(W, input);
NEW_NODE(Out, output);
#undef NEW_NODE
lookup_table_op->LinksFrom({x, W});
lookup_table_op->LinksTo({Out});
return Out;
}
PDNode *patterns::LSTM::operator()(PDNode *x) { PDNode *patterns::LSTM::operator()(PDNode *x) {
x->assert_is_op_input("lstm", "Input"); x->assert_is_op_input("lstm", "Input");
auto *lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm"); auto *lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm");
......
...@@ -418,6 +418,23 @@ struct FC : public PatternBase { ...@@ -418,6 +418,23 @@ struct FC : public PatternBase {
PATTERN_DECL_NODE(Out); PATTERN_DECL_NODE(Out);
}; };
// Embedding
struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding") {}
PDNode* operator()(PDNode* x);
// declare operator node's name
PATTERN_DECL_NODE(lookup_table);
// Inputs
//
PATTERN_DECL_NODE(Ids);
PATTERN_DECL_NODE(W); // embeddings
// Outputs
PATTERN_DECL_NODE(Out);
};
struct LSTM : public PatternBase { struct LSTM : public PatternBase {
LSTM(PDPattern* pattern, const std::string& name_scope) LSTM(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "lstm") {} : PatternBase(pattern, name_scope, "lstm") {}
......
...@@ -66,6 +66,7 @@ class Analyzer : public OrderedRegistry<PassManager> { ...@@ -66,6 +66,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
// Manual update the passes here. // Manual update the passes here.
"infer_clean_graph_pass", // "infer_clean_graph_pass", //
"attention_lstm_fuse_pass", // "attention_lstm_fuse_pass", //
"embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", // "fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", // "mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", // "fc_gru_fuse_pass", //
......
此差异已折叠。
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusedEmbeddingFCLSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusedEmbeddingFCLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册