提交 66bb5dd7 编写于 作者: J jerrywgz

refine infer shape, test=develop

上级 0d915078
...@@ -43,7 +43,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { ...@@ -43,7 +43,7 @@ class BoxCoderOp : public framework::OperatorWithKernel {
if (prior_box_var_dims.size() == 1) { if (prior_box_var_dims.size() == 1) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
prior_box_var_dims[0], 4, prior_box_var_dims[0], 4,
"The 1st dimension of Input(PriorBoxVar) should be 1" "The 1st dimension of Input(PriorBoxVar) should be 4"
"when the rank is 1."); "when the rank is 1.");
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -52,37 +52,36 @@ class BoxCoderOp : public framework::OperatorWithKernel { ...@@ -52,37 +52,36 @@ class BoxCoderOp : public framework::OperatorWithKernel {
"the dimension of Input(PriorBox when the rank is 2.)"); "the dimension of Input(PriorBox when the rank is 2.)");
} }
} }
}
auto code_type = auto code_type = GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type")); int axis = ctx->Attrs().Get<int>("axis");
int axis = ctx->Attrs().Get<int>("axis"); if (code_type == BoxCodeType::kEncodeCenterSize) {
if (code_type == BoxCodeType::kEncodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, "The rank of Input of TargetBox must be 2");
"The rank of Input of TargetBox must be 2"); PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
PADDLE_ENFORCE_EQ(target_box_dims[1], 4, "The shape of TargetBox is [M, 4]");
"The shape of TargetBox is [M, 4]"); ctx->SetOutputDim(
ctx->SetOutputDim( "OutputBox",
"OutputBox", framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4}));
framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); } else if (code_type == BoxCodeType::kDecodeCenterSize) {
} else if (code_type == BoxCodeType::kDecodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, "The rank of Input of TargetBox must be 3");
"The rank of Input of TargetBox must be 3"); if (axis == 0) {
if (axis == 0) { PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); } else if (axis == 1) {
} else if (axis == 1) { PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]);
PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]);
} else {
PADDLE_THROW("axis must be 0 or 1.");
}
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
ctx->ShareDim("TargetBox", /*->*/ "OutputBox");
}
if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) {
ctx->ShareLoD("PriorBox", /*->*/ "OutputBox");
} else { } else {
ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); PADDLE_THROW("axis must be 0 or 1.");
} }
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
ctx->ShareDim("TargetBox", /*->*/ "OutputBox");
}
if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) {
ctx->ShareLoD("PriorBox", /*->*/ "OutputBox");
} else {
ctx->ShareLoD("TargetBox", /*->*/ "OutputBox");
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册