diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index 6186dd72688e74d6b365f484347b98929eb04e70..4661d2cbf275664aeaf0b1ea20fe2656ad8c7ac6 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -181,6 +181,11 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { layer_norm->Op()->GetAttr("out_threshold")); } + if (layer_norm->Op()->HasAttr("smooth_scale")) { + new_desc.SetAttr("smooth_scale", + layer_norm->Op()->GetAttr("smooth_scale")); + } + // outputs new_desc.SetOutput("Out", {layer_norm_out->Name()}); diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index fb0463f67e7cbcce5d20f23bfc7451e555dcb7c5..681f5798c1da09f6e5e221377b9c5cacbfa15bf0 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -50,6 +50,15 @@ class SkipLayerNormOpConverter : public OpConverter { if (op_desc.HasAttr("enable_int8")) { enable_int8 = PADDLE_GET_CONST(bool, op_desc.GetAttr("enable_int8")); } + + std::vector smooth_scale; + bool use_smooth = false; + if (op_desc.HasAttr("smooth_scale")) { + smooth_scale = + PADDLE_GET_CONST(std::vector, op_desc.GetAttr("smooth_scale")); + use_smooth = true; + } + auto bias_weight = GetWeight("Bias").get(); auto scale_weight = GetWeight("Scale").get(); nvinfer1::ILayer* layer = nullptr; @@ -121,7 +130,7 @@ class SkipLayerNormOpConverter : public OpConverter { "in CustomSkipLayerNormPluginDynamic hidden " "dimension should > 0")); - const std::vector fields{ + std::vector fields{ {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, {"ld", &hidden_size, nvinfer1::PluginFieldType::kINT32, 1}, {"beta", @@ -133,28 +142,62 @@ class SkipLayerNormOpConverter : public OpConverter { GetPluginFieldType(scale_weight.type), static_cast(scale_weight.count)}, }; - nvinfer1::PluginFieldCollection* pluginPtr = - static_cast( - malloc(sizeof(nvinfer1::PluginFieldCollection) + - fields.size() * - sizeof(nvinfer1::PluginField))); // remember to free - pluginPtr->nbFields = static_cast(fields.size()); - pluginPtr->fields = fields.data(); - - auto pluginObj = - creator->createPlugin("CustomSkipLayerNormPluginDynamic", pluginPtr); - free(pluginPtr); - - auto plugin_layer = engine_->network()->addPluginV2( - inputs.data(), inputs.size(), *pluginObj); - - PADDLE_ENFORCE_NE( - plugin_layer, - nullptr, - platform::errors::InvalidArgument( - "fail to add CustomSkipLayerNormPluginDynamic layer")); - layer = plugin_layer; + if (use_smooth) { + VLOG(4) << "using special method, make sure you have correct version " + "of tensorrt"; + type = static_cast(nvinfer1::DataType::kINT8); + fields.push_back({"smooth_scale", + smooth_scale.data(), + nvinfer1::PluginFieldType::kFLOAT32, + static_cast(smooth_scale.size())}); + nvinfer1::PluginFieldCollection* pluginPtr = + static_cast( + malloc(sizeof(nvinfer1::PluginFieldCollection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + pluginPtr->nbFields = static_cast(fields.size()); + pluginPtr->fields = fields.data(); + + auto pluginObj = creator->createPlugin( + "CustomSkipLayerNormPluginDynamicWithSmooth", pluginPtr); + + free(pluginPtr); + + auto plugin_layer = engine_->network()->addPluginV2( + inputs.data(), inputs.size(), *pluginObj); + + PADDLE_ENFORCE_NE( + plugin_layer, + nullptr, + platform::errors::InvalidArgument( + "fail to add CustomSkipLayerNormPluginDynamicWithSmooth " + "layer")); + layer = plugin_layer; + } else { + nvinfer1::PluginFieldCollection* pluginPtr = + static_cast( + malloc(sizeof(nvinfer1::PluginFieldCollection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + pluginPtr->nbFields = static_cast(fields.size()); + pluginPtr->fields = fields.data(); + + auto pluginObj = creator->createPlugin( + "CustomSkipLayerNormPluginDynamic", pluginPtr); + + free(pluginPtr); + + auto plugin_layer = engine_->network()->addPluginV2( + inputs.data(), inputs.size(), *pluginObj); + + PADDLE_ENFORCE_NE( + plugin_layer, + nullptr, + platform::errors::InvalidArgument( + "fail to add CustomSkipLayerNormPluginDynamic layer")); + layer = plugin_layer; + } } auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "skip_layernorm", {output_name}, test_mode);