未验证 提交 84358115 编写于 作者: Z Zhaolong Xing 提交者: GitHub

fix emb eltwise layernorm (#24873)

test=develop
上级 a01113c3
......@@ -90,12 +90,6 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
"so the index should be zero,"
"but it's (%d)",
output_index));
PADDLE_ENFORCE_EQ(
nb_inputs, 3,
platform::errors::InvalidArgument(
"The Input of the EmbEltwiseLayernorm should be 3, but we found "
"it has (%d) inputs",
nb_inputs));
nvinfer1::DimsExprs ret;
ret.nbDims = 5;
ret.d[0] = inputs[0].d[0];
......@@ -113,13 +107,18 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_EQ(nb_outputs, 1,
platform::errors::InvalidArgument(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs.",
nb_outputs));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
(in_out && pos < (nb_inputs + nb_outputs));
int all_nums = nb_inputs + nb_outputs;
const nvinfer1::PluginTensorDesc &desc = in_out[pos];
if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
......@@ -131,18 +130,19 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
if (pos == 1 || pos == 2) {
if (pos < all_nums - 1) {
return desc.type == nvinfer1::DataType::kINT32 &&
desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1];
}
if (pos == 3) {
if (pos == all_nums - 1) {
if (sizeof(T) == sizeof(float)) {
return desc.type == nvinfer1::DataType::kFLOAT;
} else {
return desc.type == nvinfer1::DataType::kHALF;
}
}
return false;
}
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册