// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include // NOLINT #include #include "glog/logging.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { namespace inference { namespace tensorrt { namespace plugin { // Dynamic Plugin below. #if IS_TRT_VERSION_GE(6000) template __global__ void transpose(T *src, T *dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head) { int batch_id = blockIdx.x / (head_num * seq_len); int seq_id = blockIdx.x % seq_len; int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; dst[batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head + head_id * size_per_head + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; } template __global__ void TransposeQkvKernel(const int H, const T *input, T *output) { // Input: BxSx3xNxH // Bias: 3xSxB // Output: 3xBxNxSxH int n = threadIdx.y; int s = blockIdx.x; int b = blockIdx.y; int m = blockIdx.z; const int N = blockDim.y; const int S = gridDim.x; const int B = gridDim.y; const int NH = N * H; const int NHS = NH * S; const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3; const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B; const int i = threadIdx.x; output[out_offset + i] = input[in_offset + i]; } inline void TransposeQKV(const int batch, const int seq_len, const int head_size, const int head_num, const float *input, float *output, cudaStream_t stream) { int scratch_size = batch * head_num * seq_len * seq_len; const dim3 grid(seq_len, batch, 3); if (head_size % 4 == 0 && scratch_size % 4 == 0) { const int h = head_size / 4; const float4 *input4 = reinterpret_cast(input); float4 *output4 = reinterpret_cast(output); const dim3 block(h, head_num, 1); // limit h * head_num to max block size(1024). PADDLE_ENFORCE_LE(h * head_num, 1024, platform::errors::InvalidArgument( "head_num (%d) * head_size (%d) should <= %d", head_num, head_size, 1024 * 4)); TransposeQkvKernel<<>>(h, input4, output4); } else if (head_size % 2 == 0 && scratch_size % 2 == 0) { const int h = head_size / 2; const float2 *input2 = reinterpret_cast(input); float2 *output2 = reinterpret_cast(output); const dim3 block(h, head_num, 1); // limit h * head_num to max block size(1024). PADDLE_ENFORCE_LE(h * head_num, 1024, platform::errors::InvalidArgument( "head_num (%d) * head_size (%d) should <= %d", head_num, head_size, 1024 * 2)); TransposeQkvKernel<<>>(h, input2, output2); } else { const dim3 block(head_size, head_num, 1); // limit head_size * head_num to max block size(1024). PADDLE_ENFORCE_LE(head_size * head_num, 1024, platform::errors::InvalidArgument( "head_num (%d) * head_size (%d) should <= %d", head_num, head_size, 1024)); TransposeQkvKernel<<>>(head_size, input, output); } } inline void TransposeQKV(const int batch, const int seq_len, const int head_size, const int head_num, const half *input, half *output, cudaStream_t stream) { int scratch_size = batch * head_num * seq_len * seq_len; const dim3 grid(seq_len, batch, 3); if (head_size % 8 == 0 && scratch_size % 8 == 0) { int h = head_size / 8; const int4 *input4 = reinterpret_cast(input); int4 *output4 = reinterpret_cast(output); dim3 block(h, head_num, 1); // limit h * head_num to max block size(1024). PADDLE_ENFORCE_LE(h * head_num, 1024, platform::errors::InvalidArgument( "head_num (%d) * head_size (%d) should <= %d", head_num, head_size, 1024 * 8)); TransposeQkvKernel<<>>(h, input4, output4); } else if (head_size % 2 == 0 && scratch_size % 2 == 0) { const int h = head_size / 2; const half2 *input2 = reinterpret_cast(input); half2 *output2 = reinterpret_cast(output); const dim3 block(h, head_num, 1); // limit h * head_num to max block size(1024). PADDLE_ENFORCE_LE(h * head_num, 1024, platform::errors::InvalidArgument( "head_num (%d) * head_size (%d) should <= %d", head_num, head_size, 1024 * 2)); TransposeQkvKernel<<>>(h, input2, output2); } else { const dim3 block(head_size, head_num, 1); // limit head_size * head_num to max block size(1024). PADDLE_ENFORCE_LE(head_size * head_num, 1024, platform::errors::InvalidArgument( "head_num (%d) * head_size (%d) should <= %d", head_num, head_size, 1024)); TransposeQkvKernel<<>>(head_size, input, output); } } int QkvToContextPluginDynamic::initialize() { return 0; } nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, nvinfer1::IExprBuilder &expr_builder) { // input[0], (B, S, 3 * N * H, 1, 1) // input[1], (B, head_num, seq_len, seq_len) // output, (B, seq_len, hidden) PADDLE_ENFORCE_EQ(output_index, 0, platform::errors::InvalidArgument( "There is only one output of the EmbEltwiseLayernorm, " "so the index should be zero," "but it's (%d)", output_index)); PADDLE_ENFORCE_EQ( nb_inputs, 2, platform::errors::InvalidArgument( "The Input of the EmbEltwiseLayernorm should be 3, but we found " "it has (%d) inputs", nb_inputs)); nvinfer1::DimsExprs ret; ret.nbDims = 3; ret.d[0] = inputs[0].d[0]; ret.d[1] = inputs[0].d[1]; ret.d[2] = expr_builder.constant(head_size_ * head_number_); return ret; } bool QkvToContextPluginDynamic::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, int nb_outputs) { PADDLE_ENFORCE_NOT_NULL( in_out, platform::errors::InvalidArgument( "The input of swish plugin shoule not be nullptr.")); PADDLE_ENFORCE_LT( pos, nb_inputs + nb_outputs, platform::errors::InvalidArgument("The pos(%d) should be less than the " "num(%d) of the input and the output.", pos, nb_inputs + nb_outputs)); const nvinfer1::PluginTensorDesc &in = in_out[pos]; if (pos == 0) { if (with_fp16_) { #ifdef TRT_PLUGIN_FP16_AVALIABLE return (in.type == nvinfer1::DataType::kFLOAT || in.type == nvinfer1::DataType::kHALF) && (in.format == nvinfer1::TensorFormat::kLINEAR); #else return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == nvinfer1::TensorFormat::kLINEAR); #endif } else { return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == nvinfer1::TensorFormat::kLINEAR); } } const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; if (pos == 1) { return in.type == prev.type && in.format == prev.format; } // output return in.type == prev.type && in.format == prev.format; } nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType( int index, const nvinfer1::DataType *input_types, int nb_inputs) const { PADDLE_ENFORCE_EQ( index, 0, platform::errors::InvalidArgument( "The EmbEltwiseLayernorm Plugin only has one input, so the " "index value should be 0, but get %d.", index)); return input_types[0]; } template __global__ void apply_scale(T *data, T scale, int n) { #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) int tid = blockIdx.x * blockDim.x + threadIdx.x; data[tid] = data[tid] * scale; #endif } int QkvToContextPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) { auto input_dims = input_desc[0].dims; int input_num = ProductDim(input_dims); // input[0], (B, S, 3 * N * H, 1, 1) int batch = input_dims.d[0]; int seq_len = input_dims.d[1]; framework::Tensor multihead_temp_tensor; int scratch_size = batch * head_number_ * seq_len * seq_len * 1; int device_id; cudaGetDevice(&device_id); multihead_temp_tensor.Resize({scratch_size + input_num}); auto input_type = input_desc[0].type; if (input_type == nvinfer1::DataType::kFLOAT) { VLOG(1) << "TRT Plugin DataType selected. QkvToContext-->fp32"; auto *multihead_temp_data = multihead_temp_tensor.mutable_data( platform::CUDAPlace(device_id)); auto *qkptr = multihead_temp_data; auto *tptr = multihead_temp_data + scratch_size; const float *input0_data = static_cast(inputs[0]); const float *input1_data = static_cast(inputs[1]); // BxSx3xNxH => tptr: 3xBxNxSxH. TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); auto *device_ctx = static_cast( platform::DeviceContextPool::Instance().Get( platform::CUDAPlace(device_id))); const platform::CUDADeviceContext &dev_ctx = *device_ctx; operators::math::MultiHeadGPUComputeFunctor multihead_compute_func; multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_, qkptr, input1_data, tptr, scale_, static_cast(0.0)); int grid = batch * head_number_ * seq_len; int block = head_size_; float *output = static_cast(outputs[0]); transpose<<>>(tptr, output, batch, seq_len, head_number_, head_size_); } else if (input_type == nvinfer1::DataType::kHALF) { #ifdef TRT_PLUGIN_FP16_AVALIABLE VLOG(1) << "TRT Plugin DataType selected. QkvToContext-->fp16"; auto *multihead_temp_data = multihead_temp_tensor.mutable_data( // NOLINT platform::CUDAPlace(device_id)); half *qkptr = reinterpret_cast(multihead_temp_data); half *tptr = qkptr + scratch_size; const half *input0_data = static_cast(inputs[0]); const half *input1_data = static_cast(inputs[1]); // BxSx3xNxH => tptr: 3xBxNxSxH. TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); auto *device_ctx = static_cast( platform::DeviceContextPool::Instance().Get( platform::CUDAPlace(device_id))); int n_q = seq_len * head_number_ * head_size_; constexpr int threads = 128; int blocks = (n_q + threads - 1) / threads; apply_scale<<>>(tptr, static_cast(scale_), n_q); const platform::CUDADeviceContext &dev_ctx = *device_ctx; operators::math::MultiHeadGPUComputeFunctor multihead_compute_func; multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_, qkptr, input1_data, tptr, half(1.), half(0.0)); int grid = batch * head_number_ * seq_len; int block = head_size_; half *output = static_cast(outputs[0]); transpose<<>>(tptr, output, batch, seq_len, head_number_, head_size_); #else PADDLE_THROW(platform::errors::Fatal( "The Ernie(Bert) TensorRT Plugin should be " "complied with CUDA version >= 10.0 when running with fp16. " "Please recomplie it or try to use fp32 by set " "config.SetTRTDynamicShapeInfo(min_input_shape, " "max_input_shape, opt_input_shape, true")); #endif } else { PADDLE_THROW(platform::errors::Fatal( "The QKV TRT Plugin's input type should be float or half.")); } return cudaGetLastError() != cudaSuccess; } #endif } // namespace plugin } // namespace tensorrt } // namespace inference } // namespace paddle