From 12bf046b951b7d6a357ebdbbfabf9412c7c29332 Mon Sep 17 00:00:00 2001 From: wawltor Date: Wed, 18 Aug 2021 11:13:42 +0800 Subject: [PATCH] add the safe check for the some ops (#34978) --- paddle/fluid/operators/diag_embed_op.cc | 26 ++++++++++++++++++------ paddle/fluid/operators/matmul_op.cc | 6 ++++++ paddle/fluid/operators/matmul_v2_op.cc | 8 ++++++++ paddle/fluid/operators/metrics/auc_op.cc | 16 ++++++++++++++- 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/diag_embed_op.cc b/paddle/fluid/operators/diag_embed_op.cc index 6d8bc4d219..7e0990df26 100644 --- a/paddle/fluid/operators/diag_embed_op.cc +++ b/paddle/fluid/operators/diag_embed_op.cc @@ -36,22 +36,36 @@ class DiagEmbedOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("Input"); - int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1; - int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2; - int offset_ = std::abs(offset); - + PADDLE_ENFORCE_GE( + dim1, -(x_dims.size() + 1), + platform::errors::OutOfRange( + "Dim1 is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size() + 1), x_dims.size(), dim1)); PADDLE_ENFORCE_LE( - dim1_, x_dims.size(), + dim1, x_dims.size(), platform::errors::OutOfRange( "Dim1 is out of range (expected to be in range of [%ld, " "%ld], but got %ld).", -(x_dims.size() + 1), x_dims.size(), dim1)); + + PADDLE_ENFORCE_GE( + dim2, -(x_dims.size() + 1), + platform::errors::OutOfRange( + "Dim2 is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size() + 1), x_dims.size(), dim2)); PADDLE_ENFORCE_LE( - dim2_, x_dims.size(), + dim2, x_dims.size(), platform::errors::OutOfRange( "Dim2 is out of range (expected to be in range of [%ld, " "%ld], but got %ld).", -(x_dims.size() + 1), x_dims.size(), dim2)); + + int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1; + int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2; + int offset_ = std::abs(offset); + PADDLE_ENFORCE_NE(dim1_, dim2_, platform::errors::InvalidArgument( "diagonal dimensions should not be identical " diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 988a6c4f7d..78747108d4 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -330,6 +330,12 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, auto axis = ctx.Attrs().Get>("fused_transpose_" + input_name); auto dim = ctx.GetInputDim(input_name); + + PADDLE_ENFORCE_GT(dim.size(), 0, + platform::errors::InvalidArgument( + "The Input(%s) has not been initialized properly. The " + "shape of Input(%s) = [%s].", + dim)); if (!shape.empty() && !axis.empty()) { PADDLE_ENFORCE_GE( shape.size(), 2, diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index d39eac0759..4ec9a052bb 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -35,6 +35,14 @@ class MatMulV2Op : public framework::OperatorWithKernel { paddle::framework::vectorize(ctx->GetInputDim("Y")); auto ndims_x = dims_x.size(); auto ndims_y = dims_y.size(); + PADDLE_ENFORCE_GT(ndims_x, 0, + platform::errors::InvalidArgument( + "The Input(X) dims size must be greater than 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_GT(ndims_y, 0, + platform::errors::InvalidArgument( + "The Input(Y) dims size must be greater than 0," + " but reviced dims size is 0. ")); bool x_broadcasted = false, y_broadcasted = false; if (ndims_x == 1) { diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index 1dfb22718e..4f2f1d0722 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -25,7 +25,21 @@ class AucOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { OP_INOUT_CHECK(ctx->HasInput("Predict"), "Input", "Predict", "Auc"); OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "Auc"); - auto predict_width = ctx->GetInputDim("Predict")[1]; + auto predict_dims = ctx->GetInputDim("Predict"); + auto label_dims = ctx->GetInputDim("Label"); + auto predict_width = predict_dims[1]; + PADDLE_ENFORCE_NE( + framework::product(predict_dims), 0, + platform::errors::InvalidArgument( + "The Input(Predict) has not been initialized properly. The " + "shape of Input(Predict) = [%s], the shape can not involes 0.", + predict_dims)); + PADDLE_ENFORCE_NE( + framework::product(label_dims), 0, + platform::errors::InvalidArgument( + "The Input(Label) has not been initialized properly. The " + "shape of Input(Label) = [%s], the shape can not involes 0.", + label_dims)); if (ctx->IsRuntime()) { PADDLE_ENFORCE_LE(predict_width, 2, platform::errors::InvalidArgument( -- GitLab