From 69dd5152cff75b1f595952cb9360e25739966150 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Fri, 22 Nov 2019 12:00:26 +0800 Subject: [PATCH] Fix the crash issue when scale or bias was null-pointer. (#21284) * Fix the crash issue when scale or bias was null-pointer. test=develop * Add the error message for passing CI. test=develop --- paddle/fluid/operators/layer_norm_op.h | 30 ++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index ff2935b0b45..5907d1d7278 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -210,17 +210,35 @@ class LayerNormKernel : public framework::OpKernel { ctx, &out, bias, /*axis*/ 1, AddFunctor(), &out); } #else - PADDLE_ENFORCE_EQ(mean->numel(), left); - PADDLE_ENFORCE_EQ(var->numel(), left); - PADDLE_ENFORCE_EQ(scale->numel(), right); - PADDLE_ENFORCE_EQ(bias->numel(), right); + PADDLE_ENFORCE_EQ(mean->numel(), left, + platform::errors::InvalidArgument( + "mean's length (%d) is not equal with expected (%d).", + mean->numel(), left)); + PADDLE_ENFORCE_EQ(var->numel(), left, + platform::errors::InvalidArgument( + "var's length (%d) is not equal with expected (%d).", + var->numel(), left)); + if (scale) { + PADDLE_ENFORCE_EQ( + scale->numel(), right, + platform::errors::InvalidArgument( + "scale's length (%d) is not equal with expected (%d).", + scale->numel(), right)); + } + if (bias) { + PADDLE_ENFORCE_EQ( + bias->numel(), right, + platform::errors::InvalidArgument( + "bias's length (%d) is not equal with expected (%d).", + bias->numel(), right)); + } auto ker = jit::KernelFuncs, platform::CPUPlace>::Cache() .At(right); ker(x.data(), out.data(), mean->data(), var->data(), - scale->data(), bias->data(), static_cast(left), - static_cast(epsilon), right); + scale ? scale->data() : nullptr, bias ? bias->data() : nullptr, + static_cast(left), static_cast(epsilon), right); #endif } }; -- GitLab