未验证 提交 f4867e57 编写于 作者: S shangliang Xu 提交者: GitHub

[bug fix] fix unfold runtime bug (#38834)

上级 4d4a9c6c
...@@ -143,40 +143,47 @@ class UnfoldOp : public framework::OperatorWithKernel { ...@@ -143,40 +143,47 @@ class UnfoldOp : public framework::OperatorWithKernel {
"but recieved dilations_height: %d dilations_width: %d.", "but recieved dilations_height: %d dilations_width: %d.",
dilations[0], dilations[1])); dilations[0], dilations[1]));
std::vector<int> out_dims; bool contain_unknown_dim = framework::contain_unknown_dim(in_dims);
out_dims.push_back(in_dims[0]); bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1]; std::vector<int> out_dims;
out_dims.push_back(output_channels); out_dims.push_back(in_dims[0]);
int output_height = int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1];
CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0], out_dims.push_back(output_channels);
paddings[2], strides[0]);
int output_width = CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1], int output_height =
paddings[1], paddings[3], strides[1]); CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0],
// check output height and width paddings[2], strides[0]);
PADDLE_ENFORCE_GT( int output_width =
output_height, 0, CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1], paddings[1],
platform::errors::InvalidArgument( paddings[3], strides[1]);
"The sliding blocks calculated from input spatial size (%d, %d), " // check output height and width
"kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), " PADDLE_ENFORCE_GT(
"is (%d, %d), which should be a positive integer.", output_height, 0,
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1], platform::errors::InvalidArgument(
strides[0], strides[1], dilations[0], dilations[1], output_height, "The sliding blocks calculated from input spatial size "
output_width)); "(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
PADDLE_ENFORCE_GT( "dilations (%d, %d), is (%d, %d), which should be a "
output_width, 0, "positive integer.",
platform::errors::InvalidArgument( in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
"The sliding blocks calculated from input spatial size (%d, %d), " strides[0], strides[1], dilations[0], dilations[1], output_height,
"kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), " output_width));
"is (%d, %d), which should be a positive integer.", PADDLE_ENFORCE_GT(
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1], output_width, 0,
strides[0], strides[1], dilations[0], dilations[1], output_height, platform::errors::InvalidArgument(
output_width)); "The sliding blocks calculated from input spatial size "
int output_col_length = output_height * output_width; "(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
out_dims.push_back(output_col_length); "dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
ctx->SetOutputDim("Y", framework::make_ddim(out_dims)); in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
int output_col_length = output_height * output_width;
out_dims.push_back(output_col_length);
ctx->SetOutputDim("Y", framework::make_ddim(out_dims));
}
} }
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册