提交 954ebda1 编写于 作者: Z zlsh80826

skip_layernorm is not related to B, S axis

上级 2e72a0e3
......@@ -78,23 +78,11 @@ class SkipLayerNormOpConverter : public OpConverter {
auto pluginObj =
creator->createPlugin("CustomSkipLayerNormPluginDynamic", pluginPtr);
nvinfer1::Permutation permutation{1, 0, 2, 3, 4};
auto trans_layer0 = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[0]);
auto trans_layer1 = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[1]);
trans_layer0->setFirstTranspose(permutation);
trans_layer1->setFirstTranspose(permutation);
std::vector<nvinfer1::ITensor*> trans_tensors;
trans_tensors.emplace_back(trans_layer0->getOutput(0));
trans_tensors.emplace_back(trans_layer1->getOutput(0));
auto plugin_layer = engine_->network()->addPluginV2(
trans_tensors.data(), trans_tensors.size(), *pluginObj);
inputs.data(), inputs.size(), *pluginObj);
assert(plugin_layer != nullptr);
auto trans_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0));
assert(trans_layer != nullptr);
trans_layer->setFirstTranspose(permutation);
layer = trans_layer;
layer = plugin_layer;
#else
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool ban_fp16 = engine_->disable_trt_plugin_fp16();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册