未验证 提交 9429936c 编写于 作者: S Sławomir Siwek 提交者: GitHub

Fused ops converter (#50751)

* ConvertToFusedOp

* change static to inline
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>

---------
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>
上级 6ef3f2ce
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -413,7 +414,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -413,7 +414,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
if (is_mkldnn) { if (is_mkldnn) {
if (conv->Op()->Type() == "conv2d" || if (conv->Op()->Type() == "conv2d" ||
conv->Op()->Type() == "depthwise_conv2d") { conv->Op()->Type() == "depthwise_conv2d") {
conv->Op()->SetType("fused_conv2d"); ConvertToFusedOp(conv->Op());
} }
if (mkldnn_with_bias) { if (mkldnn_with_bias) {
// reuse existing conv bias node // reuse existing conv bias node
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
namespace paddle { namespace paddle {
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -66,7 +65,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, ...@@ -66,7 +65,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
OpDesc* conv_op = conv->Op(); OpDesc* conv_op = conv->Op();
if (conv_op->Type() == "conv2d") { if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d"); ConvertToFusedOp(conv_op);
} }
SetActivationAttrs(conv_op, activation->Op(), act_type); SetActivationAttrs(conv_op, activation->Op(), act_type);
...@@ -138,7 +137,7 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct( ...@@ -138,7 +137,7 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
for (auto node : concat_inputs) { for (auto node : concat_inputs) {
OpDesc* conv_op = node->inputs[0]->Op(); OpDesc* conv_op = node->inputs[0]->Op();
if (conv_op->Type() == "conv2d") { if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d"); ConvertToFusedOp(conv_op);
} }
SetActivationAttrs(conv_op, activation_op->Op(), act_type); SetActivationAttrs(conv_op, activation_op->Op(), act_type);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -166,7 +167,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv( ...@@ -166,7 +167,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
} }
if (conv_op->Op()->Type() == "conv2d") { if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d"); ConvertToFusedOp(conv_op->Op());
} }
conv_op->Op()->SetInput("ResidualData", {residual_data->Name()}); conv_op->Op()->SetInput("ResidualData", {residual_data->Name()});
...@@ -259,7 +260,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -259,7 +260,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
if (HasFusedActivation(residual_conv_op)) return; if (HasFusedActivation(residual_conv_op)) return;
if (residual_conv_op->Op()->Type() == "conv2d") { if (residual_conv_op->Op()->Type() == "conv2d") {
residual_conv_op->Op()->SetType("fused_conv2d"); ConvertToFusedOp(residual_conv_op->Op());
} }
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_out->Name()}); residual_conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
......
...@@ -453,7 +453,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -453,7 +453,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
VLOG(4) << "Quantize conv2d op"; VLOG(4) << "Quantize conv2d op";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
if (conv_op->Op()->Type() == "conv2d") { if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d"); ConvertToFusedOp(conv_op->Op());
} }
// skip if should not be quantized // skip if should not be quantized
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h" #include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h"
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -22,18 +23,6 @@ namespace ir { ...@@ -22,18 +23,6 @@ namespace ir {
class Graph; class Graph;
void ReplaceWithFusedOp(Node* op) {
const std::string matmul_type = op->Op()->Type();
if (matmul_type == "matmul" || matmul_type == "matmul_v2") {
op->Op()->SetType("fused_matmul");
if (matmul_type == "matmul") {
op->Op()->SetAttr("trans_x", op->Op()->GetAttr("transpose_X"));
op->Op()->SetAttr("trans_y", op->Op()->GetAttr("transpose_Y"));
op->Op()->SetAttr("matmul_alpha", op->Op()->GetAttr("alpha"));
}
}
}
void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Marks operators which are to be quantized."; VLOG(3) << "Marks operators which are to be quantized.";
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
...@@ -97,7 +86,7 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -97,7 +86,7 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
return; return;
} }
ReplaceWithFusedOp(op); ConvertToFusedOp(op->Op());
op->Op()->SetAttr("mkldnn_data_type", std::string("int8")); op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
}; };
gpd(graph, handler); gpd(graph, handler);
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -355,7 +356,7 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { ...@@ -355,7 +356,7 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
if (output_name.empty()) return; if (output_name.empty()) return;
if (any_op->Op()->Type() == "conv2d") { if (any_op->Op()->Type() == "conv2d") {
any_op->Op()->SetType("fused_conv2d"); ConvertToFusedOp(any_op->Op());
} }
any_op->Op()->SetAttr("force_fp32_output", true); any_op->Op()->SetAttr("force_fp32_output", true);
any_op->Op()->SetOutput(output_name, any_op->Op()->SetOutput(output_name,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.h" #include "paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
...@@ -124,7 +125,7 @@ void Int8ScaleCalculationMkldnnPass::Int8ScaleImpl( ...@@ -124,7 +125,7 @@ void Int8ScaleCalculationMkldnnPass::Int8ScaleImpl(
} }
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
if (conv_op->Op()->Type() == "conv2d") { if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d"); ConvertToFusedOp(conv_op->Op());
} }
if (!platform::HasOpINT8DataType(conv_op->Op()) || if (!platform::HasOpINT8DataType(conv_op->Op()) ||
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -63,13 +64,7 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct( ...@@ -63,13 +64,7 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
OpDesc* matmul_op = matmul->Op(); OpDesc* matmul_op = matmul->Op();
matmul_op->SetType("fused_matmul"); ConvertToFusedOp(matmul_op);
if (matmul_type == "matmul") {
matmul_op->SetAttr("trans_x", matmul_op->GetAttr("transpose_X"));
matmul_op->SetAttr("trans_y", matmul_op->GetAttr("transpose_Y"));
matmul_op->SetAttr("matmul_alpha", matmul_op->GetAttr("alpha"));
}
SetActivationAttrs(matmul_op, activation->Op(), act_type); SetActivationAttrs(matmul_op, activation->Op(), act_type);
matmul_op->SetOutput("Out", {activation_out->Name()}); matmul_op->SetOutput("Out", {activation_out->Name()});
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -65,12 +66,7 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd( ...@@ -65,12 +66,7 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd(
return; return;
} }
matmul->Op()->SetType("fused_matmul"); ConvertToFusedOp(matmul->Op());
if (matmul_type == "matmul") {
matmul->Op()->SetAttr("trans_x", matmul->Op()->GetAttr("transpose_X"));
matmul->Op()->SetAttr("trans_y", matmul->Op()->GetAttr("transpose_Y"));
matmul->Op()->SetAttr("matmul_alpha", matmul->Op()->GetAttr("alpha"));
}
matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()}); matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()});
matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()}); matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()});
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -84,12 +85,7 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse( ...@@ -84,12 +85,7 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse(
} }
OpDesc *matmul_desc = matmul_op->Op(); OpDesc *matmul_desc = matmul_op->Op();
matmul_desc->SetType("fused_matmul"); ConvertToFusedOp(matmul_desc);
if (matmul_type == "matmul") {
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
matmul_desc->SetAttr("matmul_alpha", matmul_desc->GetAttr("alpha"));
}
matmul_desc->SetOutput("Out", {reshape_out->Name()}); matmul_desc->SetOutput("Out", {reshape_out->Name()});
matmul_desc->SetAttr("fused_reshape_Out", reshape_shape); matmul_desc->SetAttr("fused_reshape_Out", reshape_shape);
matmul_desc->SetAttr("fused_transpose_Out", transpose_axis); matmul_desc->SetAttr("fused_transpose_Out", transpose_axis);
......
...@@ -155,6 +155,26 @@ static void GetInfoFromTheFirstOp(ir::Graph* graph, ...@@ -155,6 +155,26 @@ static void GetInfoFromTheFirstOp(ir::Graph* graph,
} }
} }
inline void ConvertToFusedOp(OpDesc* op) {
const std::map<std::string, std::string> fused_ops = {
{"conv2d", "fused_conv2d"},
{"depthwise_conv2d", "fused_conv2d"},
{"matmul", "fused_matmul"},
{"matmul_v2", "fused_matmul"}};
if (op->Type() == "matmul") {
op->SetAttr("trans_x", op->GetAttr("transpose_X"));
op->SetAttr("trans_y", op->GetAttr("transpose_Y"));
op->SetAttr("matmul_alpha", op->GetAttr("alpha"));
}
auto it = fused_ops.find(op->Type());
if (it != fused_ops.end()) {
op->SetType(it->second);
VLOG(3) << "Converted " << it->first << " to " << it->second;
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -86,17 +87,8 @@ void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph, ...@@ -86,17 +87,8 @@ void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph,
scale = *(scale_tensor->data<float>()); scale = *(scale_tensor->data<float>());
} }
if (op_type == "matmul") { if (op_type == "matmul" || op_type == "matmul_v2") {
operator_op->Op()->SetType("fused_matmul"); ConvertToFusedOp(operator_op->Op());
operator_op->Op()->SetAttr("trans_x",
operator_op->Op()->GetAttr("transpose_X"));
operator_op->Op()->SetAttr("trans_y",
operator_op->Op()->GetAttr("transpose_Y"));
operator_op->Op()->SetAttr("matmul_alpha",
operator_op->Op()->GetAttr("alpha"));
}
if (op_type == "matmul_v2") {
operator_op->Op()->SetType("fused_matmul");
} }
operator_op->Op()->SetAttr("fused_output_scale", scale); operator_op->Op()->SetAttr("fused_output_scale", scale);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -142,12 +143,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -142,12 +143,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
return; return;
} }
matmul_desc->SetType("fused_matmul"); ConvertToFusedOp(matmul_desc);
if (matmul_type == "matmul") {
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
matmul_desc->SetAttr("matmul_alpha", matmul_desc->GetAttr("alpha"));
}
matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()}); matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()});
matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape); matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape);
matmul_desc->SetAttr("fused_transpose_" + matmul_input_name, matmul_desc->SetAttr("fused_transpose_" + matmul_input_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册