未验证 提交 63f3ada7 编写于 作者: C chengjuntao 提交者: GitHub

fix bug which input shape (#22965)

* fix bug which input shape, test=develop

* add error type,test=develop
上级 c8d17ab3
......@@ -177,22 +177,36 @@ class DeformableConvV1Op : public framework::OperatorWithKernel {
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i],
strides[i]));
if ((!ctx->IsRuntime()) &&
(in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(ConvOutputSize(in_dims[i + 2],
filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
}
}
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
output_shape[1] % deformable_groups, 0U,
platform::errors::InvalidArgument(
"output num_filter must divide deformable group size."));
PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2],
platform::errors::InvalidArgument(
"output height must equal to offset map height."));
PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3],
platform::errors::InvalidArgument(
"output width must equal to offset map width."));
PADDLE_ENFORCE_EQ(
offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
platform::errors::InvalidArgument(
"offset filter must divide deformable group size."));
PADDLE_ENFORCE_EQ(
offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups,
platform::errors::InvalidArgument(
"offset filter must divide deformable group size."));
}
PADDLE_ENFORCE_EQ(output_shape[1] % deformable_groups, 0U,
"output num_filter must divide deformable group size.");
PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2],
"output height must equal to offset map height.");
PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3],
"output width must equal to offset map width.");
PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
"offset filter must divide deformable group size.");
PADDLE_ENFORCE_EQ(offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups,
"offset filter must divide deformable group size.");
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册