diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 4b13152f554946a52137989456f8600f91957c40..7a0093c3be8a80c45bb20bf302a08a17529d34fc 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -240,6 +240,8 @@ if(WITH_XPU) pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(conv2d_transpose_xpu_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 40f7fae13a0246113c971a7b7ae6af4a241db7fe..6e12cf00e903bc7410e642e31e725a2291b4d9a0 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -188,8 +188,10 @@ void AutoMixedPrecisionPass::SetDefaultBlacklist() const { "c_softmax_with_cross_entropy", "cross_entropy", "cross_entropy2", +#ifndef PADDLE_WITH_XPU // slower than fp32 "conv2d_transpose", +#endif // default fp32 can avoid return inf when the sum value large than 65504 "reduce_sum", }); diff --git a/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..784d5d4ec029f88e22e21ea5ebaa452eb078ef60 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc @@ -0,0 +1,495 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "glog/logging.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct Conv2dTransposeXPUPattern : public PatternBase { + Conv2dTransposeXPUPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& act_type, + bool with_ew_bias, + bool with_bn); + // operator + PATTERN_DECL_NODE(conv); + PATTERN_DECL_NODE(ew_bias_add); + PATTERN_DECL_NODE(bn); + PATTERN_DECL_NODE(act); + // conv param + PATTERN_DECL_NODE(input); + PATTERN_DECL_NODE(conv_filter); + PATTERN_DECL_NODE(conv_out); + // ew param + PATTERN_DECL_NODE(ew_bias_add_y); + PATTERN_DECL_NODE(ew_bias_add_out); + // bn param + PATTERN_DECL_NODE(bn_bias); + PATTERN_DECL_NODE(bn_mean); + PATTERN_DECL_NODE(bn_scale); + PATTERN_DECL_NODE(bn_var); + PATTERN_DECL_NODE(bn_out); + PATTERN_DECL_NODE(bn_var_out); + PATTERN_DECL_NODE(bn_mean_out); + PATTERN_DECL_NODE(bn_saved_var); + PATTERN_DECL_NODE(bn_saved_mean); + // act param + PATTERN_DECL_NODE(act_out); + + private: + std::string act_type_; + bool with_bn_; + bool with_ew_bias_; +}; + +Conv2dTransposeXPUPattern::Conv2dTransposeXPUPattern( + PDPattern* pattern, + const std::string& name_scope, + const std::string& act_type, + bool with_ew_bias, + bool with_bn) + : PatternBase(pattern, name_scope, name_scope), + act_type_(act_type), + with_bn_(with_bn), + with_ew_bias_(with_ew_bias) { + // deconv op + auto conv = pattern->NewNode(conv_repr())->assert_is_op("conv2d_transpose"); + auto input = pattern->NewNode(input_repr()) + ->assert_is_op_input("conv2d_transpose", "Input") + ->AsInput() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 4; + }); + auto conv_filter = pattern->NewNode(conv_filter_repr()) + ->assert_is_op_input("conv2d_transpose", "Filter") + ->AsInput(); + auto conv_out = pattern->NewNode(conv_out_repr()) + ->assert_is_op_output("conv2d_transpose", "Output") + ->assert_has_n_outputs(1); + conv->LinksFrom({input, conv_filter}).LinksTo({conv_out}); + + // elementwise op + PDNode* ew_bias_add = nullptr; + PDNode* ew_bias_add_y = nullptr; + PDNode* ew_bias_add_out = nullptr; + if (with_ew_bias_) { + conv_out->assert_is_op_input("elementwise_add", "X"); + ew_bias_add_y = pattern->NewNode(ew_bias_add_y_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_persistable_var() + ->assert_has_n_outputs(1) + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }); + ew_bias_add = + pattern->NewNode(ew_bias_add_repr())->assert_is_op("elementwise_add"); + ew_bias_add_out = pattern->NewNode(ew_bias_add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + if (with_bn_ || !act_type_.empty()) { + ew_bias_add_out->assert_has_n_outputs(1); + } + ew_bias_add->LinksFrom({conv_out, ew_bias_add_y}) + .LinksTo({ew_bias_add_out}); + } else { + ew_bias_add_out = conv_out; + } + + // batch_norm op + PDNode* bn = nullptr; + PDNode* bn_bias = nullptr; + PDNode* bn_mean = nullptr; + PDNode* bn_scale = nullptr; + PDNode* bn_var = nullptr; + PDNode* bn_out = nullptr; + PDNode* bn_mean_out = nullptr; + PDNode* bn_saved_mean = nullptr; + PDNode* bn_var_out = nullptr; + PDNode* bn_saved_var = nullptr; + if (with_bn_) { + ew_bias_add_out->assert_is_op_input("batch_norm", "X"); + bn_bias = pattern->NewNode(bn_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Bias") + ->assert_has_n_outputs(1); + bn_mean = pattern->NewNode(bn_mean_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Mean") + ->assert_has_n_outputs(1); + bn_scale = pattern->NewNode(bn_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Scale") + ->assert_has_n_outputs(1); + bn_var = pattern->NewNode(bn_var_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Variance") + ->assert_has_n_outputs(1); + bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm"); + bn_out = + pattern->NewNode(bn_out_repr())->assert_is_op_output("batch_norm", "Y"); + if (!act_type_.empty()) { + bn_out->assert_has_n_outputs(1); + } + bn_mean_out = pattern->NewNode(bn_mean_out_repr()) + ->assert_is_op_output("batch_norm", "MeanOut"); + bn_saved_mean = pattern->NewNode(bn_saved_mean_repr()) + ->assert_is_op_output("batch_norm", "SavedMean"); + bn_var_out = pattern->NewNode(bn_var_out_repr()) + ->assert_is_op_output("batch_norm", "VarianceOut"); + bn_saved_var = pattern->NewNode(bn_saved_var_repr()) + ->assert_is_op_output("batch_norm", "SavedVariance"); + bn->LinksFrom({ew_bias_add_out, bn_bias, bn_mean, bn_scale, bn_var}) + .LinksTo( + {bn_out, bn_mean_out, bn_var_out, bn_saved_mean, bn_saved_var}); + } else { + bn_out = ew_bias_add_out; + } + + // act + PDNode* act = nullptr; + PDNode* act_out = nullptr; + if (!act_type_.empty()) { + bn_out->assert_is_op_input(act_type_, "X"); + act = pattern->NewNode(act_repr())->assert_is_op(act_type_); + act_out = + pattern->NewNode(act_out_repr())->assert_is_op_output(act_type_, "Out"); + act->LinksFrom({bn_out}).LinksTo({act_out}); + } else { + act_out = bn_out; + } + act_out->AsOutput(); +} +} // namespace patterns + +/* fuse conv2d block in resnet50-like model to xpu_conv2d op */ +/* For example: */ +/* graph[1]: sub block */ +/* in_Input */ +/* | */ +/* | */ +/* conv2d_transpose----in_Filter */ +/* | */ +/* | */ +/* elementwise_add -----ew_add */ +/* | */ +/* | */ +/* batch_norm ------in_Bias */ +/* | */ +/* | */ +/* act */ +/* | */ +/* | */ +/* out_Out */ +/* */ +class Conv2dTransposeXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + int ApplyImpl(ir::Graph* graph, + const std::string& act_type, + bool with_ew_bias, + bool with_bn) const; + + const std::string name_scope_{"conv2d_transpose_xpu_fuse_pass"}; +}; + +void Conv2dTransposeXPUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + int found_subgraph_count = 0; + for (auto with_bn : {true, false}) { + for (auto with_ew_bias : {true, false}) { + for (auto act_type : {"relu", ""}) { + found_subgraph_count += + ApplyImpl(graph, act_type, with_ew_bias, with_bn); + } + } + } + AddStatis(found_subgraph_count); +} + +int Conv2dTransposeXPUFusePass::ApplyImpl(ir::Graph* graph, + const std::string& act_type, + bool with_ew_bias, + bool with_bn) const { + GraphPatternDetector gpd; + patterns::Conv2dTransposeXPUPattern pattern( + gpd.mutable_pattern(), name_scope_, act_type, with_ew_bias, with_bn); + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle Conv2dTransposeXPUFusePass fuse"; + /* declare operator node's name */ + GET_IR_NODE(conv); + GET_IR_NODE(ew_bias_add); + GET_IR_NODE(bn); + GET_IR_NODE(act); + /* declare variable node's name*/ + GET_IR_NODE(input); + GET_IR_NODE(conv_filter); + GET_IR_NODE(conv_out); + GET_IR_NODE(ew_bias_add_y); + GET_IR_NODE(ew_bias_add_out); + GET_IR_NODE(bn_bias); + GET_IR_NODE(bn_mean); + GET_IR_NODE(bn_scale); + GET_IR_NODE(bn_var); + GET_IR_NODE(bn_out); + GET_IR_NODE(bn_var_out); + GET_IR_NODE(bn_mean_out); + GET_IR_NODE(bn_saved_var); + GET_IR_NODE(bn_saved_mean); + GET_IR_NODE(act_out); + auto* block = conv->Op()->Block(); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + // recompute bias and weight for conv2d_transpose_xpu op + auto* filter_t = + scope->FindVar(conv_filter->Name())->GetMutable(); + // conv_filter fp16 --> fp32 + auto tensor_type = filter_t->dtype(); + if (tensor_type == phi::DataType::FLOAT16) { + CastToFp32(filter_t, nullptr); + } + auto filter_dims = filter_t->dims(); + bool has_bias = with_bn || with_ew_bias; + Node* fusion_bias_node = nullptr; + int groups = PADDLE_GET_CONST(int, conv->Op()->GetAttr("groups")); + int out_c = filter_dims[1] * groups; + + // ew bias + if (with_ew_bias) { + auto* ew_bias_add_y_t = + scope->FindVar(ew_bias_add_y->Name())->GetMutable(); + auto ew_bias_add_y_dims = ew_bias_add_y_t->dims(); + PADDLE_ENFORCE_EQ(out_c, + ew_bias_add_y_dims[0], + platform::errors::InvalidArgument( + "the shape[%d] of elewise bias tensor " + "must equal out_channel[%d] of conv", + ew_bias_add_y_dims[0], + out_c)); + PrepareBias(graph, scope, block, ew_bias_add_y, &fusion_bias_node); + } + // bn + if (with_bn) { + auto bn_bias_t = + scope->Var(bn_bias->Name())->GetMutable(); + PADDLE_ENFORCE_EQ(out_c, + bn_bias_t->dims()[0], + platform::errors::InvalidArgument( + "the shape[%d] of bn bias tensor " + "must equal out_channel[%d] of conv", + bn_bias_t->dims()[0], + out_c)); + auto bn_scale_t = + scope->Var(bn_scale->Name())->GetMutable(); + auto bn_mean_t = + scope->Var(bn_mean->Name())->GetMutable(); + auto bn_var_t = + scope->Var(bn_var->Name())->GetMutable(); + float* filter_ptr = filter_t->data(); + float* bn_scale_ptr = bn_scale_t->data(); + float* bn_bias_ptr = bn_bias_t->data(); + float* bn_mean_ptr = bn_mean_t->data(); + float* bn_var_ptr = bn_var_t->data(); + auto mean_len = bn_mean_t->numel(); // oc + + float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); + // bias + if (fusion_bias_node) { + auto fusion_bias_t = scope->Var(fusion_bias_node->Name()) + ->GetMutable(); + float* fusion_bias_ptr = fusion_bias_t->data(); + for (int i = 0; i < mean_len; ++i) { + bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); + fusion_bias_ptr[i] = + bn_bias_ptr[i] + + (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; + } + } else { + PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); + auto fusion_bias_t = scope->Var(fusion_bias_node->Name()) + ->GetMutable(); + float* fusion_bias_ptr = fusion_bias_t->data(); + for (int i = 0; i < mean_len; ++i) { + bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); + fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i]; + } + } + // compute new conv_weight, weight is ic-oc/g-h-w + int cout_group = filter_dims[1]; + int cin_group = filter_dims[0] / groups; + int c_size = cout_group * filter_dims[2] * filter_dims[3]; + int hw = filter_dims[2] * filter_dims[3]; + for (int g = 0; g < groups; g++) { + for (int k = 0; k < cin_group; ++k) { + for (int i = 0; i < cout_group; ++i) { + auto ptr_row = + filter_ptr + g * cin_group * c_size + k * c_size + i * hw; + for (int j = 0; j < hw; ++j) { + ptr_row[j] *= bn_scale_ptr[g * cout_group + i]; + } + } + } + } + } + // filter max + Node* filter_int16 = nullptr; + Node* filter_max = nullptr; + PrepareWeight( + graph, scope, block, conv_filter, &filter_int16, &filter_max, false); + // output && output max + std::string conv2d_xpu_out_name; + if (!act_type.empty()) { + conv2d_xpu_out_name = act_out->Name(); + } else if (with_bn) { + conv2d_xpu_out_name = bn_out->Name(); + } else if (with_ew_bias) { + conv2d_xpu_out_name = ew_bias_add_out->Name(); + } else { + conv2d_xpu_out_name = conv_out->Name(); + } + std::string conv_out_max_name = conv2d_xpu_out_name + "_max"; + VarDesc conv_out_max_desc(conv_out_max_name); + Node* conv2d_xpu_out_max = graph->CreateVarNode(&conv_out_max_desc); + // Generate conv2d_xpu op + framework::OpDesc conv2d_xpu_op_desc(block); + // set input&output var + conv2d_xpu_op_desc.SetType("conv2d_transpose_xpu"); + conv2d_xpu_op_desc.SetInput("x", {input->Name()}); + conv2d_xpu_op_desc.SetInput("filter", {filter_int16->Name()}); + conv2d_xpu_op_desc.SetInput("filter_max", {filter_max->Name()}); + conv2d_xpu_op_desc.SetOutput("out", {conv2d_xpu_out_name}); + conv2d_xpu_op_desc.SetOutput("out_max", {conv_out_max_name}); + // set fusion_bias input node + if (has_bias) { + conv2d_xpu_op_desc.SetInput("bias", {fusion_bias_node->Name()}); + } + conv2d_xpu_op_desc.SetAttr("has_bias", has_bias); + // set attrs of conv2d_xpu + if (!act_type.empty()) { + conv2d_xpu_op_desc.SetAttr("with_act", true); + } else { + conv2d_xpu_op_desc.SetAttr("with_act", false); + } + conv2d_xpu_op_desc.SetAttr("act_type", act_type); + conv2d_xpu_op_desc.SetAttr( + "padding_algorithm", + conv->Op()->GetAttrIfExists("padding_algorithm")); + conv2d_xpu_op_desc.SetAttr( + "output_size", + conv->Op()->GetAttrIfExists>("output_size")); + conv2d_xpu_op_desc.SetAttr( + "output_padding", + conv->Op()->GetAttrIfExists>("output_padding")); + conv2d_xpu_op_desc.SetAttr( + "dilations", + PADDLE_GET_CONST(std::vector, conv->Op()->GetAttr("dilations"))); + conv2d_xpu_op_desc.SetAttr( + "paddings", + PADDLE_GET_CONST(std::vector, conv->Op()->GetAttr("paddings"))); + conv2d_xpu_op_desc.SetAttr( + "groups", PADDLE_GET_CONST(int, conv->Op()->GetAttr("groups"))); + conv2d_xpu_op_desc.SetAttr( + "strides", + PADDLE_GET_CONST(std::vector, conv->Op()->GetAttr("strides"))); + conv2d_xpu_op_desc.SetAttr( + "data_format", conv->Op()->GetAttrIfExists("data_format")); + + auto* conv2d_xpu = graph->CreateOpNode(&conv2d_xpu_op_desc); + IR_NODE_LINK_TO(input, conv2d_xpu); + IR_NODE_LINK_TO(filter_int16, conv2d_xpu); + IR_NODE_LINK_TO(filter_max, conv2d_xpu); + if (has_bias) { + SAFE_IR_NODE_LINK_TO(fusion_bias_node, conv2d_xpu); + } + if (act_out) { + IR_NODE_LINK_TO(conv2d_xpu, act_out); + } else if (bn_out) { + IR_NODE_LINK_TO(conv2d_xpu, bn_out); + } else if (ew_bias_add_out) { + IR_NODE_LINK_TO(conv2d_xpu, ew_bias_add_out); + } else { + IR_NODE_LINK_TO(conv2d_xpu, conv_out); + } + IR_NODE_LINK_TO(conv2d_xpu, conv2d_xpu_out_max); + // delete useless node + std::unordered_set delete_nodes = {conv}; + if (act != nullptr) { + delete_nodes.insert(act); + } + if (bn != nullptr) { + delete_nodes.insert(bn); + delete_nodes.insert(bn_bias); + delete_nodes.insert(bn_var); + delete_nodes.insert(bn_mean); + delete_nodes.insert(bn_scale); + delete_nodes.insert(bn_var_out); + delete_nodes.insert(bn_mean_out); + delete_nodes.insert(bn_saved_var); + delete_nodes.insert(bn_saved_mean); + } + if (ew_bias_add) { + delete_nodes.insert(ew_bias_add); + delete_nodes.insert(ew_bias_add_y); + } + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv2d_transpose_xpu_fuse_pass, + paddle::framework::ir::Conv2dTransposeXPUFusePass); + +REGISTER_PASS_CAPABILITY(conv2d_transpose_xpu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "conv2d_transpose_xpu", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index ea852d1467f157606616e7983be3d49ab2112cee..d065149750bfddcc9335ddc176ae53b4638a06f8 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -540,6 +540,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "redundant_squeeze_unsqueeze_elimination_pass", "fc_xpu_fuse_pass", "conv2d_xpu_fuse_pass", + "conv2d_transpose_xpu_fuse_pass", "add_activation_xpu_fuse_pass", "yolo_box_xpu_fuse_pass", "link_xpu_op_max_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 4aa981c95fe2a99d59d9f332ec25ccef438a77fa..f9dc939bf5de57f55df66d7df6b80f80dfefa88d 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -14,6 +14,16 @@ data_type : x optional : x_max, y_max +- op : conv2d_transpose_xpu + args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format, bool has_bias, bool with_act, str act_type) + output : Tensor(out), Tensor(out_max) + infer_meta : + func : Conv2dTransposeXPUInferMeta + kernel : + func : conv2d_transpose_xpu + data_type : x + optional : bias, x_max + - op : conv2d_xpu args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param) output : Tensor(out), Tensor(out_max) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index d64e67b92a7e85fdafb3c5b2751a03475e4afadd..63d7a1b3ced87a0daa3853d5281e6c715919c144 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -171,6 +171,8 @@ XPUOpMap& get_kl2_ops() { {"conv2d_transpose_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"conv2d_transpose", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"conv2d_transpose_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"cumsum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 437fbda9f476f69ffc8837df14e80156771ff003..be9506511a9a46a575377bad0766e12e19b014aa 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -518,4 +518,172 @@ void YoloBoxXPUInferMeta(const MetaTensor& x, out_max->set_layout(x.layout()); } +void ConvTransposeXPUInferMeta(const MetaTensor& x, + const MetaTensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::vector& output_padding, + const std::vector& output_size, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + MetaTensor* out, + MetaTensor* out_max) { + auto x_dims = x.dims(); + auto filter_dims = filter.dims(); + std::vector paddings_ = paddings; + std::vector dilations_ = dilations; + PADDLE_ENFORCE_EQ( + x_dims.size() == 4, + true, + errors::InvalidArgument("Input of Op(conv_transpose) should be 4-D " + "Tensor. But received: %u-D Tensor, " + "the shape of input is [%s]", + x_dims.size(), + x_dims)); + PADDLE_ENFORCE_EQ( + x_dims.size(), + filter_dims.size(), + errors::InvalidArgument( + "The input's dimension size and filter's dimension size of " + "Op (conv_transpose) should be equal. But received: the shape of " + "input is [%s], the dimension size of input is [%d], the shape " + "of filter is [%s], the dimension size of filter is [%d]. ", + x_dims, + x_dims.size(), + filter_dims, + filter_dims.size())); + + int stride_size = strides.size(); + for (int i = 0; i < stride_size; ++i) { + PADDLE_ENFORCE_GT( + strides[i], + 0, + errors::InvalidArgument( + "The stride of Op(Conv) should be larget than 0, but received " + "stride is %d.", + strides[i])); + } + + int in_sub_stride_size = x_dims.size() - stride_size; + + PADDLE_ENFORCE_EQ( + x_dims.size() - strides.size(), + 2U, + errors::InvalidArgument( + "The input's dimension size minus Attr(stride)'s size must " + "be euqal to 2 for Op(conv_transpose). But received: [%d], the " + "input's dimension size is [%d], the shape of input " + "is [%s], the Attr(stride)'s size is [%d].", + in_sub_stride_size, + x_dims.size(), + x_dims, + strides.size())); + if (output_size.size()) + PADDLE_ENFORCE_EQ( + output_size.size(), + strides.size(), + errors::InvalidArgument( + "The Attr(output_size) and Attr(stride) of Op(conv_transpose) " + "should be the same.")); + if (output_padding.size()) + PADDLE_ENFORCE_EQ( + output_padding.size(), + strides.size(), + errors::InvalidArgument( + "The Attr(output_padding) and Attr(stride) of Op(conv_transpose) " + "should be the same.")); + + const int64_t C = + (data_format != "NHWC" ? x_dims[1] : x_dims[x_dims.size() - 1]); + PADDLE_ENFORCE_EQ( + C, + filter_dims[0], + errors::InvalidArgument( + "The number of input channels should be equal to filter channels " + "for Op(conv_transpose). But received: the input's channels is " + "[%d], the shape of input is [%s], the filter's channels is [%d], " + "the shape of filter is [%s]. The data_format is %s." + "The error may come from wrong data_format setting.", + C, + x_dims, + filter_dims[0], + filter_dims, + data_format)); + + DDim x_data_dims; + if (data_format != "NHWC") { + x_data_dims = slice_ddim(x_dims, 2, x_dims.size()); + } else { + x_data_dims = slice_ddim(x_dims, 1, x_dims.size() - 1); + } + DDim filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = vectorize(filter_data_dims); + UpdatePaddingAndDilation( + &paddings_, &dilations_, padding_algorithm, x_data_dims, strides, ksize); + + std::vector output_shape({x_dims[0]}); + if (data_format != "NHWC") { + output_shape.push_back(filter_dims[1] * groups); + } + const int offset = (data_format != "NHWC" ? 2 : 1); + for (size_t i = 0; i < strides.size(); ++i) { + auto filter_extent = dilations_[i] * (filter_dims[i + 2] - 1) + 1; + auto infer_shape = (x_dims[i + offset] > 0) + ? (x_dims[i + offset] - 1) * strides[i] - + paddings_[2 * i] - paddings_[2 * i + 1] + + filter_extent + : -1; + if (output_size.size()) { + output_shape.push_back(output_size[i]); + } else if (output_padding.size()) { + output_shape.push_back((infer_shape + output_padding[i])); + } else { + output_shape.push_back(infer_shape); + } + } + if (data_format == "NHWC") { + output_shape.push_back(filter_dims[1] * groups); + } + + out->set_dims(make_ddim(output_shape)); + out->set_dtype(x.dtype()); + out_max->set_dims(phi::make_ddim({6})); +} + +void Conv2dTransposeXPUInferMeta(const MetaTensor& x, + const MetaTensor& x_max, + const MetaTensor& filter, + const MetaTensor& filter_max, + const MetaTensor& bias, + const std::vector& strides, + const std::vector& paddings, + const std::vector& output_padding, + const IntArray& output_size, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + bool has_bias, + bool with_act, + const std::string& act_type, + MetaTensor* out, + MetaTensor* out_max) { + std::vector vec_output_size(output_size.GetData().begin(), + output_size.GetData().end()); + ConvTransposeXPUInferMeta(x, + filter, + strides, + paddings, + output_padding, + vec_output_size, + padding_algorithm, + groups, + dilations, + data_format, + out, + out_max); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index b4456d07a7a5de7ad73c110cca04ae3ca631cdde..8fc311ebdd89c4cb1b523395ab3ca7e2143d0a09 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -145,4 +145,22 @@ void YoloBoxXPUInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* out_max); +void Conv2dTransposeXPUInferMeta(const MetaTensor& x, + const MetaTensor& x_max, + const MetaTensor& filter, + const MetaTensor& filter_max, + const MetaTensor& bias, + const std::vector& strides, + const std::vector& paddings, + const std::vector& output_padding, + const IntArray& output_size, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + bool has_bias, + bool with_act, + const std::string& act_type, + MetaTensor* out, + MetaTensor* out_max); } // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/conv_transpose_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/conv_transpose_xpu_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a62620b53a681f6cee8c39c53d24da4482849f2 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/conv_transpose_xpu_kernel.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" + +namespace phi { +namespace fusion { +template +void Conv2dTransposeXPUKernel(const Context& ctx, + const DenseTensor& x, + const paddle::optional& x_max, + const DenseTensor& filter, + const DenseTensor& filter_max, + const paddle::optional& bias, + const std::vector& strides, + const std::vector& paddings, + const std::vector& output_padding, + const IntArray& output_size, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + bool has_bias, + bool with_act, + const std::string& act_type, + DenseTensor* out, + DenseTensor* out_max) { + using XPUT = typename XPUTypeTrait::Type; + + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + DenseTensor filter_ = filter; + ctx.template Alloc(out); + ctx.template Alloc(out_max); + bool is_nchw; + is_nchw = (data_format == "NHWC") ? false : true; + + DDim in_data_dims = slice_ddim(x.dims(), 2, x.dims().size()); // hw + DDim filter_data_dims = slice_ddim(filter_.dims(), 2, filter_.dims().size()); + std::vector ksize = vectorize(filter_data_dims); + std::vector paddings_ = paddings; + std::vector dilations_ = dilations; + UpdatePaddingAndDilation( + &paddings_, &dilations_, padding_algorithm, in_data_dims, strides, ksize); + + const int batch_size = static_cast(x.dims()[0]); + const int img_yc = static_cast(x.dims()[1]); + const int img_xc = static_cast(out->dims()[1]); + const int img_xh = static_cast(out->dims()[2]); + const int img_xw = static_cast(out->dims()[3]); + auto act = xpu::Activation_t::LINEAR; + if (with_act) { + if (act_type == "relu") { + act = xpu::Activation_t::RELU; + } + } + auto bias_data = + bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data(); + auto x_max_data = + x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data(); + auto filter_max_data = filter_max.data(); + + int r = xpu::conv2d_transpose_fusion_v2( + ctx.x_context(), + reinterpret_cast(x.data()), + filter_.data(), + reinterpret_cast(out->data()), + batch_size, + img_yc, + img_xh, + img_xw, + img_xc, + ksize, + strides, + paddings_, + dilations_, + groups, + x_max_data, + filter_max_data, + out_max->data(), + bias_data, + act, + is_nchw); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_fusion_v2"); +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(conv2d_transpose_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::Conv2dTransposeXPUKernel, + float, + phi::dtype::float16) {} diff --git a/test/ir/inference/test_xpu_conv2d_transpose_fuse_pass.py b/test/ir/inference/test_xpu_conv2d_transpose_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..7aafce7bc8bf6cea139a3164940881497cd4b48c --- /dev/null +++ b/test/ir/inference/test_xpu_conv2d_transpose_fuse_pass.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestConvTransposeXPUFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["conv2d_transpose_xpu"], (3e-3, 3e-3) + + def sample_program_config(self, draw): + x_shape = draw( + st.lists( + st.integers(min_value=4, max_value=16), min_size=4, max_size=4 + ) + ) + oc = draw(st.integers(min_value=2, max_value=16)) + weight_shape = [x_shape[1], oc, 4, 4] + y_shape = [oc] + has_bn = draw(st.booleans()) + has_add = draw(st.booleans()) + has_relu = draw(st.booleans()) + + def generate_data(shape): + return 0.1 * np.random.random(shape).astype(np.float32) + + deconv_op = OpConfig( + "conv2d_transpose", + inputs={"Input": ["input_x"], "Filter": ["weight_x"]}, + outputs={"Output": ["output_x"]}, + data_format="NCHW", + dilations=[1, 1], + groups=1, + paddings=[0, 0], + padding_algorithm="EXPLICIT", + strides=[4, 4], + fuse_relu=False, + ) + input_name_op = "output_x" + ops = [deconv_op] + + if has_add: + add_op = OpConfig( + "elementwise_add", + inputs={"X": [input_name_op], "Y": ["bias"]}, + outputs={"Out": ["add_out"]}, + axis=1, + ) + input_name_op = "add_out" + ops.append(add_op) + + if has_bn: + bn_op = OpConfig( + "batch_norm", + inputs={ + "X": [input_name_op], + "Bias": ["bn_bias"], + "Mean": ["bn_mean"], + "Scale": ["bn_scale"], + "Variance": ["bn_var"], + }, + outputs={ + "Y": ["bn_y"], + "MeanOut": ["bn_mean"], + "SavedMean": ["bn_mean_save"], + "SavedVariance": ["bn_save_var"], + "VarianceOut": ["bn_var"], + }, + data_layout="NCHW", + epsilon=0.000009999999747378752, + momentum=0.89999, + is_test=True, + use_global_stats=True, + ) + input_name_op = "bn_y" + ops.append(bn_op) + + if has_relu: + relu_op = OpConfig( + "relu", + inputs={"X": [input_name_op]}, + outputs={"Out": ["relu_out"]}, + ) + input_name_op = "relu_out" + ops.append(relu_op) + + program_config = ProgramConfig( + ops=ops, + weights={ + "weight_x": TensorConfig( + data_gen=partial(generate_data, weight_shape) + ), + "bias": TensorConfig(data_gen=partial(generate_data, y_shape)), + "bn_bias": TensorConfig( + data_gen=partial(generate_data, y_shape) + ), + "bn_mean": TensorConfig( + data_gen=partial(generate_data, y_shape) + ), + "bn_scale": TensorConfig( + data_gen=partial(generate_data, y_shape) + ), + "bn_var": TensorConfig( + data_gen=partial(generate_data, y_shape) + ), + }, + inputs={ + "input_x": TensorConfig( + data_gen=partial(generate_data, x_shape) + ), + }, + outputs=[input_name_op], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=100, + passes=["conv2d_transpose_xpu_fuse_pass"], + ) + + +if __name__ == "__main__": + unittest.main()