提交 3ea19b75 编写于 作者: T tensor-tang

fix bug and fc pass ut

上级 acfdbf02
......@@ -536,22 +536,21 @@ PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope,
mul_out_var = pattern->NewNode(name_scope, "mul_out")
->AsIntermediate()
->assert_is_only_output_of_op("mul")
->assert_is_op_input("elementwise_add", "X");
->assert_is_op_input("elementwise_add");
// bias
bias = pattern->NewNode(name_scope, "fc_bias")
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("elementwise_add", "Y");
->assert_is_op_input("elementwise_add");
// output
fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");
->assert_is_op_output("elementwise_add");
mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
} else {
fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput()
->assert_is_op_output("mul", "Out");
->assert_is_op_output("mul");
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out});
}
return fc_out;
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册