From 66bb5dd760f0ce72740ca755224bb3ca85194600 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 21 Jan 2019 10:18:41 +0000 Subject: [PATCH] refine infer shape, test=develop --- .../fluid/operators/detection/box_coder_op.cc | 57 +++++++++---------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index b4b02124cc..2ce844669b 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -43,7 +43,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { if (prior_box_var_dims.size() == 1) { PADDLE_ENFORCE_EQ( prior_box_var_dims[0], 4, - "The 1st dimension of Input(PriorBoxVar) should be 1" + "The 1st dimension of Input(PriorBoxVar) should be 4" "when the rank is 1."); } else { PADDLE_ENFORCE_EQ( @@ -52,37 +52,36 @@ class BoxCoderOp : public framework::OperatorWithKernel { "the dimension of Input(PriorBox when the rank is 2.)"); } } + } - auto code_type = - GetBoxCodeType(ctx->Attrs().Get("code_type")); - int axis = ctx->Attrs().Get("axis"); - if (code_type == BoxCodeType::kEncodeCenterSize) { - PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, - "The rank of Input of TargetBox must be 2"); - PADDLE_ENFORCE_EQ(target_box_dims[1], 4, - "The shape of TargetBox is [M, 4]"); - ctx->SetOutputDim( - "OutputBox", - framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); - } else if (code_type == BoxCodeType::kDecodeCenterSize) { - PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, - "The rank of Input of TargetBox must be 3"); - if (axis == 0) { - PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); - } else if (axis == 1) { - PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]); - } else { - PADDLE_THROW("axis must be 0 or 1."); - } - PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); - ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); - } - - if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) { - ctx->ShareLoD("PriorBox", /*->*/ "OutputBox"); + auto code_type = GetBoxCodeType(ctx->Attrs().Get("code_type")); + int axis = ctx->Attrs().Get("axis"); + if (code_type == BoxCodeType::kEncodeCenterSize) { + PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, + "The rank of Input of TargetBox must be 2"); + PADDLE_ENFORCE_EQ(target_box_dims[1], 4, + "The shape of TargetBox is [M, 4]"); + ctx->SetOutputDim( + "OutputBox", + framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); + } else if (code_type == BoxCodeType::kDecodeCenterSize) { + PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, + "The rank of Input of TargetBox must be 3"); + if (axis == 0) { + PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); + } else if (axis == 1) { + PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]); } else { - ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); + PADDLE_THROW("axis must be 0 or 1."); } + PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); + ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); + } + + if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) { + ctx->ShareLoD("PriorBox", /*->*/ "OutputBox"); + } else { + ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); } } }; -- GitLab