提交 aeca5c50 编写于 作者: K Kaipeng Deng 提交者: qingqing01

fix grid_sampler PADDLE_ENFORCE error. test=develop (#15542)

上级 5f89ce7f
...@@ -43,12 +43,14 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -43,12 +43,14 @@ class GridSampleOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
"Input(X) and Input(Grid) dims[0] should be equal."); "Input(X) and Input(Grid) dims[0] should be equal.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2], grid_dims[1], x_dims[2],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal."); "Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grid_dims[2], x_dims[3], grid_dims[2], x_dims[3],
"Input(X) dims[3] and Input(Grid) dims[2] should be equal."); "Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
}
ctx->SetOutputDim("Output", x_dims); ctx->SetOutputDim("Output", x_dims);
ctx->ShareLoD("X", "Output"); ctx->ShareLoD("X", "Output");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册