diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a3b49476d820f2871d8d9edb6b3d665ebb05e52e..283e79b81e7c678e8a4fcc3becefa5279eacb38b 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -118,6 +118,7 @@ if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn) pass_library(mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry elementwise_add_op gelu_op activation_op softmax_op softmax DIR mkldnn) pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn) + pass_library(conv_affine_channel_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_affine_channel_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_affine_channel_mkldnn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..50e751e02dfa03f920a662884081579f0171fed4 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_affine_channel_mkldnn_fuse_pass.cc @@ -0,0 +1,272 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/conv_affine_channel_mkldnn_fuse_pass.h" + +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class Node; + +#define GET_CONV_BN_NODES(pattern_name) \ + /* OPERATORS */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(affine_channel, affine_channel, pattern_name); \ + /* CONV inputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \ + /* CONV outputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \ + /* Affine Channel inputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(ac_scale, ac_scale, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(ac_bias, ac_bias, pattern_name); \ + /* Affine channel outputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(ac_out, ac_out, pattern_name); /* Out */ + +void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, + const ir::Node& ac_scale, + const LoDTensor& ac_bias_tensor, + LoDTensor* eltwise_y_in_tensor) { + using EigenVectorArrayMap = + Eigen::Map>; + using ConstEigenVectorArrayMap = + Eigen::Map>; + using EigenMatrixArrayMap = Eigen::Map< + Eigen::Array>; + + // Re-compute bias of conv2d from AffineChannel + PADDLE_ENFORCE_EQ( + eltwise_y_in_tensor->dims(), ac_bias_tensor.dims(), + platform::errors::InvalidArgument( + "Tensor elementwise y(%d) and activation bias(%d) must have same " + "dimension.", + eltwise_y_in_tensor->dims().size(), ac_bias_tensor.dims().size())); + + auto* scale_tensor = scope->FindVar(ac_scale.Name())->GetMutable(); + + ConstEigenVectorArrayMap scale_array(scale_tensor->data(), + scale_tensor->numel(), 1); + ConstEigenVectorArrayMap ac_bias_array(ac_bias_tensor.data(), + ac_bias_tensor.numel(), 1); + + EigenVectorArrayMap eltwise_y_in_array( + eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), + eltwise_y_in_tensor->numel(), 1); + + eltwise_y_in_array = (eltwise_y_in_array * scale_array) + ac_bias_array; + + // Re-compute weight of conv2d from AffineChannel + auto* weights = scope->FindVar(conv_weight->Name())->GetMutable(); + auto weights_shape = weights->dims(); + auto weights_shape_2d = phi::flatten_to_2d(weights_shape, 1); + auto* weights_data = weights->mutable_data(platform::CPUPlace()); + + EigenMatrixArrayMap weights_array_2d(weights_data, weights_shape_2d[0], + weights_shape_2d[1]); + + weights_array_2d.colwise() *= scale_array; + + // Check for subnormal values that slows down convolution execution + for (int i = 0; i < weights->numel(); ++i) { + if (std::fpclassify(weights_data[i]) == FP_SUBNORMAL) weights_data[i] = 0; + } +} + +ConvAffineChannelFusePass::ConvAffineChannelFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("affine_channel")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("data_layout") + .IsStringIn({"NCHW", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); +} + +void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(name_scope_, graph); + + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + GraphPatternDetector gpd; + auto* conv_input = + gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(), + name_scope_); + conv_ac_pattern(conv_input, false /*with_eltwise_add*/); + + int found_conv_ac_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "ConvAffineChannelFusePass in op compat failed."; + return; + } + + VLOG(4) << "handle ConvAffineChannel fuse"; + + GET_CONV_BN_NODES(conv_ac_pattern); + + auto data_format = conv->Op()->GetAttrIfExists("data_format"); + if (data_format == "AnyLayout") { + LOG_FIRST_N(WARNING, 1) << "conv_affine_channel_fuse_pass is enabled, " + "it's wrong if data_format of conv is not " + "NCHW."; + } + + // Get affine_channel bias for resizing eltwise_y! + auto* ac_bias_tensor = + scope->FindVar(ac_bias->Name())->GetMutable(); + + // Create eltwise_y (conv bias) variable + VarDesc eltwise_y_in_desc( + patterns::PDNodeName(name_scope_, "eltwise_y_in")); + // Set shape && datatype manually + eltwise_y_in_desc.SetShape(phi::vectorize(ac_bias_tensor->dims())); + eltwise_y_in_desc.SetDataType( + framework::TransToProtoVarType(ac_bias_tensor->dtype())); + eltwise_y_in_desc.SetLoDLevel(ac_bias->Var()->GetLoDLevel()); + eltwise_y_in_desc.SetPersistable(true); + + // Initialize eltwise_y + auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc); + auto* eltwise_y_in_tensor = + scope->Var(eltwise_y_in_node->Name())->GetMutable(); + eltwise_y_in_tensor->Resize(ac_bias_tensor->dims()); + std::fill_n(eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), + eltwise_y_in_tensor->numel(), 0.0f); + + // update weights and biases + recompute_bias_and_weights(scope, conv_weight, *ac_scale, *ac_bias_tensor, + eltwise_y_in_tensor); + + // create an elementwise add node. + OpDesc desc; + desc.SetInput("X", std::vector({conv_out->Name()})); + desc.SetInput("Y", std::vector({eltwise_y_in_node->Name()})); + desc.SetOutput("Out", std::vector({ac_out->Name()})); + desc.SetType("elementwise_add"); + desc.SetAttr("axis", 1); + desc.SetAttr("use_mkldnn", conv->Op()->GetAttrIfExists("use_mkldnn")); + + auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. + + GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel}); + + IR_NODE_LINK_TO(conv_out, eltwise_op); + IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); + IR_NODE_LINK_TO(eltwise_op, ac_out); + found_conv_ac_count++; + }; + + gpd(graph, handler); + + AddStatis(found_conv_ac_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_affine_channel_mkldnn_fuse_pass, + paddle::framework::ir::ConvAffineChannelFusePass); + +REGISTER_PASS_CAPABILITY(conv_affine_channel_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("conv2d", 1) + .EQ("affine_channel", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_affine_channel_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_affine_channel_mkldnn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..075b6d7220316c3dfd18d10e709ff2839ca09a82 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_affine_channel_mkldnn_fuse_pass.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the Conv and ConvAffineChannel. + */ +class Graph; + +class ConvAffineChannelFusePass : public FusePassBase { + public: + ConvAffineChannelFusePass(); + virtual ~ConvAffineChannelFusePass() {} + + protected: + void ApplyImpl(ir::Graph*) const override; + const std::string name_scope_{"conv_affine_channel_mkldnn_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 01988d5f539dc44559b4f231cdee039d3f331709..f59494628ad7e5327523b54022323327f161b773 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -282,7 +282,8 @@ void CpuPassStrategy::EnableMKLDNN() { "depthwise_conv_mkldnn_pass", // "conv_bn_fuse_pass", // Execute BN passes again to "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order - "conv_transpose_bn_fuse_pass", // + "conv_affine_channel_mkldnn_fuse_pass", // + "conv_transpose_bn_fuse_pass", // "conv_transpose_eltwiseadd_bn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", // "conv_transpose_bias_mkldnn_fuse_pass", diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 9d9fbd39a5767ffe72ad579df2d31ac66eda2234..e8a9300635e2ca3fce8f240ad4e72b26b84313a0 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -426,6 +426,7 @@ class Quant2Int8MkldnnPass(object): graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') graph = self._apply_pass(graph, 'conv_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') + graph = self._apply_pass(graph, 'conv_affine_channel_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_transpose_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_transpose_eltwiseadd_bn_fuse_pass') diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 661fbbc7759c6356f7014dbc763b0ddcd212e2d1..4717dfa1eab526da822781bac362e7516c9cf270 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -143,5 +143,6 @@ if (WITH_MKLDNN) set_tests_properties(test_mkldnn_conv_mish_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_fc_mish_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_fc_elementwise_add_fuse_pass PROPERTIES TIMEOUT 120) + set_tests_properties(test_mkldnn_conv_affine_channel_fuse_pass PROPERTIES TIMEOUT 60) endif() endif() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_affine_channel_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_affine_channel_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..a35b75e69f812ceae882fcb2df46ea78b242e76d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_affine_channel_fuse_pass.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestConvAffineChannelFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_config(self, draw): + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + groups = draw(st.integers(min_value=1, max_value=3)) + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + axis = draw(st.sampled_from([1])) + filter_channel = draw(st.integers(min_value=1, max_value=16)) * 4 + filter_size = draw(st.integers(min_value=1, max_value=4)) + in_channel = groups * filter_channel + out_channel_factor = draw(st.integers(min_value=1, max_value=16)) * 4 + out_channel = groups * out_channel_factor + batch_size = draw(st.integers(min_value=1, max_value=4)) + dilations = draw( + st.lists( + st.integers( + min_value=1, max_value=2), min_size=2, max_size=2)) + paddings = draw( + st.lists( + st.integers( + min_value=0, max_value=2), min_size=2, max_size=2)) + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=2), min_size=2, max_size=2)) + has_bias = draw(st.booleans()) + + x_shape = [ + batch_size, in_channel, 64, 64 + ] if data_format == "NCHW" else [batch_size, 64, 64, in_channel] + w_shape = [out_channel, filter_channel, filter_size, filter_size] + scale_shape = [out_channel] + bias_shape = [out_channel] + + def generate_input(): + return np.random.random(x_shape).astype(np.float32) + + def generate_weight(): + return np.random.random(w_shape).astype(np.float32) + + def generate_bias(): + return np.random.random(bias_shape).astype(np.float32) + + def generate_scale_bias(): + return np.random.random(bias_shape).astype(np.float32) + + conv2d_op = OpConfig( + "conv2d", + inputs={ + "Input": ["input_data"], + "Filter": ["conv2d_weight"], + }, + outputs={"Output": ["conv_output"]}, + data_format=data_format, + dilations=dilations, + padding_algorithm=padding_algorithm, + groups=groups, + paddings=paddings, + strides=strides, + has_bias=has_bias, + is_test=True) + ac_op = OpConfig( + "affine_channel", + inputs={ + "X": ["conv_output"], + "Scale": ["affine_channel_scale"], + "Bias": ["affine_channel_bias"] + }, + outputs={"Out": ["affine_channel_ouput"]}, + data_layout=data_format) + if has_bias == True: + conv2d_op.inputs["Bias"] = ["conv2d_bias"] + ops = [conv2d_op, ac_op] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input)), + }, + weights={ + "conv2d_weight": + TensorConfig(data_gen=partial(generate_weight)), + "conv2d_bias": TensorConfig(data_gen=partial(generate_bias)), + "affine_channel_scale": + TensorConfig(data_gen=partial(generate_scale_bias)), + "affine_channel_bias": + TensorConfig(data_gen=partial(generate_scale_bias)), + }, + outputs=["affine_channel_ouput"]) + if has_bias == True: + program_config.weights["conv2d_bias"] = TensorConfig( + data_gen=partial(generate_bias)) + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4) + + def add_ignore_pass_case(self): + # If the problem has been fixed, the judgment + # in is_program_valid needs to be deleted!!! + def teller1(program_config, predictor_config): + if program_config.ops[0].attrs['data_format'] == "NHWC": + return True + return False + + # mkldnn Output has diff with bias! + def teller2(program_config, predictor_config): + return predictor_config.mkldnn_enabled() and program_config.ops[ + 0].attrs['has_bias'] == True + + self.add_ignore_check_case( + teller1, IgnoreReasons.PASS_ACCURACY_ERROR, + "The output format of conv2d is wrong when data_format attribute is NHWC, \ + because currently its fused op (Conv2DFusion) only supports data format of channel first (NCHW)." + ) + + self.add_ignore_check_case( + teller2, IgnoreReasons.PASS_ACCURACY_ERROR, + "Currently mkldnn Output has diff with bias!") + + def test(self): + self.run_and_statis( + quant=False, + passes=["conv_affine_channel_mkldnn_fuse_pass"], ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 0c57c527bfffc66ae7a552673a5991bef92dd893..5070ea2ef06a3e06efc2bf4f2a4d9e79aff6cfbf 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -659,6 +659,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_mkldnn_matmul_transpose_reshape_fuse_pass', 'test_mkldnn_scale_matmul_fuse_pass', 'test_mkldnn_inplace_fuse_pass', + 'test_mkldnn_conv_affine_channel_fuse_pass', 'test_batch_fc_op', 'test_c_comm_init_all_op', 'test_conv2d_fusion_op',