未验证 提交 63f3ada7 编写于 作者: C chengjuntao 提交者: GitHub

fix bug which input shape (#22965)

* fix bug which input shape, test=develop

* add error type,test=develop
上级 c8d17ab3
...@@ -177,22 +177,36 @@ class DeformableConvV1Op : public framework::OperatorWithKernel { ...@@ -177,22 +177,36 @@ class DeformableConvV1Op : public framework::OperatorWithKernel {
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], if ((!ctx->IsRuntime()) &&
dilations[i], paddings[i], (in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
strides[i])); output_shape.push_back(-1);
} else {
output_shape.push_back(ConvOutputSize(in_dims[i + 2],
filter_dims[i + 2], dilations[i],
paddings[i], strides[i]));
} }
PADDLE_ENFORCE_EQ(output_shape[1] % deformable_groups, 0U, }
"output num_filter must divide deformable group size."); if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
output_shape[1] % deformable_groups, 0U,
platform::errors::InvalidArgument(
"output num_filter must divide deformable group size."));
PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2], PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2],
"output height must equal to offset map height."); platform::errors::InvalidArgument(
"output height must equal to offset map height."));
PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3], PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3],
"output width must equal to offset map width."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U, "output width must equal to offset map width."));
"offset filter must divide deformable group size."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]), offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
platform::errors::InvalidArgument(
"offset filter must divide deformable group size."));
PADDLE_ENFORCE_EQ(
offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
deformable_groups, deformable_groups,
"offset filter must divide deformable group size."); platform::errors::InvalidArgument(
"offset filter must divide deformable group size."));
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册