未验证 提交 7f0eb2e3 编写于 作者: Z zyfncg 提交者: GitHub

Refactor Pass for fused_conv (#48848)

* refactor conv_activation_mkldnn_fuse_pass

* refactor conv_affine_channel_mkldnn_fuse_pass

* fix conv_activation_mkldnn_fuse_pass

* fix mkldnn unittest

* refactor int8_scale_calculation_mkldnn_pass and params_quantization_mkldnn_pass

* refactor conv_elementwise_add_mkldnn_fuse_pass

* fix quant

* refactor conv_bn_fuse_pass

* fix conv_bn_fuse_pass

* refactor depthwise_conv_bn_fuse_pass

* fix unittest

* fix conv_bn_fuse_pass

* remove redundant conv2d in params_quantization_mkldnn_pass

* fix params_quantization_mkldnn_pass_tester
上级 b8814777
......@@ -213,6 +213,43 @@ ConvBNFusePass::ConvBNFusePass() {
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("batch_norm"))
.AddInput("X")
......@@ -361,6 +398,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
// with MKL-DNN fuse conv+bn into conv with bias
// without MKL-DNN fuse conv+bn into conv+elementwise_add
if (fuse_option == FUSE_MKLDNN) {
if (conv->Op()->Type() == "conv2d" ||
conv->Op()->Type() == "depthwise_conv2d") {
conv->Op()->SetType("fused_conv2d");
}
auto input_names = conv->Op()->InputNames();
bool has_bias =
std::find(input_names.begin(), input_names.end(), "Bias") !=
......@@ -818,6 +859,43 @@ DepthwiseConvBNFusePass::DepthwiseConvBNFusePass() {
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
} // namespace ir
......
......@@ -1826,20 +1826,20 @@ PDNode *patterns::ConvBias::operator()(
return eltwise_out_var;
}
PDNode *patterns::Conv::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
PDNode *patterns::Conv::operator()(const std::string &conv_type) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op(conv_type);
auto input_var = pattern->NewNode(conv_input_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input");
->assert_is_op_input(conv_type, "Input");
auto filter_var = pattern->NewNode(conv_filter_repr())
->AsInput()
->assert_is_op_input("conv2d", "Filter");
->assert_is_op_input(conv_type, "Filter");
auto output_var = pattern->NewNode(conv_output_repr())
->AsOutput()
->assert_is_op_output("conv2d", "Output");
->assert_is_op_output(conv_type, "Output");
conv_op->LinksFrom({input_var, filter_var}).LinksTo({output_var});
return output_var;
......@@ -2658,10 +2658,12 @@ PDNode *patterns::ConvElementwiseadd::operator()(PDNode *conv_in) {
}
PDNode *patterns::ConvAffineChannel::operator()(
paddle::framework::ir::PDNode *conv_input, bool with_eltwise_add) {
paddle::framework::ir::PDNode *conv_input,
const std::string &conv_type,
bool with_eltwise_add) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);
PDNode *eltwise_op = nullptr;
if (with_eltwise_add) {
......@@ -2676,11 +2678,11 @@ PDNode *patterns::ConvAffineChannel::operator()(
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
->assert_is_op_input(conv_type, "Filter");
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d");
->assert_is_only_output_of_op(conv_type);
PDNode *eltwise_y_in_var = nullptr;
PDNode *eltwise_out_var = nullptr;
......
......@@ -1044,7 +1044,7 @@ struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "convolution") {}
PDNode* operator()();
PDNode* operator()(const std::string& conv_type);
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input);
......@@ -1544,7 +1544,9 @@ struct ConvAffineChannel : public PatternBase {
ConvAffineChannel(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_affine_channel") {}
PDNode* operator()(PDNode* conv_input, bool with_eltwise_add);
PDNode* operator()(PDNode* conv_input,
const std::string& conv_type,
bool with_eltwise_add);
// declare operator node's name
PATTERN_DECL_NODE(conv);
......
......@@ -26,7 +26,7 @@ using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
std::vector<std::string> conv_types = {"conv2d", "fused_conv2d"};
std::vector<std::string> conv_types = {"fused_conv2d", "conv2d"};
for (auto& act_type : act_types) {
FuseConvConcatAct(graph, act_type);
......@@ -64,6 +64,10 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
OpDesc* conv_op = conv->Op();
OpDesc* act_op = activation->Op();
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
}
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
......@@ -91,8 +95,9 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
AddStatis(found_conv_activation_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_conv_activation_count > 0) {
PrettyLogDetail("--- fused %d conv with %s activation",
PrettyLogDetail("--- fused %d %s with %s activation",
found_conv_activation_count,
conv_type,
act_type);
}
}
......@@ -134,15 +139,20 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
bool is_not_conv_mkldnn =
!(prev_op_nodes[0]->Op()->GetAttrIfExists<bool>("use_mkldnn"));
if (prev_op_nodes[0]->Op()->Type() != "conv2d" || is_not_conv_mkldnn) {
LOG(WARNING)
<< "This fuse pass supports only conv2d (mkldnn) + activation.";
if ((prev_op_nodes[0]->Op()->Type() != "conv2d" &&
prev_op_nodes[0]->Op()->Type() != "fused_conv2d") ||
is_not_conv_mkldnn) {
LOG(WARNING) << "This fuse pass supports only conv2d(mkldnn) | "
"fused_conv2d(mkldnn) + activation.";
return;
}
}
for (auto node : concat_inputs) {
OpDesc* conv_op = node->inputs[0]->Op();
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
}
OpDesc* act_op = activation_op->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
......
......@@ -165,7 +165,8 @@ void MainTest(std::string activation) {
int conv_activation_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
if (node->IsOp() && (node->Op()->Type() == "conv2d" ||
node->Op()->Type() == "fused_conv2d")) {
auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
......
......@@ -143,6 +143,44 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("affine_channel"))
.AddInput("X")
.IsTensor()
......@@ -177,6 +215,12 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
}
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
FuseConvAffineChannel(graph, "conv2d");
FuseConvAffineChannel(graph, "fused_conv2d");
}
void ConvAffineChannelFusePass::FuseConvAffineChannel(
ir::Graph* graph, const std::string& conv_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);
......@@ -190,10 +234,10 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
->assert_is_op_input(conv_type, "Input");
patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(),
name_scope_);
conv_ac_pattern(conv_input, false /*with_eltwise_add*/);
conv_ac_pattern(conv_input, conv_type, false /*with_eltwise_add*/);
int found_conv_ac_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......
......@@ -36,6 +36,8 @@ class ConvAffineChannelFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph*) const override;
void FuseConvAffineChannel(ir::Graph* graph,
const std::string& conv_type) const;
const std::string name_scope_{"conv_affine_channel_mkldnn_fuse"};
};
......
......@@ -60,6 +60,43 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
.AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
......@@ -79,12 +116,13 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
const std::string& name_scope,
const GraphWithStats& graph_with_stats,
const std::string& conv_type,
bool as_x) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern();
auto conv_output = conv_pattern(conv_type);
patterns::ResidualElementwise elementwise_pattern{pattern, name_scope, as_x};
elementwise_pattern(
......@@ -127,6 +165,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
return;
}
if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d");
}
conv_op->Op()->SetInput("ResidualData", {residual_data->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
......@@ -155,15 +197,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
const std::string& name_scope,
const GraphWithStats& graph_with_stats) const {
const GraphWithStats& graph_with_stats,
const std::string& conv_type) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Conv conv_x_pattern{pattern, name_scope};
auto conv_x_output = conv_x_pattern();
auto conv_x_output = conv_x_pattern(conv_type);
patterns::Conv conv_y_pattern{pattern, name_scope};
auto conv_y_output = conv_y_pattern();
auto conv_y_output = conv_y_pattern(conv_type);
patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_pattern(conv_x_output, conv_y_output, "elementwise_add");
......@@ -215,6 +258,9 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
if (HasFusedActivation(residual_conv_op)) return;
if (residual_conv_op->Op()->Type() == "conv2d") {
residual_conv_op->Op()->SetType("fused_conv2d");
}
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
......@@ -243,10 +289,18 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
void ResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto graph_with_stats =
FuseProjectionConv(name_scope_, std::make_pair(graph, 0));
graph_with_stats = FuseConv(name_scope_, graph_with_stats, true);
graph_with_stats = FuseConv(name_scope_, graph_with_stats, false);
FuseProjectionConv(name_scope_, std::make_pair(graph, 0), "fused_conv2d");
graph_with_stats =
FuseConv(name_scope_, graph_with_stats, "fused_conv2d", true);
graph_with_stats =
FuseConv(name_scope_, graph_with_stats, "fused_conv2d", false);
graph_with_stats =
FuseProjectionConv(name_scope_, graph_with_stats, "conv2d");
graph_with_stats = FuseConv(name_scope_, graph_with_stats, "conv2d", true);
graph_with_stats = FuseConv(name_scope_, graph_with_stats, "conv2d", false);
AddStatis(graph_with_stats.second);
}
......
......@@ -27,10 +27,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
GraphWithStats FuseConv(const std::string& name_scope,
const GraphWithStats& graph_with_stats,
const std::string& conv_type,
bool as_x) const;
GraphWithStats FuseProjectionConv(
const std::string& name_scope,
const GraphWithStats& graph_with_stats) const;
GraphWithStats FuseProjectionConv(const std::string& name_scope,
const GraphWithStats& graph_with_stats,
const std::string& conv_type) const;
public:
ResidualConnectionMKLDNNFusePass();
......
......@@ -100,8 +100,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
"relu",
nodes_removed,
nodes_added));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}}));
EXPECT_TRUE(test::AssertOpsCount(
graph, {{"fused_conv2d", 1}, {"elementwise_add", 0}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass,
......@@ -134,8 +134,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
"relu",
nodes_removed,
nodes_added));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 2}, {"elementwise_add", 0}}));
EXPECT_TRUE(test::AssertOpsCount(
graph, {{"conv2d", 1}, {"fused_conv2d", 1}, {"elementwise_add", 0}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass,
......@@ -159,8 +159,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
"relu",
nodes_removed,
nodes_added));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}}));
EXPECT_TRUE(test::AssertOpsCount(
graph, {{"fused_conv2d", 1}, {"elementwise_add", 0}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
......@@ -185,8 +185,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
"relu",
nodes_removed,
nodes_added));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}}));
EXPECT_TRUE(test::AssertOpsCount(
graph, {{"fused_conv2d", 1}, {"elementwise_add", 0}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass,
......@@ -210,8 +210,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
"relu",
nodes_removed,
nodes_added));
EXPECT_TRUE(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}}));
EXPECT_TRUE(test::AssertOpsCount(
graph, {{"fused_conv2d", 1}, {"elementwise_add", 0}}));
}
TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
......
......@@ -452,6 +452,9 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
Graph* g) {
VLOG(4) << "Quantize conv2d op";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d");
}
// skip if should not be quantized
if (!platform::HasOpINT8DataType(conv_op->Op())) {
......@@ -1299,10 +1302,10 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument("Scope cannot be nullptr."));
GetQuantInfo(graph);
QuantizeConv(graph, "conv2d", false /* with_residual_data */);
QuantizeConv(graph, "conv2d", true /* with_residual_data */);
QuantizeConv(graph, "fused_conv2d", false /* with_residual_data */);
QuantizeConv(graph, "fused_conv2d", true /* with_residual_data */);
QuantizeConv(graph, "conv2d", false /* with_residual_data */);
QuantizeConv(graph, "conv2d", true /* with_residual_data */);
QuantizePool(graph);
QuantizeConcat(graph);
QuantizePriorBox(graph);
......
......@@ -174,7 +174,7 @@ void PreparePass(std::unique_ptr<ir::Graph>* graph,
void CheckScales(const OpDesc* op, float scale, float shift) {
std::string type = op->Type();
std::vector<std::string> scale_names;
if (type == "conv2d" || type == "fc") {
if (type == "conv2d" || type == "fused_conv2d" || type == "fc") {
EXPECT_EQ(op->GetAttrIfExists<std::vector<float>>("Scale_weights")[0],
scale);
scale_names.push_back("Scale_in");
......@@ -330,7 +330,7 @@ TEST(CpuQuantizePass, quantize) {
// Insert nodes: 8 Quant + 8 IN + 7 OUT + 7 DEQUANT
int added_nodes = 8 + 8 + 7 + 7;
std::unordered_map<std::string, int> expected_operators = {
{"conv2d", 4}, {"pool2d", 2}, {"quantize", 8}, {"dequantize", 7}};
{"fused_conv2d", 4}, {"pool2d", 2}, {"quantize", 8}, {"dequantize", 7}};
MainTest(BuildProgramDesc(use_mkldnn, mkldnn_data_type),
variable_names,
expected_operators,
......@@ -343,7 +343,7 @@ TEST(CpuQuantizePass, do_not_quantize) {
std::string mkldnn_data_type = "float32";
int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = {
{"conv2d", 4}, {"pool2d", 2}, {"quantize", 0}, {"dequantize", 0}};
{"fused_conv2d", 4}, {"pool2d", 2}, {"quantize", 0}, {"dequantize", 0}};
MainTest(BuildProgramDesc(use_mkldnn, mkldnn_data_type),
variable_names,
expected_operators,
......
......@@ -341,6 +341,7 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
if (dequant_in->outputs.size() == 1) {
if (any_op->Op()->Type() == "conv2d" ||
any_op->Op()->Type() == "fused_conv2d" ||
any_op->Op()->Type() == "conv2d_transpose" ||
any_op->Op()->Type() == "fc") {
// do not squash if fuse residual connection is true
......
......@@ -60,9 +60,52 @@ Int8ScaleCalculationMkldnnPass::Int8ScaleCalculationMkldnnPass() {
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
}
void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const {
Int8ScaleImpl(graph, "fused_conv2d");
Int8ScaleImpl(graph, "conv2d");
}
void Int8ScaleCalculationMkldnnPass::Int8ScaleImpl(
ir::Graph* graph, const std::string& conv_type) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
......@@ -70,7 +113,7 @@ void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::Conv conv_pattern(gpd.mutable_pattern(),
"int8_scale_calculation_mkldnn_pass");
conv_pattern();
conv_pattern(conv_type);
int found_int8_scales_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......@@ -80,6 +123,9 @@ void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d");
}
if (!platform::HasOpINT8DataType(conv_op->Op()) ||
conv_op->Op()->HasAttr("Sum_scale")) {
......
......@@ -31,6 +31,7 @@ class Int8ScaleCalculationMkldnnPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
void Int8ScaleImpl(ir::Graph* graph, const std::string& conv_type) const;
};
} // namespace ir
......
......@@ -77,7 +77,7 @@ void QuantizeConvInput(Scope* scope,
} // namespace
ParamsQuantizationMkldnnPass::ParamsQuantizationMkldnnPass() {
AddOpCompat(OpCompat("conv2d"))
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
......@@ -117,10 +117,11 @@ ParamsQuantizationMkldnnPass::ParamsQuantizationMkldnnPass() {
}
void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph,
const std::string& conv_type,
bool with_residual_data) const {
GraphPatternDetector gpd;
patterns::ConvResidual conv_pattern(gpd.mutable_pattern(), name_scope_);
conv_pattern("conv2d", with_residual_data);
conv_pattern(conv_type, with_residual_data);
int params_to_int8_conv_found = 0;
......@@ -159,8 +160,8 @@ void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph,
AddStatis(params_to_int8_conv_found);
std::stringstream msg_ss;
msg_ss << "Quantized params of " << params_to_int8_conv_found
<< " conv2d ops";
msg_ss << "Quantized params of " << params_to_int8_conv_found << " "
<< conv_type << " ops";
if (with_residual_data) msg_ss << " with residual connection";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
......@@ -170,8 +171,8 @@ void ParamsQuantizationMkldnnPass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
FusePassBase::Init(name_scope_, graph);
QuantizeConv(graph, true /*with_residual_data*/);
QuantizeConv(graph, false /*with_residual_data*/);
QuantizeConv(graph, "fused_conv2d", true /*with_residual_data*/);
QuantizeConv(graph, "fused_conv2d", false /*with_residual_data*/);
}
} // namespace ir
......
......@@ -32,7 +32,9 @@ class ParamsQuantizationMkldnnPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
void QuantizeConv(Graph* graph, bool with_residual_connection) const;
void QuantizeConv(Graph* graph,
const std::string& conv_type,
bool with_residual_connection) const;
private:
const std::string name_scope_ = "params_quantization_mkldnn_pass";
......
......@@ -141,7 +141,7 @@ struct ConvProgramStrategy : public ProgramStrategy {
protected:
OpDesc* CreateBasicConvOp(const std::string conv_name = "Conv1") {
auto op = program.MutableBlock(0)->AppendOp();
op->SetType("conv2d");
op->SetType("fused_conv2d");
op->SetAttr("use_mkldnn", true);
op->SetAttr("name", conv_name);
op->SetAttr("mkldnn_data_type", std::string{"int8"});
......
......@@ -648,7 +648,8 @@ void QuantDequantMkldnnPass::DequantizeWeights(
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue;
if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d") {
if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d" ||
op_node->Name() == "fused_conv2d") {
if (onnx_format_quantize_model) {
DequantizeOpWeightsFromONNXFormat(op_node,
scope,
......@@ -708,8 +709,12 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "quant_dequant_mkldnn_pass";
FusePassBase::Init(pattern_name, graph);
const std::unordered_set<std::string> skip_ops = {
"conv2d", "depthwise_conv2d", "mul", "matmul", "matmul_v2"};
const std::unordered_set<std::string> skip_ops = {"conv2d",
"depthwise_conv2d",
"fused_conv2d",
"mul",
"matmul",
"matmul_v2"};
const std::unordered_set<std::string> fake_quantize_types = {
"fake_quantize_moving_average_abs_max", "fake_quantize_range_abs_max"};
......
......@@ -1023,10 +1023,10 @@ void OpDesc::Flush() {
[](std::pair<std::string, Attribute> a,
std::pair<std::string, Attribute> b) { return a.first < b.first; });
for (auto &attr : sorted_attrs) {
for (auto &attr : sorted_runtime_attrs) {
set_attr_desc(attr.first, attr.second);
}
for (auto &attr : sorted_runtime_attrs) {
for (auto &attr : sorted_attrs) {
set_attr_desc(attr.first, attr.second);
}
......
......@@ -129,6 +129,7 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::unordered_set<int>(argument->quantize_excluded_op_ids()));
} else if (pass_name == "cpu_quantize_pass") {
if (argument->quantize_enabled_op_types().count("conv2d") ||
argument->quantize_enabled_op_types().count("fused_conv2d") ||
argument->quantize_enabled_op_types().count("depthwise_conv2d")) {
pass->Set("data_layout", new std::string("NHWC"));
}
......
......@@ -121,7 +121,8 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
// force unsigned type if already know it
bool is_unsigned = false;
bool compute_scale = true;
if (op->Type() == "conv2d" || op->Type() == "fc") {
if (op->Type() == "conv2d" || op->Type() == "fused_conv2d" ||
op->Type() == "fc") {
// output of conv2d with relu must be unsigned
std::string fuse_activation =
op->GetAttrIfExists<std::string>("fuse_activation");
......
......@@ -39,14 +39,44 @@ def {
name: "data_format"
type: STRING
}
}
extra {
attrs {
name: "fuse_activation"
type: STRING
}
attrs {
name: "fuse_alpha"
type: FLOAT
}
attrs {
name: "fuse_beta"
type: FLOAT
}
attrs {
name: "fuse_residual_connection"
type: BOOLEAN
}
attrs {
name: "Scale_in"
type: FLOAT
}
attrs {
name: "Scale_out"
type: FLOAT
}
attrs {
name: "Scale_in_eltwise"
type: FLOAT
}
attrs {
name: "Scale_weights"
type: FLOATS
}
attrs {
name: "Bias_scales"
type: FLOATS
}
attrs {
name: "force_fp32_output"
type: BOOLEAN
......
......@@ -168,7 +168,7 @@ class TestConvBnFusePass(PassAutoScanTest):
if program_config.ops[0].attrs['use_mkldnn']:
config = self.create_inference_config()
config.enable_mkldnn()
yield config, ['conv2d'], (1e-5, 1e-5)
yield config, ['fused_conv2d'], (1e-5, 1e-5)
else:
config = self.create_inference_config()
yield config, ['conv2d', 'elementwise_add'], (1e-5, 1e-5)
......
......@@ -158,7 +158,7 @@ class TestConvConcatActivationMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ['conv2d', 'conv2d', 'concat'], (1e-5, 1e-5)
yield config, ['fused_conv2d', 'fused_conv2d', 'concat'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
......
......@@ -118,7 +118,7 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["relu", "conv2d", "conv2d"], (1e-5, 1e-5)
yield config, ["relu", "conv2d", "fused_conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
......
......@@ -96,7 +96,7 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5)
yield config, ["fused_conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
......
......@@ -93,7 +93,7 @@ class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5)
yield config, ["fused_conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
......
......@@ -98,7 +98,7 @@ class TestConvHardSwishMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5)
yield config, ["fused_conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
......
......@@ -97,7 +97,7 @@ class TestConvMishMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5)
yield config, ["fused_conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册