From 2deada9ac81ae8f6f100bbd992c8ef3ffc2257ca Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 6 Feb 2023 17:02:15 +0800 Subject: [PATCH] Delete extra input (Bias, ResidualData) in OpMaker of conv2d (#49121) * remove extra input of conv2d * fix bug * fix unittest bug * adjust conv2d.pbtxt * fix cpu_quantize_pass_tester * revert use_addto of conv2d * fix runtime attribute * fix bug * recover force_fp32_output in conv2d * refine error info * fix bug --- .../ir/mkldnn/cpu_quantize_pass_tester.cc | 3 -- .../ir/mkldnn/cpu_quantize_squash_pass.cc | 4 +- .../ir/quant_conv2d_dequant_fuse_pass.cc | 11 ++-- paddle/fluid/framework/op_desc.cc | 5 ++ paddle/fluid/operators/compat/conv2d.pbtxt | 50 ------------------- .../operators/compat/depthwise_conv2d.pbtxt | 46 ----------------- paddle/fluid/operators/conv_op.cc | 12 ----- .../fluid/operators/fused/conv_fusion_op.cc | 10 ++++ .../fluid/operators/fused/fused_conv2d_op.cc | 10 ++++ paddle/fluid/operators/ops_extra_info.h | 4 +- paddle/phi/api/yaml/op_compat.yaml | 12 +++-- paddle/phi/kernels/onednn/conv_kernel.cc | 23 ++------- .../mkldnn/test_conv2d_bf16_mkldnn_op.py | 1 + .../mkldnn/test_conv2d_int8_mkldnn_op.py | 4 ++ .../unittests/mkldnn/test_conv2d_mkldnn_op.py | 7 +++ .../quantization/quant_int8_mkldnn_pass.py | 2 +- 16 files changed, 62 insertions(+), 142 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index 8b2c140163a..f61e236bb38 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -59,9 +59,6 @@ void SetOp(ProgramDesc* prog, op->SetAttr("fuse_residual_connection", false); } op->SetOutput("Output", {outputs[0]}); - op->SetAttr("Scale_in", 1.0f); - op->SetAttr("Scale_out", 1.0f); - op->SetAttr("Scale_weights", std::vector{1.0f}); } else if (type == "pool2d" || type == "transpose2" || type == "reshape2" || type == "nearest_interp" || type == "nearest_interp_v2") { op->SetInput("X", {inputs[0]}); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 6475f01dc76..e3c2e553bbf 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -354,7 +354,9 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { FindOutputNameByVarName(any_op->Op(), dequant_in->Name()); if (output_name.empty()) return; - + if (any_op->Op()->Type() == "conv2d") { + any_op->Op()->SetType("fused_conv2d"); + } any_op->Op()->SetAttr("force_fp32_output", true); any_op->Op()->SetOutput(output_name, std::vector({dequant_out->Name()})); diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index cb9178f365f..a4cb15dcf3f 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -411,6 +411,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, std::string input_name = ""; if (quantized_op_type == "conv2d" || quantized_op_type == "depthwise_conv2d" || + quantized_op_type == "fused_conv2d" || quantized_op_type == "conv2d_fusion" || quantized_op_type == "conv2d_transpose") { weight_name = "Filter"; @@ -424,9 +425,10 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, input_name = "Input"; } else { PADDLE_THROW(platform::errors::Unimplemented( - "QuantDequantFuse: We only support conv2d, conv2d_fusion, " - "conv2d_transpose, fc, mul, matmul, matmul_v2 for " - "now.")); + "QuantDequantFuse: We only support conv2d, conv2d_fusion, fused_conv2d," + "conv2d_transpose, fc, mul, matmul, matmul_v2 for now, but received: " + "%s.", + quantized_op_type)); } const std::string pattern_name = "dequant_fuse"; GraphPatternDetector gpd; @@ -559,6 +561,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, } } } else if (quantized_op_type == "conv2d" || + quantized_op_type == "fusd_conv2d" || quantized_op_type == "depthwise_conv2d") { PADDLE_ENFORCE_EQ( dequant_type, @@ -642,6 +645,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, new_op_desc.SetType(quantized_op_type); new_op_desc.SetAttr("enable_int8", true); if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" || + quantized_op_type == "fused_conv2d" || quantized_op_type == "depthwise_conv2d" || quantized_op_type == "conv2d_transpose") { new_op_desc.SetInput("Input", {new_input}); @@ -677,6 +681,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; std::unordered_set quantized_op_types = { "conv2d", + "fused_conv2d", "mul", "matmul", "depthwise_conv2d", diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 4e80bd7ff0d..9cb7c8c9e8b 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -671,6 +671,11 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { if (extra_attr_iter != extra_attr_map.end()) { is_runtime_attr = true; attrs_ptr = &(this->runtime_attrs_); + // When an attribute is found in both attrs and runtime_attrs, it must + // be a runtime attribute, so it's value in attrs should be removed. + if (this->attrs_.find(name) != this->attrs_.end()) { + this->attrs_.erase(name); + } } // NOTICE(minqiyang): pybind11 will take the empty list in python as // the std::vector type in C++; so we have to change the attr's type diff --git a/paddle/fluid/operators/compat/conv2d.pbtxt b/paddle/fluid/operators/compat/conv2d.pbtxt index 8de061a3cc2..b18e0264992 100644 --- a/paddle/fluid/operators/compat/conv2d.pbtxt +++ b/paddle/fluid/operators/compat/conv2d.pbtxt @@ -6,12 +6,6 @@ def { inputs { name: "Filter" } - inputs { - name: "Bias" - } - inputs { - name: "ResidualData" - } outputs { name: "Output" } @@ -69,54 +63,10 @@ extra { name: "skip_quant" type: BOOLEAN } - attrs { - name: "fuse_relu_before_depthwise_conv" - type: BOOLEAN - } - attrs { - name: "fuse_relu" - type: BOOLEAN - } - attrs { - name: "fuse_activation" - type: STRING - } - attrs { - name: "fuse_alpha" - type: FLOAT - } - attrs { - name: "fuse_beta" - type: FLOAT - } attrs { name: "use_addto" type: BOOLEAN } - attrs { - name: "fuse_residual_connection" - type: BOOLEAN - } - attrs { - name: "Scale_in" - type: FLOAT - } - attrs { - name: "Scale_out" - type: FLOAT - } - attrs { - name: "Scale_in_eltwise" - type: FLOAT - } - attrs { - name: "Scale_weights" - type: FLOATS - } - attrs { - name: "force_fp32_output" - type: BOOLEAN - } attrs { name: "workspace_size_MB" type: INT diff --git a/paddle/fluid/operators/compat/depthwise_conv2d.pbtxt b/paddle/fluid/operators/compat/depthwise_conv2d.pbtxt index 1fbb99c03e8..ee04cd73dd7 100644 --- a/paddle/fluid/operators/compat/depthwise_conv2d.pbtxt +++ b/paddle/fluid/operators/compat/depthwise_conv2d.pbtxt @@ -6,12 +6,6 @@ def { inputs { name: "Filter" } - inputs { - name: "Bias" - } - inputs { - name: "ResidualData" - } outputs { name: "Output" } @@ -65,50 +59,10 @@ extra { name: "fuse_relu_before_depthwise_conv" type: BOOLEAN } - attrs { - name: "fuse_relu" - type: BOOLEAN - } - attrs { - name: "fuse_activation" - type: STRING - } - attrs { - name: "fuse_alpha" - type: FLOAT - } - attrs { - name: "fuse_beta" - type: FLOAT - } attrs { name: "use_addto" type: BOOLEAN } - attrs { - name: "fuse_residual_connection" - type: BOOLEAN - } - attrs { - name: "Scale_in" - type: FLOAT - } - attrs { - name: "Scale_out" - type: FLOAT - } - attrs { - name: "Scale_in_eltwise" - type: FLOAT - } - attrs { - name: "Scale_weights" - type: FLOATS - } - attrs { - name: "force_fp32_output" - type: BOOLEAN - } attrs { name: "workspace_size_MB" type: INT diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index e41270de650..710f57b4280 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -250,18 +250,6 @@ void Conv2DOpMaker::Make() { "H is the height of the filter, and W is the width of the filter. " "If the groups attribute is greater than 1, C equals the number of " "input image channels divided by the groups."); - AddInput("Bias", - "(Tensor) Bias to be added to each output of filter application." - "The format of output tensor is X (one-dimensional) of size equal" - "to the number of output channels. Only used with MKL-DNN.") - .AsDispensable() - .AsExtra(); - AddInput("ResidualData", - "(Tensor) Tensor with residual data " - "to which convolution output will be added." - "Used with fuse_residual_connection fusion.") - .AsDispensable() - .AsExtra(); AddOutput("Output", "(Tensor) The output tensor of convolution operator. " "It has same data fromat and data type as the Input."); diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cc b/paddle/fluid/operators/fused/conv_fusion_op.cc index e50b42832f1..e6c2cda275a 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cc +++ b/paddle/fluid/operators/fused/conv_fusion_op.cc @@ -33,6 +33,16 @@ namespace operators { class Conv2DFusionOpMaker : public Conv2DOpMaker { protected: void Apply() override { + AddInput("Bias", + "(Tensor) Bias to be added to each output of filter application." + "The format of output tensor is X (one-dimensional) of size equal" + "to the number of output channels. Only used with MKL-DNN.") + .AsDispensable(); + AddInput("ResidualData", + "(Tensor) Tensor with residual data " + "to which convolution output will be added." + "Used with fuse_residual_connection fusion.") + .AsDispensable(); AddAttr( "activation", "The activation type can be 'identity', 'sigmoid', 'relu', 'relu6' " diff --git a/paddle/fluid/operators/fused/fused_conv2d_op.cc b/paddle/fluid/operators/fused/fused_conv2d_op.cc index 178c2a963e2..322bc6944af 100644 --- a/paddle/fluid/operators/fused/fused_conv2d_op.cc +++ b/paddle/fluid/operators/fused/fused_conv2d_op.cc @@ -23,6 +23,16 @@ namespace operators { class FusedConvOpMaker : public Conv2DOpMaker { protected: void Apply() override { + AddInput("Bias", + "(Tensor) Bias to be added to each output of filter application." + "The format of output tensor is X (one-dimensional) of size equal" + "to the number of output channels. Only used with MKL-DNN.") + .AsDispensable(); + AddInput("ResidualData", + "(Tensor) Tensor with residual data " + "to which convolution output will be added." + "Used with fuse_residual_connection fusion.") + .AsDispensable(); AddAttr( "mkldnn_data_type", "(string, default \"float32\"). Data type of mkldnn kernel") diff --git a/paddle/fluid/operators/ops_extra_info.h b/paddle/fluid/operators/ops_extra_info.h index 02624b9a49f..10ee3994b58 100644 --- a/paddle/fluid/operators/ops_extra_info.h +++ b/paddle/fluid/operators/ops_extra_info.h @@ -218,9 +218,7 @@ class ExtraInfoUtils { // TODO(chenweihang): move these extra inputs into op_compat.yaml std::unordered_map> - g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}}, - {"conv2d_transpose", {"Bias"}}, - {"conv2d_grad", {"Bias"}}}; + g_extra_input_names_map_ = {{"conv2d_transpose", {"Bias"}}}; std::vector empty_extra_input_names_; }; diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 80db4ae909b..c5eaf089b7c 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -242,11 +242,8 @@ - op : conv2d backward : conv2d_grad extra : - attrs : [bool is_test = false, bool use_cudnn = true, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false, - bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false, - str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false, - bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f, - float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false, + attrs : [bool is_test = false, bool use_cudnn = true, bool use_mkldnn = false, bool use_addto = false, + str mkldnn_data_type = "float32", bool force_fp32_output = false, int workspace_size_MB = phi::backends::gpu::GetDefaultConvWorkspaceSizeLimitMB(), bool exhaustive_search = false] - op : conv2d_fusion @@ -602,6 +599,11 @@ extra : attrs : [bool use_mkldnn = false] +- op : fused_conv2d + extra : + attrs : [bool use_cudnn = false, float fuse_alpha = 0.0f, float fuse_beta = 0.0f, float Scale_in = 1.0f, + float Scale_out = 1.0f, float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}'] + - op : gather backward : gather_grad extra : diff --git a/paddle/phi/kernels/onednn/conv_kernel.cc b/paddle/phi/kernels/onednn/conv_kernel.cc index 1e54ba0337e..c2ed2c10410 100644 --- a/paddle/phi/kernels/onednn/conv_kernel.cc +++ b/paddle/phi/kernels/onednn/conv_kernel.cc @@ -41,29 +41,16 @@ void ConvKernel(const Context& dev_ctx, dev_ctx.GetDnnAttr("mkldnn_data_type")) == "bfloat16" : false; - const auto* bias = - dev_ctx.HasDnnInput("Bias") ? dev_ctx.GetDnnInput("Bias") : nullptr; - const auto* residual_param = dev_ctx.HasDnnInput("ResidualData") - ? dev_ctx.GetDnnInput("ResidualData") - : nullptr; - bool fuse_residual_conn = - dev_ctx.HasDnnAttr("fuse_residual_connection") - ? PADDLE_GET_CONST(bool, - dev_ctx.GetDnnAttr("fuse_residual_connection")) - : false; - const std::string& fuse_activation = - dev_ctx.HasDnnAttr("fuse_activation") - ? PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation")) - : ""; bool force_fp32_output = dev_ctx.HasDnnAttr("force_fp32_output") ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) : false; + ConvOnednn(dev_ctx, &input, &filter, - bias, - residual_param, + nullptr, + nullptr, strides, paddings, padding_algorithm, @@ -72,8 +59,8 @@ void ConvKernel(const Context& dev_ctx, data_format, is_test, is_BFLOAT16, - fuse_activation, - fuse_residual_conn, + "", + false, force_fp32_output, out); } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py index 332b9e1e860..9f620d10905 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py @@ -104,6 +104,7 @@ class TestConv2DBF16Op(TestConv2DOp): } if self.fuse_residual: + self.op_type = "fused_conv2d" self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype( convert_float_to_uint16(self.input_residual) ) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py index def69dea569..ed34371fe06 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py @@ -158,6 +158,9 @@ class TestConv2DInt8Op(TestConv2DOp): input_residual ) + if self.fuse_activation != "" or self.fuse_residual: + self.op_type = "fused_conv2d" + self.attrs = { 'strides': self.stride, 'paddings': self.pad, @@ -341,6 +344,7 @@ class TestWithInput1x1Filter1x1(TestConv2DInt8Op): def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual): + self.op_type = "fused_conv2d" self.srctype = input_dt self.dsttype = np.uint8 if fuse_activation == "relu" else np.int8 diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py index 7fd7c867788..cb515755428 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py @@ -99,6 +99,13 @@ class TestConv2DMKLDNNOp(TestConv2DOp): output = np.minimum(np.maximum(output, 0), self.fuse_alpha).astype( self.dsttype ) + if ( + self.fuse_activation != "" + or self.fuse_bias + or self.fuse_residual_connection + ): + self.op_type = 'fused_conv2d' + output = output.astype(self.dtype) self.attrs['fuse_bias'] = self.fuse_bias diff --git a/python/paddle/static/quantization/quant_int8_mkldnn_pass.py b/python/paddle/static/quantization/quant_int8_mkldnn_pass.py index de04f66b3b8..f4ddcad8659 100644 --- a/python/paddle/static/quantization/quant_int8_mkldnn_pass.py +++ b/python/paddle/static/quantization/quant_int8_mkldnn_pass.py @@ -161,7 +161,7 @@ class QuantInt8MkldnnPass: } conv_op_node = graph.create_op_node( - op_type='conv2d', + op_type='fused_conv2d', attrs=attrs, inputs={'Input': input_var_node, 'Filter': weight_var_node}, outputs={'Output': output_var_node}, -- GitLab