提交 d618e483 编写于 作者: T tensor-tang

fix fuse square mat order and refine test

test=develop
上级 a5d2a6d1
......@@ -33,7 +33,7 @@ class RepeatedFCReluFusePass : public FusePassBase {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"repeated_fc_relu"};
const std::string name_scope_{"repeated_fc_relu_fuse"};
};
} // namespace ir
......
......@@ -51,7 +51,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
auto next_op = [=](Node* x, const std::string& op_type) -> Node* {
if (!(x && x->IsVar())) {
return false;
return nullptr;
}
for (auto* op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
......@@ -63,7 +63,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
auto get_op_input_var = [=](Node* x, const std::string& arg_name) -> Node* {
if (!(x && x->IsOp())) {
return false;
return nullptr;
}
for (auto* var : x->inputs) {
for (auto name : x->Op()->Input(arg_name)) {
......@@ -93,10 +93,10 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
if (!next_is_matmul_from_arg) {
return false;
}
auto* sub_x = squared_x->outputs[0]->outputs[0];
return var_is_op_input(sub_x, "elementwise_sub", "X") &&
sub_x->outputs[0]->outputs.size() == 1 &&
var_is_op_input(sub_x->outputs[0]->outputs[0], "elementwise_mul");
auto* sub_y_in = squared_x->outputs[0]->outputs[0];
return var_is_op_input(sub_y_in, "elementwise_sub", "Y") &&
sub_y_in->outputs[0]->outputs.size() == 1 &&
var_is_op_input(sub_y_in->outputs[0]->outputs[0], "elementwise_mul");
};
auto is_fusion_first_mul_out = [=](Node* x) -> bool {
......@@ -120,10 +120,10 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
if (!next_is_square) {
return false;
}
auto* sub_y = x->outputs[0]->outputs[0];
return var_is_op_input(sub_y, "elementwise_sub", "Y") &&
sub_y->outputs[0]->outputs.size() == 1 &&
var_is_op_input(sub_y->outputs[0]->outputs[0], "elementwise_mul");
auto* sub_x_in = x->outputs[0]->outputs[0];
return var_is_op_input(sub_x_in, "elementwise_sub", "X") &&
sub_x_in->outputs[0]->outputs.size() == 1 &&
var_is_op_input(sub_x_in->outputs[0]->outputs[0], "elementwise_mul");
};
auto* x = pattern->NewNode(
......@@ -219,7 +219,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
if (!is_sub_op) {
return false;
}
auto* matmul_sqx_sqy_var = get_op_input_var(x, "X");
auto* matmul_sqx_sqy_var = get_op_input_var(x, "Y");
return is_fusion_mat_squared_x_y_op_out(matmul_sqx_sqy_var);
};
......@@ -280,7 +280,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
matmul_squared_x_y_op->LinksFrom({squared_x, squared_y})
.LinksTo({mat_squared_x_y_op_out});
square_matmuled_xy_op->LinksFrom({matmuled_xy}).LinksTo({squared_xmuly});
sub_op->LinksFrom({mat_squared_x_y_op_out, squared_xmuly})
sub_op->LinksFrom({squared_xmuly, mat_squared_x_y_op_out})
.LinksTo({sub_op_out});
constant_op->LinksFrom({}).LinksTo({constant_op_out});
elementmul_op->LinksFrom({constant_op_out, sub_op_out})
......
......@@ -33,7 +33,7 @@ class SquaredMatSubFusePass : public FusePassBase {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"squared_mat_sub"};
const std::string name_scope_{"squared_mat_sub_fuse"};
};
} // namespace ir
......
......@@ -21,6 +21,12 @@ namespace paddle {
namespace inference {
namespace analysis {
// diff: similarity_norm.tmp_0, for speed: fc_4.tmp_1
static const char out_var_name[] = "reduce_sum_0.tmp_0";
// for diff: 154, for speed 111
constexpr int num_slots = 154;
struct OneSlotInBatch {
std::string name;
std::vector<std::vector<float>> data;
......@@ -41,7 +47,6 @@ struct DataRecord {
void Load(const std::string &path) {
std::ifstream file(path);
constexpr int num_slots = 154;
std::string line;
int num_lines = 0;
while (std::getline(file, line)) {
......@@ -190,13 +195,15 @@ void analysis_fuse_statis(bool use_zerocopy) {
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse"));
ASSERT_TRUE(fuse_statis.count("squared_mat_sub_fuse"));
ASSERT_TRUE(fuse_statis.count("repeated_fc_relu_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2);
ASSERT_TRUE(fuse_statis.count("repeated_fc_relu"));
EXPECT_EQ(fuse_statis.at("repeated_fc_relu"), 2);
EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 2);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 185);
EXPECT_EQ(num_ops, 171);
}
// Check the fuse status
......@@ -219,9 +226,6 @@ void PrepareZeroCopyInputs(
}
}
// diff: similarity_norm.tmp_0, // speed: fc_4.tmp_1
static const char out_var_name[] = "reduce_sum_0.tmp_0";
// return the output values
std::vector<float> zerocopy_profile(int repeat_times) {
AnalysisConfig config;
......
......@@ -68,7 +68,7 @@ void FusionSquaredMatSubOpMaker::Make() {
AddComment(R"DOC(
Fusion Squared Matrix and substrct operator.
( (A.^2 * B.^2) - (A * B).^2 ) .* scalar
( (X * Y).^2 - (X.^2 * Y.^2) ) .* scalar
)DOC");
}
......@@ -112,14 +112,14 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
T* squared_xy_data = squared_xy->mutable_data<T>(place);
T* o_data = out->mutable_data<T>(place);
matmul(x_data, y_data, squared_xy_data, m, n, k);
vsquare_xy(squared_xy_data, squared_xy_data, o_numel);
vsquare_x(x_data, squared_x_data, m * k);
vsquare_y(y_data, squared_y_data, k * n);
matmul(x_data, y_data, o_data, m, n, k);
vsquare_xy(o_data, squared_xy_data, o_numel);
matmul(squared_x_data, squared_y_data, o_data, m, n, k);
vsub(o_data, squared_xy_data, o_data, o_numel);
vsub(squared_xy_data, o_data, o_data, o_numel);
vscal(&scalar, o_data, o_data, o_numel);
}
};
......
......@@ -33,7 +33,7 @@ class TestFusionSquaredMatSubOp(OpTest):
self.inputs = {'X': matx, 'Y': maty}
self.outputs = {
'Out':
(np.dot(matx**2, maty**2) - np.dot(matx, maty)**2) * self.scalar
(np.dot(matx, maty)**2 - np.dot(matx**2, maty**2)) * self.scalar
}
self.attrs = {'scalar': self.scalar, }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册