diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7bb092d0e3c1c0bcf1cc5fb9873b14a111bef52c..21ef3b2312ff6d6bfeae6b3ce216af2bc9bc1db4 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1192,6 +1192,7 @@ USE_TRT_CONVERTER(scale); USE_TRT_CONVERTER(stack); USE_TRT_CONVERTER(clip); USE_TRT_CONVERTER(gather); +USE_TRT_CONVERTER(yolo_box); USE_TRT_CONVERTER(roi_align); USE_TRT_CONVERTER(affine_channel); USE_TRT_CONVERTER(multiclass_nms); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index bc7b7355ea192248b9786303f5657ef354e030ad..3f79230094241ca1582b6e820635afb085004d52 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -6,6 +6,7 @@ nv_library(tensorrt_converter shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc gather_op.cc + yolo_box_op.cc roi_align_op.cc affine_channel_op.cc multiclass_nms_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/yolo_box_op.cc b/paddle/fluid/inference/tensorrt/convert/yolo_box_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d12eaf736b754d623c2aa0e3c138a2ad80800b3 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/yolo_box_op.cc @@ -0,0 +1,79 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class YoloBoxOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid yolo box op to tensorrt plugin"; + + framework::OpDesc op_desc(op, nullptr); + std::string X = op_desc.Input("X").front(); + std::string img_size = op_desc.Input("ImgSize").front(); + + auto* X_tensor = engine_->GetITensor(X); + auto* img_size_tensor = engine_->GetITensor(img_size); + + int class_num = BOOST_GET_CONST(int, op_desc.GetAttr("class_num")); + std::vector anchors = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("anchors")); + + int downsample_ratio = + BOOST_GET_CONST(int, op_desc.GetAttr("downsample_ratio")); + float conf_thresh = BOOST_GET_CONST(float, op_desc.GetAttr("conf_thresh")); + bool clip_bbox = BOOST_GET_CONST(bool, op_desc.GetAttr("clip_bbox")); + float scale_x_y = BOOST_GET_CONST(float, op_desc.GetAttr("scale_x_y")); + + int type_id = static_cast(engine_->WithFp16()); + auto input_dim = X_tensor->getDimensions(); + auto* yolo_box_plugin = new plugin::YoloBoxPlugin( + type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + anchors, class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y, + input_dim.d[1], input_dim.d[2]); + + std::vector yolo_box_inputs; + yolo_box_inputs.push_back(X_tensor); + yolo_box_inputs.push_back(img_size_tensor); + + auto* yolo_box_layer = engine_->network()->addPluginV2( + yolo_box_inputs.data(), yolo_box_inputs.size(), *yolo_box_plugin); + + std::vector output_names; + output_names.push_back(op_desc.Output("Boxes").front()); + output_names.push_back(op_desc.Output("Scores").front()); + + RreplenishLayerAndOutput(yolo_box_layer, "yolo_box", output_names, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(yolo_box, YoloBoxOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 7c1b2e8001edbd8cc688731507b1bb991826328b..c95912a931e0bc740775bd59a6bfcf30eb981c81 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -111,6 +111,7 @@ struct SimpleOpTypeSetTeller : public Teller { "flatten2", "flatten", "gather", + "yolo_box", "roi_align", "affine_channel", "multiclass_nms", @@ -198,6 +199,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, if (!with_dynamic_shape || desc.Input("Axis").size() > 0) return false; } + if (op_type == "yolo_box") { + if (with_dynamic_shape) return false; + bool has_attrs = + (desc.HasAttr("class_num") && desc.HasAttr("anchors") && + desc.HasAttr("downsample_ratio") && desc.HasAttr("conf_thresh") && + desc.HasAttr("clip_bbox") && desc.HasAttr("scale_x_y")); + return has_attrs; + } + if (op_type == "affine_channel") { if (!desc.HasAttr("data_layout")) return false; auto data_layout = framework::StringToDataLayout( diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 4107f9ef6743390e9353d6848b47d14929178af2..b4e948edd8a6bb18e22918e4e3521a36f738a084 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -5,6 +5,7 @@ nv_library(tensorrt_plugin instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu + yolo_box_op_plugin.cu roi_align_op_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) diff --git a/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..e1b4c898d212ffdf8db4c4910a9aab2ee728b9c3 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu @@ -0,0 +1,404 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" +#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h" +#include "paddle/fluid/operators/detection/yolo_box_op.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type, + const std::vector& anchors, + const int class_num, const float conf_thresh, + const int downsample_ratio, const bool clip_bbox, + const float scale_x_y, const int input_h, + const int input_w) + : data_type_(data_type), + class_num_(class_num), + conf_thresh_(conf_thresh), + downsample_ratio_(downsample_ratio), + clip_bbox_(clip_bbox), + scale_x_y_(scale_x_y), + input_h_(input_h), + input_w_(input_w) { + anchors_.insert(anchors_.end(), anchors.cbegin(), anchors.cend()); + assert(data_type_ == nvinfer1::DataType::kFLOAT || + data_type_ == nvinfer1::DataType::kHALF); + assert(class_num_ > 0); + assert(input_h_ > 0); + assert(input_w_ > 0); + + cudaMalloc(&anchors_device_, anchors.size() * sizeof(int)); + cudaMemcpy(anchors_device_, anchors.data(), anchors.size() * sizeof(int), + cudaMemcpyHostToDevice); +} + +YoloBoxPlugin::YoloBoxPlugin(const void* data, size_t length) { + DeserializeValue(&data, &length, &data_type_); + DeserializeValue(&data, &length, &anchors_); + DeserializeValue(&data, &length, &class_num_); + DeserializeValue(&data, &length, &conf_thresh_); + DeserializeValue(&data, &length, &downsample_ratio_); + DeserializeValue(&data, &length, &clip_bbox_); + DeserializeValue(&data, &length, &scale_x_y_); + DeserializeValue(&data, &length, &input_h_); + DeserializeValue(&data, &length, &input_w_); +} + +YoloBoxPlugin::~YoloBoxPlugin() { + if (anchors_device_ != nullptr) { + cudaFree(anchors_device_); + anchors_device_ = nullptr; + } +} + +const char* YoloBoxPlugin::getPluginType() const { return "yolo_box_plugin"; } + +const char* YoloBoxPlugin::getPluginVersion() const { return "1"; } + +int YoloBoxPlugin::getNbOutputs() const { return 2; } + +nvinfer1::Dims YoloBoxPlugin::getOutputDimensions(int index, + const nvinfer1::Dims* inputs, + int nb_input_dims) { + const int anchor_num = anchors_.size() / 2; + const int box_num = inputs[0].d[1] * inputs[0].d[2] * anchor_num; + + assert(index <= 1); + + if (index == 0) { + return nvinfer1::Dims2(box_num, 4); + } + return nvinfer1::Dims2(box_num, class_num_); +} + +bool YoloBoxPlugin::supportsFormat(nvinfer1::DataType type, + nvinfer1::TensorFormat format) const { + return ((type == data_type_ || type == nvinfer1::DataType::kINT32) && + format == nvinfer1::TensorFormat::kLINEAR); +} + +size_t YoloBoxPlugin::getWorkspaceSize(int max_batch_size) const { return 0; } + +template +__device__ inline T sigmoid(T x) { + return 1. / (1. + exp(-x)); +} + +template <> +__device__ inline float sigmoid(float x) { + return 1.f / (1.f + expf(-x)); +} + +template +__device__ inline void GetYoloBox(float* box, const T* x, const int* anchors, + int i, int j, int an_idx, int grid_size_h, + int grid_size_w, int input_size_h, + int input_size_w, int index, int stride, + int img_height, int img_width, float scale, + float bias) { + box[0] = static_cast( + (i + sigmoid(static_cast(x[index]) * scale + bias)) * img_width / + grid_size_w); + box[1] = static_cast( + (j + sigmoid(static_cast(x[index + stride]) * scale + bias)) * + img_height / grid_size_h); + box[2] = static_cast(expf(static_cast(x[index + 2 * stride])) * + anchors[2 * an_idx] * img_width / input_size_w); + box[3] = + static_cast(expf(static_cast(x[index + 3 * stride])) * + anchors[2 * an_idx + 1] * img_height / input_size_h); +} + +__device__ inline int GetEntryIndex(int batch, int an_idx, int hw_idx, + int an_num, int an_stride, int stride, + int entry) { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; +} + +template +__device__ inline void CalcDetectionBox(T* boxes, const float* box, + const int box_idx, const int img_height, + const int img_width, bool clip_bbox) { + float tmp_box_0, tmp_box_1, tmp_box_2, tmp_box_3; + tmp_box_0 = box[0] - box[2] / 2; + tmp_box_1 = box[1] - box[3] / 2; + tmp_box_2 = box[0] + box[2] / 2; + tmp_box_3 = box[1] + box[3] / 2; + + if (clip_bbox) { + tmp_box_0 = max(tmp_box_0, 0.f); + tmp_box_1 = max(tmp_box_1, 0.f); + tmp_box_2 = min(tmp_box_2, static_cast(img_width - 1)); + tmp_box_3 = min(tmp_box_3, static_cast(img_height - 1)); + } + + boxes[box_idx + 0] = static_cast(tmp_box_0); + boxes[box_idx + 1] = static_cast(tmp_box_1); + boxes[box_idx + 2] = static_cast(tmp_box_2); + boxes[box_idx + 3] = static_cast(tmp_box_3); +} + +template +__device__ inline void CalcLabelScore(T* scores, const T* input, + const int label_idx, const int score_idx, + const int class_num, const float conf, + const int stride) { + for (int i = 0; i < class_num; i++) { + scores[score_idx + i] = static_cast( + conf * sigmoid(static_cast(input[label_idx + i * stride]))); + } +} + +template +__global__ void KeYoloBoxFw(const T* const input, const int* const imgsize, + T* boxes, T* scores, const float conf_thresh, + const int* anchors, const int n, const int h, + const int w, const int an_num, const int class_num, + const int box_num, int input_size_h, + int input_size_w, bool clip_bbox, const float scale, + const float bias) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + float box[4]; + for (; tid < n * box_num; tid += stride) { + int grid_num = h * w; + int i = tid / box_num; + int j = (tid % box_num) / grid_num; + int k = (tid % grid_num) / w; + int l = tid % w; + + int an_stride = (5 + class_num) * grid_num; + int img_height = imgsize[2 * i]; + int img_width = imgsize[2 * i + 1]; + + int obj_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4); + float conf = sigmoid(static_cast(input[obj_idx])); + int box_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); + + if (conf < conf_thresh) { + for (int i = 0; i < 4; ++i) { + box[i] = 0.f; + } + } else { + GetYoloBox(box, input, anchors, l, k, j, h, w, input_size_h, + input_size_w, box_idx, grid_num, img_height, img_width, + scale, bias); + } + + box_idx = (i * box_num + j * grid_num + k * w + l) * 4; + CalcDetectionBox(boxes, box, box_idx, img_height, img_width, clip_bbox); + + int label_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); + int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; + CalcLabelScore(scores, input, label_idx, score_idx, class_num, conf, + grid_num); + } +} + +template +int YoloBoxPlugin::enqueue_impl(int batch_size, const void* const* inputs, + void** outputs, void* workspace, + cudaStream_t stream) { + const int n = batch_size; + const int h = input_h_; + const int w = input_w_; + const int an_num = anchors_.size() / 2; + const int box_num = h * w * an_num; + int input_size_h = downsample_ratio_ * h; + int input_size_w = downsample_ratio_ * w; + + float bias = -0.5 * (scale_x_y_ - 1.); + constexpr int threads = 256; + + KeYoloBoxFw<<<(n * box_num + threads - 1) / threads, threads, 0, stream>>>( + reinterpret_cast(inputs[0]), + reinterpret_cast(inputs[1]), + reinterpret_cast(outputs[0]), reinterpret_cast(outputs[1]), + conf_thresh_, anchors_device_, n, h, w, an_num, class_num_, box_num, + input_size_h, input_size_w, clip_bbox_, scale_x_y_, bias); + return cudaGetLastError() != cudaSuccess; +} + +int YoloBoxPlugin::enqueue(int batch_size, const void* const* inputs, + void** outputs, void* workspace, + cudaStream_t stream) { + if (data_type_ == nvinfer1::DataType::kFLOAT) { + return enqueue_impl(batch_size, inputs, outputs, workspace, stream); + } else if (data_type_ == nvinfer1::DataType::kHALF) { + return enqueue_impl(batch_size, inputs, outputs, workspace, stream); + } + assert("unsupported type."); +} + +int YoloBoxPlugin::initialize() { return 0; } + +void YoloBoxPlugin::terminate() {} + +size_t YoloBoxPlugin::getSerializationSize() const { + size_t serialize_size = 0; + serialize_size += SerializedSize(data_type_); + serialize_size += SerializedSize(anchors_); + serialize_size += SerializedSize(class_num_); + serialize_size += SerializedSize(conf_thresh_); + serialize_size += SerializedSize(downsample_ratio_); + serialize_size += SerializedSize(clip_bbox_); + serialize_size += SerializedSize(scale_x_y_); + serialize_size += SerializedSize(input_h_); + serialize_size += SerializedSize(input_w_); + return serialize_size; +} + +void YoloBoxPlugin::serialize(void* buffer) const { + SerializeValue(&buffer, data_type_); + SerializeValue(&buffer, anchors_); + SerializeValue(&buffer, class_num_); + SerializeValue(&buffer, conf_thresh_); + SerializeValue(&buffer, downsample_ratio_); + SerializeValue(&buffer, clip_bbox_); + SerializeValue(&buffer, scale_x_y_); + SerializeValue(&buffer, input_h_); + SerializeValue(&buffer, input_w_); +} + +void YoloBoxPlugin::destroy() { + cudaFree(anchors_device_); + delete this; +} + +void YoloBoxPlugin::setPluginNamespace(const char* lib_namespace) { + namespace_ = std::string(lib_namespace); +} + +const char* YoloBoxPlugin::getPluginNamespace() const { + return namespace_.c_str(); +} + +nvinfer1::DataType YoloBoxPlugin::getOutputDataType( + int index, const nvinfer1::DataType* input_type, int nb_inputs) const { + return data_type_; +} + +bool YoloBoxPlugin::isOutputBroadcastAcrossBatch(int output_index, + const bool* input_is_broadcast, + int nb_inputs) const { + return false; +} + +bool YoloBoxPlugin::canBroadcastInputAcrossBatch(int input_index) const { + return false; +} + +void YoloBoxPlugin::configurePlugin( + const nvinfer1::Dims* input_dims, int nb_inputs, + const nvinfer1::Dims* output_dims, int nb_outputs, + const nvinfer1::DataType* input_types, + const nvinfer1::DataType* output_types, const bool* input_is_broadcast, + const bool* output_is_broadcast, nvinfer1::PluginFormat float_format, + int max_batct_size) {} + +nvinfer1::IPluginV2Ext* YoloBoxPlugin::clone() const { + return new YoloBoxPlugin(data_type_, anchors_, class_num_, conf_thresh_, + downsample_ratio_, clip_bbox_, scale_x_y_, input_h_, + input_w_); +} + +YoloBoxPluginCreator::YoloBoxPluginCreator() {} + +void YoloBoxPluginCreator::setPluginNamespace(const char* lib_namespace) { + namespace_ = std::string(lib_namespace); +} + +const char* YoloBoxPluginCreator::getPluginNamespace() const { + return namespace_.c_str(); +} + +const char* YoloBoxPluginCreator::getPluginName() const { + return "yolo_box_plugin"; +} + +const char* YoloBoxPluginCreator::getPluginVersion() const { return "1"; } + +const nvinfer1::PluginFieldCollection* YoloBoxPluginCreator::getFieldNames() { + return &field_collection_; +} + +nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) { + const nvinfer1::PluginField* fields = fc->fields; + + int type_id = -1; + std::vector anchors; + int class_num = -1; + float conf_thresh = 0.01; + int downsample_ratio = 32; + bool clip_bbox = true; + float scale_x_y = 1.; + int h = -1; + int w = -1; + + for (int i = 0; i < fc->nbFields; ++i) { + const std::string field_name(fc->fields[i].name); + if (field_name.compare("type_id") == 0) { + type_id = *static_cast(fc->fields[i].data); + } else if (field_name.compare("anchors")) { + const int length = fc->fields[i].length; + const int* data = static_cast(fc->fields[i].data); + anchors.insert(anchors.end(), data, data + length); + } else if (field_name.compare("class_num")) { + class_num = *static_cast(fc->fields[i].data); + } else if (field_name.compare("conf_thresh")) { + conf_thresh = *static_cast(fc->fields[i].data); + } else if (field_name.compare("downsample_ratio")) { + downsample_ratio = *static_cast(fc->fields[i].data); + } else if (field_name.compare("clip_bbox")) { + clip_bbox = *static_cast(fc->fields[i].data); + } else if (field_name.compare("scale_x_y")) { + scale_x_y = *static_cast(fc->fields[i].data); + } else if (field_name.compare("h")) { + h = *static_cast(fc->fields[i].data); + } else if (field_name.compare("w")) { + w = *static_cast(fc->fields[i].data); + } else { + assert(false && "unknown plugin field name."); + } + } + + return new YoloBoxPlugin( + type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, anchors, + class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y, h, w); +} + +nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::deserializePlugin( + const char* name, const void* serial_data, size_t serial_length) { + auto plugin = new YoloBoxPlugin(serial_data, serial_length); + plugin->setPluginNamespace(namespace_.c_str()); + return plugin; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..8ca21da7ae0377164cbb50c502f0abb5ca943058 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h @@ -0,0 +1,117 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class YoloBoxPlugin : public nvinfer1::IPluginV2Ext { + public: + explicit YoloBoxPlugin(const nvinfer1::DataType data_type, + const std::vector& anchors, const int class_num, + const float conf_thresh, const int downsample_ratio, + const bool clip_bbox, const float scale_x_y, + const int input_h, const int input_w); + YoloBoxPlugin(const void* data, size_t length); + ~YoloBoxPlugin() override; + + const char* getPluginType() const override; + const char* getPluginVersion() const override; + int getNbOutputs() const override; + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int nb_input_dims) override; + bool supportsFormat(nvinfer1::DataType type, + nvinfer1::TensorFormat format) const override; + size_t getWorkspaceSize(int max_batch_size) const override; + int enqueue(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; + template + int enqueue_impl(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream); + int initialize() override; + void terminate() override; + size_t getSerializationSize() const override; + void serialize(void* buffer) const override; + void destroy() override; + void setPluginNamespace(const char* lib_namespace) override; + const char* getPluginNamespace() const override; + + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* input_type, + int nb_inputs) const override; + bool isOutputBroadcastAcrossBatch(int output_index, + const bool* input_is_broadcast, + int nb_inputs) const override; + bool canBroadcastInputAcrossBatch(int input_index) const override; + void configurePlugin(const nvinfer1::Dims* input_dims, int nb_inputs, + const nvinfer1::Dims* output_dims, int nb_outputs, + const nvinfer1::DataType* input_types, + const nvinfer1::DataType* output_types, + const bool* input_is_broadcast, + const bool* output_is_broadcast, + nvinfer1::PluginFormat float_format, + int max_batct_size) override; + nvinfer1::IPluginV2Ext* clone() const override; + + private: + nvinfer1::DataType data_type_; + std::vector anchors_; + int* anchors_device_; + int class_num_; + float conf_thresh_; + int downsample_ratio_; + bool clip_bbox_; + float scale_x_y_; + int input_h_; + int input_w_; + std::string namespace_; +}; + +class YoloBoxPluginCreator : public nvinfer1::IPluginCreator { + public: + YoloBoxPluginCreator(); + ~YoloBoxPluginCreator() override = default; + + void setPluginNamespace(const char* lib_namespace) override; + const char* getPluginNamespace() const override; + const char* getPluginName() const override; + const char* getPluginVersion() const override; + const nvinfer1::PluginFieldCollection* getFieldNames() override; + + nvinfer1::IPluginV2Ext* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override; + nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override; + + private: + std::string namespace_; + nvinfer1::PluginFieldCollection field_collection_; +}; + +REGISTER_TRT_PLUGIN_V2(YoloBoxPluginCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_yolo_box_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_yolo_box_op.py new file mode 100644 index 0000000000000000000000000000000000000000..cff8091cd93f8ecfb48e066e11eeda00c1e83a8b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_yolo_box_op.py @@ -0,0 +1,76 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig + + +class TRTYoloBoxTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + image_shape = [self.bs, self.channel, self.height, self.width] + image = fluid.data(name='image', shape=image_shape, dtype='float32') + image_size = fluid.data( + name='image_size', shape=[self.bs, 2], dtype='int32') + boxes, scores = self.append_yolobox(image, image_size) + scores = fluid.layers.reshape(scores, (self.bs, -1)) + out = fluid.layers.batch_norm(scores, is_test=True) + + self.feeds = { + 'image': np.random.random(image_shape).astype('float32'), + 'image_size': np.random.randint( + 32, 64, size=(self.bs, 2)).astype('int32'), + } + self.enable_trt = True + self.trt_parameters = TRTYoloBoxTest.TensorRTParam( + 1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out, boxes] + + def set_params(self): + self.bs = 4 + self.channel = 255 + self.height = 64 + self.width = 64 + self.class_num = 80 + self.anchors = [10, 13, 16, 30, 33, 23] + self.conf_thresh = .1 + self.downsample_ratio = 32 + + def append_yolobox(self, image, image_size): + return fluid.layers.yolo_box( + x=image, + img_size=image_size, + class_num=self.class_num, + anchors=self.anchors, + conf_thresh=self.conf_thresh, + downsample_ratio=self.downsample_ratio) + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu, flatten=True) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +if __name__ == "__main__": + unittest.main()