提交 d618e483 编写于 作者: T tensor-tang

fix fuse square mat order and refine test

test=develop
上级 a5d2a6d1
...@@ -33,7 +33,7 @@ class RepeatedFCReluFusePass : public FusePassBase { ...@@ -33,7 +33,7 @@ class RepeatedFCReluFusePass : public FusePassBase {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"repeated_fc_relu"}; const std::string name_scope_{"repeated_fc_relu_fuse"};
}; };
} // namespace ir } // namespace ir
......
...@@ -51,7 +51,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -51,7 +51,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
auto next_op = [=](Node* x, const std::string& op_type) -> Node* { auto next_op = [=](Node* x, const std::string& op_type) -> Node* {
if (!(x && x->IsVar())) { if (!(x && x->IsVar())) {
return false; return nullptr;
} }
for (auto* op : x->outputs) { for (auto* op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
...@@ -63,7 +63,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -63,7 +63,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
auto get_op_input_var = [=](Node* x, const std::string& arg_name) -> Node* { auto get_op_input_var = [=](Node* x, const std::string& arg_name) -> Node* {
if (!(x && x->IsOp())) { if (!(x && x->IsOp())) {
return false; return nullptr;
} }
for (auto* var : x->inputs) { for (auto* var : x->inputs) {
for (auto name : x->Op()->Input(arg_name)) { for (auto name : x->Op()->Input(arg_name)) {
...@@ -93,10 +93,10 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -93,10 +93,10 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
if (!next_is_matmul_from_arg) { if (!next_is_matmul_from_arg) {
return false; return false;
} }
auto* sub_x = squared_x->outputs[0]->outputs[0]; auto* sub_y_in = squared_x->outputs[0]->outputs[0];
return var_is_op_input(sub_x, "elementwise_sub", "X") && return var_is_op_input(sub_y_in, "elementwise_sub", "Y") &&
sub_x->outputs[0]->outputs.size() == 1 && sub_y_in->outputs[0]->outputs.size() == 1 &&
var_is_op_input(sub_x->outputs[0]->outputs[0], "elementwise_mul"); var_is_op_input(sub_y_in->outputs[0]->outputs[0], "elementwise_mul");
}; };
auto is_fusion_first_mul_out = [=](Node* x) -> bool { auto is_fusion_first_mul_out = [=](Node* x) -> bool {
...@@ -120,10 +120,10 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -120,10 +120,10 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
if (!next_is_square) { if (!next_is_square) {
return false; return false;
} }
auto* sub_y = x->outputs[0]->outputs[0]; auto* sub_x_in = x->outputs[0]->outputs[0];
return var_is_op_input(sub_y, "elementwise_sub", "Y") && return var_is_op_input(sub_x_in, "elementwise_sub", "X") &&
sub_y->outputs[0]->outputs.size() == 1 && sub_x_in->outputs[0]->outputs.size() == 1 &&
var_is_op_input(sub_y->outputs[0]->outputs[0], "elementwise_mul"); var_is_op_input(sub_x_in->outputs[0]->outputs[0], "elementwise_mul");
}; };
auto* x = pattern->NewNode( auto* x = pattern->NewNode(
...@@ -219,7 +219,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -219,7 +219,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
if (!is_sub_op) { if (!is_sub_op) {
return false; return false;
} }
auto* matmul_sqx_sqy_var = get_op_input_var(x, "X"); auto* matmul_sqx_sqy_var = get_op_input_var(x, "Y");
return is_fusion_mat_squared_x_y_op_out(matmul_sqx_sqy_var); return is_fusion_mat_squared_x_y_op_out(matmul_sqx_sqy_var);
}; };
...@@ -280,7 +280,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -280,7 +280,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
matmul_squared_x_y_op->LinksFrom({squared_x, squared_y}) matmul_squared_x_y_op->LinksFrom({squared_x, squared_y})
.LinksTo({mat_squared_x_y_op_out}); .LinksTo({mat_squared_x_y_op_out});
square_matmuled_xy_op->LinksFrom({matmuled_xy}).LinksTo({squared_xmuly}); square_matmuled_xy_op->LinksFrom({matmuled_xy}).LinksTo({squared_xmuly});
sub_op->LinksFrom({mat_squared_x_y_op_out, squared_xmuly}) sub_op->LinksFrom({squared_xmuly, mat_squared_x_y_op_out})
.LinksTo({sub_op_out}); .LinksTo({sub_op_out});
constant_op->LinksFrom({}).LinksTo({constant_op_out}); constant_op->LinksFrom({}).LinksTo({constant_op_out});
elementmul_op->LinksFrom({constant_op_out, sub_op_out}) elementmul_op->LinksFrom({constant_op_out, sub_op_out})
......
...@@ -33,7 +33,7 @@ class SquaredMatSubFusePass : public FusePassBase { ...@@ -33,7 +33,7 @@ class SquaredMatSubFusePass : public FusePassBase {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"squared_mat_sub"}; const std::string name_scope_{"squared_mat_sub_fuse"};
}; };
} // namespace ir } // namespace ir
......
...@@ -21,6 +21,12 @@ namespace paddle { ...@@ -21,6 +21,12 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
// diff: similarity_norm.tmp_0, for speed: fc_4.tmp_1
static const char out_var_name[] = "reduce_sum_0.tmp_0";
// for diff: 154, for speed 111
constexpr int num_slots = 154;
struct OneSlotInBatch { struct OneSlotInBatch {
std::string name; std::string name;
std::vector<std::vector<float>> data; std::vector<std::vector<float>> data;
...@@ -41,7 +47,6 @@ struct DataRecord { ...@@ -41,7 +47,6 @@ struct DataRecord {
void Load(const std::string &path) { void Load(const std::string &path) {
std::ifstream file(path); std::ifstream file(path);
constexpr int num_slots = 154;
std::string line; std::string line;
int num_lines = 0; int num_lines = 0;
while (std::getline(file, line)) { while (std::getline(file, line)) {
...@@ -190,13 +195,15 @@ void analysis_fuse_statis(bool use_zerocopy) { ...@@ -190,13 +195,15 @@ void analysis_fuse_statis(bool use_zerocopy) {
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg); auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops); auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse")); ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse")); ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse"));
ASSERT_TRUE(fuse_statis.count("squared_mat_sub_fuse"));
ASSERT_TRUE(fuse_statis.count("repeated_fc_relu_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2); EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2);
ASSERT_TRUE(fuse_statis.count("repeated_fc_relu")); EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 2);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu"), 2); EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops; LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 185); EXPECT_EQ(num_ops, 171);
} }
// Check the fuse status // Check the fuse status
...@@ -219,9 +226,6 @@ void PrepareZeroCopyInputs( ...@@ -219,9 +226,6 @@ void PrepareZeroCopyInputs(
} }
} }
// diff: similarity_norm.tmp_0, // speed: fc_4.tmp_1
static const char out_var_name[] = "reduce_sum_0.tmp_0";
// return the output values // return the output values
std::vector<float> zerocopy_profile(int repeat_times) { std::vector<float> zerocopy_profile(int repeat_times) {
AnalysisConfig config; AnalysisConfig config;
......
...@@ -68,7 +68,7 @@ void FusionSquaredMatSubOpMaker::Make() { ...@@ -68,7 +68,7 @@ void FusionSquaredMatSubOpMaker::Make() {
AddComment(R"DOC( AddComment(R"DOC(
Fusion Squared Matrix and substrct operator. Fusion Squared Matrix and substrct operator.
( (A.^2 * B.^2) - (A * B).^2 ) .* scalar ( (X * Y).^2 - (X.^2 * Y.^2) ) .* scalar
)DOC"); )DOC");
} }
...@@ -112,14 +112,14 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> { ...@@ -112,14 +112,14 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
T* squared_xy_data = squared_xy->mutable_data<T>(place); T* squared_xy_data = squared_xy->mutable_data<T>(place);
T* o_data = out->mutable_data<T>(place); T* o_data = out->mutable_data<T>(place);
matmul(x_data, y_data, squared_xy_data, m, n, k);
vsquare_xy(squared_xy_data, squared_xy_data, o_numel);
vsquare_x(x_data, squared_x_data, m * k); vsquare_x(x_data, squared_x_data, m * k);
vsquare_y(y_data, squared_y_data, k * n); vsquare_y(y_data, squared_y_data, k * n);
matmul(x_data, y_data, o_data, m, n, k);
vsquare_xy(o_data, squared_xy_data, o_numel);
matmul(squared_x_data, squared_y_data, o_data, m, n, k); matmul(squared_x_data, squared_y_data, o_data, m, n, k);
vsub(o_data, squared_xy_data, o_data, o_numel);
vsub(squared_xy_data, o_data, o_data, o_numel);
vscal(&scalar, o_data, o_data, o_numel); vscal(&scalar, o_data, o_data, o_numel);
} }
}; };
......
...@@ -33,7 +33,7 @@ class TestFusionSquaredMatSubOp(OpTest): ...@@ -33,7 +33,7 @@ class TestFusionSquaredMatSubOp(OpTest):
self.inputs = {'X': matx, 'Y': maty} self.inputs = {'X': matx, 'Y': maty}
self.outputs = { self.outputs = {
'Out': 'Out':
(np.dot(matx**2, maty**2) - np.dot(matx, maty)**2) * self.scalar (np.dot(matx, maty)**2 - np.dot(matx**2, maty**2)) * self.scalar
} }
self.attrs = {'scalar': self.scalar, } self.attrs = {'scalar': self.scalar, }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册