未验证 提交 12bf046b 编写于 作者: W wawltor 提交者: GitHub

add the safe check for the some ops (#34978)

上级 52a7b0c4
......@@ -36,22 +36,36 @@ class DiagEmbedOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("Input");
int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1;
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2;
int offset_ = std::abs(offset);
PADDLE_ENFORCE_GE(
dim1, -(x_dims.size() + 1),
platform::errors::OutOfRange(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1), x_dims.size(), dim1));
PADDLE_ENFORCE_LE(
dim1_, x_dims.size(),
dim1, x_dims.size(),
platform::errors::OutOfRange(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1), x_dims.size(), dim1));
PADDLE_ENFORCE_GE(
dim2, -(x_dims.size() + 1),
platform::errors::OutOfRange(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1), x_dims.size(), dim2));
PADDLE_ENFORCE_LE(
dim2_, x_dims.size(),
dim2, x_dims.size(),
platform::errors::OutOfRange(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1), x_dims.size(), dim2));
int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1;
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2;
int offset_ = std::abs(offset);
PADDLE_ENFORCE_NE(dim1_, dim2_,
platform::errors::InvalidArgument(
"diagonal dimensions should not be identical "
......
......@@ -330,6 +330,12 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx,
auto axis =
ctx.Attrs().Get<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.GetInputDim(input_name);
PADDLE_ENFORCE_GT(dim.size(), 0,
platform::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].",
dim));
if (!shape.empty() && !axis.empty()) {
PADDLE_ENFORCE_GE(
shape.size(), 2,
......
......@@ -35,6 +35,14 @@ class MatMulV2Op : public framework::OperatorWithKernel {
paddle::framework::vectorize(ctx->GetInputDim("Y"));
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x, 0,
platform::errors::InvalidArgument(
"The Input(X) dims size must be greater than 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y, 0,
platform::errors::InvalidArgument(
"The Input(Y) dims size must be greater than 0,"
" but reviced dims size is 0. "));
bool x_broadcasted = false, y_broadcasted = false;
if (ndims_x == 1) {
......
......@@ -25,7 +25,21 @@ class AucOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Predict"), "Input", "Predict", "Auc");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "Auc");
auto predict_width = ctx->GetInputDim("Predict")[1];
auto predict_dims = ctx->GetInputDim("Predict");
auto label_dims = ctx->GetInputDim("Label");
auto predict_width = predict_dims[1];
PADDLE_ENFORCE_NE(
framework::product(predict_dims), 0,
platform::errors::InvalidArgument(
"The Input(Predict) has not been initialized properly. The "
"shape of Input(Predict) = [%s], the shape can not involes 0.",
predict_dims));
PADDLE_ENFORCE_NE(
framework::product(label_dims), 0,
platform::errors::InvalidArgument(
"The Input(Label) has not been initialized properly. The "
"shape of Input(Label) = [%s], the shape can not involes 0.",
label_dims));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_LE(predict_width, 2,
platform::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册