未验证 提交 a7fc3d42 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #15304 from tensor-tang/fuse/second_order_mul_sub

Fuse/second order mul sub and fuse repeated fc relu
......@@ -43,6 +43,8 @@ pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(seqpool_concat_fuse_pass inference)
pass_library(repeated_fc_relu_fuse_pass inference)
pass_library(squared_mat_sub_fuse_pass inference)
pass_library(is_test_pass base)
pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library(conv_elementwise_add2_act_fuse_pass inference)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h"
#include <algorithm> // for max
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#define MAX_NUM_FC 10
namespace paddle {
namespace framework {
namespace ir {
PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
const std::string& name_scope, int num_fc) {
auto var_next_is_fc_act = [=](Node* x, const std::string& act_type = "relu",
bool check_in_has_only_one_out = true,
int fc_idx = 0) -> bool {
bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc");
if (check_in_has_only_one_out) {
next_is_fc = next_is_fc && x->outputs.size() == 1;
}
if (!next_is_fc) {
return false;
}
auto* fc_op = x->outputs[fc_idx];
bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 &&
fc_op->outputs[0] && fc_op->outputs[0]->IsVar() &&
VarLinksToOp(fc_op->outputs[0], act_type) &&
fc_op->outputs[0]->outputs.size() == 1;
if (!next_is_act) {
return false;
}
auto* act_op = fc_op->outputs[0]->outputs[0];
return act_op && act_op->IsOp() && act_op->outputs.size() == 1;
};
auto find_fc_idx = [=](Node* x, const std::string& act_type = "relu") -> int {
bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc");
if (!next_is_fc) {
return 0;
}
for (size_t k = 0; k < x->outputs.size(); ++k) {
auto* fc_op = x->outputs[k];
bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 &&
fc_op->outputs[0] && fc_op->outputs[0]->IsVar() &&
VarLinksToOp(fc_op->outputs[0], act_type) &&
fc_op->outputs[0]->outputs.size() == 1;
if (!next_is_act) {
continue;
}
auto* act_op = fc_op->outputs[0]->outputs[0];
if (act_op && act_op->IsOp() && act_op->outputs.size() == 1) {
return k;
}
}
return 0;
};
auto next_var_of_part = [=](Node* x, int fc_idx = 0) -> Node* {
return x->outputs[fc_idx]->outputs[0]->outputs[0]->outputs[0];
};
auto var_next_is_fc_act_repeated_n_times = [=](
Node* x, int repeated_times, const std::string& act_type = "relu",
bool check_in_has_only_one_out = true) -> bool {
for (int i = 0; i < repeated_times; ++i) {
if (!var_next_is_fc_act(x, act_type,
i == 0 && check_in_has_only_one_out)) {
return false;
}
x = next_var_of_part(x);
}
return true;
};
auto var_before_is_fc_act = [=](Node* x, const std::string& act_type = "relu",
bool at_top = false) -> bool {
bool before_is_act =
x && x->IsVar() && x->inputs.size() == 1 && VarLinksFromOp(x, "relu");
if (!before_is_act) {
return false;
}
auto* relu_op = x->inputs[0];
bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 &&
relu_op->inputs[0]->IsVar() &&
VarLinksFromOp(relu_op->inputs[0], "fc") &&
relu_op->inputs[0]->inputs.size() == 1;
if (!before_is_fc) {
return false;
}
auto* fc_op = relu_op->inputs[0]->inputs[0];
bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3;
if (!is_fc) {
return false;
}
for (auto* fc_i : fc_op->inputs) {
if (!fc_i->inputs.empty()) {
if (at_top) {
return true;
} else {
return VarLinksFromOp(fc_i, "relu");
}
}
}
return false;
};
auto before_var_of_part = [=](Node* x) -> Node* {
auto* fc_op = x->inputs[0]->inputs[0];
for (auto* fc_i : fc_op->inputs) {
if (!fc_i->inputs.empty()) {
return fc_i->inputs[0];
}
}
return nullptr;
};
auto var_before_is_fc_act_repeated_n_times = [=](
Node* x, int repeated_times,
const std::string& act_type = "relu") -> bool {
for (int i = 0; i < repeated_times; ++i) {
if (!var_before_is_fc_act(x, act_type, i == repeated_times - 1)) {
return false;
}
x = before_var_of_part(x);
}
return true;
};
std::vector<PDNode*> fc_input_var(num_fc);
std::vector<PDNode*> fc_output_var(num_fc);
std::vector<PDNode*> fc_weight_var(num_fc);
std::vector<PDNode*> fc_bias_var(num_fc);
std::vector<PDNode*> fc_ops(num_fc);
std::vector<PDNode*> relu_ops(num_fc);
for (int i = 0; i < num_fc; ++i) {
fc_input_var[i] = pattern->NewNode(
[=](Node* x) {
if (i == 0 && x->outputs.size() > 0) {
bool ok = x->inputs.size() > 0;
if (!ok) {
return false;
}
int idx = find_fc_idx(x);
if (idx == 0) {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu");
} else {
x = next_var_of_part(x, idx);
return var_next_is_fc_act_repeated_n_times(
x, std::max(1, num_fc - i - 1), "relu");
}
} else {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.size() > 0 &&
var_before_is_fc_act_repeated_n_times(x, i, "relu");
}
},
name_scope + "/fc_in_" + std::to_string(i));
fc_weight_var[i] = pattern->NewNode(
[=](Node* x) {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.empty() &&
var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0],
i, "relu") &&
x->Name() == x->outputs[0]->Op()->Input("W")[0];
},
name_scope + "/fc_weight_" + std::to_string(i));
fc_bias_var[i] = pattern->NewNode(
[=](Node* x) {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.empty() &&
var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0],
i, "relu") &&
x->Name() == x->outputs[0]->Op()->Input("Bias")[0];
},
name_scope + "/fc_bias_" + std::to_string(i));
fc_output_var[i] = pattern->NewNode(
[=](Node* x) {
bool basic = x && x->IsVar() && VarLinksFromOp(x, "fc") &&
VarLinksToOp(x, "relu") && x->inputs.size() == 1 &&
x->inputs[0]->inputs.size() == 3;
if (!basic) {
return false;
}
x = x->inputs[0]->inputs[0];
if (i == 0 && x->outputs.size() > 0) {
bool ok = x->inputs.size() > 0;
if (!ok) {
return false;
}
int idx = find_fc_idx(x);
if (idx == 0) {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu");
} else {
x = next_var_of_part(x, idx);
return var_next_is_fc_act_repeated_n_times(
x, std::max(1, num_fc - i - 1), "relu");
}
} else {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.size() > 0 &&
var_before_is_fc_act_repeated_n_times(x, i, "relu");
}
},
name_scope + "/fc_out_" + std::to_string(i));
fc_ops[i] = pattern->NewNode(
[=](Node* x) {
bool basic = x && x->IsOp() && x->Op()->Type() == "fc" &&
x->inputs.size() == 3 && x->outputs.size() == 1;
if (!basic) {
return false;
}
auto* fc_out_var = x->outputs[0];
return fc_out_var && fc_out_var->IsVar() &&
fc_out_var->outputs.size() == 1 &&
VarLinksToOp(fc_out_var, "relu") &&
fc_out_var->outputs[0]->outputs.size() == 1 &&
var_next_is_fc_act_repeated_n_times(
fc_out_var->outputs[0]->outputs[0], num_fc - i - 1,
"relu") &&
var_before_is_fc_act_repeated_n_times(
fc_out_var->outputs[0]->outputs[0], i + 1, "relu");
},
name_scope + "/fc_op_" + std::to_string(i));
relu_ops[i] = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "relu" &&
x->inputs.size() == 1 && x->outputs.size() == 1 &&
x->inputs[0]->IsVar() && VarLinksFromOp(x->inputs[0], "fc") &&
x->outputs[0]->IsVar() &&
var_next_is_fc_act_repeated_n_times(x->outputs[0],
num_fc - i - 1, "relu") &&
var_before_is_fc_act_repeated_n_times(x->outputs[0], i + 1,
"relu");
},
name_scope + "/act_op_" + std::to_string(i));
fc_ops[i]
->LinksFrom({fc_input_var[i], fc_weight_var[i], fc_bias_var[i]})
.LinksTo({fc_output_var[i]});
relu_ops[i]->LinksFrom({fc_output_var[i]});
}
auto* last_out_var = pattern->NewNode(
[=](Node* x) {
return var_before_is_fc_act_repeated_n_times(x, num_fc, "relu");
},
name_scope + "/act_out");
for (int i = 0; i < num_fc - 1; ++i) {
relu_ops[i]->LinksTo({fc_input_var[i + 1]});
}
relu_ops[num_fc - 1]->LinksTo({last_out_var});
return last_out_var;
}
static int BuildFusion(Graph* graph, const std::string& name_scope,
int num_fc) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
BuildRepeatedFCReluPattern(pattern, name_scope, num_fc);
auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)),
"pattern has no Node called %s", name.c_str());
Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str());
return p;
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
LOG(INFO) << "handle Repeated FC Act fuse";
std::vector<Node*> weights_vars(num_fc);
std::vector<Node*> bias_vars(num_fc);
std::vector<Node*> relu_vars(num_fc - 1);
std::vector<std::string> weight_names(num_fc);
std::vector<std::string> bias_names(num_fc);
std::vector<std::string> relu_names(num_fc - 1);
auto& fused_pattern = gpd.pattern();
for (int i = 0; i < num_fc; ++i) {
if (i >= 1) {
relu_vars[i - 1] =
retrieve_node(name_scope + "/fc_in_" + std::to_string(i), subgraph,
fused_pattern);
relu_names[i - 1] = relu_vars[i - 1]->Name();
}
weights_vars[i] =
retrieve_node(name_scope + "/fc_weight_" + std::to_string(i),
subgraph, fused_pattern);
weight_names[i] = weights_vars[i]->Name();
bias_vars[i] = retrieve_node(name_scope + "/fc_bias_" + std::to_string(i),
subgraph, fused_pattern);
bias_names[i] = bias_vars[i]->Name();
}
auto* input_var =
retrieve_node(name_scope + "/fc_in_0", subgraph, fused_pattern);
auto* last_out_var =
retrieve_node(name_scope + "/act_out", subgraph, fused_pattern);
// Create New OpDesc
OpDesc op_desc;
op_desc.SetType("fusion_repeated_fc_relu");
op_desc.SetInput("X", {input_var->Name()});
op_desc.SetInput("W", weight_names);
op_desc.SetInput("Bias", bias_names);
op_desc.SetOutput("ReluOut", relu_names);
op_desc.SetOutput("Out", {last_out_var->Name()});
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input_var, op);
for (size_t i = 0; i < weights_vars.size(); ++i) {
IR_NODE_LINK_TO(weights_vars[i], op);
IR_NODE_LINK_TO(bias_vars[i], op);
}
for (size_t i = 0; i < relu_vars.size(); ++i) {
IR_NODE_LINK_TO(op, relu_vars[i]);
}
IR_NODE_LINK_TO(op, last_out_var);
std::unordered_set<const Node*> marked_nodes;
for (auto& item : subgraph) {
marked_nodes.insert(item.second);
}
for (size_t i = 0; i < weights_vars.size(); ++i) {
marked_nodes.erase(weights_vars[i]);
marked_nodes.erase(bias_vars[i]);
}
for (size_t i = 0; i < relu_vars.size(); ++i) {
marked_nodes.erase(relu_vars[i]);
}
marked_nodes.erase(input_var);
marked_nodes.erase(last_out_var);
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
std::unique_ptr<ir::Graph> RepeatedFCReluFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = 0;
for (int i = MAX_NUM_FC; i > 1; --i) {
fusion_count +=
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i);
}
AddStatis(fusion_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(repeated_fc_relu_fuse_pass,
paddle::framework::ir::RepeatedFCReluFusePass);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/**
* Fuse Repeated FC Relu
*/
class RepeatedFCReluFusePass : public FusePassBase {
public:
virtual ~RepeatedFCReluFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"repeated_fc_relu_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -129,7 +129,8 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
return concat_out_var;
}
int BuildFusion(Graph* graph, const std::string& name_scope, int num_inputs) {
static int BuildFusion(Graph* graph, const std::string& name_scope,
int num_inputs) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
namespace ir {
PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
const std::string& name_scope) {
auto var_is_op_input = [=](Node* x, const std::string& op_type,
const std::string& arg_name = "") -> bool {
if (!(x && x->IsVar())) {
return false;
}
for (auto* op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
if (arg_name.empty()) {
return true;
}
for (auto& name : op->Op()->Input(arg_name)) {
if (name == x->Name()) {
return true;
}
}
}
}
return false;
};
auto var_is_op_only_output = [](Node* x, const std::string& op_type) -> bool {
return x && x->IsVar() && x->inputs.size() == 1 && x->inputs[0] &&
x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == op_type &&
x->inputs[0]->outputs.size() == 1;
};
auto next_op = [=](Node* x, const std::string& op_type) -> Node* {
if (!(x && x->IsVar())) {
return nullptr;
}
for (auto* op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
return op;
}
}
return nullptr;
};
auto get_op_input_var = [=](Node* x, const std::string& arg_name) -> Node* {
if (!(x && x->IsOp())) {
return nullptr;
}
for (auto* var : x->inputs) {
for (auto name : x->Op()->Input(arg_name)) {
if (var->Name() == name) {
return var;
}
}
}
return nullptr;
};
auto is_fusion_input_var = [=](Node* x, const std::string& arg_name) {
bool basic = var_is_op_input(x, "matmul", arg_name) &&
var_is_op_input(x, "square", "X");
if (!basic) {
return false;
}
auto* squared_x_op = next_op(x, "square");
if (!(squared_x_op && squared_x_op->outputs.size() == 1)) {
return false;
}
auto* squared_x = squared_x_op->outputs[0];
bool next_is_matmul_from_arg =
var_is_op_input(squared_x, "matmul", arg_name) &&
squared_x->outputs.size() == 1 &&
squared_x->outputs[0]->outputs.size() == 1;
if (!next_is_matmul_from_arg) {
return false;
}
auto* sub_y_in = squared_x->outputs[0]->outputs[0];
return var_is_op_input(sub_y_in, "elementwise_sub", "Y") &&
sub_y_in->outputs[0]->outputs.size() == 1 &&
var_is_op_input(sub_y_in->outputs[0]->outputs[0], "elementwise_mul");
};
auto is_fusion_first_mul_out = [=](Node* x) -> bool {
bool input_is_matmul_op = x && x->inputs.size() == 1 &&
x->inputs[0]->IsOp() &&
x->inputs[0]->Op()->Type() == "matmul";
if (!input_is_matmul_op) {
return false;
}
auto* mat_x = get_op_input_var(x->inputs[0], "X");
auto* mat_y = get_op_input_var(x->inputs[0], "Y");
bool input_mul_is_valid = mat_x && is_fusion_input_var(mat_x, "X") &&
mat_y && is_fusion_input_var(mat_y, "Y");
if (!input_mul_is_valid) {
return false;
}
bool next_is_square = var_is_op_input(x, "square", "X") &&
x->outputs.size() == 1 &&
x->outputs[0]->outputs.size() == 1;
if (!next_is_square) {
return false;
}
auto* sub_x_in = x->outputs[0]->outputs[0];
return var_is_op_input(sub_x_in, "elementwise_sub", "X") &&
sub_x_in->outputs[0]->outputs.size() == 1 &&
var_is_op_input(sub_x_in->outputs[0]->outputs[0], "elementwise_mul");
};
auto* x = pattern->NewNode(
[=](Node* x) { return is_fusion_input_var(x, "X"); }, name_scope + "/x");
auto* y = pattern->NewNode(
[=](Node* x) { return is_fusion_input_var(x, "Y"); }, name_scope + "/y");
auto* square_x_op = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "square" &&
is_fusion_input_var(x->inputs[0], "X");
},
name_scope + "/squared_x_op");
auto* square_y_op = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "square" &&
is_fusion_input_var(x->inputs[0], "Y");
},
name_scope + "/squared_y_op");
auto* squared_x = pattern->NewNode(
[=](Node* x) {
return x && x->inputs.size() == 1 && x->inputs[0]->inputs.size() == 1 &&
is_fusion_input_var(x->inputs[0]->inputs[0], "X");
},
name_scope + "/squared_x");
auto* squared_y = pattern->NewNode(
[=](Node* x) {
return x && x->inputs.size() == 1 && x->inputs[0]->inputs.size() == 1 &&
is_fusion_input_var(x->inputs[0]->inputs[0], "Y");
},
name_scope + "/squared_y");
auto* matmuled_xy =
pattern->NewNode([=](Node* x) { return is_fusion_first_mul_out(x); },
name_scope + "/matmuled_xy");
auto* matmul_xy_op = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "matmul" &&
is_fusion_first_mul_out(x->outputs[0]);
},
name_scope + "/matmul_xy_op");
auto* square_matmuled_xy_op = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "square" &&
is_fusion_first_mul_out(x->inputs[0]);
},
name_scope + "/square_matmuled_xy_op");
auto* squared_xmuly = pattern->NewNode(
[=](Node* x) {
return x && x->IsVar() && x->inputs.size() == 1 &&
x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == "square" &&
is_fusion_first_mul_out(x->inputs[0]->inputs[0]);
},
name_scope + "/squared_xmuly");
auto is_fusion_mat_squared_x_y_op_out = [=](Node* x) -> bool {
bool basic = x && x->IsVar() && x->inputs.size() == 1 &&
x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == "matmul";
if (!basic) {
return false;
}
auto* sqx = get_op_input_var(x->inputs[0], "X");
auto* sqy = get_op_input_var(x->inputs[0], "Y");
return var_is_op_only_output(sqx, "square") &&
var_is_op_only_output(sqy, "square") && sqx->inputs[0] &&
sqx->inputs[0]->inputs.size() == 1 &&
is_fusion_input_var(sqx->inputs[0]->inputs[0], "X") &&
sqy->inputs[0] && sqy->inputs[0]->inputs.size() == 1 &&
is_fusion_input_var(sqy->inputs[0]->inputs[0], "Y");
};
auto* matmul_squared_x_y_op = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "matmul" &&
is_fusion_mat_squared_x_y_op_out(x->outputs[0]);
},
name_scope + "/matmul_squared_x_y_op");
auto* mat_squared_x_y_op_out = pattern->NewNode(
[=](Node* x) { return is_fusion_mat_squared_x_y_op_out(x); },
name_scope + "/mat_squared_x_y_op_out");
auto is_fusion_sub_op = [=](Node* x) -> bool {
bool is_sub_op = x && x->IsOp() && x->Op()->Type() == "elementwise_sub";
if (!is_sub_op) {
return false;
}
auto* matmul_sqx_sqy_var = get_op_input_var(x, "Y");
return is_fusion_mat_squared_x_y_op_out(matmul_sqx_sqy_var);
};
auto* sub_op = pattern->NewNode([=](Node* x) { return is_fusion_sub_op(x); },
name_scope + "/sub_op");
auto* sub_op_out = pattern->NewNode(
[=](Node* x) {
return x && x->IsVar() && x->inputs.size() == 1 &&
is_fusion_sub_op(x->inputs[0]);
},
name_scope + "/sub_op_out");
auto is_fusion_element_op = [=](Node* x) -> bool {
bool is_elemul_op = x && x->IsOp() && x->Op()->Type() == "elementwise_mul";
if (!is_elemul_op) {
return false;
}
for (auto* in : x->inputs) {
if (in && in->inputs[0] && is_fusion_sub_op(in->inputs[0])) {
return true;
}
}
return false;
};
auto* elementmul_op =
pattern->NewNode([=](Node* x) { return is_fusion_element_op(x); },
name_scope + "/elementmul_op");
auto* constant_op = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "fill_constant" &&
x->outputs.size() == 1 &&
is_fusion_element_op(x->outputs[0]->outputs[0]);
},
name_scope + "/fill_constant_op");
auto* constant_op_out = pattern->NewNode(
[=](Node* x) {
return x && x->IsVar() && var_is_op_input(x, "elementwise_mul") &&
x->inputs[0] && x->inputs[0]->IsOp() &&
x->inputs[0]->Op()->Type() == "fill_constant" && x->outputs[0] &&
is_fusion_element_op(x->outputs[0]);
},
name_scope + "/constant_op_out");
auto* last_out_var = pattern->NewNode(
[=](Node* x) {
return var_is_op_only_output(x, "elementwise_mul") &&
is_fusion_element_op(x->inputs[0]);
},
name_scope + "/out");
square_x_op->LinksFrom({x}).LinksTo({squared_x});
square_y_op->LinksFrom({y}).LinksTo({squared_y});
matmul_xy_op->LinksFrom({x, y}).LinksTo({matmuled_xy});
matmul_squared_x_y_op->LinksFrom({squared_x, squared_y})
.LinksTo({mat_squared_x_y_op_out});
square_matmuled_xy_op->LinksFrom({matmuled_xy}).LinksTo({squared_xmuly});
sub_op->LinksFrom({squared_xmuly, mat_squared_x_y_op_out})
.LinksTo({sub_op_out});
constant_op->LinksFrom({}).LinksTo({constant_op_out});
elementmul_op->LinksFrom({constant_op_out, sub_op_out})
.LinksTo({last_out_var});
return last_out_var;
}
static int BuildFusion(Graph* graph, const std::string& name_scope) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
BuildSquaredMatSubPattern(pattern, name_scope);
auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)),
"pattern has no Node called %s", name.c_str());
Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str());
return p;
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
LOG(INFO) << "handle sqaure mat sub fuse";
auto& fused_pattern = gpd.pattern();
auto* matx = retrieve_node(name_scope + "/x", subgraph, fused_pattern);
auto* maty = retrieve_node(name_scope + "/y", subgraph, fused_pattern);
auto* squaredx =
retrieve_node(name_scope + "/squared_x", subgraph, fused_pattern);
auto* squaredy =
retrieve_node(name_scope + "/squared_y", subgraph, fused_pattern);
auto* squaredxy =
retrieve_node(name_scope + "/squared_xmuly", subgraph, fused_pattern);
auto* last_out_var =
retrieve_node(name_scope + "/out", subgraph, fused_pattern);
auto* fill_constant_op = retrieve_node(name_scope + "/fill_constant_op",
subgraph, fused_pattern);
// Create New OpDesc
OpDesc op_desc;
op_desc.SetType("fusion_squared_mat_sub");
op_desc.SetInput("X", {matx->Name()});
op_desc.SetInput("Y", {maty->Name()});
op_desc.SetOutput("SquaredX", {squaredx->Name()});
op_desc.SetOutput("SquaredY", {squaredy->Name()});
op_desc.SetOutput("SquaredXY", {squaredxy->Name()});
op_desc.SetOutput("Out", {last_out_var->Name()});
op_desc.SetAttr("scalar", fill_constant_op->Op()->GetAttr("value"));
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(matx, op);
IR_NODE_LINK_TO(maty, op);
IR_NODE_LINK_TO(op, squaredx);
IR_NODE_LINK_TO(op, squaredy);
IR_NODE_LINK_TO(op, squaredxy);
IR_NODE_LINK_TO(op, last_out_var);
std::unordered_set<const Node*> marked_nodes;
for (auto& item : subgraph) {
marked_nodes.insert(item.second);
}
marked_nodes.erase(matx);
marked_nodes.erase(maty);
marked_nodes.erase(squaredx);
marked_nodes.erase(squaredy);
marked_nodes.erase(squaredxy);
marked_nodes.erase(last_out_var);
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
std::unique_ptr<ir::Graph> SquaredMatSubFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(squared_mat_sub_fuse_pass,
paddle::framework::ir::SquaredMatSubFusePass);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/**
* Fuse ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar
*/
class SquaredMatSubFusePass : public FusePassBase {
public:
virtual ~SquaredMatSubFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"squared_mat_sub_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -98,6 +98,8 @@ class CpuPassStrategy : public PassStrategy {
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", //
"squared_mat_sub_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"is_test_pass", //
......
......@@ -37,15 +37,21 @@ function(inference_analysis_api_test_with_refer_result target install_dir filena
--refer_result=${install_dir}/result.txt)
endfunction()
# RNN1
if(NOT APPLE AND WITH_MKLML)
# RNN1
set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1")
download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz")
inference_analysis_api_test(test_analyzer_rnn1 ${RNN1_INSTALL_DIR} analyzer_rnn1_tester.cc SERIAL)
# seq_pool1
set(SEQ_POOL1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_pool")
download_model_and_data(${SEQ_POOL1_INSTALL_DIR} "seq_pool1_model_.tar.gz" "seq_pool1_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_seq_pool1 ${SEQ_POOL1_INSTALL_DIR} analyzer_seq_pool1_tester.cc SERIAL)
else()
# TODO: fix this test on MACOS and OPENBLAS, the reason is that
# fusion_seqexpand_concat_fc_op is not supported on MACOS and OPENBLAS
message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_rnn1")
message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_seq_pool1")
endif()
# RNN2
......@@ -90,11 +96,6 @@ set(SEQ_CONV1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_conv1")
download_model_and_data(${SEQ_CONV1_INSTALL_DIR} "seq_conv1_model.tar.gz" "seq_conv1_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} analyzer_seq_conv1_tester.cc)
# seq_pool1
set(SEQ_POOL1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_pool")
download_model_and_data(${SEQ_POOL1_INSTALL_DIR} "seq_pool1_model_.tar.gz" "seq_pool1_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_seq_pool1 ${SEQ_POOL1_INSTALL_DIR} analyzer_seq_pool1_tester.cc)
# ocr
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR})
......
......@@ -21,6 +21,12 @@ namespace paddle {
namespace inference {
namespace analysis {
// diff: similarity_norm.tmp_0, for speed: fc_4.tmp_1
static const char out_var_name[] = "reduce_sum_0.tmp_0";
// for diff: 154, for speed 111
constexpr int num_slots = 154;
struct OneSlotInBatch {
std::string name;
std::vector<std::vector<float>> data;
......@@ -41,7 +47,6 @@ struct DataRecord {
void Load(const std::string &path) {
std::ifstream file(path);
constexpr int num_slots = 154;
std::string line;
int num_lines = 0;
while (std::getline(file, line)) {
......@@ -187,11 +192,15 @@ void analysis_fuse_statis(bool use_zerocopy) {
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse"));
ASSERT_TRUE(fuse_statis.count("squared_mat_sub_fuse"));
ASSERT_TRUE(fuse_statis.count("repeated_fc_relu_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2);
EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 2);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 195);
EXPECT_EQ(num_ops, 171);
}
// Check the fuse status
......@@ -214,9 +223,6 @@ void PrepareZeroCopyInputs(
}
}
// diff: similarity_norm.tmp_0, // speed: fc_4.tmp_1
static const char out_var_name[] = "reduce_sum_0.tmp_0";
// return the output values
std::vector<float> zerocopy_profile(int repeat_times) {
AnalysisConfig config;
......@@ -322,7 +328,9 @@ TEST(Analyzer_seq_pool1, zerocopy_compare_native) {
native_outputs.front().data.length());
auto *native_data = static_cast<float *>(native_outputs.front().data.data());
for (size_t i = 0; i < zerocopy_output.size(); ++i) {
EXPECT_NEAR(zerocopy_output[i], native_data[i], 1e-3);
EXPECT_LT(
std::fabs((zerocopy_output[i] - native_data[i]) / zerocopy_output[i]),
1e-3);
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
namespace paddle {
namespace operators {
void FusionRepeatedFCReluOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FusionRepeatedFCReluOp should not be null.");
auto sz = ctx->Inputs("W").size();
PADDLE_ENFORCE_GT(
sz, 1UL, "Inputs(W) of FusionRepeatedFCReluOp should larger than 1.");
PADDLE_ENFORCE_EQ(ctx->Inputs("Bias").size(), sz,
"Size of inputs(Bias) of FusionRepeatedFCReluOp should be "
"equal to inputs size.");
PADDLE_ENFORCE_EQ(ctx->Outputs("ReluOut").size(), sz - 1,
"Size of output(ReluOut) of FusionRepeatedFCReluOp should "
"be equal to inputs size -1.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionRepeatedFCReluOp should not be null.");
auto i_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(i_dims.size(), 2UL, "Input shape size should be 2");
auto w_dims = ctx->GetInputsDim("W");
auto b_dims = ctx->GetInputsDim("Bias");
PADDLE_ENFORCE_EQ(w_dims.size(), b_dims.size(),
"Shape size of weight and bias should be equal");
PADDLE_ENFORCE_EQ(w_dims.size(), sz,
"Shape size of weight and bias should be equal");
PADDLE_ENFORCE_EQ(i_dims[1], w_dims[0][0],
"inpute width should be equal with weight height");
for (size_t i = 1; i < sz; ++i) {
PADDLE_ENFORCE_EQ(w_dims[i].size(), 2UL,
"Every weight shape size should be 2.");
PADDLE_ENFORCE_EQ(framework::product(b_dims[i]), w_dims[i][1],
"The length of Bias must be equal with w_dims[1].");
}
ctx->SetOutputDim("Out", {i_dims[0], w_dims[sz - 1][1]});
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")),
ctx.GetPlace());
}
void FusionRepeatedFCReluOpMaker::Make() {
AddInput("X", "(LoDTensor) Input tensors of this operator.");
AddInput("W", "(Tensor) The weight tensors of this operator.").AsDuplicable();
AddInput("Bias", "(Tensor) The bias tensors of this operator.")
.AsDuplicable();
AddOutput("ReluOut", "(Tensor) The output tensor of each relu operator.")
.AsDuplicable()
.AsIntermediate();
AddOutput("Out", "(LoDTensor) Output tensor of this operator.");
AddComment(R"DOC(
Fusion Repeated FC with Relu Operator.
)DOC");
}
template <typename T>
static void fc_relu(const T* x, const T* w, const T* b, T* y, int m, int n,
int k) {
auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(k);
auto addbias_relu =
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(n);
matmul(x, w, y, m, n, k);
T* dst = y;
for (int i = 0; i < m; ++i) {
addbias_relu(b, dst, dst, n);
dst += n;
}
}
template <typename T>
class FusionRepeatedFCReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto in = ctx.Input<Tensor>("X");
auto weights = ctx.MultiInput<Tensor>("W");
auto biases = ctx.MultiInput<Tensor>("Bias");
auto relus = ctx.MultiOutput<Tensor>("ReluOut");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
int weight_sz = static_cast<int>(weights.size());
auto i_dims = in->dims();
auto w_dims = weights[0]->dims();
int m = i_dims[0];
int n = w_dims[1];
int k = w_dims[0];
relus[0]->Resize({m, n});
fc_relu(in->data<T>(), weights[0]->data<T>(), biases[0]->data<T>(),
relus[0]->mutable_data<T>(place), m, n, k);
for (int i = 1; i < weight_sz - 1; ++i) {
auto i_dims = relus[i - 1]->dims();
auto w_dims = weights[i]->dims();
int m = i_dims[0];
int n = w_dims[1];
int k = w_dims[0];
relus[i]->Resize({m, n});
fc_relu(relus[i - 1]->data<T>(), weights[i]->data<T>(),
biases[i]->data<T>(), relus[i]->mutable_data<T>(place), m, n, k);
}
auto i_dims_last = relus[weight_sz - 2]->dims();
auto w_dims_last = weights[weight_sz - 1]->dims();
m = i_dims_last[0];
n = w_dims_last[1];
k = w_dims_last[0];
fc_relu(relus[weight_sz - 2]->data<T>(), weights[weight_sz - 1]->data<T>(),
biases[weight_sz - 1]->data<T>(), out->mutable_data<T>(place), m, n,
k);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_repeated_fc_relu, ops::FusionRepeatedFCReluOp,
ops::FusionRepeatedFCReluOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_repeated_fc_relu,
ops::FusionRepeatedFCReluKernel<float>,
ops::FusionRepeatedFCReluKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionRepeatedFCReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionRepeatedFCReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
namespace paddle {
namespace operators {
void FusionSquaredMatSubOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SquaredX"),
"Output(SquaredX) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SquaredY"),
"Output(SquaredY) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SquaredXY"),
"Output(SquaredXY) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSquaredMatSubOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Input tensors dims size should be equal.");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input tensors should be a Matrix.");
PADDLE_ENFORCE_EQ(x_dims[1], y_dims[0], "Inputs Matrix should be multiply.");
ctx->SetOutputDim("SquaredX", x_dims);
ctx->SetOutputDim("SquaredY", y_dims);
ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]});
ctx->SetOutputDim("Out", {x_dims[0], y_dims[1]});
}
framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")),
ctx.GetPlace());
}
void FusionSquaredMatSubOpMaker::Make() {
AddInput("X", "(Tensor) Input Mat A of this operator.");
AddInput("Y", "(Tensor) Input Mat B of this operator.");
AddOutput("SquaredX", "(Tensor) Squared X.").AsIntermediate();
AddOutput("SquaredY", "(Tensor) Squared Y.").AsIntermediate();
AddOutput("SquaredXY", "(Tensor) Squared X*Y.").AsIntermediate();
AddOutput("Out", "(Tensor) Output tensor of concat operator.");
AddAttr<float>("scalar", "The scalar on output matrix.").SetDefault(1.f);
AddComment(R"DOC(
Fusion Squared Matrix and substrct operator.
( (X * Y).^2 - (X.^2 * Y.^2) ) .* scalar
)DOC");
}
template <typename T>
class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<Tensor>("X");
auto y = ctx.Input<Tensor>("Y");
auto* squared_x = ctx.Output<Tensor>("SquaredX");
auto* squared_y = ctx.Output<Tensor>("SquaredY");
auto* squared_xy = ctx.Output<Tensor>("SquaredXY");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
T scalar = static_cast<T>(ctx.Attr<float>("scalar"));
auto x_dims = x->dims();
auto y_dims = y->dims();
int m = x_dims[0];
int k = x_dims[1];
int n = y_dims[1];
int o_numel = m * n;
auto vsquare_x =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(m * k);
auto vsquare_y =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(k * n);
auto vsquare_xy =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel);
auto vsub =
jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel);
auto vscal =
jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel);
auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(k);
const T* x_data = x->data<T>();
const T* y_data = y->data<T>();
T* squared_x_data = squared_x->mutable_data<T>(place);
T* squared_y_data = squared_y->mutable_data<T>(place);
T* squared_xy_data = squared_xy->mutable_data<T>(place);
T* o_data = out->mutable_data<T>(place);
matmul(x_data, y_data, squared_xy_data, m, n, k);
vsquare_xy(squared_xy_data, squared_xy_data, o_numel);
vsquare_x(x_data, squared_x_data, m * k);
vsquare_y(y_data, squared_y_data, k * n);
matmul(squared_x_data, squared_y_data, o_data, m, n, k);
vsub(squared_xy_data, o_data, o_data, o_numel);
vscal(&scalar, o_data, o_data, o_numel);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_squared_mat_sub, ops::FusionSquaredMatSubOp,
ops::FusionSquaredMatSubOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_squared_mat_sub,
ops::FusionSquaredMatSubKernel<float>,
ops::FusionSquaredMatSubKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
// ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar
class FusionSquaredMatSubOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionSquaredMatSubOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
......@@ -210,6 +210,24 @@ void BenchSeqPoolKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) {
for (int n : TestSizes()) {
for (int k : TestSizes()) {
std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -2.f, 2.f);
RandomVec<T>(k * n, b.data(), -2.f, 2.f);
const T* a_data = a.data();
const T* b_data = b.data();
T* c_data = c.data();
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(k, a_data, b_data,
c_data, m, n, k);
}
}
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
......@@ -236,6 +254,7 @@ int main(int argc, char* argv[]) {
// xyn
BenchXYNKernel<jit::kVRelu, T, PlaceType>();
BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
BenchXYNKernel<jit::kVSquare, T, PlaceType>();
BenchXYNKernel<jit::kVExp, T, PlaceType>();
BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
BenchXYNKernel<jit::kVTanh, T, PlaceType>();
......@@ -251,4 +270,7 @@ int main(int argc, char* argv[]) {
// seq pool function
BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>();
// matmul
BenchMatMulKernel<jit::kMatMul, T, PlaceType>();
}
......@@ -11,11 +11,12 @@ endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(kVMul)
USE_JITKERNEL_GEN(kVAdd)
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(kVSub)
USE_JITKERNEL_GEN(kVAddRelu)
USE_JITKERNEL_GEN(kVScal)
USE_JITKERNEL_GEN(kVAddBias)
USE_JITKERNEL_GEN(kVRelu)
USE_JITKERNEL_GEN(kVSquare)
USE_JITKERNEL_GEN(kVIdentity)
USE_JITKERNEL_GEN(kVExp)
USE_JITKERNEL_GEN(kVSigmoid)
......
......@@ -91,6 +91,7 @@ void VActJitCode::genCode() {
}
DECLARE_ACT_CREATOR(VRelu);
DECLARE_ACT_CREATOR(VSquare);
DECLARE_ACT_CREATOR(VIdentity);
DECLARE_ACT_CREATOR(VExp);
DECLARE_ACT_CREATOR(VSigmoid);
......@@ -103,6 +104,10 @@ size_t VReluCreator::CodeSize(const int& d) const {
8 /* average bytes for each instruction */;
}
size_t VSquareCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
size_t VIdentityCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
......@@ -129,6 +134,7 @@ size_t VTanhCreator::CodeSize(const int& d) const {
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(kVSquare, gen::VSquareCreator);
REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);
......
......@@ -75,6 +75,12 @@ class VActFunc : public JitCode {
vmaxps(dst, src, zero);
}
// compute SQUARE with ymm, xmm
template <typename JMM>
void square_jmm(JMM& dst, JMM& src) { // NOLINT
vmulps(dst, src, src);
}
// compute EXP with ymm, xmm
template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
......@@ -228,6 +234,9 @@ class VActFunc : public JitCode {
case operand_type::RELU:
relu_jmm<JMM>(dst, src, 15);
break;
case operand_type::SQUARE:
square_jmm<JMM>(dst, src);
break;
case operand_type::EXP:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break;
......@@ -254,7 +263,7 @@ class VActJitCode : public VActFunc {
: VActFunc(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
type_ == operand_type::IDENTITY)) {
type_ == operand_type::IDENTITY || type_ == operand_type::SQUARE)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
......@@ -266,6 +275,9 @@ class VActJitCode : public VActFunc {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::SQUARE:
base += "_Square";
break;
case operand_type::EXP:
base += "_Exp";
break;
......@@ -306,6 +318,7 @@ class VActJitCode : public VActFunc {
};
DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
DECLARE_ACT_JITCODE(VSquare, operand_type::SQUARE);
DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);
......
......@@ -43,6 +43,8 @@ void VXXJitCode::genCode() {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::ADD) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::SUB) {
vsubps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
......@@ -85,6 +87,9 @@ void VXXJitCode::genCode() {
case operand_type::ADD:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::SUB:
vsubps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
}
......@@ -178,8 +183,7 @@ namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);
......
......@@ -34,7 +34,8 @@ class VXXJitCode : public JitCode {
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD ||
type_ == operand_type::SUB)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
......@@ -51,6 +52,8 @@ class VXXJitCode : public JitCode {
base += "_Mul";
} else if (type_ == operand_type::ADD) {
base += "_Add";
} else if (type_ == operand_type::SUB) {
base += "_SUB";
}
if (scalar_index_ == 2) {
base += "_Scalar";
......
......@@ -51,6 +51,7 @@ typedef enum {
SUB,
RELU,
EXP,
SQUARE,
SIGMOID,
TANH,
IDENTITY
......
......@@ -36,6 +36,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kVRelu);
ONE_CASE(kVIdentity);
ONE_CASE(kVExp);
ONE_CASE(kVSquare);
ONE_CASE(kVSigmoid);
ONE_CASE(kVTanh);
ONE_CASE(kLSTMCtHt);
......@@ -47,6 +48,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC);
ONE_CASE(kSeqPool);
ONE_CASE(kMatMul);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";
......
......@@ -30,6 +30,7 @@ typedef enum {
kVAddBias,
kVRelu,
kVIdentity,
kVSquare,
kVExp,
kVSigmoid,
kVTanh,
......@@ -42,6 +43,7 @@ typedef enum {
kLayerNorm,
kNCHW16CMulNC,
kSeqPool,
kMatMul,
} KernelType;
typedef enum {
......@@ -135,6 +137,13 @@ struct SeqPoolTuples {
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
};
template <typename T>
struct MatMulTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int, int);
};
template <typename T>
struct CRFDecodingTuples {
typedef T data_type;
......
......@@ -3,10 +3,12 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(kMatMul, mkl)
USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSquare, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl)
......@@ -24,6 +24,20 @@ namespace jit {
namespace more {
namespace mkl {
template <>
void MatMul<float>(const float* a, const float* b, float* c, int m, int n,
int k) {
platform::dynload::cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m,
n, k, 1.f, a, k, b, n, 0.f, c, n);
}
template <>
void MatMul<double>(const double* a, const double* b, double* c, int m, int n,
int k) {
platform::dynload::cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m,
n, k, 1.0, a, k, b, n, 0.0, c, n);
}
template <>
void VMul<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z);
......@@ -72,6 +86,16 @@ void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
template <>
void VSquare<float>(const float* x, float* y, int n) {
platform::dynload::vsSqr(n, x, y);
}
template <>
void VSquare<double>(const double* x, double* y, int n) {
platform::dynload::vdSqr(n, x, y);
}
template <>
void VCopy<float>(const float* x, float* y, int n) {
platform::dynload::cblas_scopy(n, x, 1, y, 1);
......@@ -93,6 +117,11 @@ void VAXPY<double>(double a, const double* x, double* y, int n) {
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool MatMulKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
}
template <>
bool VMulKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
......@@ -113,6 +142,11 @@ bool VExpKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool VSquareKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const {
return d > 7;
......@@ -139,12 +173,14 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true; \
}
AWALYS_USE_ME_WITH_DOUBLE(MatMul);
AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal);
AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
AWALYS_USE_ME_WITH_DOUBLE(VSquare);
#undef AWALYS_USE_ME_WITH_DOUBLE
} // namespace mkl
......@@ -159,10 +195,12 @@ namespace mkl = paddle::operators::jit::more::mkl;
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kMatMul, MatMul);
REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
......
......@@ -24,6 +24,9 @@ namespace jit {
namespace more {
namespace mkl {
template <typename T>
void MatMul(const T* a, const T* b, T* c, int m, int n, int k);
template <typename T>
void VMul(const T* x, const T* y, T* z, int n);
......@@ -36,6 +39,9 @@ void VScal(const T* a, const T* x, T* y, int n);
template <typename T>
void VExp(const T* x, T* y, int n);
template <typename T>
void VSquare(const T* x, T* y, int n);
template <typename T>
void VCopy(const T* x, T* y, int n);
......@@ -93,6 +99,9 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
const char* ImplType() const override { return "MKL"; } \
}
// ABCMNK
DECLARE_MKL_KERNEL(MatMul, MatMulTuples);
// XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
......@@ -104,6 +113,7 @@ DECLARE_MKL_KERNEL(VScal, AXYNTuples);
DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
......
......@@ -27,3 +27,5 @@ USE_JITKERNEL_REFER(kCRFDecoding)
USE_JITKERNEL_REFER(kLayerNorm)
USE_JITKERNEL_REFER(kNCHW16CMulNC)
USE_JITKERNEL_REFER(kSeqPool)
USE_JITKERNEL_REFER(kMatMul)
USE_JITKERNEL_REFER(kVSquare)
......@@ -31,6 +31,7 @@ REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh);
......@@ -49,4 +50,6 @@ REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
REGISTER_REFER_KERNEL(kMatMul, MatMul);
#undef REGISTER_REFER_KERNEL
......@@ -83,6 +83,13 @@ inline void VIdentity(const T* x, T* y, int n) {
}
}
template <typename T>
inline void VSquare(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] * x[i];
}
}
template <typename T>
void VExp(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
......@@ -354,6 +361,23 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
}
}
// A(M,K) * B(K,N) = C(M,N)
template <typename T>
void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {
for (int m = 0; m < M; ++m) {
const T* pa = A + m * K;
T* pc = C + m * N;
for (int n = 0; n < N; ++n) {
const T* pb = B + n;
T sum = static_cast<T>(0);
for (int k = 0; k < K; ++k) {
sum += (pa[k] * pb[k * N]);
}
*(pc + n) = sum;
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
......@@ -377,6 +401,7 @@ DECLARE_REFER_KERNEL(VIdentity, XYNTuples);
DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples);
DECLARE_REFER_KERNEL(VSquare, XYNTuples);
// lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
......@@ -394,6 +419,8 @@ DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_REFER_KERNEL(MatMul, MatMulTuples);
#undef DECLARE_REFER_KERNEL
} // namespace refer
......
......@@ -229,6 +229,26 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
}
};
template <typename T>
struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>, int, int, int> {
void operator()(const typename jit::MatMulTuples<T>::func_type tgt,
const std::vector<T>& a, const std::vector<T>& b,
const std::vector<T>& cref, int m, int n, int k) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(a.size(), static_cast<size_t>(m * k));
EXPECT_EQ(b.size(), static_cast<size_t>(k * n));
EXPECT_EQ(cref.size(), static_cast<size_t>(m * n));
std::vector<T> c(cref.size());
const T* a_data = a.data();
const T* b_data = b.data();
const T* cref_data = cref.data();
T* c_data = c.data();
tgt(a_data, b_data, c_data, m, n, k);
ExpectEQ<T>(c_data, cref_data, m * n);
}
};
template <paddle::operators::jit::KernelType KT, typename KernelTuples,
typename PlaceType, typename... Args>
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
......@@ -458,6 +478,28 @@ void TestSeqPoolKernel() {
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestMatMulKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) {
for (int k : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -0.2f, 0.2f);
RandomVec<T>(k * n, b.data(), -0.2f, 0.2f);
const T* a_data = a.data();
const T* b_data = b.data();
T* c_data = c.data();
ref(a_data, b_data, c_data, m, n, k);
TestAllImpls<KT, jit::MatMulTuples<T>, PlaceType, std::vector<T>,
std::vector<T>, std::vector<T>>(k, a, b, c, m, n, k);
}
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
......@@ -562,6 +604,12 @@ TEST(JITKernel, kVIdentity) {
TestXYNKernel<jit::kVIdentity, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVSquare) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVSquare, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVSquare, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVExp) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVExp, float, paddle::platform::CPUPlace>();
......@@ -618,6 +666,12 @@ TEST(JITKernel, kSeqPool) {
TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kMatMul) {
namespace jit = paddle::operators::jit;
TestMatMulKernel<jit::kMatMul, float, paddle::platform::CPUPlace>();
TestMatMulKernel<jit::kMatMul, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kNCHW16CMulNC) {
namespace jit = paddle::operators::jit;
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from test_fc_op import fc_refer, MatrixGenerate
class TestFusionRepeatedFCReluOp(OpTest):
def setUp(self):
self.bs = 3
self.ic = 9
self.oc = [2, 4, 3]
assert len(self.oc) > 1, 'Should larger than 1'
self.set_conf()
self.op_type = 'fusion_repeated_fc_relu'
sz = len(self.oc)
ics = [self.ic] + self.oc[0:sz - 1]
assert len(ics) == len(self.oc)
weights = []
biases = []
outs = []
i = 0
matrix = MatrixGenerate(self.bs, ics[i], self.oc[i], 1, 1)
inp = np.reshape(matrix.input, [self.bs, ics[i]])
weights.append(('W_{0}'.format(i), np.reshape(matrix.weights,
[ics[i], self.oc[i]])))
biases.append(('B_{0}'.format(i), matrix.bias))
outs.append(
np.reshape(
np.maximum(fc_refer(matrix, True), 0), [self.bs, self.oc[i]]))
for i in range(sz - 1):
matrix = MatrixGenerate(self.bs, ics[i + 1], self.oc[i + 1], 1, 1)
matrix.input = np.reshape(outs[i], [self.bs, ics[i + 1], 1, 1])
out = fc_refer(matrix, True)
weights.append(
('W_{0}'.format(i + 1),
np.reshape(matrix.weights, [ics[i + 1], self.oc[i + 1]])))
biases.append(('B_{0}'.format(i + 1), matrix.bias))
outs.append(
np.reshape(np.maximum(out, 0), [self.bs, self.oc[i + 1]]))
relu_outs = []
for i in range(sz - 1):
relu_outs.append(('ReluOut_{0}'.format(i), outs[i]))
self.inputs = {
'X': inp,
'W': weights,
'Bias': biases,
}
self.outputs = {'Out': outs[-1], 'ReluOut': relu_outs}
def test_check_output(self):
self.check_output()
def set_conf(self):
pass
class TestFusionRepeatedFCReluOpBS1(TestFusionRepeatedFCReluOp):
def set_conf(self):
self.bs = 1
self.oc = [4, 2, 7, 5]
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
class TestFusionSquaredMatSubOp(OpTest):
def setUp(self):
self.op_type = 'fusion_squared_mat_sub'
self.m = 11
self.n = 12
self.k = 4
self.scalar = 0.5
self.set_conf()
matx = np.random.random((self.m, self.k)).astype("float32")
maty = np.random.random((self.k, self.n)).astype("float32")
self.inputs = {'X': matx, 'Y': maty}
self.outputs = {
'Out':
(np.dot(matx, maty)**2 - np.dot(matx**2, maty**2)) * self.scalar
}
self.attrs = {'scalar': self.scalar, }
def set_conf(self):
pass
def test_check_output(self):
self.check_output()
class TestFusionSquaredMatSubOpCase1(TestFusionSquaredMatSubOp):
def set_conf(self):
self.scalar = -0.3
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册