未验证 提交 a087b9cb 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]Add yolo box fuse pass && kernel (#54163)

上级 6d3f56f3
......@@ -228,6 +228,7 @@ if(WITH_XPU)
SRCS xpu/pass_utils.cc
DEPS pass xpu_quant_utils)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
/*
fuse block in yolo-like model to yolo_box_xpu op
------------------------------------------------------
sub block:
x
/ | \
/ | \
/ | \
slice slice slice
| | |
| | |
ew_mul ew_mul |
| | |
| | |
ew_sub ew_pow |
| | |
| | |
ew_add ew_mul_2 |
| | |
| | |
ew_mul_2 | |
\ | /
\ | /
\ | /
concat
|
y
------------------------------------------------------
After the pass is applied:
x
grid[left_ew_add_y] | offset[left_ew_sub_y]
\ | /
\ | /
stride[left_ew_mul_2_y] -- yolo_box_xpu --- anchor_grid[mid_ew_mul_2_y]
| \
| \
| \
y y_max
*/
struct YoloBoxXPUPattern : public PatternBase {
YoloBoxXPUPattern(PDPattern* pattern,
const std::string& name_scope,
bool with_left_ew_sub_);
// declare operator node's name
PATTERN_DECL_NODE(left_slice);
PATTERN_DECL_NODE(mid_slice);
PATTERN_DECL_NODE(right_slice);
PATTERN_DECL_NODE(left_ew_mul);
PATTERN_DECL_NODE(left_ew_sub);
PATTERN_DECL_NODE(left_ew_add);
PATTERN_DECL_NODE(left_ew_mul_2);
PATTERN_DECL_NODE(mid_ew_mul);
PATTERN_DECL_NODE(mid_ew_pow);
PATTERN_DECL_NODE(mid_ew_mul_2);
PATTERN_DECL_NODE(concat);
// declare variable node's name
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(left_slice_out);
PATTERN_DECL_NODE(left_ew_mul_out);
PATTERN_DECL_NODE(left_ew_sub_y);
PATTERN_DECL_NODE(left_ew_sub_out);
PATTERN_DECL_NODE(left_ew_add_y);
PATTERN_DECL_NODE(left_ew_add_out);
PATTERN_DECL_NODE(left_ew_mul_2_y);
PATTERN_DECL_NODE(left_ew_mul_2_out);
PATTERN_DECL_NODE(mid_slice_out);
PATTERN_DECL_NODE(mid_ew_mul_out);
PATTERN_DECL_NODE(mid_ew_pow_out);
PATTERN_DECL_NODE(mid_ew_mul_2_y);
PATTERN_DECL_NODE(mid_ew_mul_2_out);
PATTERN_DECL_NODE(right_slice_out);
PATTERN_DECL_NODE(concat_out);
private:
bool with_left_ew_sub_{true};
};
YoloBoxXPUPattern::YoloBoxXPUPattern(PDPattern* pattern,
const std::string& name_scope,
bool with_left_ew_sub)
: PatternBase(pattern, name_scope, name_scope),
with_left_ew_sub_(with_left_ew_sub) {
auto x = pattern->NewNode(x_repr())
->assert_is_op_output("sigmoid", "Out")
->assert_has_n_outputs(3);
auto* left_slice =
pattern->NewNode(left_slice_repr())
->assert_is_op("strided_slice")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<std::vector<int>>("axes") ==
std::vector<int>{4} &&
op_desc->GetAttrIfExists<std::vector<int>>("starts") ==
std::vector<int>{0} &&
op_desc->GetAttrIfExists<std::vector<int>>("ends") ==
std::vector<int>{2};
});
auto* left_slice_out = pattern->NewNode(left_slice_out_repr())
->assert_is_op_output("strided_slice", "Out")
->assert_is_op_input("elementwise_mul", "X");
left_slice->LinksFrom({x}).LinksTo({left_slice_out});
auto* mid_slice =
pattern->NewNode(mid_slice_repr())
->assert_is_op("strided_slice")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<std::vector<int>>("axes") ==
std::vector<int>{4} &&
op_desc->GetAttrIfExists<std::vector<int>>("starts") ==
std::vector<int>{2} &&
op_desc->GetAttrIfExists<std::vector<int>>("ends") ==
std::vector<int>{4};
});
auto* mid_slice_out = pattern->NewNode(mid_slice_out_repr())
->assert_is_op_output("strided_slice", "Out")
->assert_is_op_input("elementwise_mul", "X");
mid_slice->LinksFrom({x}).LinksTo({mid_slice_out});
auto* right_slice =
pattern->NewNode(right_slice_repr())
->assert_is_op("strided_slice")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<std::vector<int>>("axes") ==
std::vector<int>{4} &&
op_desc->GetAttrIfExists<std::vector<int>>("starts") ==
std::vector<int>{4} &&
op_desc->GetAttrIfExists<std::vector<int>>("ends") ==
std::vector<int>{2147483647};
});
auto* right_slice_out = pattern->NewNode(right_slice_out_repr())
->assert_is_op_output("strided_slice", "Out")
->assert_is_op_nth_input("concat", "X", 2);
right_slice->LinksFrom({x}).LinksTo({right_slice_out});
// left silce pattern
auto* left_ew_mul =
pattern->NewNode(left_ew_mul_repr())
->assert_is_op("elementwise_mul")
->assert_more([&](Node* node) {
auto next_op_nodes = node->outputs[0]->outputs;
return next_op_nodes.size() == 1 &&
(next_op_nodes[0]->Op()->Type() == "elementwise_sub" ||
next_op_nodes[0]->Op()->Type() == "elementwise_add");
});
auto* left_ew_mul_out = pattern->NewNode(left_ew_mul_out_repr())
->assert_is_op_output("elementwise_mul", "Out");
left_ew_mul->LinksFrom({left_slice_out}).LinksTo({left_ew_mul_out});
PDNode* left_ew_sub = nullptr;
PDNode* left_ew_sub_y = nullptr;
PDNode* left_ew_sub_out = nullptr;
if (with_left_ew_sub_) {
left_ew_mul_out->assert_is_op_input("elementwise_sub", "X");
left_ew_sub =
pattern->NewNode(left_ew_sub_repr())->assert_is_op("elementwise_sub");
left_ew_sub_y = pattern->NewNode(left_ew_sub_y_repr())
->assert_is_op_input("elementwise_sub", "Y")
->assert_is_persistable_var();
left_ew_sub_out = pattern->NewNode(left_ew_sub_out_repr())
->assert_is_op_output("elementwise_sub", "Out");
left_ew_sub->LinksFrom({left_ew_mul_out, left_ew_sub_y})
.LinksTo({left_ew_sub_out});
} else {
left_ew_sub_out = left_ew_mul_out;
}
left_ew_sub_out->assert_is_op_input("elementwise_add", "X");
auto* left_ew_add =
pattern->NewNode(left_ew_add_repr())->assert_is_op("elementwise_add");
auto* left_ew_add_y = pattern->NewNode(left_ew_add_y_repr())
->assert_is_op_input("elementwise_add", "Y");
auto* left_ew_add_out = pattern->NewNode(left_ew_add_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("elementwise_mul", "X");
left_ew_add->LinksFrom({left_ew_sub_out, left_ew_add_y})
.LinksTo({left_ew_add_out});
auto* left_ew_mul_2 =
pattern->NewNode(left_ew_mul_2_repr())
->assert_is_op("elementwise_mul")
->assert_more([&](Node* node) {
auto pre_op_nodes = node->inputs[0]->inputs;
return pre_op_nodes.size() == 1 &&
pre_op_nodes[0]->Op()->Type() == "elementwise_add";
});
auto* left_ew_mul_2_y = pattern->NewNode(left_ew_mul_2_y_repr())
->assert_is_op_input("elementwise_mul", "Y");
auto* left_ew_mul_2_out = pattern->NewNode(left_ew_mul_2_out_repr())
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_nth_input("concat", "X", 0);
left_ew_mul_2->LinksFrom({left_ew_add_out, left_ew_mul_2_y})
.LinksTo({left_ew_mul_2_out});
// mid slice pattern
auto* mid_ew_mul =
pattern->NewNode(mid_ew_mul_repr())
->assert_is_op("elementwise_mul")
->assert_more([&](Node* node) {
auto next_op_nodes = node->outputs[0]->outputs;
return next_op_nodes.size() == 1 &&
next_op_nodes[0]->Op()->Type() == "elementwise_pow";
});
auto* mid_ew_mul_out = pattern->NewNode(mid_ew_mul_out_repr())
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_input("elementwise_pow", "X");
mid_ew_mul->LinksFrom({mid_slice_out}).LinksTo({mid_ew_mul_out});
auto* mid_ew_pow =
pattern->NewNode(mid_ew_pow_repr())->assert_is_op("elementwise_pow");
auto* mid_ew_pow_out = pattern->NewNode(mid_ew_pow_out_repr())
->assert_is_op_output("elementwise_pow", "Out")
->assert_is_op_input("elementwise_mul", "X");
mid_ew_pow->LinksFrom({mid_ew_mul_out}).LinksTo({mid_ew_pow_out});
auto* mid_ew_mul_2 =
pattern->NewNode(mid_ew_mul_2_repr())
->assert_is_op("elementwise_mul")
->assert_more([&](Node* node) {
auto pre_op_nodes = node->inputs[0]->inputs;
return pre_op_nodes.size() == 1 &&
pre_op_nodes[0]->Op()->Type() == "elementwise_pow";
});
auto* mid_ew_mul_2_y = pattern->NewNode(mid_ew_mul_2_y_repr())
->assert_is_op_input("elementwise_mul", "Y");
auto* mid_ew_mul_2_out = pattern->NewNode(mid_ew_mul_2_out_repr())
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_nth_input("concat", "X", 1);
mid_ew_mul_2->LinksFrom({mid_ew_pow_out, mid_ew_mul_2_y})
.LinksTo({mid_ew_mul_2_out});
// concat
auto* concat = pattern->NewNode(concat_repr())->assert_is_op("concat");
auto* concat_out = pattern->NewNode(concat_out_repr())
->assert_is_op_output("concat", "Out")
->AsOutput();
concat->LinksFrom({left_ew_mul_2_out, mid_ew_mul_2_out, right_slice_out})
.LinksTo({concat_out});
}
} // namespace patterns
class YoloBoxXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplyImpl(ir::Graph* graph, bool with_left_ew_sub) const;
const std::string name_scope_{"yolo_box_xpu_fuse_pass"};
};
void YoloBoxXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
int found_subgraph_count = 0;
for (auto with_left_ew_sub : {true, false}) {
found_subgraph_count += ApplyImpl(graph, with_left_ew_sub);
}
AddStatis(found_subgraph_count);
}
int YoloBoxXPUFusePass::ApplyImpl(ir::Graph* graph,
bool with_left_ew_sub) const {
GraphPatternDetector gpd;
patterns::YoloBoxXPUPattern pattern(
gpd.mutable_pattern(), name_scope_, with_left_ew_sub);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle YoloBoxXPUFusePass fuse";
/* declare operator node's name */
// declare operator node's name
GET_IR_NODE(left_slice);
GET_IR_NODE(left_ew_mul);
GET_IR_NODE(left_ew_sub);
GET_IR_NODE(left_ew_add);
GET_IR_NODE(left_ew_mul_2);
GET_IR_NODE(mid_slice);
GET_IR_NODE(mid_ew_mul);
GET_IR_NODE(mid_ew_pow);
GET_IR_NODE(mid_ew_mul_2);
GET_IR_NODE(right_slice);
GET_IR_NODE(concat);
// declare variable node's name
GET_IR_NODE(x);
GET_IR_NODE(left_slice_out);
GET_IR_NODE(left_ew_mul_out);
GET_IR_NODE(left_ew_sub_y);
GET_IR_NODE(left_ew_sub_out);
GET_IR_NODE(left_ew_add_y);
GET_IR_NODE(left_ew_add_out);
GET_IR_NODE(left_ew_mul_2_y);
GET_IR_NODE(left_ew_mul_2_out);
GET_IR_NODE(mid_slice_out);
GET_IR_NODE(mid_ew_mul_out);
GET_IR_NODE(mid_ew_pow_out);
GET_IR_NODE(mid_ew_mul_2_y);
GET_IR_NODE(mid_ew_mul_2_out);
GET_IR_NODE(right_slice_out);
GET_IR_NODE(concat_out);
auto* block = concat->Op()->Block();
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
std::string fused_op_out_name;
fused_op_out_name = concat_out->Name();
std::string fused_op_out_max_name = fused_op_out_name + "_max";
VarDesc fused_op_out_max_desc(fused_op_out_max_name);
Node* fused_op_out_max = graph->CreateVarNode(&fused_op_out_max_desc);
// Generate yolo_box_xpu fused op
framework::OpDesc fused_op_desc(block);
fused_op_desc.SetType("yolo_box_xpu");
// set attrs for fused op
fused_op_desc.SetInput("x", {x->Name()});
fused_op_desc.SetInput("grid", {left_ew_add_y->Name()});
fused_op_desc.SetInput("stride", {left_ew_mul_2_y->Name()});
fused_op_desc.SetInput("anchor_grid", {mid_ew_mul_2_y->Name()});
float offset_ = 0.f;
if (with_left_ew_sub) {
const auto& left_ew_sub_y_t =
scope->FindVar(left_ew_sub_y->Name())->Get<phi::DenseTensor>();
auto left_ew_sub_y_dims = left_ew_sub_y_t.dims();
PADDLE_ENFORCE_EQ(left_ew_sub_y_dims.size(),
1,
platform::errors::InvalidArgument(
"the size(%d) of left elementwise sub tensor "
"must equal 1",
left_ew_sub_y_dims.size()));
auto tensor_type = left_ew_sub_y_t.dtype();
if (tensor_type == phi::DataType::FLOAT16) {
auto* sub_t_fp16_ptr = left_ew_sub_y_t.data<platform::float16>();
offset_ = static_cast<float>(sub_t_fp16_ptr[0]);
} else if (tensor_type == phi::DataType::FLOAT32) {
auto* sub_t_fp32_ptr = left_ew_sub_y_t.data<float>();
offset_ = sub_t_fp32_ptr[0];
} else {
PADDLE_THROW(platform::errors::Unavailable(
"yolo_box_fuse_xpu_pass not supported weight dtype. "
"we now only support fp32/fp16."));
}
}
fused_op_desc.SetAttr("offset", offset_);
fused_op_desc.SetOutput("out", {concat_out->Name()});
fused_op_desc.SetOutput("out_max", {fused_op_out_max_name});
// relink fused op
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
IR_NODE_LINK_TO(x, fused_op);
IR_NODE_LINK_TO(left_ew_add_y, fused_op);
IR_NODE_LINK_TO(left_ew_mul_2_y, fused_op);
IR_NODE_LINK_TO(mid_ew_mul_2_y, fused_op);
IR_NODE_LINK_TO(fused_op, concat_out);
IR_NODE_LINK_TO(fused_op, fused_op_out_max);
// delete useless node
std::unordered_set<const Node*> delete_nodes = {left_slice,
left_slice_out,
left_ew_mul,
left_ew_mul_out,
left_ew_add,
left_ew_add_out,
left_ew_mul_2,
left_ew_mul_2_out,
mid_slice,
mid_slice_out,
mid_ew_mul,
mid_ew_mul_out,
mid_ew_pow,
mid_ew_pow_out,
mid_ew_mul_2,
mid_ew_mul_2_out,
right_slice,
right_slice_out,
concat};
if (with_left_ew_sub) {
delete_nodes.insert(left_ew_sub);
delete_nodes.insert(left_ew_sub_out);
}
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(yolo_box_xpu_fuse_pass,
paddle::framework::ir::YoloBoxXPUFusePass);
REGISTER_PASS_CAPABILITY(yolo_box_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"yolo_box_xpu", 0));
......@@ -529,6 +529,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fc_xpu_fuse_pass",
"conv2d_xpu_fuse_pass",
"add_activation_xpu_fuse_pass",
"yolo_box_xpu_fuse_pass",
"link_xpu_op_max_pass",
"inplace_op_var_pass",
"delete_isolated_node_pass",
......
......@@ -96,3 +96,13 @@
func : multi_encoder_xpu
data_type : x
optional : mask, seq_lod, max_seq_len, x_fp16, out_fp16
- op : yolo_box_xpu
args : (Tensor x, Tensor x_max, Tensor grid, Tensor stride, Tensor anchor_grid, float offset)
output : Tensor(out), Tensor(out_max)
infer_meta :
func : YoloBoxXPUInferMeta
kernel :
func : yolo_box_xpu
data_type : x
optional : x_max
......@@ -937,6 +937,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"isnan_v2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"yolo_box_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
// AddMore
{"sequence_conv", XPUKernelSet({phi::DataType::FLOAT32})},
......
......@@ -20,20 +20,16 @@ limitations under the License. */
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
namespace phi {
void AddActXPUInferMeta(const MetaTensor& x,
const MetaTensor& x_max,
const MetaTensor& y,
const MetaTensor& y_max,
int act_type,
MetaTensor* out,
MetaTensor* out_max) {
int axis = -1;
if (x.dims() != y.dims()) {
auto x_dims = x.dims();
auto y_dims = y.dims();
static phi::DDim BroadCastInferShape(const DDim x_dims,
const DDim y_dims,
int axis) {
std::vector<int> out_dims_array(x_dims.size(), -1);
if (x_dims != y_dims) {
int max_dim = std::max(x_dims.size(), y_dims.size());
if (x_dims.size() == y_dims.size()) {
PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0),
......@@ -58,7 +54,7 @@ void AddActXPUInferMeta(const MetaTensor& x,
: axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
out_dims_array.resize(max_dim);
funcs::GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
......@@ -66,10 +62,27 @@ void AddActXPUInferMeta(const MetaTensor& x,
out_dims_array.data(),
max_dim,
axis);
auto out_dims = phi::make_ddim(out_dims_array);
return phi::make_ddim(out_dims_array);
}
return x_dims;
}
void AddActXPUInferMeta(const MetaTensor& x,
const MetaTensor& x_max,
const MetaTensor& y,
const MetaTensor& y_max,
int act_type,
MetaTensor* out,
MetaTensor* out_max) {
int axis = -1;
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims != y_dims) {
auto out_dims = BroadCastInferShape(x_dims, y_dims, axis);
out->set_dims(out_dims);
} else {
out->set_dims(x.dims());
out->set_dims(x_dims);
}
out->set_dtype(x.dtype());
out->set_layout(x.layout());
......@@ -411,4 +424,98 @@ void FusedMultiTransformerXpuInferMeta(
out->set_layout(x.layout());
}
void YoloBoxXPUInferMeta(const MetaTensor& x,
const MetaTensor& x_max,
const MetaTensor& grid,
const MetaTensor& stride,
const MetaTensor& anchor_grid,
float offset,
MetaTensor* out,
MetaTensor* out_max) {
auto x_dims = x.dims();
auto x_dims_size = x_dims.size();
PADDLE_ENFORCE_GT(
x_dims[x_dims_size - 1],
4,
phi::errors::InvalidArgument(
"The last dim of x should be larget than 4, but received "
" is %d.",
x_dims[x_dims_size - 1]));
// compute left out_dims
// y[..., 0:2] = (x[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy
std::vector<int> axes_ = {x_dims_size - 1};
std::vector<int> infer_flags_ = {1};
std::vector<int> decrease_axis_ = {-1};
std::vector<int64_t> strides_ = {1};
std::vector<int64_t> starts_l = {0};
std::vector<int64_t> ends_l = {2};
std::vector<int64_t> left_slice_out_dims_vector(x_dims_size, -1);
phi::funcs::StridedSliceOutDims(starts_l,
ends_l,
strides_,
axes_,
infer_flags_,
x_dims,
decrease_axis_,
left_slice_out_dims_vector.data(),
1,
true);
auto left_slice_out_dims = phi::make_ddim(left_slice_out_dims_vector);
auto grid_dims = grid.dims();
auto left_add_out_dims =
BroadCastInferShape(left_slice_out_dims, grid_dims, -1);
auto stride_dims = stride.dims();
auto left_mul_out_dims =
BroadCastInferShape(left_add_out_dims, stride_dims, -1);
// compute mid out_dims
// wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
std::vector<int64_t> starts_m = {2};
std::vector<int64_t> ends_m = {4};
std::vector<int64_t> mid_slice_out_dims_vector(x_dims_size, -1);
phi::funcs::StridedSliceOutDims(starts_m,
ends_m,
strides_,
axes_,
infer_flags_,
x_dims,
decrease_axis_,
mid_slice_out_dims_vector.data(),
1,
true);
auto mid_slice_out_dims = phi::make_ddim(mid_slice_out_dims_vector);
auto anchor_grid_dims = anchor_grid.dims();
auto mid_mul_out_dims =
BroadCastInferShape(mid_slice_out_dims, anchor_grid_dims, -1);
// compute right out_dims
std::vector<int64_t> starts_r = {4};
std::vector<int64_t> ends_r = {2147483647};
std::vector<int64_t> right_slice_out_dims_vector(x_dims_size, -1);
phi::funcs::StridedSliceOutDims(starts_r,
ends_r,
strides_,
axes_,
infer_flags_,
x_dims,
decrease_axis_,
right_slice_out_dims_vector.data(),
1,
true);
auto right_slice_out_dims = phi::make_ddim(right_slice_out_dims_vector);
// compute concat out_dims
std::vector<phi::DDim> in_dims;
in_dims.reserve(3);
in_dims.emplace_back(left_mul_out_dims);
in_dims.emplace_back(mid_mul_out_dims);
in_dims.emplace_back(right_slice_out_dims);
phi::DDim out_dim =
phi::funcs::ComputeAndCheckShape(false, in_dims, x_dims_size - 1);
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out_max->set_dims(phi::make_ddim({6}));
out_max->set_dtype(x.dtype());
out_max->set_layout(x.layout());
}
} // namespace phi
......@@ -135,4 +135,14 @@ void FusedMultiTransformerXpuInferMeta(
int gather_axis,
MetaTensor* out,
std::vector<MetaTensor*> cache_kv_out);
void YoloBoxXPUInferMeta(const MetaTensor& x,
const MetaTensor& x_max,
const MetaTensor& grid,
const MetaTensor& stride,
const MetaTensor& anchor_grid,
float offset,
MetaTensor* out,
MetaTensor* out_max);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void YoloBoxXPUKernel(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& x_max,
const DenseTensor& grid,
const DenseTensor& stride,
const DenseTensor& anchor_grid,
float offset,
DenseTensor* out,
DenseTensor* out_max) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
// float* x_max
float* x_max_data = nullptr;
const float* grid_data;
const float* stride_data;
const float* anchor_grid_data;
// fix precision of fp16 model
if (std::is_same<T, phi::dtype::float16>::value) {
DenseTensor grid_data_fp32_t;
DenseTensor stride_data_fp32_t;
DenseTensor anchor_grid_data_fp32_t;
ctx.template Alloc<float>(&grid_data_fp32_t, grid.numel() * sizeof(float));
int r1 = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(grid.data<T>()),
grid_data_fp32_t.data<float>(),
grid.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r1, "cast");
ctx.template Alloc<float>(&stride_data_fp32_t,
stride.numel() * sizeof(float));
int r2 = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(stride.data<T>()),
stride_data_fp32_t.data<float>(),
stride.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r2, "cast");
ctx.template Alloc<float>(&anchor_grid_data_fp32_t,
anchor_grid.numel() * sizeof(float));
int r3 = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(anchor_grid.data<T>()),
anchor_grid_data_fp32_t.data<float>(),
anchor_grid.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r3, "cast");
grid_data = grid_data_fp32_t.data<float>();
stride_data = stride_data_fp32_t.data<float>();
anchor_grid_data = anchor_grid_data_fp32_t.data<float>();
} else {
grid_data = grid.data<float>();
stride_data = stride.data<float>();
anchor_grid_data = anchor_grid.data<float>();
}
std::vector<int64_t> x_shape = phi::vectorize(x.dims());
std::vector<int64_t> grid_shape = phi::vectorize(grid.dims());
std::vector<int64_t> stride_shape = phi::vectorize(stride.dims());
std::vector<int64_t> anchor_grid_shape = phi::vectorize(anchor_grid.dims());
// yolo_box_coord only support fp32&&fp16 precision
int r = xpu::yolo_box_coord<XPUType>(
/* baidu::xpu::api::Context* ctx */ ctx.x_context(),
/* const T* x */ x_data,
/* T* y */ out_data,
/* const std::vector<int64_t>& x_shape */ x_shape,
/* const float* grid */ grid_data,
/* const float* stride */ stride_data,
/* const float* anchor_grid */ anchor_grid_data,
/* const std::vector<int64_t>& grid_shape */ grid_shape,
/* const std::vector<int64_t>& stride_shape */ stride_shape,
/* const std::vector<int64_t>& anchor_grid */ anchor_grid_shape,
/* float offset */ offset,
/* float* x_max */ x_max_data,
/* float* y_max */ ctx.template Alloc<float>(out_max));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "yolo_box_xpu");
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(yolo_box_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::YoloBoxXPUKernel,
float,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册