未验证 提交 379e3feb 编写于 作者: W wangguanzhong 提交者: GitHub

fix shape check in density_prior_box, test=develop (#21414)

* fix shape check in density_prior_box, test=develop
上级 6aa13f46
...@@ -29,11 +29,23 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { ...@@ -29,11 +29,23 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW."); PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW.");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
PADDLE_ENFORCE_LT(input_dims[2], image_dims[2], if (ctx->IsRuntime()) {
"The height of input must smaller than image."); PADDLE_ENFORCE_LT(
input_dims[2], image_dims[2],
PADDLE_ENFORCE_LT(input_dims[3], image_dims[3], platform::errors::InvalidArgument(
"The width of input must smaller than image."); "The input tensor Input's height"
"of DensityPriorBoxOp should be smaller than input tensor Image's"
"hight. But received Input's height = %d, Image's height = %d",
input_dims[2], image_dims[2]));
PADDLE_ENFORCE_LT(
input_dims[3], image_dims[3],
platform::errors::InvalidArgument(
"The input tensor Input's width"
"of DensityPriorBoxOp should be smaller than input tensor Image's"
"width. But received Input's width = %d, Image's width = %d",
input_dims[3], image_dims[3]));
}
auto variances = ctx->Attrs().Get<std::vector<float>>("variances"); auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
auto fixed_sizes = ctx->Attrs().Get<std::vector<float>>("fixed_sizes"); auto fixed_sizes = ctx->Attrs().Get<std::vector<float>>("fixed_sizes");
...@@ -55,10 +67,13 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { ...@@ -55,10 +67,13 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
dim_vec[3] = 4; dim_vec[3] = 4;
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec)); ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
} else { } else if (ctx->IsRuntime()) {
int64_t dim0 = input_dims[2] * input_dims[3] * num_priors; int64_t dim0 = input_dims[2] * input_dims[3] * num_priors;
ctx->SetOutputDim("Boxes", {dim0, 4}); ctx->SetOutputDim("Boxes", {dim0, 4});
ctx->SetOutputDim("Variances", {dim0, 4}); ctx->SetOutputDim("Variances", {dim0, 4});
} else {
ctx->SetOutputDim("Boxes", {-1, 4});
ctx->SetOutputDim("Variances", {-1, 4});
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册