提交 2e72a0e3 编写于 作者: Z zlsh80826

skip layer norm w/ nvinfer plugin

上级 ea6ff5a2
...@@ -47,17 +47,62 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -47,17 +47,62 @@ class SkipLayerNormOpConverter : public OpConverter {
framework::DDim bias_dims, scale_dims; framework::DDim bias_dims, scale_dims;
auto* bias = get_persistable_data("Bias", &bias_dims); auto* bias = get_persistable_data("Bias", &bias_dims);
auto* scale = get_persistable_data("Scale", &scale_dims); auto* scale = get_persistable_data("Scale", &scale_dims);
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
int bias_size = framework::product(bias_dims); int bias_size = framework::product(bias_dims);
int scale_size = framework::product(scale_dims); int scale_size = framework::product(scale_dims);
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
#ifdef USE_NVINFER_PLUGIN
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "1");
assert(creator != nullptr);
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
int ld = input1->getDimensions().d[2]; // hidden dimension
assert(ld > 0);
const std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"ld", &ld, nvinfer1::PluginFieldType::kINT32, 1},
{"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size},
{"gamma", scale, nvinfer1::PluginFieldType::kFLOAT32, scale_size},
};
nvinfer1::PluginFieldCollection* pluginPtr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*pluginPtr) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
pluginPtr->nbFields = static_cast<int>(fields.size());
pluginPtr->fields = fields.data();
auto pluginObj =
creator->createPlugin("CustomSkipLayerNormPluginDynamic", pluginPtr);
nvinfer1::Permutation permutation{1, 0, 2, 3, 4};
auto trans_layer0 = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[0]);
auto trans_layer1 = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[1]);
trans_layer0->setFirstTranspose(permutation);
trans_layer1->setFirstTranspose(permutation);
std::vector<nvinfer1::ITensor*> trans_tensors;
trans_tensors.emplace_back(trans_layer0->getOutput(0));
trans_tensors.emplace_back(trans_layer1->getOutput(0));
auto plugin_layer = engine_->network()->addPluginV2(
trans_tensors.data(), trans_tensors.size(), *pluginObj);
assert(plugin_layer != nullptr);
auto trans_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0));
assert(trans_layer != nullptr);
trans_layer->setFirstTranspose(permutation);
layer = trans_layer;
#else
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool ban_fp16 = engine_->disable_trt_plugin_fp16(); bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SkipLayerNormPluginDynamic* plugin = plugin::SkipLayerNormPluginDynamic* plugin =
new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size, new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size,
scale_size, eps, ban_fp16); scale_size, eps, ban_fp16);
layer = engine_->AddPluginV2(inputs.data(), 2, plugin); layer = engine_->AddPluginV2(inputs.data(), 2, plugin);
#endif
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static" "You are running the Ernie(Bert) model in static"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册