From 597b73053d169b2c5e02f9de155cc76f7f1d969d Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 3 Sep 2018 10:02:27 +0800 Subject: [PATCH] refine/fc lstm fusion link (#13158) --- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 150 +++++++++++------ paddle/fluid/framework/ir/fc_lstm_fuse_pass.h | 18 ++- .../framework/ir/graph_pattern_detector.cc | 152 +++++++++++++++++- .../framework/ir/graph_pattern_detector.h | 89 ++++------ paddle/fluid/framework/ir/graph_viz_pass.h | 7 + paddle/fluid/inference/analysis/analyzer.cc | 1 + .../inference/analysis/analyzer_tester.cc | 1 + 7 files changed, 312 insertions(+), 106 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 5852705b6b8..c29eb4d4a62 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -13,37 +13,37 @@ // limitations under the License. #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" +#include "paddle/fluid/framework/lod_tensor.h" namespace paddle { namespace framework { namespace ir { -std::unique_ptr FCLstmFusePass::ApplyImpl( - std::unique_ptr graph) const { - GraphPatternDetector gpd; - auto* pattern = gpd.mutable_pattern(); - - std::unordered_set fused_ops({// first lstm - 13, 15, 16, - // second lstm - 23, 25, 26}); - - pattern->NewNode([&](Node* x) { return fused_ops.count(x->id()); }, - "any_node"); +std::string GenNodeName(const std::string& prefix, const std::string& name) { + return prefix + "/" + name; +} - std::unordered_set marked_nodes; +void BuildPattern(PDPattern* pattern, const std::string& name_scope, + bool with_fc_bias) { + PDNode* x = pattern->NewNode(name_scope, "x") + ->assert_is_op_input("mul") + ->assert_var_not_persistable(); + auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias); + fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. + patterns::LSTM(pattern, name_scope, fc_out); + // LOG(INFO) << "\n" << pattern->DotString(); +} - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { +int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, + bool with_fc_bias) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); - auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node")); - marked_nodes.insert(id); - }; - gpd(graph.get(), handler); + BuildPattern(pattern, name_scope, with_fc_bias); // Create New OpDesc auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h, - int bias, int hidden, int cell, int xx) { + int bias, int hidden, int cell, int xx, int fc_bias) { #define GET_NODE(x) auto* x##_n = graph->RetriveNode(x); GET_NODE(input); GET_NODE(weight_x); @@ -61,12 +61,33 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( SET_IN(WeightX, weight_x); SET_IN(WeightH, weight_h); SET_IN(Bias, bias); -#undef GET_NODE #undef SET_IN + if (with_fc_bias) { + // Add FC-bias with LSTM-bias and create a new weight + PADDLE_ENFORCE(scope); + const std::string& new_bias_var = name_scope + "_bias.new"; + auto* bias_var = scope->Var(new_bias_var); + PADDLE_ENFORCE(bias_var); + auto* bias_tensor = bias_var->GetMutable(); + auto* lstm_bias_var = scope->FindVar(bias_n->Name()); + PADDLE_ENFORCE(lstm_bias_var); + const auto& lstm_bias_tensor = lstm_bias_var->Get(); + bias_tensor->Resize(lstm_bias_tensor.dims()); + + GET_NODE(fc_bias); + auto* fc_bias_var = scope->FindVar(fc_bias_n->Name()); + const auto& fc_bias_tensor = fc_bias_var->Get(); + + auto* data = bias_tensor->mutable_data(platform::CPUPlace()); + + for (int i = 0; i < bias_tensor->numel(); i++) { + data[i] = + fc_bias_tensor.data()[i] + lstm_bias_tensor.data()[i]; + } + op_desc.SetInput("Bias", {new_bias_var}); + } - VLOG(4) << "hidden_n: " << hidden_n->Name(); - VLOG(4) << "cell: " << cell_n->Name(); - VLOG(4) << "xx: " << xx_n->Name(); +#undef GET_NODE op_desc.SetInput("H0", {}); op_desc.SetInput("C0", {}); @@ -76,7 +97,7 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"}); op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"}); op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); - op_desc.SetAttr("use_peepholes", false); + op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes")); auto* op = graph->CreateOpNode(&op_desc); #define LINK_TO(a, b) \ @@ -89,33 +110,71 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( LINK_TO(op, hidden_n); #undef LINK_TO return op; - }; - lstm_creator(16, 12, 14, 18, 17, 22, 21, 19); - lstm_creator(26, 12, 24, 28, 27, 32, 31, 29); + int fusion_count{0}; - // remove all the nodes + auto fc_no_bias_handler = [&]( + const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - for (auto* node : marked_nodes) { - graph->RemoveNode(const_cast(node)); - } +#define GET_NODE(name__) \ + std::string name__##key = name_scope + "/" + #name__; \ + auto* name__##n = pattern->RetrieveNode(name__##key); \ + PADDLE_ENFORCE(name__##n); \ + PADDLE_ENFORCE(subgraph.count(name__##n)); \ + Node* name__##_n = subgraph.at(name__##n); \ + int name__ __attribute__((unused)) = name__##_n->id(); - for (auto* node : graph->Nodes()) { - for (auto it = node->inputs.begin(); it != node->inputs.end();) { - if (marked_nodes.count(*it)) { - it = const_cast(node)->inputs.erase(it); - } else - it++; - } - for (auto it = node->outputs.begin(); it != node->outputs.end();) { - if (marked_nodes.count(*it)) { - it = const_cast(node)->outputs.erase(it); - } else - it++; + GET_NODE(x); + GET_NODE(w); + GET_NODE(mul); + GET_NODE(fc_out); + GET_NODE(Weight); + GET_NODE(lstm); + GET_NODE(Bias); + GET_NODE(Hidden); + GET_NODE(Cell); + + if (with_fc_bias) { + GET_NODE(fc_bias); + lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias); + } else { + lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1); } - } +#undef GET_NODE + + // Remove unneeded nodes. + std::unordered_set marked_nodes({mul_n, lstm_n}); + + GraphSafeRemoveNodes(graph, marked_nodes); + + ++fusion_count; + }; + + gpd(graph, fc_no_bias_handler); + + return fusion_count; +} + +std::unique_ptr MulLstmFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + + int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), + false /*with_fc_bias*/); + + AddStatis(fusion_count); + return graph; +} + +std::unique_ptr FCLstmFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + + int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), + true /*with_fc_bias*/); + AddStatis(fusion_count); return graph; } @@ -123,4 +182,5 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( } // namespace framework } // namespace paddle +REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulLstmFusePass); REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h index 74b08ae558b..5a6687872eb 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h @@ -12,20 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { namespace ir { -class FCLstmFusePass : public Pass { +// The MulLstmFusePass and MulLstmFusePass will fuse to the same FusionLstm op. + +// Just FC without bias +class FCLstmFusePass : public FusePassBase { public: virtual ~FCLstmFusePass() {} protected: std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + const std::string name_scope_{"fc_lstm_fuse"}; +}; + +class MulLstmFusePass : public FusePassBase { + public: + virtual ~MulLstmFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + const std::string name_scope_{"fc_nobias_lstm_fuse"}; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 945ab110b14..f651ab635ea 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -71,7 +72,11 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) { void GraphPatternDetector::operator()(Graph* graph, GraphPatternDetector::handle_t handler) { - if (!MarkPDNodesInGraph(*graph)) return; + if (!MarkPDNodesInGraph(*graph)) { + LOG(INFO) << "Mark failed"; + return; + } + auto subgraphs = DetectPatterns(); UniquePatterns(&subgraphs); RemoveOverlappedMatch(&subgraphs); @@ -87,7 +92,7 @@ void GraphPatternDetector::operator()(Graph* graph, } bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { - VLOG(4) << "mark pdnodes in graph"; + VLOG(3) << "mark pdnodes in graph"; if (graph.Nodes().empty()) return false; for (auto& node : GraphTraits::DFS(graph)) { @@ -107,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { } } VLOG(3) << pdnodes2nodes_.size() << " nodes marked"; + return !pdnodes2nodes_.empty(); } @@ -357,7 +363,9 @@ PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type, assert_is_op_input(op_type); asserts_.emplace_back([=](Node* x) { for (auto* op : x->outputs) { - if (IsNthInput(x, op, argument, nth)) return true; + if (op->IsOp() && op->Op()->Type() == op_type && + IsNthInput(x, op, argument, nth)) + return true; } return false; }); @@ -368,7 +376,9 @@ PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type, assert_is_var(); asserts_.emplace_back([=](Node* x) { for (auto* op : x->inputs) { - if (IsNthOutput(x, op, argument, nth)) return true; + if (op->IsOp() && op->Op()->Type() == op_type && + IsNthOutput(x, op, argument, nth)) + return true; } return false; }); @@ -412,6 +422,12 @@ PDNode* PDNode::assert_is_op_output(const std::string& op_type) { }); return this; } +PDNode* PDNode::assert_is_op_output(const std::string& op_type, + const std::string& argument) { + assert_is_var(); + assert_is_op_nth_output(op_type, argument, 0); + return this; +} PDNode* PDNode::assert_is_op_input(const std::string& op_type) { assert_is_var(); asserts_.emplace_back([=](Node* x) { @@ -424,6 +440,12 @@ PDNode* PDNode::assert_is_op_input(const std::string& op_type) { }); return this; } +PDNode* PDNode::assert_is_op_input(const std::string& op_type, + const std::string& argument) { + assert_is_var(); + assert_is_op_nth_input(op_type, argument, 0); + return this; +} PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) { assert_is_op(op_type); asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; }); @@ -439,6 +461,128 @@ PDNode* PDNode::assert_more(PDNode::teller_t&& teller) { return this; } +bool VarLinksToOp(Node* node, const std::string& op_type) { + for (auto* out : node->outputs) { + if (out->IsOp() && out->Op()->Type() == op_type) { + return true; + } + } + return false; +} +bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth) { + PADDLE_ENFORCE(var->IsVar()); + PADDLE_ENFORCE(op->IsOp()); + if (op->Op()->Input(argument).size() <= nth) return false; + return var->Name() == op->Op()->Input(argument)[nth]; +} +bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth) { + PADDLE_ENFORCE(var->IsVar()); + PADDLE_ENFORCE(op->IsOp()); + if (op->Op()->Output(argument).size() <= nth) return false; + return var->Name() == op->Op()->Output(argument)[nth]; +} +void GraphSafeRemoveNodes(Graph* graph, + const std::unordered_set& nodes) { + for (auto* node : nodes) { + graph->RemoveNode(const_cast(node)); + } + + for (auto* node : graph->Nodes()) { + for (auto it = node->inputs.begin(); it != node->inputs.end();) { + if (nodes.count(*it)) { + it = const_cast(node)->inputs.erase(it); + } else + it++; + } + for (auto it = node->outputs.begin(); it != node->outputs.end();) { + if (nodes.count(*it)) { + it = const_cast(node)->outputs.erase(it); + } else + it++; + } + } +} +bool VarLinksFromOp(Node* node, const std::string& op_type) { + for (auto* out : node->inputs) { + if (out->IsOp() && out->Op()->Type() == op_type) { + return true; + } + } + return false; +} + +PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope, + PDNode* x, bool with_bias) { + // Create Operators + PDNode* elementwise_add_op{nullptr}; + auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul"); + if (with_bias) { + elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add") + ->assert_is_op("elementwise_add"); + } + // Create variables + // w + auto* mul_weight_var = pattern->NewNode(name_scope, "w") + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_nth_input("mul", "Y", 0); + PDNode* mul_out_var{nullptr}; + if (with_bias) { + // intermediate variable, will be removed in the IR after fuse. + mul_out_var = pattern->NewNode(name_scope, "mul_out") + ->AsIntermediate() + ->assert_is_only_output_of_op("mul") + ->assert_is_op_input("elementwise_add"); + } + PDNode *bias{nullptr}, *fc_out{nullptr}; + if (with_bias) { + // bias + bias = pattern->NewNode(name_scope, "fc_bias") + ->assert_is_op_input("elementwise_add") + ->AsInput(); + // output + fc_out = pattern->NewNode(name_scope, "fc_out") + ->AsOutput() + ->assert_is_op_output("elementwise_add"); + } else { + fc_out = pattern->NewNode(name_scope, "fc_out") + ->AsOutput() + ->assert_is_op_output("mul"); + } + + if (with_bias) { + mul_op->LinksFrom({mul_weight_var, x}).LinksTo({mul_out_var}); + elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out}); + } else { + mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out}); + } + + return fc_out; +} +PDNode* patterns::LSTM(PDPattern* pattern, const std::string& name_scope, + PDNode* x) { + x->assert_is_op_input("lstm", "Input"); + auto* lstm_op = pattern->NewNode(name_scope, "lstm")->assert_is_op("lstm"); +#define NEW_NODE(arg__, io__) \ + auto* arg__ = pattern->NewNode(name_scope, #arg__) \ + ->assert_is_op_##io__("lstm", #arg__); + + // Currently, the H0 and C0 are optional + // TODO(Superjomn) upgrade the fuse framework to support optional. + // NEW_NODE(H0, input); + // NEW_NODE(C0, input); + NEW_NODE(Weight, input); + NEW_NODE(Bias, input); + + NEW_NODE(Hidden, output); + NEW_NODE(Cell, output); + NEW_NODE(BatchGate, output); + NEW_NODE(BatchCellPreAct, output); + + lstm_op->LinksFrom({x, Weight, Bias}); + lstm_op->LinksTo({Hidden, Cell, BatchGate, BatchCellPreAct}); + return Hidden; +} } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index f8488c84962..024ce8ce556 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -95,7 +95,11 @@ struct PDNode { PDNode* assert_var_not_persistable(); PDNode* assert_is_persistable_var(); PDNode* assert_is_op_output(const std::string& op_type); + PDNode* assert_is_op_output(const std::string& op_type, + const std::string& argument); PDNode* assert_is_op_input(const std::string& op_type); + PDNode* assert_is_op_input(const std::string& op_type, + const std::string& argument); PDNode* assert_is_op_nth_input(const std::string& op_type, const std::string& argument, int nth); PDNode* assert_is_op_nth_output(const std::string& op_type, @@ -167,6 +171,9 @@ class PDPattern { PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID()); PDNode* NewNode(const std::string& name = NewID()); + PDNode* NewNode(const std::string& prefix, const std::string& name) { + return NewNode(prefix + "/" + name); + } PDNode* RetrieveNode(const std::string& id) const; const std::vector>& nodes() const { return nodes_; } @@ -257,64 +264,36 @@ class GraphPatternDetector { // some helper methods. -// Op's input. -static bool VarLinksToOp(Node* node, const std::string& op_type) { - for (auto* out : node->outputs) { - if (out->IsOp() && out->Op()->Type() == op_type) { - return true; - } - } - return false; -} - -// Op's output. -static bool VarLinksFromOp(Node* node, const std::string& op_type) { - for (auto* out : node->inputs) { - if (out->IsOp() && out->Op()->Type() == op_type) { - return true; - } - } - return false; -} +// Tell if a var links to an Op +bool VarLinksToOp(Node* node, const std::string& op_type); + +// Tell if an op links to a var +bool VarLinksFromOp(Node* node, const std::string& op_type); // Check whether a var node is a op node's nth input. -static bool IsNthInput(Node* var, Node* op, const std::string& argument, - size_t nth) { - PADDLE_ENFORCE(var->IsVar()); - PADDLE_ENFORCE(op->IsOp()); - if (op->inputs.size() <= nth) return false; - return var->Name() == op->Op()->Input(argument)[nth]; -} - -static bool IsNthOutput(Node* var, Node* op, const std::string& argument, - size_t nth) { - PADDLE_ENFORCE(var->IsVar()); - PADDLE_ENFORCE(op->IsOp()); - if (op->inputs.size() <= nth) return false; - return var->Name() == op->Op()->Output(argument)[nth]; -} - -static void GraphSafeRemoveNodes(Graph* graph, - const std::unordered_set& nodes) { - for (auto* node : nodes) { - graph->RemoveNode(const_cast(node)); - } +bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth); - for (auto* node : graph->Nodes()) { - for (auto it = node->inputs.begin(); it != node->inputs.end();) { - if (nodes.count(*it)) { - it = const_cast(node)->inputs.erase(it); - } else - it++; - } - for (auto it = node->outputs.begin(); it != node->outputs.end();) { - if (nodes.count(*it)) { - it = const_cast(node)->outputs.erase(it); - } else - it++; - } - } -} +// Tell whether a var node is a op node's nth output. +bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); + +// Graph safely remove some nodes, will automatically clean up the edges. +void GraphSafeRemoveNodes(Graph* graph, + const std::unordered_set& nodes); + +// Some pre-defined patterns those can be reused in multiple passes. +namespace patterns { + +// FC with bias +// op: mul + elementwise_add +// named nodes: +// mul, elementwise_add +// w, mul_out, bias, fc_out +PDNode* FC(PDPattern* pattern, const std::string& name_scope, PDNode* x, + bool with_bias); + +PDNode* LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x); + +} // namespace patterns } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/graph_viz_pass.h b/paddle/fluid/framework/ir/graph_viz_pass.h index 8d885cb9e4e..e64916a5bb6 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.h +++ b/paddle/fluid/framework/ir/graph_viz_pass.h @@ -42,6 +42,13 @@ class GraphVizPass : public Pass { marked_nodes_t ConsumeMarkedNodes(Graph* graph) const; }; +static GraphVizPass::marked_nodes_t& GetMarkedNodes(Graph* graph) { + if (!graph->Has(kGraphvizMarkedNodeAttr)) { + graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t); + } + return graph->Get(kGraphvizMarkedNodeAttr); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index e6e63544ffa..192ac2daa6a 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -109,6 +109,7 @@ void Analyzer::Run(Argument* argument) { "infer_clean_graph_pass", "graph_viz_pass", // "attention_lstm_fuse_pass", "graph_viz_pass", // "fc_lstm_fuse_pass", "graph_viz_pass", // + "mul_lstm_fuse_pass", "graph_viz_pass", // "seq_concat_fc_fuse_pass", "graph_viz_pass", // "fc_fuse_pass", "graph_viz_pass" // diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 119247cfe43..ec1f3979a74 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -329,6 +329,7 @@ void TestDituRNNPrediction(const std::string &model_path, ASSERT_TRUE(fuse_statis.count("fc")); EXPECT_EQ(fuse_statis.at("fc"), 1); + EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 1); } } -- GitLab