diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index 60c9763ecaa00e47644561e07cc21c12375f42b3..20b471a4c2af4edbff04368f4609ac4865ae1513 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -78,23 +78,11 @@ class SkipLayerNormOpConverter : public OpConverter { auto pluginObj = creator->createPlugin("CustomSkipLayerNormPluginDynamic", pluginPtr); - nvinfer1::Permutation permutation{1, 0, 2, 3, 4}; - auto trans_layer0 = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[0]); - auto trans_layer1 = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[1]); - trans_layer0->setFirstTranspose(permutation); - trans_layer1->setFirstTranspose(permutation); - std::vector trans_tensors; - trans_tensors.emplace_back(trans_layer0->getOutput(0)); - trans_tensors.emplace_back(trans_layer1->getOutput(0)); auto plugin_layer = engine_->network()->addPluginV2( - trans_tensors.data(), trans_tensors.size(), *pluginObj); + inputs.data(), inputs.size(), *pluginObj); assert(plugin_layer != nullptr); - auto trans_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0)); - assert(trans_layer != nullptr); - trans_layer->setFirstTranspose(permutation); - layer = trans_layer; + layer = plugin_layer; #else float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon")); bool ban_fp16 = engine_->disable_trt_plugin_fp16();