提交 cdb2c4ac 编写于 作者: Z zlsh80826

replace hard core name to tensorrt engine input

上级 d9ad276c
...@@ -114,13 +114,17 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -114,13 +114,17 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_0")); // word_embedding, eval_placeholder_0 engine_->network()->getInput(0)->getName())); // word_embedding,
// eval_placeholder_0
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_1")); // sent_embedding, eval_placeholder_1 engine_->network()->getInput(1)->getName())); // sent_embedding,
// eval_placeholder_1
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_2")); // cu_seqlens, eval_placeholder_2 engine_->network()->getInput(2)->getName())); // cu_seqlens,
// eval_placeholder_2
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_3")); // max_seqlen, eval_placeholder_3 engine_->network()->getInput(3)->getName())); // max_seqlen,
// eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"CustomEmbLayerNormPluginDynamic", "2"); "CustomEmbLayerNormPluginDynamic", "2");
......
...@@ -149,9 +149,11 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -149,9 +149,11 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.emplace_back(fc_layer->getOutput(0)); plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(mask_tensor); plugin_inputs.emplace_back(mask_tensor);
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_2")); // cu_seqlens, eval_placeholder_2 engine_->network()->getInput(2)->getName())); // cu_seqlens,
// eval_placeholder_2
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(engine_->GetITensor(
"eval_placeholder_3")); // max_seqlen, eval_placeholder_3 engine_->network()->getInput(3)->getName())); // max_seqlen,
// eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer; layer = plugin_layer;
......
...@@ -50,7 +50,9 @@ class SliceOpConverter : public OpConverter { ...@@ -50,7 +50,9 @@ class SliceOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
// plugin_inputs.emplace_back(trans_layer->getOutput(0)); // plugin_inputs.emplace_back(trans_layer->getOutput(0));
plugin_inputs.emplace_back(input); plugin_inputs.emplace_back(input);
plugin_inputs.emplace_back(engine_->GetITensor("eval_placeholder_2")); plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()->getInput(2)->getName())); // cu_seqlens,
// eval_placeholder_2
// bool ban_fp16 = engine_->disable_trt_plugin_fp16(); // bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SpecialSlicePluginDynamic* plugin = plugin::SpecialSlicePluginDynamic* plugin =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册