提交 4d774953 编写于 作者: T tensor-tang

enable fc gru fuse pass

上级 74f95b8d
...@@ -24,6 +24,7 @@ pass_library(fc_fuse_pass) ...@@ -24,6 +24,7 @@ pass_library(fc_fuse_pass)
pass_library(attention_lstm_fuse_pass) pass_library(attention_lstm_fuse_pass)
pass_library(infer_clean_graph_pass) pass_library(infer_clean_graph_pass)
pass_library(fc_lstm_fuse_pass) pass_library(fc_lstm_fuse_pass)
pass_library(fc_gru_fuse_pass)
pass_library(seq_concat_fc_fuse_pass) pass_library(seq_concat_fc_fuse_pass)
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
......
...@@ -20,12 +20,8 @@ namespace paddle { ...@@ -20,12 +20,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::string GenNodeName(const std::string& prefix, const std::string& name) { static void BuildPattern(PDPattern* pattern, const std::string& name_scope,
return prefix + "/" + name; bool with_fc_bias) {
}
void BuildPattern(PDPattern* pattern, const std::string& name_scope,
bool with_fc_bias) {
PDNode* x = pattern->NewNode(name_scope, "x") PDNode* x = pattern->NewNode(name_scope, "x")
->assert_is_op_input("mul") ->assert_is_op_input("mul")
->assert_var_not_persistable(); ->assert_var_not_persistable();
...@@ -35,8 +31,8 @@ void BuildPattern(PDPattern* pattern, const std::string& name_scope, ...@@ -35,8 +31,8 @@ void BuildPattern(PDPattern* pattern, const std::string& name_scope,
VLOG(3) << "\n" << pattern->DotString(); VLOG(3) << "\n" << pattern->DotString();
} }
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, static int BuildFusion(Graph* graph, const std::string& name_scope,
bool with_fc_bias) { Scope* scope, bool with_fc_bias) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
...@@ -108,7 +104,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -108,7 +104,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr); // auto* scope = graph->Get<Scope*>(kParamScopeAttr);
IR_NODE_LINK_TO(x_n, op); IR_NODE_LINK_TO(x_n, op);
IR_NODE_LINK_TO(weight_x_n, op); IR_NODE_LINK_TO(weight_x_n, op);
...@@ -189,5 +185,5 @@ std::unique_ptr<ir::Graph> FCGRUFusePass::ApplyImpl( ...@@ -189,5 +185,5 @@ std::unique_ptr<ir::Graph> FCGRUFusePass::ApplyImpl(
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulGRUFusePass); REGISTER_PASS(mul_gru_fuse_pass, paddle::framework::ir::MulGRUFusePass);
REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCGRUFusePass); REGISTER_PASS(fc_gru_fuse_pass, paddle::framework::ir::FCGRUFusePass);
...@@ -19,12 +19,13 @@ namespace paddle { ...@@ -19,12 +19,13 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::string GenNodeName(const std::string& prefix, const std::string& name) { static std::string GenNodeName(const std::string& prefix,
const std::string& name) {
return prefix + "/" + name; return prefix + "/" + name;
} }
void BuildPattern(PDPattern* pattern, const std::string& name_scope, static void BuildPattern(PDPattern* pattern, const std::string& name_scope,
bool with_fc_bias) { bool with_fc_bias) {
PDNode* x = pattern->NewNode(name_scope, "x") PDNode* x = pattern->NewNode(name_scope, "x")
->assert_is_op_input("mul") ->assert_is_op_input("mul")
->assert_var_not_persistable(); ->assert_var_not_persistable();
...@@ -34,8 +35,8 @@ void BuildPattern(PDPattern* pattern, const std::string& name_scope, ...@@ -34,8 +35,8 @@ void BuildPattern(PDPattern* pattern, const std::string& name_scope,
// LOG(INFO) << "\n" << pattern->DotString(); // LOG(INFO) << "\n" << pattern->DotString();
} }
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, static int BuildFusion(Graph* graph, const std::string& name_scope,
bool with_fc_bias) { Scope* scope, bool with_fc_bias) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
......
...@@ -36,6 +36,8 @@ limitations under the License. */ ...@@ -36,6 +36,8 @@ limitations under the License. */
*/ */
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/flags.h" #include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h" #include "paddle/fluid/inference/analysis/pass_manager.h"
...@@ -66,6 +68,8 @@ class Analyzer : public OrderedRegistry<PassManager> { ...@@ -66,6 +68,8 @@ class Analyzer : public OrderedRegistry<PassManager> {
"attention_lstm_fuse_pass", // "attention_lstm_fuse_pass", //
"fc_lstm_fuse_pass", // "fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", // "mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", // "seq_concat_fc_fuse_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
}}; }};
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/analysis_predictor.h"
...@@ -237,6 +238,30 @@ void TestLACPrediction(const std::string &model_path, ...@@ -237,6 +238,30 @@ void TestLACPrediction(const std::string &model_path,
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
EXPECT_EQ(pdata_ref[i], pdata[i]); EXPECT_EQ(pdata_ref[i], pdata[i]);
} }
AnalysisPredictor *analysis_predictor =
dynamic_cast<AnalysisPredictor *>(predictor.get());
auto &fuse_statis = analysis_predictor->analysis_argument()
.Get<std::unordered_map<std::string, int>>(
framework::ir::kFuseStatisAttr);
for (auto &item : fuse_statis) {
LOG(INFO) << "fused " << item.first << " " << item.second;
}
int num_ops = 0;
for (auto &node :
analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
if (node->IsFunction()) {
++num_ops;
}
}
LOG(INFO) << "has num ops: " << num_ops;
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
LOG(INFO) << "fc fuse num:" << fuse_statis.at("fc_fuse");
LOG(INFO) << "fc gru fuse num:" << fuse_statis.at("fc_gru_fuse");
// ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
// LOG(INFO) << fuse_statis.at("fc_gru_fuse");
} }
} }
......
...@@ -50,6 +50,7 @@ cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_a ...@@ -50,6 +50,7 @@ cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_a
pass pass
fc_fuse_pass fc_fuse_pass
fc_lstm_fuse_pass fc_lstm_fuse_pass
fc_gru_fuse_pass
seq_concat_fc_fuse_pass seq_concat_fc_fuse_pass
graph_viz_pass graph_viz_pass
infer_clean_graph_pass infer_clean_graph_pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册