提交 d43bb7f2 编写于 作者: Z zlsh80826

convert mask with fp32/fp16 support

上级 d4dcc80d
......@@ -1029,7 +1029,7 @@ USE_TRT_CONVERTER(elementwise_mul_tensor);
USE_TRT_CONVERTER(elementwise_max_tensor);
USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(mul);
USE_TRT_CONVERTER(matmul);
USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu);
USE_TRT_CONVERTER(sigmoid);
......
......@@ -11,7 +11,6 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
namespace paddle {
......@@ -81,24 +80,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
auto pos_tensor = engine_->GetITensor("eval_placeholder_2");
plugin::CastIntPluginDynamic* cast_plugin =
new plugin::CastIntPluginDynamic();
auto cast_layer = engine_->AddPluginV2(&pos_tensor, 1, cast_plugin);
auto casted_pos_tensor = cast_layer->getOutput(0);
auto reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *casted_pos_tensor);
nvinfer1::Dims2 reshape_dim(0, 0);
nvinfer1::Permutation perm{1, 0, 2};
reshape_layer->setFirstTranspose(perm);
reshape_layer->setReshapeDimensions(reshape_dim);
auto imask_layer =
TRT_ENGINE_ADD_LAYER(engine_, Reduce, *reshape_layer->getOutput(0),
nvinfer1::ReduceOperation::kMAX, 1, false);
engine_->SetITensor("imask_tensor", imask_layer->getOutput(0));
plugin::DynamicPluginTensorRT* plugin = nullptr;
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
namespace paddle {
namespace inference {
......@@ -31,17 +32,27 @@ class MulOpConverter : public OpConverter {
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
bool transpose_x = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_X"));
bool transpose_y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y"));
#ifdef USE_NVINFER_PLUGIN
nvinfer1::DataType type = (engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT;
plugin::ConvertMaskPluginDynamic* plugin =
new plugin::ConvertMaskPluginDynamic(type);
auto convert_mask_layer = engine_->AddPluginV2(&input1, 1, plugin);
engine_->SetITensor("qkv_plugin_mask", convert_mask_layer->getOutput(0));
#endif
// Both the input1 and input2 do not need transpose.
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false,
*const_cast<nvinfer1::ITensor*>(input2), false);
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1),
transpose_x, *const_cast<nvinfer1::ITensor*>(input2), transpose_y);
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
RreplenishLayerAndOutput(layer, "matmul", {output_name}, test_mode);
}
};
......@@ -49,4 +60,4 @@ class MulOpConverter : public OpConverter {
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);
REGISTER_TRT_OP_CONVERTER(matmul, MulOpConverter);
......@@ -113,33 +113,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())};
nvinfer1::Permutation permutation{0, 1, 2, 3, 4};
auto trans_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
trans_layer->setFirstTranspose(permutation);
auto* fc_layer = TRT_ENGINE_ADD_LAYER(
engine_, FullyConnected, *trans_layer->getOutput(0), n, weight, bias);
/*
auto pos_tensor = engine_->GetITensor("eval_placeholder_2");
plugin::CastIntPluginDynamic* cast_plugin =
new plugin::CastIntPluginDynamic();
auto cast_layer = engine_->AddPluginV2(&pos_tensor, 1, cast_plugin);
auto casted_pos_tensor = cast_layer->getOutput(0);
auto reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *casted_pos_tensor);
nvinfer1::Dims2 reshape_dim(0, 0);
nvinfer1::Permutation perm{1, 0, 2};
reshape_layer->setFirstTranspose(perm);
reshape_layer->setReshapeDimensions(reshape_dim);
auto reduce_layer =
TRT_ENGINE_ADD_LAYER(engine_, Reduce,
*reshape_layer->getOutput(0),
nvinfer1::ReduceOperation::kMAX, 1, false);
*/
// auto imask_tensor = engine_->GetITensor("imask_tensor");
auto imask_tensor = engine_->GetITensor("fused_mha_mask");
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n,
weight, bias);
auto mask_tensor = engine_->GetITensor("qkv_plugin_mask");
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "1");
......@@ -154,28 +131,24 @@ class MultiheadMatMulOpConverter : public OpConverter {
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
};
nvinfer1::PluginFieldCollection* pluginPtr =
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*pluginPtr) +
malloc(sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
pluginPtr->nbFields = static_cast<int>(fields.size());
pluginPtr->fields = fields.data();
plugin_collection->nbFields = static_cast<int>(fields.size());
plugin_collection->fields = fields.data();
auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic",
plugin_collection);
free(plugin_collection);
auto pluginObj =
creator->createPlugin("CustomQKVToContextPluginDynamic", pluginPtr);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(fc_layer->getOutput(0));
// plugin_inputs.push_back(reduce_layer->getOutput(0));
plugin_inputs.push_back(imask_tensor);
plugin_inputs.push_back(mask_tensor);
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *pluginObj);
assert(plugin_layer != nullptr);
auto trans_r_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0));
assert(trans_r_layer != nullptr);
trans_r_layer->setFirstTranspose(permutation);
layer = trans_r_layer;
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
#else
// transpose weight_data from m * n to n * m
auto* input_bias_qk =
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
namespace paddle {
namespace inference {
......@@ -27,7 +26,6 @@ class ScaleOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid scale op to tensorrt mul layer without bias";
std::cerr << "Scale converter" << std::endl;
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
......@@ -66,12 +64,6 @@ class ScaleOpConverter : public OpConverter {
platform::errors::Fatal(
"Paddle-TRT scale mode only support dimension >= 3"));
plugin::ConvertMaskPluginDynamic* plugin =
new plugin::ConvertMaskPluginDynamic();
auto convert_mask_layer = engine_->AddPluginV2(&input, 1, plugin);
convert_mask_layer->setName("convert_mask_layer");
engine_->SetITensor("fused_mha_mask", convert_mask_layer->getOutput(0));
nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
......
......@@ -43,7 +43,7 @@ struct SimpleOpTypeSetTeller : public Teller {
private:
// use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{"mul",
std::unordered_set<std::string> int8_teller_set{"matmul",
"conv2d",
"pool2d",
"relu",
......@@ -59,7 +59,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_mul",
"conv2d_transpose"};
std::unordered_set<std::string> teller_set{
"mul",
"matmul",
"conv2d",
"pool2d",
"relu",
......
......@@ -2,7 +2,7 @@ nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
cast_int_plugin.cu stack_op_plugin.cu convert_mask_plugin.cu
stack_op_plugin.cu convert_mask_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
nvinfer1::DimsExprs CastIntPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
assert(output_index == 0);
return inputs[0];
}
bool CastIntPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) {
const nvinfer1::PluginTensorDesc& in = in_out[pos];
return (in.type == nvinfer1::DataType::kINT32);
}
nvinfer1::DataType CastIntPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The Cast Int only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[index];
}
__global__ void castIntKernel(const int64_t* input, int32_t* output,
size_t num_elements) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_elements) return;
output[idx] = input[idx] + 1;
}
int CastIntPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs, void* workspace,
cudaStream_t stream) {
auto input_dims = input_desc[0].dims;
auto output_dims = output_desc[0].dims;
size_t num_elements = ProductDim(input_dims);
size_t out_num_elements = ProductDim(output_dims);
assert(input_type ==
nvinfer1::DataType::kINT32); // although the input is int64_t
assert(num_elements == out_num_elements);
const size_t num_threads = 256;
castIntKernel<<<num_elements / num_threads + 1, num_threads>>>(
static_cast<const int64_t*>(inputs[0]), static_cast<int32_t*>(outputs[0]),
num_elements);
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class CastIntPluginDynamic : public DynamicPluginTensorRT {
public:
CastIntPluginDynamic() {}
CastIntPluginDynamic(void const* serial_data, size_t serial_length) {}
~CastIntPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new CastIntPluginDynamic();
}
const char* getPluginType() const override { return "cast_int_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
size_t getSerializationSize() const override { return 0; }
void serialize(void* buffer) const override {}
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs, int nb_outputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs,
const nvinfer1::PluginTensorDesc* outputs,
int nb_outputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const override;
void destroy() override { delete this; }
};
class CastIntPluginV2Creator : public nvinfer1::IPluginCreator {
public:
CastIntPluginV2Creator() {}
const char* getPluginName() const override { return "cast_int_plugin"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new CastIntPluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(CastIntPluginV2Creator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -17,6 +17,7 @@
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
namespace paddle {
namespace inference {
......@@ -38,15 +39,23 @@ constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128;
nvinfer1::DimsExprs ConvertMaskPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
auto cms128 = expr_builder.constant(packedMaskSize128);
auto fp16maskSize = expr_builder.operation(
nvinfer1::DimensionOperation::kPROD, *cms128, *expr_builder.constant(2));
assert(output_index == 0);
if (type_ == nvinfer1::DataType::kHALF) {
auto cms128 = expr_builder.constant(packedMaskSize128);
auto fp16maskSize =
expr_builder.operation(nvinfer1::DimensionOperation::kPROD, *cms128,
*expr_builder.constant(2));
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.d[0] = inputs[0].d[0];
ret.d[1] = fp16maskSize;
return ret;
}
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.nbDims = 1;
ret.d[0] = inputs[0].d[0];
ret.d[1] = fp16maskSize;
return ret;
}
......@@ -54,22 +63,21 @@ bool ConvertMaskPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) {
const nvinfer1::PluginTensorDesc& desc = in_out[pos];
/* input: [B, S, S] */
/* input: [B, S, 1] */
/* output: [B, 2*maskSize] */
assert(nb_inputs == 1);
assert(nb_outputs == 1);
if (pos == 0) {
std::cerr << "desc.type: " << static_cast<int>(desc.type) << " "
<< desc.dims.nbDims << std::endl;
return ((desc.type == nvinfer1::DataType::kFLOAT ||
desc.type == nvinfer1::DataType::kHALF) &&
desc.dims.nbDims == 3);
}
std::cerr << "output.type: " << static_cast<int>(desc.type) << " "
<< desc.dims.nbDims << std::endl;
// return desc.type == nvinfer1::DataType::kHALF;
return true;
// return true;
/* fp16 -> fp16, fp32 -> int32 */
if (type_ == nvinfer1::DataType::kHALF)
return desc.type == nvinfer1::DataType::kHALF;
return desc.type == nvinfer1::DataType::kINT32;
}
nvinfer1::DataType ConvertMaskPluginDynamic::getOutputDataType(
......@@ -79,16 +87,36 @@ nvinfer1::DataType ConvertMaskPluginDynamic::getOutputDataType(
"The convert mask plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return nvinfer1::DataType::kHALF;
if (type_ == nvinfer1::DataType::kHALF) {
return nvinfer1::DataType::kHALF;
}
return nvinfer1::DataType::kINT32;
}
/* half [B, S, 1] -> int [S, B, 1] */
template <typename T>
__global__ void CastToIntAndReduce(const T* input, int* output, int seq_len,
__global__ void FullMaskPreprocess(const T* input, int* output, int seq_len,
int batch) {
int bid = blockIdx.x;
int sid = threadIdx.x;
output[sid * batch + bid] =
static_cast<int>(input[bid * seq_len * seq_len + sid]);
output[sid * batch + bid] = static_cast<int>(input[bid * seq_len + sid]);
}
/* float [B, S, 1] -> int [B] */
/* [[1. 1. 1. 0. 0.], -> [3, 4]
[1. 1. 1. 1. 0.]] */
__global__ void IMaskPreprocess(const float* input, int* output, int seq_len,
int batch) {
float sum = 0.f;
int bid = blockIdx.x;
int sid = threadIdx.x;
float thread_data = input[bid * seq_len + sid];
sum = paddle::operators::math::blockReduceSum<float>(thread_data, 0xffffffff);
if (sid == 0) {
output[bid] = static_cast<int>(sum);
}
}
__global__ void fillSBSMaskKernel(const uint32_t warps_m,
......@@ -159,33 +187,33 @@ int ConvertMaskPluginDynamic::enqueue(
int batch = input_dims.d[0];
int seq_len = input_dims.d[1];
assert(num_elements == out_num_elements * seq_len);
assert(seq_len <= 1024);
assert(output_desc.type == nvinfer1::DataType::kHALF);
// temp use, should remove
int* inputMaskSB;
cudaMalloc(&inputMaskSB, batch * seq_len * sizeof(int));
assert(seq_len == 128);
if (input_desc[0].type == nvinfer1::DataType::kFLOAT) {
CastToIntAndReduce<float><<<batch, seq_len, 0, stream>>>(
static_cast<const float*>(inputs[0]), inputMaskSB, seq_len, batch);
if (type_ == nvinfer1::DataType::kFLOAT) {
IMaskPreprocess<<<batch, seq_len, 0, stream>>>(
static_cast<const float*>(inputs[0]), static_cast<int*>(outputs[0]),
seq_len, batch);
} else {
CastToIntAndReduce<half><<<batch, seq_len, 0, stream>>>(
static_cast<const half*>(inputs[0]), inputMaskSB, seq_len, batch);
int* inputMaskSB;
cudaMalloc(&inputMaskSB, batch * seq_len * sizeof(int));
if (input_desc[0].type == nvinfer1::DataType::kFLOAT) {
FullMaskPreprocess<float><<<batch, seq_len, 0, stream>>>(
static_cast<const float*>(inputs[0]), inputMaskSB, seq_len, batch);
} else {
FullMaskPreprocess<half><<<batch, seq_len, 0, stream>>>(
static_cast<const half*>(inputs[0]), inputMaskSB, seq_len, batch);
}
size_t warps_m = 0, warps_n = 0, warps_k = 1;
if (seq_len == 128) {
warps_m = 2;
warps_n = 2;
}
convertMask(seq_len, batch, warps_m, warps_n, warps_k, inputMaskSB,
static_cast<uint32_t*>(outputs[0]), stream);
cudaFree(inputMaskSB);
}
assert(seq_len == 128);
size_t warps_m = 0, warps_n = 0, warps_k = 1;
if (seq_len == 128) {
warps_m = 2;
warps_n = 2;
}
convertMask(seq_len, batch, warps_m, warps_n, warps_k, inputMaskSB,
static_cast<uint32_t*>(outputs[0]), stream);
cudaFree(inputMaskSB);
return cudaGetLastError() != cudaSuccess;
}
#endif
......
......@@ -27,20 +27,32 @@ namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class ConvertMaskPluginDynamic : public DynamicPluginTensorRT {
public:
ConvertMaskPluginDynamic() {}
ConvertMaskPluginDynamic(void const* serial_data, size_t serial_length) {}
explicit ConvertMaskPluginDynamic(nvinfer1::DataType type) : type_(type) {
assert(type == nvinfer1::DataType::kHALF ||
type == nvinfer1::DataType::kFLOAT);
}
ConvertMaskPluginDynamic(void const* serial_data, size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &type_);
}
~ConvertMaskPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new ConvertMaskPluginDynamic();
return new ConvertMaskPluginDynamic(type_);
}
const char* getPluginType() const override { return "convert_mask_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
size_t getSerializationSize() const override { return 0; }
void serialize(void* buffer) const override {}
size_t getSerializationSize() const override {
size_t serialize_size = 0;
serialize_size += SerializedSize(type_);
return serialize_size;
}
void serialize(void* buffer) const override {
SerializeValue(&buffer, type_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
......@@ -71,6 +83,9 @@ class ConvertMaskPluginDynamic : public DynamicPluginTensorRT {
int nb_inputs) const override;
void destroy() override { delete this; }
private:
nvinfer1::DataType type_;
};
class ConvertMaskPluginV2Creator : public nvinfer1::IPluginCreator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册