diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 4dca3ceb4569fb708c7a98621c5239acbe217586..01733fdda20a2c3ac1598c1e380a955d4fae0535 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -34,6 +34,7 @@ endif() pass_library(attention_lstm_fuse_pass inference) pass_library(infer_clean_graph_pass inference) pass_library(fc_lstm_fuse_pass inference) +pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..38495125c346cbaa631b4891822201413996b136 --- /dev/null +++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc @@ -0,0 +1,242 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h" +#include +#include "paddle/fluid/framework/lod_tensor.h" + +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace framework { +namespace ir { + +static int BuildFusion(Graph* graph, const std::string& name_scope, + Scope* scope, bool with_fc_bias) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Build pattern + PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x")) + ->assert_is_op_input("lookup_table") + ->assert_var_not_persistable(); + patterns::Embedding embedding_pattern(pattern, name_scope); + // TODO(jczaja): Intermediate can only be for val that are not used anywhere + // but lookup table output may go into other LSTM (for reverse + // direction) + auto* embedding_out = embedding_pattern(x); + patterns::FC fc_pattern(pattern, name_scope); + + // fc_out is a tmp var, will be removed after fuse, so marked as intermediate. + auto* fc_out = fc_pattern(embedding_out, with_fc_bias)->AsIntermediate(); + patterns::LSTM lstm_pattern(pattern, name_scope); + lstm_pattern(fc_out); + + // Create New OpDesc + auto embedding_lstm_creator = [&](Node* embedding, Node* W, Node* lstm, + Node* input, Node* weight_x, Node* weight_h, + Node* bias, Node* hidden, Node* cell, + Node* xx, Node* fc_bias) { + OpDesc op_desc; + op_desc.SetType("fused_embedding_fc_lstm"); +#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()}); + SET_IN(Ids, input); + SET_IN(WeightH, weight_h); + // Neet to have this passed as We need Wc data for peephole connections + SET_IN(Bias, bias); +#undef SET_IN + + // Multiply embeddings with Weights + PADDLE_ENFORCE(scope); + const std::string& embeddings = patterns::UniqueKey("Embeddings"); + auto* embeddings_var = scope->Var(embeddings); + PADDLE_ENFORCE(embeddings_var); + auto* embeddings_tensor = + embeddings_var->GetMutable(); + // Get WeightX size: [single_embedding, fc_size] + // and embedding size: [dict_size, single_embedding] + // and create new size of embeddings eg. [dict_size , hidden_size] + auto* embedding_var = scope->FindVar(W->Name()); + PADDLE_ENFORCE(embedding_var); + const auto& embedding_tensor = embedding_var->Get(); + + const auto& weightx_tensor = + scope->FindVar(weight_x->Name())->Get(); + embeddings_tensor->Resize( + {embedding_tensor.dims()[0], weightx_tensor.dims()[1]}); + + // Multiplie embeddings via WeightsX and add bias + auto embedding_data = embedding_tensor.data(); + auto weightx_data = weightx_tensor.data(); + auto embeddings_data = + embeddings_tensor->mutable_data(platform::CPUPlace()); + + // Adding biases to GEMM result to be + auto* lstm_bias_var = scope->FindVar(bias->Name()); + PADDLE_ENFORCE(lstm_bias_var); + const auto& lstm_bias_tensor = lstm_bias_var->Get(); + + auto alpha = 1.0f; + auto beta = 1.0f; + int m = embedding_tensor.dims()[0]; + int n = weightx_tensor.dims()[1]; + int k = embedding_tensor.dims()[1]; + + // Copy only gate biases values (only actual bias data, not peephole + // weights) + std::vector combined_biases(n, 0.0f); + memcpy(&combined_biases[0], lstm_bias_tensor.data(), + n * sizeof(float)); + + if (with_fc_bias) { + // Add FC-bias with LSTM-bias (into GEMM result to be) + auto* fc_bias_var = scope->FindVar(fc_bias->Name()); + const auto& fc_bias_tensor = fc_bias_var->Get(); + for (int i = 0; i < fc_bias_tensor.numel(); i++) { + combined_biases[i] = + lstm_bias_tensor.data()[i] + fc_bias_tensor.data()[i]; + } + } + + // broadcast biases + std::vector ones(m, 1.0f); + paddle::operators::math::CBlas::GEMM( + CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, 1, alpha, &ones[0], 1, + &combined_biases[0], n, 0.0f, embeddings_data, n); + + // Wx*embeddings + paddle::operators::math::CBlas::GEMM( + CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, + embedding_data, k, weightx_data, n, beta, embeddings_data, n); + op_desc.SetInput("Embeddings", {embeddings}); + + // Create temp variables. + const std::string BatchedInput = patterns::UniqueKey("BatchedInput"); + const std::string BatchedCellPreAct = + patterns::UniqueKey("BatchedCellPreAct"); + const std::string BatchedGate = patterns::UniqueKey("BatchedGate"); + + scope->Var(BatchedInput)->GetMutable(); + scope->Var(BatchedCellPreAct)->GetMutable(); + scope->Var(BatchedGate)->GetMutable(); + + op_desc.SetInput("H0", {}); + op_desc.SetInput("C0", {}); + op_desc.SetOutput("Hidden", {hidden->Name()}); + op_desc.SetOutput("Cell", {cell->Name()}); + op_desc.SetOutput("XX", {xx->Name()}); + op_desc.SetOutput("BatchedGate", {BatchedGate}); + op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct}); + op_desc.SetOutput("BatchedInput", {BatchedInput}); + op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse")); + op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes")); + // TODO(TJ): get from attr + op_desc.SetAttr("use_seq", true); + + PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); + auto* scope = graph->Get(kParamScopeAttr); +#define OP_SET_OUT(x) \ + const std::string x = patterns::UniqueKey(#x); \ + op_desc.SetOutput(#x, {x}); \ + scope->Var(x)->GetMutable() + OP_SET_OUT(BatchedCell); + OP_SET_OUT(BatchedHidden); + OP_SET_OUT(ReorderedH0); + OP_SET_OUT(ReorderedC0); +#undef OP_SET_OUT + + auto* op = graph->CreateOpNode(&op_desc); + IR_NODE_LINK_TO(input, op); + IR_NODE_LINK_TO(weight_x, op); + IR_NODE_LINK_TO(weight_h, op); + IR_NODE_LINK_TO(bias, op); + IR_NODE_LINK_TO(op, hidden); + return op; + }; + + int fusion_count{0}; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, embedding_pattern); + GET_IR_NODE_FROM_SUBGRAPH(W, W, embedding_pattern); + GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); + + // TODO(jczaja): Add support for is_sparse / is_distributed + auto is_sparse = boost::get(lookup_table->Op()->GetAttr("is_sparse")); + auto is_distributed = + boost::get(lookup_table->Op()->GetAttr("is_distributed")); + + if (is_sparse == true || is_distributed == true) { + return; + } + + if (with_fc_bias) { + GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); + embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight, + Bias, Hidden, Cell, fc_out, fc_bias); + // Remove unneeded nodes. + // TODO(jczaja): Proper removing of loopup table + std::unordered_set marked_nodes( + //{lookup_table, mul, lstm, elementwise_add, fc_bias, W}); + {mul, lstm, elementwise_add, fc_bias}); + GraphSafeRemoveNodes(graph, marked_nodes); + } else { + GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern); + embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight, + Bias, Hidden, Cell, fc_out, nullptr); + // Remove unneeded nodes. + // TODO(jczaja): Proper removing of loopup table + // std::unordered_set marked_nodes({lookup_table, W, mul, + // lstm}); + std::unordered_set marked_nodes({mul, lstm}); + GraphSafeRemoveNodes(graph, marked_nodes); + } + + ++fusion_count; + }; + + gpd(graph, handler); + + return fusion_count; +} + +std::unique_ptr EmbeddingFCLSTMFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + + int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), + true /*with_fc_bias*/); + + AddStatis(fusion_count); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(embedding_fc_lstm_fuse_pass, + paddle::framework::ir::EmbeddingFCLSTMFusePass); diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..e5ad3067ec4060e41f1464395f3fc76183de3e66 --- /dev/null +++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h @@ -0,0 +1,40 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +// Fusing of Embedding , FC and LSTM op + +// Just FC without bias +class EmbeddingFCLSTMFusePass : public FusePassBase { + public: + virtual ~EmbeddingFCLSTMFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + const std::string name_scope_{"embedding_fc_lstm_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6d2c51b0e9bed8461f6491b84a36a3bf6663a138..46c6a52c09e896596aa6d8e1e901955a68a4957d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -692,6 +692,24 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, } } +PDNode *patterns::Embedding::operator()(PDNode *x) { + x->assert_is_op_input("lookup_table", "Ids"); + auto *lookup_table_op = + pattern->NewNode(lookup_table_repr())->assert_is_op("lookup_table"); +#define NEW_NODE(arg__, io__) \ + auto *arg__ = pattern->NewNode(arg__##_repr()) \ + ->assert_is_op_##io__("lookup_table", #arg__); + + NEW_NODE(W, input); + + NEW_NODE(Out, output); +#undef NEW_NODE + + lookup_table_op->LinksFrom({x, W}); + lookup_table_op->LinksTo({Out}); + return Out; +} + PDNode *patterns::LSTM::operator()(PDNode *x) { x->assert_is_op_input("lstm", "Input"); auto *lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 69b486c29d8bd1102a8372f5041051c25ce19359..508113bf4fcab274394f2705c36eddbf4ba3c77a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -418,6 +418,23 @@ struct FC : public PatternBase { PATTERN_DECL_NODE(Out); }; +// Embedding +struct Embedding : public PatternBase { + Embedding(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "embedding") {} + + PDNode* operator()(PDNode* x); + + // declare operator node's name + PATTERN_DECL_NODE(lookup_table); + // Inputs + // + PATTERN_DECL_NODE(Ids); + PATTERN_DECL_NODE(W); // embeddings + // Outputs + PATTERN_DECL_NODE(Out); +}; + struct LSTM : public PatternBase { LSTM(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "lstm") {} diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index 9bdbefc07cbc4bf7a4714927c84855837610430e..0aa9367bf5692e53e2a1f1247523cf9a4f0b3a1d 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -64,14 +64,15 @@ class Analyzer : public OrderedRegistry { // larger fusion. const std::vector all_ir_passes_{{ // Manual update the passes here. - "infer_clean_graph_pass", // - "attention_lstm_fuse_pass", // - "fc_lstm_fuse_pass", // - "mul_lstm_fuse_pass", // - "fc_gru_fuse_pass", // - "mul_gru_fuse_pass", // - "seq_concat_fc_fuse_pass", // - "fc_fuse_pass", // + "infer_clean_graph_pass", // + "attention_lstm_fuse_pass", // + "embedding_fc_lstm_fuse_pass", // + "fc_lstm_fuse_pass", // + "mul_lstm_fuse_pass", // + "fc_gru_fuse_pass", // + "mul_gru_fuse_pass", // + "seq_concat_fc_fuse_pass", // + "fc_fuse_pass", // #ifdef PADDLE_WITH_MKLDNN "conv_relu_mkldnn_fuse_pass", // #endif diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c4cc7745256b10db43dfbdc980ed19511bdd39f --- /dev/null +++ b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc @@ -0,0 +1,608 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused_embedding_fc_lstm_op.h" +#include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { + +void FusedEmbeddingFCLSTMOp::InferShape( + framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Embeddings"), + "Assert only one Input(Embeddings) of LSTM."); + PADDLE_ENFORCE(ctx->HasInput("WeightH"), + "Assert only one Input(WeightH) of LSTM."); + PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Assert only one Output(Hidden) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("Cell"), + "Assert only one Output(Cell) of LSTM."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), + "Input(Ids) of LookupTableOp should not be null."); + + auto table_dims = ctx->GetInputDim("Embeddings"); + auto ids_dims = ctx->GetInputDim("Ids"); + int ids_rank = ids_dims.size(); + + PADDLE_ENFORCE_EQ(table_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1, + "The last dimension of the 'Ids' tensor must be 1."); + + auto x_dims = ctx->GetInputDim("Ids"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(Ids)'s rank must be 2."); + + if (ctx->HasInput("H0")) { + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(Cell) and Input(Hidden) of LSTM should not " + "be null at the same time."); + auto h_dims = ctx->GetInputDim("H0"); + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + auto embeddings_dims = ctx->GetInputDim("Embeddings"); + PADDLE_ENFORCE_EQ(embeddings_dims.size(), 2, + "The rank of Input(Embeddings) should be 2."); + // PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], + // "The first dimension of Input(Embeddings) " + // "should be %d.", + // x_dims[1]); + + auto wh_dims = ctx->GetInputDim("WeightH"); + int frame_size = wh_dims[1] / 4; + PADDLE_ENFORCE_EQ(wh_dims.size(), 2, + "The rank of Input(WeightH) should be 2."); + PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, + "The first dimension of Input(WeightH) " + "should be %d.", + frame_size); + PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size, + "The second dimension of Input(WeightH) " + "should be 4 * %d.", + frame_size); + + auto b_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, + "The first dimension of Input(Bias) should be 1."); + PADDLE_ENFORCE_EQ( + b_dims[1], (ctx->Attrs().Get("use_peepholes") ? 7 : 4) * frame_size, + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection or" + "4 * %d if disable peepholes", + frame_size, frame_size); + + framework::DDim out_dims({x_dims[0], frame_size}); + ctx->SetOutputDim("Hidden", out_dims); + ctx->SetOutputDim("Cell", out_dims); + ctx->ShareLoD("Ids", "Hidden"); + ctx->ShareLoD("Ids", "Cell"); + int xx_width; + if (ctx->Attrs().Get("use_seq")) { + xx_width = wh_dims[1]; + } else { + xx_width = x_dims[1] > wh_dims[1] ? wh_dims[1] : x_dims[1]; + PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), + "Assert only one Output(BatchedInput) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), + "Assert only one Output(BatchedHidden) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), + "Assert only one Output(BatchedCell) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), + "Assert only one Output(ReorderedH0) of LSTM"); + PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), + "Assert only one Output(ReorderedC0) of LSTM."); + ctx->SetOutputDim("BatchedInput", {x_dims[0], wh_dims[1]}); + ctx->SetOutputDim("BatchedHidden", out_dims); + ctx->SetOutputDim("BatchedCell", out_dims); + } + ctx->SetOutputDim("XX", {x_dims[0], xx_width}); + ctx->ShareLoD("Ids", "XX"); +} + +framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + framework::ToDataType( + ctx.Input("Embeddings")->type()), + ctx.device_context()); +} + +void FusedEmbeddingFCLSTMOpMaker::Make() { + AddInput("Ids", + "An input with type int32 or int64 " + "contains the ids to be looked up in W. " + "The last dimension size must be 1."); + AddInput("Embeddings", + "(Tensor) the learnable weights of X." + " - The shape is (M x 4D), where M is the dim size of x, D is the " + "hidden size. " + " - Weight = {W_cx, W_ix, W_fx, W_ox}"); + AddInput("WeightH", + "(Tensor) same as LSTMOp, the learnable hidden-hidden weights." + " - The shape is (D x 4D), where D is the hidden size. " + " - Weight = {W_ch, W_ih, W_fh, W_oh}"); + AddInput("Bias", + "(Tensor) the learnable weights. Almost same as LSTMOp" + "Note: we should add the fc bias into this (1x4D) in bias." + "input-hidden bias weight and peephole connections weight if " + "setting `use_peepholes` True. " + "1. `use_peepholes = False` " + " - The shape is (1 x 4D). " + " - Bias = {b_c, b_i, b_f, b_o}." + "2. `use_peepholes = True` " + " - The shape is (1 x 7D). " + " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); + AddInput("H0", + "(Tensor, optional) (same as LSTMOp) the initial hidden state is an " + "optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size and D is the hidden size.") + .AsDispensable(); + AddInput("C0", + "(Tensor, optional) (same as LSTMOp) (the initial cell state is an " + "optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size. `H0` and `C0` can be NULL but only at the same time.") + .AsDispensable(); + AddOutput("Hidden", + "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("XX", + "(LoDTensor) the result after X * WeightX (size is T x 4D)" + " or batched_X (size is T x M), this will be automatically chosen," + " where T is the total time steps in this mini-batch," + " D is the hidden size, M is the dim size of x input.") + .AsIntermediate(); + AddOutput("BatchedInput", "(LoDTensor) (T x 4D).").AsIntermediate(); + AddOutput("BatchedHidden", "(LoDTensor) (T x D).").AsIntermediate(); + AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate(); + AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate(); + AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate(); + AddAttr("use_peepholes", + "(bool, defalut: True) " + "whether to enable diagonal/peephole connections.") + .SetDefault(true); + AddAttr("is_reverse", + "(bool, defalut: False) " + "whether to compute reversed LSTM.") + .SetDefault(false); + AddAttr("use_seq", + "(bool, defalut: True) " + "whether to use seq mode to compute.") + .SetDefault(true); + AddAttr("gate_activation", + "(string, default: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by default.") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("cell_activation", + "(string, default: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("candidate_activation", + "(string, default: tanh)" + "The activation for candidate hidden state, " + "`tanh` by default.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddComment(R"DOC( +Fusion Long-Short Term Memory (LSTM) Operator. +This operator fuse the X into LSTM, more details can refer to LSTM op. +)DOC"); +} + +template +class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { + public: +#define INIT_VEC_FUNC \ + std::function act_gate, act_cell, act_cand; \ + auto& act_gate_str = ctx.Attr("gate_activation"); \ + auto& act_cell_str = ctx.Attr("cell_activation"); \ + auto& act_cand_str = ctx.Attr("candidate_activation"); \ + if (platform::jit::MayIUse(platform::jit::avx)) { \ + math::VecActivations act_functor; \ + act_gate = act_functor(act_gate_str); \ + act_cell = act_functor(act_cell_str); \ + act_cand = act_functor(act_cand_str); \ + } else { \ + math::VecActivations act_functor; \ + act_gate = act_functor(act_gate_str); \ + act_cell = act_functor(act_cell_str); \ + act_cand = act_functor(act_cand_str); \ + } + +#define INIT_BASE_INPUT_OUTPUT \ + auto* ids = ctx.Input("Ids"); \ + auto* h0 = ctx.Input("H0"); \ + auto* c0 = ctx.Input("C0"); \ + auto* embeddings = ctx.Input("Embeddings"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* bias = ctx.Input("Bias"); \ + auto* xx = ctx.Output("XX"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + auto* cell_out = ctx.Output("Cell"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + bool use_peepholes = ctx.Attr("use_peepholes"); + +#define INIT_BASE_SIZES \ + auto ids_dims = ids->dims(); /* T x M*/ \ + auto ids_numel = ids->numel(); /* T x 1*/ \ + auto wh_dims = wh->dims(); /* D x 4D*/ \ + const int D = wh_dims[0]; \ + const int D2 = D * 2; \ + const int D3 = D * 3; \ + int64_t row_number = embeddings->dims()[0]; \ + int64_t row_width = embeddings->dims()[1]; \ + const int D4 = wh_dims[1]; + +#define INIT_BASE_INPUT_DATAS \ + const int64_t* ids_data = ids->data(); \ + const T* embeddings_data = embeddings->data(); \ + const T* wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T* wc_data = bias->data() + D4; \ + /* for peephole only*/ \ + Tensor checked_cell; \ + T* checked_cell_data = nullptr; \ + auto place = ctx.GetPlace(); \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + checked_cell_data = checked_cell.mutable_data({2, D}, place); \ + } + +/// Compute LSTM +#define GEMM_WH_ADDON(bs, prev, out) \ + blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast(1), prev, D, \ + wh_data, D4, static_cast(1), out, D4) + +// gates: W_ch, W_ih, W_fh, W_oh +#define GET_Ct(ct_1, gates, ct) \ + /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ + act_cand(D, gates, gates); \ + blas.VMUL(D, gates, gates + D, gates + D); \ + blas.VMUL(D, ct_1, gates + D2, gates + D2); \ + blas.VADD(D, gates + D, gates + D2, ct) + +#define GET_Ht(ct, gates, ht) \ + /* H_t = act_cell(C_t) * ogated */ \ + act_cell(D, ct, gates + D2); \ + blas.VMUL(D, gates + D2, gates + D3, ht) + +#define GET_Ct_NOH0C0(gates, ct) \ + /* C_t = igated * cgated*/ \ + act_gate(D, gates + D, gates + D); \ + act_cand(D, gates, gates); \ + blas.VMUL(D, gates, gates + D, ct) + +#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \ + GET_Ct_NOH0C0(gates, ct); \ + act_gate(D, gates + D3, gates + D3); \ + GET_Ht(ct, gates, ht) + +#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \ + GET_Ct_NOH0C0(gates, ct); \ + /* get outgated, put W_oc * C_t on igated */ \ + blas.VMUL(D, wc_data + D2, ct, gates + D); \ + blas.VADD(D, gates + D, gates + D3, gates + D3); \ + act_gate(D, gates + D3, gates + D3); \ + GET_Ht(ct, gates, ht) + +#define COMPUTE_CtHt(gates, ct_1, ct, ht) \ + act_gate(D3, gates + D, gates + D); \ + GET_Ct(ct_1, gates, ct); \ + GET_Ht(ct, gates, ht) + +#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \ + /* get fgated and igated*/ \ + blas.VMUL(D, wc_data, ct_1, checked_cell_data); \ + blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \ + blas.VADD(D2, checked_cell_data, gates + D, gates + D); \ + act_gate(D2, gates + D, gates + D); \ + GET_Ct(ct_1, gates, ct); \ + /* get ogated*/ \ + blas.VMUL(D, wc_data + D2, ct, gates + D); \ + blas.VADD(D, gates + D, gates + D3, gates + D3); \ + act_gate(D, gates + D3, gates + D3); \ + GET_Ht(ct, gates, ht) + + void SeqCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + INIT_BASE_INPUT_OUTPUT + INIT_BASE_SIZES + INIT_VEC_FUNC + INIT_BASE_INPUT_DATAS + + // std::cout << "====> SeqCompute" << std::endl; + auto ids_lod = ids->lod(); + const int total_T = ids_dims[0]; + const int N = ids_lod[0].size() - 1; + const T* h0_data = h0 ? h0->data() : nullptr; + const T* c0_data = c0 ? c0->data() : nullptr; + T* xx_data = xx->mutable_data(place); + T* h_out_data = hidden_out->mutable_data(place); + T* c_out_data = cell_out->mutable_data(place); + auto blas = math::GetBlas(ctx); + + for (int64_t i = 0; i < ids_numel; ++i) { + PADDLE_ENFORCE_LT(ids_data[i], row_number); + PADDLE_ENFORCE_GE(ids_data[i], 0, "ids %d", i); + memcpy(xx_data + i * row_width, embeddings_data + ids_data[i] * row_width, + row_width * sizeof(T)); + } + + int xx_offset = D4; + int gate_offset = D; + if (is_reverse) { + const int offset = (total_T - 1) * D; + xx_data = xx_data + offset * 4; + h_out_data = h_out_data + offset; + c_out_data = c_out_data + offset; + xx_offset = -D4; + gate_offset = -D; + } + +#define MOVE_ONE_STEP \ + prev_h_data = h_out_data; \ + prev_c_data = c_out_data; \ + xx_data = xx_data + xx_offset; \ + h_out_data = h_out_data + gate_offset; \ + c_out_data = c_out_data + gate_offset + +#define PROCESS_H0C0_DEFINES \ + int bid = is_reverse ? N - 1 - i : i; \ + int seq_len = ids_lod[0][bid + 1] - ids_lod[0][bid]; \ + const T* prev_c_data = nullptr; \ + const T* prev_h_data = nullptr; \ + int tstart = 0 + +#define PROCESS_H0C0_PEEPHOLE \ + PROCESS_H0C0_DEFINES; \ + if (h0_data) { \ + prev_h_data = h0_data + bid * D; \ + prev_c_data = c0_data + bid * D; \ + } else { \ + COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \ + MOVE_ONE_STEP; \ + tstart = 1; \ + } + +#define PROCESS_H0C0 \ + PROCESS_H0C0_DEFINES; \ + if (h0_data) { \ + prev_h_data = h0_data + bid * D; \ + prev_c_data = c0_data + bid * D; \ + } else { \ + COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \ + MOVE_ONE_STEP; \ + tstart = 1; \ + } + + if (use_peepholes) { + for (int i = 0; i < N; ++i) { + PROCESS_H0C0_PEEPHOLE + for (int step = tstart; step < seq_len; ++step) { + GEMM_WH_ADDON(1, prev_h_data, xx_data); + COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data); + MOVE_ONE_STEP; + } + } + } else { + for (int i = 0; i < N; ++i) { + PROCESS_H0C0 + for (int step = tstart; step < seq_len; ++step) { + GEMM_WH_ADDON(1, prev_h_data, xx_data); + COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data); + MOVE_ONE_STEP; + } + } + } +#undef PROCESS_H0C0_DEFINES +#undef PROCESS_H0C0_PEEPHOLE +#undef PROCESS_H0C0 +#undef MOVE_ONE_STEP + } + + void BatchCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = platform::CPUDeviceContext; + INIT_BASE_INPUT_OUTPUT + if (ids->lod()[0].size() == 2) { + SeqCompute(ctx); + return; + } + INIT_BASE_SIZES + INIT_VEC_FUNC + INIT_BASE_INPUT_DATAS + + // std::cout << "===> Batch Compute" << std::endl; + + auto* reordered_h0 = ctx.Output("ReorderedH0"); + auto* reordered_c0 = ctx.Output("ReorderedC0"); + auto* batched_input = ctx.Output("BatchedInput"); + auto* batched_c_out = ctx.Output("BatchedCell"); + auto* batched_h_out = ctx.Output("BatchedHidden"); + T* xx_data = xx->mutable_data(place); + T* batched_input_data = batched_input->mutable_data(place); + T* batched_c_out_data = batched_c_out->mutable_data(place); + T* batched_h_out_data = batched_h_out->mutable_data(place); + hidden_out->mutable_data(place); + cell_out->mutable_data(place); + + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + + for (int64_t i = 0; i < ids_numel; ++i) { + PADDLE_ENFORCE_LT(ids_data[i], row_number); + PADDLE_ENFORCE_GE(ids_data[i], 0, "ids %d", i); + memcpy(xx_data + i * row_width, embeddings_data + ids_data[i] * row_width, + row_width * sizeof(T)); + } + + to_batch(dev_ctx, *xx, batched_input, true, is_reverse); + + auto batched_lod = batched_input->lod(); + const auto& seq_order = batched_lod[2]; + const int max_bs = seq_order.size(); + reordered_h0->Resize({max_bs, D}); + reordered_c0->Resize({max_bs, D}); + + int tstart = 0; + T* prev_h_data = nullptr; + T* prev_c_data = nullptr; + if (h0) { + // reorder h0, c0 + T* reordered_h0_data = reordered_h0->mutable_data(place); + T* reordered_c0_data = reordered_c0->mutable_data(place); + const T* h0_data = h0->data(); + const T* c0_data = c0->data(); + prev_h_data = reordered_h0_data; + prev_c_data = reordered_c0_data; + size_t sz = sizeof(T) * D; + for (int i = 0; i < max_bs; ++i) { + std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); + std::memcpy(reordered_c0_data, c0_data + seq_order[i] * D, sz); + reordered_h0_data += D; + reordered_c0_data += D; + } + } else { + // compute without h0, c0 + T* cur_in_data = batched_input_data; + T* cur_h_out_data = batched_h_out_data; + T* cur_c_out_data = batched_c_out_data; + for (int i = 0; i < max_bs; ++i) { + GET_Ct_NOH0C0(cur_in_data, cur_c_out_data); + if (use_peepholes) { + blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D); + blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3); + } + act_gate(D, cur_in_data + D3, cur_in_data + D3); + GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data); + cur_in_data += D4; + cur_c_out_data += D; + cur_h_out_data += D; + } + tstart = 1; + prev_h_data = batched_h_out_data; + prev_c_data = batched_c_out_data; + } + const auto& batch_starts = batched_lod[0]; + const int max_seq_len = batch_starts.size() - 1; + const int offset = tstart * max_bs * D; + batched_input_data = batched_input_data + offset * 4; + batched_h_out_data = batched_h_out_data + offset; + batched_c_out_data = batched_c_out_data + offset; + +#define DEFINE_CUR \ + T* cur_in_data = batched_input_data; \ + T* cur_prev_c_data = prev_c_data; \ + T* cur_c_out_data = batched_c_out_data; \ + T* cur_h_out_data = batched_h_out_data + +#define MOVE_ONE_BATCH \ + cur_in_data += D4; \ + cur_prev_c_data += D; \ + cur_c_out_data += D; \ + cur_h_out_data += D + +#define MOVE_ONE_STEP \ + prev_c_data = batched_c_out_data; \ + prev_h_data = batched_h_out_data; \ + batched_c_out_data = cur_c_out_data; \ + batched_h_out_data = cur_h_out_data; \ + batched_input_data = cur_in_data + + if (use_peepholes) { + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); + DEFINE_CUR; + for (int i = 0; i < cur_bs; ++i) { + COMPUTE_CtHt_PEEPHOLE(cur_in_data, cur_prev_c_data, cur_c_out_data, + cur_h_out_data); + MOVE_ONE_BATCH; + } + MOVE_ONE_STEP; + } + } else { + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); + DEFINE_CUR; + for (int i = 0; i < cur_bs; ++i) { + COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, + cur_h_out_data); + MOVE_ONE_BATCH; + } + MOVE_ONE_STEP; + } + } +#undef MOVE_ONE_STEP +#undef MOVE_ONE_BATCH +#undef DEFINE_CUR + + math::Batch2LoDTensorFunctor to_seq; + batched_h_out->set_lod(batched_lod); + to_seq(dev_ctx, *batched_h_out, hidden_out); + batched_c_out->set_lod(batched_lod); + to_seq(dev_ctx, *batched_c_out, cell_out); + } + + void Compute(const framework::ExecutionContext& ctx) const override { + if (ctx.Attr("use_seq")) { + SeqCompute(ctx); + } else { + BatchCompute(ctx); + } + } + +#undef COMPUTE_CtHt_PEEPHOLE +#undef COMPUTE_CtHt +#undef GET_Ct_NOH0C0 +#undef COMPUTE_CtHt_NOH0C0 +#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0 +#undef GET_Ht +#undef GET_Ct +#undef GEMM_WH_ADDON +#undef INIT_BASE_INPUT_DATAS +#undef INIT_BASE_SIZES +#undef INIT_BASE_INPUT_OUTPUT +#undef INIT_VEC_FUNC +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_embedding_fc_lstm, ops::FusedEmbeddingFCLSTMOp, + ops::FusedEmbeddingFCLSTMOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(fused_embedding_fc_lstm, + ops::FusedEmbeddingFCLSTMKernel, + ops::FusedEmbeddingFCLSTMKernel); diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.h b/paddle/fluid/operators/fused_embedding_fc_lstm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2775b2ac04d2890355fe6d75a1e2507a2668dc95 --- /dev/null +++ b/paddle/fluid/operators/fused_embedding_fc_lstm_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +class FusedEmbeddingFCLSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FusedEmbeddingFCLSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle