未验证 提交 24187fcb 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference] add varlen_token_prune plugin, pass, convert (#44733)

* add varlen_token_prune plugin, pass, convert
上级 8482f1ae
...@@ -359,6 +359,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -359,6 +359,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std::vector<int64_t> skip_layernorm_x_shape = std::vector<int64_t> skip_layernorm_x_shape =
skip_layernorm_x->Var()->GetShape(); skip_layernorm_x->Var()->GetShape();
check_flag = true;
if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) { if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) {
check_flag = false; check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return " VLOG(3) << "Transformer model remove_padding shape check failed, return "
...@@ -395,6 +396,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -395,6 +396,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc); GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc);
std::vector<int64_t> fc_input_shape = fc_input->Var()->GetShape(); std::vector<int64_t> fc_input_shape = fc_input->Var()->GetShape();
check_flag = true;
if ((fc_input_shape.size() != multihead_matmul_input_shape.size()) || if ((fc_input_shape.size() != multihead_matmul_input_shape.size()) ||
(fc_input_shape.size() != 3)) { (fc_input_shape.size() != 3)) {
check_flag = false; check_flag = false;
...@@ -446,11 +448,13 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -446,11 +448,13 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std::vector<int64_t> activation_input_shape = std::vector<int64_t> activation_input_shape =
activation_input->Var()->GetShape(); activation_input->Var()->GetShape();
check_flag = true;
if ((activation_input_shape.size() != if ((activation_input_shape.size() !=
multihead_matmul_input_shape.size()) || multihead_matmul_input_shape.size()) ||
(activation_input_shape.size() != 3)) { (activation_input_shape.size() != 3)) {
check_flag = false; check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return " VLOG(3) << "Activation: Transformer model remove_padding "
"shape(activation_input_shape.size()) check failed, return "
"remove_padding pass."; "remove_padding pass.";
return; return;
} }
...@@ -465,7 +469,8 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -465,7 +469,8 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
check_flag = false; check_flag = false;
} }
if (!check_flag) { if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return " VLOG(3) << "Activation: Transformer model remove_padding "
"shape(activation_input_shape[i]) check failed, return "
"remove_padding pass."; "remove_padding pass.";
return; return;
} }
...@@ -530,6 +535,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -530,6 +535,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std::vector<int64_t> skip_layernorm_x_shape = std::vector<int64_t> skip_layernorm_x_shape =
preln_skip_layernorm_x->Var()->GetShape(); preln_skip_layernorm_x->Var()->GetShape();
check_flag = true;
if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) { if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) {
check_flag = false; check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return " VLOG(3) << "Transformer model remove_padding shape check failed, return "
......
...@@ -60,6 +60,50 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -60,6 +60,50 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
std::vector<std::string>{word_id_name, pos_id_name, sent_id_name}; std::vector<std::string>{word_id_name, pos_id_name, sent_id_name};
emb_names = emb_names =
std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name}; std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name};
auto mask_id_tensor = engine_->GetITensor("mask_id");
auto mask_dims = mask_id_tensor->getDimensions();
auto slice_start_dims = mask_dims;
auto slice_stride_dims = mask_dims;
for (int i = 0; i < mask_dims.nbDims; i++) {
slice_start_dims.d[i] = 0;
slice_stride_dims.d[i] = 1;
}
auto* shape_tensor = Shape(mask_id_tensor);
std::vector<nvinfer1::ITensor*> size_vec_tensor;
for (int i = 0; i < mask_dims.nbDims; i++) {
size_vec_tensor.push_back(Add1DConstantLayer(1));
}
size_vec_tensor[1] = GetEleTensorOfShape(shape_tensor, 1);
auto size_tensor = Concat(size_vec_tensor);
auto slice_layer =
TRT_ENGINE_ADD_LAYER(engine_,
Slice,
*mask_id_tensor,
slice_start_dims,
slice_start_dims,
slice_stride_dims); // unuseful slice_start_dims
slice_layer->setInput(2, *size_tensor);
slice_layer->setName(
("Embeltwise_slice_layer (Output: slice_max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(slice_layer->getOutput(0), 1.0f);
auto* reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *slice_layer->getOutput(0));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
reshape_layer->setReshapeDimensions(shape_dim);
reshape_layer->setName(("Embeltwise_reshape_layer (Output: max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f);
engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0));
} else { } else {
id_names = op_desc.Input("Ids"); id_names = op_desc.Input("Ids");
emb_names = op_desc.Input("Embs"); emb_names = op_desc.Input("Embs");
...@@ -192,20 +236,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -192,20 +236,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs.emplace_back( plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens, engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2 // eval_placeholder_2
auto max_seqlen_tensor = engine_->GetITensor(mask_id_name); plugin_inputs.emplace_back(engine_->GetITensor(
auto* shuffle_layer = "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor);
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
shuffle_layer->setName(
("Embeltwise_Shuffle_reshape (Output: max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"CustomEmbLayerNormPluginDynamic", "2"); "CustomEmbLayerNormPluginDynamic", "2");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
...@@ -23,7 +23,6 @@ class FusedTokenPruneOpConverter : public OpConverter { ...@@ -23,7 +23,6 @@ class FusedTokenPruneOpConverter : public OpConverter {
bool test_mode) override { bool test_mode) override {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
auto* Attn = engine_->GetITensor(op_desc.Input("Attn").front()); auto* Attn = engine_->GetITensor(op_desc.Input("Attn").front());
auto* X = engine_->GetITensor(op_desc.Input("X").front()); auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Mask = engine_->GetITensor(op_desc.Input("Mask").front()); auto* Mask = engine_->GetITensor(op_desc.Input("Mask").front());
...@@ -36,28 +35,54 @@ class FusedTokenPruneOpConverter : public OpConverter { ...@@ -36,28 +35,54 @@ class FusedTokenPruneOpConverter : public OpConverter {
op_desc.HasAttr("keep_order") op_desc.HasAttr("keep_order")
? PADDLE_GET_CONST(bool, op_desc.GetAttr("keep_order")) ? PADDLE_GET_CONST(bool, op_desc.GetAttr("keep_order"))
: false; : false;
std::vector<nvinfer1::ITensor*> itensors = {Attn, X, Mask, NewMask};
auto output_name = op_desc.Output("SlimmedX")[0]; auto output_name = op_desc.Output("SlimmedX")[0];
auto out_inds_name = op_desc.Output("CLSInds")[0]; auto out_inds_name = op_desc.Output("CLSInds")[0];
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
bool with_fp16 = bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) { if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = true; with_fp16 = true;
} }
bool flag_varseqlen = engine_->use_varseqlen();
plugin::FusedTokenPrunePluginDynamic* plugin = plugin::FusedTokenPrunePluginDynamic* plugin =
new plugin::FusedTokenPrunePluginDynamic( new plugin::FusedTokenPrunePluginDynamic(
with_fp16, keep_first_token, keep_order); with_fp16, keep_first_token, keep_order, flag_varseqlen);
layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin); if (flag_varseqlen) {
#else auto* word_id = engine_->GetITensor("word_id");
PADDLE_THROW(platform::errors::Fatal( auto* pos_id = engine_->GetITensor("pos_id");
"You are running the TRT Dynamic Shape mode, need to confirm that " auto* mask_id = engine_->GetITensor("mask_id");
"your TRT version is no less than 6.0")); std::vector<nvinfer1::ITensor*> itensors = {
#endif Attn, X, Mask, NewMask, word_id, pos_id, mask_id};
layer = engine_->AddDynamicPlugin(itensors.data(), 7, plugin);
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
layer->getOutput(1)->setName(out_inds_name.c_str());
engine_->SetITensor(out_inds_name, layer->getOutput(1));
engine_->DeleteITensor("word_id", word_id);
layer->getOutput(2)->setName("word_id_after_token_prune");
engine_->SetITensor("word_id", layer->getOutput(2));
engine_->DeleteITensor("pos_id", pos_id);
layer->getOutput(3)->setName("pos_id_after_token_prune");
engine_->SetITensor("pos_id", layer->getOutput(3));
engine_->DeleteITensor("mask_id", mask_id);
layer->getOutput(4)->setName("mask_id_after_token_prune");
engine_->SetITensor("mask_id", layer->getOutput(4));
} else {
std::vector<nvinfer1::ITensor*> itensors = {Attn, X, Mask, NewMask};
layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin);
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
layer->getOutput(1)->setName(out_inds_name.c_str());
engine_->SetITensor(out_inds_name, layer->getOutput(1));
}
layer->setName(
("fused_token_prune(Output: " + output_name + ")").c_str());
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static shape mode, which " "You are running the Ernie(Bert) model in static shape mode, which "
...@@ -65,8 +90,6 @@ class FusedTokenPruneOpConverter : public OpConverter { ...@@ -65,8 +90,6 @@ class FusedTokenPruneOpConverter : public OpConverter {
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set " "You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode.")); "the shape information to run the dynamic shape mode."));
} }
RreplenishLayerAndOutput(
layer, "fused_token_prune", {output_name, out_inds_name}, test_mode);
} }
}; };
......
...@@ -94,6 +94,8 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -94,6 +94,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data), static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())}; static_cast<int32_t>(bias_t->numel())};
auto max_seqlen_tensor = engine_->GetITensor("max_seqlen_tensor");
auto pos_id_tensor = engine_->GetITensor("pos_id");
if (engine_->with_interleaved()) { if (engine_->with_interleaved()) {
VLOG(4) << "fused multihead_matmul op: use_varseqlen and " VLOG(4) << "fused multihead_matmul op: use_varseqlen and "
"with_interleaved"; "with_interleaved";
...@@ -154,31 +156,9 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -154,31 +156,9 @@ class MultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0)); plugin_inputs.emplace_back(fc_layer->getOutput(0));
if (engine_->Has("ernie_pos_name")) { plugin_inputs.emplace_back(pos_id_tensor);
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->Get<std::string>("ernie_pos_name")));
} else {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()
->getInput(2)
->getName())); // cu_seqlens, eval_placeholder_2
}
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
engine_->SetTensorDynamicRange(max_seqlen_tensor, 1.0f);
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_,
Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back( plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 max_seqlen_tensor); // max_seqlen, eval_placeholder_3
shuffle_layer->setName(
("Multihead: Shuffle: (Output: " + output_name + ")").c_str());
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer; layer = plugin_layer;
...@@ -299,20 +279,9 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -299,20 +279,9 @@ class MultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0)); plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask")); plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask"));
plugin_inputs.emplace_back(engine_->GetITensor("pos_id")); plugin_inputs.emplace_back(pos_id_tensor);
auto max_seqlen_tensor = engine_->GetITensor("mask_id");
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_,
Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back( plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 max_seqlen_tensor); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
......
...@@ -157,20 +157,47 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -157,20 +157,47 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs.emplace_back( plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens, engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2 // eval_placeholder_2
auto max_seqlen_tensor = engine_->GetITensor(mask_id_name); auto mask_id_tensor = engine_->GetITensor("mask_id");
auto* shuffle_layer = auto mask_dims = mask_id_tensor->getDimensions();
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor); auto slice_start_dims = mask_dims;
auto slice_size_dims = mask_dims;
auto slice_stride_dims = mask_dims;
for (int i = 0; i < mask_dims.nbDims; i++) {
slice_start_dims.d[i] = 0;
slice_size_dims.d[i] = 1;
slice_stride_dims.d[i] = 1;
}
slice_size_dims.d[1] = mask_dims.d[1];
auto* slice_size_tensor = Add1DConstantLayer(slice_size_dims);
auto slice_layer =
TRT_ENGINE_ADD_LAYER(engine_,
Slice,
*mask_id_tensor,
slice_start_dims,
slice_start_dims,
slice_stride_dims); // unuseful slice_start_dims
slice_layer->setInput(2, *slice_size_tensor);
slice_layer->setName(
("PrelnEmbeltwise_slice_layer (Output: slice_max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(slice_layer->getOutput(0), 1.0f);
auto* reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *slice_layer->getOutput(0));
nvinfer1::Dims shape_dim; nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1; shape_dim.nbDims = 1;
shape_dim.d[0] = -1; shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim); reshape_layer->setReshapeDimensions(shape_dim);
shuffle_layer->setName( reshape_layer->setName(
("PrelnEmbeltwise_Shuffle_reshape (Output: max_seqlen " + ("PrelnEmbeltwise_reshape_layer (Output: max_seqlen " +
op_desc.Output("Out_0")[0] + ")") op_desc.Output("Out")[0] + ")")
.c_str()); .c_str());
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f); engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f);
engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0));
plugin_inputs.emplace_back( plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 reshape_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"CustomEmbLayerNormPluginDynamic", "3"); "CustomEmbLayerNormPluginDynamic", "3");
......
...@@ -111,6 +111,8 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { ...@@ -111,6 +111,8 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data), static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())}; static_cast<int32_t>(bias_t->numel())};
auto max_seqlen_tensor = engine_->GetITensor("max_seqlen_tensor");
auto pos_id_tensor = engine_->GetITensor("pos_id");
if (engine_->with_interleaved()) { if (engine_->with_interleaved()) {
VLOG(4) << "fused multihead_matmul op: use_varseqlen and " VLOG(4) << "fused multihead_matmul op: use_varseqlen and "
"with_interleaved"; "with_interleaved";
...@@ -171,31 +173,9 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { ...@@ -171,31 +173,9 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0)); plugin_inputs.emplace_back(fc_layer->getOutput(0));
if (engine_->Has("ernie_pos_name")) { plugin_inputs.emplace_back(pos_id_tensor);
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->Get<std::string>("ernie_pos_name")));
} else {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()
->getInput(2)
->getName())); // cu_seqlens, eval_placeholder_2
}
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
engine_->SetTensorDynamicRange(max_seqlen_tensor, 1.0f);
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_,
Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back( plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 max_seqlen_tensor); // max_seqlen, eval_placeholder_3
shuffle_layer->setName(
("Multihead: Shuffle: (Output: " + output_name + ")").c_str());
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer; layer = plugin_layer;
...@@ -316,21 +296,9 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { ...@@ -316,21 +296,9 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0)); plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask")); plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask"));
plugin_inputs.emplace_back(engine_->GetITensor("pos_id")); plugin_inputs.emplace_back(pos_id_tensor);
auto max_seqlen_tensor = engine_->GetITensor("mask_id");
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_,
Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back( plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 max_seqlen_tensor); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer; layer = plugin_layer;
......
...@@ -410,6 +410,19 @@ void TensorRTEngine::DeclareOutput(const std::string &name) { ...@@ -410,6 +410,19 @@ void TensorRTEngine::DeclareOutput(const std::string &name) {
name)); name));
network()->markOutput(*output); network()->markOutput(*output);
} }
void TensorRTEngine::DeleteITensor(const std::string &name,
nvinfer1::ITensor *tensor) {
PADDLE_ENFORCE_NOT_NULL(
tensor,
platform::errors::InvalidArgument(
"Tensor named %s of TRT engine should not be null.", name));
PADDLE_ENFORCE_EQ(
true,
itensor_map_.count(name),
platform::errors::InvalidArgument(
"Tensor named %s of TRT engine should not be null", name));
itensor_map_.erase(name);
}
void TensorRTEngine::SetITensor(const std::string &name, void TensorRTEngine::SetITensor(const std::string &name,
nvinfer1::ITensor *tensor) { nvinfer1::ITensor *tensor) {
......
...@@ -278,6 +278,7 @@ class TensorRTEngine { ...@@ -278,6 +278,7 @@ class TensorRTEngine {
void DeclareOutput(const std::string& name); void DeclareOutput(const std::string& name);
void ClearTensorMap() { itensor_map_.clear(); } void ClearTensorMap() { itensor_map_.clear(); }
void DeleteITensor(const std::string& name, nvinfer1::ITensor* tensor);
void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); void SetITensor(const std::string& name, nvinfer1::ITensor* tensor);
// Get an ITensor called name. // Get an ITensor called name.
nvinfer1::ITensor* GetITensor(const std::string& name); nvinfer1::ITensor* GetITensor(const std::string& name);
......
...@@ -177,22 +177,75 @@ __global__ void TakeAlongAxis(const T* src, ...@@ -177,22 +177,75 @@ __global__ void TakeAlongAxis(const T* src,
} }
} }
__global__ void pos_id_prune_kernel(const int32_t* src,
int32_t* dst,
int pos_nums,
float scale) {
dst[0] = 0;
for (int i = 1; i < pos_nums; i++) {
dst[i] =
dst[i - 1] + max(static_cast<int>((src[i] - src[i - 1]) * scale), 2);
}
}
nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions( nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions(
int output_index, int output_index,
const nvinfer1::DimsExprs* inputs, const nvinfer1::DimsExprs* inputs,
int nb_inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT { nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
auto x_dims = inputs[1], new_mask_dims = inputs[3]; auto x_dims = inputs[1], new_mask_dims = inputs[3];
if (output_index == 0) { if (flag_varseqlen_) {
nvinfer1::DimsExprs ret = x_dims; if (output_index == 0) {
ret.d[1] = new_mask_dims.d[2]; nvinfer1::DimsExprs ret = x_dims;
return ret; ret.d[1] = new_mask_dims.d[2];
return ret;
} else if (output_index == 1) {
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.d[0] = new_mask_dims.d[0];
ret.d[1] = new_mask_dims.d[2];
return ret;
} else if (output_index == 2) {
// word id
nvinfer1::DimsExprs ret;
ret.nbDims = 1;
// max sum of seqlen: pre_seqlen * new_mask[2] / mask[1] + 2 * batchs
const auto* two = expr_builder.constant(2);
ret.d[0] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*inputs[4].d[0],
*new_mask_dims.d[2]),
*inputs[6].d[1]),
*expr_builder.operation(
nvinfer1::DimensionOperation::kPROD, *two, *inputs[6].d[0]));
return ret;
} else if (output_index == 3) {
// pos id
nvinfer1::DimsExprs ret = inputs[5];
return ret;
} else if (output_index == 4) {
// mask id
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.d[0] = inputs[6].d[0];
ret.d[1] = new_mask_dims.d[2];
return ret;
}
} else { } else {
nvinfer1::DimsExprs ret; if (output_index == 0) {
ret.nbDims = 2; nvinfer1::DimsExprs ret = x_dims;
ret.d[0] = new_mask_dims.d[0]; ret.d[1] = new_mask_dims.d[2];
ret.d[1] = new_mask_dims.d[2]; return ret;
return ret; } else {
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.d[0] = new_mask_dims.d[0];
ret.d[1] = new_mask_dims.d[2];
return ret;
}
} }
} }
...@@ -215,26 +268,53 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( ...@@ -215,26 +268,53 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
nb_inputs + nb_outputs)); nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc& in = in_out[pos]; const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (pos == 0) { if (flag_varseqlen_) {
if (with_fp16_) { if (pos == 0) {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT || return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) && in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
#else #else
return (in.type == nvinfer1::DataType::kFLOAT) && return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
#endif #endif
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
} else if (pos <= 3 || pos == 7) {
const nvinfer1::PluginTensorDesc& prev = in_out[0];
return in.type == prev.type && in.format == prev.format;
} else if (pos == 6 || pos == 11) { // mask_id, mask_id_out
return in.type == nvinfer1::DataType::kFLOAT &&
in.format == nvinfer1::TensorFormat::kLINEAR;
} else { } else {
return (in.type == nvinfer1::DataType::kFLOAT) && return in.type == nvinfer1::DataType::kINT32 &&
(in.format == nvinfer1::TensorFormat::kLINEAR); in.format == nvinfer1::TensorFormat::kLINEAR;
} }
} else if (pos <= 4) {
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
return in.type == prev.type && in.format == prev.format;
} else { } else {
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1]; if (pos == 0) {
return in.type == nvinfer1::DataType::kINT32 && in.format == prev.format; if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
} else if (pos <= 4) {
const nvinfer1::PluginTensorDesc& prev = in_out[0];
return in.type == prev.type && in.format == prev.format;
} else {
return in.type == nvinfer1::DataType::kINT32 &&
in.format == nvinfer1::TensorFormat::kLINEAR;
}
} }
} }
...@@ -242,10 +322,22 @@ nvinfer1::DataType FusedTokenPrunePluginDynamic::getOutputDataType( ...@@ -242,10 +322,22 @@ nvinfer1::DataType FusedTokenPrunePluginDynamic::getOutputDataType(
int index, int index,
const nvinfer1::DataType* input_types, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT { int nb_inputs) const TRT_NOEXCEPT {
if (index == 0) { if (flag_varseqlen_) {
return input_types[1]; if (index == 0) {
} else if (index == 1) { return input_types[1];
return nvinfer1::DataType::kINT32; } else if (index == 4) {
return nvinfer1::DataType::kFLOAT;
} else {
// index = 1,2,3
return nvinfer1::DataType::kINT32;
}
} else {
if (index == 0) {
return input_types[1];
} else {
// index = 1
return nvinfer1::DataType::kINT32;
}
} }
} }
...@@ -273,15 +365,16 @@ size_t FusedTokenPrunePluginDynamic::getWorkspaceSize( ...@@ -273,15 +365,16 @@ size_t FusedTokenPrunePluginDynamic::getWorkspaceSize(
} }
template <typename T> template <typename T>
int FusedTokenPrunePluginDynamic::enqueueImpl( inline void enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* input_desc, const nvinfer1::PluginTensorDesc* output_desc,
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
const void* const* inputs, void* const* outputs,
void* const* outputs, void* workspace_ptr,
void* workspace_ptr, cudaStream_t stream,
cudaStream_t stream, int device_id,
int device_id, T max_value,
T max_value) { bool keep_first_token_,
bool keep_order_) {
// Dims // Dims
auto attn_dims = input_desc[0].dims; auto attn_dims = input_desc[0].dims;
auto x_dims = input_desc[1].dims; auto x_dims = input_desc[1].dims;
...@@ -462,8 +555,14 @@ int FusedTokenPrunePluginDynamic::enqueueImpl( ...@@ -462,8 +555,14 @@ int FusedTokenPrunePluginDynamic::enqueueImpl(
slimmed_x_len, slimmed_x_len,
c); c);
} }
}
return cudaGetLastError() != cudaSuccess; inline void pos_id_prune(const int32_t* input,
int32_t* output,
int pos_nums,
float scale,
cudaStream_t stream) {
pos_id_prune_kernel<<<1, 1, 0, stream>>>(input, output, pos_nums, scale);
} }
int FusedTokenPrunePluginDynamic::enqueue( int FusedTokenPrunePluginDynamic::enqueue(
...@@ -485,14 +584,16 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -485,14 +584,16 @@ int FusedTokenPrunePluginDynamic::enqueue(
float max = std::numeric_limits<float>::max(); float max = std::numeric_limits<float>::max();
return enqueueImpl<float>(input_desc, enqueueImpl<float>(input_desc,
output_desc, output_desc,
inputs, inputs,
outputs, outputs,
workspace, workspace,
stream, stream,
device_id, device_id,
max); max,
keep_first_token_,
keep_order_);
} else if (input_type == nvinfer1::DataType::kHALF) { } else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
...@@ -500,14 +601,16 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -500,14 +601,16 @@ int FusedTokenPrunePluginDynamic::enqueue(
half max = 65504.0; half max = 65504.0;
return enqueueImpl<half>(input_desc, enqueueImpl<half>(input_desc,
output_desc, output_desc,
inputs, inputs,
outputs, outputs,
workspace, workspace,
stream, stream,
device_id, device_id,
max); max,
keep_first_token_,
keep_order_);
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
...@@ -522,6 +625,17 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -522,6 +625,17 @@ int FusedTokenPrunePluginDynamic::enqueue(
platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type " platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type "
"should be float or half.")); "should be float or half."));
} }
if (flag_varseqlen_) {
float scale =
static_cast<float>(input_desc[3].dims.d[2]) / input_desc[6].dims.d[1];
// outputs[2]=inputs[4]; // word_id
const int32_t* inputs5 = static_cast<const int32_t*>(inputs[5]);
int32_t* outputs3 = static_cast<int32_t*>(outputs[3]);
pos_id_prune(
inputs5, outputs3, input_desc[5].dims.d[0], scale, stream); // pos_id
// outputs[4]=inputs[6]; // new_mask
}
return cudaGetLastError() != cudaSuccess;
} }
#endif #endif
......
...@@ -28,34 +28,45 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { ...@@ -28,34 +28,45 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
public: public:
explicit FusedTokenPrunePluginDynamic(bool with_fp16, explicit FusedTokenPrunePluginDynamic(bool with_fp16,
bool keep_first_token, bool keep_first_token,
bool keep_order) bool keep_order,
: keep_first_token_(keep_first_token), keep_order_(keep_order) { bool flag_varseqlen)
: keep_first_token_(keep_first_token),
keep_order_(keep_order),
flag_varseqlen_(flag_varseqlen) {
with_fp16_ = with_fp16; with_fp16_ = with_fp16;
} }
FusedTokenPrunePluginDynamic(void const* serial_data, size_t serial_length) { FusedTokenPrunePluginDynamic(void const* serial_data, size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &with_fp16_); DeserializeValue(&serial_data, &serial_length, &with_fp16_);
DeserializeValue(&serial_data, &serial_length, &keep_first_token_); DeserializeValue(&serial_data, &serial_length, &keep_first_token_);
DeserializeValue(&serial_data, &serial_length, &keep_order_); DeserializeValue(&serial_data, &serial_length, &keep_order_);
DeserializeValue(&serial_data, &serial_length, &flag_varseqlen_);
} }
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new FusedTokenPrunePluginDynamic( return new FusedTokenPrunePluginDynamic(
with_fp16_, keep_first_token_, keep_order_); with_fp16_, keep_first_token_, keep_order_, flag_varseqlen_);
} }
const char* getPluginType() const TRT_NOEXCEPT override { const char* getPluginType() const TRT_NOEXCEPT override {
return "fused_token_prune_plugin_dynamic"; return "fused_token_prune_plugin_dynamic";
} }
int getNbOutputs() const TRT_NOEXCEPT override { return 2; } int getNbOutputs() const TRT_NOEXCEPT override {
if (flag_varseqlen_) {
return 5;
} else {
return 2;
}
}
int initialize() TRT_NOEXCEPT override { return 0; } int initialize() TRT_NOEXCEPT override { return 0; }
size_t getSerializationSize() const TRT_NOEXCEPT override { size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(with_fp16_) + SerializedSize(keep_first_token_) + return SerializedSize(with_fp16_) + SerializedSize(keep_first_token_) +
SerializedSize(keep_order_); SerializedSize(keep_order_) + SerializedSize(flag_varseqlen_);
} }
void serialize(void* buffer) const TRT_NOEXCEPT override { void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, with_fp16_); SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, keep_first_token_); SerializeValue(&buffer, keep_first_token_);
SerializeValue(&buffer, keep_order_); SerializeValue(&buffer, keep_order_);
SerializeValue(&buffer, flag_varseqlen_);
} }
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
...@@ -95,17 +106,9 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { ...@@ -95,17 +106,9 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
void destroy() TRT_NOEXCEPT override { delete this; } void destroy() TRT_NOEXCEPT override { delete this; }
private: private:
template <typename T>
int enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream,
int device_id,
T max_value);
bool keep_first_token_; bool keep_first_token_;
bool keep_order_; bool keep_order_;
bool flag_varseqlen_;
}; };
class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator { class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator {
......
...@@ -72,12 +72,16 @@ bool RecoverPaddingPlugin::supportsFormatCombination( ...@@ -72,12 +72,16 @@ bool RecoverPaddingPlugin::supportsFormatCombination(
platform::errors::InvalidArgument("Must have 1 output, " platform::errors::InvalidArgument("Must have 1 output, "
"but got %d output(s). ", "but got %d output(s). ",
nbOutputs)); nbOutputs));
if (pos == 1) { // PosId, MaxSeqlen if (pos == 1) { // PosId
return inOut[pos].type == nvinfer1::DataType::kINT32 && return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} else if (pos == 2) { // mask_id
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} else {
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} }
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format // return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// == nvinfer1::TensorFormat::kLINEAR)|| // == nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == // (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
......
...@@ -72,9 +72,10 @@ bool RemovePaddingPlugin::supportsFormatCombination( ...@@ -72,9 +72,10 @@ bool RemovePaddingPlugin::supportsFormatCombination(
if (pos == 1 || pos == 2) { // pos_id, work_id if (pos == 1 || pos == 2) { // pos_id, work_id
return inOut[pos].type == nvinfer1::DataType::kINT32 && return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} else {
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} }
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format // return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// == nvinfer1::TensorFormat::kLINEAR)|| // == nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == // (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
......
...@@ -22,8 +22,10 @@ namespace tensorrt { ...@@ -22,8 +22,10 @@ namespace tensorrt {
namespace plugin { namespace plugin {
TEST(fused_token_prune_op_plugin, test_plugin) { TEST(fused_token_prune_op_plugin, test_plugin) {
FusedTokenPrunePluginDynamic plugin( FusedTokenPrunePluginDynamic plugin(true,
true, /*keep_first_token*/ false, /*keep_order*/ true); /*keep_first_token*/ false,
/*keep_order*/ true,
/*flag_varseqlen*/ false);
plugin.configurePlugin(nullptr, 4, nullptr, 2); plugin.configurePlugin(nullptr, 4, nullptr, 2);
plugin.initialize(); plugin.initialize();
plugin.getPluginType(); plugin.getPluginType();
......
...@@ -293,8 +293,10 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { ...@@ -293,8 +293,10 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
auto *new_mask = engine_->DeclareInput( auto *new_mask = engine_->DeclareInput(
"new_mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 2, 2}); "new_mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 2, 2});
plugin::FusedTokenPrunePluginDynamic *plugin = plugin::FusedTokenPrunePluginDynamic *plugin =
new plugin::FusedTokenPrunePluginDynamic( new plugin::FusedTokenPrunePluginDynamic(true,
true, /*keep_first_token*/ false, /*keep_order*/ true); /*keep_first_token*/ false,
/*keep_order*/ true,
/*flag_varseqlen*/ false);
std::vector<nvinfer1::ITensor *> itensors = {attn, x, mask, new_mask}; std::vector<nvinfer1::ITensor *> itensors = {attn, x, mask, new_mask};
auto *layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin); auto *layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin);
PADDLE_ENFORCE_NOT_NULL(layer, PADDLE_ENFORCE_NOT_NULL(layer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册