diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index f1583f5312f4aedeaa13bb01dddf7329de8f2f7e..1c24fa14b5fdc400943f71d39d66dca4e10cc0e4 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -228,6 +228,7 @@ if(WITH_XPU) SRCS xpu/pass_utils.cc DEPS pass xpu_quant_utils) set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils) + pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) diff --git a/paddle/fluid/framework/ir/xpu/yolo_box_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/yolo_box_xpu_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a97821ea4bddc2e933971fe95261e4a390f82f6 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/yolo_box_xpu_fuse_pass.cc @@ -0,0 +1,430 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +/* +fuse block in yolo-like model to yolo_box_xpu op +------------------------------------------------------ +sub block: + x + / | \ + / | \ + / | \ + slice slice slice + | | | + | | | + ew_mul ew_mul | + | | | + | | | + ew_sub ew_pow | + | | | + | | | + ew_add ew_mul_2 | + | | | + | | | + ew_mul_2 | | + \ | / + \ | / + \ | / + concat + | + y +------------------------------------------------------ +After the pass is applied: + x + grid[left_ew_add_y] | offset[left_ew_sub_y] + \ | / + \ | / +stride[left_ew_mul_2_y] -- yolo_box_xpu --- anchor_grid[mid_ew_mul_2_y] + | \ + | \ + | \ + y y_max +*/ +struct YoloBoxXPUPattern : public PatternBase { + YoloBoxXPUPattern(PDPattern* pattern, + const std::string& name_scope, + bool with_left_ew_sub_); + // declare operator node's name + PATTERN_DECL_NODE(left_slice); + PATTERN_DECL_NODE(mid_slice); + PATTERN_DECL_NODE(right_slice); + PATTERN_DECL_NODE(left_ew_mul); + PATTERN_DECL_NODE(left_ew_sub); + PATTERN_DECL_NODE(left_ew_add); + PATTERN_DECL_NODE(left_ew_mul_2); + PATTERN_DECL_NODE(mid_ew_mul); + PATTERN_DECL_NODE(mid_ew_pow); + PATTERN_DECL_NODE(mid_ew_mul_2); + PATTERN_DECL_NODE(concat); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(left_slice_out); + PATTERN_DECL_NODE(left_ew_mul_out); + PATTERN_DECL_NODE(left_ew_sub_y); + PATTERN_DECL_NODE(left_ew_sub_out); + PATTERN_DECL_NODE(left_ew_add_y); + PATTERN_DECL_NODE(left_ew_add_out); + PATTERN_DECL_NODE(left_ew_mul_2_y); + PATTERN_DECL_NODE(left_ew_mul_2_out); + PATTERN_DECL_NODE(mid_slice_out); + PATTERN_DECL_NODE(mid_ew_mul_out); + PATTERN_DECL_NODE(mid_ew_pow_out); + PATTERN_DECL_NODE(mid_ew_mul_2_y); + PATTERN_DECL_NODE(mid_ew_mul_2_out); + PATTERN_DECL_NODE(right_slice_out); + PATTERN_DECL_NODE(concat_out); + + private: + bool with_left_ew_sub_{true}; +}; + +YoloBoxXPUPattern::YoloBoxXPUPattern(PDPattern* pattern, + const std::string& name_scope, + bool with_left_ew_sub) + : PatternBase(pattern, name_scope, name_scope), + with_left_ew_sub_(with_left_ew_sub) { + auto x = pattern->NewNode(x_repr()) + ->assert_is_op_output("sigmoid", "Out") + ->assert_has_n_outputs(3); + auto* left_slice = + pattern->NewNode(left_slice_repr()) + ->assert_is_op("strided_slice") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists>("axes") == + std::vector{4} && + op_desc->GetAttrIfExists>("starts") == + std::vector{0} && + op_desc->GetAttrIfExists>("ends") == + std::vector{2}; + }); + auto* left_slice_out = pattern->NewNode(left_slice_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_input("elementwise_mul", "X"); + left_slice->LinksFrom({x}).LinksTo({left_slice_out}); + auto* mid_slice = + pattern->NewNode(mid_slice_repr()) + ->assert_is_op("strided_slice") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists>("axes") == + std::vector{4} && + op_desc->GetAttrIfExists>("starts") == + std::vector{2} && + op_desc->GetAttrIfExists>("ends") == + std::vector{4}; + }); + auto* mid_slice_out = pattern->NewNode(mid_slice_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_input("elementwise_mul", "X"); + mid_slice->LinksFrom({x}).LinksTo({mid_slice_out}); + auto* right_slice = + pattern->NewNode(right_slice_repr()) + ->assert_is_op("strided_slice") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists>("axes") == + std::vector{4} && + op_desc->GetAttrIfExists>("starts") == + std::vector{4} && + op_desc->GetAttrIfExists>("ends") == + std::vector{2147483647}; + }); + auto* right_slice_out = pattern->NewNode(right_slice_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_nth_input("concat", "X", 2); + right_slice->LinksFrom({x}).LinksTo({right_slice_out}); + // left silce pattern + auto* left_ew_mul = + pattern->NewNode(left_ew_mul_repr()) + ->assert_is_op("elementwise_mul") + ->assert_more([&](Node* node) { + auto next_op_nodes = node->outputs[0]->outputs; + return next_op_nodes.size() == 1 && + (next_op_nodes[0]->Op()->Type() == "elementwise_sub" || + next_op_nodes[0]->Op()->Type() == "elementwise_add"); + }); + auto* left_ew_mul_out = pattern->NewNode(left_ew_mul_out_repr()) + ->assert_is_op_output("elementwise_mul", "Out"); + left_ew_mul->LinksFrom({left_slice_out}).LinksTo({left_ew_mul_out}); + PDNode* left_ew_sub = nullptr; + PDNode* left_ew_sub_y = nullptr; + PDNode* left_ew_sub_out = nullptr; + if (with_left_ew_sub_) { + left_ew_mul_out->assert_is_op_input("elementwise_sub", "X"); + left_ew_sub = + pattern->NewNode(left_ew_sub_repr())->assert_is_op("elementwise_sub"); + left_ew_sub_y = pattern->NewNode(left_ew_sub_y_repr()) + ->assert_is_op_input("elementwise_sub", "Y") + ->assert_is_persistable_var(); + left_ew_sub_out = pattern->NewNode(left_ew_sub_out_repr()) + ->assert_is_op_output("elementwise_sub", "Out"); + left_ew_sub->LinksFrom({left_ew_mul_out, left_ew_sub_y}) + .LinksTo({left_ew_sub_out}); + } else { + left_ew_sub_out = left_ew_mul_out; + } + left_ew_sub_out->assert_is_op_input("elementwise_add", "X"); + auto* left_ew_add = + pattern->NewNode(left_ew_add_repr())->assert_is_op("elementwise_add"); + auto* left_ew_add_y = pattern->NewNode(left_ew_add_y_repr()) + ->assert_is_op_input("elementwise_add", "Y"); + auto* left_ew_add_out = pattern->NewNode(left_ew_add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("elementwise_mul", "X"); + left_ew_add->LinksFrom({left_ew_sub_out, left_ew_add_y}) + .LinksTo({left_ew_add_out}); + auto* left_ew_mul_2 = + pattern->NewNode(left_ew_mul_2_repr()) + ->assert_is_op("elementwise_mul") + ->assert_more([&](Node* node) { + auto pre_op_nodes = node->inputs[0]->inputs; + return pre_op_nodes.size() == 1 && + pre_op_nodes[0]->Op()->Type() == "elementwise_add"; + }); + auto* left_ew_mul_2_y = pattern->NewNode(left_ew_mul_2_y_repr()) + ->assert_is_op_input("elementwise_mul", "Y"); + auto* left_ew_mul_2_out = pattern->NewNode(left_ew_mul_2_out_repr()) + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_nth_input("concat", "X", 0); + left_ew_mul_2->LinksFrom({left_ew_add_out, left_ew_mul_2_y}) + .LinksTo({left_ew_mul_2_out}); + // mid slice pattern + auto* mid_ew_mul = + pattern->NewNode(mid_ew_mul_repr()) + ->assert_is_op("elementwise_mul") + ->assert_more([&](Node* node) { + auto next_op_nodes = node->outputs[0]->outputs; + return next_op_nodes.size() == 1 && + next_op_nodes[0]->Op()->Type() == "elementwise_pow"; + }); + auto* mid_ew_mul_out = pattern->NewNode(mid_ew_mul_out_repr()) + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("elementwise_pow", "X"); + mid_ew_mul->LinksFrom({mid_slice_out}).LinksTo({mid_ew_mul_out}); + auto* mid_ew_pow = + pattern->NewNode(mid_ew_pow_repr())->assert_is_op("elementwise_pow"); + auto* mid_ew_pow_out = pattern->NewNode(mid_ew_pow_out_repr()) + ->assert_is_op_output("elementwise_pow", "Out") + ->assert_is_op_input("elementwise_mul", "X"); + mid_ew_pow->LinksFrom({mid_ew_mul_out}).LinksTo({mid_ew_pow_out}); + auto* mid_ew_mul_2 = + pattern->NewNode(mid_ew_mul_2_repr()) + ->assert_is_op("elementwise_mul") + ->assert_more([&](Node* node) { + auto pre_op_nodes = node->inputs[0]->inputs; + return pre_op_nodes.size() == 1 && + pre_op_nodes[0]->Op()->Type() == "elementwise_pow"; + }); + auto* mid_ew_mul_2_y = pattern->NewNode(mid_ew_mul_2_y_repr()) + ->assert_is_op_input("elementwise_mul", "Y"); + auto* mid_ew_mul_2_out = pattern->NewNode(mid_ew_mul_2_out_repr()) + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_nth_input("concat", "X", 1); + mid_ew_mul_2->LinksFrom({mid_ew_pow_out, mid_ew_mul_2_y}) + .LinksTo({mid_ew_mul_2_out}); + // concat + auto* concat = pattern->NewNode(concat_repr())->assert_is_op("concat"); + auto* concat_out = pattern->NewNode(concat_out_repr()) + ->assert_is_op_output("concat", "Out") + ->AsOutput(); + concat->LinksFrom({left_ew_mul_2_out, mid_ew_mul_2_out, right_slice_out}) + .LinksTo({concat_out}); +} + +} // namespace patterns + +class YoloBoxXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + int ApplyImpl(ir::Graph* graph, bool with_left_ew_sub) const; + + const std::string name_scope_{"yolo_box_xpu_fuse_pass"}; +}; + +void YoloBoxXPUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + int found_subgraph_count = 0; + for (auto with_left_ew_sub : {true, false}) { + found_subgraph_count += ApplyImpl(graph, with_left_ew_sub); + } + AddStatis(found_subgraph_count); +} + +int YoloBoxXPUFusePass::ApplyImpl(ir::Graph* graph, + bool with_left_ew_sub) const { + GraphPatternDetector gpd; + patterns::YoloBoxXPUPattern pattern( + gpd.mutable_pattern(), name_scope_, with_left_ew_sub); + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle YoloBoxXPUFusePass fuse"; + /* declare operator node's name */ + // declare operator node's name + GET_IR_NODE(left_slice); + GET_IR_NODE(left_ew_mul); + GET_IR_NODE(left_ew_sub); + GET_IR_NODE(left_ew_add); + GET_IR_NODE(left_ew_mul_2); + GET_IR_NODE(mid_slice); + GET_IR_NODE(mid_ew_mul); + GET_IR_NODE(mid_ew_pow); + GET_IR_NODE(mid_ew_mul_2); + GET_IR_NODE(right_slice); + GET_IR_NODE(concat); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(left_slice_out); + GET_IR_NODE(left_ew_mul_out); + GET_IR_NODE(left_ew_sub_y); + GET_IR_NODE(left_ew_sub_out); + GET_IR_NODE(left_ew_add_y); + GET_IR_NODE(left_ew_add_out); + GET_IR_NODE(left_ew_mul_2_y); + GET_IR_NODE(left_ew_mul_2_out); + GET_IR_NODE(mid_slice_out); + GET_IR_NODE(mid_ew_mul_out); + GET_IR_NODE(mid_ew_pow_out); + GET_IR_NODE(mid_ew_mul_2_y); + GET_IR_NODE(mid_ew_mul_2_out); + GET_IR_NODE(right_slice_out); + GET_IR_NODE(concat_out); + auto* block = concat->Op()->Block(); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + std::string fused_op_out_name; + fused_op_out_name = concat_out->Name(); + std::string fused_op_out_max_name = fused_op_out_name + "_max"; + VarDesc fused_op_out_max_desc(fused_op_out_max_name); + Node* fused_op_out_max = graph->CreateVarNode(&fused_op_out_max_desc); + // Generate yolo_box_xpu fused op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("yolo_box_xpu"); + // set attrs for fused op + fused_op_desc.SetInput("x", {x->Name()}); + fused_op_desc.SetInput("grid", {left_ew_add_y->Name()}); + fused_op_desc.SetInput("stride", {left_ew_mul_2_y->Name()}); + fused_op_desc.SetInput("anchor_grid", {mid_ew_mul_2_y->Name()}); + float offset_ = 0.f; + if (with_left_ew_sub) { + const auto& left_ew_sub_y_t = + scope->FindVar(left_ew_sub_y->Name())->Get(); + auto left_ew_sub_y_dims = left_ew_sub_y_t.dims(); + PADDLE_ENFORCE_EQ(left_ew_sub_y_dims.size(), + 1, + platform::errors::InvalidArgument( + "the size(%d) of left elementwise sub tensor " + "must equal 1", + left_ew_sub_y_dims.size())); + auto tensor_type = left_ew_sub_y_t.dtype(); + if (tensor_type == phi::DataType::FLOAT16) { + auto* sub_t_fp16_ptr = left_ew_sub_y_t.data(); + offset_ = static_cast(sub_t_fp16_ptr[0]); + } else if (tensor_type == phi::DataType::FLOAT32) { + auto* sub_t_fp32_ptr = left_ew_sub_y_t.data(); + offset_ = sub_t_fp32_ptr[0]; + } else { + PADDLE_THROW(platform::errors::Unavailable( + "yolo_box_fuse_xpu_pass not supported weight dtype. " + "we now only support fp32/fp16.")); + } + } + fused_op_desc.SetAttr("offset", offset_); + fused_op_desc.SetOutput("out", {concat_out->Name()}); + fused_op_desc.SetOutput("out_max", {fused_op_out_max_name}); + // relink fused op + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(x, fused_op); + IR_NODE_LINK_TO(left_ew_add_y, fused_op); + IR_NODE_LINK_TO(left_ew_mul_2_y, fused_op); + IR_NODE_LINK_TO(mid_ew_mul_2_y, fused_op); + IR_NODE_LINK_TO(fused_op, concat_out); + IR_NODE_LINK_TO(fused_op, fused_op_out_max); + // delete useless node + std::unordered_set delete_nodes = {left_slice, + left_slice_out, + left_ew_mul, + left_ew_mul_out, + left_ew_add, + left_ew_add_out, + left_ew_mul_2, + left_ew_mul_2_out, + mid_slice, + mid_slice_out, + mid_ew_mul, + mid_ew_mul_out, + mid_ew_pow, + mid_ew_pow_out, + mid_ew_mul_2, + mid_ew_mul_2_out, + right_slice, + right_slice_out, + concat}; + if (with_left_ew_sub) { + delete_nodes.insert(left_ew_sub); + delete_nodes.insert(left_ew_sub_out); + } + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(yolo_box_xpu_fuse_pass, + paddle::framework::ir::YoloBoxXPUFusePass); + +REGISTER_PASS_CAPABILITY(yolo_box_xpu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "yolo_box_xpu", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 423256c7ec7d308bb1663e70e7a101c7f19d2d0c..41a07585984e5f8e538f2445272068c99fa922d4 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -529,6 +529,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "fc_xpu_fuse_pass", "conv2d_xpu_fuse_pass", "add_activation_xpu_fuse_pass", + "yolo_box_xpu_fuse_pass", "link_xpu_op_max_pass", "inplace_op_var_pass", "delete_isolated_node_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index c8be15e5b5551f263c3dcf4259c8ed9ba1c78581..4aa981c95fe2a99d59d9f332ec25ccef438a77fa 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -96,3 +96,13 @@ func : multi_encoder_xpu data_type : x optional : mask, seq_lod, max_seq_len, x_fp16, out_fp16 + +- op : yolo_box_xpu + args : (Tensor x, Tensor x_max, Tensor grid, Tensor stride, Tensor anchor_grid, float offset) + output : Tensor(out), Tensor(out_max) + infer_meta : + func : YoloBoxXPUInferMeta + kernel : + func : yolo_box_xpu + data_type : x + optional : x_max diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 3777f1c8f3713eac5024ff3d0bef33a7d1e9a698..add1d7eca7d360c185b373097983db77c20f4d58 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -937,6 +937,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, {"isnan_v2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"yolo_box_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, // AddMore {"sequence_conv", XPUKernelSet({phi::DataType::FLOAT32})}, diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 78bf3d01de090fbef4517170c9936a1360e31c64..437fbda9f476f69ffc8837df14e80156771ff003 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -20,20 +20,16 @@ limitations under the License. */ #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/concat_funcs.h" +#include "paddle/phi/kernels/funcs/strided_slice.h" namespace phi { -void AddActXPUInferMeta(const MetaTensor& x, - const MetaTensor& x_max, - const MetaTensor& y, - const MetaTensor& y_max, - int act_type, - MetaTensor* out, - MetaTensor* out_max) { - int axis = -1; - if (x.dims() != y.dims()) { - auto x_dims = x.dims(); - auto y_dims = y.dims(); +static phi::DDim BroadCastInferShape(const DDim x_dims, + const DDim y_dims, + int axis) { + std::vector out_dims_array(x_dims.size(), -1); + if (x_dims != y_dims) { int max_dim = std::max(x_dims.size(), y_dims.size()); if (x_dims.size() == y_dims.size()) { PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0), @@ -58,7 +54,7 @@ void AddActXPUInferMeta(const MetaTensor& x, : axis); std::vector x_dims_array(max_dim); std::vector y_dims_array(max_dim); - std::vector out_dims_array(max_dim); + out_dims_array.resize(max_dim); funcs::GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), @@ -66,10 +62,27 @@ void AddActXPUInferMeta(const MetaTensor& x, out_dims_array.data(), max_dim, axis); - auto out_dims = phi::make_ddim(out_dims_array); + + return phi::make_ddim(out_dims_array); + } + return x_dims; +} + +void AddActXPUInferMeta(const MetaTensor& x, + const MetaTensor& x_max, + const MetaTensor& y, + const MetaTensor& y_max, + int act_type, + MetaTensor* out, + MetaTensor* out_max) { + int axis = -1; + auto x_dims = x.dims(); + auto y_dims = y.dims(); + if (x_dims != y_dims) { + auto out_dims = BroadCastInferShape(x_dims, y_dims, axis); out->set_dims(out_dims); } else { - out->set_dims(x.dims()); + out->set_dims(x_dims); } out->set_dtype(x.dtype()); out->set_layout(x.layout()); @@ -411,4 +424,98 @@ void FusedMultiTransformerXpuInferMeta( out->set_layout(x.layout()); } +void YoloBoxXPUInferMeta(const MetaTensor& x, + const MetaTensor& x_max, + const MetaTensor& grid, + const MetaTensor& stride, + const MetaTensor& anchor_grid, + float offset, + MetaTensor* out, + MetaTensor* out_max) { + auto x_dims = x.dims(); + auto x_dims_size = x_dims.size(); + PADDLE_ENFORCE_GT( + x_dims[x_dims_size - 1], + 4, + phi::errors::InvalidArgument( + "The last dim of x should be larget than 4, but received " + " is %d.", + x_dims[x_dims_size - 1])); + // compute left out_dims + // y[..., 0:2] = (x[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy + std::vector axes_ = {x_dims_size - 1}; + std::vector infer_flags_ = {1}; + std::vector decrease_axis_ = {-1}; + std::vector strides_ = {1}; + std::vector starts_l = {0}; + std::vector ends_l = {2}; + std::vector left_slice_out_dims_vector(x_dims_size, -1); + phi::funcs::StridedSliceOutDims(starts_l, + ends_l, + strides_, + axes_, + infer_flags_, + x_dims, + decrease_axis_, + left_slice_out_dims_vector.data(), + 1, + true); + auto left_slice_out_dims = phi::make_ddim(left_slice_out_dims_vector); + auto grid_dims = grid.dims(); + auto left_add_out_dims = + BroadCastInferShape(left_slice_out_dims, grid_dims, -1); + auto stride_dims = stride.dims(); + auto left_mul_out_dims = + BroadCastInferShape(left_add_out_dims, stride_dims, -1); + // compute mid out_dims + // wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + std::vector starts_m = {2}; + std::vector ends_m = {4}; + std::vector mid_slice_out_dims_vector(x_dims_size, -1); + phi::funcs::StridedSliceOutDims(starts_m, + ends_m, + strides_, + axes_, + infer_flags_, + x_dims, + decrease_axis_, + mid_slice_out_dims_vector.data(), + 1, + true); + auto mid_slice_out_dims = phi::make_ddim(mid_slice_out_dims_vector); + auto anchor_grid_dims = anchor_grid.dims(); + auto mid_mul_out_dims = + BroadCastInferShape(mid_slice_out_dims, anchor_grid_dims, -1); + // compute right out_dims + std::vector starts_r = {4}; + std::vector ends_r = {2147483647}; + std::vector right_slice_out_dims_vector(x_dims_size, -1); + phi::funcs::StridedSliceOutDims(starts_r, + ends_r, + strides_, + axes_, + infer_flags_, + x_dims, + decrease_axis_, + right_slice_out_dims_vector.data(), + 1, + true); + auto right_slice_out_dims = phi::make_ddim(right_slice_out_dims_vector); + // compute concat out_dims + std::vector in_dims; + in_dims.reserve(3); + in_dims.emplace_back(left_mul_out_dims); + in_dims.emplace_back(mid_mul_out_dims); + in_dims.emplace_back(right_slice_out_dims); + phi::DDim out_dim = + phi::funcs::ComputeAndCheckShape(false, in_dims, x_dims_size - 1); + + out->set_dims(out_dim); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out_max->set_dims(phi::make_ddim({6})); + out_max->set_dtype(x.dtype()); + out_max->set_layout(x.layout()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 751bbfd95f7644192d426d57459d403e1c5718a8..b4456d07a7a5de7ad73c110cca04ae3ca631cdde 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -135,4 +135,14 @@ void FusedMultiTransformerXpuInferMeta( int gather_axis, MetaTensor* out, std::vector cache_kv_out); + +void YoloBoxXPUInferMeta(const MetaTensor& x, + const MetaTensor& x_max, + const MetaTensor& grid, + const MetaTensor& stride, + const MetaTensor& anchor_grid, + float offset, + MetaTensor* out, + MetaTensor* out_max); + } // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/yolo_box_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/yolo_box_xpu_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..193e18d38dccd40cc0ac3bab9d60747575fc162e --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/yolo_box_xpu_kernel.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void YoloBoxXPUKernel(const Context& ctx, + const DenseTensor& x, + const paddle::optional& x_max, + const DenseTensor& grid, + const DenseTensor& stride, + const DenseTensor& anchor_grid, + float offset, + DenseTensor* out, + DenseTensor* out_max) { + using XPUType = typename XPUTypeTrait::Type; + + auto* x_data = reinterpret_cast(x.data()); + auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + // float* x_max + float* x_max_data = nullptr; + const float* grid_data; + const float* stride_data; + const float* anchor_grid_data; + // fix precision of fp16 model + if (std::is_same::value) { + DenseTensor grid_data_fp32_t; + DenseTensor stride_data_fp32_t; + DenseTensor anchor_grid_data_fp32_t; + ctx.template Alloc(&grid_data_fp32_t, grid.numel() * sizeof(float)); + int r1 = xpu::cast( + ctx.x_context(), + reinterpret_cast(grid.data()), + grid_data_fp32_t.data(), + grid.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r1, "cast"); + ctx.template Alloc(&stride_data_fp32_t, + stride.numel() * sizeof(float)); + int r2 = xpu::cast( + ctx.x_context(), + reinterpret_cast(stride.data()), + stride_data_fp32_t.data(), + stride.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r2, "cast"); + ctx.template Alloc(&anchor_grid_data_fp32_t, + anchor_grid.numel() * sizeof(float)); + int r3 = xpu::cast( + ctx.x_context(), + reinterpret_cast(anchor_grid.data()), + anchor_grid_data_fp32_t.data(), + anchor_grid.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r3, "cast"); + grid_data = grid_data_fp32_t.data(); + stride_data = stride_data_fp32_t.data(); + anchor_grid_data = anchor_grid_data_fp32_t.data(); + } else { + grid_data = grid.data(); + stride_data = stride.data(); + anchor_grid_data = anchor_grid.data(); + } + std::vector x_shape = phi::vectorize(x.dims()); + std::vector grid_shape = phi::vectorize(grid.dims()); + std::vector stride_shape = phi::vectorize(stride.dims()); + std::vector anchor_grid_shape = phi::vectorize(anchor_grid.dims()); + // yolo_box_coord only support fp32&&fp16 precision + int r = xpu::yolo_box_coord( + /* baidu::xpu::api::Context* ctx */ ctx.x_context(), + /* const T* x */ x_data, + /* T* y */ out_data, + /* const std::vector& x_shape */ x_shape, + /* const float* grid */ grid_data, + /* const float* stride */ stride_data, + /* const float* anchor_grid */ anchor_grid_data, + /* const std::vector& grid_shape */ grid_shape, + /* const std::vector& stride_shape */ stride_shape, + /* const std::vector& anchor_grid */ anchor_grid_shape, + /* float offset */ offset, + /* float* x_max */ x_max_data, + /* float* y_max */ ctx.template Alloc(out_max)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "yolo_box_xpu"); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(yolo_box_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::YoloBoxXPUKernel, + float, + phi::dtype::float16) {}