未验证 提交 b031c389 编写于 作者: S shangliang Xu 提交者: GitHub

[bug fix] fix unfold bug in compile time (#38925)

上级 f082e171
...@@ -143,22 +143,18 @@ class UnfoldOp : public framework::OperatorWithKernel { ...@@ -143,22 +143,18 @@ class UnfoldOp : public framework::OperatorWithKernel {
"but recieved dilations_height: %d dilations_width: %d.", "but recieved dilations_height: %d dilations_width: %d.",
dilations[0], dilations[1])); dilations[0], dilations[1]));
bool contain_unknown_dim = framework::contain_unknown_dim(in_dims); std::vector<int> out_dims;
bool check = ctx->IsRuntime() || !contain_unknown_dim; out_dims.push_back(in_dims[0]);
if (check) { int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1];
std::vector<int> out_dims; out_dims.push_back(output_channels);
out_dims.push_back(in_dims[0]);
int output_height =
int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1]; CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0],
out_dims.push_back(output_channels); paddings[2], strides[0]);
int output_width = CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1],
int output_height = paddings[1], paddings[3], strides[1]);
CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0], if (ctx->IsRuntime()) {
paddings[2], strides[0]); // only check output height and width in runtime
int output_width =
CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1], paddings[1],
paddings[3], strides[1]);
// check output height and width
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
output_height, 0, output_height, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -179,11 +175,10 @@ class UnfoldOp : public framework::OperatorWithKernel { ...@@ -179,11 +175,10 @@ class UnfoldOp : public framework::OperatorWithKernel {
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1], in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height, strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width)); output_width));
int output_col_length = output_height * output_width;
out_dims.push_back(output_col_length);
ctx->SetOutputDim("Y", framework::make_ddim(out_dims));
} }
int output_col_length = output_height * output_width;
out_dims.push_back(output_col_length);
ctx->SetOutputDim("Y", framework::make_ddim(out_dims));
} }
protected: protected:
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -29,15 +30,6 @@ inline int CalcOutputSize(int input_size, int filter_size, int dilation, ...@@ -29,15 +30,6 @@ inline int CalcOutputSize(int input_size, int filter_size, int dilation,
int padding1, int padding2, int stride) { int padding1, int padding2, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1; const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1; int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1;
PADDLE_ENFORCE_GT(
output_size, 0UL,
platform::errors::InvalidArgument(
"Due to the settings of padding(%d, %d), filter_size(%d), "
"dilation(%d) and "
"stride(%d), the output size is less than 0, please check "
"again. Input_size:%d",
padding1, padding2, filter_size, dilation, stride, input_size));
return output_size; return output_size;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册