From b2dcdb5100c957ec9b1b3eec9c3c3e1f718d9560 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sun, 14 Apr 2019 14:12:31 +0000 Subject: [PATCH] infer shape compatable -1. test=develop --- paddle/fluid/operators/grid_sampler_op.cc | 9 ++++++--- paddle/fluid/operators/interpolate_op.cc | 9 +++++++-- paddle/fluid/operators/kldiv_loss_op.cc | 6 ++++-- paddle/fluid/operators/spectral_norm_op.cc | 20 +++++++++++++------- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 241184c6f4a..57a1fcd42da 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/grid_sampler_op.h" +#include #include "paddle/fluid/framework/op_registry.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" @@ -40,10 +41,12 @@ class GridSampleOp : public framework::OperatorWithKernel { "Input(X) of GridSampleOp should be 4-D Tensor."); PADDLE_ENFORCE(grid_dims.size() == 4, "Input(Grid) of GridSampleOp should be 4-D Tensor."); - PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); - PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], - "Input(X) and Input(Grid) dims[0] should be equal."); + if (ctx->IsRuntime() || grid_dims[3] > 0) { + PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); + } if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], + "Input(X) and Input(Grid) dims[0] should be equal."); PADDLE_ENFORCE_EQ( grid_dims[1], x_dims[2], "Input(X) dims[2] and Input(Grid) dims[1] should be equal."); diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 9f2e3ad4a5a..16291a06dcf 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -58,8 +58,13 @@ class InterpolateOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", "Out"); return; } - std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); - ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); + + if (ctx->IsRuntime() || (out_h > 0 && out_w > 0)) { + std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); + ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); + } else { + ctx->SetOutputDim("Out", dim_x); + } } protected: diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index a43f22c0496..a7c5d6305b0 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -35,8 +35,10 @@ class KLDivLossOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(), "Input(X) rank and Input(Target) rank should be same."); for (int i = 0; i < dim_x.size(); i++) { - PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i], - "Input(X) and Input(Target) should in same shape."); + if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) { + PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i], + "Input(X) and Input(Target) should in same shape."); + } } auto reduction = ctx->Attrs().Get("reduction"); diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 04f659a465a..ec5ee487729 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -56,13 +56,19 @@ class SpectralNormOp : public framework::OperatorWithKernel { } auto dim_u = ctx->GetInputDim("U"); auto dim_v = ctx->GetInputDim("V"); - PADDLE_ENFORCE_EQ(dim_u[0], h, - "Input(U) dims[0] should be equal to " - "Input(Weight) dims[Attr(dim)]"); - PADDLE_ENFORCE_EQ( - dim_v[0], w, - "Input(V) dims[0] should be equal to " - "the product of Input(Weight) dims except dims[Attr(dim)]"); + + if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) { + PADDLE_ENFORCE_EQ(dim_u[0], h, + "Input(U) dims[0] should be equal to " + "Input(Weight) dims[Attr(dim)]"); + } + + if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) { + PADDLE_ENFORCE_EQ( + dim_v[0], w, + "Input(V) dims[0] should be equal to " + "the product of Input(Weight) dims except dims[Attr(dim)]"); + } ctx->SetOutputDim("Out", dim_weight); ctx->ShareLoD("Weight", /*->*/ "Out"); -- GitLab