提交 e0708e62 编写于 作者: J jerrywgz

refine code

上级 1c591c39
...@@ -53,8 +53,8 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,8 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
ctx->HasOutput("TargetBBox"), ctx->HasOutput("TargetBBox"),
"Output(TargetBBox) of RpnTargetAssignOp should not be null"); "Output(TargetBBox) of RpnTargetAssignOp should not be null");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasOutput("BBox_inside_weight"), ctx->HasOutput("BBoxInsideWeight"),
"Output(BBox_inside_weight) of RpnTargetAssignOp should not be null"); "Output(BBoxInsideWeight) of RpnTargetAssignOp should not be null");
auto anchor_dims = ctx->GetInputDim("Anchor"); auto anchor_dims = ctx->GetInputDim("Anchor");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes"); auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
...@@ -71,7 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -71,7 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ScoreIndex", {-1}); ctx->SetOutputDim("ScoreIndex", {-1});
ctx->SetOutputDim("TargetLabel", {-1, 1}); ctx->SetOutputDim("TargetLabel", {-1, 1});
ctx->SetOutputDim("TargetBBox", {-1, 4}); ctx->SetOutputDim("TargetBBox", {-1, 4});
ctx->SetOutputDim("BBox_inside_weight", {-1, 4}); ctx->SetOutputDim("BBoxInsideWeight", {-1, 4});
} }
protected: protected:
...@@ -345,7 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -345,7 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
auto* score_index = context.Output<LoDTensor>("ScoreIndex"); auto* score_index = context.Output<LoDTensor>("ScoreIndex");
auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox"); auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox");
auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel"); auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel");
auto* bbox_inside_weight = context.Output<LoDTensor>("BBox_inside_weight"); auto* bbox_inside_weight = context.Output<LoDTensor>("BBoxInsideWeight");
PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL, PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL,
"RpnTargetAssignOp gt_boxes needs 1 level of LoD"); "RpnTargetAssignOp gt_boxes needs 1 level of LoD");
...@@ -547,7 +547,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -547,7 +547,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"TargetLabel", "TargetLabel",
"(Tensor<int>), The target labels of each anchor with shape " "(Tensor<int>), The target labels of each anchor with shape "
"[F + B, 1], F and B are sampled foreground and backgroud number."); "[F + B, 1], F and B are sampled foreground and backgroud number.");
AddOutput("BBox_inside_weight", AddOutput("BBoxInsideWeight",
"(Tensor), The bbox inside weight with shape " "(Tensor), The bbox inside weight with shape "
"[F, 4], F is the sampled foreground number."); "[F, 4], F is the sampled foreground number.");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -167,7 +167,7 @@ def rpn_target_assign(bbox_pred, ...@@ -167,7 +167,7 @@ def rpn_target_assign(bbox_pred,
'ScoreIndex': score_index, 'ScoreIndex': score_index,
'TargetLabel': target_label, 'TargetLabel': target_label,
'TargetBBox': target_bbox, 'TargetBBox': target_bbox,
'BBox_inside_weight': bbox_inside_weight 'BBoxInsideWeight': bbox_inside_weight
}, },
attrs={ attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im, 'rpn_batch_size_per_im': rpn_batch_size_per_im,
......
...@@ -324,6 +324,7 @@ class TestRpnTargetAssign(unittest.TestCase): ...@@ -324,6 +324,7 @@ class TestRpnTargetAssign(unittest.TestCase):
assert pred_scores.shape[1] == 1 assert pred_scores.shape[1] == 1
assert pred_loc.shape[1] == 4 assert pred_loc.shape[1] == 4
assert pred_loc.shape[1] == tgt_bbox.shape[1] assert pred_loc.shape[1] == tgt_bbox.shape[1]
print(str(program))
class TestGenerateProposals(unittest.TestCase): class TestGenerateProposals(unittest.TestCase):
......
...@@ -227,7 +227,7 @@ class TestRpnTargetAssignOp(OpTest): ...@@ -227,7 +227,7 @@ class TestRpnTargetAssignOp(OpTest):
'ScoreIndex': score_index.astype('int32'), 'ScoreIndex': score_index.astype('int32'),
'TargetBBox': tgt_bbox.astype('float32'), 'TargetBBox': tgt_bbox.astype('float32'),
'TargetLabel': labels.astype('int32'), 'TargetLabel': labels.astype('int32'),
'BBox_inside_weight': bbox_inside_weights.astype('float32') 'BBoxInsideWeight': bbox_inside_weights.astype('float32')
} }
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册