From 8284947b820e5b50e8af047bc40bc37b7b379830 Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 17 Jul 2018 22:31:56 +0800 Subject: [PATCH] Fix infershape of im2sequence. (#12183) --- paddle/fluid/operators/im2sequence_op.cc | 12 ++---------- paddle/fluid/operators/im2sequence_op.h | 5 +++-- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/im2sequence_op.cc b/paddle/fluid/operators/im2sequence_op.cc index c8c7f36536a..8efd43928aa 100644 --- a/paddle/fluid/operators/im2sequence_op.cc +++ b/paddle/fluid/operators/im2sequence_op.cc @@ -33,22 +33,14 @@ class Im2SequenceOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(in_dim.size(), 4, "Input(X) format must be 4D tensor, eg., NCHW."); - int batch_size = in_dim[0]; int img_channels = in_dim[1]; - int img_height = in_dim[2]; - int img_width = in_dim[3]; auto kernels = ctx->Attrs().Get>("kernels"); auto strides = ctx->Attrs().Get>("strides"); auto paddings = ctx->Attrs().Get>("paddings"); - int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0], - paddings[2], strides[0]); - int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1], - paddings[3], strides[1]); - - ctx->SetOutputDim("Out", {batch_size * output_height * output_width, - img_channels * kernels[0] * kernels[1]}); + ctx->SetOutputDim("Out", + {in_dim[0], img_channels * kernels[0] * kernels[1]}); } }; diff --git a/paddle/fluid/operators/im2sequence_op.h b/paddle/fluid/operators/im2sequence_op.h index 5bfb91db188..4a994281941 100644 --- a/paddle/fluid/operators/im2sequence_op.h +++ b/paddle/fluid/operators/im2sequence_op.h @@ -109,12 +109,13 @@ class Im2SequenceKernel : public framework::OpKernel { } out->set_lod(lod); } else { - out->mutable_data(ctx.GetPlace()); int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0], paddings[2], strides[0]); int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1], paddings[3], strides[1]); - + out->mutable_data({batch_size * output_height * output_width, + img_channels * kernels[0] * kernels[1]}, + ctx.GetPlace()); const std::vector dilations({1, 1}); auto out_dims = out->dims(); out->Resize({batch_size, out->numel() / batch_size}); -- GitLab