未验证 提交 76b02b7c 编写于 作者: Z zhoutianzi666 提交者: GitHub

fix compile fail in cuda11.6 (#43559)

上级 1a1d596b
......@@ -14,8 +14,6 @@
#pragma once
#include <thrust/device_vector.h>
#include <string>
#include <utility>
#include <vector>
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <cuda_fp16.h>
#include <thrust/device_vector.h>
#include <algorithm>
......@@ -63,9 +64,7 @@ void SplitPlugin::shareData(const SplitPlugin* another) {
inner_cols_ = another->inner_cols_;
same_shape_ = another->same_shape_;
axis_shape_ = another->axis_shape_;
d_segment_offsets_ = another->d_segment_offsets_;
segment_offsets_ = another->segment_offsets_;
d_output_ptrs_.resize(another->d_output_ptrs_.size(), nullptr);
}
int SplitPlugin::initialize() TRT_NOEXCEPT {
......@@ -93,9 +92,7 @@ int SplitPlugin::initialize() TRT_NOEXCEPT {
segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
}
axis_shape_ = dims.d[axis_];
d_segment_offsets_ = segment_offsets;
segment_offsets_ = std::move(segment_offsets);
d_output_ptrs_.resize(this->getNbOutputs(), nullptr);
return 0;
}
......@@ -133,13 +130,18 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
#endif
// this two thrust variables decalared here , not with in .h
// to avoid compiling error in cuda 11.6
thrust::device_vector<int> d_segment_offsets = segment_offsets_;
thrust::device_vector<float*> d_output_ptrs;
d_output_ptrs.resize(segment_offsets_.size(), nullptr);
const int* d_segment_offsets_ptr =
thrust::raw_pointer_cast(&d_segment_offsets_[0]);
thrust::raw_pointer_cast(&d_segment_offsets[0]);
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]);
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
output_ptrs, h_odatas, d_output_ptrs_.size() * sizeof(float*),
output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(float*),
cudaMemcpyHostToDevice, stream));
int outer_rows = outer_rows_ * batchSize;
......@@ -150,7 +152,7 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
std::min((outer_rows_ - 1) / block.z + 1, 65535u));
split_kernel<<<grid, block, 0, stream>>>(
d_segment_offsets_.size(), d_segment_offsets_ptr, input_ptr, output_ptrs,
segment_offsets_.size(), d_segment_offsets_ptr, input_ptr, output_ptrs,
inner_cols_, axis_shape_, outer_rows);
return cudaGetLastError() != cudaSuccess;
}
......
......@@ -14,8 +14,6 @@
#pragma once
#include <thrust/device_vector.h>
#include <string>
#include <utility>
#include <vector>
......@@ -94,8 +92,6 @@ class SplitPlugin : public PluginTensorRTV2Ext {
bool same_shape_;
std::vector<int> output_length_;
std::vector<int> segment_offsets_;
thrust::device_vector<int> d_segment_offsets_;
thrust::device_vector<float*> d_output_ptrs_;
private:
void shareData(const SplitPlugin* another);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册