From 47fdc60ecc1890e667f6301813818b386486afaa Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Tue, 15 Sep 2020 15:27:32 +0800 Subject: [PATCH] Optimize slice trt plugin (#26970) * optimize slice TRT plugin This patch removes unnecessary barrier for data transfer of needed offset, so data transfer can be overlap with GPU kernel execution. This patch also fixes incorrect name of slice plugin. That is, replaces "layernorm" with "slice" test=develop * add serialize/deserialize to slice plugin * add static shape slice trt plugin * fix slice trt op convertor dynamic shape bug * fix format by clang-format * fix pylint format error * fix problems commented by peiyang Co-authored-by: Ryan Jeng --- .../inference/tensorrt/convert/slice_op.cc | 67 ++++-- paddle/fluid/inference/tensorrt/op_teller.cc | 1 + .../tensorrt/plugin/slice_op_plugin.cu | 221 ++++++++++++++++-- .../tensorrt/plugin/slice_op_plugin.h | 88 ++++++- .../ir/inference/test_trt_slice_plugin.py | 150 ++++++++++++ 5 files changed, 490 insertions(+), 37 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index 2a76317eea1..3c3fead3d36 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -23,9 +23,8 @@ class SliceOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { -// This OP is implemented by trt dynamic shpae plugin. -// Dynamic shape plugin requires TRT version greater than 6.0. -#if IS_TRT_VERSION_GE(6000) + // This OP is implemented by trt dynamic shpae plugin. + // Dynamic shape plugin requires TRT version greater than 6.0. VLOG(4) << "convert slice op to tensorrt layer"; framework::OpDesc op_desc(op, nullptr); // Declare inputs @@ -38,27 +37,65 @@ class SliceOpConverter : public OpConverter { std::vector ends = BOOST_GET_CONST(std::vector, op_desc.GetAttr("ends")); + PADDLE_ENFORCE_EQ( + starts.size(), axes.size(), + platform::errors::InvalidArgument( + "The size of starts must be equal to the size of axes.")); + PADDLE_ENFORCE_EQ( + ends.size(), axes.size(), + platform::errors::InvalidArgument( + "The size of ends must be equal to the size of axes.")); + + auto input_dims = input->getDimensions(); + if (!engine_->with_dynamic_shape()) { + // notice that input shape is [CHW] without batch axis when input has + // static shape + for (size_t i = input_dims.nbDims; i > 0; i--) { + input_dims.d[i] = input_dims.d[i - 1]; + } + input_dims.d[0] = 1; // fake batchsize, not useful here + for (size_t i = 0; i < axes.size(); i++) { + // split on batch is not supported in TensorRT + PADDLE_ENFORCE_NE(axes[i], 0, platform::errors::InvalidArgument( + "Invalid slice axis. Slice on batch " + "axis is not supported in TensorRT")); + if (starts[i] < 0) { + starts[i] = std::max(starts[i] + input_dims.d[axes[i]], 0); + } + if (ends[i] < 0) { + ends[i] = std::max(ends[i] + input_dims.d[axes[i]], 0); + } + ends[i] = std::min(ends[i], input_dims.d[axes[i]]); + PADDLE_ENFORCE_GT( + ends[i], starts[i], + platform::errors::InvalidArgument( + "Attr(ends) should be greater than attr(starts) in " + "slice op. But received ends = %d, starts = %d.", + ends[i], starts[i])); + } + } + nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { +#if IS_TRT_VERSION_GE(6000) bool ban_fp16 = engine_->disable_trt_plugin_fp16(); plugin::SlicePluginDynamic* plugin = - new plugin::SlicePluginDynamic(starts, ends, ends, ban_fp16); + new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16); layer = engine_->AddPluginV2(&input, 1, plugin); - } else { +#else PADDLE_THROW(platform::errors::Fatal( - "You are running the Ernie(Bert) model in static" - "shape mode, which is not supported for the time being.\n" - "You can use the config.SetTRTDynamicShapeInfo(...) interface" - " to set the shape information to run the dynamic shape mode.")); + "You are running the TRT Dynamic Shape mode, need to confirm that " + "your TRT version is no less than 6.0")); +#endif + } else { + bool ban_fp16 = engine_->disable_trt_plugin_fp16(); + plugin::SlicePlugin* plugin = + new plugin::SlicePlugin(starts, ends, axes, ban_fp16); + layer = engine_->AddPlugin(&input, 1, plugin); } auto output_name = op_desc.Output("Out")[0]; - RreplenishLayerAndOutput(layer, "skip_layernorm", {output_name}, test_mode); -#else - PADDLE_THROW(platform::errors::Fatal( - "You are running the TRT Dynamic Shape mode, need to confirm that " - "your TRT version is no less than 6.0")); -#endif + RreplenishLayerAndOutput(layer, "slice", {output_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index a5b71356d0e..31128ba8c5d 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -31,6 +31,7 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("fused_embedding_eltwise_layernorm"); teller_set.insert("multihead_matmul"); teller_set.insert("skip_layernorm"); + teller_set.insert("slice"); #endif } diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu index 4fb1d824108..5c56270627a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu @@ -26,8 +26,10 @@ namespace inference { namespace tensorrt { namespace plugin { -// Dynamic Plugin below. -#if IS_TRT_VERSION_GE(6000) +SlicePlugin *CreateSlicePluginDeserialize(const void *buffer, size_t length) { + return new SlicePlugin(buffer, length); +} +REGISTER_TRT_PLUGIN("slice_plugin", CreateSlicePluginDeserialize); template __global__ void SliceKernel(int num, int dims, const T *input, @@ -56,11 +58,196 @@ __global__ void SliceKernel(int num, int dims, const T *input, } } +SlicePlugin::SlicePlugin(std::vector starts, std::vector ends, + std::vector axes, bool ban_fp16) + : starts_(starts), ends_(ends), axes_(axes), ban_fp16_(ban_fp16) { + cudaEventCreate(©_event_); + cudaStreamCreate(©_stream_); +} + +SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) { + deserializeBase(serial_data, serial_length); + DeserializeValue(&serial_data, &serial_length, &starts_); + DeserializeValue(&serial_data, &serial_length, &ends_); + DeserializeValue(&serial_data, &serial_length, &axes_); + DeserializeValue(&serial_data, &serial_length, &ban_fp16_); + cudaEventCreate(©_event_); + cudaStreamCreate(©_stream_); +} + +SlicePlugin::~SlicePlugin() { + cudaStreamDestroy(copy_stream_); + cudaEventDestroy(copy_event_); + cudaFree(offset_temp_data_); +} + +SlicePlugin *SlicePlugin::clone() const { + return new SlicePlugin(starts_, ends_, axes_, ban_fp16_); +} + +bool SlicePlugin::supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const { +#ifdef SUPPORTS_CUDA_FP16 + return ((type == nvinfer1::DataType::kFLOAT || + type == nvinfer1::DataType::kHALF) && + (format == nvinfer1::PluginFormat::kNCHW)); +#else + return ((type == nvinfer1::DataType::kFLOAT) && + (format == nvinfer1::PluginFormat::kNCHW)); +#endif +} + +nvinfer1::Dims SlicePlugin::getOutputDimensions(int index, + const nvinfer1::Dims *inputs, + int nb_input_dims) { + auto in_dims = inputs[0]; + nvinfer1::Dims out_dims = in_dims; + for (size_t i = 0; i < axes_.size(); i++) { + int start = starts_[i]; + int end = ends_[i]; + out_dims.d[axes_[i] - 1] = end - start; + } + return out_dims; +} + +int SlicePlugin::enqueue(int batch_size, const void *const *inputs, + void **outputs, void *workspace, cudaStream_t stream) { + auto input_dims = getInputDims(0); + + // notice input dims is [C, H, W], add input batch dim here + auto out_dims = getOutputDimensions(0, &input_dims, 1); + input_dims.nbDims += 1; + out_dims.nbDims += 1; + for (auto i = input_dims.nbDims; i > 0; --i) { + input_dims.d[i] = input_dims.d[i - 1]; + out_dims.d[i] = out_dims.d[i - 1]; + } + input_dims.d[0] = batch_size; + out_dims.d[0] = batch_size; + + auto num_dims = input_dims.nbDims; + size_t out_num = ProductDim(out_dims); + + std::vector seg_offsets; + std::vector offsets; + std::vector extends; + + offsets.resize(num_dims); + extends.resize(num_dims); + seg_offsets.resize(num_dims); + + seg_offsets[num_dims - 1] = 1; + for (int i = num_dims - 2; i >= 0; i--) { + seg_offsets[i] = input_dims.d[i + 1] * seg_offsets[i + 1]; + } + for (size_t i = 0; i < num_dims; ++i) { + offsets[i] = 0; + extends[i] = out_dims.d[i]; + } + for (size_t i = 0; i < axes_.size(); ++i) { + offsets[axes_[i]] = starts_[i]; + } + + std::vector offset_info; + for (size_t i = 0; i < num_dims; ++i) { + offset_info.push_back(offsets[i]); + offset_info.push_back(extends[i]); + offset_info.push_back(seg_offsets[i]); + } + + if (offset_temp_data_ == nullptr) { + cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int)); + } + + cudaMemcpyAsync(offset_temp_data_, offset_info.data(), + sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, + copy_stream_); + + cudaEventRecord(copy_event_, copy_stream_); + cudaStreamWaitEvent(stream, copy_event_, 0); + + int threads = 256; + int blocks = (out_num + threads - 1) / threads; + auto input_type = getDataType(); + if (input_type == nvinfer1::DataType::kFLOAT) { + const float *input1 = static_cast(inputs[0]); + float *output = static_cast(outputs[0]); + SliceKernel<<>>( + out_num, num_dims, input1, offset_temp_data_, output); + } else if (input_type == nvinfer1::DataType::kHALF) { +#ifdef SUPPORTS_CUDA_FP16 + const half *input1 = static_cast(inputs[0]); + half *output = static_cast(outputs[0]); + SliceKernel<<>>( + out_num, num_dims, input1, offset_temp_data_, output); +#else + PADDLE_THROW(platform::errors::Fatal( + "The cuda archs you specific should greater than 600.")); +#endif + } else { + PADDLE_THROW(platform::errors::Fatal( + "The Slice TRT Plugin's input type should be float or half.")); + } + return cudaGetLastError() != cudaSuccess; +} + +size_t SlicePlugin::getSerializationSize() { + return getBaseSerializationSize() + SerializedSize(getPluginType()) + + SerializedSize(starts_) + SerializedSize(ends_) + + SerializedSize(axes_) + SerializedSize(ban_fp16_); +} + +void SlicePlugin::serialize(void *buffer) { + SerializeValue(&buffer, getPluginType()); + serializeBase(buffer); + SerializeValue(&buffer, starts_); + SerializeValue(&buffer, ends_); + SerializeValue(&buffer, axes_); + SerializeValue(&buffer, ban_fp16_); +} + +// Dynamic Plugin below. +#if IS_TRT_VERSION_GE(6000) +SlicePluginDynamic::SlicePluginDynamic(std::vector starts, + std::vector ends, + std::vector axes, bool ban_fp16) + : starts_(starts), ends_(ends), axes_(axes), ban_fp16_(ban_fp16) { + cudaEventCreate(©_event_); + cudaStreamCreate(©_stream_); +} + +SlicePluginDynamic::SlicePluginDynamic(void const *serialData, + size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &starts_); + DeserializeValue(&serialData, &serialLength, &ends_); + DeserializeValue(&serialData, &serialLength, &axes_); + DeserializeValue(&serialData, &serialLength, &ban_fp16_); + cudaEventCreate(©_event_); + cudaStreamCreate(©_stream_); +} + +void SlicePluginDynamic::destroy() { + cudaStreamDestroy(copy_stream_); + cudaEventDestroy(copy_event_); + cudaFree(offset_temp_data_); + delete this; +} + int SlicePluginDynamic::initialize() { return 0; } -size_t SlicePluginDynamic::getSerializationSize() const { return 0; } +size_t SlicePluginDynamic::getSerializationSize() const { + size_t size = SerializedSize(starts_) + SerializedSize(ends_) + + SerializedSize(axes_) + SerializedSize(ban_fp16_); -void SlicePluginDynamic::serialize(void *buffer) const {} + return size; +} + +void SlicePluginDynamic::serialize(void *buffer) const { + SerializeValue(&buffer, starts_); + SerializeValue(&buffer, ends_); + SerializeValue(&buffer, axes_); + SerializeValue(&buffer, ban_fp16_); +} nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, @@ -136,9 +323,9 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, std::vector offsets; std::vector extends; - offsets.reserve(num_dims); - extends.reserve(num_dims); - seg_offsets.reserve(num_dims); + offsets.resize(num_dims); + extends.resize(num_dims); + seg_offsets.resize(num_dims); seg_offsets[num_dims - 1] = 1; for (int i = num_dims - 2; i >= 0; i--) { @@ -160,16 +347,16 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, offset_info.push_back(seg_offsets[i]); } - framework::Tensor offset_temp_tensor; + if (offset_temp_data_ == nullptr) { + cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int)); + } - int device_id; - cudaGetDevice(&device_id); - offset_temp_tensor.Resize({3 * num_dims}); - auto *offset_temp_data = - offset_temp_tensor.mutable_data(platform::CUDAPlace(device_id)); + cudaMemcpyAsync(offset_temp_data_, offset_info.data(), + sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, + copy_stream_); - cudaMemcpyAsync(offset_temp_data, offset_info.data(), - sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice, stream); + cudaEventRecord(copy_event_, copy_stream_); + cudaStreamWaitEvent(stream, copy_event_, 0); int threads = 256; int blocks = (out_num + threads - 1) / threads; @@ -178,13 +365,13 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, const float *input1 = static_cast(inputs[0]); float *output = static_cast(outputs[0]); SliceKernel<<>>( - out_num, num_dims, input1, offset_temp_data, output); + out_num, num_dims, input1, offset_temp_data_, output); } else if (input_type == nvinfer1::DataType::kHALF) { #ifdef SUPPORTS_CUDA_FP16 const half *input1 = static_cast(inputs[0]); half *output = static_cast(outputs[0]); SliceKernel<<>>( - out_num, num_dims, input1, offset_temp_data, output); + out_num, num_dims, input1, offset_temp_data_, output); #else PADDLE_THROW(platform::errors::Fatal( "The cuda archs you specific should greater than 600.")); diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h index 13d86df131f..e36a270f05d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h @@ -26,17 +26,56 @@ namespace inference { namespace tensorrt { namespace plugin { +class SlicePlugin : public PluginTensorRT { + public: + explicit SlicePlugin(std::vector starts, std::vector ends, + std::vector axes, bool ban_fp16); + + // It was used for tensorrt deserialization. + // It should not be called by users. + SlicePlugin(void const* serial_data, size_t serial_length); + ~SlicePlugin(); + SlicePlugin* clone() const override; + + const char* getPluginType() const override { return "slice_plugin"; } + int getNbOutputs() const override { return 1; } + int initialize() override { return 0; } + bool supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const override; + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int nb_input_dims) override; + int enqueue(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; + + protected: + size_t getSerializationSize() override; + + // TRT will call this func to serialize the configuration of TRT + // It should not be called by users. + void serialize(void* buffer) override; + + private: + std::vector starts_; + std::vector ends_; + std::vector axes_; + bool ban_fp16_{false}; + int* offset_temp_data_{nullptr}; + cudaEvent_t copy_event_; + cudaStream_t copy_stream_; +}; + #if IS_TRT_VERSION_GE(6000) class SlicePluginDynamic : public DynamicPluginTensorRT { public: explicit SlicePluginDynamic(std::vector starts, std::vector ends, - std::vector axes, bool ban_fp16) - : starts_(starts), ends_(ends), axes_(axes), ban_fp16_(ban_fp16) {} - SlicePluginDynamic(void const* serialData, size_t serialLength) {} + std::vector axes, bool ban_fp16); + nvinfer1::IPluginV2DynamicExt* clone() const override { return new SlicePluginDynamic(starts_, ends_, axes_, ban_fp16_); } + SlicePluginDynamic(void const* serialData, size_t serialLength); + const char* getPluginType() const override { return "slice_plugin"; } int getNbOutputs() const override { return 1; } int initialize() override; @@ -72,15 +111,54 @@ class SlicePluginDynamic : public DynamicPluginTensorRT { const nvinfer1::DataType* inputTypes, int nbInputs) const override; - void destroy() override { delete this; } + void destroy() override; private: std::vector starts_; std::vector ends_; std::vector axes_; - bool ban_fp16_{false}; + int* offset_temp_data_{nullptr}; + cudaEvent_t copy_event_; + cudaStream_t copy_stream_; }; + +class SlicePluginV2Creator : public nvinfer1::IPluginCreator { + public: + SlicePluginV2Creator() {} + const char* getPluginName() const override { return "slice_plugin"; } + + const char* getPluginVersion() const override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) override { + auto plugin = new SlicePluginDynamic(serialData, serialLength); + return plugin; + } + + void setPluginNamespace(const char* libNamespace) override { + namespace_ = libNamespace; + } + + const char* getPluginNamespace() const override { return namespace_.c_str(); } + + private: + std::string namespace_; + nvinfer1::PluginFieldCollection field_collection_; +}; + +REGISTER_TRT_PLUGIN_V2(SlicePluginV2Creator); + #endif } // namespace plugin diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py new file mode 100644 index 00000000000..660a9c93e66 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py @@ -0,0 +1,150 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig + + +#normal starts && ends +class SlicePluginTRTTest1(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="float32") + axes = [1, 3] + starts = [0, 1] + ends = [2, 3] + slice_out = fluid.layers.slice( + data, axes=axes, starts=starts, ends=ends) + out = fluid.layers.batch_norm(slice_out, is_test=True) + + self.feeds = { + "data": np.random.random((3, 3, 3, 3)).astype("float32"), + } + # Diff occurred between GPU and TRT. + # In order to provide TRT CI ASAP, this test for trt part + # is disabled temporarily. + self.enable_trt = True + self.trt_parameters = SlicePluginTRTTest1.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def test_check_output(self): + use_gpu = [False] + if core.is_compiled_with_cuda(): + use_gpu.append(True) + for i in range(len(use_gpu)): + self.check_output_with_option(use_gpu[i]) + + +#negative starts && ends +class SlicePluginTRTTest2(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="float32") + axes = [2, 3] + starts = [-3, -2] + ends = [-1, 3] + slice_out = fluid.layers.slice( + data, axes=axes, starts=starts, ends=ends) + out = fluid.layers.batch_norm(slice_out, is_test=True) + + self.feeds = { + "data": np.random.random((3, 3, 3, 3)).astype("float32"), + } + # Diff occurred between GPU and TRT. + # In order to provide TRT CI ASAP, this test for trt part + # is disabled temporarily. + self.enable_trt = True + self.trt_parameters = SlicePluginTRTTest2.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def test_check_output(self): + use_gpu = [False] + if core.is_compiled_with_cuda(): + use_gpu.append(True) + for i in range(len(use_gpu)): + self.check_output_with_option(use_gpu[i]) + + +#exceeded bound starts && ends +class SlicePluginTRTTest3(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="float32") + axes = [2, 3] + starts = [-5, -2] + ends = [-1, 8] + slice_out = fluid.layers.slice( + data, axes=axes, starts=starts, ends=ends) + out = fluid.layers.batch_norm(slice_out, is_test=True) + + self.feeds = { + "data": np.random.random((3, 3, 3, 3)).astype("float32"), + } + # Diff occurred between GPU and TRT. + # In order to provide TRT CI ASAP, this test for trt part + # is disabled temporarily. + self.enable_trt = True + self.trt_parameters = SlicePluginTRTTest3.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def test_check_output(self): + use_gpu = [False] + if core.is_compiled_with_cuda(): + use_gpu.append(True) + for i in range(len(use_gpu)): + self.check_output_with_option(use_gpu[i]) + + +#fp16 +class SlicePluginTRTTest4(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name="data", shape=[3, 3, 3, 3], dtype="float32") + axes = [2, 3] + starts = [-5, -2] + ends = [-1, 8] + slice_out = fluid.layers.slice( + data, axes=axes, starts=starts, ends=ends) + out = fluid.layers.batch_norm(slice_out, is_test=True) + + self.feeds = { + "data": np.random.random((3, 3, 3, 3)).astype("float32"), + } + # Diff occurred between GPU and TRT. + # In order to provide TRT CI ASAP, this test for trt part + # is disabled temporarily. + self.enable_trt = True + self.trt_parameters = SlicePluginTRTTest3.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Half, False, False) + self.fetch_list = [out] + + def test_check_output(self): + use_gpu = [False] + if core.is_compiled_with_cuda(): + use_gpu.append(True) + for i in range(len(use_gpu)): + self.check_output_with_option(use_gpu[i]) + + +if __name__ == "__main__": + unittest.main() -- GitLab