未验证 提交 24187fcb 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference] add varlen_token_prune plugin, pass, convert (#44733)

* add varlen_token_prune plugin, pass, convert
上级 8482f1ae
......@@ -359,6 +359,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std::vector<int64_t> skip_layernorm_x_shape =
skip_layernorm_x->Var()->GetShape();
check_flag = true;
if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
......@@ -395,6 +396,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc);
std::vector<int64_t> fc_input_shape = fc_input->Var()->GetShape();
check_flag = true;
if ((fc_input_shape.size() != multihead_matmul_input_shape.size()) ||
(fc_input_shape.size() != 3)) {
check_flag = false;
......@@ -446,11 +448,13 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std::vector<int64_t> activation_input_shape =
activation_input->Var()->GetShape();
check_flag = true;
if ((activation_input_shape.size() !=
multihead_matmul_input_shape.size()) ||
(activation_input_shape.size() != 3)) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
VLOG(3) << "Activation: Transformer model remove_padding "
"shape(activation_input_shape.size()) check failed, return "
"remove_padding pass.";
return;
}
......@@ -465,7 +469,8 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
check_flag = false;
}
if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return "
VLOG(3) << "Activation: Transformer model remove_padding "
"shape(activation_input_shape[i]) check failed, return "
"remove_padding pass.";
return;
}
......@@ -530,6 +535,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std::vector<int64_t> skip_layernorm_x_shape =
preln_skip_layernorm_x->Var()->GetShape();
check_flag = true;
if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
......
......@@ -60,6 +60,50 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
std::vector<std::string>{word_id_name, pos_id_name, sent_id_name};
emb_names =
std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name};
auto mask_id_tensor = engine_->GetITensor("mask_id");
auto mask_dims = mask_id_tensor->getDimensions();
auto slice_start_dims = mask_dims;
auto slice_stride_dims = mask_dims;
for (int i = 0; i < mask_dims.nbDims; i++) {
slice_start_dims.d[i] = 0;
slice_stride_dims.d[i] = 1;
}
auto* shape_tensor = Shape(mask_id_tensor);
std::vector<nvinfer1::ITensor*> size_vec_tensor;
for (int i = 0; i < mask_dims.nbDims; i++) {
size_vec_tensor.push_back(Add1DConstantLayer(1));
}
size_vec_tensor[1] = GetEleTensorOfShape(shape_tensor, 1);
auto size_tensor = Concat(size_vec_tensor);
auto slice_layer =
TRT_ENGINE_ADD_LAYER(engine_,
Slice,
*mask_id_tensor,
slice_start_dims,
slice_start_dims,
slice_stride_dims); // unuseful slice_start_dims
slice_layer->setInput(2, *size_tensor);
slice_layer->setName(
("Embeltwise_slice_layer (Output: slice_max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(slice_layer->getOutput(0), 1.0f);
auto* reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *slice_layer->getOutput(0));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
reshape_layer->setReshapeDimensions(shape_dim);
reshape_layer->setName(("Embeltwise_reshape_layer (Output: max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f);
engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0));
} else {
id_names = op_desc.Input("Ids");
emb_names = op_desc.Input("Embs");
......@@ -192,20 +236,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2
auto max_seqlen_tensor = engine_->GetITensor(mask_id_name);
auto* shuffle_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor);
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
shuffle_layer->setName(
("Embeltwise_Shuffle_reshape (Output: max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
plugin_inputs.emplace_back(engine_->GetITensor(
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomEmbLayerNormPluginDynamic", "2");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
......@@ -23,7 +23,6 @@ class FusedTokenPruneOpConverter : public OpConverter {
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
nvinfer1::ILayer* layer = nullptr;
auto* Attn = engine_->GetITensor(op_desc.Input("Attn").front());
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Mask = engine_->GetITensor(op_desc.Input("Mask").front());
......@@ -36,28 +35,54 @@ class FusedTokenPruneOpConverter : public OpConverter {
op_desc.HasAttr("keep_order")
? PADDLE_GET_CONST(bool, op_desc.GetAttr("keep_order"))
: false;
std::vector<nvinfer1::ITensor*> itensors = {Attn, X, Mask, NewMask};
auto output_name = op_desc.Output("SlimmedX")[0];
auto out_inds_name = op_desc.Output("CLSInds")[0];
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = true;
}
bool flag_varseqlen = engine_->use_varseqlen();
plugin::FusedTokenPrunePluginDynamic* plugin =
new plugin::FusedTokenPrunePluginDynamic(
with_fp16, keep_first_token, keep_order);
layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
with_fp16, keep_first_token, keep_order, flag_varseqlen);
if (flag_varseqlen) {
auto* word_id = engine_->GetITensor("word_id");
auto* pos_id = engine_->GetITensor("pos_id");
auto* mask_id = engine_->GetITensor("mask_id");
std::vector<nvinfer1::ITensor*> itensors = {
Attn, X, Mask, NewMask, word_id, pos_id, mask_id};
layer = engine_->AddDynamicPlugin(itensors.data(), 7, plugin);
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
layer->getOutput(1)->setName(out_inds_name.c_str());
engine_->SetITensor(out_inds_name, layer->getOutput(1));
engine_->DeleteITensor("word_id", word_id);
layer->getOutput(2)->setName("word_id_after_token_prune");
engine_->SetITensor("word_id", layer->getOutput(2));
engine_->DeleteITensor("pos_id", pos_id);
layer->getOutput(3)->setName("pos_id_after_token_prune");
engine_->SetITensor("pos_id", layer->getOutput(3));
engine_->DeleteITensor("mask_id", mask_id);
layer->getOutput(4)->setName("mask_id_after_token_prune");
engine_->SetITensor("mask_id", layer->getOutput(4));
} else {
std::vector<nvinfer1::ITensor*> itensors = {Attn, X, Mask, NewMask};
layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin);
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
layer->getOutput(1)->setName(out_inds_name.c_str());
engine_->SetITensor(out_inds_name, layer->getOutput(1));
}
layer->setName(
("fused_token_prune(Output: " + output_name + ")").c_str());
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static shape mode, which "
......@@ -65,8 +90,6 @@ class FusedTokenPruneOpConverter : public OpConverter {
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode."));
}
RreplenishLayerAndOutput(
layer, "fused_token_prune", {output_name, out_inds_name}, test_mode);
}
};
......
......@@ -94,6 +94,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())};
auto max_seqlen_tensor = engine_->GetITensor("max_seqlen_tensor");
auto pos_id_tensor = engine_->GetITensor("pos_id");
if (engine_->with_interleaved()) {
VLOG(4) << "fused multihead_matmul op: use_varseqlen and "
"with_interleaved";
......@@ -154,31 +156,9 @@ class MultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0));
if (engine_->Has("ernie_pos_name")) {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->Get<std::string>("ernie_pos_name")));
} else {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()
->getInput(2)
->getName())); // cu_seqlens, eval_placeholder_2
}
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
engine_->SetTensorDynamicRange(max_seqlen_tensor, 1.0f);
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_,
Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(pos_id_tensor);
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
shuffle_layer->setName(
("Multihead: Shuffle: (Output: " + output_name + ")").c_str());
max_seqlen_tensor); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
......@@ -299,20 +279,9 @@ class MultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask"));
plugin_inputs.emplace_back(engine_->GetITensor("pos_id"));
auto max_seqlen_tensor = engine_->GetITensor("mask_id");
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_,
Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(pos_id_tensor);
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
max_seqlen_tensor); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
......
......@@ -157,20 +157,47 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2
auto max_seqlen_tensor = engine_->GetITensor(mask_id_name);
auto* shuffle_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor);
auto mask_id_tensor = engine_->GetITensor("mask_id");
auto mask_dims = mask_id_tensor->getDimensions();
auto slice_start_dims = mask_dims;
auto slice_size_dims = mask_dims;
auto slice_stride_dims = mask_dims;
for (int i = 0; i < mask_dims.nbDims; i++) {
slice_start_dims.d[i] = 0;
slice_size_dims.d[i] = 1;
slice_stride_dims.d[i] = 1;
}
slice_size_dims.d[1] = mask_dims.d[1];
auto* slice_size_tensor = Add1DConstantLayer(slice_size_dims);
auto slice_layer =
TRT_ENGINE_ADD_LAYER(engine_,
Slice,
*mask_id_tensor,
slice_start_dims,
slice_start_dims,
slice_stride_dims); // unuseful slice_start_dims
slice_layer->setInput(2, *slice_size_tensor);
slice_layer->setName(
("PrelnEmbeltwise_slice_layer (Output: slice_max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(slice_layer->getOutput(0), 1.0f);
auto* reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *slice_layer->getOutput(0));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
shuffle_layer->setName(
("PrelnEmbeltwise_Shuffle_reshape (Output: max_seqlen " +
op_desc.Output("Out_0")[0] + ")")
reshape_layer->setReshapeDimensions(shape_dim);
reshape_layer->setName(
("PrelnEmbeltwise_reshape_layer (Output: max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f);
engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0));
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
reshape_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomEmbLayerNormPluginDynamic", "3");
......
......@@ -111,6 +111,8 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())};
auto max_seqlen_tensor = engine_->GetITensor("max_seqlen_tensor");
auto pos_id_tensor = engine_->GetITensor("pos_id");
if (engine_->with_interleaved()) {
VLOG(4) << "fused multihead_matmul op: use_varseqlen and "
"with_interleaved";
......@@ -171,31 +173,9 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0));
if (engine_->Has("ernie_pos_name")) {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->Get<std::string>("ernie_pos_name")));
} else {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()
->getInput(2)
->getName())); // cu_seqlens, eval_placeholder_2
}
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
engine_->SetTensorDynamicRange(max_seqlen_tensor, 1.0f);
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_,
Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(pos_id_tensor);
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
shuffle_layer->setName(
("Multihead: Shuffle: (Output: " + output_name + ")").c_str());
max_seqlen_tensor); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
......@@ -316,21 +296,9 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask"));
plugin_inputs.emplace_back(engine_->GetITensor("pos_id"));
auto max_seqlen_tensor = engine_->GetITensor("mask_id");
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_,
Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(pos_id_tensor);
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
max_seqlen_tensor); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
......
......@@ -410,6 +410,19 @@ void TensorRTEngine::DeclareOutput(const std::string &name) {
name));
network()->markOutput(*output);
}
void TensorRTEngine::DeleteITensor(const std::string &name,
nvinfer1::ITensor *tensor) {
PADDLE_ENFORCE_NOT_NULL(
tensor,
platform::errors::InvalidArgument(
"Tensor named %s of TRT engine should not be null.", name));
PADDLE_ENFORCE_EQ(
true,
itensor_map_.count(name),
platform::errors::InvalidArgument(
"Tensor named %s of TRT engine should not be null", name));
itensor_map_.erase(name);
}
void TensorRTEngine::SetITensor(const std::string &name,
nvinfer1::ITensor *tensor) {
......
......@@ -278,6 +278,7 @@ class TensorRTEngine {
void DeclareOutput(const std::string& name);
void ClearTensorMap() { itensor_map_.clear(); }
void DeleteITensor(const std::string& name, nvinfer1::ITensor* tensor);
void SetITensor(const std::string& name, nvinfer1::ITensor* tensor);
// Get an ITensor called name.
nvinfer1::ITensor* GetITensor(const std::string& name);
......
......@@ -177,22 +177,75 @@ __global__ void TakeAlongAxis(const T* src,
}
}
__global__ void pos_id_prune_kernel(const int32_t* src,
int32_t* dst,
int pos_nums,
float scale) {
dst[0] = 0;
for (int i = 1; i < pos_nums; i++) {
dst[i] =
dst[i - 1] + max(static_cast<int>((src[i] - src[i - 1]) * scale), 2);
}
}
nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
auto x_dims = inputs[1], new_mask_dims = inputs[3];
if (output_index == 0) {
nvinfer1::DimsExprs ret = x_dims;
ret.d[1] = new_mask_dims.d[2];
return ret;
if (flag_varseqlen_) {
if (output_index == 0) {
nvinfer1::DimsExprs ret = x_dims;
ret.d[1] = new_mask_dims.d[2];
return ret;
} else if (output_index == 1) {
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.d[0] = new_mask_dims.d[0];
ret.d[1] = new_mask_dims.d[2];
return ret;
} else if (output_index == 2) {
// word id
nvinfer1::DimsExprs ret;
ret.nbDims = 1;
// max sum of seqlen: pre_seqlen * new_mask[2] / mask[1] + 2 * batchs
const auto* two = expr_builder.constant(2);
ret.d[0] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*inputs[4].d[0],
*new_mask_dims.d[2]),
*inputs[6].d[1]),
*expr_builder.operation(
nvinfer1::DimensionOperation::kPROD, *two, *inputs[6].d[0]));
return ret;
} else if (output_index == 3) {
// pos id
nvinfer1::DimsExprs ret = inputs[5];
return ret;
} else if (output_index == 4) {
// mask id
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.d[0] = inputs[6].d[0];
ret.d[1] = new_mask_dims.d[2];
return ret;
}
} else {
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.d[0] = new_mask_dims.d[0];
ret.d[1] = new_mask_dims.d[2];
return ret;
if (output_index == 0) {
nvinfer1::DimsExprs ret = x_dims;
ret.d[1] = new_mask_dims.d[2];
return ret;
} else {
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.d[0] = new_mask_dims.d[0];
ret.d[1] = new_mask_dims.d[2];
return ret;
}
}
}
......@@ -215,26 +268,53 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
if (flag_varseqlen_) {
if (pos == 0) {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
} else if (pos <= 3 || pos == 7) {
const nvinfer1::PluginTensorDesc& prev = in_out[0];
return in.type == prev.type && in.format == prev.format;
} else if (pos == 6 || pos == 11) { // mask_id, mask_id_out
return in.type == nvinfer1::DataType::kFLOAT &&
in.format == nvinfer1::TensorFormat::kLINEAR;
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
return in.type == nvinfer1::DataType::kINT32 &&
in.format == nvinfer1::TensorFormat::kLINEAR;
}
} else if (pos <= 4) {
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
return in.type == prev.type && in.format == prev.format;
} else {
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
return in.type == nvinfer1::DataType::kINT32 && in.format == prev.format;
if (pos == 0) {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
} else if (pos <= 4) {
const nvinfer1::PluginTensorDesc& prev = in_out[0];
return in.type == prev.type && in.format == prev.format;
} else {
return in.type == nvinfer1::DataType::kINT32 &&
in.format == nvinfer1::TensorFormat::kLINEAR;
}
}
}
......@@ -242,10 +322,22 @@ nvinfer1::DataType FusedTokenPrunePluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
if (index == 0) {
return input_types[1];
} else if (index == 1) {
return nvinfer1::DataType::kINT32;
if (flag_varseqlen_) {
if (index == 0) {
return input_types[1];
} else if (index == 4) {
return nvinfer1::DataType::kFLOAT;
} else {
// index = 1,2,3
return nvinfer1::DataType::kINT32;
}
} else {
if (index == 0) {
return input_types[1];
} else {
// index = 1
return nvinfer1::DataType::kINT32;
}
}
}
......@@ -273,15 +365,16 @@ size_t FusedTokenPrunePluginDynamic::getWorkspaceSize(
}
template <typename T>
int FusedTokenPrunePluginDynamic::enqueueImpl(
const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace_ptr,
cudaStream_t stream,
int device_id,
T max_value) {
inline void enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace_ptr,
cudaStream_t stream,
int device_id,
T max_value,
bool keep_first_token_,
bool keep_order_) {
// Dims
auto attn_dims = input_desc[0].dims;
auto x_dims = input_desc[1].dims;
......@@ -462,8 +555,14 @@ int FusedTokenPrunePluginDynamic::enqueueImpl(
slimmed_x_len,
c);
}
}
return cudaGetLastError() != cudaSuccess;
inline void pos_id_prune(const int32_t* input,
int32_t* output,
int pos_nums,
float scale,
cudaStream_t stream) {
pos_id_prune_kernel<<<1, 1, 0, stream>>>(input, output, pos_nums, scale);
}
int FusedTokenPrunePluginDynamic::enqueue(
......@@ -485,14 +584,16 @@ int FusedTokenPrunePluginDynamic::enqueue(
float max = std::numeric_limits<float>::max();
return enqueueImpl<float>(input_desc,
output_desc,
inputs,
outputs,
workspace,
stream,
device_id,
max);
enqueueImpl<float>(input_desc,
output_desc,
inputs,
outputs,
workspace,
stream,
device_id,
max,
keep_first_token_,
keep_order_);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
......@@ -500,14 +601,16 @@ int FusedTokenPrunePluginDynamic::enqueue(
half max = 65504.0;
return enqueueImpl<half>(input_desc,
output_desc,
inputs,
outputs,
workspace,
stream,
device_id,
max);
enqueueImpl<half>(input_desc,
output_desc,
inputs,
outputs,
workspace,
stream,
device_id,
max,
keep_first_token_,
keep_order_);
#else
PADDLE_THROW(platform::errors::Fatal(
......@@ -522,6 +625,17 @@ int FusedTokenPrunePluginDynamic::enqueue(
platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type "
"should be float or half."));
}
if (flag_varseqlen_) {
float scale =
static_cast<float>(input_desc[3].dims.d[2]) / input_desc[6].dims.d[1];
// outputs[2]=inputs[4]; // word_id
const int32_t* inputs5 = static_cast<const int32_t*>(inputs[5]);
int32_t* outputs3 = static_cast<int32_t*>(outputs[3]);
pos_id_prune(
inputs5, outputs3, input_desc[5].dims.d[0], scale, stream); // pos_id
// outputs[4]=inputs[6]; // new_mask
}
return cudaGetLastError() != cudaSuccess;
}
#endif
......
......@@ -28,34 +28,45 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
public:
explicit FusedTokenPrunePluginDynamic(bool with_fp16,
bool keep_first_token,
bool keep_order)
: keep_first_token_(keep_first_token), keep_order_(keep_order) {
bool keep_order,
bool flag_varseqlen)
: keep_first_token_(keep_first_token),
keep_order_(keep_order),
flag_varseqlen_(flag_varseqlen) {
with_fp16_ = with_fp16;
}
FusedTokenPrunePluginDynamic(void const* serial_data, size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
DeserializeValue(&serial_data, &serial_length, &keep_first_token_);
DeserializeValue(&serial_data, &serial_length, &keep_order_);
DeserializeValue(&serial_data, &serial_length, &flag_varseqlen_);
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new FusedTokenPrunePluginDynamic(
with_fp16_, keep_first_token_, keep_order_);
with_fp16_, keep_first_token_, keep_order_, flag_varseqlen_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "fused_token_prune_plugin_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 2; }
int getNbOutputs() const TRT_NOEXCEPT override {
if (flag_varseqlen_) {
return 5;
} else {
return 2;
}
}
int initialize() TRT_NOEXCEPT override { return 0; }
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(with_fp16_) + SerializedSize(keep_first_token_) +
SerializedSize(keep_order_);
SerializedSize(keep_order_) + SerializedSize(flag_varseqlen_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, keep_first_token_);
SerializeValue(&buffer, keep_order_);
SerializeValue(&buffer, flag_varseqlen_);
}
nvinfer1::DimsExprs getOutputDimensions(
......@@ -95,17 +106,9 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
void destroy() TRT_NOEXCEPT override { delete this; }
private:
template <typename T>
int enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream,
int device_id,
T max_value);
bool keep_first_token_;
bool keep_order_;
bool flag_varseqlen_;
};
class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator {
......
......@@ -72,12 +72,16 @@ bool RecoverPaddingPlugin::supportsFormatCombination(
platform::errors::InvalidArgument("Must have 1 output, "
"but got %d output(s). ",
nbOutputs));
if (pos == 1) { // PosId, MaxSeqlen
if (pos == 1) { // PosId
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} else if (pos == 2) { // mask_id
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} else {
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// == nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
......
......@@ -72,9 +72,10 @@ bool RemovePaddingPlugin::supportsFormatCombination(
if (pos == 1 || pos == 2) { // pos_id, work_id
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} else {
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// == nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
......
......@@ -22,8 +22,10 @@ namespace tensorrt {
namespace plugin {
TEST(fused_token_prune_op_plugin, test_plugin) {
FusedTokenPrunePluginDynamic plugin(
true, /*keep_first_token*/ false, /*keep_order*/ true);
FusedTokenPrunePluginDynamic plugin(true,
/*keep_first_token*/ false,
/*keep_order*/ true,
/*flag_varseqlen*/ false);
plugin.configurePlugin(nullptr, 4, nullptr, 2);
plugin.initialize();
plugin.getPluginType();
......
......@@ -293,8 +293,10 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
auto *new_mask = engine_->DeclareInput(
"new_mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 2, 2});
plugin::FusedTokenPrunePluginDynamic *plugin =
new plugin::FusedTokenPrunePluginDynamic(
true, /*keep_first_token*/ false, /*keep_order*/ true);
new plugin::FusedTokenPrunePluginDynamic(true,
/*keep_first_token*/ false,
/*keep_order*/ true,
/*flag_varseqlen*/ false);
std::vector<nvinfer1::ITensor *> itensors = {attn, x, mask, new_mask};
auto *layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin);
PADDLE_ENFORCE_NOT_NULL(layer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册