未验证 提交 9429936c 编写于 作者: S Sławomir Siwek 提交者: GitHub

Fused ops converter (#50751)

* ConvertToFusedOp

* change static to inline
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>

---------
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>
上级 6ef3f2ce
......@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/tensor.h"
......@@ -413,7 +414,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
if (is_mkldnn) {
if (conv->Op()->Type() == "conv2d" ||
conv->Op()->Type() == "depthwise_conv2d") {
conv->Op()->SetType("fused_conv2d");
ConvertToFusedOp(conv->Op());
}
if (mkldnn_with_bias) {
// reuse existing conv bias node
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_desc.h"
namespace paddle {
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -66,7 +65,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
OpDesc* conv_op = conv->Op();
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
ConvertToFusedOp(conv_op);
}
SetActivationAttrs(conv_op, activation->Op(), act_type);
......@@ -138,7 +137,7 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
for (auto node : concat_inputs) {
OpDesc* conv_op = node->inputs[0]->Op();
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
ConvertToFusedOp(conv_op);
}
SetActivationAttrs(conv_op, activation_op->Op(), act_type);
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -166,7 +167,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
}
if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d");
ConvertToFusedOp(conv_op->Op());
}
conv_op->Op()->SetInput("ResidualData", {residual_data->Name()});
......@@ -259,7 +260,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
if (HasFusedActivation(residual_conv_op)) return;
if (residual_conv_op->Op()->Type() == "conv2d") {
residual_conv_op->Op()->SetType("fused_conv2d");
ConvertToFusedOp(residual_conv_op->Op());
}
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
......
......@@ -453,7 +453,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
VLOG(4) << "Quantize conv2d op";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d");
ConvertToFusedOp(conv_op->Op());
}
// skip if should not be quantized
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h"
#include <unordered_set>
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
namespace paddle {
namespace framework {
......@@ -22,18 +23,6 @@ namespace ir {
class Graph;
void ReplaceWithFusedOp(Node* op) {
const std::string matmul_type = op->Op()->Type();
if (matmul_type == "matmul" || matmul_type == "matmul_v2") {
op->Op()->SetType("fused_matmul");
if (matmul_type == "matmul") {
op->Op()->SetAttr("trans_x", op->Op()->GetAttr("transpose_X"));
op->Op()->SetAttr("trans_y", op->Op()->GetAttr("transpose_Y"));
op->Op()->SetAttr("matmul_alpha", op->Op()->GetAttr("alpha"));
}
}
}
void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Marks operators which are to be quantized.";
std::unordered_set<std::string> supported_op_types =
......@@ -97,7 +86,7 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
return;
}
ReplaceWithFusedOp(op);
ConvertToFusedOp(op->Op());
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
};
gpd(graph, handler);
......
......@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -355,7 +356,7 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
if (output_name.empty()) return;
if (any_op->Op()->Type() == "conv2d") {
any_op->Op()->SetType("fused_conv2d");
ConvertToFusedOp(any_op->Op());
}
any_op->Op()->SetAttr("force_fp32_output", true);
any_op->Op()->SetOutput(output_name,
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/phi/core/enforce.h"
......@@ -124,7 +125,7 @@ void Int8ScaleCalculationMkldnnPass::Int8ScaleImpl(
}
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d");
ConvertToFusedOp(conv_op->Op());
}
if (!platform::HasOpINT8DataType(conv_op->Op()) ||
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -63,13 +64,7 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
OpDesc* matmul_op = matmul->Op();
matmul_op->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul_op->SetAttr("trans_x", matmul_op->GetAttr("transpose_X"));
matmul_op->SetAttr("trans_y", matmul_op->GetAttr("transpose_Y"));
matmul_op->SetAttr("matmul_alpha", matmul_op->GetAttr("alpha"));
}
ConvertToFusedOp(matmul_op);
SetActivationAttrs(matmul_op, activation->Op(), act_type);
matmul_op->SetOutput("Out", {activation_out->Name()});
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -65,12 +66,7 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd(
return;
}
matmul->Op()->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul->Op()->SetAttr("trans_x", matmul->Op()->GetAttr("transpose_X"));
matmul->Op()->SetAttr("trans_y", matmul->Op()->GetAttr("transpose_Y"));
matmul->Op()->SetAttr("matmul_alpha", matmul->Op()->GetAttr("alpha"));
}
ConvertToFusedOp(matmul->Op());
matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()});
matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()});
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -84,12 +85,7 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse(
}
OpDesc *matmul_desc = matmul_op->Op();
matmul_desc->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
matmul_desc->SetAttr("matmul_alpha", matmul_desc->GetAttr("alpha"));
}
ConvertToFusedOp(matmul_desc);
matmul_desc->SetOutput("Out", {reshape_out->Name()});
matmul_desc->SetAttr("fused_reshape_Out", reshape_shape);
matmul_desc->SetAttr("fused_transpose_Out", transpose_axis);
......
......@@ -155,6 +155,26 @@ static void GetInfoFromTheFirstOp(ir::Graph* graph,
}
}
inline void ConvertToFusedOp(OpDesc* op) {
const std::map<std::string, std::string> fused_ops = {
{"conv2d", "fused_conv2d"},
{"depthwise_conv2d", "fused_conv2d"},
{"matmul", "fused_matmul"},
{"matmul_v2", "fused_matmul"}};
if (op->Type() == "matmul") {
op->SetAttr("trans_x", op->GetAttr("transpose_X"));
op->SetAttr("trans_y", op->GetAttr("transpose_Y"));
op->SetAttr("matmul_alpha", op->GetAttr("alpha"));
}
auto it = fused_ops.find(op->Type());
if (it != fused_ops.end()) {
op->SetType(it->second);
VLOG(3) << "Converted " << it->first << " to " << it->second;
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -86,17 +87,8 @@ void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph,
scale = *(scale_tensor->data<float>());
}
if (op_type == "matmul") {
operator_op->Op()->SetType("fused_matmul");
operator_op->Op()->SetAttr("trans_x",
operator_op->Op()->GetAttr("transpose_X"));
operator_op->Op()->SetAttr("trans_y",
operator_op->Op()->GetAttr("transpose_Y"));
operator_op->Op()->SetAttr("matmul_alpha",
operator_op->Op()->GetAttr("alpha"));
}
if (op_type == "matmul_v2") {
operator_op->Op()->SetType("fused_matmul");
if (op_type == "matmul" || op_type == "matmul_v2") {
ConvertToFusedOp(operator_op->Op());
}
operator_op->Op()->SetAttr("fused_output_scale", scale);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -142,12 +143,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
return;
}
matmul_desc->SetType("fused_matmul");
if (matmul_type == "matmul") {
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
matmul_desc->SetAttr("matmul_alpha", matmul_desc->GetAttr("alpha"));
}
ConvertToFusedOp(matmul_desc);
matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()});
matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape);
matmul_desc->SetAttr("fused_transpose_" + matmul_input_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册