提交 8284947b 编写于 作者: W whs 提交者: qingqing01

Fix infershape of im2sequence. (#12183)

上级 d1135906
...@@ -33,22 +33,14 @@ class Im2SequenceOp : public framework::OperatorWithKernel { ...@@ -33,22 +33,14 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(in_dim.size(), 4, PADDLE_ENFORCE_EQ(in_dim.size(), 4,
"Input(X) format must be 4D tensor, eg., NCHW."); "Input(X) format must be 4D tensor, eg., NCHW.");
int batch_size = in_dim[0];
int img_channels = in_dim[1]; int img_channels = in_dim[1];
int img_height = in_dim[2];
int img_width = in_dim[3];
auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels"); auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides"); auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0], ctx->SetOutputDim("Out",
paddings[2], strides[0]); {in_dim[0], img_channels * kernels[0] * kernels[1]});
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
paddings[3], strides[1]);
ctx->SetOutputDim("Out", {batch_size * output_height * output_width,
img_channels * kernels[0] * kernels[1]});
} }
}; };
......
...@@ -109,12 +109,13 @@ class Im2SequenceKernel : public framework::OpKernel<T> { ...@@ -109,12 +109,13 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
} }
out->set_lod(lod); out->set_lod(lod);
} else { } else {
out->mutable_data<T>(ctx.GetPlace());
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0], int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]); paddings[2], strides[0]);
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1], int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
paddings[3], strides[1]); paddings[3], strides[1]);
out->mutable_data<T>({batch_size * output_height * output_width,
img_channels * kernels[0] * kernels[1]},
ctx.GetPlace());
const std::vector<int> dilations({1, 1}); const std::vector<int> dilations({1, 1});
auto out_dims = out->dims(); auto out_dims = out->dims();
out->Resize({batch_size, out->numel() / batch_size}); out->Resize({batch_size, out->numel() / batch_size});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册