未验证 提交 9eefd2c7 编写于 作者: Q qingqing01 提交者: GitHub

Modify some infer-shape about detection operators in compile-time. (#14483)

* Modify some infer-shape in compile-time.
上级 cf685f36
...@@ -30,27 +30,30 @@ class BoxCoderOp : public framework::OperatorWithKernel { ...@@ -30,27 +30,30 @@ class BoxCoderOp : public framework::OperatorWithKernel {
auto prior_box_dims = ctx->GetInputDim("PriorBox"); auto prior_box_dims = ctx->GetInputDim("PriorBox");
auto target_box_dims = ctx->GetInputDim("TargetBox"); auto target_box_dims = ctx->GetInputDim("TargetBox");
PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2, if (ctx->IsRuntime()) {
"The rank of Input of PriorBoxVar must be 2"); PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2,
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]"); "The rank of Input of PriorBoxVar must be 2");
if (ctx->HasInput("PriorBoxVar")) { PADDLE_ENFORCE_EQ(prior_box_dims[1], 4,
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); "The shape of PriorBox is [N, 4]");
PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims); if (ctx->HasInput("PriorBoxVar")) {
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims);
}
auto code_type =
GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
if (code_type == BoxCodeType::kEncodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
"The rank of Input of TargetBox must be 2");
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
"The shape of TargetBox is [M, 4]");
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
"The rank of Input of TargetBox must be 3");
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
}
} }
auto code_type = GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
if (code_type == BoxCodeType::kEncodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
"The rank of Input of TargetBox must be 2");
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
"The shape of TargetBox is [M, 4]");
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
"The rank of Input of TargetBox must be 3");
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
}
ctx->SetOutputDim( ctx->SetOutputDim(
"OutputBox", "OutputBox",
framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4}));
......
...@@ -36,24 +36,26 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { ...@@ -36,24 +36,26 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
auto box_dims = ctx->GetInputDim("BBoxes"); auto box_dims = ctx->GetInputDim("BBoxes");
auto score_dims = ctx->GetInputDim("Scores"); auto score_dims = ctx->GetInputDim("Scores");
PADDLE_ENFORCE_EQ(box_dims.size(), 3, if (ctx->IsRuntime()) {
"The rank of Input(BBoxes) must be 3."); PADDLE_ENFORCE_EQ(box_dims.size(), 3,
PADDLE_ENFORCE_EQ(score_dims.size(), 3, "The rank of Input(BBoxes) must be 3.");
"The rank of Input(Scores) must be 3."); PADDLE_ENFORCE_EQ(score_dims.size(), 3,
PADDLE_ENFORCE(box_dims[2] == 4 || box_dims[2] == 8 || box_dims[2] == 16 || "The rank of Input(Scores) must be 3.");
box_dims[2] == 24 || box_dims[2] == 32, PADDLE_ENFORCE(box_dims[2] == 4 || box_dims[2] == 8 ||
"The 2nd dimension of Input(BBoxes) must be 4 or 8, " box_dims[2] == 16 || box_dims[2] == 24 ||
"represents the layout of coordinate " box_dims[2] == 32,
"[xmin, ymin, xmax, ymax] or " "The 2nd dimension of Input(BBoxes) must be 4 or 8, "
"4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or " "represents the layout of coordinate "
"8 points: [xi, yi] i= 1,2,...,8 or " "[xmin, ymin, xmax, ymax] or "
"12 points: [xi, yi] i= 1,2,...,12 or " "4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or "
"16 points: [xi, yi] i= 1,2,...,16"); "8 points: [xi, yi] i= 1,2,...,8 or "
PADDLE_ENFORCE_EQ(box_dims[1], score_dims[2], "12 points: [xi, yi] i= 1,2,...,12 or "
"The 1st dimensiong of Input(BBoxes) must be equal to " "16 points: [xi, yi] i= 1,2,...,16");
"3rd dimension of Input(Scores), which represents the " PADDLE_ENFORCE_EQ(box_dims[1], score_dims[2],
"predicted bboxes."); "The 1st dimensiong of Input(BBoxes) must be equal to "
"3rd dimension of Input(Scores), which represents the "
"predicted bboxes.");
}
// Here the box_dims[0] is not the real dimension of output. // Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel. // It will be rewritten in the computing kernel.
ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2}); ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2});
......
...@@ -283,11 +283,7 @@ def detection_output(loc, ...@@ -283,11 +283,7 @@ def detection_output(loc,
prior_box_var=prior_box_var, prior_box_var=prior_box_var,
target_box=loc, target_box=loc,
code_type='decode_center_size') code_type='decode_center_size')
compile_shape = scores.shape
run_shape = nn.shape(scores)
scores = nn.flatten(x=scores, axis=2)
scores = nn.softmax(input=scores) scores = nn.softmax(input=scores)
scores = nn.reshape(x=scores, shape=compile_shape, actual_shape=run_shape)
scores = nn.transpose(scores, perm=[0, 2, 1]) scores = nn.transpose(scores, perm=[0, 2, 1])
scores.stop_gradient = True scores.stop_gradient = True
nmsed_outs = helper.create_variable_for_type_inference( nmsed_outs = helper.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册