提交 32d533c2 编写于 作者: D dengkaipeng

cache obj_mask and gt_match_mask. test=develop

上级 6c5a5d07
...@@ -29,6 +29,11 @@ class Yolov3LossOp : public framework::OperatorWithKernel { ...@@ -29,6 +29,11 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
"Input(GTLabel) of Yolov3LossOp should not be null."); "Input(GTLabel) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"), PADDLE_ENFORCE(ctx->HasOutput("Loss"),
"Output(Loss) of Yolov3LossOp should not be null."); "Output(Loss) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("ObjectnessMask"),
"Output(ObjectnessMask) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("GTMatchMask"),
"Output(GTMatchMask) of Yolov3LossOp should not be null.");
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
auto dim_gtbox = ctx->GetInputDim("GTBox"); auto dim_gtbox = ctx->GetInputDim("GTBox");
...@@ -68,6 +73,12 @@ class Yolov3LossOp : public framework::OperatorWithKernel { ...@@ -68,6 +73,12 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
std::vector<int64_t> dim_out({dim_x[0]}); std::vector<int64_t> dim_out({dim_x[0]});
ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); ctx->SetOutputDim("Loss", framework::make_ddim(dim_out));
std::vector<int64_t> dim_obj_mask({dim_x[0], mask_num, dim_x[2], dim_x[3]});
ctx->SetOutputDim("ObjectnessMask", framework::make_ddim(dim_obj_mask));
std::vector<int64_t> dim_gt_match_mask({dim_gtbox[0], dim_gtbox[1]});
ctx->SetOutputDim("GTMatchMask", framework::make_ddim(dim_gt_match_mask));
} }
protected: protected:
...@@ -103,6 +114,16 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -103,6 +114,16 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Loss", AddOutput("Loss",
"The output yolov3 loss tensor, " "The output yolov3 loss tensor, "
"This is a 1-D tensor with shape of [N]"); "This is a 1-D tensor with shape of [N]");
AddOutput("ObjectnessMask",
"This is an intermediate tensor with shape of [N, M, H, W], "
"M is the number of anchor masks. This parameter caches the "
"mask for calculate objectness loss in gradient kernel.")
.AsIntermediate();
AddOutput("GTMatchMask",
"This is an intermediate tensor with shape if [N, B], "
"B is the max box number of GT boxes. This parameter caches "
"matched mask index of each GT boxes for gradient calculate.")
.AsIntermediate();
AddAttr<int>("class_num", "The number of classes to predict."); AddAttr<int>("class_num", "The number of classes to predict.");
AddAttr<std::vector<int>>("anchors", AddAttr<std::vector<int>>("anchors",
...@@ -208,6 +229,8 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { ...@@ -208,6 +229,8 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTBox", Input("GTBox"));
op->SetInput("GTLabel", Input("GTLabel")); op->SetInput("GTLabel", Input("GTLabel"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetInput("ObjectnessMask", Output("ObjectnessMask"));
op->SetInput("GTMatchMask", Output("GTMatchMask"));
op->SetAttrMap(Attrs()); op->SetAttrMap(Attrs());
......
...@@ -227,6 +227,8 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -227,6 +227,8 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
auto* gt_box = ctx.Input<Tensor>("GTBox"); auto* gt_box = ctx.Input<Tensor>("GTBox");
auto* gt_label = ctx.Input<Tensor>("GTLabel"); auto* gt_label = ctx.Input<Tensor>("GTLabel");
auto* loss = ctx.Output<Tensor>("Loss"); auto* loss = ctx.Output<Tensor>("Loss");
auto* objness_mask = ctx.Output<Tensor>("ObjectnessMask");
auto* gt_match_mask = ctx.Output<Tensor>("GTMatchMask");
auto anchors = ctx.Attr<std::vector<int>>("anchors"); auto anchors = ctx.Attr<std::vector<int>>("anchors");
auto anchor_mask = ctx.Attr<std::vector<int>>("anchor_mask"); auto anchor_mask = ctx.Attr<std::vector<int>>("anchor_mask");
int class_num = ctx.Attr<int>("class_num"); int class_num = ctx.Attr<int>("class_num");
...@@ -241,19 +243,19 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -241,19 +243,19 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
const int b = gt_box->dims()[1]; const int b = gt_box->dims()[1];
int input_size = downsample * h; int input_size = downsample * h;
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* gt_box_data = gt_box->data<T>(); const T* gt_box_data = gt_box->data<T>();
const int* gt_label_data = gt_label->data<int>(); const int* gt_label_data = gt_label->data<int>();
T* loss_data = loss->mutable_data<T>({n}, ctx.GetPlace()); T* loss_data = loss->mutable_data<T>({n}, ctx.GetPlace());
memset(loss_data, 0, loss->numel() * sizeof(T)); memset(loss_data, 0, loss->numel() * sizeof(T));
int* obj_mask_data =
Tensor objness; objness_mask->mutable_data<int>({n, mask_num, h, w}, ctx.GetPlace());
int* objness_data = memset(obj_mask_data, 0, objness_mask->numel() * sizeof(int));
objness.mutable_data<int>({n, mask_num, h, w}, ctx.GetPlace()); int* gt_match_mask_data =
memset(objness_data, 0, objness.numel() * sizeof(int)); gt_match_mask->mutable_data<int>({n, b}, ctx.GetPlace());
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
for (int j = 0; j < mask_num; j++) { for (int j = 0; j < mask_num; j++) {
...@@ -277,7 +279,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -277,7 +279,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
if (best_iou > ignore_thresh) { if (best_iou > ignore_thresh) {
int obj_idx = (i * mask_num + j) * stride + k * w + l; int obj_idx = (i * mask_num + j) * stride + k * w + l;
objness_data[obj_idx] = -1; obj_mask_data[obj_idx] = -1;
} }
} }
} }
...@@ -285,6 +287,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -285,6 +287,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
for (int t = 0; t < b; t++) { for (int t = 0; t < b; t++) {
Box<T> gt = GetGtBox(gt_box_data, i, b, t); Box<T> gt = GetGtBox(gt_box_data, i, b, t);
if (LessEqualZero<T>(gt.w) || LessEqualZero<T>(gt.h)) { if (LessEqualZero<T>(gt.w) || LessEqualZero<T>(gt.h)) {
gt_match_mask_data[i * b + t] = -1;
continue; continue;
} }
int gi = static_cast<int>(gt.x * w); int gi = static_cast<int>(gt.x * w);
...@@ -309,6 +312,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -309,6 +312,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
} }
int mask_idx = GetMaskIndex(anchor_mask, best_n); int mask_idx = GetMaskIndex(anchor_mask, best_n);
gt_match_mask_data[i * b + t] = mask_idx;
if (mask_idx >= 0) { if (mask_idx >= 0) {
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0); an_stride, stride, 0);
...@@ -316,7 +320,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -316,7 +320,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
box_idx, gi, gj, h, input_size, stride); box_idx, gi, gj, h, input_size, stride);
int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi;
objness_data[obj_idx] = 1; obj_mask_data[obj_idx] = 1;
int label = gt_label_data[i * b + t]; int label = gt_label_data[i * b + t];
int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
...@@ -327,7 +331,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -327,7 +331,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
} }
} }
CalcObjnessLoss<T>(loss_data, input_data + 4 * stride, objness_data, n, CalcObjnessLoss<T>(loss_data, input_data + 4 * stride, obj_mask_data, n,
mask_num, h, w, stride, an_stride); mask_num, h, w, stride, an_stride);
} }
}; };
...@@ -341,64 +345,35 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> { ...@@ -341,64 +345,35 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
auto* gt_label = ctx.Input<Tensor>("GTLabel"); auto* gt_label = ctx.Input<Tensor>("GTLabel");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss")); auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* objness_mask = ctx.Input<Tensor>("ObjectnessMask");
auto* gt_match_mask = ctx.Input<Tensor>("GTMatchMask");
auto anchors = ctx.Attr<std::vector<int>>("anchors"); auto anchors = ctx.Attr<std::vector<int>>("anchors");
auto anchor_mask = ctx.Attr<std::vector<int>>("anchor_mask"); auto anchor_mask = ctx.Attr<std::vector<int>>("anchor_mask");
int class_num = ctx.Attr<int>("class_num"); int class_num = ctx.Attr<int>("class_num");
float ignore_thresh = ctx.Attr<float>("ignore_thresh");
int downsample = ctx.Attr<int>("downsample"); int downsample = ctx.Attr<int>("downsample");
const int n = input->dims()[0]; const int n = input_grad->dims()[0];
const int c = input->dims()[1]; const int c = input_grad->dims()[1];
const int h = input->dims()[2]; const int h = input_grad->dims()[2];
const int w = input->dims()[3]; const int w = input_grad->dims()[3];
const int an_num = anchors.size() / 2;
const int mask_num = anchor_mask.size(); const int mask_num = anchor_mask.size();
const int b = gt_box->dims()[1]; const int b = gt_match_mask->dims()[1];
int input_size = downsample * h; int input_size = downsample * h;
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* gt_box_data = gt_box->data<T>(); const T* gt_box_data = gt_box->data<T>();
const int* gt_label_data = gt_label->data<int>(); const int* gt_label_data = gt_label->data<int>();
const T* loss_grad_data = loss_grad->data<T>(); const T* loss_grad_data = loss_grad->data<T>();
const int* obj_mask_data = objness_mask->data<int>();
const int* gt_match_mask_data = gt_match_mask->data<int>();
T* input_grad_data = T* input_grad_data =
input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace()); input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
Tensor objness;
int* objness_data =
objness.mutable_data<int>({n, mask_num, h, w}, ctx.GetPlace());
memset(objness_data, 0, objness.numel() * sizeof(int));
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
for (int j = 0; j < mask_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int box_idx =
GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0);
Box<T> pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j],
h, input_size, box_idx, stride);
T best_iou = 0;
for (int t = 0; t < b; t++) {
Box<T> gt = GetGtBox(gt_box_data, i, b, t);
if (LessEqualZero<T>(gt.w) || LessEqualZero<T>(gt.h)) {
continue;
}
T iou = CalcBoxIoU(pred, gt);
if (iou > best_iou) {
best_iou = iou;
}
}
if (best_iou > ignore_thresh) {
int obj_idx = (i * mask_num + j) * stride + k * w + l;
objness_data[obj_idx] = -1;
}
}
}
}
for (int t = 0; t < b; t++) { for (int t = 0; t < b; t++) {
Box<T> gt = GetGtBox(gt_box_data, i, b, t); Box<T> gt = GetGtBox(gt_box_data, i, b, t);
if (LessEqualZero<T>(gt.w) || LessEqualZero<T>(gt.h)) { if (LessEqualZero<T>(gt.w) || LessEqualZero<T>(gt.h)) {
...@@ -406,35 +381,14 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> { ...@@ -406,35 +381,14 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
} }
int gi = static_cast<int>(gt.x * w); int gi = static_cast<int>(gt.x * w);
int gj = static_cast<int>(gt.y * h); int gj = static_cast<int>(gt.y * h);
Box<T> gt_shift = gt;
gt_shift.x = 0.0;
gt_shift.y = 0.0;
T best_iou = 0.0;
int best_n = 0;
for (int an_idx = 0; an_idx < an_num; an_idx++) {
Box<T> an_box;
an_box.x = 0.0;
an_box.y = 0.0;
an_box.w = anchors[2 * an_idx] / static_cast<T>(input_size);
an_box.h = anchors[2 * an_idx + 1] / static_cast<T>(input_size);
float iou = CalcBoxIoU<T>(an_box, gt_shift);
// TO DO: iou > 0.5 ?
if (iou > best_iou) {
best_iou = iou;
best_n = an_idx;
}
}
int mask_idx = GetMaskIndex(anchor_mask, best_n); int mask_idx = gt_match_mask_data[i * b + t];
if (mask_idx >= 0) { if (mask_idx >= 0) {
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0); an_stride, stride, 0);
CalcBoxLocationLossGrad<T>(input_grad_data, loss_grad_data[i], CalcBoxLocationLossGrad<T>(
input_data, gt, anchors, best_n, box_idx, input_grad_data, loss_grad_data[i], input_data, gt, anchors,
gi, gj, h, input_size, stride); anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride);
int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi;
objness_data[obj_idx] = 1;
int label = gt_label_data[i * b + t]; int label = gt_label_data[i * b + t];
int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
...@@ -446,7 +400,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> { ...@@ -446,7 +400,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
} }
CalcObjnessLossGrad<T>(input_grad_data + 4 * stride, loss_grad_data, CalcObjnessLossGrad<T>(input_grad_data + 4 * stride, loss_grad_data,
input_data + 4 * stride, objness_data, n, mask_num, input_data + 4 * stride, obj_mask_data, n, mask_num,
h, w, stride, an_stride); h, w, stride, an_stride);
} }
}; };
......
...@@ -483,6 +483,9 @@ def yolov3_loss(x, ...@@ -483,6 +483,9 @@ def yolov3_loss(x,
loss = helper.create_variable( loss = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
objectness_mask = helper.create_variable_for_type_inference(dtype='int32')
gt_match_mask = helper.create_variable_for_type_inference(dtype='int32')
attrs = { attrs = {
"anchors": anchors, "anchors": anchors,
"anchor_mask": anchor_mask, "anchor_mask": anchor_mask,
...@@ -496,7 +499,11 @@ def yolov3_loss(x, ...@@ -496,7 +499,11 @@ def yolov3_loss(x,
inputs={"X": x, inputs={"X": x,
"GTBox": gtbox, "GTBox": gtbox,
"GTLabel": gtlabel}, "GTLabel": gtlabel},
outputs={'Loss': loss}, outputs={
'Loss': loss,
'ObjectnessMask': objectness_mask,
'GTMatchMask': gt_match_mask
},
attrs=attrs) attrs=attrs)
return loss return loss
......
...@@ -116,13 +116,17 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): ...@@ -116,13 +116,17 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
anchor_boxes = np.tile(anchor_boxes[np.newaxis, :, :], (n, 1, 1)) anchor_boxes = np.tile(anchor_boxes[np.newaxis, :, :], (n, 1, 1))
ious = batch_xywh_box_iou(gtbox_shift, anchor_boxes) ious = batch_xywh_box_iou(gtbox_shift, anchor_boxes)
iou_matches = np.argmax(ious, axis=-1) iou_matches = np.argmax(ious, axis=-1)
gt_matches = iou_matches.copy()
for i in range(n): for i in range(n):
for j in range(b): for j in range(b):
if gtbox[i, j, 2:].sum() == 0: if gtbox[i, j, 2:].sum() == 0:
gt_matches[i, j] = -1
continue continue
if iou_matches[i, j] not in anchor_mask: if iou_matches[i, j] not in anchor_mask:
gt_matches[i, j] = -1
continue continue
an_idx = anchor_mask.index(iou_matches[i, j]) an_idx = anchor_mask.index(iou_matches[i, j])
gt_matches[i, j] = an_idx
gi = int(gtbox[i, j, 0] * w) gi = int(gtbox[i, j, 0] * w)
gj = int(gtbox[i, j, 1] * h) gj = int(gtbox[i, j, 1] * h)
...@@ -146,7 +150,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): ...@@ -146,7 +150,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
if objness[i, j] >= 0: if objness[i, j] >= 0:
loss[i] += sce(pred_obj[i, j], objness[i, j]) loss[i] += sce(pred_obj[i, j], objness[i, j])
return loss return (loss, objness.reshape((n, mask_num, h, w)).astype('int32'), \
gt_matches.astype('int32'))
class TestYolov3LossOp(OpTest): class TestYolov3LossOp(OpTest):
...@@ -173,11 +178,16 @@ class TestYolov3LossOp(OpTest): ...@@ -173,11 +178,16 @@ class TestYolov3LossOp(OpTest):
'GTBox': gtbox.astype('float32'), 'GTBox': gtbox.astype('float32'),
'GTLabel': gtlabel.astype('int32') 'GTLabel': gtlabel.astype('int32')
} }
self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)} loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs)
self.outputs = {
'Loss': loss,
'ObjectnessMask': objness,
"GTMatchMask": gt_matches
}
def test_check_output(self): def test_check_output(self):
place = core.CPUPlace() place = core.CPUPlace()
self.check_output_with_place(place, atol=1e-3) self.check_output_with_place(place, atol=2e-3)
def test_check_grad_ignore_gtbox(self): def test_check_grad_ignore_gtbox(self):
place = core.CPUPlace() place = core.CPUPlace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册