未验证 提交 29782728 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]optimize token prune for Paddle-TensorRT (#48241)

* optimize token prune
上级 d39f3fb6
...@@ -131,6 +131,21 @@ void Activation::operator()() { ...@@ -131,6 +131,21 @@ void Activation::operator()() {
// Add links for activation op. // Add links for activation op.
activation_op->LinksFrom({activation_input}).LinksTo({activation_out}); activation_op->LinksFrom({activation_input}).LinksTo({activation_out});
} }
void FusedTokenPrune::operator()() {
// Create nodes for fused_token_prune.
auto* fused_token_prune_input =
pattern->NewNode(fused_token_prune_input_repr())
->assert_is_op_input("fused_token_prune", "X");
auto* fused_token_prune_op = pattern->NewNode(fused_token_prune_op_repr())
->assert_is_op("fused_token_prune");
auto* fused_token_prune_output =
pattern->NewNode(fused_token_prune_output_repr())
->assert_is_op_output("fused_token_prune", "SlimmedX");
fused_token_prune_op->LinksFrom({fused_token_prune_input})
.LinksTo({fused_token_prune_output});
}
} // namespace patterns } // namespace patterns
void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...@@ -563,6 +578,48 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -563,6 +578,48 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
}; };
gpd6(graph, handler6); gpd6(graph, handler6);
GraphPatternDetector gpd7;
patterns::FusedTokenPrune fused_token_prune(
gpd7.mutable_pattern(), "remove_padding_recover_padding_pass");
fused_token_prune();
auto handler7 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"fused_token_prune";
GET_IR_NODE_FROM_SUBGRAPH(
fused_token_prune_input, fused_token_prune_input, fused_token_prune);
GET_IR_NODE_FROM_SUBGRAPH(
fused_token_prune_op, fused_token_prune_op, fused_token_prune);
GET_IR_NODE_FROM_SUBGRAPH(
fused_token_prune_output, fused_token_prune_output, fused_token_prune);
std::vector<int64_t> fused_token_prune_input_shape =
fused_token_prune_input->Var()->GetShape();
check_flag = true;
if (fused_token_prune_input_shape.size() !=
multihead_matmul_input_shape.size()) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
for (size_t i = 0; i < fused_token_prune_input_shape.size(); ++i) {
if (fused_token_prune_input_shape[i] != multihead_matmul_input_shape[i]) {
check_flag = false;
}
}
if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
insert_recover_padding_op(fused_token_prune_op, fused_token_prune_output);
found_subgraph_count++;
};
gpd7(graph, handler7);
AddStatis(found_subgraph_count); AddStatis(found_subgraph_count);
} }
......
...@@ -95,7 +95,6 @@ struct Fc : public PatternBase { ...@@ -95,7 +95,6 @@ struct Fc : public PatternBase {
PATTERN_DECL_NODE(fc_input); PATTERN_DECL_NODE(fc_input);
PATTERN_DECL_NODE(fc_op); PATTERN_DECL_NODE(fc_op);
PATTERN_DECL_NODE(fc_out);
}; };
struct Activation : public PatternBase { struct Activation : public PatternBase {
...@@ -108,6 +107,17 @@ struct Activation : public PatternBase { ...@@ -108,6 +107,17 @@ struct Activation : public PatternBase {
PATTERN_DECL_NODE(activation_op); PATTERN_DECL_NODE(activation_op);
PATTERN_DECL_NODE(activation_out); PATTERN_DECL_NODE(activation_out);
}; };
struct FusedTokenPrune : public PatternBase {
FusedTokenPrune(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "fused_token_prune") {}
void operator()();
PATTERN_DECL_NODE(fused_token_prune_input);
PATTERN_DECL_NODE(fused_token_prune_op);
PATTERN_DECL_NODE(fused_token_prune_output);
};
} // namespace patterns } // namespace patterns
class RemovePaddingRecoverPaddingPass : public FusePassBase { class RemovePaddingRecoverPaddingPass : public FusePassBase {
......
...@@ -52,8 +52,21 @@ class FusedTokenPruneOpConverter : public OpConverter { ...@@ -52,8 +52,21 @@ class FusedTokenPruneOpConverter : public OpConverter {
auto* word_id = engine_->GetITensor("word_id"); auto* word_id = engine_->GetITensor("word_id");
auto* pos_id = engine_->GetITensor("pos_id"); auto* pos_id = engine_->GetITensor("pos_id");
auto* mask_id = engine_->GetITensor("mask_id"); auto* mask_id = engine_->GetITensor("mask_id");
// reduce_sum: (-1,headsize,token_length,token_length) ->
// (-1,token_length)
uint32_t reduce_dim = 0;
reduce_dim |= 1 << 1; // 00000000000000000000000000000010
reduce_dim |= 1 << 2; // 00000000000000000000000000000110
bool keep_dim = false;
nvinfer1::ReduceOperation reduce_type = nvinfer1::ReduceOperation::kSUM;
auto* reduce_sum_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *Attn, reduce_type, reduce_dim, keep_dim);
// reduce_sum_layer->getOutput(0)->setType(reduce_sum_layer->getInput(0)->getType());
auto* Reduced = reduce_sum_layer->getOutput(0);
std::vector<nvinfer1::ITensor*> itensors = { std::vector<nvinfer1::ITensor*> itensors = {
Attn, X, Mask, NewMask, word_id, pos_id, mask_id}; Reduced, X, Mask, NewMask, word_id, pos_id, mask_id};
layer = engine_->AddDynamicPlugin(itensors.data(), 7, plugin); layer = engine_->AddDynamicPlugin(itensors.data(), 7, plugin);
layer->getOutput(0)->setName(output_name.c_str()); layer->getOutput(0)->setName(output_name.c_str());
......
...@@ -31,19 +31,15 @@ namespace inference { ...@@ -31,19 +31,15 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
#if IS_TRT_VERSION_GE(6000)
template <typename T> template <typename T>
__global__ void ElementwiseMask(const T* a, __global__ void ElementwiseMask(const T* a,
const T* b, const T* b,
T* res, T* res,
int num_elements) { int num_elements) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
auto tid = threadIdx.x + blockIdx.x * blockDim.x; auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_elements) return; if (tid >= num_elements) return;
const T zero = 0; const T zero = 0;
res[tid] = b[tid] >= zero ? a[tid] : zero; res[tid] = b[tid] >= zero ? a[tid] : zero;
#endif
} }
template <typename T> template <typename T>
...@@ -123,7 +119,6 @@ __global__ void ReduceSum2( ...@@ -123,7 +119,6 @@ __global__ void ReduceSum2(
template <> template <>
__global__ void ReduceSum2<half>( __global__ void ReduceSum2<half>(
const half* src, half* dst, int bsz, int nb_head, int max_seq_len) { const half* src, half* dst, int bsz, int nb_head, int max_seq_len) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = threadIdx.x; int tid = threadIdx.x;
int bid = blockIdx.x; int bid = blockIdx.x;
int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len); int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len);
...@@ -155,7 +150,6 @@ __global__ void ReduceSum2<half>( ...@@ -155,7 +150,6 @@ __global__ void ReduceSum2<half>(
static_cast<size_t>(bsz * max_seq_len), static_cast<size_t>(bsz * max_seq_len),
static_cast<platform::float16>(res_half[0])); static_cast<platform::float16>(res_half[0]));
} }
#endif
} }
template <typename T> template <typename T>
...@@ -177,14 +171,81 @@ __global__ void TakeAlongAxis(const T* src, ...@@ -177,14 +171,81 @@ __global__ void TakeAlongAxis(const T* src,
} }
} }
__global__ void pos_id_prune_kernel(const int32_t* src, __global__ void compute_token_length(const int32_t* src,
int32_t* dst, int32_t* dst,
int pos_nums,
float scale) { float scale) {
dst[0] = 0; int32_t it = threadIdx.x;
for (int i = 1; i < pos_nums; i++) { dst[it] = max(static_cast<int>((src[it + 1] - src[it]) * scale), 1);
dst[i] = }
dst[i - 1] + max(static_cast<int>((src[i] - src[i - 1]) * scale), 2);
__global__ void fill_index_padding_score(int32_t* token_index,
const half* scores,
int32_t scores_size,
half* padding_scores) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
token_index[tid] = threadIdx.x;
if (tid < scores_size) {
padding_scores[tid] = scores[tid];
} else {
padding_scores[tid] = 0;
}
}
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD>
__global__ void general_topk_pair_sort(T* in_keys, int32_t* in_out_values) {
typedef cub::BlockRadixSort<T, BLOCK_THREADS, ITEMS_PER_THREAD, int>
BlockRadixSort;
typedef cub::
BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadKey;
typedef cub::
BlockLoad<int, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadValue;
typedef cub::
BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_TRANSPOSE>
BlockStoreKey;
typedef cub::BlockStore<int,
BLOCK_THREADS,
ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>
BlockStoreValue;
__shared__ union {
typename BlockRadixSort::TempStorage sort;
typename BlockLoadKey::TempStorage loadkey;
typename BlockLoadValue::TempStorage loadvalue;
typename BlockStoreKey::TempStorage storekey;
typename BlockStoreValue::TempStorage storevalue;
} temp_storage;
int block_offset = blockIdx.x * BLOCK_THREADS * ITEMS_PER_THREAD;
T thread_keys[ITEMS_PER_THREAD];
int thread_values[ITEMS_PER_THREAD];
BlockLoadKey(temp_storage.loadkey).Load(in_keys + block_offset, thread_keys);
BlockLoadValue(temp_storage.loadvalue)
.Load(in_out_values + block_offset, thread_values);
__syncthreads();
BlockRadixSort(temp_storage.sort).SortDescending(thread_keys, thread_values);
__syncthreads();
BlockStoreValue(temp_storage.storevalue)
.Store(in_out_values + block_offset, thread_values);
}
__global__ void varlen_prune_token(const half* tokens,
const int32_t* token_pos,
const int32_t* token_index,
half* output) {
int batch = blockIdx.x;
int token_it = batch * gridDim.y + blockIdx.y;
int pre_value_it =
token_it * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + threadIdx.x;
if (token_index[token_it] < token_pos[batch + 1] - token_pos[batch]) {
output[(token_index[token_it] + token_pos[batch]) * gridDim.z * blockDim.x +
blockIdx.z * blockDim.x + threadIdx.x] = tokens[pre_value_it];
} }
} }
...@@ -195,9 +256,29 @@ nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions( ...@@ -195,9 +256,29 @@ nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions(
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT { nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
auto x_dims = inputs[1], new_mask_dims = inputs[3]; auto x_dims = inputs[1], new_mask_dims = inputs[3];
if (flag_varseqlen_) { if (flag_varseqlen_) {
// max sum of seqlen: ceil(sum / scale) + n -1 >= for(i=0;i<n;i++) {sum +=
// floor(num(i) / scale)} auto
// pruned_sum_length=std::ceil(inputs[4].d[0]*new_mask_dims.d[2]/inputs[6].d[1])+
// inputs[1].d[0] - 1;
auto pruned_sum_length = expr_builder.operation(
nvinfer1::DimensionOperation::kSUB,
*expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kCEIL_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*inputs[4].d[0],
*new_mask_dims.d[2]),
*inputs[6].d[1]),
*inputs[1].d[0]),
*expr_builder.constant(1));
if (output_index == 0) { if (output_index == 0) {
nvinfer1::DimsExprs ret = x_dims; nvinfer1::DimsExprs ret;
ret.d[1] = new_mask_dims.d[2]; ret.nbDims = 4;
ret.d[0] = pruned_sum_length;
ret.d[1] = x_dims.d[2];
ret.d[2] = expr_builder.constant(1);
ret.d[3] = expr_builder.constant(1);
return ret; return ret;
} else if (output_index == 1) { } else if (output_index == 1) {
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
...@@ -209,18 +290,7 @@ nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions( ...@@ -209,18 +290,7 @@ nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions(
// word id // word id
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
ret.nbDims = 1; ret.nbDims = 1;
// max sum of seqlen: pre_seqlen * new_mask[2] / mask[1] + 2 * batchs ret.d[0] = pruned_sum_length;
const auto* two = expr_builder.constant(2);
ret.d[0] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*inputs[4].d[0],
*new_mask_dims.d[2]),
*inputs[6].d[1]),
*expr_builder.operation(
nvinfer1::DimensionOperation::kPROD, *two, *inputs[6].d[0]));
return ret; return ret;
} else if (output_index == 3) { } else if (output_index == 3) {
// pos id // pos id
...@@ -269,26 +339,18 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( ...@@ -269,26 +339,18 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
const nvinfer1::PluginTensorDesc& in = in_out[pos]; const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (flag_varseqlen_) { if (flag_varseqlen_) {
if (pos == 0) { if (pos <= 3 || pos == 7) {
if (with_fp16_) { if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE return (in.type == nvinfer1::DataType::kHALF) &&
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
} else { } else {
return (in.type == nvinfer1::DataType::kFLOAT) && PADDLE_THROW(platform::errors::Fatal(
(in.format == nvinfer1::TensorFormat::kLINEAR); "The FusedTokenPrune TRT Plugin's input type "
"should be half for varseqlen."));
} }
} else if (pos <= 3 || pos == 7) {
const nvinfer1::PluginTensorDesc& prev = in_out[0];
return in.type == prev.type && in.format == prev.format;
} else if (pos == 6 || pos == 11) { // mask_id, mask_id_out } else if (pos == 6 || pos == 11) { // mask_id, mask_id_out
return in.type == nvinfer1::DataType::kFLOAT && return (in.type == nvinfer1::DataType::kFLOAT) &&
in.format == nvinfer1::TensorFormat::kLINEAR; (in.format == nvinfer1::TensorFormat::kLINEAR);
} else { } else {
return in.type == nvinfer1::DataType::kINT32 && return in.type == nvinfer1::DataType::kINT32 &&
in.format == nvinfer1::TensorFormat::kLINEAR; in.format == nvinfer1::TensorFormat::kLINEAR;
...@@ -296,14 +358,9 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( ...@@ -296,14 +358,9 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
} else { } else {
if (pos == 0) { if (pos == 0) {
if (with_fp16_) { if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE return (in.type == nvinfer1::DataType::kHALF) &&
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
} else { } else {
return (in.type == nvinfer1::DataType::kFLOAT) && return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
...@@ -324,9 +381,9 @@ nvinfer1::DataType FusedTokenPrunePluginDynamic::getOutputDataType( ...@@ -324,9 +381,9 @@ nvinfer1::DataType FusedTokenPrunePluginDynamic::getOutputDataType(
int nb_inputs) const TRT_NOEXCEPT { int nb_inputs) const TRT_NOEXCEPT {
if (flag_varseqlen_) { if (flag_varseqlen_) {
if (index == 0) { if (index == 0) {
return input_types[1]; return nvinfer1::DataType::kHALF;
} else if (index == 4) { } else if (index == 4) { // mask id
return nvinfer1::DataType::kFLOAT; return input_types[6];
} else { } else {
// index = 1,2,3 // index = 1,2,3
return nvinfer1::DataType::kINT32; return nvinfer1::DataType::kINT32;
...@@ -557,14 +614,6 @@ inline void enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc, ...@@ -557,14 +614,6 @@ inline void enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc,
} }
} }
inline void pos_id_prune(const int32_t* input,
int32_t* output,
int pos_nums,
float scale,
cudaStream_t stream) {
pos_id_prune_kernel<<<1, 1, 0, stream>>>(input, output, pos_nums, scale);
}
int FusedTokenPrunePluginDynamic::enqueue( int FusedTokenPrunePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc, const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const nvinfer1::PluginTensorDesc* output_desc,
...@@ -572,6 +621,107 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -572,6 +621,107 @@ int FusedTokenPrunePluginDynamic::enqueue(
void* const* outputs, void* const* outputs,
void* workspace, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT { cudaStream_t stream) TRT_NOEXCEPT {
if (flag_varseqlen_) {
if (!(input_desc[0].type == nvinfer1::DataType::kHALF &&
input_desc[1].type == nvinfer1::DataType::kHALF)) {
PADDLE_THROW(
platform::errors::InvalidArgument("Token_prune'type must half"));
}
float scale =
static_cast<float>(input_desc[3].dims.d[2]) / input_desc[6].dims.d[1];
const int32_t* inputs5 =
static_cast<const int32_t*>(inputs[5]); // pre pos id
int32_t* outputs3 = static_cast<int32_t*>(outputs[3]); // new pos id
half* outputs0 = static_cast<half*>(outputs[0]);
const int32_t B = input_desc[1].dims.d[0]; // batchs
const int32_t max_sequnce_length =
input_desc[1].dims.d[1]; // max sequnce length
const int32_t length = input_desc[1].dims.d[2]; // vector length
const half* scores = static_cast<const half*>(inputs[0]); // reduce sum
const half* tokens = static_cast<const half*>(inputs[1]);
const int32_t scores_size = B * max_sequnce_length;
int32_t padding_token_length;
if (max_sequnce_length <= 128) {
padding_token_length = 128;
} else if (max_sequnce_length <= 256) {
padding_token_length = 256;
} else if (max_sequnce_length <= 384) {
padding_token_length = 384;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Token_prune'token_length must <= 384"));
}
// 1. Compute the token length after pruning.
compute_token_length<<<1, B, 0, stream>>>(
inputs5, pruned_token_lengths_, scale);
fill_index_padding_score<<<B, padding_token_length, 0, stream>>>(
token_index_, scores, scores_size, padding_scores_);
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(d_temp_storage,
temp_storage_bytes,
pruned_token_lengths_,
outputs3,
B + 1);
// Allocate temporary storage
cudaMalloc(&d_temp_storage, temp_storage_bytes);
// Run exclusive prefix sum
cub::DeviceScan::ExclusiveSum(d_temp_storage,
temp_storage_bytes,
pruned_token_lengths_,
outputs3,
B + 1);
if (padding_token_length == 128) {
general_topk_pair_sort<half, 32, 4>
<<<B, 32, 0, stream>>>(padding_scores_, token_index_); // 128
} else if (padding_token_length == 256) {
general_topk_pair_sort<half, 64, 4>
<<<B, 64, 0, stream>>>(padding_scores_, token_index_); // 256
} else {
general_topk_pair_sort<half, 96, 4>
<<<B, 96, 0, stream>>>(padding_scores_, token_index_); // 384
}
int32_t num_threads;
if (length < 1024) {
num_threads = length;
} else {
if (length % 512 == 0) {
num_threads = 512;
} else if (length % 256 == 0) {
num_threads = 256;
} else if (length % 128 == 0) {
num_threads = 128;
} else if (length % 64 == 0) {
num_threads = 64;
} else if (length % 32 == 0) {
num_threads = 32;
} else if (length % 16 == 0) {
num_threads = 16;
} else if (length % 8 == 0) {
num_threads = 8;
} else if (length % 4 == 0) {
num_threads = 4;
} else if (length % 2 == 0) {
num_threads = 2;
} else {
num_threads = 1;
}
}
const dim3 num_blocks(
B,
max_sequnce_length,
length / num_threads); // batchs, max_sequnce_length, vector_ength/***
varlen_prune_token<<<num_blocks, num_threads, 0, stream>>>(
tokens, outputs3, token_index_, outputs0);
} else {
auto input_type = input_desc[0].type; auto input_type = input_desc[0].type;
auto attn_dims = input_desc[0].dims; auto attn_dims = input_desc[0].dims;
auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1], auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1],
...@@ -596,11 +746,9 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -596,11 +746,9 @@ int FusedTokenPrunePluginDynamic::enqueue(
keep_order_); keep_order_);
} else if (input_type == nvinfer1::DataType::kHALF) { } else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp16"; VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp16";
half max = 65504.0; half max = 65504.0;
enqueueImpl<half>(input_desc, enqueueImpl<half>(input_desc,
output_desc, output_desc,
inputs, inputs,
...@@ -611,34 +759,15 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -611,34 +759,15 @@ int FusedTokenPrunePluginDynamic::enqueue(
max, max,
keep_first_token_, keep_first_token_,
keep_order_); keep_order_);
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) TensorRT Plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.SetTRTDynamicShapeInfo(min_input_shape, "
"max_input_shape, opt_input_shape, true"));
#endif
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type " platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type "
"should be float or half.")); "should be float or half."));
} }
if (flag_varseqlen_) {
float scale =
static_cast<float>(input_desc[3].dims.d[2]) / input_desc[6].dims.d[1];
// outputs[2]=inputs[4]; // word_id
const int32_t* inputs5 = static_cast<const int32_t*>(inputs[5]);
int32_t* outputs3 = static_cast<int32_t*>(outputs[3]);
pos_id_prune(
inputs5, outputs3, input_desc[5].dims.d[0], scale, stream); // pos_id
// outputs[4]=inputs[6]; // new_mask
} }
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
#endif
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -30,11 +31,10 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { ...@@ -30,11 +31,10 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
bool keep_first_token, bool keep_first_token,
bool keep_order, bool keep_order,
bool flag_varseqlen) bool flag_varseqlen)
: keep_first_token_(keep_first_token), : with_fp16_(with_fp16),
keep_first_token_(keep_first_token),
keep_order_(keep_order), keep_order_(keep_order),
flag_varseqlen_(flag_varseqlen) { flag_varseqlen_(flag_varseqlen) {}
with_fp16_ = with_fp16;
}
FusedTokenPrunePluginDynamic(void const* serial_data, size_t serial_length) { FusedTokenPrunePluginDynamic(void const* serial_data, size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &with_fp16_); DeserializeValue(&serial_data, &serial_length, &with_fp16_);
DeserializeValue(&serial_data, &serial_length, &keep_first_token_); DeserializeValue(&serial_data, &serial_length, &keep_first_token_);
...@@ -42,8 +42,14 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { ...@@ -42,8 +42,14 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
DeserializeValue(&serial_data, &serial_length, &flag_varseqlen_); DeserializeValue(&serial_data, &serial_length, &flag_varseqlen_);
} }
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new FusedTokenPrunePluginDynamic( FusedTokenPrunePluginDynamic* ptr = new FusedTokenPrunePluginDynamic(
with_fp16_, keep_first_token_, keep_order_, flag_varseqlen_); with_fp16_, keep_first_token_, keep_order_, flag_varseqlen_);
ptr->max_batchs_ = max_batchs_;
ptr->max_token_length_ = max_token_length_;
ptr->pruned_token_lengths_ = pruned_token_lengths_;
ptr->token_index_ = token_index_;
ptr->padding_scores_ = padding_scores_;
return ptr;
} }
const char* getPluginType() const TRT_NOEXCEPT override { const char* getPluginType() const TRT_NOEXCEPT override {
...@@ -84,7 +90,16 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { ...@@ -84,7 +90,16 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs, int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out, const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT override {} int nb_outputs) TRT_NOEXCEPT override {
max_batchs_ = in[1].max.d[0];
max_token_length_ = in[1].max.d[1];
PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(&pruned_token_lengths_,
(max_batchs_ + 1) * sizeof(int32_t)));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(
&token_index_, max_batchs_ * max_token_length_ * sizeof(int32_t)));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(
&padding_scores_, max_batchs_ * max_token_length_ * sizeof(half)));
}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs, int nb_inputs,
...@@ -106,9 +121,15 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { ...@@ -106,9 +121,15 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
void destroy() TRT_NOEXCEPT override { delete this; } void destroy() TRT_NOEXCEPT override { delete this; }
private: private:
bool with_fp16_;
bool keep_first_token_; bool keep_first_token_;
bool keep_order_; bool keep_order_;
bool flag_varseqlen_; bool flag_varseqlen_;
int32_t* pruned_token_lengths_;
int32_t* token_index_;
int32_t max_batchs_;
int32_t max_token_length_;
half* padding_scores_;
}; };
class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator { class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator {
......
...@@ -19,9 +19,9 @@ namespace inference { ...@@ -19,9 +19,9 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
__global__ void RecoverPaddingKernel(const float* input0, __global__ void RecoverPaddingKernel(const half* input0,
const int32_t* input1, const int32_t* input1,
float* output) { half* output) {
int word_id = blockIdx.x * gridDim.y + blockIdx.y; int word_id = blockIdx.x * gridDim.y + blockIdx.y;
int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x]; int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x];
if (blockIdx.y < seqence_length) { if (blockIdx.y < seqence_length) {
...@@ -79,7 +79,7 @@ bool RecoverPaddingPlugin::supportsFormatCombination( ...@@ -79,7 +79,7 @@ bool RecoverPaddingPlugin::supportsFormatCombination(
return inOut[pos].type == nvinfer1::DataType::kFLOAT && return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} else { } else {
return inOut[pos].type == nvinfer1::DataType::kFLOAT && return inOut[pos].type == nvinfer1::DataType::kHALF &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} }
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format // return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
...@@ -114,38 +114,43 @@ int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, ...@@ -114,38 +114,43 @@ int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const auto input0_desc = inputDesc[0]; const auto input0_desc = inputDesc[0];
const auto input1_desc = inputDesc[1]; const auto input1_desc = inputDesc[1];
const auto input2_desc = inputDesc[2]; const auto input2_desc = inputDesc[2];
const float* input0 = static_cast<const float*>(inputs[0]); const half* input0 = static_cast<const half*>(inputs[0]);
const int32_t* input1 = const int32_t* input1 =
static_cast<const int32_t*>(inputs[1]); // pos_id_tensor static_cast<const int32_t*>(inputs[1]); // pos_id_tensor
float* output = static_cast<float*>(outputs[0]); half* output = static_cast<half*>(outputs[0]);
const int32_t vector_length = input0_desc.dims.d[1];
int32_t num_threads; int32_t num_threads;
if (input0_desc.dims.d[1] % 512 == 0) { if (vector_length < 1024) {
num_threads = vector_length;
} else {
if (vector_length % 512 == 0) {
num_threads = 512; num_threads = 512;
} else if (input0_desc.dims.d[1] % 256 == 0) { } else if (vector_length % 256 == 0) {
num_threads = 256; num_threads = 256;
} else if (input0_desc.dims.d[1] % 128 == 0) { } else if (vector_length % 128 == 0) {
num_threads = 128; num_threads = 128;
} else if (input0_desc.dims.d[1] % 64 == 0) { } else if (vector_length % 64 == 0) {
num_threads = 64; num_threads = 64;
} else if (input0_desc.dims.d[1] % 32 == 0) { } else if (vector_length % 32 == 0) {
num_threads = 32; num_threads = 32;
} else if (input0_desc.dims.d[1] % 16 == 0) { } else if (vector_length % 16 == 0) {
num_threads = 16; num_threads = 16;
} else if (input0_desc.dims.d[1] % 8 == 0) { } else if (vector_length % 8 == 0) {
num_threads = 8; num_threads = 8;
} else if (input0_desc.dims.d[1] % 4 == 0) { } else if (vector_length % 4 == 0) {
num_threads = 4; num_threads = 4;
} else if (input0_desc.dims.d[1] % 2 == 0) { } else if (vector_length % 2 == 0) {
num_threads = 2; num_threads = 2;
} else { } else {
num_threads = 1; num_threads = 1;
} }
}
const dim3 num_blocks( const dim3 num_blocks(
input1_desc.dims.d[0] - 1, input1_desc.dims.d[0] - 1,
input2_desc.dims.d[1], input2_desc.dims.d[1],
input0_desc.dims.d[1] / num_threads); // batchs, max sequnce length vector_length / num_threads); // batchs, max sequnce length
// (mask_id.dims.d[1]), // (mask_id.dims.d[1]),
// input.dims.d[1]/256 // input.dims.d[1]/***
RecoverPaddingKernel<<<num_blocks, num_threads, 0, stream>>>( RecoverPaddingKernel<<<num_blocks, num_threads, 0, stream>>>(
input0, input1, output); input0, input1, output);
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
......
...@@ -19,9 +19,9 @@ namespace inference { ...@@ -19,9 +19,9 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
__global__ void RemovePaddingKernel(const float* input0, __global__ void RemovePaddingKernel(const half* input0,
const int32_t* input1, const int32_t* input1,
float* output) { half* output) {
int word_id = blockIdx.x * gridDim.y + blockIdx.y; int word_id = blockIdx.x * gridDim.y + blockIdx.y;
int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x]; int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x];
if (blockIdx.y < seqence_length) { if (blockIdx.y < seqence_length) {
...@@ -73,7 +73,7 @@ bool RemovePaddingPlugin::supportsFormatCombination( ...@@ -73,7 +73,7 @@ bool RemovePaddingPlugin::supportsFormatCombination(
return inOut[pos].type == nvinfer1::DataType::kINT32 && return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} else { } else {
return inOut[pos].type == nvinfer1::DataType::kFLOAT && return inOut[pos].type == nvinfer1::DataType::kHALF &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} }
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format // return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
...@@ -106,38 +106,43 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, ...@@ -106,38 +106,43 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
void* workspace, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT { cudaStream_t stream) TRT_NOEXCEPT {
const auto input_desc = inputDesc[0]; const auto input_desc = inputDesc[0];
const float* input0 = static_cast<const float*>(inputs[0]); const half* input0 = static_cast<const half*>(inputs[0]);
const int32_t* input1 = const int32_t* input1 =
static_cast<const int32_t*>(inputs[1]); // pos_id_tensor static_cast<const int32_t*>(inputs[1]); // pos_id_tensor
float* output = static_cast<float*>(outputs[0]); half* output = static_cast<half*>(outputs[0]);
const auto input0_desc = inputDesc[0]; const auto input0_desc = inputDesc[0];
const int32_t vector_length = input0_desc.dims.d[2];
int32_t num_threads; int32_t num_threads;
if (input0_desc.dims.d[2] % 512 == 0) { if (vector_length < 1024) {
num_threads = vector_length;
} else {
if (vector_length % 512 == 0) {
num_threads = 512; num_threads = 512;
} else if (input0_desc.dims.d[2] % 256 == 0) { } else if (vector_length % 256 == 0) {
num_threads = 256; num_threads = 256;
} else if (input0_desc.dims.d[2] % 128 == 0) { } else if (vector_length % 128 == 0) {
num_threads = 128; num_threads = 128;
} else if (input0_desc.dims.d[2] % 64 == 0) { } else if (vector_length % 64 == 0) {
num_threads = 64; num_threads = 64;
} else if (input0_desc.dims.d[2] % 32 == 0) { } else if (vector_length % 32 == 0) {
num_threads = 32; num_threads = 32;
} else if (input0_desc.dims.d[2] % 16 == 0) { } else if (vector_length % 16 == 0) {
num_threads = 16; num_threads = 16;
} else if (input0_desc.dims.d[2] % 8 == 0) { } else if (vector_length % 8 == 0) {
num_threads = 8; num_threads = 8;
} else if (input0_desc.dims.d[2] % 4 == 0) { } else if (vector_length % 4 == 0) {
num_threads = 4; num_threads = 4;
} else if (input0_desc.dims.d[2] % 2 == 0) { } else if (vector_length % 2 == 0) {
num_threads = 2; num_threads = 2;
} else { } else {
num_threads = 1; num_threads = 1;
} }
}
const dim3 num_blocks( const dim3 num_blocks(
input0_desc.dims.d[0], input0_desc.dims.d[0],
input0_desc.dims.d[1], input0_desc.dims.d[1],
input0_desc.dims.d[2] / vector_length /
num_threads); // batchs, max sequnce length, input.dims.d[2]/256 num_threads); // batchs, max sequnce length, input0.dims.d[2]/***
RemovePaddingKernel<<<num_blocks, num_threads, 0, stream>>>( RemovePaddingKernel<<<num_blocks, num_threads, 0, stream>>>(
input0, input1, output); input0, input1, output);
......
...@@ -26,12 +26,9 @@ TEST(fused_token_prune_op_plugin, test_plugin) { ...@@ -26,12 +26,9 @@ TEST(fused_token_prune_op_plugin, test_plugin) {
/*keep_first_token*/ false, /*keep_first_token*/ false,
/*keep_order*/ true, /*keep_order*/ true,
/*flag_varseqlen*/ false); /*flag_varseqlen*/ false);
plugin.configurePlugin(nullptr, 4, nullptr, 2);
plugin.initialize(); plugin.initialize();
plugin.getPluginType(); plugin.getPluginType();
plugin.getNbOutputs(); plugin.getNbOutputs();
auto clone_plugin = plugin.clone();
clone_plugin->destroy();
size_t buf_size = plugin.getSerializationSize(); size_t buf_size = plugin.getSerializationSize();
std::vector<char> buf(buf_size); std::vector<char> buf(buf_size);
plugin.serialize(buf.data()); plugin.serialize(buf.data());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册