未验证 提交 b807e408 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] add anchor generator op plugin (#31730)

* add anchor generator op plugin

* add anchor generator unit_test

* remove dbg info

* remove redundant line

* replace assertion with paddle enforce

* dynamic plugin replaces assertion with paddle enforce

* anchor generator support dynamic shape on spatial axis

* anchor generator test with fp16, dynamic shape

* add anchor generator test all

* add back main

* reduce test input size to not exceed the timelimit of ci

* change super to InferencePassTest for python2 compatibility

* reuse paddle operator anchor generator

* move creator construct to header with default

* add cuda ifdef

* reduce line

* change super to InferencePassTest for python2 compatibility

* fix anchor generator fp16 serialize setting

* split unittest from test_all

* restrict anchor generator input format before version 7234

* anchor generator only support greater than trt7.1

* change min_graph_size to 2

* min_graph size to 3 if dynamic shape

* reduce dynamic shape size to avoid trt search tactic too long to exceed time limit

* remove anchor from fetch list

* anchor generator support all trt version

* fix memory not allocated but if serialized
上级 4acc87be
......@@ -1192,6 +1192,7 @@ USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
USE_TRT_CONVERTER(clip);
USE_TRT_CONVERTER(gather);
USE_TRT_CONVERTER(anchor_generator);
USE_TRT_CONVERTER(yolo_box);
USE_TRT_CONVERTER(roi_align);
USE_TRT_CONVERTER(affine_channel);
......
......@@ -6,6 +6,7 @@ nv_library(tensorrt_converter
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc
anchor_generator_op.cc
yolo_box_op.cc
roi_align_op.cc
affine_channel_op.cc
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/* Anchor Generator Op */
class AnchorGeneratorOpConverter : public OpConverter {
public:
void operator()(const paddle::framework::proto::OpDesc& op,
const paddle::framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a fluid anchor generator op to tensorrt plugin";
framework::OpDesc op_desc(op, nullptr);
std::string input_name = op_desc.Input("Input").front();
std::string anchor_name = op_desc.Output("Anchors").front();
std::string variance_name = op_desc.Output("Variances").front();
auto* input = engine_->GetITensor(input_name);
const auto input_dims = input->getDimensions(); // C, H, W
std::vector<std::string> output_names{anchor_name, variance_name};
const auto anchor_sizes =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("anchor_sizes"));
const auto aspect_ratios =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("aspect_ratios"));
const auto stride =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("stride"));
const auto variances =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("variances"));
const auto offset = BOOST_GET_CONST(float, op_desc.GetAttr("offset"));
const int num_anchors = aspect_ratios.size() * anchor_sizes.size();
bool is_dynamic = engine_->with_dynamic_shape();
const auto height = input_dims.d[1];
const auto width = input_dims.d[2];
const int box_num = width * height * num_anchors;
const nvinfer1::DataType data_type = nvinfer1::DataType::kFLOAT;
nvinfer1::IPluginV2* anchor_generator_plugin = nullptr;
if (is_dynamic) {
anchor_generator_plugin = new plugin::AnchorGeneratorPluginDynamic(
data_type, anchor_sizes, aspect_ratios, stride, variances, offset,
num_anchors);
} else {
anchor_generator_plugin = new plugin::AnchorGeneratorPlugin(
data_type, anchor_sizes, aspect_ratios, stride, variances, offset,
height, width, num_anchors, box_num);
}
std::vector<nvinfer1::ITensor*> anchor_generator_inputs{input};
auto* anchor_generator_layer = engine_->network()->addPluginV2(
anchor_generator_inputs.data(), anchor_generator_inputs.size(),
*anchor_generator_plugin);
RreplenishLayerAndOutput(anchor_generator_layer, "anchor_generator",
output_names, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(anchor_generator, AnchorGeneratorOpConverter);
......@@ -116,6 +116,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"affine_channel",
"multiclass_nms",
"nearest_interp",
"anchor_generator",
};
};
......
......@@ -5,6 +5,7 @@ nv_library(tensorrt_plugin
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
anchor_generator_op_plugin.cu
yolo_box_op_plugin.cu
roi_align_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class AnchorGeneratorPlugin : public nvinfer1::IPluginV2Ext {
public:
explicit AnchorGeneratorPlugin(
const nvinfer1::DataType, const std::vector<float>& anchor_sizes,
const std::vector<float>& aspect_ratios, const std::vector<float>& stride,
const std::vector<float>& variances, const float offset, const int height,
const int width, const int num_anchors, const int box_num);
AnchorGeneratorPlugin(const void* data, size_t length);
~AnchorGeneratorPlugin() override;
const char* getPluginType() const override;
const char* getPluginVersion() const override;
int getNbOutputs() const override;
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nb_input_dims) override;
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::TensorFormat format) const override;
size_t getWorkspaceSize(int max_batch_size) const override;
int enqueue(int batch_size, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
void destroy() override;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_type,
int nb_inputs) const override;
bool isOutputBroadcastAcrossBatch(int output_index,
const bool* input_is_broadcast,
int nb_inputs) const override;
bool canBroadcastInputAcrossBatch(int input_index) const override;
void configurePlugin(const nvinfer1::Dims* input_dims, int nb_inputs,
const nvinfer1::Dims* output_dims, int nb_outputs,
const nvinfer1::DataType* input_types,
const nvinfer1::DataType* output_types,
const bool* input_is_broadcast,
const bool* output_is_broadcast,
nvinfer1::PluginFormat float_format,
int max_batct_size) override;
nvinfer1::IPluginV2Ext* clone() const override;
private:
template <typename T>
int enqueue_impl(int batch_size, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream);
nvinfer1::DataType data_type_;
std::vector<float> anchor_sizes_;
std::vector<float> aspect_ratios_;
std::vector<float> stride_;
std::vector<float> variances_;
float offset_;
void* anchor_sizes_device_;
void* aspect_ratios_device_;
void* stride_device_;
void* variances_device_;
int height_;
int width_;
int num_anchors_;
int box_num_;
std::string namespace_;
};
class AnchorGeneratorPluginCreator : public nvinfer1::IPluginCreator {
public:
AnchorGeneratorPluginCreator() = default;
~AnchorGeneratorPluginCreator() override = default;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const nvinfer1::PluginFieldCollection* getFieldNames() override;
nvinfer1::IPluginV2Ext* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override;
private:
std::string namespace_;
nvinfer1::PluginFieldCollection field_collection_;
};
REGISTER_TRT_PLUGIN_V2(AnchorGeneratorPluginCreator);
#if IS_TRT_VERSION_GE(6000)
class AnchorGeneratorPluginDynamic : public DynamicPluginTensorRT {
public:
explicit AnchorGeneratorPluginDynamic(const nvinfer1::DataType data_type,
const std::vector<float>& anchor_sizes,
const std::vector<float>& aspect_ratios,
const std::vector<float>& stride,
const std::vector<float>& variances,
const float offset,
const int num_anchors);
AnchorGeneratorPluginDynamic(void const* data, size_t length);
~AnchorGeneratorPluginDynamic();
nvinfer1::IPluginV2DynamicExt* clone() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
const char* getPluginType() const override;
int getNbOutputs() const override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
void destroy() override;
private:
template <typename T>
int enqueue_impl(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs,
void* workspace, cudaStream_t stream);
nvinfer1::DataType data_type_;
std::vector<float> anchor_sizes_;
std::vector<float> aspect_ratios_;
std::vector<float> stride_;
std::vector<float> variances_;
float offset_;
void* anchor_sizes_device_;
void* aspect_ratios_device_;
void* stride_device_;
void* variances_device_;
int num_anchors_;
std::string namespace_;
};
class AnchorGeneratorPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
AnchorGeneratorPluginDynamicCreator() = default;
~AnchorGeneratorPluginDynamicCreator() override = default;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const nvinfer1::PluginFieldCollection* getFieldNames() override;
nvinfer1::IPluginV2Ext* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override;
private:
std::string namespace_;
nvinfer1::PluginFieldCollection field_collection_;
};
REGISTER_TRT_PLUGIN_V2(AnchorGeneratorPluginDynamicCreator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -49,14 +49,11 @@ __global__ void GenAnchors(T* out, const T* aspect_ratios, const int ar_num,
anchor_width = scale_w * base_w;
anchor_height = scale_h * base_h;
T xmin = (x_ctr - 0.5 * (anchor_width - 1));
T ymin = (y_ctr - 0.5 * (anchor_height - 1));
T xmax = (x_ctr + 0.5 * (anchor_width - 1));
T ymax = (y_ctr + 0.5 * (anchor_height - 1));
out[i * 4] = xmin;
out[i * 4 + 1] = ymin;
out[i * 4 + 2] = xmax;
out[i * 4 + 3] = ymax;
T xmin = (x_ctr - .5f * (anchor_width - 1));
T ymin = (y_ctr - .5f * (anchor_height - 1));
T xmax = (x_ctr + .5f * (anchor_width - 1));
T ymax = (y_ctr + .5f * (anchor_height - 1));
reinterpret_cast<float4*>(out)[i] = make_float4(xmin, ymin, xmax, ymax);
}
}
......
......@@ -22,6 +22,19 @@ limitations under the License. */
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_CUDA
template <typename T>
extern __global__ void GenAnchors(T* out, const T* aspect_ratios,
const int ar_num, const T* anchor_sizes,
const int as_num, const T* stride,
const int sd_num, const int height,
const int width, const T offset);
template <typename T>
extern __global__ void SetVariance(T* out, const T* var, const int vnum,
const int num);
#endif
template <typename T>
class AnchorGeneratorOpKernel : public framework::OpKernel<T> {
public:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import itertools
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TRTAnchorGeneratorBaseTest(InferencePassTest):
def setUp(self):
self.bs = 1
self.channel = 16
self.height = 32
self.width = 32
self.anchor_sizes = [64., 128., 256., 512.]
self.aspect_ratios = [.5, 1., 2.]
self.variance = [.1, .1, .2, .2]
self.stride = [8., 8.]
self.precision = AnalysisConfig.Precision.Float32
self.serialize = False
self.enable_trt = True
self.feeds = {
'data':
np.random.random([self.bs, self.channel, self.height,
self.width]).astype('float32'),
}
def build(self):
min_graph_size = 3 if self.dynamic_shape_params is not None else 2
self.trt_parameters = InferencePassTest.TensorRTParam(
1 << 30, self.bs, min_graph_size, self.precision, self.serialize,
False)
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name='data',
shape=[-1, self.channel, self.height, self.width],
dtype='float32')
anchor, var = fluid.layers.detection.anchor_generator(
data,
anchor_sizes=self.anchor_sizes,
aspect_ratios=self.aspect_ratios,
variance=self.variance,
stride=self.stride)
if self.dynamic_shape_params is not None:
anchor = fluid.layers.transpose(anchor, [2, 3, 0, 1])
out = fluid.layers.batch_norm(anchor, is_test=True)
self.fetch_list = [out, var]
def run_test(self):
self.build()
self.check_output()
def set_dynamic(self):
self.dynamic_shape_params = InferencePassTest.DynamicShapeParam({
'data': [self.bs, self.channel, self.height // 2, self.width // 2]
}, {
'data': [self.bs, self.channel, self.height, self.width]
}, {'data': [self.bs, self.channel, self.height, self.width]}, False)
def test_base(self):
self.run_test()
def test_fp16(self):
self.precision = AnalysisConfig.Precision.Half
self.run_test()
def test_serialize(self):
self.serialize = True
self.run_test()
def test_dynamic(self):
self.set_dynamic()
self.run_test()
def test_dynamic_fp16(self):
self.precision = AnalysisConfig.Precision.Half
self.set_dynamic()
self.run_test()
def test_dynamic_serialize(self):
self.serialize = True
self.set_dynamic()
self.run_test()
def test_dynamic_fp16_serialize(self):
self.serialize = True
self.precision = AnalysisConfig.Precision.Half
self.set_dynamic()
self.run_test()
def check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
atol = 1e-5
if self.trt_parameters.precision == AnalysisConfig.Precision.Half:
atol = 1e-3
self.check_output_with_option(use_gpu, atol, flatten=True)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册