未验证 提交 8a66d999 编写于 作者: H handiz 提交者: GitHub

change skip-layernorm to adapt a new method (#52456)

* change skip-layernorm to adapt a new method

* fix review problem and add vlog

* fix review problem
上级 67a6dd32
......@@ -181,6 +181,11 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
layer_norm->Op()->GetAttr("out_threshold"));
}
if (layer_norm->Op()->HasAttr("smooth_scale")) {
new_desc.SetAttr("smooth_scale",
layer_norm->Op()->GetAttr("smooth_scale"));
}
// outputs
new_desc.SetOutput("Out", {layer_norm_out->Name()});
......
......@@ -50,6 +50,15 @@ class SkipLayerNormOpConverter : public OpConverter {
if (op_desc.HasAttr("enable_int8")) {
enable_int8 = PADDLE_GET_CONST(bool, op_desc.GetAttr("enable_int8"));
}
std::vector<float> smooth_scale;
bool use_smooth = false;
if (op_desc.HasAttr("smooth_scale")) {
smooth_scale =
PADDLE_GET_CONST(std::vector<float>, op_desc.GetAttr("smooth_scale"));
use_smooth = true;
}
auto bias_weight = GetWeight("Bias").get();
auto scale_weight = GetWeight("Scale").get();
nvinfer1::ILayer* layer = nullptr;
......@@ -121,7 +130,7 @@ class SkipLayerNormOpConverter : public OpConverter {
"in CustomSkipLayerNormPluginDynamic hidden "
"dimension should > 0"));
const std::vector<nvinfer1::PluginField> fields{
std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"ld", &hidden_size, nvinfer1::PluginFieldType::kINT32, 1},
{"beta",
......@@ -133,6 +142,15 @@ class SkipLayerNormOpConverter : public OpConverter {
GetPluginFieldType(scale_weight.type),
static_cast<int32_t>(scale_weight.count)},
};
if (use_smooth) {
VLOG(4) << "using special method, make sure you have correct version "
"of tensorrt";
type = static_cast<int32_t>(nvinfer1::DataType::kINT8);
fields.push_back({"smooth_scale",
smooth_scale.data(),
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(smooth_scale.size())});
nvinfer1::PluginFieldCollection* pluginPtr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(nvinfer1::PluginFieldCollection) +
......@@ -141,8 +159,32 @@ class SkipLayerNormOpConverter : public OpConverter {
pluginPtr->nbFields = static_cast<int32_t>(fields.size());
pluginPtr->fields = fields.data();
auto pluginObj =
creator->createPlugin("CustomSkipLayerNormPluginDynamic", pluginPtr);
auto pluginObj = creator->createPlugin(
"CustomSkipLayerNormPluginDynamicWithSmooth", pluginPtr);
free(pluginPtr);
auto plugin_layer = engine_->network()->addPluginV2(
inputs.data(), inputs.size(), *pluginObj);
PADDLE_ENFORCE_NE(
plugin_layer,
nullptr,
platform::errors::InvalidArgument(
"fail to add CustomSkipLayerNormPluginDynamicWithSmooth "
"layer"));
layer = plugin_layer;
} else {
nvinfer1::PluginFieldCollection* pluginPtr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(nvinfer1::PluginFieldCollection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
pluginPtr->nbFields = static_cast<int32_t>(fields.size());
pluginPtr->fields = fields.data();
auto pluginObj = creator->createPlugin(
"CustomSkipLayerNormPluginDynamic", pluginPtr);
free(pluginPtr);
......@@ -156,6 +198,7 @@ class SkipLayerNormOpConverter : public OpConverter {
"fail to add CustomSkipLayerNormPluginDynamic layer"));
layer = plugin_layer;
}
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "skip_layernorm", {output_name}, test_mode);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册