提交 69dd5152 编写于 作者: Y Yihua Xu 提交者: Tao Luo

Fix the crash issue when scale or bias was null-pointer. (#21284)

* Fix the crash issue when scale or bias was null-pointer.

test=develop

* Add the error message for passing CI.

test=develop
上级 698b8b73
......@@ -210,17 +210,35 @@ class LayerNormKernel : public framework::OpKernel<T> {
ctx, &out, bias, /*axis*/ 1, AddFunctor<T>(), &out);
}
#else
PADDLE_ENFORCE_EQ(mean->numel(), left);
PADDLE_ENFORCE_EQ(var->numel(), left);
PADDLE_ENFORCE_EQ(scale->numel(), right);
PADDLE_ENFORCE_EQ(bias->numel(), right);
PADDLE_ENFORCE_EQ(mean->numel(), left,
platform::errors::InvalidArgument(
"mean's length (%d) is not equal with expected (%d).",
mean->numel(), left));
PADDLE_ENFORCE_EQ(var->numel(), left,
platform::errors::InvalidArgument(
"var's length (%d) is not equal with expected (%d).",
var->numel(), left));
if (scale) {
PADDLE_ENFORCE_EQ(
scale->numel(), right,
platform::errors::InvalidArgument(
"scale's length (%d) is not equal with expected (%d).",
scale->numel(), right));
}
if (bias) {
PADDLE_ENFORCE_EQ(
bias->numel(), right,
platform::errors::InvalidArgument(
"bias's length (%d) is not equal with expected (%d).",
bias->numel(), right));
}
auto ker =
jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
.At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left),
static_cast<const float>(epsilon), right);
scale ? scale->data<T>() : nullptr, bias ? bias->data<T>() : nullptr,
static_cast<int>(left), static_cast<const float>(epsilon), right);
#endif
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册