未验证 提交 1255e7d6 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] fix special_slice plugin (#39875)

* fix plugin: special slice for ernie
上级 ce207c3a
...@@ -113,12 +113,12 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType( ...@@ -113,12 +113,12 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
template <typename T> template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input, __global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) { const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x * gridDim.y; const int hidden = blockDim.x * gridDim.x;
const int batch = blockIdx.x; const int hidden_id = blockIdx.x * blockDim.x + threadIdx.x;
const int local_idx = blockIdx.y * blockDim.y + threadIdx.x; const int batch_id = blockIdx.y;
output[batch * hidden + local_idx] = output[batch_id * hidden + hidden_id] =
slice_input[cu_seqlens[batch] * hidden + local_idx]; slice_input[cu_seqlens[batch_id] * hidden + hidden_id];
} }
int SpecialSlicePluginDynamic::enqueue( int SpecialSlicePluginDynamic::enqueue(
...@@ -137,15 +137,16 @@ int SpecialSlicePluginDynamic::enqueue( ...@@ -137,15 +137,16 @@ int SpecialSlicePluginDynamic::enqueue(
"hidden should be multiple of 128.")); "hidden should be multiple of 128."));
constexpr int num_threads = 128; constexpr int num_threads = 128;
const dim3 blocks(out_dims.d[0], hidden / num_threads);
const half* slice_input = static_cast<const half*>(inputs[0]); const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]); const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]); half* output = static_cast<half*>(outputs[0]);
SpecialSliceKernel<<<blocks, num_threads, 0, stream>>>(slice_input, const int32_t num_blocks_x = hidden / num_threads;
cu_seqlens, output); const int32_t num_blocks_y = out_dims.d[0]; // batchs
const dim3 num_blocks(num_blocks_x, num_blocks_y); // blocks
SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
slice_input, cu_seqlens, output);
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册