未验证 提交 f4867e57 编写于 作者: S shangliang Xu 提交者: GitHub

[bug fix] fix unfold runtime bug (#38834)

上级 4d4a9c6c
......@@ -143,6 +143,9 @@ class UnfoldOp : public framework::OperatorWithKernel {
"but recieved dilations_height: %d dilations_width: %d.",
dilations[0], dilations[1]));
bool contain_unknown_dim = framework::contain_unknown_dim(in_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
std::vector<int> out_dims;
out_dims.push_back(in_dims[0]);
......@@ -152,24 +155,27 @@ class UnfoldOp : public framework::OperatorWithKernel {
int output_height =
CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0],
paddings[2], strides[0]);
int output_width = CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1],
paddings[1], paddings[3], strides[1]);
int output_width =
CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1], paddings[1],
paddings[3], strides[1]);
// check output height and width
PADDLE_ENFORCE_GT(
output_height, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size (%d, %d), "
"kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), "
"is (%d, %d), which should be a positive integer.",
"The sliding blocks calculated from input spatial size "
"(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
"dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
PADDLE_ENFORCE_GT(
output_width, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size (%d, %d), "
"kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), "
"is (%d, %d), which should be a positive integer.",
"The sliding blocks calculated from input spatial size "
"(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
"dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
......@@ -178,6 +184,7 @@ class UnfoldOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Y", framework::make_ddim(out_dims));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册