From 843581154f26d82389269bd0b6e04d7b632fcc7e Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Thu, 18 Jun 2020 17:07:25 +0800 Subject: [PATCH] fix emb eltwise layernorm (#24873) test=develop --- .../plugin/emb_eltwise_layernorm_plugin.cu | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu index 175bc8c7945..575dfa68e6e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu @@ -90,12 +90,6 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions( "so the index should be zero," "but it's (%d)", output_index)); - PADDLE_ENFORCE_EQ( - nb_inputs, 3, - platform::errors::InvalidArgument( - "The Input of the EmbEltwiseLayernorm should be 3, but we found " - "it has (%d) inputs", - nb_inputs)); nvinfer1::DimsExprs ret; ret.nbDims = 5; ret.d[0] = inputs[0].d[0]; @@ -113,13 +107,18 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination( PADDLE_ENFORCE_NOT_NULL( in_out, platform::errors::InvalidArgument( "The input of swish plugin shoule not be nullptr.")); - + PADDLE_ENFORCE_EQ(nb_outputs, 1, + platform::errors::InvalidArgument( + "The EmbEltwiseLayerNorm's output should be one" + "but it's (%d) outputs.", + nb_outputs)); PADDLE_ENFORCE_LT( pos, nb_inputs + nb_outputs, platform::errors::InvalidArgument("The pos(%d) should be less than the " "num(%d) of the input and the output.", pos, nb_inputs + nb_outputs)); - (in_out && pos < (nb_inputs + nb_outputs)); + + int all_nums = nb_inputs + nb_outputs; const nvinfer1::PluginTensorDesc &desc = in_out[pos]; if (desc.format != nvinfer1::TensorFormat::kLINEAR) { @@ -131,18 +130,19 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination( } const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; - if (pos == 1 || pos == 2) { + if (pos < all_nums - 1) { return desc.type == nvinfer1::DataType::kINT32 && desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1]; } - if (pos == 3) { + if (pos == all_nums - 1) { if (sizeof(T) == sizeof(float)) { return desc.type == nvinfer1::DataType::kFLOAT; } else { return desc.type == nvinfer1::DataType::kHALF; } } + return false; } template -- GitLab