From 954ebda134e3a5fc1a089a25a0fd96c9220595b7 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Tue, 25 Aug 2020 15:27:34 +0800 Subject: [PATCH] skip_layernorm is not related to B, S axis --- .../inference/tensorrt/convert/skip_layernorm.cc | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index 60c9763ecaa..20b471a4c2a 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -78,23 +78,11 @@ class SkipLayerNormOpConverter : public OpConverter { auto pluginObj = creator->createPlugin("CustomSkipLayerNormPluginDynamic", pluginPtr); - nvinfer1::Permutation permutation{1, 0, 2, 3, 4}; - auto trans_layer0 = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[0]); - auto trans_layer1 = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[1]); - trans_layer0->setFirstTranspose(permutation); - trans_layer1->setFirstTranspose(permutation); - std::vector trans_tensors; - trans_tensors.emplace_back(trans_layer0->getOutput(0)); - trans_tensors.emplace_back(trans_layer1->getOutput(0)); auto plugin_layer = engine_->network()->addPluginV2( - trans_tensors.data(), trans_tensors.size(), *pluginObj); + inputs.data(), inputs.size(), *pluginObj); assert(plugin_layer != nullptr); - auto trans_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0)); - assert(trans_layer != nullptr); - trans_layer->setFirstTranspose(permutation); - layer = trans_layer; + layer = plugin_layer; #else float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon")); bool ban_fp16 = engine_->disable_trt_plugin_fp16(); -- GitLab