diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 39e676dac850d8a638b064240793de35a55dd238..159e42a1ef910674f042736ee0002dc0dcaa9b5b 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -103,6 +103,7 @@ pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_assign_op_pass inference) pass_library(delete_dropout_op_pass inference) pass_library(delete_concat_op_pass inference) +pass_library(conv2d_trans_filter_dilations_nxn_to_1x1_pass inference) pass_library(preln_residual_bias_fuse_pass inference) pass_library(constant_folding_pass inference) pass_library(auto_mixed_precision_pass inference) diff --git a/paddle/fluid/framework/ir/conv2d_trans_filter_dilations_nxn_to_1x1_pass.cc b/paddle/fluid/framework/ir/conv2d_trans_filter_dilations_nxn_to_1x1_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..0cad2736c1e053b214972da367603c461dfb9fc9 --- /dev/null +++ b/paddle/fluid/framework/ir/conv2d_trans_filter_dilations_nxn_to_1x1_pass.cc @@ -0,0 +1,196 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/conv2d_trans_filter_dilations_nxn_to_1x1_pass.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct Conv2dLargeDilationsPattern : public PatternBase { + Conv2dLargeDilationsPattern(PDPattern* pattern, + const std::string& name_scope); + + PATTERN_DECL_NODE(conv2d); +}; + +Conv2dLargeDilationsPattern::Conv2dLargeDilationsPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + pattern->NewNode(conv2d_repr()) + ->assert_is_op("conv2d") + ->assert_more([](Node* node) { + auto data_format = + node->Op()->GetAttrIfExists("data_format"); + if (data_format != "NCHW") return false; + auto dilations = + node->Op()->GetAttrIfExists>("dilations"); + if (dilations.size() != 2) return false; + return dilations[0] * dilations[1] > 1; + }); +} + +} // namespace patterns + +void Conv2dTransFilterDilationsNxNTo1x1Pass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + conv2d_dilation_trans(graph); +} +template +static void conv2d_dilation_trans_fn(const T* weights_data, + T* new_weights_data, + int kn, + int kc, + int kh, + int kw, + int new_kh, + int new_kw, + int dilation_h, + int dilation_w) { + for (int n = 0; n < kn; n++) { + for (int c = 0; c < kc; c++) { + for (int h = 0; h < kh; h++) { + auto h_offset = dilation_h * h; + for (int w = 0; w < kw; w++) { + auto w_offset = dilation_w * w; + auto new_offset = n * kc * new_kh * new_kw + c * new_kh * new_kw + + h_offset * new_kw + w_offset; + auto old_offset = n * kc * kh * kw + c * kh * kw + h * kw + w; + new_weights_data[new_offset] = weights_data[old_offset]; + } + } + } + } +} + +void Conv2dTransFilterDilationsNxNTo1x1Pass::conv2d_dilation_trans( + ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::Conv2dLargeDilationsPattern pattern(gpd.mutable_pattern(), + name_scope_); + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle conv2d large dilation trans"; + GET_IR_NODE_FROM_SUBGRAPH(conv2d, conv2d, pattern); + auto* block = conv2d->Op()->Block(); + auto* scope = param_scope(); + auto weights_name = conv2d->Op()->Input("Filter")[0]; + auto dilations = + conv2d->Op()->GetAttrIfExists>("dilations"); + auto* weights = + scope->FindVar(weights_name)->GetMutable(); + auto weights_shape = weights->dims(); + int kh = weights_shape[2]; + int kw = weights_shape[3]; + int new_kh = dilations[0] * (kh - 1) + 1; + int new_kw = dilations[1] * (kw - 1) + 1; + // New weights + auto new_weights_name = weights_name + "_dilation_trans"; + auto* new_weights = + scope->Var(new_weights_name)->GetMutable(); + new_weights->Resize({weights_shape[0], weights_shape[1], new_kh, new_kw}); + if (weights->dtype() == phi::DataType::FLOAT32) { + auto weights_data = weights->mutable_data(platform::CPUPlace()); + auto* new_weights_data = + new_weights->mutable_data(platform::CPUPlace()); + memset(new_weights_data, 0, new_weights->numel() * sizeof(float)); + conv2d_dilation_trans_fn(weights_data, + new_weights_data, + weights_shape[0], + weights_shape[1], + kh, + kw, + new_kh, + new_kw, + dilations[0], + dilations[1]); + } else if (weights->dtype() == phi::DataType::FLOAT16) { + auto weights_data = + weights->mutable_data(platform::CPUPlace()); + auto* new_weights_data = + new_weights->mutable_data(platform::CPUPlace()); + memset(new_weights_data, + 0, + new_weights->numel() * sizeof(phi::dtype::float16)); + conv2d_dilation_trans_fn(weights_data, + new_weights_data, + weights_shape[0], + weights_shape[1], + kh, + kw, + new_kh, + new_kw, + dilations[0], + dilations[1]); + } else { + VLOG(3) + << "Transfilter only support float32/float16 dtype of weights -- do " + "nothing and break."; + return; // Only support fp32/fp16 dtype + } + + VarDesc new_weights_desc(new_weights_name); + new_weights_desc.SetPersistable(true); + new_weights_desc.SetShape(vectorize(new_weights->dims())); + new_weights_desc.SetDataType( + framework::TransToProtoVarType(new_weights->dtype())); + auto* new_weights_node = graph->CreateVarNode(&new_weights_desc); + auto* block_new_weights_desc = block->Var(new_weights_name); + block_new_weights_desc->SetPersistable(new_weights_desc.Persistable()); + block_new_weights_desc->SetShape(new_weights_desc.GetShape()); + block_new_weights_desc->SetDataType(new_weights_desc.GetDataType()); + // Update conv2d node + conv2d->Op()->SetAttr("dilations", std::vector({1, 1})); + conv2d->Op()->RenameInput(weights_name, new_weights_name); + IR_NODE_LINK_TO(new_weights_node, conv2d); + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv2d_trans_filter_dilations_nxn_to_1x1_pass, + paddle::framework::ir::Conv2dTransFilterDilationsNxNTo1x1Pass); +REGISTER_PASS_CAPABILITY(conv2d_trans_filter_dilations_nxn_to_1x1_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().LE( + "conv2d", 1)); diff --git a/paddle/fluid/framework/ir/conv2d_trans_filter_dilations_nxn_to_1x1_pass.h b/paddle/fluid/framework/ir/conv2d_trans_filter_dilations_nxn_to_1x1_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..a7a88a78d280dddb8fe49d27b52caa7fc0fe655e --- /dev/null +++ b/paddle/fluid/framework/ir/conv2d_trans_filter_dilations_nxn_to_1x1_pass.h @@ -0,0 +1,47 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class Conv2dTransFilterDilationsNxNTo1x1Pass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void conv2d_dilation_trans(ir::Graph* graph) const; + + const std::string name_scope_{ + "conv2d_trans_filter_dilations_nxn_to_1x1_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 7aa0f44aceb4ea408d917ca428d7a2d5926e6937..5a639c75437f64b04ee9305512e844794ef56371 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -531,6 +531,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "reduce_ops_fuse_pass", "delete_cast_op_pass", "xpu_delete_cast_op_pass", + "conv2d_trans_filter_dilations_nxn_to_1x1_pass", "stack_fuse_pass", "fused_multi_transformer_xpu_pass", "relu6_fuse_pass", diff --git a/test/ir/inference/test_xpu_conv2d_trans_filter_dilations_nxn_to_1x1_pass.py b/test/ir/inference/test_xpu_conv2d_trans_filter_dilations_nxn_to_1x1_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..74348e5f92f7c889347975dfaf308d67ee225c86 --- /dev/null +++ b/test/ir/inference/test_xpu_conv2d_trans_filter_dilations_nxn_to_1x1_pass.py @@ -0,0 +1,164 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestConv2dTransFilterDilationsNxNTo1x1PassPass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["conv2d"], (1e-3, 1e-3) + + def is_program_valid(self, prog_config): + paddings = prog_config.ops[0].attrs["paddings"] + strides = prog_config.ops[0].attrs["strides"] + groups = prog_config.ops[0].attrs["groups"] + padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"] + dilations = prog_config.ops[0].attrs["dilations"] + data_format = prog_config.ops[0].attrs["data_format"] + filter_shape = prog_config.weights["conv2d_weight"].shape + input_shape = prog_config.inputs["conv2d_input"].shape + if data_format != "NCHW": + return False + if padding_algorithm == "VALID": + if ( + (input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) + / strides[0] + + 1 + ) < 1 or ( + (input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) + / strides[1] + + 1 + ) < 1: + return False + if padding_algorithm == "EXPLICIT": + if ( + ( + input_shape[2] + + 2 * paddings[0] + - (dilations[0] * (filter_shape[2] - 1) + 1) + ) + / strides[0] + + 1 + ) < 1 or ( + ( + input_shape[3] + + 2 * paddings[1] + - (dilations[1] * (filter_shape[3] - 1) + 1) + ) + / strides[1] + + 1 + ) < 1: + return False + if data_format == "NCHW": + if input_shape[1] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + return True + + def sample_program_config(self, draw): + data_format = draw(st.sampled_from(["NCHW"])) + + x_shape = draw( + st.lists( + st.integers(min_value=12, max_value=12), min_size=4, max_size=4 + ) + ) + x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + + # 3. Generate legal shape of input:Y of conv2d + w_shape = draw( + st.lists( + st.integers(min_value=3, max_value=3), min_size=4, max_size=4 + ) + ) + if data_format == "NCHW": + w_shape[1] = x_shape[1] + + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "VALID"])) + + groups = draw(st.integers(min_value=1, max_value=1)) + + paddings = draw( + st.lists( + st.integers(min_value=1, max_value=1), min_size=2, max_size=2 + ) + ) + strides = draw( + st.lists( + st.integers(min_value=1, max_value=1), min_size=2, max_size=2 + ) + ) + dilations = draw( + st.lists( + st.integers(min_value=1, max_value=5), min_size=2, max_size=2 + ) + ) + + def generate_data(shape): + return np.random.random(shape).astype(np.float32) + + # Here we will compose a program + # Still has some risks that the program is invalid or cause bug while running + # Use function `is_program_valid` to filter the invalid programs before running + # Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing + conv2d_op = OpConfig( + "conv2d", + inputs={ + "Input": ["conv2d_input"], + "Filter": ["conv2d_weight"], + }, + outputs={"Output": ["conv2d_out"]}, + data_format=data_format, + dilations=dilations, + padding_algorithm=padding_algorithm, + groups=groups, + paddings=paddings, + strides=strides, + has_bias=False, + ) + + program_config = ProgramConfig( + ops=[conv2d_op], + inputs={ + "conv2d_input": TensorConfig( + data_gen=partial(generate_data, x_shape) + ), + }, + weights={ + "conv2d_weight": TensorConfig( + data_gen=partial(generate_data, w_shape) + ), + }, + outputs=["conv2d_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["conv2d_trans_filter_dilations_nxn_to_1x1_pass"], + ) + + +if __name__ == "__main__": + unittest.main()