未验证 提交 65c17315 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]optimize token prune for no varlen (#49094)

* optimize token prune for no varlen
上级 4cdeab7b
...@@ -38,6 +38,17 @@ class FusedTokenPruneOpConverter : public OpConverter { ...@@ -38,6 +38,17 @@ class FusedTokenPruneOpConverter : public OpConverter {
auto output_name = op_desc.Output("SlimmedX")[0]; auto output_name = op_desc.Output("SlimmedX")[0];
auto out_inds_name = op_desc.Output("CLSInds")[0]; auto out_inds_name = op_desc.Output("CLSInds")[0];
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
// reduce_sum: (-1,headsize,token_length,token_length) ->
// (-1,token_length)
uint32_t reduce_dim = 0;
reduce_dim |= 1 << 1; // 00000000000000000000000000000010
reduce_dim |= 1 << 2; // 00000000000000000000000000000110
bool keep_dim = false;
nvinfer1::ReduceOperation reduce_type = nvinfer1::ReduceOperation::kSUM;
auto* reduce_sum_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *Attn, reduce_type, reduce_dim, keep_dim);
auto* Reduced = reduce_sum_layer->getOutput(0);
bool with_fp16 = bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
...@@ -53,21 +64,10 @@ class FusedTokenPruneOpConverter : public OpConverter { ...@@ -53,21 +64,10 @@ class FusedTokenPruneOpConverter : public OpConverter {
auto* pos_id = engine_->GetITensor("pos_id"); auto* pos_id = engine_->GetITensor("pos_id");
auto* mask_id = engine_->GetITensor("mask_id"); auto* mask_id = engine_->GetITensor("mask_id");
// reduce_sum: (-1,headsize,token_length,token_length) ->
// (-1,token_length)
uint32_t reduce_dim = 0;
reduce_dim |= 1 << 1; // 00000000000000000000000000000010
reduce_dim |= 1 << 2; // 00000000000000000000000000000110
bool keep_dim = false;
nvinfer1::ReduceOperation reduce_type = nvinfer1::ReduceOperation::kSUM;
auto* reduce_sum_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *Attn, reduce_type, reduce_dim, keep_dim);
// reduce_sum_layer->getOutput(0)->setType(reduce_sum_layer->getInput(0)->getType());
auto* Reduced = reduce_sum_layer->getOutput(0);
std::vector<nvinfer1::ITensor*> itensors = { std::vector<nvinfer1::ITensor*> itensors = {
Reduced, X, Mask, NewMask, word_id, pos_id, mask_id}; Reduced, X, Mask, NewMask, word_id, pos_id, mask_id};
layer = engine_->AddDynamicPlugin(itensors.data(), 7, plugin); layer = engine_->AddDynamicPlugin(
itensors.data(), itensors.size(), plugin); // inputs'number: 7
layer->getOutput(0)->setName(output_name.c_str()); layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
...@@ -87,10 +87,13 @@ class FusedTokenPruneOpConverter : public OpConverter { ...@@ -87,10 +87,13 @@ class FusedTokenPruneOpConverter : public OpConverter {
layer->getOutput(4)->setName("mask_id_after_token_prune"); layer->getOutput(4)->setName("mask_id_after_token_prune");
engine_->SetITensor("mask_id", layer->getOutput(4)); engine_->SetITensor("mask_id", layer->getOutput(4));
} else { } else {
std::vector<nvinfer1::ITensor*> itensors = {Attn, X, Mask, NewMask}; std::vector<nvinfer1::ITensor*> itensors = {Reduced, X, Mask, NewMask};
layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin); layer = engine_->AddDynamicPlugin(
itensors.data(), itensors.size(), plugin); // inputs'number: 4
layer->getOutput(0)->setName(output_name.c_str()); layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
layer->getOutput(1)->setName(out_inds_name.c_str()); layer->getOutput(1)->setName(out_inds_name.c_str());
engine_->SetITensor(out_inds_name, layer->getOutput(1)); engine_->SetITensor(out_inds_name, layer->getOutput(1));
} }
......
...@@ -31,150 +31,6 @@ namespace inference { ...@@ -31,150 +31,6 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
template <typename T>
__global__ void ElementwiseMask(const T* a,
const T* b,
T* res,
int num_elements) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_elements) return;
const T zero = 0;
res[tid] = b[tid] >= zero ? a[tid] : zero;
#endif
}
template <typename T>
__global__ void FillZero(T* data, int len) {
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= len) return;
const T zero = 0;
data[tid] = zero;
}
__global__ void FillIndex(int32_t* indices, int num_raws, int num_cols) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_raws * num_cols) return;
int col = tid % num_cols;
int raw = tid / num_cols;
indices[tid] = col;
}
template <typename T>
__global__ void MaximumFirst(T* mat, int num_raws, int num_cols, T max_value) {
auto raw = blockIdx.x * blockDim.x + threadIdx.x;
if (raw >= num_raws) return;
mat[raw * num_cols] = max_value;
}
__global__ void FillOffsets(int* offsets, int num_raws, int num_cols) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid > num_raws) return;
offsets[tid] = tid * num_cols;
}
template <typename T>
__global__ void Slice(
const T* src, T* dst, int num_raws, int src_num_cols, int dst_num_cols) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_raws * dst_num_cols) return;
int raw = tid / dst_num_cols;
int col = tid % dst_num_cols;
dst[tid] = src[raw * src_num_cols + col];
}
template <typename T>
__global__ void ReduceSum2(
const T* src, T* dst, int bsz, int nb_head, int max_seq_len) {
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len);
int batch = bid / (nb_head * num_blocks_per_head);
int col = bid % max_seq_len;
int head = (bid / num_blocks_per_head) % nb_head;
extern __shared__ T res_float[];
res_float[tid] =
src[batch * (nb_head * max_seq_len * max_seq_len) +
head * (max_seq_len * max_seq_len) + col + tid * max_seq_len];
__syncthreads();
for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) {
if (tid < offset) {
res_float[tid] += res_float[tid + offset];
}
__syncthreads();
if (offset % 2 == 1 && tid == offset - 2) {
res_float[tid] += res_float[tid + 1];
}
}
if (tid == 0) {
auto* dst_addr = dst + batch * max_seq_len + col;
atomicAdd(dst_addr, res_float[0]);
}
}
template <>
__global__ void ReduceSum2<half>(
const half* src, half* dst, int bsz, int nb_head, int max_seq_len) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len);
int batch = bid / (nb_head * num_blocks_per_head);
int col = bid % max_seq_len;
int head = (bid / num_blocks_per_head) % nb_head;
extern __shared__ half res_half[];
res_half[tid] =
src[batch * (nb_head * max_seq_len * max_seq_len) +
head * (max_seq_len * max_seq_len) + col + tid * max_seq_len];
__syncthreads();
for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) {
if (tid < offset) {
res_half[tid] += res_half[tid + offset];
}
__syncthreads();
if (offset % 2 == 1 && tid == offset - 2) {
res_half[tid] += res_half[tid + 1];
}
__syncthreads();
}
if (tid == 0) {
phi::fastAtomicAdd<platform::float16>(
reinterpret_cast<platform::float16*>(dst),
static_cast<size_t>(batch * max_seq_len + col),
static_cast<size_t>(bsz * max_seq_len),
static_cast<platform::float16>(res_half[0]));
}
#endif
}
template <typename T>
__global__ void TakeAlongAxis(const T* src,
T* dst,
int32_t* indices,
int num_raws,
int src_num_cols,
int dst_num_cols,
int num_elements) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_raws * dst_num_cols) return;
int raw = tid / dst_num_cols;
int col = tid % dst_num_cols;
for (int i = 0; i < num_elements; ++i) {
dst[tid * num_elements + i] =
*(src + (raw * src_num_cols + indices[tid]) * num_elements + i);
}
}
__global__ void compute_token_length(const int32_t* src, __global__ void compute_token_length(const int32_t* src,
int32_t* dst, int32_t* dst,
float scale) { float scale) {
...@@ -182,16 +38,18 @@ __global__ void compute_token_length(const int32_t* src, ...@@ -182,16 +38,18 @@ __global__ void compute_token_length(const int32_t* src,
dst[it] = max(static_cast<int>((src[it + 1] - src[it]) * scale), 1); dst[it] = max(static_cast<int>((src[it + 1] - src[it]) * scale), 1);
} }
template <typename T>
__global__ void fill_index_padding_score(int32_t* token_index, __global__ void fill_index_padding_score(int32_t* token_index,
const half* scores, const T* scores,
int32_t scores_size, int32_t sequnce_length,
half* padding_scores) { T* padding_scores) {
int tid = threadIdx.x + blockIdx.x * blockDim.x; int padding_scores_it = threadIdx.x + blockIdx.x * blockDim.x;
token_index[tid] = threadIdx.x; int scores_it = threadIdx.x + blockIdx.x * sequnce_length;
if (tid < scores_size) { token_index[padding_scores_it] = threadIdx.x;
padding_scores[tid] = scores[tid]; if (threadIdx.x < sequnce_length) {
padding_scores[padding_scores_it] = scores[scores_it];
} else { } else {
padding_scores[tid] = 0; padding_scores[padding_scores_it] = 0;
} }
} }
...@@ -238,21 +96,64 @@ __global__ void general_topk_pair_sort(T* in_keys, int32_t* in_out_values) { ...@@ -238,21 +96,64 @@ __global__ void general_topk_pair_sort(T* in_keys, int32_t* in_out_values) {
.Store(in_out_values + block_offset, thread_values); .Store(in_out_values + block_offset, thread_values);
} }
__global__ void varlen_prune_token(const half* tokens, __global__ void varlen_prune_token_change_order(
const int32_t* token_pos, const half* tokens,
const int32_t* token_index, const int32_t* token_pos,
half* output) { const int32_t padding_token_length,
const int32_t* token_index,
half* output) {
int batch = blockIdx.x; int batch = blockIdx.x;
int token_it = batch * gridDim.y + blockIdx.y; int token_it = batch * gridDim.y + blockIdx.y;
int pre_value_it = int pre_value_it =
token_it * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + threadIdx.x; token_it * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + threadIdx.x;
int token_index_it = batch * padding_token_length + blockIdx.y;
if (token_index[token_it] < token_pos[batch + 1] - token_pos[batch]) { if (token_index[token_index_it] < token_pos[batch + 1] - token_pos[batch]) {
output[(token_index[token_it] + token_pos[batch]) * gridDim.z * blockDim.x + output[(token_index[token_index_it] + token_pos[batch]) * gridDim.z *
blockDim.x +
blockIdx.z * blockDim.x + threadIdx.x] = tokens[pre_value_it]; blockIdx.z * blockDim.x + threadIdx.x] = tokens[pre_value_it];
} }
} }
template <typename T>
__global__ void prune_token_change_order(const T* tokens,
int32_t new_sequnce_length,
const int32_t padding_token_length,
const int32_t* token_index,
T* output) {
int batch = blockIdx.x;
int token_it = batch * gridDim.y + blockIdx.y;
int pre_value_it =
token_it * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + threadIdx.x;
int token_index_it = batch * padding_token_length + blockIdx.y;
if (token_index[token_index_it] < new_sequnce_length) {
output[(batch * new_sequnce_length + token_index[token_index_it]) *
gridDim.z * blockDim.x +
blockIdx.z * blockDim.x + threadIdx.x] = tokens[pre_value_it];
}
}
template <typename T>
__global__ void prune_token_keep_order(const T* tokens,
int32_t pre_sequnce_length,
int32_t new_sequnce_length,
const int32_t padding_token_length,
const int32_t* token_index,
T* output) {
int batch = blockIdx.x;
int index = 0;
for (int i = 0; i < pre_sequnce_length; ++i) {
if (token_index[batch * padding_token_length + i] < new_sequnce_length) {
output[(batch * new_sequnce_length + index) * gridDim.y * blockDim.x +
blockIdx.y * blockDim.x + threadIdx.x] =
tokens[(batch * pre_sequnce_length + i) * gridDim.y * blockDim.x +
blockIdx.y * blockDim.x + threadIdx.x];
index++;
}
}
}
nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions( nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions(
int output_index, int output_index,
const nvinfer1::DimsExprs* inputs, const nvinfer1::DimsExprs* inputs,
...@@ -353,7 +254,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( ...@@ -353,7 +254,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
"should be half for varseqlen.")); "should be half for varseqlen."));
} }
} else if (pos == 6 || pos == 11) { // mask_id, mask_id_out } else if (pos == 6 || pos == 11) { // mask_id, mask_id_out
return (in.type == nvinfer1::DataType::kFLOAT) && return (in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
} else { } else {
return in.type == nvinfer1::DataType::kINT32 && return in.type == nvinfer1::DataType::kINT32 &&
...@@ -364,7 +265,6 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( ...@@ -364,7 +265,6 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
if (with_fp16_) { if (with_fp16_) {
return (in.type == nvinfer1::DataType::kHALF) && return (in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
} else { } else {
return (in.type == nvinfer1::DataType::kFLOAT) && return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
...@@ -373,8 +273,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( ...@@ -373,8 +273,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
const nvinfer1::PluginTensorDesc& prev = in_out[0]; const nvinfer1::PluginTensorDesc& prev = in_out[0];
return in.type == prev.type && in.format == prev.format; return in.type == prev.type && in.format == prev.format;
} else { } else {
return in.type == nvinfer1::DataType::kINT32 && return in.format == nvinfer1::TensorFormat::kLINEAR;
in.format == nvinfer1::TensorFormat::kLINEAR;
} }
} }
} }
...@@ -425,199 +324,6 @@ size_t FusedTokenPrunePluginDynamic::getWorkspaceSize( ...@@ -425,199 +324,6 @@ size_t FusedTokenPrunePluginDynamic::getWorkspaceSize(
return size; return size;
} }
template <typename T>
inline void enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace_ptr,
cudaStream_t stream,
int device_id,
T max_value,
bool keep_first_token_,
bool keep_order_) {
// Dims
auto attn_dims = input_desc[0].dims;
auto x_dims = input_desc[1].dims;
auto new_mask_dims = input_desc[3].dims;
auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1],
max_seq_len = attn_dims.d[2];
auto c = x_dims.d[2];
auto slimmed_x_len = new_mask_dims.d[2];
// Inputs
const T* attn_data = static_cast<const T*>(inputs[0]);
const T* x_data = static_cast<const T*>(inputs[1]);
const T* mask_data = static_cast<const T*>(inputs[2]);
// Outputs
T* output_data = static_cast<T*>(outputs[0]);
int32_t* output_indices_data = static_cast<int32_t*>(outputs[1]);
int total = bsz * nb_head * max_seq_len * max_seq_len;
int block = operators::ComputeBlockSize(max_seq_len);
int grid = operators::CeilDivide(total, block);
// Workspace for intermediate variable
char* workspace = static_cast<char*>(workspace_ptr);
T* attn_tmp_data = reinterpret_cast<T*>(workspace);
size_t offset = total * sizeof(T);
T* attn_accu_data = reinterpret_cast<T*>(workspace + offset);
offset += bsz * max_seq_len * sizeof(T);
int32_t* attn_accu_indices_data =
reinterpret_cast<int32_t*>(workspace + offset);
offset += bsz * max_seq_len * sizeof(int32_t);
T* sort_attn_accu_data = reinterpret_cast<T*>(workspace + offset);
offset += bsz * max_seq_len * sizeof(T);
int32_t* sort_attn_accu_indices_data =
reinterpret_cast<int32_t*>(workspace + offset);
offset += bsz * max_seq_len * sizeof(int32_t);
int* offsets_data = reinterpret_cast<int*>(workspace + offset);
offset += (bsz + 1) * sizeof(int);
int32_t* slimmed_sort_attn_accu_indices_data =
reinterpret_cast<int32_t*>(workspace + offset);
// 1. Filter attn by mask
ElementwiseMask<T>
<<<grid, block, 0, stream>>>(attn_data, mask_data, attn_tmp_data, total);
total = bsz * max_seq_len;
block = operators::ComputeBlockSize(max_seq_len);
grid = operators::CeilDivide(total, block);
FillZero<T><<<grid, block, 0, stream>>>(attn_accu_data, total);
// 2. Reduce sum
total = bsz * nb_head * max_seq_len * max_seq_len;
int block_tmp = max_seq_len;
while (block_tmp > 1024)
block_tmp /= 2; // if max seq len > 1024, it must be 2^n
block =
block_tmp; // make sure max_seq_len is an integral multiple of block_size
grid = operators::CeilDivide(total, block);
ReduceSum2<T><<<grid, block, block * sizeof(T), stream>>>(
attn_tmp_data, attn_accu_data, bsz, nb_head, max_seq_len);
// 3. Prepare token indices
total = bsz * max_seq_len;
block = operators::ComputeBlockSize(max_seq_len);
grid = operators::CeilDivide(total, block);
FillIndex<<<grid, block, 0, stream>>>(
attn_accu_indices_data, bsz, max_seq_len);
// 4. Sort token indices by attn
if (keep_first_token_) {
MaximumFirst<T>
<<<bsz, 1, 0, stream>>>(attn_accu_data, bsz, max_seq_len, max_value);
}
size_t temp_storage_bytes = -1;
int num_items = bsz * max_seq_len;
int num_segments = bsz;
FillOffsets<<<bsz + 1, 1, 0, stream>>>(offsets_data, bsz, max_seq_len);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr,
temp_storage_bytes,
attn_accu_data,
sort_attn_accu_data,
attn_accu_indices_data,
sort_attn_accu_indices_data,
num_items,
num_segments,
offsets_data,
offsets_data + 1,
0,
sizeof(T) * 8,
stream));
int64_t temp_size = temp_storage_bytes;
phi::DenseTensor temp_storage;
auto* temp_storage_data = temp_storage.mutable_data<uint8_t>(
{temp_size}, platform::CUDAPlace(device_id));
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage_data,
temp_storage_bytes,
attn_accu_data,
sort_attn_accu_data,
attn_accu_indices_data,
sort_attn_accu_indices_data,
num_items,
num_segments,
offsets_data,
offsets_data + 1,
0,
sizeof(T) * 8,
stream));
// 5. Slice
total = bsz * slimmed_x_len;
block = operators::ComputeBlockSize(slimmed_x_len);
grid = operators::CeilDivide(total, block);
Slice<int32_t>
<<<grid, block, 0, stream>>>(sort_attn_accu_indices_data,
slimmed_sort_attn_accu_indices_data,
bsz,
max_seq_len,
slimmed_x_len);
if (keep_order_) {
// 6. reorder
num_items = bsz * slimmed_x_len;
FillOffsets<<<bsz + 1, 1, 0, stream>>>(offsets_data, bsz, slimmed_x_len);
temp_storage_bytes = -1;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys(
nullptr,
temp_storage_bytes,
slimmed_sort_attn_accu_indices_data,
output_indices_data,
num_items,
num_segments,
offsets_data,
offsets_data + 1,
0,
sizeof(int32_t) * 8,
stream));
temp_size = temp_storage_bytes;
temp_storage.Resize({temp_size});
temp_storage_data =
temp_storage.mutable_data<uint8_t>(platform::CUDAPlace(device_id));
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys(
temp_storage_data,
temp_storage_bytes,
slimmed_sort_attn_accu_indices_data,
output_indices_data,
num_items,
num_segments,
offsets_data,
offsets_data + 1,
0,
sizeof(int32_t) * 8,
stream));
TakeAlongAxis<T><<<grid, block, 0, stream>>>(x_data,
output_data,
output_indices_data,
bsz,
max_seq_len,
slimmed_x_len,
c);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(output_indices_data,
slimmed_sort_attn_accu_indices_data,
bsz * slimmed_x_len * sizeof(int32_t),
cudaMemcpyDeviceToDevice));
TakeAlongAxis<T>
<<<grid, block, 0, stream>>>(x_data,
output_data,
slimmed_sort_attn_accu_indices_data,
bsz,
max_seq_len,
slimmed_x_len,
c);
}
}
int FusedTokenPrunePluginDynamic::enqueue( int FusedTokenPrunePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc, const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const nvinfer1::PluginTensorDesc* output_desc,
...@@ -628,49 +334,56 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -628,49 +334,56 @@ int FusedTokenPrunePluginDynamic::enqueue(
if (flag_varseqlen_) { if (flag_varseqlen_) {
if (!(input_desc[0].type == nvinfer1::DataType::kHALF && if (!(input_desc[0].type == nvinfer1::DataType::kHALF &&
input_desc[1].type == nvinfer1::DataType::kHALF)) { input_desc[1].type == nvinfer1::DataType::kHALF)) {
PADDLE_THROW( PADDLE_THROW(platform::errors::InvalidArgument(
platform::errors::InvalidArgument("Token_prune'type must half")); "Token_prune'type must half for varseqlen"));
} }
float scale = float scale =
static_cast<float>(input_desc[3].dims.d[2]) / input_desc[6].dims.d[1]; static_cast<float>(input_desc[3].dims.d[2]) / input_desc[2].dims.d[2];
const int32_t* inputs5 = const int32_t* input5 =
static_cast<const int32_t*>(inputs[5]); // pre pos id static_cast<const int32_t*>(inputs[5]); // pre pos id
int32_t* outputs3 = static_cast<int32_t*>(outputs[3]); // new pos id int32_t* output3 = static_cast<int32_t*>(outputs[3]); // new pos id
half* outputs0 = static_cast<half*>(outputs[0]); half* output0 = static_cast<half*>(outputs[0]);
const int32_t B = input_desc[1].dims.d[0]; // batchs const int32_t B = input_desc[1].dims.d[0]; // batchs
const int32_t max_sequnce_length = const int32_t max_sequnce_length =
input_desc[1].dims.d[1]; // max sequnce length input_desc[1].dims.d[1]; // max sequnce length
const int32_t length = input_desc[1].dims.d[2]; // vector length const int32_t length = input_desc[1].dims.d[2]; // hidden size
const half* scores = static_cast<const half*>(inputs[0]); // reduce sum const half* scores = static_cast<const half*>(inputs[0]); // reduce sum
const half* tokens = static_cast<const half*>(inputs[1]); const half* tokens = static_cast<const half*>(inputs[1]);
const int32_t scores_size = B * max_sequnce_length;
int32_t padding_token_length; int32_t padding_token_length;
if (max_sequnce_length <= 128) { if (max_sequnce_length <= 64) {
padding_token_length = 64;
} else if (max_sequnce_length <= 128) {
padding_token_length = 128; padding_token_length = 128;
} else if (max_sequnce_length <= 256) { } else if (max_sequnce_length <= 256) {
padding_token_length = 256; padding_token_length = 256;
} else if (max_sequnce_length <= 384) { } else if (max_sequnce_length <= 384) {
padding_token_length = 384; padding_token_length = 384;
} else if (max_sequnce_length <= 512) {
padding_token_length = 512;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Token_prune'token_length must <= 384")); "Token_prune'token_length must <= 512"));
} }
// 1. Compute the token length after pruning. // 1. Compute the token length after pruning.
compute_token_length<<<1, B, 0, stream>>>( compute_token_length<<<1, B, 0, stream>>>(
inputs5, pruned_token_lengths_, scale); input5, pruned_token_lengths_, scale);
fill_index_padding_score<<<B, padding_token_length, 0, stream>>>( // 2. Padding scores
token_index_, scores, scores_size, padding_scores_); fill_index_padding_score<half><<<B, padding_token_length, 0, stream>>>(
token_index_,
scores,
max_sequnce_length,
static_cast<half*>(padding_scores_));
// 3. compute new pos id
// Determine temporary device storage requirements // Determine temporary device storage requirements
void* d_temp_storage = NULL; void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(d_temp_storage, cub::DeviceScan::ExclusiveSum(d_temp_storage,
temp_storage_bytes, temp_storage_bytes,
pruned_token_lengths_, pruned_token_lengths_,
outputs3, output3,
B + 1); B + 1);
// Allocate temporary storage // Allocate temporary storage
cudaMalloc(&d_temp_storage, temp_storage_bytes); cudaMalloc(&d_temp_storage, temp_storage_bytes);
...@@ -679,20 +392,28 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -679,20 +392,28 @@ int FusedTokenPrunePluginDynamic::enqueue(
cub::DeviceScan::ExclusiveSum(d_temp_storage, cub::DeviceScan::ExclusiveSum(d_temp_storage,
temp_storage_bytes, temp_storage_bytes,
pruned_token_lengths_, pruned_token_lengths_,
outputs3, output3,
B + 1); B + 1);
if (padding_token_length == 128) { // 4. sort scores
general_topk_pair_sort<half, 32, 4> if (padding_token_length == 64) {
<<<B, 32, 0, stream>>>(padding_scores_, token_index_); // 128 general_topk_pair_sort<half, 32, 2><<<B, 32, 0, stream>>>(
static_cast<half*>(padding_scores_), token_index_); // 64
} else if (padding_token_length == 128) {
general_topk_pair_sort<half, 32, 4><<<B, 32, 0, stream>>>(
static_cast<half*>(padding_scores_), token_index_); // 128
} else if (padding_token_length == 256) { } else if (padding_token_length == 256) {
general_topk_pair_sort<half, 64, 4> general_topk_pair_sort<half, 64, 4><<<B, 64, 0, stream>>>(
<<<B, 64, 0, stream>>>(padding_scores_, token_index_); // 256 static_cast<half*>(padding_scores_), token_index_); // 256
} else if (padding_token_length == 384) {
general_topk_pair_sort<half, 96, 4><<<B, 96, 0, stream>>>(
static_cast<half*>(padding_scores_), token_index_); // 384
} else { } else {
general_topk_pair_sort<half, 96, 4> general_topk_pair_sort<half, 128, 4><<<B, 128, 0, stream>>>(
<<<B, 96, 0, stream>>>(padding_scores_, token_index_); // 384 static_cast<half*>(padding_scores_), token_index_); // 512
} }
// 5. compute output
int32_t num_threads; int32_t num_threads;
if (length < 1024) { if (length < 1024) {
num_threads = length; num_threads = length;
...@@ -723,46 +444,196 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -723,46 +444,196 @@ int FusedTokenPrunePluginDynamic::enqueue(
B, B,
max_sequnce_length, max_sequnce_length,
length / num_threads); // batchs, max_sequnce_length, vector_ength/*** length / num_threads); // batchs, max_sequnce_length, vector_ength/***
varlen_prune_token<<<num_blocks, num_threads, 0, stream>>>( varlen_prune_token_change_order<<<num_blocks, num_threads, 0, stream>>>(
tokens, outputs3, token_index_, outputs0); tokens, output3, padding_token_length, token_index_, output0);
} else { } else {
auto input_type = input_desc[0].type; auto input_type = input_desc[0].type;
auto attn_dims = input_desc[0].dims; const int32_t B = input_desc[1].dims.d[0]; // batchs
auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1], const int32_t pre_sequnce_length = input_desc[1].dims.d[1];
max_seq_len = attn_dims.d[2]; const int32_t new_sequnce_length = input_desc[3].dims.d[2]; // new mask
int device_id; const int32_t length = input_desc[1].dims.d[2]; // hidden size
cudaGetDevice(&device_id);
if (input_type == nvinfer1::DataType::kFLOAT) { if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp32"; VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp32";
const float* scores = static_cast<const float*>(inputs[0]); // reduce sum
const float* tokens = static_cast<const float*>(inputs[1]); // X
float* output0 = static_cast<float*>(outputs[0]);
int32_t padding_token_length;
if (pre_sequnce_length <= 64) {
padding_token_length = 64;
} else if (pre_sequnce_length <= 128) {
padding_token_length = 128;
} else if (pre_sequnce_length <= 256) {
padding_token_length = 256;
} else if (pre_sequnce_length <= 384) {
padding_token_length = 384;
} else if (pre_sequnce_length <= 512) {
padding_token_length = 512;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Token_prune'token_length must <= 512"));
}
float max = std::numeric_limits<float>::max(); // 1. Padding scores
fill_index_padding_score<float><<<B, padding_token_length, 0, stream>>>(
enqueueImpl<float>(input_desc, token_index_,
output_desc, scores,
inputs, pre_sequnce_length,
outputs, static_cast<float*>(padding_scores_));
workspace,
stream, // 2. sort scores
device_id, if (padding_token_length == 64) {
max, general_topk_pair_sort<float, 32, 2><<<B, 32, 0, stream>>>(
keep_first_token_, static_cast<float*>(padding_scores_), token_index_); // 64
keep_order_); } else if (padding_token_length == 128) {
general_topk_pair_sort<float, 32, 4><<<B, 32, 0, stream>>>(
static_cast<float*>(padding_scores_), token_index_); // 128
} else if (padding_token_length == 256) {
general_topk_pair_sort<float, 64, 4><<<B, 64, 0, stream>>>(
static_cast<float*>(padding_scores_), token_index_); // 256
} else if (padding_token_length == 384) {
general_topk_pair_sort<float, 96, 4><<<B, 96, 0, stream>>>(
static_cast<float*>(padding_scores_), token_index_); // 384
} else {
general_topk_pair_sort<float, 128, 4><<<B, 128, 0, stream>>>(
static_cast<float*>(padding_scores_), token_index_); // 512
}
// 3. compute output
int32_t num_threads;
if (length < 1024) {
num_threads = length;
} else {
if (length % 512 == 0) {
num_threads = 512;
} else if (length % 256 == 0) {
num_threads = 256;
} else if (length % 128 == 0) {
num_threads = 128;
} else if (length % 64 == 0) {
num_threads = 64;
} else if (length % 32 == 0) {
num_threads = 32;
} else if (length % 16 == 0) {
num_threads = 16;
} else if (length % 8 == 0) {
num_threads = 8;
} else if (length % 4 == 0) {
num_threads = 4;
} else if (length % 2 == 0) {
num_threads = 2;
} else {
num_threads = 1;
}
}
if (keep_order_) {
const dim3 num_blocks(B, length / num_threads);
prune_token_keep_order<float>
<<<num_blocks, num_threads, 0, stream>>>(tokens,
pre_sequnce_length,
new_sequnce_length,
padding_token_length,
token_index_,
output0);
} else {
const dim3 num_blocks(B, pre_sequnce_length, length / num_threads);
prune_token_change_order<float>
<<<num_blocks, num_threads, 0, stream>>>(tokens,
new_sequnce_length,
padding_token_length,
token_index_,
output0);
}
} else if (input_type == nvinfer1::DataType::kHALF) { } else if (input_type == nvinfer1::DataType::kHALF) {
VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp16"; VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp16";
const half* scores = static_cast<const half*>(inputs[0]); // reduce sum
const half* tokens = static_cast<const half*>(inputs[1]); // X
half* output0 = static_cast<half*>(outputs[0]);
int32_t padding_token_length;
if (pre_sequnce_length <= 64) {
padding_token_length = 64;
} else if (pre_sequnce_length <= 128) {
padding_token_length = 128;
} else if (pre_sequnce_length <= 256) {
padding_token_length = 256;
} else if (pre_sequnce_length <= 384) {
padding_token_length = 384;
} else if (pre_sequnce_length <= 512) {
padding_token_length = 512;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Token_prune'token_length must <= 512"));
}
// 1. Padding scores
fill_index_padding_score<half><<<B, padding_token_length, 0, stream>>>(
token_index_,
scores,
pre_sequnce_length,
static_cast<half*>(padding_scores_));
// 2. sort scores
if (padding_token_length == 64) {
general_topk_pair_sort<half, 32, 2><<<B, 32, 0, stream>>>(
static_cast<half*>(padding_scores_), token_index_); // 64
} else if (padding_token_length == 128) {
general_topk_pair_sort<half, 32, 4><<<B, 32, 0, stream>>>(
static_cast<half*>(padding_scores_), token_index_); // 128
} else if (padding_token_length == 256) {
general_topk_pair_sort<half, 64, 4><<<B, 64, 0, stream>>>(
static_cast<half*>(padding_scores_), token_index_); // 256
} else if (padding_token_length == 384) {
general_topk_pair_sort<half, 96, 4><<<B, 96, 0, stream>>>(
static_cast<half*>(padding_scores_), token_index_); // 384
} else {
general_topk_pair_sort<half, 128, 4><<<B, 128, 0, stream>>>(
static_cast<half*>(padding_scores_), token_index_); // 512
}
half max = 65504.0; // 3. compute output
enqueueImpl<half>(input_desc, int32_t num_threads;
output_desc, if (length < 1024) {
inputs, num_threads = length;
outputs, } else {
workspace, if (length % 512 == 0) {
stream, num_threads = 512;
device_id, } else if (length % 256 == 0) {
max, num_threads = 256;
keep_first_token_, } else if (length % 128 == 0) {
keep_order_); num_threads = 128;
} else if (length % 64 == 0) {
num_threads = 64;
} else if (length % 32 == 0) {
num_threads = 32;
} else if (length % 16 == 0) {
num_threads = 16;
} else if (length % 8 == 0) {
num_threads = 8;
} else if (length % 4 == 0) {
num_threads = 4;
} else if (length % 2 == 0) {
num_threads = 2;
} else {
num_threads = 1;
}
}
if (keep_order_) {
const dim3 num_blocks(B, length / num_threads);
prune_token_keep_order<half>
<<<num_blocks, num_threads, 0, stream>>>(tokens,
pre_sequnce_length,
new_sequnce_length,
padding_token_length,
token_index_,
output0);
} else {
const dim3 num_blocks(B, pre_sequnce_length, length / num_threads);
prune_token_change_order<half>
<<<num_blocks, num_threads, 0, stream>>>(tokens,
new_sequnce_length,
padding_token_length,
token_index_,
output0);
}
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type " platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type "
......
...@@ -93,12 +93,33 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { ...@@ -93,12 +93,33 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
int nb_outputs) TRT_NOEXCEPT override { int nb_outputs) TRT_NOEXCEPT override {
max_batchs_ = in[1].max.d[0]; max_batchs_ = in[1].max.d[0];
max_token_length_ = in[1].max.d[1]; max_token_length_ = in[1].max.d[1];
int32_t padding_token_length;
if (max_token_length_ <= 64) {
padding_token_length = 64;
} else if (max_token_length_ <= 128) {
padding_token_length = 128;
} else if (max_token_length_ <= 256) {
padding_token_length = 256;
} else if (max_token_length_ <= 384) {
padding_token_length = 384;
} else if (max_token_length_ <= 512) {
padding_token_length = 512;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Token_prune'token_length(max) must <= 512"));
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(&pruned_token_lengths_, PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(&pruned_token_lengths_,
(max_batchs_ + 1) * sizeof(int32_t))); (max_batchs_ + 1) * sizeof(int32_t)));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc( PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(
&token_index_, max_batchs_ * max_token_length_ * sizeof(int32_t))); &token_index_, max_batchs_ * padding_token_length * sizeof(int32_t)));
int32_t type_size = 4;
if (in[0].desc.type == nvinfer1::DataType::kHALF) {
type_size = 2;
} else {
type_size = 4;
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc( PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(
&padding_scores_, max_batchs_ * max_token_length_ * sizeof(half))); &padding_scores_, max_batchs_ * padding_token_length * type_size));
} }
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
...@@ -129,7 +150,7 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { ...@@ -129,7 +150,7 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
int32_t* token_index_; int32_t* token_index_;
int32_t max_batchs_; int32_t max_batchs_;
int32_t max_token_length_; int32_t max_token_length_;
half* padding_scores_; void* padding_scores_;
}; };
class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator { class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator {
......
...@@ -352,24 +352,24 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test { ...@@ -352,24 +352,24 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
ctx_->PartialInitWithAllocator(); ctx_->PartialInitWithAllocator();
std::map<std::string, std::vector<int>> min_input_shape = { std::map<std::string, std::vector<int>> min_input_shape = {
{"attn", {4, 1, 4, 4}}, {"attn", {4, 4}},
{"x", {4, 4, 1}}, {"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}}, {"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}}; {"new_mask", {4, 1, 2, 2}}};
std::map<std::string, std::vector<int>> max_input_shape = { std::map<std::string, std::vector<int>> max_input_shape = {
{"attn", {4, 1, 4, 4}}, {"attn", {4, 4}},
{"x", {4, 4, 1}}, {"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}}, {"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}}; {"new_mask", {4, 1, 2, 2}}};
std::map<std::string, std::vector<int>> optim_input_shape = { std::map<std::string, std::vector<int>> optim_input_shape = {
{"attn", {4, 1, 4, 4}}, {"attn", {4, 4}},
{"x", {4, 4, 1}}, {"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}}, {"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}}; {"new_mask", {4, 1, 2, 2}}};
engine_ = new TensorRTEngine(16, engine_ = new TensorRTEngine(16,
1 << 10, 1 << 10,
AnalysisConfig::Precision::kHalf, AnalysisConfig::Precision::kFloat32,
nullptr, nullptr,
0, 0,
min_input_shape, min_input_shape,
...@@ -391,7 +391,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test { ...@@ -391,7 +391,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
} }
} }
void PrepareInputOutput(const std::vector<std::vector<float16>> inputs, void PrepareInputOutput(const std::vector<std::vector<float>> inputs,
std::vector<std::vector<int>> output_shapes) { std::vector<std::vector<int>> output_shapes) {
LOG(INFO) << "PrepareInputOutput"; LOG(INFO) << "PrepareInputOutput";
int num_inputs = inputs.size(); int num_inputs = inputs.size();
...@@ -423,15 +423,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { ...@@ -423,15 +423,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000) #if IS_TRT_VERSION_GE(8000)
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt(); tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto *attn = engine_->DeclareInput( auto *attn = engine_->DeclareInput(
"attn", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4}); "attn", nvinfer1::DataType::kFLOAT, nvinfer1::Dims2{-1, 4});
auto *x = engine_->DeclareInput( auto *x = engine_->DeclareInput(
"x", nvinfer1::DataType::kHALF, nvinfer1::Dims3{-1, 4, 1}); "x", nvinfer1::DataType::kFLOAT, nvinfer1::Dims3{-1, 4, 1});
auto *mask = engine_->DeclareInput( auto *mask = engine_->DeclareInput(
"mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4}); "mask", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4{-1, 1, 4, 4});
auto *new_mask = engine_->DeclareInput( auto *new_mask = engine_->DeclareInput(
"new_mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 2, 2}); "new_mask", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4{-1, 1, 2, 2});
plugin::FusedTokenPrunePluginDynamic *plugin = plugin::FusedTokenPrunePluginDynamic *plugin =
new plugin::FusedTokenPrunePluginDynamic(true, new plugin::FusedTokenPrunePluginDynamic(/*with_fp16*/ false,
/*keep_first_token*/ false, /*keep_first_token*/ false,
/*keep_order*/ true, /*keep_order*/ true,
/*flag_varseqlen*/ false); /*flag_varseqlen*/ false);
...@@ -449,18 +449,215 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { ...@@ -449,18 +449,215 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
ASSERT_EQ(engine_->engine()->getNbBindings(), 6); ASSERT_EQ(engine_->engine()->getNbBindings(), 6);
LOG(INFO) << "create input"; LOG(INFO) << "create input";
std::vector<float16> attn_v(64); std::vector<float> attn_v(16);
for (int j = 0; j < 4; ++j) {
for (int k = 0; k < 4; ++k) {
attn_v[j * 4 + k] = k;
}
}
std::vector<float> x_v(16);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
x_v[i * 4 + j] = 4 - j;
}
}
std::vector<float> mask_v(64);
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
for (int k = 0; k < 4; ++k) { for (int k = 0; k < 4; ++k) {
attn_v[i * 16 + j * 4 + k] = k; mask_v[i * 16 + j * 4 + k] = 1;
} }
} }
} }
std::vector<float> new_mask_v(16);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 2; ++j) {
for (int k = 0; k < 2; ++k) {
new_mask_v[i * 4 + j * 2 + k] = 1;
}
}
}
LOG(INFO) << "create output";
std::vector<int> out_slimmed_x_shape{4, 2, 1};
std::vector<int> out_cls_ins_shape{4, 2};
PrepareInputOutput({attn_v, x_v, mask_v, new_mask_v},
{out_slimmed_x_shape, out_cls_ins_shape});
auto *attn_gpu_data = inputs_[0].mutable_data<float>(ctx_->GetPlace());
auto *x_gpu_data = inputs_[1].mutable_data<float>(ctx_->GetPlace());
auto *mask_gpu_data = inputs_[2].mutable_data<float>(ctx_->GetPlace());
auto *new_mask_gpu_data = inputs_[3].mutable_data<float>(ctx_->GetPlace());
auto *slimmed_x_gpu_data = outputs_[0].mutable_data<float>(ctx_->GetPlace());
auto *cls_inds_gpu_data = outputs_[1].mutable_data<int32_t>(ctx_->GetPlace());
LOG(INFO) << "create buffers";
std::vector<void *> buffers(6);
buffers[0] = reinterpret_cast<void *>(attn_gpu_data);
buffers[1] = reinterpret_cast<void *>(x_gpu_data);
buffers[2] = reinterpret_cast<void *>(mask_gpu_data);
buffers[3] = reinterpret_cast<void *>(new_mask_gpu_data);
buffers[4] = reinterpret_cast<void *>(slimmed_x_gpu_data);
buffers[5] = reinterpret_cast<void *>(cls_inds_gpu_data);
LOG(INFO) << "Execute";
engine_->Execute(4, &buffers, ctx_->stream());
std::vector<float> slimmed_x_v(8);
std::vector<int32_t> cls_inds_v;
LOG(INFO) << "GetOutput";
GetOutput(slimmed_x_v, cls_inds_v);
// slimmed_x_v: [[4,3,2,1],[4,3,2,1],[4,3,2,1],[4,3,2,1]] ->
// [[2,1],[2,1],[2,1],[2,1]]
ASSERT_EQ(slimmed_x_v[0], 2);
ASSERT_EQ(slimmed_x_v[1], 1);
ASSERT_EQ(slimmed_x_v[2], 2);
ASSERT_EQ(slimmed_x_v[3], 1);
ASSERT_EQ(slimmed_x_v[4], 2);
ASSERT_EQ(slimmed_x_v[5], 1);
ASSERT_EQ(slimmed_x_v[6], 2);
ASSERT_EQ(slimmed_x_v[7], 1);
LOG(INFO) << "finish";
#endif
}
class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
protected:
void SetUp() override {
ctx_ = new phi::GPUContext(platform::CUDAPlace(0));
ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CUDAPlace(0), ctx_->stream())
.get());
ctx_->SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
ctx_->SetZeroAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(platform::CUDAPlace(0))
.get());
ctx_->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
ctx_->PartialInitWithAllocator();
std::map<std::string, std::vector<int>> min_input_shape = {
{"attn", {4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"attn", {4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
std::map<std::string, std::vector<int>> optim_input_shape = {
{"attn", {4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
engine_ = new TensorRTEngine(16,
1 << 10,
AnalysisConfig::Precision::kHalf,
nullptr,
0,
min_input_shape,
max_input_shape,
optim_input_shape,
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
false,
phi::DataType::FLOAT16,
NaiveLogger::Global());
engine_->InitNetwork();
}
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
}
void PrepareInputOutput(const std::vector<std::vector<float16>> inputs,
std::vector<std::vector<int>> output_shapes) {
LOG(INFO) << "PrepareInputOutput";
int num_inputs = inputs.size();
int num_outputs = output_shapes.size();
inputs_.resize(num_inputs);
outputs_.resize(num_outputs);
for (int i = 0; i < num_inputs; ++i) {
paddle::framework::TensorFromVector(inputs[i], *ctx_, &inputs_[i]);
}
for (int i = 0; i < num_outputs; ++i) {
outputs_[i].Resize(phi::make_ddim(output_shapes[i]));
}
}
void GetOutput(std::vector<float> &slimmed_x, // NOLINT
std::vector<int32_t> &cls_inds) { // NOLINT
paddle::framework::TensorToVector(outputs_[0], *ctx_, &slimmed_x);
paddle::framework::TensorToVector(outputs_[1], *ctx_, &cls_inds);
}
protected:
std::vector<phi::DenseTensor> inputs_;
std::vector<phi::DenseTensor> outputs_;
TensorRTEngine *engine_;
phi::GPUContext *ctx_;
};
TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000)
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto *attn = engine_->DeclareInput(
"attn", nvinfer1::DataType::kHALF, nvinfer1::Dims2{-1, 4});
auto *x = engine_->DeclareInput(
"x", nvinfer1::DataType::kHALF, nvinfer1::Dims3{-1, 4, 1});
auto *mask = engine_->DeclareInput(
"mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4});
auto *new_mask = engine_->DeclareInput(
"new_mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 2, 2});
plugin::FusedTokenPrunePluginDynamic *plugin =
new plugin::FusedTokenPrunePluginDynamic(/*with_fp16*/ true,
/*keep_first_token*/ false,
/*keep_order*/ true,
/*flag_varseqlen*/ false);
std::vector<nvinfer1::ITensor *> itensors = {attn, x, mask, new_mask};
auto *layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin);
PADDLE_ENFORCE_NOT_NULL(layer,
platform::errors::InvalidArgument(
"TRT fused_token_prune layer building failed."));
std::vector<std::string> output_tensor_names{"out_slimmed_x", "out_cls_inds"};
for (size_t i = 0; i < 2; i++) {
layer->getOutput(i)->setName(output_tensor_names[i].c_str());
engine_->DeclareOutput(layer, i, output_tensor_names[i]);
}
engine_->FreezeNetwork();
ASSERT_EQ(engine_->engine()->getNbBindings(), 6);
LOG(INFO) << "create input";
std::vector<float16> attn_v(16);
for (int j = 0; j < 4; ++j) {
for (int k = 0; k < 4; ++k) {
attn_v[j * 4 + k] = k;
}
}
std::vector<float16> x_v(16); std::vector<float16> x_v(16);
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
x_v[i * 4 + j] = 1; x_v[i * 4 + j] = 4 - j;
} }
} }
std::vector<float16> mask_v(64); std::vector<float16> mask_v(64);
...@@ -509,20 +706,24 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { ...@@ -509,20 +706,24 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
engine_->Execute(4, &buffers, ctx_->stream()); engine_->Execute(4, &buffers, ctx_->stream());
std::vector<float> slimmed_x_v; std::vector<float> slimmed_x_v(8);
std::vector<int32_t> cls_inds_v; std::vector<int32_t> cls_inds_v;
LOG(INFO) << "GetOutput"; LOG(INFO) << "GetOutput";
GetOutput(slimmed_x_v, cls_inds_v); GetOutput(slimmed_x_v, cls_inds_v);
ASSERT_EQ(cls_inds_v[0], 2); // slimmed_x_v: [[4,3,2,1],[4,3,2,1],[4,3,2,1],[4,3,2,1]] ->
ASSERT_EQ(cls_inds_v[1], 3); // [[2,1],[2,1],[2,1],[2,1]]
ASSERT_EQ(cls_inds_v[2], 2);
ASSERT_EQ(cls_inds_v[3], 3); ASSERT_EQ(slimmed_x_v[0], 2);
ASSERT_EQ(cls_inds_v[4], 2); ASSERT_EQ(slimmed_x_v[1], 1);
ASSERT_EQ(cls_inds_v[5], 3); ASSERT_EQ(slimmed_x_v[2], 2);
ASSERT_EQ(cls_inds_v[6], 2); ASSERT_EQ(slimmed_x_v[3], 1);
ASSERT_EQ(cls_inds_v[7], 3); ASSERT_EQ(slimmed_x_v[4], 2);
ASSERT_EQ(slimmed_x_v[5], 1);
ASSERT_EQ(slimmed_x_v[6], 2);
ASSERT_EQ(slimmed_x_v[7], 1);
LOG(INFO) << "finish"; LOG(INFO) << "finish";
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册