未验证 提交 1255e7d6 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] fix special_slice plugin (#39875)

* fix plugin: special slice for ernie
上级 ce207c3a
......@@ -113,12 +113,12 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x * gridDim.y;
const int batch = blockIdx.x;
const int local_idx = blockIdx.y * blockDim.y + threadIdx.x;
const int hidden = blockDim.x * gridDim.x;
const int hidden_id = blockIdx.x * blockDim.x + threadIdx.x;
const int batch_id = blockIdx.y;
output[batch * hidden + local_idx] =
slice_input[cu_seqlens[batch] * hidden + local_idx];
output[batch_id * hidden + hidden_id] =
slice_input[cu_seqlens[batch_id] * hidden + hidden_id];
}
int SpecialSlicePluginDynamic::enqueue(
......@@ -137,15 +137,16 @@ int SpecialSlicePluginDynamic::enqueue(
"hidden should be multiple of 128."));
constexpr int num_threads = 128;
const dim3 blocks(out_dims.d[0], hidden / num_threads);
const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]);
SpecialSliceKernel<<<blocks, num_threads, 0, stream>>>(slice_input,
cu_seqlens, output);
const int32_t num_blocks_x = hidden / num_threads;
const int32_t num_blocks_y = out_dims.d[0]; // batchs
const dim3 num_blocks(num_blocks_x, num_blocks_y); // blocks
SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
slice_input, cu_seqlens, output);
return cudaGetLastError() != cudaSuccess;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册