未验证 提交 1bec83f4 编写于 作者: W Wangzheee 提交者: GitHub

disable_skip_layernorm_fp16 (#45041)

上级 9a04540c
...@@ -22,7 +22,8 @@ namespace tensorrt { ...@@ -22,7 +22,8 @@ namespace tensorrt {
class SkipLayerNormOpConverter : public OpConverter { class SkipLayerNormOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope,
bool test_mode) override {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert fused skip layernorm op to tensorrt layer"; VLOG(4) << "convert fused skip layernorm op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
...@@ -63,7 +64,8 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -63,7 +64,8 @@ class SkipLayerNormOpConverter : public OpConverter {
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "3"); "CustomSkipLayerNormPluginDynamic", "3");
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
creator, nullptr, creator,
nullptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"fail to get creator of CustomSkipLayerNormPluginDynamic")); "fail to get creator of CustomSkipLayerNormPluginDynamic"));
const std::vector<nvinfer1::PluginField> fields{ const std::vector<nvinfer1::PluginField> fields{
...@@ -85,7 +87,8 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -85,7 +87,8 @@ class SkipLayerNormOpConverter : public OpConverter {
inputs.data(), inputs.size(), *pluginObj); inputs.data(), inputs.size(), *pluginObj);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
plugin_layer, nullptr, plugin_layer,
nullptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"fail to add CustomSkipLayerNormPluginDynamic layer")); "fail to add CustomSkipLayerNormPluginDynamic layer"));
layer = plugin_layer; layer = plugin_layer;
...@@ -93,14 +96,16 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -93,14 +96,16 @@ class SkipLayerNormOpConverter : public OpConverter {
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "2"); "CustomSkipLayerNormPluginDynamic", "2");
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
creator, nullptr, creator,
nullptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"fail to get creator of CustomSkipLayerNormPluginDynamic")); "fail to get creator of CustomSkipLayerNormPluginDynamic"));
int type = static_cast<int>((engine_->WithFp16() == 1) int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF ? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT); : nvinfer1::DataType::kFLOAT);
int ld = input1->getDimensions().d[2]; // hidden dimension int ld = input1->getDimensions().d[2]; // hidden dimension
PADDLE_ENFORCE_GT(ld, 0, PADDLE_ENFORCE_GT(ld,
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"in CustomSkipLayerNormPluginDynamic hidden " "in CustomSkipLayerNormPluginDynamic hidden "
"dimension should > 0")); "dimension should > 0"));
...@@ -128,18 +133,21 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -128,18 +133,21 @@ class SkipLayerNormOpConverter : public OpConverter {
inputs.data(), inputs.size(), *pluginObj); inputs.data(), inputs.size(), *pluginObj);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
plugin_layer, nullptr, plugin_layer,
nullptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"fail to add CustomSkipLayerNormPluginDynamic layer")); "fail to add CustomSkipLayerNormPluginDynamic layer"));
layer = plugin_layer; layer = plugin_layer;
} }
} else { } else {
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon")); float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool with_fp16 = /* bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
*/
bool with_fp16 = false;
plugin::SkipLayerNormPluginDynamic* plugin = plugin::SkipLayerNormPluginDynamic* plugin =
new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size, new plugin::SkipLayerNormPluginDynamic(
scale_size, eps, with_fp16); bias, scale, bias_size, scale_size, eps, with_fp16);
layer = engine_->AddDynamicPlugin(inputs.data(), 2, plugin); layer = engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册