From e1604f9ee68b9b5d169553049f9b0104d643561b Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 21 Jun 2022 11:41:34 +0800 Subject: [PATCH] fix compile fail in cuda11.6 (#43588) --- .../tensorrt/plugin/gather_nd_op_plugin.h | 1 - .../tensorrt/plugin/split_op_plugin.cu | 19 +++++++++++-------- .../tensorrt/plugin/split_op_plugin.h | 3 --- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h index 841fb2f6fe3..a293cf69be3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include #include diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu index ec4fcca6d74..349716d1e2c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -13,7 +13,9 @@ // limitations under the License. #include +#include #include + #include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" namespace paddle { @@ -61,9 +63,7 @@ void SplitPlugin::shareData(const SplitPlugin* another) { inner_cols_ = another->inner_cols_; same_shape_ = another->same_shape_; axis_shape_ = another->axis_shape_; - d_segment_offsets_ = another->d_segment_offsets_; segment_offsets_ = another->segment_offsets_; - d_output_ptrs_.resize(another->d_output_ptrs_.size(), nullptr); } int SplitPlugin::initialize() TRT_NOEXCEPT { @@ -91,9 +91,7 @@ int SplitPlugin::initialize() TRT_NOEXCEPT { segment_offsets.push_back(segment_offsets.back() + output_length_[i]); } axis_shape_ = dims.d[axis_]; - d_segment_offsets_ = segment_offsets; segment_offsets_ = std::move(segment_offsets); - d_output_ptrs_.resize(this->getNbOutputs(), nullptr); return 0; } @@ -131,13 +129,18 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT { #endif + // this two thrust variables decalared here , not with in .h + // to avoid compiling error in cuda 11.6 + thrust::device_vector d_segment_offsets = segment_offsets_; + thrust::device_vector d_output_ptrs; + d_output_ptrs.resize(segment_offsets_.size(), nullptr); const int* d_segment_offsets_ptr = - thrust::raw_pointer_cast(&d_segment_offsets_[0]); + thrust::raw_pointer_cast(&d_segment_offsets[0]); float const* input_ptr = reinterpret_cast(inputs[0]); float* const* h_odatas = reinterpret_cast(outputs); - float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]); + float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync( - output_ptrs, h_odatas, d_output_ptrs_.size() * sizeof(float*), + output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(float*), cudaMemcpyHostToDevice, stream)); int outer_rows = outer_rows_ * batchSize; @@ -148,7 +151,7 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs, std::min((outer_rows_ - 1) / block.z + 1, 65535u)); split_kernel<<>>( - d_segment_offsets_.size(), d_segment_offsets_ptr, input_ptr, output_ptrs, + segment_offsets_.size(), d_segment_offsets_ptr, input_ptr, output_ptrs, inner_cols_, axis_shape_, outer_rows); return cudaGetLastError() != cudaSuccess; } diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h index 7a41fe1d1ee..502c903dbff 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include #include @@ -92,8 +91,6 @@ class SplitPlugin : public PluginTensorRTV2Ext { bool same_shape_; std::vector output_length_; std::vector segment_offsets_; - thrust::device_vector d_segment_offsets_; - thrust::device_vector d_output_ptrs_; private: void shareData(const SplitPlugin* another); -- GitLab