From ffe81af0734d11e3d0239f4aeb114c280a20dba3 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Mon, 15 Apr 2019 10:44:33 +0000 Subject: [PATCH] modified infer shape test=develop --- paddle/fluid/operators/affine_channel_op.cc | 11 +++++++++-- paddle/fluid/operators/conv_op.cc | 11 ++++++++--- paddle/fluid/operators/detection_map_op.cc | 6 ++++-- paddle/fluid/operators/roi_pool_op.cc | 14 ++++++++------ paddle/fluid/operators/row_conv_op.cc | 15 ++++++++++++--- paddle/fluid/operators/unpool_op.cc | 9 +++++++-- 6 files changed, 48 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index 268a5b894a..7663890a47 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -79,9 +79,16 @@ class AffineChannelOp : public framework::OperatorWithKernel { : x_dims[x_dims.size() - 1]); PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL); - PADDLE_ENFORCE_EQ(scale_dims[0], C); PADDLE_ENFORCE_EQ(b_dims.size(), 1UL); - PADDLE_ENFORCE_EQ(b_dims[0], C); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(scale_dims[0], C); + PADDLE_ENFORCE_EQ(b_dims[0], C); + } else { + if (scale_dims[0] > 0 && b_dims[0] > 0) { + PADDLE_ENFORCE_EQ(scale_dims[0], C); + PADDLE_ENFORCE_EQ(b_dims[0], C); + } + } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", "Out"); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 619e12e6ba..a78e8ca4fb 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -68,9 +68,14 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], - dilations[i], paddings[i], - strides[i])); + if ((!ctx->IsRuntime()) && + (in_dims[i + 2] == -1 || filter_dims[i + 2] == -1)) { + output_shape.push_back(-1); + } else { + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); + } } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->ShareLoD("Input", "Output"); diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index e1d113f854..554e50725f 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -51,8 +51,10 @@ class DetectionMAPOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(label_dims.size(), 2, "The rank of Input(Label) must be 2, " "the shape is [N, 6]."); - PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5, - "The shape of Input(Label) is [N, 6] or [N, 5]."); + if (ctx->IsRuntime() || label_dims[1] > 0) { + PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5, + "The shape of Input(Label) is [N, 6] or [N, 5]."); + } if (ctx->HasInput("PosCount")) { PADDLE_ENFORCE(ctx->HasInput("TruePos"), diff --git a/paddle/fluid/operators/roi_pool_op.cc b/paddle/fluid/operators/roi_pool_op.cc index cfac7e09e1..11b0bf3bee 100644 --- a/paddle/fluid/operators/roi_pool_op.cc +++ b/paddle/fluid/operators/roi_pool_op.cc @@ -50,12 +50,14 @@ class ROIPoolOp : public framework::OperatorWithKernel { int pooled_width = ctx->Attrs().Get("pooled_width"); float spatial_scale = ctx->Attrs().Get("spatial_scale"); - PADDLE_ENFORCE_GT(pooled_height, 0, - "The pooled output height must greater than 0"); - PADDLE_ENFORCE_GT(pooled_width, 0, - "The pooled output width must greater than 0"); - PADDLE_ENFORCE_GT(spatial_scale, 0.0f, - "The spatial scale must greater than 0"); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_GT(pooled_height, 0, + "The pooled output height must greater than 0"); + PADDLE_ENFORCE_GT(pooled_width, 0, + "The pooled output width must greater than 0"); + PADDLE_ENFORCE_GT(spatial_scale, 0.0f, + "The spatial scale must greater than 0"); + } auto out_dims = input_dims; out_dims[0] = rois_dims[0]; diff --git a/paddle/fluid/operators/row_conv_op.cc b/paddle/fluid/operators/row_conv_op.cc index 81aabdd006..31ac8d8de9 100644 --- a/paddle/fluid/operators/row_conv_op.cc +++ b/paddle/fluid/operators/row_conv_op.cc @@ -45,9 +45,18 @@ class RowConvOp : public framework::OperatorWithKernel { auto filter_dims = ctx->GetInputDim("Filter"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2."); - PADDLE_ENFORCE_EQ( - x_dims[1], filter_dims[1], - "The 2nd dimension of Input(X) and Input(Filter) should be same."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + x_dims[1], filter_dims[1], + "The 2nd dimension of Input(X) and Input(Filter) should be same."); + } else { + if (x_dims[1] > 0 && filter_dims[1] > 0) { + PADDLE_ENFORCE_EQ( + x_dims[1], filter_dims[1], + "The 2nd dimension of Input(X) and Input(Filter) should be same."); + } + } + ctx->SetOutputDim("Out", x_dims); ctx->ShareLoD("X", "Out"); } diff --git a/paddle/fluid/operators/unpool_op.cc b/paddle/fluid/operators/unpool_op.cc index 11e505d6df..bb487e94d5 100644 --- a/paddle/fluid/operators/unpool_op.cc +++ b/paddle/fluid/operators/unpool_op.cc @@ -99,10 +99,15 @@ class UnpoolOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(in_x_dims.size() == 4, "Unpooling intput must be of 4-dimensional."); PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims); + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); for (size_t i = 0; i < ksize.size(); ++i) { - output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i], - paddings[i], strides[i])); + if (!ctx->IsRuntime() && in_x_dims[i + 2] == -1) { + output_shape.push_back(-1); + } else { + output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i], + paddings[i], strides[i])); + } } ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } -- GitLab