未验证 提交 1bec83f4 编写于 作者: W Wangzheee 提交者: GitHub

disable_skip_layernorm_fp16 (#45041)

上级 9a04540c
......@@ -22,7 +22,8 @@ namespace tensorrt {
class SkipLayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
const framework::Scope& scope,
bool test_mode) override {
#if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert fused skip layernorm op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr);
......@@ -63,7 +64,8 @@ class SkipLayerNormOpConverter : public OpConverter {
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "3");
PADDLE_ENFORCE_NE(
creator, nullptr,
creator,
nullptr,
platform::errors::InvalidArgument(
"fail to get creator of CustomSkipLayerNormPluginDynamic"));
const std::vector<nvinfer1::PluginField> fields{
......@@ -85,7 +87,8 @@ class SkipLayerNormOpConverter : public OpConverter {
inputs.data(), inputs.size(), *pluginObj);
PADDLE_ENFORCE_NE(
plugin_layer, nullptr,
plugin_layer,
nullptr,
platform::errors::InvalidArgument(
"fail to add CustomSkipLayerNormPluginDynamic layer"));
layer = plugin_layer;
......@@ -93,14 +96,16 @@ class SkipLayerNormOpConverter : public OpConverter {
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "2");
PADDLE_ENFORCE_NE(
creator, nullptr,
creator,
nullptr,
platform::errors::InvalidArgument(
"fail to get creator of CustomSkipLayerNormPluginDynamic"));
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
int ld = input1->getDimensions().d[2]; // hidden dimension
PADDLE_ENFORCE_GT(ld, 0,
PADDLE_ENFORCE_GT(ld,
0,
platform::errors::InvalidArgument(
"in CustomSkipLayerNormPluginDynamic hidden "
"dimension should > 0"));
......@@ -128,18 +133,21 @@ class SkipLayerNormOpConverter : public OpConverter {
inputs.data(), inputs.size(), *pluginObj);
PADDLE_ENFORCE_NE(
plugin_layer, nullptr,
plugin_layer,
nullptr,
platform::errors::InvalidArgument(
"fail to add CustomSkipLayerNormPluginDynamic layer"));
layer = plugin_layer;
}
} else {
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
/* bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
*/
bool with_fp16 = false;
plugin::SkipLayerNormPluginDynamic* plugin =
new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size,
scale_size, eps, with_fp16);
new plugin::SkipLayerNormPluginDynamic(
bias, scale, bias_size, scale_size, eps, with_fp16);
layer = engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册