From 379e3febf20a1a5e31839af898e830363b6a0c2e Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Mon, 2 Dec 2019 11:31:35 +0800 Subject: [PATCH] fix shape check in density_prior_box, test=develop (#21414) * fix shape check in density_prior_box, test=develop --- .../detection/density_prior_box_op.cc | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cc b/paddle/fluid/operators/detection/density_prior_box_op.cc index cfa5f467f0d..8a71ed0b13f 100644 --- a/paddle/fluid/operators/detection/density_prior_box_op.cc +++ b/paddle/fluid/operators/detection/density_prior_box_op.cc @@ -29,11 +29,23 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW."); PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); - PADDLE_ENFORCE_LT(input_dims[2], image_dims[2], - "The height of input must smaller than image."); - - PADDLE_ENFORCE_LT(input_dims[3], image_dims[3], - "The width of input must smaller than image."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_LT( + input_dims[2], image_dims[2], + platform::errors::InvalidArgument( + "The input tensor Input's height" + "of DensityPriorBoxOp should be smaller than input tensor Image's" + "hight. But received Input's height = %d, Image's height = %d", + input_dims[2], image_dims[2])); + + PADDLE_ENFORCE_LT( + input_dims[3], image_dims[3], + platform::errors::InvalidArgument( + "The input tensor Input's width" + "of DensityPriorBoxOp should be smaller than input tensor Image's" + "width. But received Input's width = %d, Image's width = %d", + input_dims[3], image_dims[3])); + } auto variances = ctx->Attrs().Get>("variances"); auto fixed_sizes = ctx->Attrs().Get>("fixed_sizes"); @@ -55,10 +67,13 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { dim_vec[3] = 4; ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec)); ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); - } else { + } else if (ctx->IsRuntime()) { int64_t dim0 = input_dims[2] * input_dims[3] * num_priors; ctx->SetOutputDim("Boxes", {dim0, 4}); ctx->SetOutputDim("Variances", {dim0, 4}); + } else { + ctx->SetOutputDim("Boxes", {-1, 4}); + ctx->SetOutputDim("Variances", {-1, 4}); } } -- GitLab