diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 14a2524bd8f4a9f7685c84f1d9767f5f7eedf0e7..241184c6f4a19a1da0d6d75c5d4e2b372c14e9da 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -43,12 +43,14 @@ class GridSampleOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], "Input(X) and Input(Grid) dims[0] should be equal."); - PADDLE_ENFORCE_EQ( - grid_dims[1], x_dims[2], - "Input(X) dims[2] and Input(Grid) dims[1] should be equal."); - PADDLE_ENFORCE_EQ( - grid_dims[2], x_dims[3], - "Input(X) dims[3] and Input(Grid) dims[2] should be equal."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + grid_dims[1], x_dims[2], + "Input(X) dims[2] and Input(Grid) dims[1] should be equal."); + PADDLE_ENFORCE_EQ( + grid_dims[2], x_dims[3], + "Input(X) dims[3] and Input(Grid) dims[2] should be equal."); + } ctx->SetOutputDim("Output", x_dims); ctx->ShareLoD("X", "Output");