diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index 0a51d50e06176e713922837861f2102c9ee8a899..de3612677440596387f313e1ff59184cb3fdb7ae 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -60,14 +60,15 @@ class BoxCoderOp : public framework::OperatorWithKernel { } else if (code_type == BoxCodeType::kDecodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, "The rank of Input TargetBox must be 3"); - if (axis == 0) { - PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); - } else if (axis == 1) { - PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]); - } else { - PADDLE_THROW("axis must be 0 or 1."); + PADDLE_ENFORCE(axis == 0 || axis == 1, "axis must be 0 or 1"); + if (ctx->IsRuntime()) { + if (axis == 0) { + PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); + } else if (axis == 1) { + PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]); + } + PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); } - PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); }