未验证 提交 5c357504 编写于 作者: S shangliang Xu 提交者: GitHub

[bug fix] fix unfold runtime bug (#38819)

上级 a8afed69
......@@ -143,40 +143,47 @@ class UnfoldOp : public framework::OperatorWithKernel {
"but recieved dilations_height: %d dilations_width: %d.",
dilations[0], dilations[1]));
std::vector<int> out_dims;
out_dims.push_back(in_dims[0]);
int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1];
out_dims.push_back(output_channels);
int output_height =
CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0],
paddings[2], strides[0]);
int output_width = CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1],
paddings[1], paddings[3], strides[1]);
// check output height and width
PADDLE_ENFORCE_GT(
output_height, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size (%d, %d), "
"kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), "
"is (%d, %d), which should be a positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
PADDLE_ENFORCE_GT(
output_width, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size (%d, %d), "
"kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), "
"is (%d, %d), which should be a positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
int output_col_length = output_height * output_width;
out_dims.push_back(output_col_length);
ctx->SetOutputDim("Y", framework::make_ddim(out_dims));
bool contain_unknown_dim = framework::contain_unknown_dim(in_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
std::vector<int> out_dims;
out_dims.push_back(in_dims[0]);
int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1];
out_dims.push_back(output_channels);
int output_height =
CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0],
paddings[2], strides[0]);
int output_width =
CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1], paddings[1],
paddings[3], strides[1]);
// check output height and width
PADDLE_ENFORCE_GT(
output_height, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size "
"(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
"dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
PADDLE_ENFORCE_GT(
output_width, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size "
"(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
"dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
int output_col_length = output_height * output_width;
out_dims.push_back(output_col_length);
ctx->SetOutputDim("Y", framework::make_ddim(out_dims));
}
}
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册