未验证 提交 84358115 编写于 作者: Z Zhaolong Xing 提交者: GitHub

fix emb eltwise layernorm (#24873)

test=develop
上级 a01113c3
...@@ -90,12 +90,6 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions( ...@@ -90,12 +90,6 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
"so the index should be zero," "so the index should be zero,"
"but it's (%d)", "but it's (%d)",
output_index)); output_index));
PADDLE_ENFORCE_EQ(
nb_inputs, 3,
platform::errors::InvalidArgument(
"The Input of the EmbEltwiseLayernorm should be 3, but we found "
"it has (%d) inputs",
nb_inputs));
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
ret.nbDims = 5; ret.nbDims = 5;
ret.d[0] = inputs[0].d[0]; ret.d[0] = inputs[0].d[0];
...@@ -113,13 +107,18 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination( ...@@ -113,13 +107,18 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument( in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr.")); "The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_EQ(nb_outputs, 1,
platform::errors::InvalidArgument(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs.",
nb_outputs));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs, pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the " platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.", "num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs)); pos, nb_inputs + nb_outputs));
(in_out && pos < (nb_inputs + nb_outputs));
int all_nums = nb_inputs + nb_outputs;
const nvinfer1::PluginTensorDesc &desc = in_out[pos]; const nvinfer1::PluginTensorDesc &desc = in_out[pos];
if (desc.format != nvinfer1::TensorFormat::kLINEAR) { if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
...@@ -131,18 +130,19 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination( ...@@ -131,18 +130,19 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
} }
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
if (pos == 1 || pos == 2) { if (pos < all_nums - 1) {
return desc.type == nvinfer1::DataType::kINT32 && return desc.type == nvinfer1::DataType::kINT32 &&
desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1]; desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1];
} }
if (pos == 3) { if (pos == all_nums - 1) {
if (sizeof(T) == sizeof(float)) { if (sizeof(T) == sizeof(float)) {
return desc.type == nvinfer1::DataType::kFLOAT; return desc.type == nvinfer1::DataType::kFLOAT;
} else { } else {
return desc.type == nvinfer1::DataType::kHALF; return desc.type == nvinfer1::DataType::kHALF;
} }
} }
return false;
} }
template <typename T> template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册