未验证 提交 09b18a49 编写于 作者: S Shang Zhizhou 提交者: GitHub

[Paddle-TRT] Implement MHA fp16 order same as training (#32629) (#32785)

* implement MHA order same as training

* fix fp16 compile issue on old architecture
Co-authored-by: Nzlsh80826 <rewang@nvidia.com>
上级 2ec6b6f1
...@@ -225,6 +225,14 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType( ...@@ -225,6 +225,14 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType(
return input_types[0]; return input_types[0];
} }
template <typename T>
__global__ void apply_scale(T *data, T scale, int n) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = blockIdx.x * blockDim.x + threadIdx.x;
data[tid] = data[tid] * scale;
#endif
}
int QkvToContextPluginDynamic::enqueue( int QkvToContextPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
...@@ -291,10 +299,17 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -291,10 +299,17 @@ int QkvToContextPluginDynamic::enqueue(
platform::DeviceContextPool::Instance().Get( platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(device_id))); platform::CUDAPlace(device_id)));
int n_q = seq_len * head_number_ * head_size_;
constexpr int threads = 128;
int blocks = (n_q + threads - 1) / threads;
apply_scale<<<blocks, threads, 0, stream>>>(tptr, static_cast<half>(scale_),
n_q);
const platform::CUDADeviceContext &dev_ctx = *device_ctx; const platform::CUDADeviceContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func; operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_, multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_,
qkptr, input1_data, tptr, half(scale_), half(0.0)); qkptr, input1_data, tptr, half(1.), half(0.0));
int grid = batch * head_number_ * seq_len; int grid = batch * head_number_ * seq_len;
int block = head_size_; int block = head_size_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册