未验证 提交 76b02b7c 编写于 作者: Z zhoutianzi666 提交者: GitHub

fix compile fail in cuda11.6 (#43559)

上级 1a1d596b
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#pragma once #pragma once
#include <thrust/device_vector.h>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <thrust/device_vector.h>
#include <algorithm> #include <algorithm>
...@@ -63,9 +64,7 @@ void SplitPlugin::shareData(const SplitPlugin* another) { ...@@ -63,9 +64,7 @@ void SplitPlugin::shareData(const SplitPlugin* another) {
inner_cols_ = another->inner_cols_; inner_cols_ = another->inner_cols_;
same_shape_ = another->same_shape_; same_shape_ = another->same_shape_;
axis_shape_ = another->axis_shape_; axis_shape_ = another->axis_shape_;
d_segment_offsets_ = another->d_segment_offsets_;
segment_offsets_ = another->segment_offsets_; segment_offsets_ = another->segment_offsets_;
d_output_ptrs_.resize(another->d_output_ptrs_.size(), nullptr);
} }
int SplitPlugin::initialize() TRT_NOEXCEPT { int SplitPlugin::initialize() TRT_NOEXCEPT {
...@@ -93,9 +92,7 @@ int SplitPlugin::initialize() TRT_NOEXCEPT { ...@@ -93,9 +92,7 @@ int SplitPlugin::initialize() TRT_NOEXCEPT {
segment_offsets.push_back(segment_offsets.back() + output_length_[i]); segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
} }
axis_shape_ = dims.d[axis_]; axis_shape_ = dims.d[axis_];
d_segment_offsets_ = segment_offsets;
segment_offsets_ = std::move(segment_offsets); segment_offsets_ = std::move(segment_offsets);
d_output_ptrs_.resize(this->getNbOutputs(), nullptr);
return 0; return 0;
} }
...@@ -133,13 +130,18 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs, ...@@ -133,13 +130,18 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void* const* outputs, void* workspace, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT { cudaStream_t stream) TRT_NOEXCEPT {
#endif #endif
// this two thrust variables decalared here , not with in .h
// to avoid compiling error in cuda 11.6
thrust::device_vector<int> d_segment_offsets = segment_offsets_;
thrust::device_vector<float*> d_output_ptrs;
d_output_ptrs.resize(segment_offsets_.size(), nullptr);
const int* d_segment_offsets_ptr = const int* d_segment_offsets_ptr =
thrust::raw_pointer_cast(&d_segment_offsets_[0]); thrust::raw_pointer_cast(&d_segment_offsets[0]);
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]); float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
float* const* h_odatas = reinterpret_cast<float* const*>(outputs); float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]); float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync( PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
output_ptrs, h_odatas, d_output_ptrs_.size() * sizeof(float*), output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(float*),
cudaMemcpyHostToDevice, stream)); cudaMemcpyHostToDevice, stream));
int outer_rows = outer_rows_ * batchSize; int outer_rows = outer_rows_ * batchSize;
...@@ -150,7 +152,7 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs, ...@@ -150,7 +152,7 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
std::min((outer_rows_ - 1) / block.z + 1, 65535u)); std::min((outer_rows_ - 1) / block.z + 1, 65535u));
split_kernel<<<grid, block, 0, stream>>>( split_kernel<<<grid, block, 0, stream>>>(
d_segment_offsets_.size(), d_segment_offsets_ptr, input_ptr, output_ptrs, segment_offsets_.size(), d_segment_offsets_ptr, input_ptr, output_ptrs,
inner_cols_, axis_shape_, outer_rows); inner_cols_, axis_shape_, outer_rows);
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#pragma once #pragma once
#include <thrust/device_vector.h>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -94,8 +92,6 @@ class SplitPlugin : public PluginTensorRTV2Ext { ...@@ -94,8 +92,6 @@ class SplitPlugin : public PluginTensorRTV2Ext {
bool same_shape_; bool same_shape_;
std::vector<int> output_length_; std::vector<int> output_length_;
std::vector<int> segment_offsets_; std::vector<int> segment_offsets_;
thrust::device_vector<int> d_segment_offsets_;
thrust::device_vector<float*> d_output_ptrs_;
private: private:
void shareData(const SplitPlugin* another); void shareData(const SplitPlugin* another);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册