diff --git a/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc b/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc index 4832b1fad1826c035885a0be95993ebaa5227d66..ffa8b15733ba96cffff5fd20a8b45964be9b7148 100644 --- a/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc @@ -38,6 +38,17 @@ class FusedTokenPruneOpConverter : public OpConverter { auto output_name = op_desc.Output("SlimmedX")[0]; auto out_inds_name = op_desc.Output("CLSInds")[0]; if (engine_->with_dynamic_shape()) { + // reduce_sum: (-1,headsize,token_length,token_length) -> + // (-1,token_length) + uint32_t reduce_dim = 0; + reduce_dim |= 1 << 1; // 00000000000000000000000000000010 + reduce_dim |= 1 << 2; // 00000000000000000000000000000110 + bool keep_dim = false; + nvinfer1::ReduceOperation reduce_type = nvinfer1::ReduceOperation::kSUM; + auto* reduce_sum_layer = TRT_ENGINE_ADD_LAYER( + engine_, Reduce, *Attn, reduce_type, reduce_dim, keep_dim); + auto* Reduced = reduce_sum_layer->getOutput(0); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); @@ -53,21 +64,10 @@ class FusedTokenPruneOpConverter : public OpConverter { auto* pos_id = engine_->GetITensor("pos_id"); auto* mask_id = engine_->GetITensor("mask_id"); - // reduce_sum: (-1,headsize,token_length,token_length) -> - // (-1,token_length) - uint32_t reduce_dim = 0; - reduce_dim |= 1 << 1; // 00000000000000000000000000000010 - reduce_dim |= 1 << 2; // 00000000000000000000000000000110 - bool keep_dim = false; - nvinfer1::ReduceOperation reduce_type = nvinfer1::ReduceOperation::kSUM; - auto* reduce_sum_layer = TRT_ENGINE_ADD_LAYER( - engine_, Reduce, *Attn, reduce_type, reduce_dim, keep_dim); - // reduce_sum_layer->getOutput(0)->setType(reduce_sum_layer->getInput(0)->getType()); - - auto* Reduced = reduce_sum_layer->getOutput(0); std::vector itensors = { Reduced, X, Mask, NewMask, word_id, pos_id, mask_id}; - layer = engine_->AddDynamicPlugin(itensors.data(), 7, plugin); + layer = engine_->AddDynamicPlugin( + itensors.data(), itensors.size(), plugin); // inputs'number: 7 layer->getOutput(0)->setName(output_name.c_str()); engine_->SetITensor(output_name, layer->getOutput(0)); @@ -87,10 +87,13 @@ class FusedTokenPruneOpConverter : public OpConverter { layer->getOutput(4)->setName("mask_id_after_token_prune"); engine_->SetITensor("mask_id", layer->getOutput(4)); } else { - std::vector itensors = {Attn, X, Mask, NewMask}; - layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin); + std::vector itensors = {Reduced, X, Mask, NewMask}; + layer = engine_->AddDynamicPlugin( + itensors.data(), itensors.size(), plugin); // inputs'number: 4 + layer->getOutput(0)->setName(output_name.c_str()); engine_->SetITensor(output_name, layer->getOutput(0)); + layer->getOutput(1)->setName(out_inds_name.c_str()); engine_->SetITensor(out_inds_name, layer->getOutput(1)); } diff --git a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu index 97a60b37088b7241d396bccb6c8dda34875b75fa..f65d40a0ea4b6a94768cf8a92646d4ea29aa10c6 100644 --- a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu @@ -31,150 +31,6 @@ namespace inference { namespace tensorrt { namespace plugin { -template -__global__ void ElementwiseMask(const T* a, - const T* b, - T* res, - int num_elements) { -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= num_elements) return; - const T zero = 0; - res[tid] = b[tid] >= zero ? a[tid] : zero; -#endif -} - -template -__global__ void FillZero(T* data, int len) { - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= len) return; - const T zero = 0; - data[tid] = zero; -} - -__global__ void FillIndex(int32_t* indices, int num_raws, int num_cols) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= num_raws * num_cols) return; - - int col = tid % num_cols; - int raw = tid / num_cols; - - indices[tid] = col; -} - -template -__global__ void MaximumFirst(T* mat, int num_raws, int num_cols, T max_value) { - auto raw = blockIdx.x * blockDim.x + threadIdx.x; - if (raw >= num_raws) return; - mat[raw * num_cols] = max_value; -} - -__global__ void FillOffsets(int* offsets, int num_raws, int num_cols) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid > num_raws) return; - - offsets[tid] = tid * num_cols; -} - -template -__global__ void Slice( - const T* src, T* dst, int num_raws, int src_num_cols, int dst_num_cols) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= num_raws * dst_num_cols) return; - int raw = tid / dst_num_cols; - int col = tid % dst_num_cols; - dst[tid] = src[raw * src_num_cols + col]; -} - -template -__global__ void ReduceSum2( - const T* src, T* dst, int bsz, int nb_head, int max_seq_len) { - int tid = threadIdx.x; - int bid = blockIdx.x; - int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len); - int batch = bid / (nb_head * num_blocks_per_head); - int col = bid % max_seq_len; - int head = (bid / num_blocks_per_head) % nb_head; - - extern __shared__ T res_float[]; - res_float[tid] = - src[batch * (nb_head * max_seq_len * max_seq_len) + - head * (max_seq_len * max_seq_len) + col + tid * max_seq_len]; - __syncthreads(); - - for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) { - if (tid < offset) { - res_float[tid] += res_float[tid + offset]; - } - __syncthreads(); - if (offset % 2 == 1 && tid == offset - 2) { - res_float[tid] += res_float[tid + 1]; - } - } - - if (tid == 0) { - auto* dst_addr = dst + batch * max_seq_len + col; - atomicAdd(dst_addr, res_float[0]); - } -} - -template <> -__global__ void ReduceSum2( - const half* src, half* dst, int bsz, int nb_head, int max_seq_len) { -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - int tid = threadIdx.x; - int bid = blockIdx.x; - int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len); - int batch = bid / (nb_head * num_blocks_per_head); - int col = bid % max_seq_len; - int head = (bid / num_blocks_per_head) % nb_head; - - extern __shared__ half res_half[]; - res_half[tid] = - src[batch * (nb_head * max_seq_len * max_seq_len) + - head * (max_seq_len * max_seq_len) + col + tid * max_seq_len]; - __syncthreads(); - - for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) { - if (tid < offset) { - res_half[tid] += res_half[tid + offset]; - } - __syncthreads(); - if (offset % 2 == 1 && tid == offset - 2) { - res_half[tid] += res_half[tid + 1]; - } - __syncthreads(); - } - - if (tid == 0) { - phi::fastAtomicAdd( - reinterpret_cast(dst), - static_cast(batch * max_seq_len + col), - static_cast(bsz * max_seq_len), - static_cast(res_half[0])); - } -#endif -} - -template -__global__ void TakeAlongAxis(const T* src, - T* dst, - int32_t* indices, - int num_raws, - int src_num_cols, - int dst_num_cols, - int num_elements) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= num_raws * dst_num_cols) return; - - int raw = tid / dst_num_cols; - int col = tid % dst_num_cols; - for (int i = 0; i < num_elements; ++i) { - dst[tid * num_elements + i] = - *(src + (raw * src_num_cols + indices[tid]) * num_elements + i); - } -} - __global__ void compute_token_length(const int32_t* src, int32_t* dst, float scale) { @@ -182,16 +38,18 @@ __global__ void compute_token_length(const int32_t* src, dst[it] = max(static_cast((src[it + 1] - src[it]) * scale), 1); } +template __global__ void fill_index_padding_score(int32_t* token_index, - const half* scores, - int32_t scores_size, - half* padding_scores) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - token_index[tid] = threadIdx.x; - if (tid < scores_size) { - padding_scores[tid] = scores[tid]; + const T* scores, + int32_t sequnce_length, + T* padding_scores) { + int padding_scores_it = threadIdx.x + blockIdx.x * blockDim.x; + int scores_it = threadIdx.x + blockIdx.x * sequnce_length; + token_index[padding_scores_it] = threadIdx.x; + if (threadIdx.x < sequnce_length) { + padding_scores[padding_scores_it] = scores[scores_it]; } else { - padding_scores[tid] = 0; + padding_scores[padding_scores_it] = 0; } } @@ -238,21 +96,64 @@ __global__ void general_topk_pair_sort(T* in_keys, int32_t* in_out_values) { .Store(in_out_values + block_offset, thread_values); } -__global__ void varlen_prune_token(const half* tokens, - const int32_t* token_pos, - const int32_t* token_index, - half* output) { +__global__ void varlen_prune_token_change_order( + const half* tokens, + const int32_t* token_pos, + const int32_t padding_token_length, + const int32_t* token_index, + half* output) { int batch = blockIdx.x; int token_it = batch * gridDim.y + blockIdx.y; int pre_value_it = token_it * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + threadIdx.x; + int token_index_it = batch * padding_token_length + blockIdx.y; - if (token_index[token_it] < token_pos[batch + 1] - token_pos[batch]) { - output[(token_index[token_it] + token_pos[batch]) * gridDim.z * blockDim.x + + if (token_index[token_index_it] < token_pos[batch + 1] - token_pos[batch]) { + output[(token_index[token_index_it] + token_pos[batch]) * gridDim.z * + blockDim.x + blockIdx.z * blockDim.x + threadIdx.x] = tokens[pre_value_it]; } } +template +__global__ void prune_token_change_order(const T* tokens, + int32_t new_sequnce_length, + const int32_t padding_token_length, + const int32_t* token_index, + T* output) { + int batch = blockIdx.x; + int token_it = batch * gridDim.y + blockIdx.y; + int pre_value_it = + token_it * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + threadIdx.x; + int token_index_it = batch * padding_token_length + blockIdx.y; + + if (token_index[token_index_it] < new_sequnce_length) { + output[(batch * new_sequnce_length + token_index[token_index_it]) * + gridDim.z * blockDim.x + + blockIdx.z * blockDim.x + threadIdx.x] = tokens[pre_value_it]; + } +} + +template +__global__ void prune_token_keep_order(const T* tokens, + int32_t pre_sequnce_length, + int32_t new_sequnce_length, + const int32_t padding_token_length, + const int32_t* token_index, + T* output) { + int batch = blockIdx.x; + int index = 0; + for (int i = 0; i < pre_sequnce_length; ++i) { + if (token_index[batch * padding_token_length + i] < new_sequnce_length) { + output[(batch * new_sequnce_length + index) * gridDim.y * blockDim.x + + blockIdx.y * blockDim.x + threadIdx.x] = + tokens[(batch * pre_sequnce_length + i) * gridDim.y * blockDim.x + + blockIdx.y * blockDim.x + threadIdx.x]; + index++; + } + } +} + nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs* inputs, @@ -353,7 +254,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( "should be half for varseqlen.")); } } else if (pos == 6 || pos == 11) { // mask_id, mask_id_out - return (in.type == nvinfer1::DataType::kFLOAT) && + return (in.type == nvinfer1::DataType::kHALF) && (in.format == nvinfer1::TensorFormat::kLINEAR); } else { return in.type == nvinfer1::DataType::kINT32 && @@ -364,7 +265,6 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( if (with_fp16_) { return (in.type == nvinfer1::DataType::kHALF) && (in.format == nvinfer1::TensorFormat::kLINEAR); - } else { return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == nvinfer1::TensorFormat::kLINEAR); @@ -373,8 +273,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( const nvinfer1::PluginTensorDesc& prev = in_out[0]; return in.type == prev.type && in.format == prev.format; } else { - return in.type == nvinfer1::DataType::kINT32 && - in.format == nvinfer1::TensorFormat::kLINEAR; + return in.format == nvinfer1::TensorFormat::kLINEAR; } } } @@ -425,199 +324,6 @@ size_t FusedTokenPrunePluginDynamic::getWorkspaceSize( return size; } -template -inline void enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc, - const nvinfer1::PluginTensorDesc* output_desc, - const void* const* inputs, - void* const* outputs, - void* workspace_ptr, - cudaStream_t stream, - int device_id, - T max_value, - bool keep_first_token_, - bool keep_order_) { - // Dims - auto attn_dims = input_desc[0].dims; - auto x_dims = input_desc[1].dims; - auto new_mask_dims = input_desc[3].dims; - - auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1], - max_seq_len = attn_dims.d[2]; - auto c = x_dims.d[2]; - auto slimmed_x_len = new_mask_dims.d[2]; - - // Inputs - const T* attn_data = static_cast(inputs[0]); - const T* x_data = static_cast(inputs[1]); - const T* mask_data = static_cast(inputs[2]); - - // Outputs - T* output_data = static_cast(outputs[0]); - int32_t* output_indices_data = static_cast(outputs[1]); - - int total = bsz * nb_head * max_seq_len * max_seq_len; - int block = operators::ComputeBlockSize(max_seq_len); - int grid = operators::CeilDivide(total, block); - - // Workspace for intermediate variable - char* workspace = static_cast(workspace_ptr); - T* attn_tmp_data = reinterpret_cast(workspace); - size_t offset = total * sizeof(T); - T* attn_accu_data = reinterpret_cast(workspace + offset); - offset += bsz * max_seq_len * sizeof(T); - int32_t* attn_accu_indices_data = - reinterpret_cast(workspace + offset); - offset += bsz * max_seq_len * sizeof(int32_t); - T* sort_attn_accu_data = reinterpret_cast(workspace + offset); - offset += bsz * max_seq_len * sizeof(T); - int32_t* sort_attn_accu_indices_data = - reinterpret_cast(workspace + offset); - offset += bsz * max_seq_len * sizeof(int32_t); - int* offsets_data = reinterpret_cast(workspace + offset); - offset += (bsz + 1) * sizeof(int); - int32_t* slimmed_sort_attn_accu_indices_data = - reinterpret_cast(workspace + offset); - - // 1. Filter attn by mask - ElementwiseMask - <<>>(attn_data, mask_data, attn_tmp_data, total); - - total = bsz * max_seq_len; - block = operators::ComputeBlockSize(max_seq_len); - grid = operators::CeilDivide(total, block); - FillZero<<>>(attn_accu_data, total); - - // 2. Reduce sum - total = bsz * nb_head * max_seq_len * max_seq_len; - int block_tmp = max_seq_len; - while (block_tmp > 1024) - block_tmp /= 2; // if max seq len > 1024, it must be 2^n - block = - block_tmp; // make sure max_seq_len is an integral multiple of block_size - grid = operators::CeilDivide(total, block); - ReduceSum2<<>>( - attn_tmp_data, attn_accu_data, bsz, nb_head, max_seq_len); - - // 3. Prepare token indices - total = bsz * max_seq_len; - block = operators::ComputeBlockSize(max_seq_len); - grid = operators::CeilDivide(total, block); - - FillIndex<<>>( - attn_accu_indices_data, bsz, max_seq_len); - - // 4. Sort token indices by attn - if (keep_first_token_) { - MaximumFirst - <<>>(attn_accu_data, bsz, max_seq_len, max_value); - } - size_t temp_storage_bytes = -1; - int num_items = bsz * max_seq_len; - int num_segments = bsz; - FillOffsets<<>>(offsets_data, bsz, max_seq_len); - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortPairsDescending( - nullptr, - temp_storage_bytes, - attn_accu_data, - sort_attn_accu_data, - attn_accu_indices_data, - sort_attn_accu_indices_data, - num_items, - num_segments, - offsets_data, - offsets_data + 1, - 0, - sizeof(T) * 8, - stream)); - int64_t temp_size = temp_storage_bytes; - phi::DenseTensor temp_storage; - auto* temp_storage_data = temp_storage.mutable_data( - {temp_size}, platform::CUDAPlace(device_id)); - - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortPairsDescending( - temp_storage_data, - temp_storage_bytes, - attn_accu_data, - sort_attn_accu_data, - attn_accu_indices_data, - sort_attn_accu_indices_data, - num_items, - num_segments, - offsets_data, - offsets_data + 1, - 0, - sizeof(T) * 8, - stream)); - // 5. Slice - total = bsz * slimmed_x_len; - block = operators::ComputeBlockSize(slimmed_x_len); - grid = operators::CeilDivide(total, block); - - Slice - <<>>(sort_attn_accu_indices_data, - slimmed_sort_attn_accu_indices_data, - bsz, - max_seq_len, - slimmed_x_len); - - if (keep_order_) { - // 6. reorder - num_items = bsz * slimmed_x_len; - FillOffsets<<>>(offsets_data, bsz, slimmed_x_len); - temp_storage_bytes = -1; - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, - temp_storage_bytes, - slimmed_sort_attn_accu_indices_data, - output_indices_data, - num_items, - num_segments, - offsets_data, - offsets_data + 1, - 0, - sizeof(int32_t) * 8, - stream)); - - temp_size = temp_storage_bytes; - temp_storage.Resize({temp_size}); - temp_storage_data = - temp_storage.mutable_data(platform::CUDAPlace(device_id)); - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys( - temp_storage_data, - temp_storage_bytes, - slimmed_sort_attn_accu_indices_data, - output_indices_data, - num_items, - num_segments, - offsets_data, - offsets_data + 1, - 0, - sizeof(int32_t) * 8, - stream)); - - TakeAlongAxis<<>>(x_data, - output_data, - output_indices_data, - bsz, - max_seq_len, - slimmed_x_len, - c); - } else { - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(output_indices_data, - slimmed_sort_attn_accu_indices_data, - bsz * slimmed_x_len * sizeof(int32_t), - cudaMemcpyDeviceToDevice)); - TakeAlongAxis - <<>>(x_data, - output_data, - slimmed_sort_attn_accu_indices_data, - bsz, - max_seq_len, - slimmed_x_len, - c); - } -} - int FusedTokenPrunePluginDynamic::enqueue( const nvinfer1::PluginTensorDesc* input_desc, const nvinfer1::PluginTensorDesc* output_desc, @@ -628,49 +334,56 @@ int FusedTokenPrunePluginDynamic::enqueue( if (flag_varseqlen_) { if (!(input_desc[0].type == nvinfer1::DataType::kHALF && input_desc[1].type == nvinfer1::DataType::kHALF)) { - PADDLE_THROW( - platform::errors::InvalidArgument("Token_prune'type must half")); + PADDLE_THROW(platform::errors::InvalidArgument( + "Token_prune'type must half for varseqlen")); } float scale = - static_cast(input_desc[3].dims.d[2]) / input_desc[6].dims.d[1]; - const int32_t* inputs5 = - static_cast(inputs[5]); // pre pos id - int32_t* outputs3 = static_cast(outputs[3]); // new pos id - half* outputs0 = static_cast(outputs[0]); - + static_cast(input_desc[3].dims.d[2]) / input_desc[2].dims.d[2]; + const int32_t* input5 = + static_cast(inputs[5]); // pre pos id + int32_t* output3 = static_cast(outputs[3]); // new pos id + half* output0 = static_cast(outputs[0]); const int32_t B = input_desc[1].dims.d[0]; // batchs const int32_t max_sequnce_length = input_desc[1].dims.d[1]; // max sequnce length - const int32_t length = input_desc[1].dims.d[2]; // vector length + const int32_t length = input_desc[1].dims.d[2]; // hidden size const half* scores = static_cast(inputs[0]); // reduce sum const half* tokens = static_cast(inputs[1]); - const int32_t scores_size = B * max_sequnce_length; int32_t padding_token_length; - if (max_sequnce_length <= 128) { + if (max_sequnce_length <= 64) { + padding_token_length = 64; + } else if (max_sequnce_length <= 128) { padding_token_length = 128; } else if (max_sequnce_length <= 256) { padding_token_length = 256; } else if (max_sequnce_length <= 384) { padding_token_length = 384; + } else if (max_sequnce_length <= 512) { + padding_token_length = 512; } else { PADDLE_THROW(platform::errors::InvalidArgument( - "Token_prune'token_length must <= 384")); + "Token_prune'token_length must <= 512")); } // 1. Compute the token length after pruning. compute_token_length<<<1, B, 0, stream>>>( - inputs5, pruned_token_lengths_, scale); + input5, pruned_token_lengths_, scale); - fill_index_padding_score<<>>( - token_index_, scores, scores_size, padding_scores_); + // 2. Padding scores + fill_index_padding_score<<>>( + token_index_, + scores, + max_sequnce_length, + static_cast(padding_scores_)); + // 3. compute new pos id // Determine temporary device storage requirements void* d_temp_storage = NULL; size_t temp_storage_bytes = 0; cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, pruned_token_lengths_, - outputs3, + output3, B + 1); // Allocate temporary storage cudaMalloc(&d_temp_storage, temp_storage_bytes); @@ -679,20 +392,28 @@ int FusedTokenPrunePluginDynamic::enqueue( cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, pruned_token_lengths_, - outputs3, + output3, B + 1); - if (padding_token_length == 128) { - general_topk_pair_sort - <<>>(padding_scores_, token_index_); // 128 + // 4. sort scores + if (padding_token_length == 64) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 64 + } else if (padding_token_length == 128) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 128 } else if (padding_token_length == 256) { - general_topk_pair_sort - <<>>(padding_scores_, token_index_); // 256 + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 256 + } else if (padding_token_length == 384) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 384 } else { - general_topk_pair_sort - <<>>(padding_scores_, token_index_); // 384 + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 512 } + // 5. compute output int32_t num_threads; if (length < 1024) { num_threads = length; @@ -723,46 +444,196 @@ int FusedTokenPrunePluginDynamic::enqueue( B, max_sequnce_length, length / num_threads); // batchs, max_sequnce_length, vector_ength/*** - varlen_prune_token<<>>( - tokens, outputs3, token_index_, outputs0); + varlen_prune_token_change_order<<>>( + tokens, output3, padding_token_length, token_index_, output0); } else { auto input_type = input_desc[0].type; - auto attn_dims = input_desc[0].dims; - auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1], - max_seq_len = attn_dims.d[2]; - int device_id; - cudaGetDevice(&device_id); - + const int32_t B = input_desc[1].dims.d[0]; // batchs + const int32_t pre_sequnce_length = input_desc[1].dims.d[1]; + const int32_t new_sequnce_length = input_desc[3].dims.d[2]; // new mask + const int32_t length = input_desc[1].dims.d[2]; // hidden size if (input_type == nvinfer1::DataType::kFLOAT) { VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp32"; + const float* scores = static_cast(inputs[0]); // reduce sum + const float* tokens = static_cast(inputs[1]); // X + float* output0 = static_cast(outputs[0]); + int32_t padding_token_length; + if (pre_sequnce_length <= 64) { + padding_token_length = 64; + } else if (pre_sequnce_length <= 128) { + padding_token_length = 128; + } else if (pre_sequnce_length <= 256) { + padding_token_length = 256; + } else if (pre_sequnce_length <= 384) { + padding_token_length = 384; + } else if (pre_sequnce_length <= 512) { + padding_token_length = 512; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Token_prune'token_length must <= 512")); + } - float max = std::numeric_limits::max(); - - enqueueImpl(input_desc, - output_desc, - inputs, - outputs, - workspace, - stream, - device_id, - max, - keep_first_token_, - keep_order_); + // 1. Padding scores + fill_index_padding_score<<>>( + token_index_, + scores, + pre_sequnce_length, + static_cast(padding_scores_)); + + // 2. sort scores + if (padding_token_length == 64) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 64 + } else if (padding_token_length == 128) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 128 + } else if (padding_token_length == 256) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 256 + } else if (padding_token_length == 384) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 384 + } else { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 512 + } + // 3. compute output + int32_t num_threads; + if (length < 1024) { + num_threads = length; + } else { + if (length % 512 == 0) { + num_threads = 512; + } else if (length % 256 == 0) { + num_threads = 256; + } else if (length % 128 == 0) { + num_threads = 128; + } else if (length % 64 == 0) { + num_threads = 64; + } else if (length % 32 == 0) { + num_threads = 32; + } else if (length % 16 == 0) { + num_threads = 16; + } else if (length % 8 == 0) { + num_threads = 8; + } else if (length % 4 == 0) { + num_threads = 4; + } else if (length % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } + } + if (keep_order_) { + const dim3 num_blocks(B, length / num_threads); + prune_token_keep_order + <<>>(tokens, + pre_sequnce_length, + new_sequnce_length, + padding_token_length, + token_index_, + output0); + } else { + const dim3 num_blocks(B, pre_sequnce_length, length / num_threads); + prune_token_change_order + <<>>(tokens, + new_sequnce_length, + padding_token_length, + token_index_, + output0); + } } else if (input_type == nvinfer1::DataType::kHALF) { VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp16"; + const half* scores = static_cast(inputs[0]); // reduce sum + const half* tokens = static_cast(inputs[1]); // X + half* output0 = static_cast(outputs[0]); + int32_t padding_token_length; + if (pre_sequnce_length <= 64) { + padding_token_length = 64; + } else if (pre_sequnce_length <= 128) { + padding_token_length = 128; + } else if (pre_sequnce_length <= 256) { + padding_token_length = 256; + } else if (pre_sequnce_length <= 384) { + padding_token_length = 384; + } else if (pre_sequnce_length <= 512) { + padding_token_length = 512; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Token_prune'token_length must <= 512")); + } + + // 1. Padding scores + fill_index_padding_score<<>>( + token_index_, + scores, + pre_sequnce_length, + static_cast(padding_scores_)); + + // 2. sort scores + if (padding_token_length == 64) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 64 + } else if (padding_token_length == 128) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 128 + } else if (padding_token_length == 256) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 256 + } else if (padding_token_length == 384) { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 384 + } else { + general_topk_pair_sort<<>>( + static_cast(padding_scores_), token_index_); // 512 + } - half max = 65504.0; - enqueueImpl(input_desc, - output_desc, - inputs, - outputs, - workspace, - stream, - device_id, - max, - keep_first_token_, - keep_order_); + // 3. compute output + int32_t num_threads; + if (length < 1024) { + num_threads = length; + } else { + if (length % 512 == 0) { + num_threads = 512; + } else if (length % 256 == 0) { + num_threads = 256; + } else if (length % 128 == 0) { + num_threads = 128; + } else if (length % 64 == 0) { + num_threads = 64; + } else if (length % 32 == 0) { + num_threads = 32; + } else if (length % 16 == 0) { + num_threads = 16; + } else if (length % 8 == 0) { + num_threads = 8; + } else if (length % 4 == 0) { + num_threads = 4; + } else if (length % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } + } + if (keep_order_) { + const dim3 num_blocks(B, length / num_threads); + prune_token_keep_order + <<>>(tokens, + pre_sequnce_length, + new_sequnce_length, + padding_token_length, + token_index_, + output0); + } else { + const dim3 num_blocks(B, pre_sequnce_length, length / num_threads); + prune_token_change_order + <<>>(tokens, + new_sequnce_length, + padding_token_length, + token_index_, + output0); + } } else { PADDLE_THROW( platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type " diff --git a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h index 4c9c24c59afa29279a615016f4342fa411ff30d5..238566c7b498febb7859d675c8c1cb0aebf7d314 100644 --- a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h @@ -93,12 +93,33 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { int nb_outputs) TRT_NOEXCEPT override { max_batchs_ = in[1].max.d[0]; max_token_length_ = in[1].max.d[1]; + int32_t padding_token_length; + if (max_token_length_ <= 64) { + padding_token_length = 64; + } else if (max_token_length_ <= 128) { + padding_token_length = 128; + } else if (max_token_length_ <= 256) { + padding_token_length = 256; + } else if (max_token_length_ <= 384) { + padding_token_length = 384; + } else if (max_token_length_ <= 512) { + padding_token_length = 512; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Token_prune'token_length(max) must <= 512")); + } PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(&pruned_token_lengths_, (max_batchs_ + 1) * sizeof(int32_t))); PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc( - &token_index_, max_batchs_ * max_token_length_ * sizeof(int32_t))); + &token_index_, max_batchs_ * padding_token_length * sizeof(int32_t))); + int32_t type_size = 4; + if (in[0].desc.type == nvinfer1::DataType::kHALF) { + type_size = 2; + } else { + type_size = 4; + } PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc( - &padding_scores_, max_batchs_ * max_token_length_ * sizeof(half))); + &padding_scores_, max_batchs_ * padding_token_length * type_size)); } size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, @@ -129,7 +150,7 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { int32_t* token_index_; int32_t max_batchs_; int32_t max_token_length_; - half* padding_scores_; + void* padding_scores_; }; class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator { diff --git a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc index 0fc959a14b93409a2ee4a1f4dc65f6d3ab9b266d..46451c5db57fdde33580c3e5ca66a41f42dd3963 100644 --- a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc @@ -352,24 +352,24 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test { ctx_->PartialInitWithAllocator(); std::map> min_input_shape = { - {"attn", {4, 1, 4, 4}}, + {"attn", {4, 4}}, {"x", {4, 4, 1}}, {"mask", {4, 1, 4, 4}}, {"new_mask", {4, 1, 2, 2}}}; std::map> max_input_shape = { - {"attn", {4, 1, 4, 4}}, + {"attn", {4, 4}}, {"x", {4, 4, 1}}, {"mask", {4, 1, 4, 4}}, {"new_mask", {4, 1, 2, 2}}}; std::map> optim_input_shape = { - {"attn", {4, 1, 4, 4}}, + {"attn", {4, 4}}, {"x", {4, 4, 1}}, {"mask", {4, 1, 4, 4}}, {"new_mask", {4, 1, 2, 2}}}; engine_ = new TensorRTEngine(16, 1 << 10, - AnalysisConfig::Precision::kHalf, + AnalysisConfig::Precision::kFloat32, nullptr, 0, min_input_shape, @@ -391,7 +391,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test { } } - void PrepareInputOutput(const std::vector> inputs, + void PrepareInputOutput(const std::vector> inputs, std::vector> output_shapes) { LOG(INFO) << "PrepareInputOutput"; int num_inputs = inputs.size(); @@ -423,15 +423,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { #if IS_TRT_VERSION_GE(8000) tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt(); auto *attn = engine_->DeclareInput( - "attn", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4}); + "attn", nvinfer1::DataType::kFLOAT, nvinfer1::Dims2{-1, 4}); auto *x = engine_->DeclareInput( - "x", nvinfer1::DataType::kHALF, nvinfer1::Dims3{-1, 4, 1}); + "x", nvinfer1::DataType::kFLOAT, nvinfer1::Dims3{-1, 4, 1}); auto *mask = engine_->DeclareInput( - "mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4}); + "mask", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4{-1, 1, 4, 4}); auto *new_mask = engine_->DeclareInput( - "new_mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 2, 2}); + "new_mask", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4{-1, 1, 2, 2}); plugin::FusedTokenPrunePluginDynamic *plugin = - new plugin::FusedTokenPrunePluginDynamic(true, + new plugin::FusedTokenPrunePluginDynamic(/*with_fp16*/ false, /*keep_first_token*/ false, /*keep_order*/ true, /*flag_varseqlen*/ false); @@ -449,18 +449,215 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { ASSERT_EQ(engine_->engine()->getNbBindings(), 6); LOG(INFO) << "create input"; - std::vector attn_v(64); + std::vector attn_v(16); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 4; ++k) { + attn_v[j * 4 + k] = k; + } + } + std::vector x_v(16); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + x_v[i * 4 + j] = 4 - j; + } + } + std::vector mask_v(64); for (int i = 0; i < 4; ++i) { for (int j = 0; j < 4; ++j) { for (int k = 0; k < 4; ++k) { - attn_v[i * 16 + j * 4 + k] = k; + mask_v[i * 16 + j * 4 + k] = 1; } } } + std::vector new_mask_v(16); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 2; ++k) { + new_mask_v[i * 4 + j * 2 + k] = 1; + } + } + } + + LOG(INFO) << "create output"; + std::vector out_slimmed_x_shape{4, 2, 1}; + std::vector out_cls_ins_shape{4, 2}; + + PrepareInputOutput({attn_v, x_v, mask_v, new_mask_v}, + {out_slimmed_x_shape, out_cls_ins_shape}); + + auto *attn_gpu_data = inputs_[0].mutable_data(ctx_->GetPlace()); + auto *x_gpu_data = inputs_[1].mutable_data(ctx_->GetPlace()); + auto *mask_gpu_data = inputs_[2].mutable_data(ctx_->GetPlace()); + auto *new_mask_gpu_data = inputs_[3].mutable_data(ctx_->GetPlace()); + + auto *slimmed_x_gpu_data = outputs_[0].mutable_data(ctx_->GetPlace()); + auto *cls_inds_gpu_data = outputs_[1].mutable_data(ctx_->GetPlace()); + + LOG(INFO) << "create buffers"; + + std::vector buffers(6); + buffers[0] = reinterpret_cast(attn_gpu_data); + buffers[1] = reinterpret_cast(x_gpu_data); + buffers[2] = reinterpret_cast(mask_gpu_data); + buffers[3] = reinterpret_cast(new_mask_gpu_data); + buffers[4] = reinterpret_cast(slimmed_x_gpu_data); + buffers[5] = reinterpret_cast(cls_inds_gpu_data); + + LOG(INFO) << "Execute"; + + engine_->Execute(4, &buffers, ctx_->stream()); + + std::vector slimmed_x_v(8); + std::vector cls_inds_v; + + LOG(INFO) << "GetOutput"; + GetOutput(slimmed_x_v, cls_inds_v); + + // slimmed_x_v: [[4,3,2,1],[4,3,2,1],[4,3,2,1],[4,3,2,1]] -> + // [[2,1],[2,1],[2,1],[2,1]] + + ASSERT_EQ(slimmed_x_v[0], 2); + ASSERT_EQ(slimmed_x_v[1], 1); + ASSERT_EQ(slimmed_x_v[2], 2); + ASSERT_EQ(slimmed_x_v[3], 1); + ASSERT_EQ(slimmed_x_v[4], 2); + ASSERT_EQ(slimmed_x_v[5], 1); + ASSERT_EQ(slimmed_x_v[6], 2); + ASSERT_EQ(slimmed_x_v[7], 1); + + LOG(INFO) << "finish"; +#endif +} + +class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test { + protected: + void SetUp() override { + ctx_ = new phi::GPUContext(platform::CUDAPlace(0)); + ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(platform::CUDAPlace(0), ctx_->stream()) + .get()); + ctx_->SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + ctx_->SetZeroAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetZeroAllocator(platform::CUDAPlace(0)) + .get()); + ctx_->SetPinnedAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CUDAPinnedPlace()) + .get()); + ctx_->PartialInitWithAllocator(); + + std::map> min_input_shape = { + {"attn", {4, 4}}, + {"x", {4, 4, 1}}, + {"mask", {4, 1, 4, 4}}, + {"new_mask", {4, 1, 2, 2}}}; + std::map> max_input_shape = { + {"attn", {4, 4}}, + {"x", {4, 4, 1}}, + {"mask", {4, 1, 4, 4}}, + {"new_mask", {4, 1, 2, 2}}}; + std::map> optim_input_shape = { + {"attn", {4, 4}}, + {"x", {4, 4, 1}}, + {"mask", {4, 1, 4, 4}}, + {"new_mask", {4, 1, 2, 2}}}; + + engine_ = new TensorRTEngine(16, + 1 << 10, + AnalysisConfig::Precision::kHalf, + nullptr, + 0, + min_input_shape, + max_input_shape, + optim_input_shape, + std::map>(), + std::map>(), + std::map>(), + false, + phi::DataType::FLOAT16, + NaiveLogger::Global()); + engine_->InitNetwork(); + } + + void TearDown() override { + if (engine_) { + delete engine_; + engine_ = nullptr; + } + } + + void PrepareInputOutput(const std::vector> inputs, + std::vector> output_shapes) { + LOG(INFO) << "PrepareInputOutput"; + int num_inputs = inputs.size(); + int num_outputs = output_shapes.size(); + inputs_.resize(num_inputs); + outputs_.resize(num_outputs); + for (int i = 0; i < num_inputs; ++i) { + paddle::framework::TensorFromVector(inputs[i], *ctx_, &inputs_[i]); + } + for (int i = 0; i < num_outputs; ++i) { + outputs_[i].Resize(phi::make_ddim(output_shapes[i])); + } + } + + void GetOutput(std::vector &slimmed_x, // NOLINT + std::vector &cls_inds) { // NOLINT + paddle::framework::TensorToVector(outputs_[0], *ctx_, &slimmed_x); + paddle::framework::TensorToVector(outputs_[1], *ctx_, &cls_inds); + } + + protected: + std::vector inputs_; + std::vector outputs_; + TensorRTEngine *engine_; + phi::GPUContext *ctx_; +}; + +TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) { +#if IS_TRT_VERSION_GE(8000) + tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt(); + auto *attn = engine_->DeclareInput( + "attn", nvinfer1::DataType::kHALF, nvinfer1::Dims2{-1, 4}); + auto *x = engine_->DeclareInput( + "x", nvinfer1::DataType::kHALF, nvinfer1::Dims3{-1, 4, 1}); + auto *mask = engine_->DeclareInput( + "mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4}); + auto *new_mask = engine_->DeclareInput( + "new_mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 2, 2}); + plugin::FusedTokenPrunePluginDynamic *plugin = + new plugin::FusedTokenPrunePluginDynamic(/*with_fp16*/ true, + /*keep_first_token*/ false, + /*keep_order*/ true, + /*flag_varseqlen*/ false); + std::vector itensors = {attn, x, mask, new_mask}; + auto *layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin); + PADDLE_ENFORCE_NOT_NULL(layer, + platform::errors::InvalidArgument( + "TRT fused_token_prune layer building failed.")); + std::vector output_tensor_names{"out_slimmed_x", "out_cls_inds"}; + for (size_t i = 0; i < 2; i++) { + layer->getOutput(i)->setName(output_tensor_names[i].c_str()); + engine_->DeclareOutput(layer, i, output_tensor_names[i]); + } + engine_->FreezeNetwork(); + + ASSERT_EQ(engine_->engine()->getNbBindings(), 6); + LOG(INFO) << "create input"; + std::vector attn_v(16); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 4; ++k) { + attn_v[j * 4 + k] = k; + } + } std::vector x_v(16); for (int i = 0; i < 4; ++i) { for (int j = 0; j < 4; ++j) { - x_v[i * 4 + j] = 1; + x_v[i * 4 + j] = 4 - j; } } std::vector mask_v(64); @@ -509,20 +706,24 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { engine_->Execute(4, &buffers, ctx_->stream()); - std::vector slimmed_x_v; + std::vector slimmed_x_v(8); std::vector cls_inds_v; LOG(INFO) << "GetOutput"; GetOutput(slimmed_x_v, cls_inds_v); - ASSERT_EQ(cls_inds_v[0], 2); - ASSERT_EQ(cls_inds_v[1], 3); - ASSERT_EQ(cls_inds_v[2], 2); - ASSERT_EQ(cls_inds_v[3], 3); - ASSERT_EQ(cls_inds_v[4], 2); - ASSERT_EQ(cls_inds_v[5], 3); - ASSERT_EQ(cls_inds_v[6], 2); - ASSERT_EQ(cls_inds_v[7], 3); + // slimmed_x_v: [[4,3,2,1],[4,3,2,1],[4,3,2,1],[4,3,2,1]] -> + // [[2,1],[2,1],[2,1],[2,1]] + + ASSERT_EQ(slimmed_x_v[0], 2); + ASSERT_EQ(slimmed_x_v[1], 1); + ASSERT_EQ(slimmed_x_v[2], 2); + ASSERT_EQ(slimmed_x_v[3], 1); + ASSERT_EQ(slimmed_x_v[4], 2); + ASSERT_EQ(slimmed_x_v[5], 1); + ASSERT_EQ(slimmed_x_v[6], 2); + ASSERT_EQ(slimmed_x_v[7], 1); + LOG(INFO) << "finish"; #endif }