提交 3ea19b75 编写于 作者: T tensor-tang

fix bug and fc pass ut

上级 acfdbf02
...@@ -536,22 +536,21 @@ PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope, ...@@ -536,22 +536,21 @@ PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope,
mul_out_var = pattern->NewNode(name_scope, "mul_out") mul_out_var = pattern->NewNode(name_scope, "mul_out")
->AsIntermediate() ->AsIntermediate()
->assert_is_only_output_of_op("mul") ->assert_is_only_output_of_op("mul")
->assert_is_op_input("elementwise_add", "X"); ->assert_is_op_input("elementwise_add");
// bias // bias
bias = pattern->NewNode(name_scope, "fc_bias") bias = pattern->NewNode(name_scope, "fc_bias")
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_op_input("elementwise_add");
->assert_is_op_input("elementwise_add", "Y");
// output // output
fc_out = pattern->NewNode(name_scope, "fc_out") fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput() ->AsOutput()
->assert_is_op_output("elementwise_add", "Out"); ->assert_is_op_output("elementwise_add");
mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var}); mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out}); elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
} else { } else {
fc_out = pattern->NewNode(name_scope, "fc_out") fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput() ->AsOutput()
->assert_is_op_output("mul", "Out"); ->assert_is_op_output("mul");
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out}); mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out});
} }
return fc_out; return fc_out;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册