未验证 提交 af97b310 编写于 作者: C ccrrong 提交者: GitHub

add slice plugin int32 support (#43808)

* add slice plugin int32 support
上级 eec4e034
......@@ -28,8 +28,8 @@ namespace tensorrt {
namespace plugin {
template <typename T>
__global__ void SliceKernel(int num, int dims, const T *input,
const int *offsets_info, T *output) {
__global__ void SliceKernel(
int num, int dims, const T *input, const int *offsets_info, T *output) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
extern __shared__ int shared_data[];
......@@ -54,8 +54,10 @@ __global__ void SliceKernel(int num, int dims, const T *input,
}
}
SlicePlugin::SlicePlugin(std::vector<int> starts, std::vector<int> ends,
std::vector<int> axes, bool with_fp16)
SlicePlugin::SlicePlugin(std::vector<int> starts,
std::vector<int> ends,
std::vector<int> axes,
bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes) {
with_fp16_ = with_fp16;
}
......@@ -79,10 +81,12 @@ bool SlicePlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
type == nvinfer1::DataType::kHALF ||
type == nvinfer1::DataType::kINT32) &&
(format == nvinfer1::PluginFormat::kLINEAR));
} else {
return ((type == nvinfer1::DataType::kFLOAT) &&
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kINT32) &&
(format == nvinfer1::PluginFormat::kLINEAR));
}
}
......@@ -99,11 +103,15 @@ nvinfer1::Dims SlicePlugin::getOutputDimensions(
return out_dims;
}
int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
int SlicePlugin::enqueue(int batch_size,
const void *const *inputs,
#if IS_TRT_VERSION_LT(8000)
void **outputs, void *workspace, cudaStream_t stream) {
void **outputs,
void *workspace,
cudaStream_t stream) {
#else
void *const *outputs, void *workspace,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
#endif
auto input_dims = getInputDims(0);
......@@ -153,8 +161,11 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int));
}
cudaMemcpyAsync(offset_temp_data_, offset_info.data(),
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(offset_temp_data_,
offset_info.data(),
sizeof(int) * 3 * num_dims,
cudaMemcpyHostToDevice,
stream);
int threads = 256;
int blocks = (out_num + threads - 1) / threads;
......@@ -171,9 +182,15 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
half *output = static_cast<half *>(outputs[0]);
SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
out_num, num_dims, input1, offset_temp_data_, output);
} else if (input_type == nvinfer1::DataType::kINT32) {
VLOG(1) << "TRT Plugin DataType selected. Slice-->int32";
const int *input1 = static_cast<const int *>(inputs[0]);
int *output = static_cast<int *>(outputs[0]);
SliceKernel<int><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
out_num, num_dims, input1, offset_temp_data_, output);
} else {
PADDLE_THROW(platform::errors::Fatal(
"The Slice TRT Plugin's input type should be float or half."));
"The Slice TRT Plugin's input type should be float, half or int."));
}
return cudaGetLastError() != cudaSuccess;
}
......@@ -197,7 +214,8 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
#if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
std::vector<int> ends,
std::vector<int> axes, int decrease_axis,
std::vector<int> axes,
int decrease_axis,
bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) {
with_fp16_ = with_fp16;
......@@ -238,7 +256,9 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
}
nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
int output_index,
const nvinfer1::DimsExprs *inputs,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
auto in_dims = inputs[0];
nvinfer1::DimsExprs ret = in_dims;
......@@ -264,7 +284,8 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
for (size_t i = 0; i < in_dims.nbDims; i++) {
if (decrease_axis_ == i) continue;
res.d[j++] = expr_builder.operation(nvinfer1::DimensionOperation::kMAX,
*expr_builder.constant(0), *ret.d[i]);
*expr_builder.constant(0),
*ret.d[i]);
}
return res;
}
......@@ -272,26 +293,33 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
}
bool SlicePluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int pos,
const nvinfer1::PluginTensorDesc *in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
in_out,
platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
pos,
nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
pos,
nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
in.type == nvinfer1::DataType::kHALF ||
in.type == nvinfer1::DataType::kINT32) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kINT32) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
}
......@@ -301,24 +329,28 @@ bool SlicePluginDynamic::supportsFormatCombination(
}
nvinfer1::DataType SlicePluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types,
int index,
const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(index, 0,
PADDLE_ENFORCE_EQ(index,
0,
platform::errors::InvalidArgument(
"The Slice Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT ||
input_types[0] == nvinfer1::DataType::kHALF),
input_types[0] == nvinfer1::DataType::kHALF ||
input_types[0] == nvinfer1::DataType::kINT32),
true,
platform::errors::InvalidArgument(
"The input type should be half or float"));
"The input type should be half, float or int"));
return input_types[0];
}
int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs, void *const *outputs,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims;
......@@ -362,8 +394,11 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int));
}
cudaMemcpyAsync(offset_temp_data_, offset_info_.data(),
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(offset_temp_data_,
offset_info_.data(),
sizeof(int) * 3 * num_dims,
cudaMemcpyHostToDevice,
stream);
int threads = 256;
int blocks = (out_num + threads - 1) / threads;
......@@ -380,9 +415,15 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
half *output = static_cast<half *>(outputs[0]);
SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
out_num, num_dims, input1, offset_temp_data_, output);
} else if (input_type == nvinfer1::DataType::kINT32) {
VLOG(1) << "TRT Plugin DataType selected. Slice-->int32";
const int *input1 = static_cast<const int *>(inputs[0]);
int *output = static_cast<int *>(outputs[0]);
SliceKernel<int><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
out_num, num_dims, input1, offset_temp_data_, output);
} else {
PADDLE_THROW(platform::errors::Fatal(
"The Slice TRT Plugin's input type should be float or half."));
"The Slice TRT Plugin's input type should be float, half or int."));
}
return cudaGetLastError() != cudaSuccess;
}
......
......@@ -108,5 +108,56 @@ class StaticSlicePluginTRTTestFp32(SlicePluginTRTTest):
self.enable_trt = True
class SlicePluginTRTTestInt32(SlicePluginTRTTest):
def setUp(self):
self.setUpSliceParams()
self.setUpTensorRTParams()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="int32")
axes = self.params_axes
starts = self.params_starts
ends = self.params_ends
slice_out = fluid.layers.slice(data,
axes=axes,
starts=starts,
ends=ends)
cast_out = fluid.layers.cast(slice_out, 'float32')
out = fluid.layers.batch_norm(cast_out, is_test=True)
self.feeds = {
"data": np.random.random((3, 3, 3, 3)).astype("int32"),
}
self.fetch_list = [out]
class StaticSlicePluginTRTTestInt32(SlicePluginTRTTest):
def setUpTensorRTParams(self):
self.trt_parameters = SlicePluginTRTTest.TensorRTParam(
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, True, False)
self.enable_trt = True
def setUp(self):
self.setUpSliceParams()
self.setUpTensorRTParams()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="int32")
axes = self.params_axes
starts = self.params_starts
ends = self.params_ends
slice_out = fluid.layers.slice(data,
axes=axes,
starts=starts,
ends=ends)
cast_out = fluid.layers.cast(slice_out, 'float32')
out = fluid.layers.batch_norm(cast_out, is_test=True)
self.feeds = {
"data": np.random.random((3, 3, 3, 3)).astype("int32"),
}
self.fetch_list = [out]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册