提交 e0708e62 编写于 作者: J jerrywgz

refine code

上级 1c591c39
......@@ -53,8 +53,8 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
ctx->HasOutput("TargetBBox"),
"Output(TargetBBox) of RpnTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("BBox_inside_weight"),
"Output(BBox_inside_weight) of RpnTargetAssignOp should not be null");
ctx->HasOutput("BBoxInsideWeight"),
"Output(BBoxInsideWeight) of RpnTargetAssignOp should not be null");
auto anchor_dims = ctx->GetInputDim("Anchor");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
......@@ -71,7 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ScoreIndex", {-1});
ctx->SetOutputDim("TargetLabel", {-1, 1});
ctx->SetOutputDim("TargetBBox", {-1, 4});
ctx->SetOutputDim("BBox_inside_weight", {-1, 4});
ctx->SetOutputDim("BBoxInsideWeight", {-1, 4});
}
protected:
......@@ -345,7 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
auto* score_index = context.Output<LoDTensor>("ScoreIndex");
auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox");
auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel");
auto* bbox_inside_weight = context.Output<LoDTensor>("BBox_inside_weight");
auto* bbox_inside_weight = context.Output<LoDTensor>("BBoxInsideWeight");
PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL,
"RpnTargetAssignOp gt_boxes needs 1 level of LoD");
......@@ -547,7 +547,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"TargetLabel",
"(Tensor<int>), The target labels of each anchor with shape "
"[F + B, 1], F and B are sampled foreground and backgroud number.");
AddOutput("BBox_inside_weight",
AddOutput("BBoxInsideWeight",
"(Tensor), The bbox inside weight with shape "
"[F, 4], F is the sampled foreground number.");
AddComment(R"DOC(
......
......@@ -167,7 +167,7 @@ def rpn_target_assign(bbox_pred,
'ScoreIndex': score_index,
'TargetLabel': target_label,
'TargetBBox': target_bbox,
'BBox_inside_weight': bbox_inside_weight
'BBoxInsideWeight': bbox_inside_weight
},
attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im,
......
......@@ -324,6 +324,7 @@ class TestRpnTargetAssign(unittest.TestCase):
assert pred_scores.shape[1] == 1
assert pred_loc.shape[1] == 4
assert pred_loc.shape[1] == tgt_bbox.shape[1]
print(str(program))
class TestGenerateProposals(unittest.TestCase):
......
......@@ -227,7 +227,7 @@ class TestRpnTargetAssignOp(OpTest):
'ScoreIndex': score_index.astype('int32'),
'TargetBBox': tgt_bbox.astype('float32'),
'TargetLabel': labels.astype('int32'),
'BBox_inside_weight': bbox_inside_weights.astype('float32')
'BBoxInsideWeight': bbox_inside_weights.astype('float32')
}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册