提交 4d774953 编写于 作者: T tensor-tang

enable fc gru fuse pass

上级 74f95b8d
......@@ -24,6 +24,7 @@ pass_library(fc_fuse_pass)
pass_library(attention_lstm_fuse_pass)
pass_library(infer_clean_graph_pass)
pass_library(fc_lstm_fuse_pass)
pass_library(fc_gru_fuse_pass)
pass_library(seq_concat_fc_fuse_pass)
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
......
......@@ -20,12 +20,8 @@ namespace paddle {
namespace framework {
namespace ir {
std::string GenNodeName(const std::string& prefix, const std::string& name) {
return prefix + "/" + name;
}
void BuildPattern(PDPattern* pattern, const std::string& name_scope,
bool with_fc_bias) {
static void BuildPattern(PDPattern* pattern, const std::string& name_scope,
bool with_fc_bias) {
PDNode* x = pattern->NewNode(name_scope, "x")
->assert_is_op_input("mul")
->assert_var_not_persistable();
......@@ -35,8 +31,8 @@ void BuildPattern(PDPattern* pattern, const std::string& name_scope,
VLOG(3) << "\n" << pattern->DotString();
}
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
bool with_fc_bias) {
static int BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
......@@ -108,7 +104,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
// auto* scope = graph->Get<Scope*>(kParamScopeAttr);
IR_NODE_LINK_TO(x_n, op);
IR_NODE_LINK_TO(weight_x_n, op);
......@@ -189,5 +185,5 @@ std::unique_ptr<ir::Graph> FCGRUFusePass::ApplyImpl(
} // namespace framework
} // namespace paddle
REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulGRUFusePass);
REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCGRUFusePass);
REGISTER_PASS(mul_gru_fuse_pass, paddle::framework::ir::MulGRUFusePass);
REGISTER_PASS(fc_gru_fuse_pass, paddle::framework::ir::FCGRUFusePass);
......@@ -19,12 +19,13 @@ namespace paddle {
namespace framework {
namespace ir {
std::string GenNodeName(const std::string& prefix, const std::string& name) {
static std::string GenNodeName(const std::string& prefix,
const std::string& name) {
return prefix + "/" + name;
}
void BuildPattern(PDPattern* pattern, const std::string& name_scope,
bool with_fc_bias) {
static void BuildPattern(PDPattern* pattern, const std::string& name_scope,
bool with_fc_bias) {
PDNode* x = pattern->NewNode(name_scope, "x")
->assert_is_op_input("mul")
->assert_var_not_persistable();
......@@ -34,8 +35,8 @@ void BuildPattern(PDPattern* pattern, const std::string& name_scope,
// LOG(INFO) << "\n" << pattern->DotString();
}
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
bool with_fc_bias) {
static int BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
......
......@@ -36,6 +36,8 @@ limitations under the License. */
*/
#include <gflags/gflags.h>
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
......@@ -66,6 +68,8 @@ class Analyzer : public OrderedRegistry<PassManager> {
"attention_lstm_fuse_pass", //
"fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
}};
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
......@@ -237,6 +238,30 @@ void TestLACPrediction(const std::string &model_path,
for (size_t i = 0; i < size; ++i) {
EXPECT_EQ(pdata_ref[i], pdata[i]);
}
AnalysisPredictor *analysis_predictor =
dynamic_cast<AnalysisPredictor *>(predictor.get());
auto &fuse_statis = analysis_predictor->analysis_argument()
.Get<std::unordered_map<std::string, int>>(
framework::ir::kFuseStatisAttr);
for (auto &item : fuse_statis) {
LOG(INFO) << "fused " << item.first << " " << item.second;
}
int num_ops = 0;
for (auto &node :
analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
if (node->IsFunction()) {
++num_ops;
}
}
LOG(INFO) << "has num ops: " << num_ops;
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
LOG(INFO) << "fc fuse num:" << fuse_statis.at("fc_fuse");
LOG(INFO) << "fc gru fuse num:" << fuse_statis.at("fc_gru_fuse");
// ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
// LOG(INFO) << fuse_statis.at("fc_gru_fuse");
}
}
......
......@@ -50,6 +50,7 @@ cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_a
pass
fc_fuse_pass
fc_lstm_fuse_pass
fc_gru_fuse_pass
seq_concat_fc_fuse_pass
graph_viz_pass
infer_clean_graph_pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册