未验证 提交 ed57237e 编写于 作者: S shangliang Xu 提交者: GitHub

fix infershape in compile time (#45156)

上级 27a5429a
...@@ -3872,6 +3872,7 @@ void UnfoldInferMeta(const MetaTensor& x, ...@@ -3872,6 +3872,7 @@ void UnfoldInferMeta(const MetaTensor& x,
paddings[1], paddings[1],
paddings[3], paddings[3],
strides[1]); strides[1]);
int output_col_length = output_height * output_width;
if (config.is_runtime) { if (config.is_runtime) {
// only check output height and width in runtime // only check output height and width in runtime
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
...@@ -3910,8 +3911,10 @@ void UnfoldInferMeta(const MetaTensor& x, ...@@ -3910,8 +3911,10 @@ void UnfoldInferMeta(const MetaTensor& x,
dilations[1], dilations[1],
output_height, output_height,
output_width)); output_width));
} else {
output_col_length =
output_height == -1 || output_width == -1 ? -1 : output_col_length;
} }
int output_col_length = output_height * output_width;
out_dims.push_back(output_col_length); out_dims.push_back(output_col_length);
out->set_dims(phi::make_ddim(out_dims)); out->set_dims(phi::make_ddim(out_dims));
} }
......
...@@ -26,7 +26,7 @@ inline int CalcOutputSize(int input_size, ...@@ -26,7 +26,7 @@ inline int CalcOutputSize(int input_size,
int stride) { int stride) {
const int dkernel = dilation * (filter_size - 1) + 1; const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1; int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1;
return output_size; return input_size == -1 ? -1 : output_size;
} }
} // namespace funcs } // namespace funcs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册