diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 42fb6a1aa5375bfbb266454cfbc7f0fb756f779c..84b5321264796b18ac5bf666bc4b4ac403d9e4ea 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -43,6 +43,8 @@ pass_library(multi_batch_merge_pass base) pass_library(conv_bn_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference) pass_library(seqpool_concat_fuse_pass inference) +pass_library(repeated_fc_relu_fuse_pass inference) +pass_library(squared_mat_sub_fuse_pass inference) pass_library(is_test_pass base) pass_library(conv_elementwise_add_act_fuse_pass inference) pass_library(conv_elementwise_add2_act_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..84a4ff2de173d86184fcef53b8e55fe17958fb8c --- /dev/null +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -0,0 +1,386 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h" +#include // for max +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" + +#define MAX_NUM_FC 10 + +namespace paddle { +namespace framework { +namespace ir { + +PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, + const std::string& name_scope, int num_fc) { + auto var_next_is_fc_act = [=](Node* x, const std::string& act_type = "relu", + bool check_in_has_only_one_out = true, + int fc_idx = 0) -> bool { + bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc"); + if (check_in_has_only_one_out) { + next_is_fc = next_is_fc && x->outputs.size() == 1; + } + if (!next_is_fc) { + return false; + } + auto* fc_op = x->outputs[fc_idx]; + bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 && + fc_op->outputs[0] && fc_op->outputs[0]->IsVar() && + VarLinksToOp(fc_op->outputs[0], act_type) && + fc_op->outputs[0]->outputs.size() == 1; + if (!next_is_act) { + return false; + } + auto* act_op = fc_op->outputs[0]->outputs[0]; + return act_op && act_op->IsOp() && act_op->outputs.size() == 1; + }; + + auto find_fc_idx = [=](Node* x, const std::string& act_type = "relu") -> int { + bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc"); + if (!next_is_fc) { + return 0; + } + for (size_t k = 0; k < x->outputs.size(); ++k) { + auto* fc_op = x->outputs[k]; + bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 && + fc_op->outputs[0] && fc_op->outputs[0]->IsVar() && + VarLinksToOp(fc_op->outputs[0], act_type) && + fc_op->outputs[0]->outputs.size() == 1; + if (!next_is_act) { + continue; + } + auto* act_op = fc_op->outputs[0]->outputs[0]; + if (act_op && act_op->IsOp() && act_op->outputs.size() == 1) { + return k; + } + } + return 0; + }; + + auto next_var_of_part = [=](Node* x, int fc_idx = 0) -> Node* { + return x->outputs[fc_idx]->outputs[0]->outputs[0]->outputs[0]; + }; + auto var_next_is_fc_act_repeated_n_times = [=]( + Node* x, int repeated_times, const std::string& act_type = "relu", + bool check_in_has_only_one_out = true) -> bool { + for (int i = 0; i < repeated_times; ++i) { + if (!var_next_is_fc_act(x, act_type, + i == 0 && check_in_has_only_one_out)) { + return false; + } + x = next_var_of_part(x); + } + return true; + }; + + auto var_before_is_fc_act = [=](Node* x, const std::string& act_type = "relu", + bool at_top = false) -> bool { + bool before_is_act = + x && x->IsVar() && x->inputs.size() == 1 && VarLinksFromOp(x, "relu"); + if (!before_is_act) { + return false; + } + auto* relu_op = x->inputs[0]; + bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 && + relu_op->inputs[0]->IsVar() && + VarLinksFromOp(relu_op->inputs[0], "fc") && + relu_op->inputs[0]->inputs.size() == 1; + + if (!before_is_fc) { + return false; + } + auto* fc_op = relu_op->inputs[0]->inputs[0]; + bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3; + if (!is_fc) { + return false; + } + for (auto* fc_i : fc_op->inputs) { + if (!fc_i->inputs.empty()) { + if (at_top) { + return true; + } else { + return VarLinksFromOp(fc_i, "relu"); + } + } + } + return false; + }; + + auto before_var_of_part = [=](Node* x) -> Node* { + auto* fc_op = x->inputs[0]->inputs[0]; + for (auto* fc_i : fc_op->inputs) { + if (!fc_i->inputs.empty()) { + return fc_i->inputs[0]; + } + } + return nullptr; + }; + + auto var_before_is_fc_act_repeated_n_times = [=]( + Node* x, int repeated_times, + const std::string& act_type = "relu") -> bool { + for (int i = 0; i < repeated_times; ++i) { + if (!var_before_is_fc_act(x, act_type, i == repeated_times - 1)) { + return false; + } + x = before_var_of_part(x); + } + return true; + }; + + std::vector fc_input_var(num_fc); + std::vector fc_output_var(num_fc); + std::vector fc_weight_var(num_fc); + std::vector fc_bias_var(num_fc); + std::vector fc_ops(num_fc); + std::vector relu_ops(num_fc); + + for (int i = 0; i < num_fc; ++i) { + fc_input_var[i] = pattern->NewNode( + [=](Node* x) { + if (i == 0 && x->outputs.size() > 0) { + bool ok = x->inputs.size() > 0; + if (!ok) { + return false; + } + int idx = find_fc_idx(x); + if (idx == 0) { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu"); + } else { + x = next_var_of_part(x, idx); + return var_next_is_fc_act_repeated_n_times( + x, std::max(1, num_fc - i - 1), "relu"); + } + } else { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && + x->inputs.size() > 0 && + var_before_is_fc_act_repeated_n_times(x, i, "relu"); + } + }, + name_scope + "/fc_in_" + std::to_string(i)); + + fc_weight_var[i] = pattern->NewNode( + [=](Node* x) { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && + x->inputs.empty() && + var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0], + i, "relu") && + x->Name() == x->outputs[0]->Op()->Input("W")[0]; + }, + name_scope + "/fc_weight_" + std::to_string(i)); + + fc_bias_var[i] = pattern->NewNode( + [=](Node* x) { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && + x->inputs.empty() && + var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0], + i, "relu") && + x->Name() == x->outputs[0]->Op()->Input("Bias")[0]; + }, + name_scope + "/fc_bias_" + std::to_string(i)); + + fc_output_var[i] = pattern->NewNode( + [=](Node* x) { + bool basic = x && x->IsVar() && VarLinksFromOp(x, "fc") && + VarLinksToOp(x, "relu") && x->inputs.size() == 1 && + x->inputs[0]->inputs.size() == 3; + if (!basic) { + return false; + } + x = x->inputs[0]->inputs[0]; + if (i == 0 && x->outputs.size() > 0) { + bool ok = x->inputs.size() > 0; + if (!ok) { + return false; + } + int idx = find_fc_idx(x); + if (idx == 0) { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu"); + } else { + x = next_var_of_part(x, idx); + return var_next_is_fc_act_repeated_n_times( + x, std::max(1, num_fc - i - 1), "relu"); + } + } else { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && + x->inputs.size() > 0 && + var_before_is_fc_act_repeated_n_times(x, i, "relu"); + } + }, + name_scope + "/fc_out_" + std::to_string(i)); + + fc_ops[i] = pattern->NewNode( + [=](Node* x) { + bool basic = x && x->IsOp() && x->Op()->Type() == "fc" && + x->inputs.size() == 3 && x->outputs.size() == 1; + if (!basic) { + return false; + } + auto* fc_out_var = x->outputs[0]; + return fc_out_var && fc_out_var->IsVar() && + fc_out_var->outputs.size() == 1 && + VarLinksToOp(fc_out_var, "relu") && + fc_out_var->outputs[0]->outputs.size() == 1 && + var_next_is_fc_act_repeated_n_times( + fc_out_var->outputs[0]->outputs[0], num_fc - i - 1, + "relu") && + var_before_is_fc_act_repeated_n_times( + fc_out_var->outputs[0]->outputs[0], i + 1, "relu"); + }, + name_scope + "/fc_op_" + std::to_string(i)); + + relu_ops[i] = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "relu" && + x->inputs.size() == 1 && x->outputs.size() == 1 && + x->inputs[0]->IsVar() && VarLinksFromOp(x->inputs[0], "fc") && + x->outputs[0]->IsVar() && + var_next_is_fc_act_repeated_n_times(x->outputs[0], + num_fc - i - 1, "relu") && + var_before_is_fc_act_repeated_n_times(x->outputs[0], i + 1, + "relu"); + }, + name_scope + "/act_op_" + std::to_string(i)); + + fc_ops[i] + ->LinksFrom({fc_input_var[i], fc_weight_var[i], fc_bias_var[i]}) + .LinksTo({fc_output_var[i]}); + relu_ops[i]->LinksFrom({fc_output_var[i]}); + } + + auto* last_out_var = pattern->NewNode( + [=](Node* x) { + return var_before_is_fc_act_repeated_n_times(x, num_fc, "relu"); + }, + name_scope + "/act_out"); + for (int i = 0; i < num_fc - 1; ++i) { + relu_ops[i]->LinksTo({fc_input_var[i + 1]}); + } + relu_ops[num_fc - 1]->LinksTo({last_out_var}); + return last_out_var; +} + +static int BuildFusion(Graph* graph, const std::string& name_scope, + int num_fc) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + BuildRepeatedFCReluPattern(pattern, name_scope, num_fc); + + auto retrieve_node = [](const std::string& name, + const GraphPatternDetector::subgraph_t& subgraph, + const PDPattern& pat) -> Node* { + PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), + "pattern has no Node called %s", name.c_str()); + Node* p = subgraph.at(pat.RetrieveNode(name)); + PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str()); + return p; + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + LOG(INFO) << "handle Repeated FC Act fuse"; + std::vector weights_vars(num_fc); + std::vector bias_vars(num_fc); + std::vector relu_vars(num_fc - 1); + + std::vector weight_names(num_fc); + std::vector bias_names(num_fc); + std::vector relu_names(num_fc - 1); + + auto& fused_pattern = gpd.pattern(); + for (int i = 0; i < num_fc; ++i) { + if (i >= 1) { + relu_vars[i - 1] = + retrieve_node(name_scope + "/fc_in_" + std::to_string(i), subgraph, + fused_pattern); + relu_names[i - 1] = relu_vars[i - 1]->Name(); + } + + weights_vars[i] = + retrieve_node(name_scope + "/fc_weight_" + std::to_string(i), + subgraph, fused_pattern); + weight_names[i] = weights_vars[i]->Name(); + + bias_vars[i] = retrieve_node(name_scope + "/fc_bias_" + std::to_string(i), + subgraph, fused_pattern); + bias_names[i] = bias_vars[i]->Name(); + } + + auto* input_var = + retrieve_node(name_scope + "/fc_in_0", subgraph, fused_pattern); + auto* last_out_var = + retrieve_node(name_scope + "/act_out", subgraph, fused_pattern); + + // Create New OpDesc + OpDesc op_desc; + op_desc.SetType("fusion_repeated_fc_relu"); + op_desc.SetInput("X", {input_var->Name()}); + op_desc.SetInput("W", weight_names); + op_desc.SetInput("Bias", bias_names); + op_desc.SetOutput("ReluOut", relu_names); + op_desc.SetOutput("Out", {last_out_var->Name()}); + auto* op = graph->CreateOpNode(&op_desc); + IR_NODE_LINK_TO(input_var, op); + for (size_t i = 0; i < weights_vars.size(); ++i) { + IR_NODE_LINK_TO(weights_vars[i], op); + IR_NODE_LINK_TO(bias_vars[i], op); + } + for (size_t i = 0; i < relu_vars.size(); ++i) { + IR_NODE_LINK_TO(op, relu_vars[i]); + } + IR_NODE_LINK_TO(op, last_out_var); + + std::unordered_set marked_nodes; + for (auto& item : subgraph) { + marked_nodes.insert(item.second); + } + for (size_t i = 0; i < weights_vars.size(); ++i) { + marked_nodes.erase(weights_vars[i]); + marked_nodes.erase(bias_vars[i]); + } + for (size_t i = 0; i < relu_vars.size(); ++i) { + marked_nodes.erase(relu_vars[i]); + } + marked_nodes.erase(input_var); + marked_nodes.erase(last_out_var); + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + + gpd(graph, handler); + return fusion_count; +} + +std::unique_ptr RepeatedFCReluFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + int fusion_count = 0; + for (int i = MAX_NUM_FC; i > 1; --i) { + fusion_count += + BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i); + } + AddStatis(fusion_count); + + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(repeated_fc_relu_fuse_pass, + paddle::framework::ir::RepeatedFCReluFusePass); diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..3f3f0846eba1201e57a653f8e515c28d2bcdd5e3 --- /dev/null +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/** + * Fuse Repeated FC Relu + */ +class RepeatedFCReluFusePass : public FusePassBase { + public: + virtual ~RepeatedFCReluFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + const std::string name_scope_{"repeated_fc_relu_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc index fa75e3b4aa7feb7ff856dc26338d089f90efa2e2..63a0c24f2a6b6e1afe3d25210ec6eb3cbaac2f2f 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc @@ -129,7 +129,8 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, return concat_out_var; } -int BuildFusion(Graph* graph, const std::string& name_scope, int num_inputs) { +static int BuildFusion(Graph* graph, const std::string& name_scope, + int num_inputs) { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs); diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..78c8cabb10f5b7718375f8052644074869929d04 --- /dev/null +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -0,0 +1,379 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h" +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" + +namespace paddle { +namespace framework { +namespace ir { + +PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, + const std::string& name_scope) { + auto var_is_op_input = [=](Node* x, const std::string& op_type, + const std::string& arg_name = "") -> bool { + if (!(x && x->IsVar())) { + return false; + } + for (auto* op : x->outputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { + if (arg_name.empty()) { + return true; + } + for (auto& name : op->Op()->Input(arg_name)) { + if (name == x->Name()) { + return true; + } + } + } + } + return false; + }; + + auto var_is_op_only_output = [](Node* x, const std::string& op_type) -> bool { + return x && x->IsVar() && x->inputs.size() == 1 && x->inputs[0] && + x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == op_type && + x->inputs[0]->outputs.size() == 1; + }; + + auto next_op = [=](Node* x, const std::string& op_type) -> Node* { + if (!(x && x->IsVar())) { + return nullptr; + } + for (auto* op : x->outputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { + return op; + } + } + return nullptr; + }; + + auto get_op_input_var = [=](Node* x, const std::string& arg_name) -> Node* { + if (!(x && x->IsOp())) { + return nullptr; + } + for (auto* var : x->inputs) { + for (auto name : x->Op()->Input(arg_name)) { + if (var->Name() == name) { + return var; + } + } + } + return nullptr; + }; + + auto is_fusion_input_var = [=](Node* x, const std::string& arg_name) { + bool basic = var_is_op_input(x, "matmul", arg_name) && + var_is_op_input(x, "square", "X"); + if (!basic) { + return false; + } + auto* squared_x_op = next_op(x, "square"); + if (!(squared_x_op && squared_x_op->outputs.size() == 1)) { + return false; + } + auto* squared_x = squared_x_op->outputs[0]; + bool next_is_matmul_from_arg = + var_is_op_input(squared_x, "matmul", arg_name) && + squared_x->outputs.size() == 1 && + squared_x->outputs[0]->outputs.size() == 1; + if (!next_is_matmul_from_arg) { + return false; + } + auto* sub_y_in = squared_x->outputs[0]->outputs[0]; + return var_is_op_input(sub_y_in, "elementwise_sub", "Y") && + sub_y_in->outputs[0]->outputs.size() == 1 && + var_is_op_input(sub_y_in->outputs[0]->outputs[0], "elementwise_mul"); + }; + + auto is_fusion_first_mul_out = [=](Node* x) -> bool { + bool input_is_matmul_op = x && x->inputs.size() == 1 && + x->inputs[0]->IsOp() && + x->inputs[0]->Op()->Type() == "matmul"; + if (!input_is_matmul_op) { + return false; + } + auto* mat_x = get_op_input_var(x->inputs[0], "X"); + auto* mat_y = get_op_input_var(x->inputs[0], "Y"); + bool input_mul_is_valid = mat_x && is_fusion_input_var(mat_x, "X") && + mat_y && is_fusion_input_var(mat_y, "Y"); + if (!input_mul_is_valid) { + return false; + } + + bool next_is_square = var_is_op_input(x, "square", "X") && + x->outputs.size() == 1 && + x->outputs[0]->outputs.size() == 1; + if (!next_is_square) { + return false; + } + auto* sub_x_in = x->outputs[0]->outputs[0]; + return var_is_op_input(sub_x_in, "elementwise_sub", "X") && + sub_x_in->outputs[0]->outputs.size() == 1 && + var_is_op_input(sub_x_in->outputs[0]->outputs[0], "elementwise_mul"); + }; + + auto* x = pattern->NewNode( + [=](Node* x) { return is_fusion_input_var(x, "X"); }, name_scope + "/x"); + + auto* y = pattern->NewNode( + [=](Node* x) { return is_fusion_input_var(x, "Y"); }, name_scope + "/y"); + + auto* square_x_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "square" && + is_fusion_input_var(x->inputs[0], "X"); + }, + name_scope + "/squared_x_op"); + + auto* square_y_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "square" && + is_fusion_input_var(x->inputs[0], "Y"); + }, + name_scope + "/squared_y_op"); + + auto* squared_x = pattern->NewNode( + [=](Node* x) { + return x && x->inputs.size() == 1 && x->inputs[0]->inputs.size() == 1 && + is_fusion_input_var(x->inputs[0]->inputs[0], "X"); + }, + name_scope + "/squared_x"); + + auto* squared_y = pattern->NewNode( + [=](Node* x) { + return x && x->inputs.size() == 1 && x->inputs[0]->inputs.size() == 1 && + is_fusion_input_var(x->inputs[0]->inputs[0], "Y"); + }, + name_scope + "/squared_y"); + + auto* matmuled_xy = + pattern->NewNode([=](Node* x) { return is_fusion_first_mul_out(x); }, + name_scope + "/matmuled_xy"); + + auto* matmul_xy_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "matmul" && + is_fusion_first_mul_out(x->outputs[0]); + }, + name_scope + "/matmul_xy_op"); + + auto* square_matmuled_xy_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "square" && + is_fusion_first_mul_out(x->inputs[0]); + }, + name_scope + "/square_matmuled_xy_op"); + + auto* squared_xmuly = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && x->inputs.size() == 1 && + x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == "square" && + is_fusion_first_mul_out(x->inputs[0]->inputs[0]); + }, + name_scope + "/squared_xmuly"); + + auto is_fusion_mat_squared_x_y_op_out = [=](Node* x) -> bool { + bool basic = x && x->IsVar() && x->inputs.size() == 1 && + x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == "matmul"; + if (!basic) { + return false; + } + auto* sqx = get_op_input_var(x->inputs[0], "X"); + auto* sqy = get_op_input_var(x->inputs[0], "Y"); + + return var_is_op_only_output(sqx, "square") && + var_is_op_only_output(sqy, "square") && sqx->inputs[0] && + sqx->inputs[0]->inputs.size() == 1 && + is_fusion_input_var(sqx->inputs[0]->inputs[0], "X") && + sqy->inputs[0] && sqy->inputs[0]->inputs.size() == 1 && + is_fusion_input_var(sqy->inputs[0]->inputs[0], "Y"); + }; + + auto* matmul_squared_x_y_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "matmul" && + is_fusion_mat_squared_x_y_op_out(x->outputs[0]); + }, + name_scope + "/matmul_squared_x_y_op"); + + auto* mat_squared_x_y_op_out = pattern->NewNode( + [=](Node* x) { return is_fusion_mat_squared_x_y_op_out(x); }, + name_scope + "/mat_squared_x_y_op_out"); + + auto is_fusion_sub_op = [=](Node* x) -> bool { + bool is_sub_op = x && x->IsOp() && x->Op()->Type() == "elementwise_sub"; + if (!is_sub_op) { + return false; + } + auto* matmul_sqx_sqy_var = get_op_input_var(x, "Y"); + return is_fusion_mat_squared_x_y_op_out(matmul_sqx_sqy_var); + }; + + auto* sub_op = pattern->NewNode([=](Node* x) { return is_fusion_sub_op(x); }, + name_scope + "/sub_op"); + + auto* sub_op_out = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && x->inputs.size() == 1 && + is_fusion_sub_op(x->inputs[0]); + }, + name_scope + "/sub_op_out"); + + auto is_fusion_element_op = [=](Node* x) -> bool { + bool is_elemul_op = x && x->IsOp() && x->Op()->Type() == "elementwise_mul"; + if (!is_elemul_op) { + return false; + } + for (auto* in : x->inputs) { + if (in && in->inputs[0] && is_fusion_sub_op(in->inputs[0])) { + return true; + } + } + return false; + }; + + auto* elementmul_op = + pattern->NewNode([=](Node* x) { return is_fusion_element_op(x); }, + name_scope + "/elementmul_op"); + + auto* constant_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "fill_constant" && + x->outputs.size() == 1 && + is_fusion_element_op(x->outputs[0]->outputs[0]); + }, + name_scope + "/fill_constant_op"); + + auto* constant_op_out = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && var_is_op_input(x, "elementwise_mul") && + x->inputs[0] && x->inputs[0]->IsOp() && + x->inputs[0]->Op()->Type() == "fill_constant" && x->outputs[0] && + is_fusion_element_op(x->outputs[0]); + }, + name_scope + "/constant_op_out"); + + auto* last_out_var = pattern->NewNode( + [=](Node* x) { + return var_is_op_only_output(x, "elementwise_mul") && + is_fusion_element_op(x->inputs[0]); + }, + name_scope + "/out"); + + square_x_op->LinksFrom({x}).LinksTo({squared_x}); + square_y_op->LinksFrom({y}).LinksTo({squared_y}); + matmul_xy_op->LinksFrom({x, y}).LinksTo({matmuled_xy}); + matmul_squared_x_y_op->LinksFrom({squared_x, squared_y}) + .LinksTo({mat_squared_x_y_op_out}); + square_matmuled_xy_op->LinksFrom({matmuled_xy}).LinksTo({squared_xmuly}); + sub_op->LinksFrom({squared_xmuly, mat_squared_x_y_op_out}) + .LinksTo({sub_op_out}); + constant_op->LinksFrom({}).LinksTo({constant_op_out}); + elementmul_op->LinksFrom({constant_op_out, sub_op_out}) + .LinksTo({last_out_var}); + + return last_out_var; +} + +static int BuildFusion(Graph* graph, const std::string& name_scope) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + BuildSquaredMatSubPattern(pattern, name_scope); + + auto retrieve_node = [](const std::string& name, + const GraphPatternDetector::subgraph_t& subgraph, + const PDPattern& pat) -> Node* { + PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), + "pattern has no Node called %s", name.c_str()); + Node* p = subgraph.at(pat.RetrieveNode(name)); + PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str()); + return p; + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + LOG(INFO) << "handle sqaure mat sub fuse"; + auto& fused_pattern = gpd.pattern(); + + auto* matx = retrieve_node(name_scope + "/x", subgraph, fused_pattern); + auto* maty = retrieve_node(name_scope + "/y", subgraph, fused_pattern); + auto* squaredx = + retrieve_node(name_scope + "/squared_x", subgraph, fused_pattern); + auto* squaredy = + retrieve_node(name_scope + "/squared_y", subgraph, fused_pattern); + auto* squaredxy = + retrieve_node(name_scope + "/squared_xmuly", subgraph, fused_pattern); + auto* last_out_var = + retrieve_node(name_scope + "/out", subgraph, fused_pattern); + auto* fill_constant_op = retrieve_node(name_scope + "/fill_constant_op", + subgraph, fused_pattern); + + // Create New OpDesc + OpDesc op_desc; + op_desc.SetType("fusion_squared_mat_sub"); + op_desc.SetInput("X", {matx->Name()}); + op_desc.SetInput("Y", {maty->Name()}); + op_desc.SetOutput("SquaredX", {squaredx->Name()}); + op_desc.SetOutput("SquaredY", {squaredy->Name()}); + op_desc.SetOutput("SquaredXY", {squaredxy->Name()}); + op_desc.SetOutput("Out", {last_out_var->Name()}); + op_desc.SetAttr("scalar", fill_constant_op->Op()->GetAttr("value")); + + auto* op = graph->CreateOpNode(&op_desc); + IR_NODE_LINK_TO(matx, op); + IR_NODE_LINK_TO(maty, op); + IR_NODE_LINK_TO(op, squaredx); + IR_NODE_LINK_TO(op, squaredy); + IR_NODE_LINK_TO(op, squaredxy); + IR_NODE_LINK_TO(op, last_out_var); + + std::unordered_set marked_nodes; + for (auto& item : subgraph) { + marked_nodes.insert(item.second); + } + + marked_nodes.erase(matx); + marked_nodes.erase(maty); + marked_nodes.erase(squaredx); + marked_nodes.erase(squaredy); + marked_nodes.erase(squaredxy); + marked_nodes.erase(last_out_var); + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + + gpd(graph, handler); + return fusion_count; +} + +std::unique_ptr SquaredMatSubFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + int fusion_count = BuildFusion(graph.get(), name_scope_); + AddStatis(fusion_count); + + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(squared_mat_sub_fuse_pass, + paddle::framework::ir::SquaredMatSubFusePass); diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..fb49adc3768ec99cab4321c6b90c93dfed6d32f2 --- /dev/null +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/** + * Fuse ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar + */ +class SquaredMatSubFusePass : public FusePassBase { + public: + virtual ~SquaredMatSubFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + const std::string name_scope_{"squared_mat_sub_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index de9650735adfe158e72213d4f6d5d3569aa90d55..efe1ba106a2fcbac66a773e56b98d1a6452f4013 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -98,6 +98,8 @@ class CpuPassStrategy : public PassStrategy { "mul_gru_fuse_pass", // "seq_concat_fc_fuse_pass", // "fc_fuse_pass", // + "repeated_fc_relu_fuse_pass", // + "squared_mat_sub_fuse_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // "is_test_pass", // diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 6854282a164773ad32a105c254b12a3bb4731e7f..0f670658892b9926dcc534038925c46047a113fd 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -37,15 +37,21 @@ function(inference_analysis_api_test_with_refer_result target install_dir filena --refer_result=${install_dir}/result.txt) endfunction() -# RNN1 if(NOT APPLE AND WITH_MKLML) + # RNN1 set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1") download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz") inference_analysis_api_test(test_analyzer_rnn1 ${RNN1_INSTALL_DIR} analyzer_rnn1_tester.cc SERIAL) + + # seq_pool1 + set(SEQ_POOL1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_pool") + download_model_and_data(${SEQ_POOL1_INSTALL_DIR} "seq_pool1_model_.tar.gz" "seq_pool1_data.txt.tar.gz") + inference_analysis_api_test(test_analyzer_seq_pool1 ${SEQ_POOL1_INSTALL_DIR} analyzer_seq_pool1_tester.cc SERIAL) else() # TODO: fix this test on MACOS and OPENBLAS, the reason is that # fusion_seqexpand_concat_fc_op is not supported on MACOS and OPENBLAS message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_rnn1") + message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_seq_pool1") endif() # RNN2 @@ -90,11 +96,6 @@ set(SEQ_CONV1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_conv1") download_model_and_data(${SEQ_CONV1_INSTALL_DIR} "seq_conv1_model.tar.gz" "seq_conv1_data.txt.tar.gz") inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} analyzer_seq_conv1_tester.cc) -# seq_pool1 -set(SEQ_POOL1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_pool") -download_model_and_data(${SEQ_POOL1_INSTALL_DIR} "seq_pool1_model_.tar.gz" "seq_pool1_data.txt.tar.gz") -inference_analysis_api_test(test_analyzer_seq_pool1 ${SEQ_POOL1_INSTALL_DIR} analyzer_seq_pool1_tester.cc) - # ocr set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr") if (NOT EXISTS ${OCR_INSTALL_DIR}) diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc index c137090879e67d5314b94709586c5292dc208745..8be2a6d79b2ede2c149aa523e38c3960ab30acb1 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc @@ -21,6 +21,12 @@ namespace paddle { namespace inference { namespace analysis { +// diff: similarity_norm.tmp_0, for speed: fc_4.tmp_1 +static const char out_var_name[] = "reduce_sum_0.tmp_0"; + +// for diff: 154, for speed 111 +constexpr int num_slots = 154; + struct OneSlotInBatch { std::string name; std::vector> data; @@ -41,7 +47,6 @@ struct DataRecord { void Load(const std::string &path) { std::ifstream file(path); - constexpr int num_slots = 154; std::string line; int num_lines = 0; while (std::getline(file, line)) { @@ -187,11 +192,15 @@ void analysis_fuse_statis(bool use_zerocopy) { auto predictor = CreatePaddlePredictor(cfg); auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops); ASSERT_TRUE(fuse_statis.count("fc_fuse")); - ASSERT_EQ(fuse_statis.at("fc_fuse"), 10); ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse")); + ASSERT_TRUE(fuse_statis.count("squared_mat_sub_fuse")); + ASSERT_TRUE(fuse_statis.count("repeated_fc_relu_fuse")); + ASSERT_EQ(fuse_statis.at("fc_fuse"), 10); EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2); + EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 2); + EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2); LOG(INFO) << "num_ops: " << num_ops; - EXPECT_EQ(num_ops, 195); + EXPECT_EQ(num_ops, 171); } // Check the fuse status @@ -214,9 +223,6 @@ void PrepareZeroCopyInputs( } } -// diff: similarity_norm.tmp_0, // speed: fc_4.tmp_1 -static const char out_var_name[] = "reduce_sum_0.tmp_0"; - // return the output values std::vector zerocopy_profile(int repeat_times) { AnalysisConfig config; @@ -322,7 +328,9 @@ TEST(Analyzer_seq_pool1, zerocopy_compare_native) { native_outputs.front().data.length()); auto *native_data = static_cast(native_outputs.front().data.data()); for (size_t i = 0; i < zerocopy_output.size(); ++i) { - EXPECT_NEAR(zerocopy_output[i], native_data[i], 1e-3); + EXPECT_LT( + std::fabs((zerocopy_output[i] - native_data[i]) / zerocopy_output[i]), + 1e-3); } } diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a35ee8a09ed5ddcc4ac465d200b84358fa65b2f3 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc @@ -0,0 +1,149 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h" +#include +#include +#include "paddle/fluid/operators/jit/kernels.h" + +namespace paddle { +namespace operators { + +void FusionRepeatedFCReluOp::InferShape( + framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FusionRepeatedFCReluOp should not be null."); + auto sz = ctx->Inputs("W").size(); + PADDLE_ENFORCE_GT( + sz, 1UL, "Inputs(W) of FusionRepeatedFCReluOp should larger than 1."); + PADDLE_ENFORCE_EQ(ctx->Inputs("Bias").size(), sz, + "Size of inputs(Bias) of FusionRepeatedFCReluOp should be " + "equal to inputs size."); + PADDLE_ENFORCE_EQ(ctx->Outputs("ReluOut").size(), sz - 1, + "Size of output(ReluOut) of FusionRepeatedFCReluOp should " + "be equal to inputs size -1."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FusionRepeatedFCReluOp should not be null."); + + auto i_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(i_dims.size(), 2UL, "Input shape size should be 2"); + + auto w_dims = ctx->GetInputsDim("W"); + auto b_dims = ctx->GetInputsDim("Bias"); + PADDLE_ENFORCE_EQ(w_dims.size(), b_dims.size(), + "Shape size of weight and bias should be equal"); + PADDLE_ENFORCE_EQ(w_dims.size(), sz, + "Shape size of weight and bias should be equal"); + PADDLE_ENFORCE_EQ(i_dims[1], w_dims[0][0], + "inpute width should be equal with weight height"); + + for (size_t i = 1; i < sz; ++i) { + PADDLE_ENFORCE_EQ(w_dims[i].size(), 2UL, + "Every weight shape size should be 2."); + PADDLE_ENFORCE_EQ(framework::product(b_dims[i]), w_dims[i][1], + "The length of Bias must be equal with w_dims[1]."); + } + ctx->SetOutputDim("Out", {i_dims[0], w_dims[sz - 1][1]}); + ctx->ShareLoD("X", /*->*/ "Out"); +} + +framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), + ctx.GetPlace()); +} + +void FusionRepeatedFCReluOpMaker::Make() { + AddInput("X", "(LoDTensor) Input tensors of this operator."); + AddInput("W", "(Tensor) The weight tensors of this operator.").AsDuplicable(); + AddInput("Bias", "(Tensor) The bias tensors of this operator.") + .AsDuplicable(); + AddOutput("ReluOut", "(Tensor) The output tensor of each relu operator.") + .AsDuplicable() + .AsIntermediate(); + AddOutput("Out", "(LoDTensor) Output tensor of this operator."); + AddComment(R"DOC( + Fusion Repeated FC with Relu Operator. +)DOC"); +} + +template +static void fc_relu(const T* x, const T* w, const T* b, T* y, int m, int n, + int k) { + auto matmul = + jit::Get, platform::CPUPlace>(k); + auto addbias_relu = + jit::Get, platform::CPUPlace>(n); + matmul(x, w, y, m, n, k); + T* dst = y; + for (int i = 0; i < m; ++i) { + addbias_relu(b, dst, dst, n); + dst += n; + } +} + +template +class FusionRepeatedFCReluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto in = ctx.Input("X"); + auto weights = ctx.MultiInput("W"); + auto biases = ctx.MultiInput("Bias"); + auto relus = ctx.MultiOutput("ReluOut"); + auto* out = ctx.Output("Out"); + auto place = ctx.GetPlace(); + int weight_sz = static_cast(weights.size()); + + auto i_dims = in->dims(); + auto w_dims = weights[0]->dims(); + int m = i_dims[0]; + int n = w_dims[1]; + int k = w_dims[0]; + relus[0]->Resize({m, n}); + fc_relu(in->data(), weights[0]->data(), biases[0]->data(), + relus[0]->mutable_data(place), m, n, k); + + for (int i = 1; i < weight_sz - 1; ++i) { + auto i_dims = relus[i - 1]->dims(); + auto w_dims = weights[i]->dims(); + int m = i_dims[0]; + int n = w_dims[1]; + int k = w_dims[0]; + relus[i]->Resize({m, n}); + fc_relu(relus[i - 1]->data(), weights[i]->data(), + biases[i]->data(), relus[i]->mutable_data(place), m, n, k); + } + + auto i_dims_last = relus[weight_sz - 2]->dims(); + auto w_dims_last = weights[weight_sz - 1]->dims(); + m = i_dims_last[0]; + n = w_dims_last[1]; + k = w_dims_last[0]; + fc_relu(relus[weight_sz - 2]->data(), weights[weight_sz - 1]->data(), + biases[weight_sz - 1]->data(), out->mutable_data(place), m, n, + k); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fusion_repeated_fc_relu, ops::FusionRepeatedFCReluOp, + ops::FusionRepeatedFCReluOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(fusion_repeated_fc_relu, + ops::FusionRepeatedFCReluKernel, + ops::FusionRepeatedFCReluKernel); diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cdcaf8b4833464100ed579a5962c60013edecdb0 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +class FusionRepeatedFCReluOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FusionRepeatedFCReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..00dafdead53bbd4614c70875441c565724fca46d --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc @@ -0,0 +1,137 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h" +#include +#include +#include "paddle/fluid/operators/jit/kernels.h" + +namespace paddle { +namespace operators { + +void FusionSquaredMatSubOp::InferShape( + framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FusionSquaredMatSubOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of FusionSquaredMatSubOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("SquaredX"), + "Output(SquaredX) of FusionSquaredMatSubOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("SquaredY"), + "Output(SquaredY) of FusionSquaredMatSubOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("SquaredXY"), + "Output(SquaredXY) of FusionSquaredMatSubOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FusionSquaredMatSubOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), + "Input tensors dims size should be equal."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input tensors should be a Matrix."); + PADDLE_ENFORCE_EQ(x_dims[1], y_dims[0], "Inputs Matrix should be multiply."); + + ctx->SetOutputDim("SquaredX", x_dims); + ctx->SetOutputDim("SquaredY", y_dims); + ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]}); + ctx->SetOutputDim("Out", {x_dims[0], y_dims[1]}); +} + +framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), + ctx.GetPlace()); +} + +void FusionSquaredMatSubOpMaker::Make() { + AddInput("X", "(Tensor) Input Mat A of this operator."); + AddInput("Y", "(Tensor) Input Mat B of this operator."); + AddOutput("SquaredX", "(Tensor) Squared X.").AsIntermediate(); + AddOutput("SquaredY", "(Tensor) Squared Y.").AsIntermediate(); + AddOutput("SquaredXY", "(Tensor) Squared X*Y.").AsIntermediate(); + AddOutput("Out", "(Tensor) Output tensor of concat operator."); + AddAttr("scalar", "The scalar on output matrix.").SetDefault(1.f); + AddComment(R"DOC( + Fusion Squared Matrix and substrct operator. + + ( (X * Y).^2 - (X.^2 * Y.^2) ) .* scalar +)DOC"); +} + +template +class FusionSquaredMatSubKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + auto y = ctx.Input("Y"); + auto* squared_x = ctx.Output("SquaredX"); + auto* squared_y = ctx.Output("SquaredY"); + auto* squared_xy = ctx.Output("SquaredXY"); + auto* out = ctx.Output("Out"); + auto place = ctx.GetPlace(); + T scalar = static_cast(ctx.Attr("scalar")); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + int m = x_dims[0]; + int k = x_dims[1]; + int n = y_dims[1]; + int o_numel = m * n; + + auto vsquare_x = + jit::Get, platform::CPUPlace>(m * k); + auto vsquare_y = + jit::Get, platform::CPUPlace>(k * n); + auto vsquare_xy = + jit::Get, platform::CPUPlace>(o_numel); + auto vsub = + jit::Get, platform::CPUPlace>(o_numel); + auto vscal = + jit::Get, platform::CPUPlace>(o_numel); + auto matmul = + jit::Get, platform::CPUPlace>(k); + + const T* x_data = x->data(); + const T* y_data = y->data(); + T* squared_x_data = squared_x->mutable_data(place); + T* squared_y_data = squared_y->mutable_data(place); + T* squared_xy_data = squared_xy->mutable_data(place); + T* o_data = out->mutable_data(place); + + matmul(x_data, y_data, squared_xy_data, m, n, k); + vsquare_xy(squared_xy_data, squared_xy_data, o_numel); + + vsquare_x(x_data, squared_x_data, m * k); + vsquare_y(y_data, squared_y_data, k * n); + matmul(squared_x_data, squared_y_data, o_data, m, n, k); + + vsub(squared_xy_data, o_data, o_data, o_numel); + vscal(&scalar, o_data, o_data, o_numel); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fusion_squared_mat_sub, ops::FusionSquaredMatSubOp, + ops::FusionSquaredMatSubOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(fusion_squared_mat_sub, + ops::FusionSquaredMatSubKernel, + ops::FusionSquaredMatSubKernel); diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0ab2c2bb10a15cc6d9a472142416bd363e65944f --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +// ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar +class FusionSquaredMatSubOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FusionSquaredMatSubOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 4b4ce07fa78b97e636173566fa104cb8a18c914e..b39ce280939515ec8f4fa3b443ff4332074825fd 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -210,6 +210,24 @@ void BenchSeqPoolKernel() { } } +template +void BenchMatMulKernel() { + for (int m : {1, 2, 3, 4}) { + for (int n : TestSizes()) { + for (int k : TestSizes()) { + std::vector a(m * k), b(k * n), c(m * n); + RandomVec(m * k, a.data(), -2.f, 2.f); + RandomVec(k * n, b.data(), -2.f, 2.f); + const T* a_data = a.data(); + const T* b_data = b.data(); + T* c_data = c.data(); + BenchAllImpls, PlaceType>(k, a_data, b_data, + c_data, m, n, k); + } + } + } +} + // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] // Options: @@ -236,6 +254,7 @@ int main(int argc, char* argv[]) { // xyn BenchXYNKernel(); BenchXYNKernel(); + BenchXYNKernel(); BenchXYNKernel(); BenchXYNKernel(); BenchXYNKernel(); @@ -251,4 +270,7 @@ int main(int argc, char* argv[]) { // seq pool function BenchSeqPoolKernel(); + + // matmul + BenchMatMulKernel(); } diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 2b8c758a032fd7edff0d4b7e23bd8e685eb3ab15..40310c2d2b372a414054f75348e8e1b4471bf3d2 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -11,11 +11,12 @@ endfunction() # use gen jitcode kernel by name USE_JITKERNEL_GEN(kVMul) USE_JITKERNEL_GEN(kVAdd) -#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me +USE_JITKERNEL_GEN(kVSub) USE_JITKERNEL_GEN(kVAddRelu) USE_JITKERNEL_GEN(kVScal) USE_JITKERNEL_GEN(kVAddBias) USE_JITKERNEL_GEN(kVRelu) +USE_JITKERNEL_GEN(kVSquare) USE_JITKERNEL_GEN(kVIdentity) USE_JITKERNEL_GEN(kVExp) USE_JITKERNEL_GEN(kVSigmoid) diff --git a/paddle/fluid/operators/jit/gen/act.cc b/paddle/fluid/operators/jit/gen/act.cc index 3ea076f217dc7c8a755055d3f48c22b7a3627012..a2a5661b93ad3d885983c502566860aa313d110f 100644 --- a/paddle/fluid/operators/jit/gen/act.cc +++ b/paddle/fluid/operators/jit/gen/act.cc @@ -91,6 +91,7 @@ void VActJitCode::genCode() { } DECLARE_ACT_CREATOR(VRelu); +DECLARE_ACT_CREATOR(VSquare); DECLARE_ACT_CREATOR(VIdentity); DECLARE_ACT_CREATOR(VExp); DECLARE_ACT_CREATOR(VSigmoid); @@ -103,6 +104,10 @@ size_t VReluCreator::CodeSize(const int& d) const { 8 /* average bytes for each instruction */; } +size_t VSquareCreator::CodeSize(const int& d) const { + return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8; +} + size_t VIdentityCreator::CodeSize(const int& d) const { return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8; } @@ -129,6 +134,7 @@ size_t VTanhCreator::CodeSize(const int& d) const { namespace gen = paddle::operators::jit::gen; REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator); +REGISTER_JITKERNEL_GEN(kVSquare, gen::VSquareCreator); REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator); REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator); REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator); diff --git a/paddle/fluid/operators/jit/gen/act.h b/paddle/fluid/operators/jit/gen/act.h index 81503c42ab5cd46961378847584f68f2cbed0ed5..68e66f9298c4eafabb55c20195d46fed800f4ec4 100644 --- a/paddle/fluid/operators/jit/gen/act.h +++ b/paddle/fluid/operators/jit/gen/act.h @@ -75,6 +75,12 @@ class VActFunc : public JitCode { vmaxps(dst, src, zero); } + // compute SQUARE with ymm, xmm + template + void square_jmm(JMM& dst, JMM& src) { // NOLINT + vmulps(dst, src, src); + } + // compute EXP with ymm, xmm template void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT @@ -228,6 +234,9 @@ class VActFunc : public JitCode { case operand_type::RELU: relu_jmm(dst, src, 15); break; + case operand_type::SQUARE: + square_jmm(dst, src); + break; case operand_type::EXP: exp_jmm(dst, src, 11, 12, 13, 14, 15); break; @@ -254,7 +263,7 @@ class VActJitCode : public VActFunc { : VActFunc(code_size, code_ptr), num_(d), type_(type) { if (!(type_ == operand_type::RELU || type_ == operand_type::EXP || type_ == operand_type::SIGMOID || type_ == operand_type::TANH || - type_ == operand_type::IDENTITY)) { + type_ == operand_type::IDENTITY || type_ == operand_type::SQUARE)) { LOG(FATAL) << "Do not support this operand type: " << type_; } this->genCode(); @@ -266,6 +275,9 @@ class VActJitCode : public VActFunc { case operand_type::RELU: base += "_Relu"; break; + case operand_type::SQUARE: + base += "_Square"; + break; case operand_type::EXP: base += "_Exp"; break; @@ -306,6 +318,7 @@ class VActJitCode : public VActFunc { }; DECLARE_ACT_JITCODE(VRelu, operand_type::RELU); +DECLARE_ACT_JITCODE(VSquare, operand_type::SQUARE); DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY); DECLARE_ACT_JITCODE(VExp, operand_type::EXP); DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID); diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc index c1198773088faa594bac0714dd8449b240b3ce4d..dee6c7b9d3ee9756c1b11d10d55fdca341cbee85 100644 --- a/paddle/fluid/operators/jit/gen/blas.cc +++ b/paddle/fluid/operators/jit/gen/blas.cc @@ -43,6 +43,8 @@ void VXXJitCode::genCode() { vmulps(ymm_dst, ymm_src1, ymm_src2); } else if (type_ == operand_type::ADD) { vaddps(ymm_dst, ymm_src1, ymm_src2); + } else if (type_ == operand_type::SUB) { + vsubps(ymm_dst, ymm_src1, ymm_src2); } if (with_relu_) { vmaxps(ymm_dst, ymm_zero, ymm_dst); @@ -85,6 +87,9 @@ void VXXJitCode::genCode() { case operand_type::ADD: vaddps(xmm_dst, xmm_src1, xmm_src2); break; + case operand_type::SUB: + vsubps(xmm_dst, xmm_src1, xmm_src2); + break; default: break; } @@ -178,8 +183,7 @@ namespace gen = paddle::operators::jit::gen; REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator); REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator); -// TODO(TJ): enable sub -// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator); +REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator); REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator); REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator); REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator); diff --git a/paddle/fluid/operators/jit/gen/blas.h b/paddle/fluid/operators/jit/gen/blas.h index c46ec15fb788c0c7a90cfc8732aad375a9e226a1..de6b33f467279124d7acd97709516c31706ec4f9 100644 --- a/paddle/fluid/operators/jit/gen/blas.h +++ b/paddle/fluid/operators/jit/gen/blas.h @@ -34,7 +34,8 @@ class VXXJitCode : public JitCode { type_(type), scalar_index_(scalar_index), with_relu_(with_relu) { - if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) { + if (!(type_ == operand_type::MUL || type_ == operand_type::ADD || + type_ == operand_type::SUB)) { LOG(FATAL) << "Do not support this operand type: " << type_; } this->genCode(); @@ -51,6 +52,8 @@ class VXXJitCode : public JitCode { base += "_Mul"; } else if (type_ == operand_type::ADD) { base += "_Add"; + } else if (type_ == operand_type::SUB) { + base += "_SUB"; } if (scalar_index_ == 2) { base += "_Scalar"; diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h index 5b7234c1cb5d15d290685a3dceb3b757be1ef0c6..f63d40ad5a559ab87a9b3735406671cfd936d9e4 100644 --- a/paddle/fluid/operators/jit/gen/jitcode.h +++ b/paddle/fluid/operators/jit/gen/jitcode.h @@ -51,6 +51,7 @@ typedef enum { SUB, RELU, EXP, + SQUARE, SIGMOID, TANH, IDENTITY diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index 7d02590f2e5d82b5105132d7af716f14c661d067..5dbe22a81b4866bdf60a03710d8ffd0b7bcb597b 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -36,6 +36,7 @@ const char* to_string(KernelType kt) { ONE_CASE(kVRelu); ONE_CASE(kVIdentity); ONE_CASE(kVExp); + ONE_CASE(kVSquare); ONE_CASE(kVSigmoid); ONE_CASE(kVTanh); ONE_CASE(kLSTMCtHt); @@ -47,6 +48,7 @@ const char* to_string(KernelType kt) { ONE_CASE(kLayerNorm); ONE_CASE(kNCHW16CMulNC); ONE_CASE(kSeqPool); + ONE_CASE(kMatMul); default: PADDLE_THROW("Not support type: %d, or forget to add it.", kt); return "NOT JITKernel"; diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 2a7697a6f253dcc2b8143d9f14a80a1cfd45996d..adb101bd5cdf231ac330dbf44beb4c24c1fcf29e 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -30,6 +30,7 @@ typedef enum { kVAddBias, kVRelu, kVIdentity, + kVSquare, kVExp, kVSigmoid, kVTanh, @@ -42,6 +43,7 @@ typedef enum { kLayerNorm, kNCHW16CMulNC, kSeqPool, + kMatMul, } KernelType; typedef enum { @@ -135,6 +137,13 @@ struct SeqPoolTuples { typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); }; +template +struct MatMulTuples { + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(const T*, const T*, T*, int, int, int); +}; + template struct CRFDecodingTuples { typedef T data_type; diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index f5ed2f0572176e42b774259c2b8fe9713d989417..667c6dfad6676d00ab994564bff57c90caa0cb41 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -3,10 +3,12 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE) # use mkl kernels by name and type +USE_JITKERNEL_MORE(kMatMul, mkl) USE_JITKERNEL_MORE(kVMul, mkl) USE_JITKERNEL_MORE(kVAdd, mkl) USE_JITKERNEL_MORE(kVScal, mkl) USE_JITKERNEL_MORE(kVExp, mkl) +USE_JITKERNEL_MORE(kVSquare, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kSeqPool, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 5a499ac2c02aa70d2824f0d3be618e083ba10334..fccdc68f5efa34bac6f5a34a41569d2f77416284 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -24,6 +24,20 @@ namespace jit { namespace more { namespace mkl { +template <> +void MatMul(const float* a, const float* b, float* c, int m, int n, + int k) { + platform::dynload::cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, + n, k, 1.f, a, k, b, n, 0.f, c, n); +} + +template <> +void MatMul(const double* a, const double* b, double* c, int m, int n, + int k) { + platform::dynload::cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, + n, k, 1.0, a, k, b, n, 0.0, c, n); +} + template <> void VMul(const float* x, const float* y, float* z, int n) { platform::dynload::vsMul(n, x, y, z); @@ -72,6 +86,16 @@ void VExp(const double* x, double* y, int n) { platform::dynload::vdExp(n, x, y); } +template <> +void VSquare(const float* x, float* y, int n) { + platform::dynload::vsSqr(n, x, y); +} + +template <> +void VSquare(const double* x, double* y, int n) { + platform::dynload::vdSqr(n, x, y); +} + template <> void VCopy(const float* x, float* y, int n) { platform::dynload::cblas_scopy(n, x, 1, y, 1); @@ -93,6 +117,11 @@ void VAXPY(double a, const double* x, double* y, int n) { } // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 +template <> +bool MatMulKernel::UseMe(const int& d) const { + return platform::MayIUse(platform::avx); +} + template <> bool VMulKernel::UseMe(const int& d) const { return platform::MayIUse(platform::avx512f) && d > 512; @@ -113,6 +142,11 @@ bool VExpKernel::UseMe(const int& d) const { return d > 7; } +template <> +bool VSquareKernel::UseMe(const int& d) const { + return d > 7; +} + template <> bool VSigmoidKernel::UseMe(const int& d) const { return d > 7; @@ -139,12 +173,14 @@ bool SeqPoolKernel::UseMe(const seq_pool_attr_t& attr) const { return true; \ } +AWALYS_USE_ME_WITH_DOUBLE(MatMul); AWALYS_USE_ME_WITH_DOUBLE(VMul); AWALYS_USE_ME_WITH_DOUBLE(VAdd); AWALYS_USE_ME_WITH_DOUBLE(VScal); AWALYS_USE_ME_WITH_DOUBLE(VExp); AWALYS_USE_ME_WITH_DOUBLE(VSigmoid); AWALYS_USE_ME_WITH_DOUBLE(VTanh); +AWALYS_USE_ME_WITH_DOUBLE(VSquare); #undef AWALYS_USE_ME_WITH_DOUBLE } // namespace mkl @@ -159,10 +195,12 @@ namespace mkl = paddle::operators::jit::more::mkl; REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel, \ mkl::func##Kernel) +REGISTER_MKL_KERNEL(kMatMul, MatMul); REGISTER_MKL_KERNEL(kVMul, VMul); REGISTER_MKL_KERNEL(kVAdd, VAdd); REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(kVExp, VExp); +REGISTER_MKL_KERNEL(kVSquare, VSquare); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kSeqPool, SeqPool); diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index 0a3816db24ccd0820cb259b40044e1f5b66665f7..a27196fa19f1d3e9aa6c414b6b9f99a21ef49025 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -24,6 +24,9 @@ namespace jit { namespace more { namespace mkl { +template +void MatMul(const T* a, const T* b, T* c, int m, int n, int k); + template void VMul(const T* x, const T* y, T* z, int n); @@ -36,6 +39,9 @@ void VScal(const T* a, const T* x, T* y, int n); template void VExp(const T* x, T* y, int n); +template +void VSquare(const T* x, T* y, int n); + template void VCopy(const T* x, T* y, int n); @@ -93,6 +99,9 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { const char* ImplType() const override { return "MKL"; } \ } +// ABCMNK +DECLARE_MKL_KERNEL(MatMul, MatMulTuples); + // XYZN DECLARE_MKL_KERNEL(VMul, XYZNTuples); DECLARE_MKL_KERNEL(VAdd, XYZNTuples); @@ -104,6 +113,7 @@ DECLARE_MKL_KERNEL(VScal, AXYNTuples); DECLARE_MKL_KERNEL(VExp, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VTanh, XYNTuples); +DECLARE_MKL_KERNEL(VSquare, XYNTuples); DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 0f626bb3bfd2851e3fb6ad8265169f9bb9860851..4b9bc5e8d49c62404d5d4ef99b7c50987fcb415a 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -27,3 +27,5 @@ USE_JITKERNEL_REFER(kCRFDecoding) USE_JITKERNEL_REFER(kLayerNorm) USE_JITKERNEL_REFER(kNCHW16CMulNC) USE_JITKERNEL_REFER(kSeqPool) +USE_JITKERNEL_REFER(kMatMul) +USE_JITKERNEL_REFER(kVSquare) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index 85381daa47484a4053326f04e12d583543a423e0..3512ad7fe7921381afb6152330fff6be34de5ad7 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -31,6 +31,7 @@ REGISTER_REFER_KERNEL(kVAddBias, VAddBias); REGISTER_REFER_KERNEL(kVRelu, VRelu); REGISTER_REFER_KERNEL(kVIdentity, VIdentity); +REGISTER_REFER_KERNEL(kVSquare, VSquare); REGISTER_REFER_KERNEL(kVExp, VExp); REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid); REGISTER_REFER_KERNEL(kVTanh, VTanh); @@ -49,4 +50,6 @@ REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); REGISTER_REFER_KERNEL(kSeqPool, SeqPool); +REGISTER_REFER_KERNEL(kMatMul, MatMul); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index b4e9c8dd107ee844544165b1719d38754ae976bc..97d029358594d757f0e1874e9c87ecb8f97c9d50 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -83,6 +83,13 @@ inline void VIdentity(const T* x, T* y, int n) { } } +template +inline void VSquare(const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] * x[i]; + } +} + template void VExp(const T* x, T* y, int n) { for (int i = 0; i < n; ++i) { @@ -354,6 +361,23 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { } } +// A(M,K) * B(K,N) = C(M,N) +template +void MatMul(const T* A, const T* B, T* C, int M, int N, int K) { + for (int m = 0; m < M; ++m) { + const T* pa = A + m * K; + T* pc = C + m * N; + for (int n = 0; n < N; ++n) { + const T* pb = B + n; + T sum = static_cast(0); + for (int k = 0; k < K; ++k) { + sum += (pa[k] * pb[k * N]); + } + *(pc + n) = sum; + } + } +} + #define DECLARE_REFER_KERNEL(name, tuples) \ template \ class name##Kernel : public ReferKernel> { \ @@ -377,6 +401,7 @@ DECLARE_REFER_KERNEL(VIdentity, XYNTuples); DECLARE_REFER_KERNEL(VExp, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VTanh, XYNTuples); +DECLARE_REFER_KERNEL(VSquare, XYNTuples); // lstm_t*, const lstm_attr_t* DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); @@ -394,6 +419,8 @@ DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); +DECLARE_REFER_KERNEL(MatMul, MatMulTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 30291bfef3bc96fe2e687e5be6d782eee89496aa..f4415a54ca9678c75038a820bb5d212e61593ec7 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -229,6 +229,26 @@ struct TestFuncWithRefer, std::vector, } }; +template +struct TestFuncWithRefer, std::vector, std::vector, + std::vector, int, int, int> { + void operator()(const typename jit::MatMulTuples::func_type tgt, + const std::vector& a, const std::vector& b, + const std::vector& cref, int m, int n, int k) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(a.size(), static_cast(m * k)); + EXPECT_EQ(b.size(), static_cast(k * n)); + EXPECT_EQ(cref.size(), static_cast(m * n)); + std::vector c(cref.size()); + const T* a_data = a.data(); + const T* b_data = b.data(); + const T* cref_data = cref.data(); + T* c_data = c.data(); + tgt(a_data, b_data, c_data, m, n, k); + ExpectEQ(c_data, cref_data, m * n); + } +}; + template void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { @@ -458,6 +478,28 @@ void TestSeqPoolKernel() { } } +template +void TestMatMulKernel() { + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + for (int m : {1, 2, 3, 4}) { + for (int n : {1, 2, 3, 4}) { + for (int k : TestSizes()) { + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + std::vector a(m * k), b(k * n), c(m * n); + RandomVec(m * k, a.data(), -0.2f, 0.2f); + RandomVec(k * n, b.data(), -0.2f, 0.2f); + const T* a_data = a.data(); + const T* b_data = b.data(); + T* c_data = c.data(); + ref(a_data, b_data, c_data, m, n, k); + TestAllImpls, PlaceType, std::vector, + std::vector, std::vector>(k, a, b, c, m, n, k); + } + } + } +} + template void TestNCHW16CMulNCKernel() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); @@ -562,6 +604,12 @@ TEST(JITKernel, kVIdentity) { TestXYNKernel(); } +TEST(JITKernel, kVSquare) { + namespace jit = paddle::operators::jit; + TestXYNKernel(); + TestXYNKernel(); +} + TEST(JITKernel, kVExp) { namespace jit = paddle::operators::jit; TestXYNKernel(); @@ -618,6 +666,12 @@ TEST(JITKernel, kSeqPool) { TestSeqPoolKernel(); } +TEST(JITKernel, kMatMul) { + namespace jit = paddle::operators::jit; + TestMatMulKernel(); + TestMatMulKernel(); +} + TEST(JITKernel, kNCHW16CMulNC) { namespace jit = paddle::operators::jit; TestNCHW16CMulNCKernel 1, 'Should larger than 1' + self.set_conf() + self.op_type = 'fusion_repeated_fc_relu' + sz = len(self.oc) + ics = [self.ic] + self.oc[0:sz - 1] + assert len(ics) == len(self.oc) + weights = [] + biases = [] + outs = [] + + i = 0 + matrix = MatrixGenerate(self.bs, ics[i], self.oc[i], 1, 1) + inp = np.reshape(matrix.input, [self.bs, ics[i]]) + weights.append(('W_{0}'.format(i), np.reshape(matrix.weights, + [ics[i], self.oc[i]]))) + biases.append(('B_{0}'.format(i), matrix.bias)) + outs.append( + np.reshape( + np.maximum(fc_refer(matrix, True), 0), [self.bs, self.oc[i]])) + + for i in range(sz - 1): + matrix = MatrixGenerate(self.bs, ics[i + 1], self.oc[i + 1], 1, 1) + matrix.input = np.reshape(outs[i], [self.bs, ics[i + 1], 1, 1]) + out = fc_refer(matrix, True) + weights.append( + ('W_{0}'.format(i + 1), + np.reshape(matrix.weights, [ics[i + 1], self.oc[i + 1]]))) + biases.append(('B_{0}'.format(i + 1), matrix.bias)) + outs.append( + np.reshape(np.maximum(out, 0), [self.bs, self.oc[i + 1]])) + + relu_outs = [] + for i in range(sz - 1): + relu_outs.append(('ReluOut_{0}'.format(i), outs[i])) + + self.inputs = { + 'X': inp, + 'W': weights, + 'Bias': biases, + } + + self.outputs = {'Out': outs[-1], 'ReluOut': relu_outs} + + def test_check_output(self): + self.check_output() + + def set_conf(self): + pass + + +class TestFusionRepeatedFCReluOpBS1(TestFusionRepeatedFCReluOp): + def set_conf(self): + self.bs = 1 + self.oc = [4, 2, 7, 5] + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fusion_squared_mat_sub_op.py b/python/paddle/fluid/tests/unittests/test_fusion_squared_mat_sub_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a097d3d9a20f0b4b5dddf286f064d5698de35b5f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fusion_squared_mat_sub_op.py @@ -0,0 +1,53 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestFusionSquaredMatSubOp(OpTest): + def setUp(self): + self.op_type = 'fusion_squared_mat_sub' + self.m = 11 + self.n = 12 + self.k = 4 + self.scalar = 0.5 + self.set_conf() + matx = np.random.random((self.m, self.k)).astype("float32") + maty = np.random.random((self.k, self.n)).astype("float32") + + self.inputs = {'X': matx, 'Y': maty} + self.outputs = { + 'Out': + (np.dot(matx, maty)**2 - np.dot(matx**2, maty**2)) * self.scalar + } + self.attrs = {'scalar': self.scalar, } + + def set_conf(self): + pass + + def test_check_output(self): + self.check_output() + + +class TestFusionSquaredMatSubOpCase1(TestFusionSquaredMatSubOp): + def set_conf(self): + self.scalar = -0.3 + + +if __name__ == '__main__': + unittest.main()