From 8a66d999f179f50a59f54e8c2fccd730b6d17039 Mon Sep 17 00:00:00 2001 From: handiz <35895648+ZhangHandi@users.noreply.github.com> Date: Tue, 4 Apr 2023 14:32:32 +0800 Subject: [PATCH] change skip-layernorm to adapt a new method (#52456) * change skip-layernorm to adapt a new method * fix review problem and add vlog * fix review problem --- .../ir/trt_skip_layernorm_fuse_pass.cc | 5 ++ .../tensorrt/convert/skip_layernorm.cc | 87 ++++++++++++++----- 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index 6186dd72688..4661d2cbf27 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -181,6 +181,11 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { layer_norm->Op()->GetAttr("out_threshold")); } + if (layer_norm->Op()->HasAttr("smooth_scale")) { + new_desc.SetAttr("smooth_scale", + layer_norm->Op()->GetAttr("smooth_scale")); + } + // outputs new_desc.SetOutput("Out", {layer_norm_out->Name()}); diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index fb0463f67e7..681f5798c1d 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -50,6 +50,15 @@ class SkipLayerNormOpConverter : public OpConverter { if (op_desc.HasAttr("enable_int8")) { enable_int8 = PADDLE_GET_CONST(bool, op_desc.GetAttr("enable_int8")); } + + std::vector smooth_scale; + bool use_smooth = false; + if (op_desc.HasAttr("smooth_scale")) { + smooth_scale = + PADDLE_GET_CONST(std::vector, op_desc.GetAttr("smooth_scale")); + use_smooth = true; + } + auto bias_weight = GetWeight("Bias").get(); auto scale_weight = GetWeight("Scale").get(); nvinfer1::ILayer* layer = nullptr; @@ -121,7 +130,7 @@ class SkipLayerNormOpConverter : public OpConverter { "in CustomSkipLayerNormPluginDynamic hidden " "dimension should > 0")); - const std::vector fields{ + std::vector fields{ {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, {"ld", &hidden_size, nvinfer1::PluginFieldType::kINT32, 1}, {"beta", @@ -133,28 +142,62 @@ class SkipLayerNormOpConverter : public OpConverter { GetPluginFieldType(scale_weight.type), static_cast(scale_weight.count)}, }; - nvinfer1::PluginFieldCollection* pluginPtr = - static_cast( - malloc(sizeof(nvinfer1::PluginFieldCollection) + - fields.size() * - sizeof(nvinfer1::PluginField))); // remember to free - pluginPtr->nbFields = static_cast(fields.size()); - pluginPtr->fields = fields.data(); - - auto pluginObj = - creator->createPlugin("CustomSkipLayerNormPluginDynamic", pluginPtr); - free(pluginPtr); - - auto plugin_layer = engine_->network()->addPluginV2( - inputs.data(), inputs.size(), *pluginObj); - - PADDLE_ENFORCE_NE( - plugin_layer, - nullptr, - platform::errors::InvalidArgument( - "fail to add CustomSkipLayerNormPluginDynamic layer")); - layer = plugin_layer; + if (use_smooth) { + VLOG(4) << "using special method, make sure you have correct version " + "of tensorrt"; + type = static_cast(nvinfer1::DataType::kINT8); + fields.push_back({"smooth_scale", + smooth_scale.data(), + nvinfer1::PluginFieldType::kFLOAT32, + static_cast(smooth_scale.size())}); + nvinfer1::PluginFieldCollection* pluginPtr = + static_cast( + malloc(sizeof(nvinfer1::PluginFieldCollection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + pluginPtr->nbFields = static_cast(fields.size()); + pluginPtr->fields = fields.data(); + + auto pluginObj = creator->createPlugin( + "CustomSkipLayerNormPluginDynamicWithSmooth", pluginPtr); + + free(pluginPtr); + + auto plugin_layer = engine_->network()->addPluginV2( + inputs.data(), inputs.size(), *pluginObj); + + PADDLE_ENFORCE_NE( + plugin_layer, + nullptr, + platform::errors::InvalidArgument( + "fail to add CustomSkipLayerNormPluginDynamicWithSmooth " + "layer")); + layer = plugin_layer; + } else { + nvinfer1::PluginFieldCollection* pluginPtr = + static_cast( + malloc(sizeof(nvinfer1::PluginFieldCollection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + pluginPtr->nbFields = static_cast(fields.size()); + pluginPtr->fields = fields.data(); + + auto pluginObj = creator->createPlugin( + "CustomSkipLayerNormPluginDynamic", pluginPtr); + + free(pluginPtr); + + auto plugin_layer = engine_->network()->addPluginV2( + inputs.data(), inputs.size(), *pluginObj); + + PADDLE_ENFORCE_NE( + plugin_layer, + nullptr, + platform::errors::InvalidArgument( + "fail to add CustomSkipLayerNormPluginDynamic layer")); + layer = plugin_layer; + } } auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "skip_layernorm", {output_name}, test_mode); -- GitLab