未验证 提交 47fdc60e 编写于 作者: S Shang Zhizhou 提交者: GitHub

Optimize slice trt plugin (#26970)

* optimize slice TRT plugin

This patch removes unnecessary barrier for data transfer of needed offset,
so data transfer can be overlap with GPU kernel execution.

This patch also fixes incorrect name of slice plugin. That is, replaces
"layernorm" with "slice"

test=develop

* add serialize/deserialize to slice plugin

* add static shape slice trt plugin

* fix slice trt op convertor dynamic shape bug

* fix format by clang-format

* fix pylint format error

* fix problems commented by peiyang
Co-authored-by: NRyan Jeng <rjeng@nvidia.com>
上级 cb34cf18
...@@ -23,9 +23,8 @@ class SliceOpConverter : public OpConverter { ...@@ -23,9 +23,8 @@ class SliceOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
// This OP is implemented by trt dynamic shpae plugin. // This OP is implemented by trt dynamic shpae plugin.
// Dynamic shape plugin requires TRT version greater than 6.0. // Dynamic shape plugin requires TRT version greater than 6.0.
#if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert slice op to tensorrt layer"; VLOG(4) << "convert slice op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
...@@ -38,27 +37,65 @@ class SliceOpConverter : public OpConverter { ...@@ -38,27 +37,65 @@ class SliceOpConverter : public OpConverter {
std::vector<int> ends = std::vector<int> ends =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends"));
PADDLE_ENFORCE_EQ(
starts.size(), axes.size(),
platform::errors::InvalidArgument(
"The size of starts must be equal to the size of axes."));
PADDLE_ENFORCE_EQ(
ends.size(), axes.size(),
platform::errors::InvalidArgument(
"The size of ends must be equal to the size of axes."));
auto input_dims = input->getDimensions();
if (!engine_->with_dynamic_shape()) {
// notice that input shape is [CHW] without batch axis when input has
// static shape
for (size_t i = input_dims.nbDims; i > 0; i--) {
input_dims.d[i] = input_dims.d[i - 1];
}
input_dims.d[0] = 1; // fake batchsize, not useful here
for (size_t i = 0; i < axes.size(); i++) {
// split on batch is not supported in TensorRT
PADDLE_ENFORCE_NE(axes[i], 0, platform::errors::InvalidArgument(
"Invalid slice axis. Slice on batch "
"axis is not supported in TensorRT"));
if (starts[i] < 0) {
starts[i] = std::max(starts[i] + input_dims.d[axes[i]], 0);
}
if (ends[i] < 0) {
ends[i] = std::max(ends[i] + input_dims.d[axes[i]], 0);
}
ends[i] = std::min(ends[i], input_dims.d[axes[i]]);
PADDLE_ENFORCE_GT(
ends[i], starts[i],
platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received ends = %d, starts = %d.",
ends[i], starts[i]));
}
}
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
bool ban_fp16 = engine_->disable_trt_plugin_fp16(); bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SlicePluginDynamic* plugin = plugin::SlicePluginDynamic* plugin =
new plugin::SlicePluginDynamic(starts, ends, ends, ban_fp16); new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16);
layer = engine_->AddPluginV2(&input, 1, plugin); layer = engine_->AddPluginV2(&input, 1, plugin);
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static"
"shape mode, which is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
" to set the shape information to run the dynamic shape mode."));
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "skip_layernorm", {output_name}, test_mode);
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that " "You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0")); "your TRT version is no less than 6.0"));
#endif #endif
} else {
bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SlicePlugin* plugin =
new plugin::SlicePlugin(starts, ends, axes, ban_fp16);
layer = engine_->AddPlugin(&input, 1, plugin);
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "slice", {output_name}, test_mode);
} }
}; };
......
...@@ -31,6 +31,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -31,6 +31,7 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("fused_embedding_eltwise_layernorm"); teller_set.insert("fused_embedding_eltwise_layernorm");
teller_set.insert("multihead_matmul"); teller_set.insert("multihead_matmul");
teller_set.insert("skip_layernorm"); teller_set.insert("skip_layernorm");
teller_set.insert("slice");
#endif #endif
} }
......
...@@ -26,8 +26,10 @@ namespace inference { ...@@ -26,8 +26,10 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
// Dynamic Plugin below. SlicePlugin *CreateSlicePluginDeserialize(const void *buffer, size_t length) {
#if IS_TRT_VERSION_GE(6000) return new SlicePlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("slice_plugin", CreateSlicePluginDeserialize);
template <typename T> template <typename T>
__global__ void SliceKernel(int num, int dims, const T *input, __global__ void SliceKernel(int num, int dims, const T *input,
...@@ -56,11 +58,196 @@ __global__ void SliceKernel(int num, int dims, const T *input, ...@@ -56,11 +58,196 @@ __global__ void SliceKernel(int num, int dims, const T *input,
} }
} }
SlicePlugin::SlicePlugin(std::vector<int> starts, std::vector<int> ends,
std::vector<int> axes, bool ban_fp16)
: starts_(starts), ends_(ends), axes_(axes), ban_fp16_(ban_fp16) {
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
}
SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) {
deserializeBase(serial_data, serial_length);
DeserializeValue(&serial_data, &serial_length, &starts_);
DeserializeValue(&serial_data, &serial_length, &ends_);
DeserializeValue(&serial_data, &serial_length, &axes_);
DeserializeValue(&serial_data, &serial_length, &ban_fp16_);
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
}
SlicePlugin::~SlicePlugin() {
cudaStreamDestroy(copy_stream_);
cudaEventDestroy(copy_event_);
cudaFree(offset_temp_data_);
}
SlicePlugin *SlicePlugin::clone() const {
return new SlicePlugin(starts_, ends_, axes_, ban_fp16_);
}
bool SlicePlugin::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const {
#ifdef SUPPORTS_CUDA_FP16
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW));
#else
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW));
#endif
}
nvinfer1::Dims SlicePlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputs,
int nb_input_dims) {
auto in_dims = inputs[0];
nvinfer1::Dims out_dims = in_dims;
for (size_t i = 0; i < axes_.size(); i++) {
int start = starts_[i];
int end = ends_[i];
out_dims.d[axes_[i] - 1] = end - start;
}
return out_dims;
}
int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
void **outputs, void *workspace, cudaStream_t stream) {
auto input_dims = getInputDims(0);
// notice input dims is [C, H, W], add input batch dim here
auto out_dims = getOutputDimensions(0, &input_dims, 1);
input_dims.nbDims += 1;
out_dims.nbDims += 1;
for (auto i = input_dims.nbDims; i > 0; --i) {
input_dims.d[i] = input_dims.d[i - 1];
out_dims.d[i] = out_dims.d[i - 1];
}
input_dims.d[0] = batch_size;
out_dims.d[0] = batch_size;
auto num_dims = input_dims.nbDims;
size_t out_num = ProductDim(out_dims);
std::vector<int> seg_offsets;
std::vector<int> offsets;
std::vector<int> extends;
offsets.resize(num_dims);
extends.resize(num_dims);
seg_offsets.resize(num_dims);
seg_offsets[num_dims - 1] = 1;
for (int i = num_dims - 2; i >= 0; i--) {
seg_offsets[i] = input_dims.d[i + 1] * seg_offsets[i + 1];
}
for (size_t i = 0; i < num_dims; ++i) {
offsets[i] = 0;
extends[i] = out_dims.d[i];
}
for (size_t i = 0; i < axes_.size(); ++i) {
offsets[axes_[i]] = starts_[i];
}
std::vector<int> offset_info;
for (size_t i = 0; i < num_dims; ++i) {
offset_info.push_back(offsets[i]);
offset_info.push_back(extends[i]);
offset_info.push_back(seg_offsets[i]);
}
if (offset_temp_data_ == nullptr) {
cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int));
}
cudaMemcpyAsync(offset_temp_data_, offset_info.data(),
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice,
copy_stream_);
cudaEventRecord(copy_event_, copy_stream_);
cudaStreamWaitEvent(stream, copy_event_, 0);
int threads = 256;
int blocks = (out_num + threads - 1) / threads;
auto input_type = getDataType();
if (input_type == nvinfer1::DataType::kFLOAT) {
const float *input1 = static_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]);
SliceKernel<float><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
out_num, num_dims, input1, offset_temp_data_, output);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
const half *input1 = static_cast<const half *>(inputs[0]);
half *output = static_cast<half *>(outputs[0]);
SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
out_num, num_dims, input1, offset_temp_data_, output);
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"The Slice TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
size_t SlicePlugin::getSerializationSize() {
return getBaseSerializationSize() + SerializedSize(getPluginType()) +
SerializedSize(starts_) + SerializedSize(ends_) +
SerializedSize(axes_) + SerializedSize(ban_fp16_);
}
void SlicePlugin::serialize(void *buffer) {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, starts_);
SerializeValue(&buffer, ends_);
SerializeValue(&buffer, axes_);
SerializeValue(&buffer, ban_fp16_);
}
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
std::vector<int> ends,
std::vector<int> axes, bool ban_fp16)
: starts_(starts), ends_(ends), axes_(axes), ban_fp16_(ban_fp16) {
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
}
SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &starts_);
DeserializeValue(&serialData, &serialLength, &ends_);
DeserializeValue(&serialData, &serialLength, &axes_);
DeserializeValue(&serialData, &serialLength, &ban_fp16_);
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
}
void SlicePluginDynamic::destroy() {
cudaStreamDestroy(copy_stream_);
cudaEventDestroy(copy_event_);
cudaFree(offset_temp_data_);
delete this;
}
int SlicePluginDynamic::initialize() { return 0; } int SlicePluginDynamic::initialize() { return 0; }
size_t SlicePluginDynamic::getSerializationSize() const { return 0; } size_t SlicePluginDynamic::getSerializationSize() const {
size_t size = SerializedSize(starts_) + SerializedSize(ends_) +
SerializedSize(axes_) + SerializedSize(ban_fp16_);
void SlicePluginDynamic::serialize(void *buffer) const {} return size;
}
void SlicePluginDynamic::serialize(void *buffer) const {
SerializeValue(&buffer, starts_);
SerializeValue(&buffer, ends_);
SerializeValue(&buffer, axes_);
SerializeValue(&buffer, ban_fp16_);
}
nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
...@@ -136,9 +323,9 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, ...@@ -136,9 +323,9 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
std::vector<int> offsets; std::vector<int> offsets;
std::vector<int> extends; std::vector<int> extends;
offsets.reserve(num_dims); offsets.resize(num_dims);
extends.reserve(num_dims); extends.resize(num_dims);
seg_offsets.reserve(num_dims); seg_offsets.resize(num_dims);
seg_offsets[num_dims - 1] = 1; seg_offsets[num_dims - 1] = 1;
for (int i = num_dims - 2; i >= 0; i--) { for (int i = num_dims - 2; i >= 0; i--) {
...@@ -160,16 +347,16 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, ...@@ -160,16 +347,16 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
offset_info.push_back(seg_offsets[i]); offset_info.push_back(seg_offsets[i]);
} }
framework::Tensor offset_temp_tensor; if (offset_temp_data_ == nullptr) {
cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int));
}
int device_id; cudaMemcpyAsync(offset_temp_data_, offset_info.data(),
cudaGetDevice(&device_id); sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice,
offset_temp_tensor.Resize({3 * num_dims}); copy_stream_);
auto *offset_temp_data =
offset_temp_tensor.mutable_data<int>(platform::CUDAPlace(device_id));
cudaMemcpyAsync(offset_temp_data, offset_info.data(), cudaEventRecord(copy_event_, copy_stream_);
sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream); cudaStreamWaitEvent(stream, copy_event_, 0);
int threads = 256; int threads = 256;
int blocks = (out_num + threads - 1) / threads; int blocks = (out_num + threads - 1) / threads;
...@@ -178,13 +365,13 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, ...@@ -178,13 +365,13 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
const float *input1 = static_cast<const float *>(inputs[0]); const float *input1 = static_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]); float *output = static_cast<float *>(outputs[0]);
SliceKernel<float><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>( SliceKernel<float><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
out_num, num_dims, input1, offset_temp_data, output); out_num, num_dims, input1, offset_temp_data_, output);
} else if (input_type == nvinfer1::DataType::kHALF) { } else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16 #ifdef SUPPORTS_CUDA_FP16
const half *input1 = static_cast<const half *>(inputs[0]); const half *input1 = static_cast<const half *>(inputs[0]);
half *output = static_cast<half *>(outputs[0]); half *output = static_cast<half *>(outputs[0]);
SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>( SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
out_num, num_dims, input1, offset_temp_data, output); out_num, num_dims, input1, offset_temp_data_, output);
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600.")); "The cuda archs you specific should greater than 600."));
......
...@@ -26,17 +26,56 @@ namespace inference { ...@@ -26,17 +26,56 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
class SlicePlugin : public PluginTensorRT {
public:
explicit SlicePlugin(std::vector<int> starts, std::vector<int> ends,
std::vector<int> axes, bool ban_fp16);
// It was used for tensorrt deserialization.
// It should not be called by users.
SlicePlugin(void const* serial_data, size_t serial_length);
~SlicePlugin();
SlicePlugin* clone() const override;
const char* getPluginType() const override { return "slice_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override;
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nb_input_dims) override;
int enqueue(int batch_size, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) override;
protected:
size_t getSerializationSize() override;
// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) override;
private:
std::vector<int> starts_;
std::vector<int> ends_;
std::vector<int> axes_;
bool ban_fp16_{false};
int* offset_temp_data_{nullptr};
cudaEvent_t copy_event_;
cudaStream_t copy_stream_;
};
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
class SlicePluginDynamic : public DynamicPluginTensorRT { class SlicePluginDynamic : public DynamicPluginTensorRT {
public: public:
explicit SlicePluginDynamic(std::vector<int> starts, std::vector<int> ends, explicit SlicePluginDynamic(std::vector<int> starts, std::vector<int> ends,
std::vector<int> axes, bool ban_fp16) std::vector<int> axes, bool ban_fp16);
: starts_(starts), ends_(ends), axes_(axes), ban_fp16_(ban_fp16) {}
SlicePluginDynamic(void const* serialData, size_t serialLength) {}
nvinfer1::IPluginV2DynamicExt* clone() const override { nvinfer1::IPluginV2DynamicExt* clone() const override {
return new SlicePluginDynamic(starts_, ends_, axes_, ban_fp16_); return new SlicePluginDynamic(starts_, ends_, axes_, ban_fp16_);
} }
SlicePluginDynamic(void const* serialData, size_t serialLength);
const char* getPluginType() const override { return "slice_plugin"; } const char* getPluginType() const override { return "slice_plugin"; }
int getNbOutputs() const override { return 1; } int getNbOutputs() const override { return 1; }
int initialize() override; int initialize() override;
...@@ -72,15 +111,54 @@ class SlicePluginDynamic : public DynamicPluginTensorRT { ...@@ -72,15 +111,54 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
const nvinfer1::DataType* inputTypes, const nvinfer1::DataType* inputTypes,
int nbInputs) const override; int nbInputs) const override;
void destroy() override { delete this; } void destroy() override;
private: private:
std::vector<int> starts_; std::vector<int> starts_;
std::vector<int> ends_; std::vector<int> ends_;
std::vector<int> axes_; std::vector<int> axes_;
bool ban_fp16_{false}; bool ban_fp16_{false};
int* offset_temp_data_{nullptr};
cudaEvent_t copy_event_;
cudaStream_t copy_stream_;
}; };
class SlicePluginV2Creator : public nvinfer1::IPluginCreator {
public:
SlicePluginV2Creator() {}
const char* getPluginName() const override { return "slice_plugin"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serialData,
size_t serialLength) override {
auto plugin = new SlicePluginDynamic(serialData, serialLength);
return plugin;
}
void setPluginNamespace(const char* libNamespace) override {
namespace_ = libNamespace;
}
const char* getPluginNamespace() const override { return namespace_.c_str(); }
private:
std::string namespace_;
nvinfer1::PluginFieldCollection field_collection_;
};
REGISTER_TRT_PLUGIN_V2(SlicePluginV2Creator);
#endif #endif
} // namespace plugin } // namespace plugin
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import AnalysisConfig
#normal starts && ends
class SlicePluginTRTTest1(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="float32")
axes = [1, 3]
starts = [0, 1]
ends = [2, 3]
slice_out = fluid.layers.slice(
data, axes=axes, starts=starts, ends=ends)
out = fluid.layers.batch_norm(slice_out, is_test=True)
self.feeds = {
"data": np.random.random((3, 3, 3, 3)).astype("float32"),
}
# Diff occurred between GPU and TRT.
# In order to provide TRT CI ASAP, this test for trt part
# is disabled temporarily.
self.enable_trt = True
self.trt_parameters = SlicePluginTRTTest1.TensorRTParam(
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
use_gpu = [False]
if core.is_compiled_with_cuda():
use_gpu.append(True)
for i in range(len(use_gpu)):
self.check_output_with_option(use_gpu[i])
#negative starts && ends
class SlicePluginTRTTest2(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="float32")
axes = [2, 3]
starts = [-3, -2]
ends = [-1, 3]
slice_out = fluid.layers.slice(
data, axes=axes, starts=starts, ends=ends)
out = fluid.layers.batch_norm(slice_out, is_test=True)
self.feeds = {
"data": np.random.random((3, 3, 3, 3)).astype("float32"),
}
# Diff occurred between GPU and TRT.
# In order to provide TRT CI ASAP, this test for trt part
# is disabled temporarily.
self.enable_trt = True
self.trt_parameters = SlicePluginTRTTest2.TensorRTParam(
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
use_gpu = [False]
if core.is_compiled_with_cuda():
use_gpu.append(True)
for i in range(len(use_gpu)):
self.check_output_with_option(use_gpu[i])
#exceeded bound starts && ends
class SlicePluginTRTTest3(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="float32")
axes = [2, 3]
starts = [-5, -2]
ends = [-1, 8]
slice_out = fluid.layers.slice(
data, axes=axes, starts=starts, ends=ends)
out = fluid.layers.batch_norm(slice_out, is_test=True)
self.feeds = {
"data": np.random.random((3, 3, 3, 3)).astype("float32"),
}
# Diff occurred between GPU and TRT.
# In order to provide TRT CI ASAP, this test for trt part
# is disabled temporarily.
self.enable_trt = True
self.trt_parameters = SlicePluginTRTTest3.TensorRTParam(
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
use_gpu = [False]
if core.is_compiled_with_cuda():
use_gpu.append(True)
for i in range(len(use_gpu)):
self.check_output_with_option(use_gpu[i])
#fp16
class SlicePluginTRTTest4(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="float32")
axes = [2, 3]
starts = [-5, -2]
ends = [-1, 8]
slice_out = fluid.layers.slice(
data, axes=axes, starts=starts, ends=ends)
out = fluid.layers.batch_norm(slice_out, is_test=True)
self.feeds = {
"data": np.random.random((3, 3, 3, 3)).astype("float32"),
}
# Diff occurred between GPU and TRT.
# In order to provide TRT CI ASAP, this test for trt part
# is disabled temporarily.
self.enable_trt = True
self.trt_parameters = SlicePluginTRTTest3.TensorRTParam(
1 << 30, 32, 1, AnalysisConfig.Precision.Half, False, False)
self.fetch_list = [out]
def test_check_output(self):
use_gpu = [False]
if core.is_compiled_with_cuda():
use_gpu.append(True)
for i in range(len(use_gpu)):
self.check_output_with_option(use_gpu[i])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册