未验证 提交 6b58de95 编写于 作者: Z zhupengyang 提交者: GitHub

add yolo_box_fuse_pass, yolo_box_head_op, yolo_box_post_op (#42641)

上级 a51c492c
......@@ -95,6 +95,7 @@ pass_library(skip_layernorm_fuse_pass base)
pass_library(multihead_matmul_fuse_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(yolo_box_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/yolo_box_fuse_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
namespace patterns {
struct YoloBoxPattern : public PatternBase {
YoloBoxPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
// elementwise_div pattern
auto* elt_div_in_x = pattern->NewNode(elt_div_in_x_repr())
->assert_is_op_input("elementwise_div", "X");
auto* elt_div_in_y = pattern->NewNode(elt_div_in_y_repr())
->assert_is_op_input("elementwise_div", "Y");
auto* elt_div =
pattern->NewNode(elt_div_repr())->assert_is_op("elementwise_div");
auto* elt_div_out = pattern->NewNode(elt_div_out_repr())
->assert_is_op_output("elementwise_div", "Out")
->assert_is_op_input("cast", "X");
elt_div->LinksFrom({elt_div_in_x, elt_div_in_y}).LinksTo({elt_div_out});
// cast pattern
auto* cast = pattern->NewNode(cast_repr())->assert_is_op("cast");
auto* cast_out = pattern->NewNode(cast_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("yolo_box", "ImgSize");
cast->LinksFrom({elt_div_out}).LinksTo({cast_out});
// 3 * (yolo_box + transpose) pattern
#define YOLO_BOX_TRANSPOSE_PATTERN(idx_) \
auto* yolo_box##idx_##_in_x = pattern->NewNode(yolo_box##idx_##_in_x_repr()) \
->assert_is_op_input("yolo_box", "X"); \
auto* yolo_box##idx_ = \
pattern->NewNode(yolo_box##idx_##_repr())->assert_is_op("yolo_box"); \
auto* yolo_box##idx_##_out_boxes = \
pattern->NewNode(yolo_box##idx_##_out_boxes_repr()) \
->assert_is_op_output("yolo_box", "Boxes") \
->assert_is_op_nth_input("concat", "X", idx_); \
auto* yolo_box##idx_##_out_scores = \
pattern->NewNode(yolo_box##idx_##_out_scores_repr()) \
->assert_is_op_output("yolo_box", "Scores") \
->assert_is_op_input("transpose2", "X"); \
yolo_box##idx_->LinksFrom({yolo_box##idx_##_in_x, cast_out}) \
.LinksTo({yolo_box##idx_##_out_boxes, yolo_box##idx_##_out_scores}); \
auto* transpose##idx_ = \
pattern->NewNode(transpose##idx_##_repr())->assert_is_op("transpose2"); \
auto* transpose##idx_##_out = \
pattern->NewNode(transpose##idx_##_out_repr()) \
->assert_is_op_output("transpose2", "Out") \
->assert_is_op_nth_input("concat", "X", idx_); \
auto* transpose##idx_##_out_xshape = \
pattern->NewNode(transpose##idx_##_out_xshape_repr()) \
->assert_is_op_output("transpose2", "XShape"); \
transpose##idx_->LinksFrom({yolo_box##idx_##_out_scores}) \
.LinksTo({transpose##idx_##_out, transpose##idx_##_out_xshape});
YOLO_BOX_TRANSPOSE_PATTERN(0);
YOLO_BOX_TRANSPOSE_PATTERN(1);
YOLO_BOX_TRANSPOSE_PATTERN(2);
#undef YOLO_BOX_TRANSPOSE_PATTERN
// concat0 pattern
auto* concat0 = pattern->NewNode(concat0_repr())->assert_is_op("concat");
auto* concat0_out = pattern->NewNode(concat0_out_repr())
->assert_is_op_output("concat", "Out")
->assert_is_op_input("multiclass_nms3", "BBoxes");
concat0
->LinksFrom(
{yolo_box0_out_boxes, yolo_box1_out_boxes, yolo_box2_out_boxes})
.LinksTo({concat0_out});
// concat1 pattern
auto* concat1 = pattern->NewNode(concat1_repr())->assert_is_op("concat");
auto* concat1_out = pattern->NewNode(concat1_out_repr())
->assert_is_op_output("concat", "Out")
->assert_is_op_input("multiclass_nms3", "Scores");
concat1->LinksFrom({transpose0_out, transpose1_out, transpose2_out})
.LinksTo({concat1_out});
// nms pattern
auto* nms = pattern->NewNode(nms_repr())->assert_is_op("multiclass_nms3");
auto* nms_out = pattern->NewNode(nms_out_repr())
->assert_is_op_output("multiclass_nms3", "Out");
auto* nms_out_index = pattern->NewNode(nms_out_index_repr())
->assert_is_op_output("multiclass_nms3", "Index");
auto* nms_out_rois_num =
pattern->NewNode(nms_out_rois_num_repr())
->assert_is_op_output("multiclass_nms3", "NmsRoisNum");
nms->LinksFrom({concat0_out, concat1_out})
.LinksTo({nms_out, nms_out_index, nms_out_rois_num});
}
// declare operator node's name
PATTERN_DECL_NODE(elt_div);
PATTERN_DECL_NODE(cast);
PATTERN_DECL_NODE(yolo_box0);
PATTERN_DECL_NODE(yolo_box1);
PATTERN_DECL_NODE(yolo_box2);
PATTERN_DECL_NODE(concat0);
PATTERN_DECL_NODE(transpose0);
PATTERN_DECL_NODE(transpose1);
PATTERN_DECL_NODE(transpose2);
PATTERN_DECL_NODE(concat1);
PATTERN_DECL_NODE(nms);
// declare variable node's name
PATTERN_DECL_NODE(elt_div_in_x);
PATTERN_DECL_NODE(elt_div_in_y);
PATTERN_DECL_NODE(elt_div_out);
PATTERN_DECL_NODE(cast_out);
PATTERN_DECL_NODE(yolo_box0_in_x);
PATTERN_DECL_NODE(yolo_box1_in_x);
PATTERN_DECL_NODE(yolo_box2_in_x);
PATTERN_DECL_NODE(yolo_box0_out_boxes);
PATTERN_DECL_NODE(yolo_box1_out_boxes);
PATTERN_DECL_NODE(yolo_box2_out_boxes);
PATTERN_DECL_NODE(yolo_box0_out_scores);
PATTERN_DECL_NODE(yolo_box1_out_scores);
PATTERN_DECL_NODE(yolo_box2_out_scores);
PATTERN_DECL_NODE(concat0_out);
PATTERN_DECL_NODE(transpose0_out);
PATTERN_DECL_NODE(transpose1_out);
PATTERN_DECL_NODE(transpose2_out);
PATTERN_DECL_NODE(transpose0_out_xshape);
PATTERN_DECL_NODE(transpose1_out_xshape);
PATTERN_DECL_NODE(transpose2_out_xshape);
PATTERN_DECL_NODE(concat1_out);
PATTERN_DECL_NODE(nms_out);
PATTERN_DECL_NODE(nms_out_index);
PATTERN_DECL_NODE(nms_out_rois_num);
};
} // namespace patterns
YoloBoxFusePass::YoloBoxFusePass() {}
void YoloBoxFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::YoloBoxPattern yolo_box_pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle YoloBoxFusePass fuse";
#define GET_IR_NODE(node_) \
GET_IR_NODE_FROM_SUBGRAPH(node_, node_, yolo_box_pattern)
GET_IR_NODE(elt_div);
GET_IR_NODE(cast);
GET_IR_NODE(yolo_box0);
GET_IR_NODE(yolo_box1);
GET_IR_NODE(yolo_box2);
GET_IR_NODE(concat0);
GET_IR_NODE(transpose0);
GET_IR_NODE(transpose1);
GET_IR_NODE(transpose2);
GET_IR_NODE(concat1);
GET_IR_NODE(nms);
GET_IR_NODE(elt_div_in_x);
GET_IR_NODE(elt_div_in_y);
GET_IR_NODE(elt_div_out);
GET_IR_NODE(cast_out);
GET_IR_NODE(yolo_box0_in_x);
GET_IR_NODE(yolo_box1_in_x);
GET_IR_NODE(yolo_box2_in_x);
GET_IR_NODE(yolo_box0_out_boxes);
GET_IR_NODE(yolo_box1_out_boxes);
GET_IR_NODE(yolo_box2_out_boxes);
GET_IR_NODE(yolo_box0_out_scores);
GET_IR_NODE(yolo_box1_out_scores);
GET_IR_NODE(yolo_box2_out_scores);
GET_IR_NODE(concat0_out);
GET_IR_NODE(transpose0_out);
GET_IR_NODE(transpose1_out);
GET_IR_NODE(transpose2_out);
GET_IR_NODE(transpose0_out_xshape);
GET_IR_NODE(transpose1_out_xshape);
GET_IR_NODE(transpose2_out_xshape);
GET_IR_NODE(concat1_out);
GET_IR_NODE(nms_out);
GET_IR_NODE(nms_out_index);
GET_IR_NODE(nms_out_rois_num);
#undef GET_IR_NODE
// create yolo_box_head
#define CREATE_YOLO_BOX_HEAD(idx_) \
framework::OpDesc yolo_box_head##idx_##_op_desc; \
yolo_box_head##idx_##_op_desc.SetType("yolo_box_head"); \
yolo_box_head##idx_##_op_desc.SetInput("X", \
{yolo_box##idx_##_in_x->Name()}); \
yolo_box_head##idx_##_op_desc.SetAttr( \
"anchors", yolo_box##idx_->Op()->GetAttr("anchors")); \
yolo_box_head##idx_##_op_desc.SetAttr( \
"class_num", yolo_box##idx_->Op()->GetAttr("class_num")); \
yolo_box_head##idx_##_op_desc.SetOutput( \
"Out", {yolo_box##idx_##_out_boxes->Name()}); \
yolo_box_head##idx_##_op_desc.Flush(); \
auto* yolo_box_head##idx_ = \
graph->CreateOpNode(&yolo_box_head##idx_##_op_desc); \
IR_NODE_LINK_TO(yolo_box##idx_##_in_x, yolo_box_head##idx_); \
IR_NODE_LINK_TO(yolo_box_head##idx_, yolo_box##idx_##_out_boxes);
CREATE_YOLO_BOX_HEAD(0);
CREATE_YOLO_BOX_HEAD(1);
CREATE_YOLO_BOX_HEAD(2);
#undef CREATE_YOLO_BOX_HEAD
// create yolo_box_post
framework::OpDesc yolo_box_post_op_desc;
yolo_box_post_op_desc.SetType("yolo_box_post");
yolo_box_post_op_desc.SetInput("Boxes0", {yolo_box0_out_boxes->Name()});
yolo_box_post_op_desc.SetInput("Boxes1", {yolo_box1_out_boxes->Name()});
yolo_box_post_op_desc.SetInput("Boxes2", {yolo_box2_out_boxes->Name()});
yolo_box_post_op_desc.SetInput("ImageShape", {elt_div_in_x->Name()});
yolo_box_post_op_desc.SetInput("ImageScale", {elt_div_in_y->Name()});
yolo_box_post_op_desc.SetAttr("anchors0",
yolo_box0->Op()->GetAttr("anchors"));
yolo_box_post_op_desc.SetAttr("anchors1",
yolo_box1->Op()->GetAttr("anchors"));
yolo_box_post_op_desc.SetAttr("anchors2",
yolo_box2->Op()->GetAttr("anchors"));
yolo_box_post_op_desc.SetAttr("class_num",
yolo_box0->Op()->GetAttr("class_num"));
yolo_box_post_op_desc.SetAttr("conf_thresh",
yolo_box0->Op()->GetAttr("conf_thresh"));
yolo_box_post_op_desc.SetAttr("downsample_ratio0",
yolo_box0->Op()->GetAttr("downsample_ratio"));
yolo_box_post_op_desc.SetAttr("downsample_ratio1",
yolo_box1->Op()->GetAttr("downsample_ratio"));
yolo_box_post_op_desc.SetAttr("downsample_ratio2",
yolo_box2->Op()->GetAttr("downsample_ratio"));
yolo_box_post_op_desc.SetAttr("clip_bbox",
yolo_box0->Op()->GetAttr("clip_bbox"));
yolo_box_post_op_desc.SetAttr("scale_x_y",
yolo_box0->Op()->GetAttr("scale_x_y"));
yolo_box_post_op_desc.SetAttr("nms_threshold",
nms->Op()->GetAttr("nms_threshold"));
yolo_box_post_op_desc.SetOutput("Out", {nms_out->Name()});
yolo_box_post_op_desc.SetOutput("NmsRoisNum", {nms_out_rois_num->Name()});
auto* yolo_box_post = graph->CreateOpNode(&yolo_box_post_op_desc);
IR_NODE_LINK_TO(yolo_box0_out_boxes, yolo_box_post);
IR_NODE_LINK_TO(yolo_box1_out_boxes, yolo_box_post);
IR_NODE_LINK_TO(yolo_box2_out_boxes, yolo_box_post);
IR_NODE_LINK_TO(elt_div_in_x, yolo_box_post);
IR_NODE_LINK_TO(elt_div_in_y, yolo_box_post);
IR_NODE_LINK_TO(yolo_box_post, nms_out);
IR_NODE_LINK_TO(yolo_box_post, nms_out_rois_num);
// delete useless node
GraphSafeRemoveNodes(graph, {elt_div,
cast,
yolo_box0,
yolo_box1,
yolo_box2,
concat0,
transpose0,
transpose1,
transpose2,
concat1,
nms,
elt_div_out,
cast_out,
yolo_box0_out_scores,
yolo_box1_out_scores,
yolo_box2_out_scores,
concat0_out,
transpose0_out,
transpose1_out,
transpose2_out,
transpose0_out_xshape,
transpose1_out_xshape,
transpose2_out_xshape,
concat1_out,
nms_out_index});
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(yolo_box_fuse_pass, paddle::framework::ir::YoloBoxFusePass);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
/*
1. before fuse:
div
|
cast-----------|-------------|
| | |
yolo_box yolo_box yolo_box
| | |
transpose-| transpose-| transpose-|
|------|-----|-------|------| |
| concat | |
|-----|-------|-------------|
| cocnat
|-------|
nms3
2. after fuse:
yolo_box_head yolo_box_head yolo_box_head
|------------------|------------------|
yolo_box_post
*/
class YoloBoxFusePass : public FusePassBase {
public:
YoloBoxFusePass();
virtual ~YoloBoxFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
std::string name_scope_{"yolo_box_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -1777,6 +1777,7 @@ USE_TRT_CONVERTER(clip);
USE_TRT_CONVERTER(gather);
USE_TRT_CONVERTER(anchor_generator);
USE_TRT_CONVERTER(yolo_box);
USE_TRT_CONVERTER(yolo_box_head);
USE_TRT_CONVERTER(roi_align);
USE_TRT_CONVERTER(affine_channel);
USE_TRT_CONVERTER(multiclass_nms);
......
......@@ -111,6 +111,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"conv_elementwise_add_fuse_pass", //
// "remove_padding_recover_padding_pass", //
// "delete_remove_padding_recover_padding_pass", //
// "yolo_box_fuse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
......
......@@ -36,6 +36,7 @@ nv_library(tensorrt_converter
gather_op.cc
anchor_generator_op.cc
yolo_box_op.cc
yolo_box_head_op.cc
roi_align_op.cc
affine_channel_op.cc
multiclass_nms_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_head_op_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
class YoloBoxHeadOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a yolo_box_head op to tensorrt plugin";
framework::OpDesc op_desc(op, nullptr);
auto* x_tensor = engine_->GetITensor(op_desc.Input("X").front());
std::vector<int> anchors =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("anchors"));
int class_num = BOOST_GET_CONST(int, op_desc.GetAttr("class_num"));
auto* yolo_box_plugin = new plugin::YoloBoxHeadPlugin(anchors, class_num);
std::vector<nvinfer1::ITensor*> yolo_box_inputs;
yolo_box_inputs.push_back(x_tensor);
auto* yolo_box_head_layer = engine_->network()->addPluginV2(
yolo_box_inputs.data(), yolo_box_inputs.size(), *yolo_box_plugin);
std::vector<std::string> output_names;
output_names.push_back(op_desc.Output("Out").front());
RreplenishLayerAndOutput(yolo_box_head_layer, "yolo_box_head", output_names,
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(yolo_box_head, YoloBoxHeadOpConverter);
......@@ -100,6 +100,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"gather",
"gather_nd",
"yolo_box",
"yolo_box_head",
"roi_align",
"affine_channel",
"nearest_interp",
......@@ -165,6 +166,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"gather",
"gather_nd",
"yolo_box",
"yolo_box_head",
"roi_align",
"affine_channel",
"nearest_interp",
......@@ -634,6 +636,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (!has_attrs) return false;
}
if (op_type == "yolo_box_head") {
if (with_dynamic_shape) return false;
bool has_attrs = desc.HasAttr("class_num") && desc.HasAttr("anchors");
if (!has_attrs) return false;
}
if (op_type == "affine_channel") {
if (!desc.HasAttr("data_layout")) return false;
auto data_layout = framework::StringToDataLayout(
......
......@@ -7,6 +7,7 @@ nv_library(tensorrt_plugin
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
anchor_generator_op_plugin.cu
yolo_box_op_plugin.cu
yolo_box_head_op_plugin.cu
roi_align_op_plugin.cu
gather_nd_op_plugin.cu
mish_op_plugin.cu
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_head_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
inline __device__ float SigmoidGPU(const float& x) {
return 1.0f / (1.0f + __expf(-x));
}
__global__ void YoloBoxHeadKernel(const float* input, float* output,
const int grid_size_x, const int grid_size_y,
const int class_num, const int anchors_num) {
int x_id = blockIdx.x * blockDim.x + threadIdx.x;
int y_id = blockIdx.y * blockDim.y + threadIdx.y;
int z_id = blockIdx.z * blockDim.z + threadIdx.z;
if ((x_id >= grid_size_x) || (y_id >= grid_size_y) || (z_id >= anchors_num)) {
return;
}
const int grids_num = grid_size_x * grid_size_y;
const int bbindex = y_id * grid_size_x + x_id;
// objectness
output[bbindex + grids_num * (z_id * (5 + class_num) + 4)] =
SigmoidGPU(input[bbindex + grids_num * (z_id * (5 + class_num) + 4)]);
// x
output[bbindex + grids_num * (z_id * (5 + class_num) + 0)] =
SigmoidGPU(input[bbindex + grids_num * (z_id * (5 + class_num) + 0)]);
// y
output[bbindex + grids_num * (z_id * (5 + class_num) + 1)] =
SigmoidGPU(input[bbindex + grids_num * (z_id * (5 + class_num) + 1)]);
// w
output[bbindex + grids_num * (z_id * (5 + class_num) + 2)] =
__expf(input[bbindex + grids_num * (z_id * (5 + class_num) + 2)]);
// h
output[bbindex + grids_num * (z_id * (5 + class_num) + 3)] =
__expf(input[bbindex + grids_num * (z_id * (5 + class_num) + 3)]);
// Probabilities of classes
for (int i = 0; i < class_num; ++i) {
output[bbindex + grids_num * (z_id * (5 + class_num) + (5 + i))] =
SigmoidGPU(
input[bbindex + grids_num * (z_id * (5 + class_num) + (5 + i))]);
}
}
int YoloBoxHeadPlugin::enqueue(int batch_size, const void* const* inputs,
#if IS_TRT_VERSION_LT(8000)
void** outputs,
#else
void* const* outputs,
#endif
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const int h = input_dims_[0].d[1];
const int w = input_dims_[0].d[2];
const int grid_size_x = w;
const int grid_size_y = h;
const int anchors_num = anchors_.size() / 2;
const float* input_data = static_cast<const float*>(inputs[0]);
float* output_data = static_cast<float*>(outputs[0]);
const int volume = input_dims_[0].d[0] * h * w;
dim3 block(16, 16, 4);
dim3 grid((grid_size_x / block.x) + 1, (grid_size_y / block.y) + 1,
(anchors_num / block.z) + 1);
for (int n = 0; n < batch_size; n++) {
YoloBoxHeadKernel<<<grid, block, 0, stream>>>(
input_data + n * volume, output_data + n * volume, grid_size_x,
grid_size_y, class_num_, anchors_num);
}
return 0;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class YoloBoxHeadPlugin : public PluginTensorRT {
public:
explicit YoloBoxHeadPlugin(const std::vector<int>& anchors,
const int class_num)
: anchors_(anchors), class_num_(class_num) {}
YoloBoxHeadPlugin(const void* data, size_t length) {
deserializeBase(data, length);
DeserializeValue(&data, &length, &anchors_);
DeserializeValue(&data, &length, &class_num_);
}
~YoloBoxHeadPlugin() override{};
nvinfer1::IPluginV2* clone() const TRT_NOEXCEPT override {
return new YoloBoxHeadPlugin(anchors_, class_num_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "yolo_box_head_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override { return 0; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nb_input_dims) TRT_NOEXCEPT override {
assert(index == 0);
assert(nb_input_dims == 1);
return inputs[0];
}
int enqueue(int batch_size, const void* const* inputs,
#if IS_TRT_VERSION_LT(8000)
void** outputs,
#else
void* const* outputs,
#endif
void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override {
return getBaseSerializationSize() + SerializedSize(anchors_) +
SerializedSize(class_num_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
serializeBase(buffer);
SerializeValue(&buffer, anchors_);
SerializeValue(&buffer, class_num_);
}
private:
std::vector<int> anchors_;
int class_num_;
std::string namespace_;
};
class YoloBoxHeadPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "yolo_box_head_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT override {
return new YoloBoxHeadPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(YoloBoxHeadPluginCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -11,6 +11,8 @@ register_operators(EXCLUDES
fused_fc_elementwise_layernorm_op
multihead_matmul_op
skip_layernorm_op
yolo_box_head_op
yolo_box_post_op
fused_embedding_eltwise_layernorm_op
fusion_group_op
fusion_gru_op
......@@ -53,6 +55,8 @@ if (WITH_GPU OR WITH_ROCM)
# multihead_matmul_op
op_library(multihead_matmul_op)
op_library(skip_layernorm_op)
op_library(yolo_box_head_op)
op_library(yolo_box_post_op)
op_library(fused_embedding_eltwise_layernorm_op)
# fusion_group
if(NOT APPLE AND NOT WIN32)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class YoloBoxHeadOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "yolo_box_head");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "yolo_box_head");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
};
class YoloBoxHeadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "The input tensor");
AddAttr<std::vector<int>>("anchors",
"The anchor width and height, "
"it will be parsed pair by pair.");
AddAttr<int>("class_num", "The number of classes to predict.");
AddOutput("Out", "The output tensor");
AddComment(R"DOC(
yolo_box_head Operator.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(yolo_box_head, ops::YoloBoxHeadOp, ops::YoloBoxHeadOpMaker);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
template <typename T>
inline __device__ T SigmoidGPU(const T& x) {
return 1.0f / (1.0f + __expf(-x));
}
template <typename T>
__global__ void YoloBoxHeadCudaKernel(const T* input, T* output,
const int grid_size_x,
const int grid_size_y,
const int class_num,
const int anchors_num) {
int x_id = blockIdx.x * blockDim.x + threadIdx.x;
int y_id = blockIdx.y * blockDim.y + threadIdx.y;
int z_id = blockIdx.z * blockDim.z + threadIdx.z;
if ((x_id >= grid_size_x) || (y_id >= grid_size_y) || (z_id >= anchors_num)) {
return;
}
const int grids_num = grid_size_x * grid_size_y;
const int bbindex = y_id * grid_size_x + x_id;
// objectness
output[bbindex + grids_num * (z_id * (5 + class_num) + 4)] =
SigmoidGPU(input[bbindex + grids_num * (z_id * (5 + class_num) + 4)]);
// x
output[bbindex + grids_num * (z_id * (5 + class_num) + 0)] =
SigmoidGPU(input[bbindex + grids_num * (z_id * (5 + class_num) + 0)]);
// y
output[bbindex + grids_num * (z_id * (5 + class_num) + 1)] =
SigmoidGPU(input[bbindex + grids_num * (z_id * (5 + class_num) + 1)]);
// w
output[bbindex + grids_num * (z_id * (5 + class_num) + 2)] =
__expf(input[bbindex + grids_num * (z_id * (5 + class_num) + 2)]);
// h
output[bbindex + grids_num * (z_id * (5 + class_num) + 3)] =
__expf(input[bbindex + grids_num * (z_id * (5 + class_num) + 3)]);
// Probabilities of classes
for (int i = 0; i < class_num; ++i) {
output[bbindex + grids_num * (z_id * (5 + class_num) + (5 + i))] =
SigmoidGPU(
input[bbindex + grids_num * (z_id * (5 + class_num) + (5 + i))]);
}
}
template <typename T>
class YoloBoxHeadKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using Tensor = framework::Tensor;
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto anchors = context.Attr<std::vector<int>>("anchors");
auto class_num = context.Attr<int>("class_num");
auto& device_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto x_dims = x->dims();
const int batch_size = x_dims[0];
const int h = x_dims[2];
const int w = x_dims[3];
const int grid_size_x = w;
const int grid_size_y = h;
const int anchors_num = anchors.size() / 2;
const T* input_data = x->data<T>();
T* output_data = out->mutable_data<T>(context.GetPlace());
auto stream = device_ctx.stream();
const int volume = x_dims[1] * h * w;
dim3 block(16, 16, 4);
dim3 grid((grid_size_x / block.x) + 1, (grid_size_y / block.y) + 1,
(anchors_num / block.z) + 1);
for (int n = 0; n < batch_size; n++) {
YoloBoxHeadCudaKernel<<<grid, block, 0, stream>>>(
input_data + n * volume, output_data + n * volume, grid_size_x,
grid_size_y, class_num, anchors_num);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(yolo_box_head, ops::YoloBoxHeadKernel<float>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class YoloBoxPostOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("Boxes0"), "Input", "Boxes0", "yolo_box_post");
OP_INOUT_CHECK(ctx->HasInput("Boxes1"), "Input", "Boxes1", "yolo_box_post");
OP_INOUT_CHECK(ctx->HasInput("Boxes2"), "Input", "Boxes2", "yolo_box_post");
OP_INOUT_CHECK(ctx->HasInput("ImageShape"), "Input", "ImageShape",
"yolo_box_post");
OP_INOUT_CHECK(ctx->HasInput("ImageScale"), "Input", "ImageScale",
"yolo_box_post");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "yolo_box_post");
OP_INOUT_CHECK(ctx->HasOutput("NmsRoisNum"), "Output", "NmsRoisNum",
"yolo_box_post");
}
};
class YoloBoxPostOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("Boxes0", "The Boxes0 tensor");
AddInput("Boxes1", "The Boxes1 tensor");
AddInput("Boxes2", "The Boxes2 tensor");
AddInput("ImageShape", "The height and width of each input image.");
AddInput("ImageScale", "The scale factor of ImageShape.");
AddAttr<std::vector<int>>("anchors0", "The anchors of Boxes0.");
AddAttr<std::vector<int>>("anchors1", "The anchors of Boxes1.");
AddAttr<std::vector<int>>("anchors2", "The anchors of Boxes2.");
AddAttr<int>("class_num", "The number of classes to predict.");
AddAttr<float>("conf_thresh",
"The confidence scores threshold of detection boxes. "
"Boxes with confidence scores under threshold should "
"be ignored.");
AddAttr<int>("downsample_ratio0", "The downsample ratio of Boxes0.");
AddAttr<int>("downsample_ratio1", "The downsample ratio of Boxes1.");
AddAttr<int>("downsample_ratio2", "The downsample ratio of Boxes2.");
AddAttr<bool>("clip_bbox",
"Whether clip output bonding box in Input(ImgSize) "
"boundary. Default true.");
AddAttr<float>("scale_x_y",
"Scale the center point of decoded bounding "
"box. Default 1.0");
AddAttr<float>("nms_threshold", "The threshold to be used in NMS.");
AddOutput("Out", "The output tensor");
AddOutput("NmsRoisNum", "The output RoIs tensor");
AddComment(R"DOC(
yolo_box_post Operator.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(yolo_box_post, ops::YoloBoxPostOp, ops::YoloBoxPostOpMaker);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
struct Box {
float x, y, w, h;
};
struct Detection {
Box bbox;
int classes;
float* prob;
float* mask;
float objectness;
int sort_class;
int max_prob_class_index;
};
struct TensorInfo {
int bbox_count_host; // record bbox numbers
int bbox_count_max_alloc{50};
float* bboxes_dev_ptr;
float* bboxes_host_ptr;
int* bbox_count_device_ptr; // Box counter in gpu memory, used by atomicAdd
};
static int NMSComparator(const void* pa, const void* pb) {
const Detection a = *reinterpret_cast<const Detection*>(pa);
const Detection b = *reinterpret_cast<const Detection*>(pb);
if (a.max_prob_class_index > b.max_prob_class_index)
return 1;
else if (a.max_prob_class_index < b.max_prob_class_index)
return -1;
float diff = 0;
if (b.sort_class >= 0) {
diff = a.prob[b.sort_class] - b.prob[b.sort_class];
} else {
diff = a.objectness - b.objectness;
}
if (diff < 0)
return 1;
else if (diff > 0)
return -1;
return 0;
}
static float Overlap(float x1, float w1, float x2, float w2) {
float l1 = x1 - w1 / 2;
float l2 = x2 - w2 / 2;
float left = l1 > l2 ? l1 : l2;
float r1 = x1 + w1 / 2;
float r2 = x2 + w2 / 2;
float right = r1 < r2 ? r1 : r2;
return right - left;
}
static float BoxIntersection(Box a, Box b) {
float w = Overlap(a.x, a.w, b.x, b.w);
float h = Overlap(a.y, a.h, b.y, b.h);
if (w < 0 || h < 0) return 0;
float area = w * h;
return area;
}
static float BoxUnion(Box a, Box b) {
float i = BoxIntersection(a, b);
float u = a.w * a.h + b.w * b.h - i;
return u;
}
static float BoxIOU(Box a, Box b) {
return BoxIntersection(a, b) / BoxUnion(a, b);
}
static void PostNMS(std::vector<Detection>* det_bboxes, float thresh,
int classes) {
int total = det_bboxes->size();
if (total <= 0) {
return;
}
Detection* dets = det_bboxes->data();
int i, j, k;
k = total - 1;
for (i = 0; i <= k; ++i) {
if (dets[i].objectness == 0) {
Detection swap = dets[i];
dets[i] = dets[k];
dets[k] = swap;
--k;
--i;
}
}
total = k + 1;
qsort(dets, total, sizeof(Detection), NMSComparator);
for (i = 0; i < total; ++i) {
if (dets[i].objectness == 0) continue;
Box a = dets[i].bbox;
for (j = i + 1; j < total; ++j) {
if (dets[j].objectness == 0) continue;
if (dets[j].max_prob_class_index != dets[i].max_prob_class_index) break;
Box b = dets[j].bbox;
if (BoxIOU(a, b) > thresh) {
dets[j].objectness = 0;
for (k = 0; k < classes; ++k) {
dets[j].prob[k] = 0;
}
}
}
}
}
__global__ void YoloBoxNum(const float* input, int* bbox_count,
const int grid_size, const int class_num,
const int anchors_num, float prob_thresh) {
int x_id = blockIdx.x * blockDim.x + threadIdx.x;
int y_id = blockIdx.y * blockDim.y + threadIdx.y;
int z_id = blockIdx.z * blockDim.z + threadIdx.z;
if ((x_id >= grid_size) || (y_id >= grid_size) || (z_id >= anchors_num)) {
return;
}
const int grids_num = grid_size * grid_size;
const int bbindex = y_id * grid_size + x_id;
float objectness = input[bbindex + grids_num * (z_id * (5 + class_num) + 4)];
if (objectness < prob_thresh) {
return;
}
atomicAdd(bbox_count, 1);
}
__global__ void YoloTensorParseKernel(
const float* input, const float* im_shape_data, const float* im_scale_data,
float* output, int* bbox_index, const int grid_size, const int class_num,
const int anchors_num, const int netw, const int neth, int* biases,
float prob_thresh) {
int x_id = blockIdx.x * blockDim.x + threadIdx.x;
int y_id = blockIdx.y * blockDim.y + threadIdx.y;
int z_id = blockIdx.z * blockDim.z + threadIdx.z;
if ((x_id >= grid_size) || (y_id >= grid_size) || (z_id >= anchors_num)) {
return;
}
const float pic_h = im_shape_data[0] / im_scale_data[0];
const float pic_w = im_shape_data[1] / im_scale_data[1];
const int grids_num = grid_size * grid_size;
const int bbindex = y_id * grid_size + x_id;
float objectness = input[bbindex + grids_num * (z_id * (5 + class_num) + 4)];
if (objectness < prob_thresh) {
return;
}
int cur_bbox_index = atomicAdd(bbox_index, 1);
int tensor_index = cur_bbox_index * (5 + class_num);
// x
float x = input[bbindex + grids_num * (z_id * (5 + class_num) + 0)];
x = (x + static_cast<float>(x_id)) * static_cast<float>(pic_w) /
static_cast<float>(grid_size);
// y
float y = input[bbindex + grids_num * (z_id * (5 + class_num) + 1)];
y = (y + static_cast<float>(y_id)) * static_cast<float>(pic_h) /
static_cast<float>(grid_size);
// w
float w = input[bbindex + grids_num * (z_id * (5 + class_num) + 2)];
w = w * biases[2 * z_id] * pic_w / netw;
// h
float h = input[bbindex + grids_num * (z_id * (5 + class_num) + 3)];
h = h * biases[2 * z_id + 1] * pic_h / neth;
output[tensor_index] = objectness;
output[tensor_index + 1] = x - w / 2;
output[tensor_index + 2] = y - h / 2;
output[tensor_index + 3] = x + w / 2;
output[tensor_index + 4] = y + h / 2;
output[tensor_index + 1] =
output[tensor_index + 1] > 0 ? output[tensor_index + 1] : 0.f;
output[tensor_index + 2] =
output[tensor_index + 2] > 0 ? output[tensor_index + 2] : 0.f;
output[tensor_index + 3] = output[tensor_index + 3] < pic_w - 1
? output[tensor_index + 3]
: pic_w - 1;
output[tensor_index + 4] = output[tensor_index + 4] < pic_h - 1
? output[tensor_index + 4]
: pic_h - 1;
// Probabilities of classes
for (int i = 0; i < class_num; ++i) {
float prob =
input[bbindex + grids_num * (z_id * (5 + class_num) + (5 + i))] *
objectness;
output[tensor_index + 5 + i] = prob;
}
}
static void YoloTensorParseCuda(
const float* input_data, // [in] YOLO_BOX_HEAD layer output
const float* image_shape_data, const float* image_scale_data,
float** bboxes_tensor_ptr, // [out] Bounding boxes output tensor
int* bbox_count_max_alloc, // [in/out] maximum bounding Box number
// allocated in dev
int* bbox_count_host, // [in/out] bounding boxes number recorded in host
int* bbox_count_device_ptr, // [in/out] bounding boxes number calculated
// in
// device side
int* bbox_index_device_ptr, // [in] bounding Box index for kernel threads
// shared access
int grid_size, int class_num, int anchors_num, int netw, int neth,
int* biases_device, float prob_thresh) {
dim3 threads_per_block(16, 16, 4);
dim3 number_of_blocks((grid_size / threads_per_block.x) + 1,
(grid_size / threads_per_block.y) + 1,
(anchors_num / threads_per_block.z) + 1);
// Estimate how many boxes will be choosed
int bbox_count = 0;
#ifdef PADDLE_WITH_HIP
hipMemcpy(bbox_count_device_ptr, &bbox_count, sizeof(int),
hipMemcpyHostToDevice);
#else
cudaMemcpy(bbox_count_device_ptr, &bbox_count, sizeof(int),
cudaMemcpyHostToDevice);
#endif
YoloBoxNum<<<number_of_blocks, threads_per_block, 0>>>(
input_data, bbox_count_device_ptr, grid_size, class_num, anchors_num,
prob_thresh);
#ifdef PADDLE_WITH_HIP
hipMemcpy(&bbox_count, bbox_count_device_ptr, sizeof(int),
hipMemcpyDeviceToHost);
#else
cudaMemcpy(&bbox_count, bbox_count_device_ptr, sizeof(int),
cudaMemcpyDeviceToHost);
#endif
// Record actual bbox number
*bbox_count_host = bbox_count;
// Obtain previous allocated bbox tensor in device side
float* bbox_tensor = *bboxes_tensor_ptr;
// Update previous maximum bbox number
if (bbox_count > *bbox_count_max_alloc) {
#ifdef PADDLE_WITH_HIP
hipFree(bbox_tensor);
hipMalloc(&bbox_tensor, bbox_count * (5 + class_num) * sizeof(float));
#else
cudaFree(bbox_tensor);
cudaMalloc(&bbox_tensor, bbox_count * (5 + class_num) * sizeof(float));
#endif
*bbox_count_max_alloc = bbox_count;
*bboxes_tensor_ptr = bbox_tensor;
}
// Now generate bboxes
int bbox_index = 0;
#ifdef PADDLE_WITH_HIP
hipMemcpy(bbox_index_device_ptr, &bbox_index, sizeof(int),
hipMemcpyHostToDevice);
#else
cudaMemcpy(bbox_index_device_ptr, &bbox_index, sizeof(int),
cudaMemcpyHostToDevice);
#endif
YoloTensorParseKernel<<<number_of_blocks, threads_per_block, 0>>>(
input_data, image_shape_data, image_scale_data, bbox_tensor,
bbox_index_device_ptr, grid_size, class_num, anchors_num, netw, neth,
biases_device, prob_thresh);
}
template <typename T>
class YoloBoxPostKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using Tensor = framework::Tensor;
// prepare inputs
std::vector<const float*> boxes_input(3);
std::vector<std::vector<int32_t>> boxes_input_dims(3);
for (int i = 0; i < 3; i++) {
auto* boxes_tensor =
context.Input<framework::Tensor>("Boxes" + std::to_string(i));
boxes_input[i] = boxes_tensor->data<float>();
auto dims = boxes_tensor->dims();
for (int j = 0; j < dims.size(); j++) {
boxes_input_dims[i].push_back(dims[j]);
}
}
const float* image_shape_data =
context.Input<framework::Tensor>("ImageShape")->data<float>();
const float* image_scale_data =
context.Input<framework::Tensor>("ImageScale")->data<float>();
// prepare outputs
auto* boxes_scores_tensor = context.Output<framework::Tensor>("Out");
auto* boxes_num_tensor = context.Output<framework::Tensor>("NmsRoisNum");
// prepare anchors
std::vector<int32_t> anchors;
auto anchors0 = context.Attr<std::vector<int>>("anchors0");
auto anchors1 = context.Attr<std::vector<int>>("anchors1");
auto anchors2 = context.Attr<std::vector<int>>("anchors2");
anchors.insert(anchors.end(), anchors0.begin(), anchors0.end());
anchors.insert(anchors.end(), anchors1.begin(), anchors1.end());
anchors.insert(anchors.end(), anchors2.begin(), anchors2.end());
int* device_anchors;
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void**>(&device_anchors),
anchors.size() * sizeof(int));
hipMemcpy(device_anchors, anchors.data(), anchors.size() * sizeof(int),
hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void**>(&device_anchors),
anchors.size() * sizeof(int));
cudaMemcpy(device_anchors, anchors.data(), anchors.size() * sizeof(int),
cudaMemcpyHostToDevice);
#endif
int* device_anchors_ptr[3];
device_anchors_ptr[0] = device_anchors;
device_anchors_ptr[1] = device_anchors_ptr[0] + anchors0.size();
device_anchors_ptr[2] = device_anchors_ptr[1] + anchors1.size();
std::vector<int> anchors_num{static_cast<int>(anchors0.size()) / 2,
static_cast<int>(anchors1.size()) / 2,
static_cast<int>(anchors2.size()) / 2};
// prepare other attrs
int class_num = context.Attr<int>("class_num");
float conf_thresh = context.Attr<float>("conf_thresh");
std::vector<int> downsample_ratio{context.Attr<int>("downsample_ratio0"),
context.Attr<int>("downsample_ratio1"),
context.Attr<int>("downsample_ratio2")};
// clip_bbox and scale_x_y is not used now!
float nms_threshold = context.Attr<float>("nms_threshold");
int batch = context.Input<framework::Tensor>("ImageShape")->dims()[0];
TensorInfo* ts_info = new TensorInfo[batch * boxes_input.size()];
for (int i = 0; i < batch * static_cast<int>(boxes_input.size()); i++) {
#ifdef PADDLE_WITH_HIP
hipMalloc(
reinterpret_cast<void**>(&ts_info[i].bboxes_dev_ptr),
ts_info[i].bbox_count_max_alloc * (5 + class_num) * sizeof(float));
#else
cudaMalloc(
reinterpret_cast<void**>(&ts_info[i].bboxes_dev_ptr),
ts_info[i].bbox_count_max_alloc * (5 + class_num) * sizeof(float));
#endif
ts_info[i].bboxes_host_ptr = reinterpret_cast<float*>(malloc(
ts_info[i].bbox_count_max_alloc * (5 + class_num) * sizeof(float)));
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void**>(&ts_info[i].bbox_count_device_ptr),
sizeof(int));
#else
cudaMalloc(reinterpret_cast<void**>(&ts_info[i].bbox_count_device_ptr),
sizeof(int));
#endif
}
// Box index counter in gpu memory
// *bbox_index_device_ptr used by atomicAdd
int* bbox_index_device_ptr;
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void**>(&bbox_index_device_ptr), sizeof(int));
#else
cudaMalloc(reinterpret_cast<void**>(&bbox_index_device_ptr), sizeof(int));
#endif
int total_bbox = 0;
for (int batch_id = 0; batch_id < batch; batch_id++) {
for (int input_id = 0; input_id < boxes_input.size(); input_id++) {
int c = boxes_input_dims[input_id][1];
int h = boxes_input_dims[input_id][2];
int w = boxes_input_dims[input_id][3];
int ts_id = batch_id * boxes_input.size() + input_id;
int bbox_count_max_alloc = ts_info[ts_id].bbox_count_max_alloc;
YoloTensorParseCuda(
boxes_input[input_id] + batch_id * c * h * w,
image_shape_data + batch_id * 2, image_scale_data + batch_id * 2,
&(ts_info[ts_id].bboxes_dev_ptr), // output in gpu,must use 2-level
// pointer, because we may
// re-malloc
&bbox_count_max_alloc, // bbox_count_alloc_ptr boxes we
// pre-allocate
&(ts_info[ts_id].bbox_count_host), // record bbox numbers
ts_info[ts_id].bbox_count_device_ptr, // for atomicAdd
bbox_index_device_ptr, // for atomicAdd
h, class_num, anchors_num[input_id], downsample_ratio[input_id] * h,
downsample_ratio[input_id] * w, device_anchors_ptr[input_id],
conf_thresh);
// batch info update
if (bbox_count_max_alloc > ts_info[ts_id].bbox_count_max_alloc) {
ts_info[ts_id].bbox_count_max_alloc = bbox_count_max_alloc;
ts_info[ts_id].bboxes_host_ptr = reinterpret_cast<float*>(
realloc(ts_info[ts_id].bboxes_host_ptr,
bbox_count_max_alloc * (5 + class_num) * sizeof(float)));
}
// we need copy bbox_count_host boxes to cpu memory
#ifdef PADDLE_WITH_HIP
hipMemcpyAsync(
ts_info[ts_id].bboxes_host_ptr, ts_info[ts_id].bboxes_dev_ptr,
ts_info[ts_id].bbox_count_host * (5 + class_num) * sizeof(float),
hipMemcpyDeviceToHost);
#else
cudaMemcpyAsync(
ts_info[ts_id].bboxes_host_ptr, ts_info[ts_id].bboxes_dev_ptr,
ts_info[ts_id].bbox_count_host * (5 + class_num) * sizeof(float),
cudaMemcpyDeviceToHost);
#endif
total_bbox += ts_info[ts_id].bbox_count_host;
}
}
boxes_scores_tensor->Resize({total_bbox > 0 ? total_bbox : 1, 6});
float* boxes_scores_data =
boxes_scores_tensor->mutable_data<float>(platform::CPUPlace());
memset(boxes_scores_data, 0, sizeof(float) * 6);
boxes_num_tensor->Resize({batch});
int* boxes_num_data =
boxes_num_tensor->mutable_data<int>(platform::CPUPlace());
int boxes_scores_id = 0;
// NMS
for (int batch_id = 0; batch_id < batch; batch_id++) {
std::vector<Detection> bbox_det_vec;
for (int input_id = 0; input_id < boxes_input.size(); input_id++) {
int ts_id = batch_id * boxes_input.size() + input_id;
int bbox_count = ts_info[ts_id].bbox_count_host;
if (bbox_count <= 0) {
continue;
}
float* bbox_host_ptr = ts_info[ts_id].bboxes_host_ptr;
for (int bbox_index = 0; bbox_index < bbox_count; ++bbox_index) {
Detection bbox_det;
memset(&bbox_det, 0, sizeof(Detection));
bbox_det.objectness = bbox_host_ptr[bbox_index * (5 + class_num) + 0];
bbox_det.bbox.x = bbox_host_ptr[bbox_index * (5 + class_num) + 1];
bbox_det.bbox.y = bbox_host_ptr[bbox_index * (5 + class_num) + 2];
bbox_det.bbox.w =
bbox_host_ptr[bbox_index * (5 + class_num) + 3] - bbox_det.bbox.x;
bbox_det.bbox.h =
bbox_host_ptr[bbox_index * (5 + class_num) + 4] - bbox_det.bbox.y;
bbox_det.classes = class_num;
bbox_det.prob =
reinterpret_cast<float*>(malloc(class_num * sizeof(float)));
int max_prob_class_id = -1;
float max_class_prob = 0.0;
for (int class_id = 0; class_id < class_num; class_id++) {
float prob =
bbox_host_ptr[bbox_index * (5 + class_num) + 5 + class_id];
bbox_det.prob[class_id] = prob;
if (prob > max_class_prob) {
max_class_prob = prob;
max_prob_class_id = class_id;
}
}
bbox_det.max_prob_class_index = max_prob_class_id;
bbox_det.sort_class = max_prob_class_id;
bbox_det_vec.push_back(bbox_det);
}
}
PostNMS(&bbox_det_vec, nms_threshold, class_num);
for (int i = 0; i < bbox_det_vec.size(); i++) {
boxes_scores_data[boxes_scores_id++] =
bbox_det_vec[i].max_prob_class_index;
boxes_scores_data[boxes_scores_id++] = bbox_det_vec[i].objectness;
boxes_scores_data[boxes_scores_id++] = bbox_det_vec[i].bbox.x;
boxes_scores_data[boxes_scores_id++] = bbox_det_vec[i].bbox.y;
boxes_scores_data[boxes_scores_id++] =
bbox_det_vec[i].bbox.w + bbox_det_vec[i].bbox.x;
boxes_scores_data[boxes_scores_id++] =
bbox_det_vec[i].bbox.h + bbox_det_vec[i].bbox.y;
free(bbox_det_vec[i].prob);
}
boxes_num_data[batch_id] = bbox_det_vec.size();
}
#ifdef PADDLE_WITH_HIP
hipFree(bbox_index_device_ptr);
#else
cudaFree(bbox_index_device_ptr);
#endif
for (int i = 0; i < batch * boxes_input.size(); i++) {
#ifdef PADDLE_WITH_HIP
hipFree(ts_info[i].bboxes_dev_ptr);
hipFree(ts_info[i].bbox_count_device_ptr);
#else
cudaFree(ts_info[i].bboxes_dev_ptr);
cudaFree(ts_info[i].bbox_count_device_ptr);
#endif
free(ts_info[i].bboxes_host_ptr);
}
delete[] ts_info;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(yolo_box_post, ops::YoloBoxPostKernel<float>);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import List, Dict, Any
import unittest
class TrtConvertYoloBoxHeadTest(TrtLayerAutoScanTest):
def sample_program_configs(self):
def generate_input(attrs: List[Dict[str, Any]], batch, shape):
gen_shape = shape.copy()
gen_shape.insert(0, batch)
return np.random.uniform(0, 1, gen_shape).astype("float32")
input_shape = [[255, 19, 19], [255, 38, 38], [255, 76, 76]]
anchors = [[116, 90, 156, 198, 373, 326], [30, 61, 62, 45, 59, 119],
[10, 13, 16, 30, 33, 23]]
class_num = 80
for batch in [1, 4]:
for i in range(len(anchors)):
attrs_dict = {
"anchors": anchors[i],
"class_num": class_num,
}
ops_config = [{
"op_type": "yolo_box_head",
"op_inputs": {
"X": ["yolo_box_head_input"],
},
"op_outputs": {
"Out": ["yolo_box_head_output"],
},
"op_attrs": attrs_dict
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"yolo_box_head_input": TensorConfig(data_gen=partial(
generate_input, attrs_dict, batch, input_shape[i]))
},
outputs=["yolo_box_head_output"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
# for static_shape
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), [1, 2], 1e-5
def test(self):
self.run_test()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
paddle.enable_static()
def yolo_box_post(box0,
box1,
box2,
im_shape,
im_scale,
anchors0=[116, 90, 156, 198, 373, 326],
anchors1=[30, 61, 62, 45, 59, 119],
anchors2=[10, 13, 16, 30, 33, 23],
class_num=80,
conf_thresh=0.005,
downsample_ratio0=32,
downsample_ratio1=16,
downsample_ratio2=8,
clip_bbox=True,
scale_x_y=1.,
nms_threshold=0.45):
helper = LayerHelper('yolo_box_post', **locals())
output = helper.create_variable_for_type_inference(dtype=box0.dtype)
nms_rois_num = helper.create_variable_for_type_inference(dtype='int32')
inputs = {
'Boxes0': box0,
'Boxes1': box1,
'Boxes2': box2,
"ImageShape": im_shape,
"ImageScale": im_scale
}
outputs = {'Out': output, 'NmsRoisNum': nms_rois_num}
helper.append_op(
type="yolo_box_post",
inputs=inputs,
attrs={
'anchors0': anchors0,
'anchors1': anchors1,
'anchors2': anchors2,
'class_num': class_num,
'conf_thresh': conf_thresh,
'downsample_ratio0': downsample_ratio0,
'downsample_ratio1': downsample_ratio1,
'downsample_ratio2': downsample_ratio2,
'clip_bbox': clip_bbox,
'scale_x_y': scale_x_y,
'nms_threshold': nms_threshold
},
outputs=outputs)
output.stop_gradient = True
nms_rois_num.stop_gradient = True
return output, nms_rois_num
@unittest.skipIf(not paddle.is_compiled_with_cuda(),
"only support cuda kernel.")
class TestYoloBoxPost(unittest.TestCase):
def test_yolo_box_post(self):
place = paddle.CUDAPlace(0)
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
box0 = paddle.static.data("box0", [1, 255, 19, 19])
box1 = paddle.static.data("box1", [1, 255, 38, 38])
box2 = paddle.static.data("box2", [1, 255, 76, 76])
im_shape = paddle.static.data("im_shape", [1, 2])
im_scale = paddle.static.data("im_scale", [1, 2])
out, rois_num = yolo_box_post(box0, box1, box2, im_shape, im_scale)
exe = paddle.static.Executor(place)
exe.run(startup_program)
feed = {
"box0": np.random.uniform(size=[1, 255, 19, 19]).astype("float32"),
"box1": np.random.uniform(size=[1, 255, 38, 38]).astype("float32"),
"box2": np.random.uniform(size=[1, 255, 76, 76]).astype("float32"),
"im_shape": np.array([[608., 608.]], "float32"),
"im_scale": np.array([[1., 1.]], "float32")
}
outs = exe.run(program, feed=feed, fetch_list=[out.name, rois_num.name])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
paddle.enable_static()
def multiclass_nms(bboxes,
scores,
score_threshold,
nms_top_k,
keep_top_k,
nms_threshold=0.3,
normalized=True,
nms_eta=1.,
background_label=-1):
helper = LayerHelper('multiclass_nms3', **locals())
output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
index = helper.create_variable_for_type_inference(dtype='int32')
nms_rois_num = helper.create_variable_for_type_inference(dtype='int32')
inputs = {'BBoxes': bboxes, 'Scores': scores}
outputs = {'Out': output, 'Index': index, 'NmsRoisNum': nms_rois_num}
helper.append_op(
type="multiclass_nms3",
inputs=inputs,
attrs={
'background_label': background_label,
'score_threshold': score_threshold,
'nms_top_k': nms_top_k,
'nms_threshold': nms_threshold,
'keep_top_k': keep_top_k,
'nms_eta': nms_eta,
'normalized': normalized
},
outputs=outputs)
output.stop_gradient = True
index.stop_gradient = True
return output, index, nms_rois_num
class TestYoloBoxPass(unittest.TestCase):
def test_yolo_box_pass(self):
program = paddle.static.Program()
with paddle.static.program_guard(program):
im_shape = paddle.static.data("im_shape", [1, 2])
im_scale = paddle.static.data("im_scale", [1, 2])
yolo_box0_x = paddle.static.data("yolo_box0_x", [1, 255, 19, 19])
yolo_box1_x = paddle.static.data("yolo_box1_x", [1, 255, 38, 38])
yolo_box2_x = paddle.static.data("yolo_box2_x", [1, 255, 76, 76])
div = paddle.divide(im_shape, im_scale)
cast = paddle.cast(div, "int32")
boxes0, scores0 = paddle.vision.ops.yolo_box(
yolo_box0_x, cast, [116, 90, 156, 198, 373, 326], 80, 0.005, 32)
boxes1, scores1 = paddle.vision.ops.yolo_box(
yolo_box1_x, cast, [30, 61, 62, 45, 59, 119], 80, 0.005, 16)
boxes2, scores2 = paddle.vision.ops.yolo_box(
yolo_box2_x, cast, [10, 13, 16, 30, 33, 23], 80, 0.005, 8)
transpose0 = paddle.transpose(scores0, [0, 2, 1])
transpose1 = paddle.transpose(scores1, [0, 2, 1])
transpose2 = paddle.transpose(scores2, [0, 2, 1])
concat0 = paddle.concat([boxes0, boxes1, boxes2], 1)
concat1 = paddle.concat([transpose0, transpose1, transpose2], 2)
out0, out1, out2 = multiclass_nms(concat0, concat1, 0.01, 1000, 100,
0.45, True, 1., 80)
graph = core.Graph(program.desc)
core.get_pass("yolo_box_fuse_pass").apply(graph)
graph = paddle.fluid.framework.IrGraph(graph)
op_nodes = graph.all_op_nodes()
for op_node in op_nodes:
op_type = op_node.op().type()
self.assertTrue(op_type in ["yolo_box_head", "yolo_box_post"])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册