From 75282e7466f948673faa7adf9a2da513e82c7d52 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Thu, 29 Apr 2021 09:04:32 +0800 Subject: [PATCH] [Paddle-TRT] Implement MHA fp16 order same as training (#32629) * implement MHA order same as training * fix fp16 compile issue on old architecture * fix format * fix format --- .../tensorrt/plugin/qkv_to_context_plugin.cu | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index a5fc9e73c5f..214e1a81e7d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -225,6 +225,14 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType( return input_types[0]; } +template +__global__ void apply_scale(T *data, T scale, int n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + int tid = blockIdx.x * blockDim.x + threadIdx.x; + data[tid] = data[tid] * scale; +#endif +} + int QkvToContextPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, @@ -291,10 +299,17 @@ int QkvToContextPluginDynamic::enqueue( platform::DeviceContextPool::Instance().Get( platform::CUDAPlace(device_id))); + int n_q = seq_len * head_number_ * head_size_; + constexpr int threads = 128; + int blocks = (n_q + threads - 1) / threads; + + apply_scale<<>>(tptr, static_cast(scale_), + n_q); + const platform::CUDADeviceContext &dev_ctx = *device_ctx; operators::math::MultiHeadGPUComputeFunctor multihead_compute_func; multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_, - qkptr, input1_data, tptr, half(scale_), half(0.0)); + qkptr, input1_data, tptr, half(1.), half(0.0)); int grid = batch * head_number_ * seq_len; int block = head_size_; -- GitLab