提交 cdb2c4ac 编写于 作者: Z zlsh80826

replace hard core name to tensorrt engine input

上级 d9ad276c
......@@ -114,13 +114,17 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_0")); // word_embedding, eval_placeholder_0
engine_->network()->getInput(0)->getName())); // word_embedding,
// eval_placeholder_0
plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_1")); // sent_embedding, eval_placeholder_1
engine_->network()->getInput(1)->getName())); // sent_embedding,
// eval_placeholder_1
plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_2")); // cu_seqlens, eval_placeholder_2
engine_->network()->getInput(2)->getName())); // cu_seqlens,
// eval_placeholder_2
plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_3")); // max_seqlen, eval_placeholder_3
engine_->network()->getInput(3)->getName())); // max_seqlen,
// eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomEmbLayerNormPluginDynamic", "2");
......
......@@ -149,9 +149,11 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(mask_tensor);
plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_2")); // cu_seqlens, eval_placeholder_2
engine_->network()->getInput(2)->getName())); // cu_seqlens,
// eval_placeholder_2
plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_3")); // max_seqlen, eval_placeholder_3
engine_->network()->getInput(3)->getName())); // max_seqlen,
// eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
......
......@@ -50,7 +50,9 @@ class SliceOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs;
// plugin_inputs.emplace_back(trans_layer->getOutput(0));
plugin_inputs.emplace_back(input);
plugin_inputs.emplace_back(engine_->GetITensor("eval_placeholder_2"));
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()->getInput(2)->getName())); // cu_seqlens,
// eval_placeholder_2
// bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SpecialSlicePluginDynamic* plugin =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册