diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index aa4ba3b62eb342ffc3c5e49a38664960c7372341..8c46e341d625098349000ea2aff6af004841fe9d 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -35,13 +35,16 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_gtlabel = ctx->GetInputDim("GTLabel"); auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; + auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); + int mask_num = anchor_mask.size(); auto class_num = ctx->Attrs().Get("class_num"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], "Input(X) dim[3] and dim[4] should be euqal."); - PADDLE_ENFORCE_EQ(dim_x[1], anchor_num * (5 + class_num), - "Input(X) dim[1] should be equal to (anchor_number * (5 " - "+ class_num))."); + PADDLE_ENFORCE_EQ( + dim_x[1], mask_num * (5 + class_num), + "Input(X) dim[1] should be equal to (anchor_mask_number * (5 " + "+ class_num))."); PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3, "Input(GTBox) should be a 3-D tensor"); PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5"); @@ -55,6 +58,11 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, "Attr(anchors) length should be even integer."); + for (size_t i = 0; i < anchor_mask.size(); i++) { + PADDLE_ENFORCE_LT( + anchor_mask[i], anchor_num, + "Attr(anchor_mask) should not crossover Attr(anchors)."); + } PADDLE_ENFORCE_GT(class_num, 0, "Attr(class_num) should be an integer greater then 0."); @@ -74,7 +82,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "The input tensor of YOLO v3 loss operator, " + "The input tensor of YOLOv3 loss operator, " "This is a 4-D tensor with shape of [N, C, H, W]." "H and W should be same, and the second dimention(C) stores" "box locations, confidence score and classification one-hot" @@ -99,13 +107,20 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("class_num", "The number of classes to predict."); AddAttr>("anchors", "The anchor width and height, " - "it will be parsed pair by pair."); - AddAttr("input_size", - "The input size of YOLOv3 net, " - "generally this is set as 320, 416 or 608.") - .SetDefault(406); + "it will be parsed pair by pair.") + .SetDefault(std::vector{}); + AddAttr>("anchor_mask", + "The mask index of anchors used in " + "current YOLOv3 loss calculation.") + .SetDefault(std::vector{}); + AddAttr("downsample", + "The downsample ratio from network input to YOLOv3 loss " + "input, so 32, 16, 8 should be set for the first, second, " + "and thrid YOLOv3 loss operators.") + .SetDefault(32); AddAttr("ignore_thresh", - "The ignore threshold to ignore confidence loss."); + "The ignore threshold to ignore confidence loss.") + .SetDefault(0.7); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index e32cd309674c71a2d6c88d34adcea69ccb84ce5b..9254a6cf6f62bafa8dff48ce0811b974bbc33e4c 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -321,6 +321,182 @@ static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, obj_mask_data, n, an_num, grid_num, class_num, class_num); } +static int mask_index(std::vector mask, int val) { + for (int i = 0; i < mask.size(); i++) { + if (mask[i] == val) { + return i; + } + } + return -1; +} + +template +struct Box { + float x, y, w, h; +}; + +template +static inline T sigmoid(T x) { + return 1.0 / (1.0 + std::exp(-x)); +} + +template +static inline void sigmoid_arrray(T* arr, int len) { + for (int i = 0; i < len; i++) { + arr[i] = sigmoid(arr[i]); + } +} + +template +static inline Box get_yolo_box(const T* x, std::vector anchors, int i, + int j, int an_idx, int grid_size, + int input_size, int index, int stride) { + Box b; + b.x = (i + sigmoid(x[index])) / grid_size; + b.y = (j + sigmoid(x[index + stride])) / grid_size; + b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] / input_size; + b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] / input_size; + return b; +} + +template +static inline Box get_gt_box(const T* gt, int batch, int max_boxes, + int idx) { + Box b; + b.x = gt[(batch * max_boxes + idx) * 4]; + b.y = gt[(batch * max_boxes + idx) * 4 + 1]; + b.w = gt[(batch * max_boxes + idx) * 4 + 2]; + b.h = gt[(batch * max_boxes + idx) * 4 + 3]; + return b; +} + +template +static inline T overlap(T c1, T w1, T c2, T w2) { + T l1 = c1 - w1 / 2.0; + T l2 = c2 - w2 / 2.0; + T left = l1 > l2 ? l1 : l2; + T r1 = c1 + w1 / 2.0; + T r2 = c2 + w2 / 2.0; + T right = r1 < r2 ? r1 : r2; + return right - left; +} + +template +static inline T box_iou(Box b1, Box b2) { + T w = overlap(b1.x, b1.w, b2.x, b2.w); + T h = overlap(b1.y, b1.h, b2.y, b2.h); + T inter_area = (w < 0 || h < 0) ? 0.0 : w * h; + T union_area = b1.w * b1.h + b2.w * b2.h - inter_area; + return inter_area / union_area; +} + +static inline int entry_index(int batch, int an_idx, int hw_idx, int an_num, + int an_stride, int stride, int entry) { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; +} + +template +static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, + std::vector anchors, int an_idx, + int box_idx, int gi, int gj, int grid_size, + int input_size, int stride) { + T tx = gt.x * grid_size - gi; + T ty = gt.y * grid_size - gj; + T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); + T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); + + T scale = 2.0 - gt.w * gt.h; + loss[0] += SCE(input[box_idx], tx) * scale; + loss[0] += SCE(input[box_idx + stride], ty) * scale; + loss[0] += L1Loss(input[box_idx + 2 * stride], tw) * scale; + loss[0] += L1Loss(input[box_idx + 3 * stride], th) * scale; +} + +template +static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, + Box gt, std::vector anchors, + int an_idx, int box_idx, int gi, int gj, + int grid_size, int input_size, int stride) { + T tx = gt.x * grid_size - gi; + T ty = gt.y * grid_size - gj; + T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); + T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); + + T scale = 2.0 - gt.w * gt.h; + input_grad[box_idx] = SCEGrad(input[box_idx], tx) * scale * loss; + input_grad[box_idx + stride] = + SCEGrad(input[box_idx + stride], ty) * scale * loss; + input_grad[box_idx + 2 * stride] = + L1LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; + input_grad[box_idx + 3 * stride] = + L1LossGrad(input[box_idx + 3 * stride], th) * scale * loss; +} + +template +static inline void CalcLabelLoss(T* loss, const T* input, const int index, + const int label, const int class_num, + const int stride) { + for (int i = 0; i < class_num; i++) { + loss[0] += SCE(input[index + i * stride], (i == label) ? 1.0 : 0.0); + } +} + +template +static inline void CalcLabelLossGrad(T* input_grad, const T loss, + const T* input, const int index, + const int label, const int class_num, + const int stride) { + for (int i = 0; i < class_num; i++) { + input_grad[index + i * stride] = + SCEGrad(input[index + i * stride], (i == label) ? 1.0 : 0.0) * loss; + } +} + +template +static inline void CalcObjnessLoss(T* loss, const T* input, const int* objness, + const int n, const int an_num, const int h, + const int w, const int stride, + const int an_stride) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj = objness[k * w + l]; + if (obj >= 0) { + loss[i] += SCE(input[k * w + l], static_cast(obj)); + } + } + } + objness += stride; + input += an_stride; + } + } +} + +template +static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, + const T* input, const int* objness, + const int n, const int an_num, + const int h, const int w, + const int stride, const int an_stride) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj = objness[k * w + l]; + if (obj >= 0) { + input_grad[k * w + l] = + SCEGrad(input[k * w + l], static_cast(obj)) * loss[i]; + } + } + } + objness += stride; + input += an_stride; + input_grad += an_stride; + } + } +} + template class Yolov3LossKernel : public framework::OpKernel { public: @@ -330,55 +506,158 @@ class Yolov3LossKernel : public framework::OpKernel { auto* gt_label = ctx.Input("GTLabel"); auto* loss = ctx.Output("Loss"); auto anchors = ctx.Attr>("anchors"); + auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); - int input_size = ctx.Attr("input_size"); float ignore_thresh = ctx.Attr("ignore_thresh"); + int downsample = ctx.Attr("downsample"); const int n = input->dims()[0]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; + const int mask_num = anchor_mask.size(); + const int b = gt_box->dims()[1]; + int input_size = downsample * h; - Tensor conf_mask, obj_mask; - Tensor tx, ty, tw, th, tweight, tconf, tclass; - conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - - math::SetConstant constant; - constant(ctx.template device_context(), - &conf_mask, static_cast(1.0)); - constant(ctx.template device_context(), - &obj_mask, static_cast(0.0)); - constant(ctx.template device_context(), &tx, - static_cast(0.0)); - constant(ctx.template device_context(), &ty, - static_cast(0.0)); - constant(ctx.template device_context(), &tw, - static_cast(0.0)); - constant(ctx.template device_context(), &th, - static_cast(0.0)); - constant(ctx.template device_context(), - &tweight, static_cast(0.0)); - constant(ctx.template device_context(), &tconf, - static_cast(0.0)); - constant(ctx.template device_context(), &tclass, - static_cast(0.0)); - - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, - h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, - &tconf, &tclass); - + const T* input_data = input->data(); + const T* gt_box_data = gt_box->data(); + const int* gt_label_data = gt_label->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); - memset(loss_data, 0, n * sizeof(T)); - CalcYolov3Loss(loss_data, *input, tx, ty, tw, th, tweight, tconf, tclass, - conf_mask, obj_mask); + memset(loss_data, 0, n * sizeof(int)); + + Tensor objness; + int* objness_data = + objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(objness_data, 0, objness.numel() * sizeof(int)); + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < mask_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int box_idx = + entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0); + Box pred = + get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h, + input_size, box_idx, stride); + T best_iou = 0; + // int best_t = 0; + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + T iou = box_iou(pred, gt); + if (iou > best_iou) { + best_iou = iou; + // best_t = t; + } + } + + if (best_iou > ignore_thresh) { + int obj_idx = (i * mask_num + j) * stride + k * w + l; + objness_data[obj_idx] = -1; + } + } + } + } + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + int gi = static_cast(gt.x * w); + int gj = static_cast(gt.y * h); + Box gt_shift = gt; + gt_shift.x = 0.0; + gt_shift.y = 0.0; + T best_iou = 0.0; + int best_n = 0; + for (int an_idx = 0; an_idx < an_num; an_idx++) { + Box an_box; + an_box.x = 0.0; + an_box.y = 0.0; + an_box.w = anchors[2 * an_idx] / static_cast(input_size); + an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); + float iou = box_iou(an_box, gt_shift); + // TO DO: iou > 0.5 ? + if (iou > best_iou) { + best_iou = iou; + best_n = an_idx; + } + } + + int mask_idx = mask_index(anchor_mask, best_n); + if (mask_idx >= 0) { + int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 0); + CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, + box_idx, gi, gj, h, input_size, stride); + + int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; + objness_data[obj_idx] = 1; + + int label = gt_label_data[i * b + t]; + int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 5); + CalcLabelLoss(loss_data + i, input_data, label_idx, label, + class_num, stride); + } + } + } + + CalcObjnessLoss(loss_data, input_data + 4 * stride, objness_data, n, + mask_num, h, w, stride, an_stride); + + // Tensor conf_mask, obj_mask; + // Tensor tx, ty, tw, th, tweight, tconf, tclass; + // conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + // + // math::SetConstant constant; + // constant(ctx.template device_context(), + // &conf_mask, static_cast(1.0)); + // constant(ctx.template device_context(), + // &obj_mask, static_cast(0.0)); + // constant(ctx.template device_context(), &tx, + // static_cast(0.0)); + // constant(ctx.template device_context(), &ty, + // static_cast(0.0)); + // constant(ctx.template device_context(), &tw, + // static_cast(0.0)); + // constant(ctx.template device_context(), &th, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tweight, static_cast(0.0)); + // constant(ctx.template device_context(), + // &tconf, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tclass, + // static_cast(0.0)); + // + // PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, + // input_size, + // h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, + // &tweight, + // &tconf, &tclass); + // + // T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); + // memset(loss_data, 0, n * sizeof(T)); + // CalcYolov3Loss(loss_data, *input, tx, ty, tw, th, tweight, tconf, + // tclass, + // conf_mask, obj_mask); } }; @@ -389,59 +668,172 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto anchors = ctx.Attr>("anchors"); + auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); - auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); - int input_size = ctx.Attr("input_size"); + int downsample = ctx.Attr("downsample"); const int n = input->dims()[0]; const int c = input->dims()[1]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; - - Tensor conf_mask, obj_mask; - Tensor tx, ty, tw, th, tweight, tconf, tclass; - conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - - math::SetConstant constant; - constant(ctx.template device_context(), - &conf_mask, static_cast(1.0)); - constant(ctx.template device_context(), - &obj_mask, static_cast(0.0)); - constant(ctx.template device_context(), &tx, - static_cast(0.0)); - constant(ctx.template device_context(), &ty, - static_cast(0.0)); - constant(ctx.template device_context(), &tw, - static_cast(0.0)); - constant(ctx.template device_context(), &th, - static_cast(0.0)); - constant(ctx.template device_context(), - &tweight, static_cast(0.0)); - constant(ctx.template device_context(), &tconf, - static_cast(0.0)); - constant(ctx.template device_context(), &tclass, - static_cast(0.0)); - - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, - h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, - &tconf, &tclass); - + const int mask_num = anchor_mask.size(); + const int b = gt_box->dims()[1]; + int input_size = downsample * h; + + const T* input_data = input->data(); + const T* gt_box_data = gt_box->data(); + const int* gt_label_data = gt_label->data(); + const T* loss_grad_data = loss_grad->data(); T* input_grad_data = input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); - CalcYolov3LossGrad(input_grad_data, *loss_grad, *input, tx, ty, tw, th, - tweight, tconf, tclass, conf_mask, obj_mask); + memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); + + Tensor objness; + int* objness_data = + objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(objness_data, 0, objness.numel() * sizeof(int)); + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < mask_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int box_idx = + entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0); + Box pred = + get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h, + input_size, box_idx, stride); + T best_iou = 0; + // int best_t = 0; + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + T iou = box_iou(pred, gt); + if (iou > best_iou) { + best_iou = iou; + // best_t = t; + } + } + + if (best_iou > ignore_thresh) { + int obj_idx = (i * mask_num + j) * stride + k * w + l; + objness_data[obj_idx] = -1; + } + } + } + } + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + int gi = static_cast(gt.x * w); + int gj = static_cast(gt.y * h); + Box gt_shift = gt; + gt_shift.x = 0.0; + gt_shift.y = 0.0; + T best_iou = 0.0; + int best_n = 0; + for (int an_idx = 0; an_idx < an_num; an_idx++) { + Box an_box; + an_box.x = 0.0; + an_box.y = 0.0; + an_box.w = anchors[2 * an_idx] / static_cast(input_size); + an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); + float iou = box_iou(an_box, gt_shift); + // TO DO: iou > 0.5 ? + if (iou > best_iou) { + best_iou = iou; + best_n = an_idx; + } + } + + int mask_idx = mask_index(anchor_mask, best_n); + if (mask_idx >= 0) { + int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 0); + CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], + input_data, gt, anchors, best_n, box_idx, + gi, gj, h, input_size, stride); + + int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; + objness_data[obj_idx] = 1; + + int label = gt_label_data[i * b + t]; + int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 5); + CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, + label_idx, label, class_num, stride); + } + } + } + + CalcObjnessLossGrad(input_grad_data + 4 * stride, loss_grad_data, + input_data + 4 * stride, objness_data, n, mask_num, + h, w, stride, an_stride); + + // const int n = input->dims()[0]; + // const int c = input->dims()[1]; + // const int h = input->dims()[2]; + // const int w = input->dims()[3]; + // const int an_num = anchors.size() / 2; + // + // Tensor conf_mask, obj_mask; + // Tensor tx, ty, tw, th, tweight, tconf, tclass; + // conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + // + // math::SetConstant constant; + // constant(ctx.template device_context(), + // &conf_mask, static_cast(1.0)); + // constant(ctx.template device_context(), + // &obj_mask, static_cast(0.0)); + // constant(ctx.template device_context(), &tx, + // static_cast(0.0)); + // constant(ctx.template device_context(), &ty, + // static_cast(0.0)); + // constant(ctx.template device_context(), &tw, + // static_cast(0.0)); + // constant(ctx.template device_context(), &th, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tweight, static_cast(0.0)); + // constant(ctx.template device_context(), + // &tconf, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tclass, + // static_cast(0.0)); + // + // PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, + // input_size, + // h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, + // &tweight, + // &tconf, &tclass); + // + // T* input_grad_data = + // input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + // CalcYolov3LossGrad(input_grad_data, *loss_grad, *input, tx, ty, tw, + // th, + // tweight, tconf, tclass, conf_mask, obj_mask); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 92823af1e02087e6dada07e829111eb515b3caf7..542162b7f41b8c116625e7956c2a64d7711f85ea 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -413,9 +413,10 @@ def yolov3_loss(x, gtbox, gtlabel, anchors, + anchor_mask, class_num, ignore_thresh, - input_size, + downsample, name=None): """ ${comment} @@ -430,9 +431,10 @@ def yolov3_loss(x, gtlabel (Variable): class id of ground truth boxes, shoud be ins shape of [N, B]. anchors (list|tuple): ${anchors_comment} + anchor_mask (list|tuple): ${anchor_mask_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} - input_size (int): ${input_size_comment} + downsample (int): ${downsample_comment} name (string): the name of yolov3 loss Returns: @@ -452,7 +454,8 @@ def yolov3_loss(x, x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') - anchors = [10, 13, 16, 30, 33, 23] + anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] + anchors = [0, 1, 2] loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 anchors=anchors, ignore_thresh=0.5) """ @@ -466,6 +469,8 @@ def yolov3_loss(x, raise TypeError("Input gtlabel of yolov3_loss must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): raise TypeError("Attr anchors of yolov3_loss must be list or tuple") + if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): + raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") if not isinstance(ignore_thresh, float): @@ -480,9 +485,10 @@ def yolov3_loss(x, attrs = { "anchors": anchors, + "anchor_mask": anchor_mask, "class_num": class_num, "ignore_thresh": ignore_thresh, - "input_size": input_size, + "downsample": downsample, } helper.append_op( diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 7d75562900e6498fc3c7bc8a3a35bae55ab09066..e11205d2bf33f79cfba07e5f0f66319c7c4171ac 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -463,8 +463,8 @@ class TestYoloDetection(unittest.TestCase): x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') - loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10, - 0.7, 416) + loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], + [0, 1], 10, 0.7, 32) self.assertIsNotNone(loss) diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index e52047b0ad651774145b589217156a6501547cee..3cada49647d42ac33f2d7cc2de99744156fa5582 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -22,32 +22,42 @@ from op_test import OpTest from paddle.fluid import core - -def l1loss(x, y, weight): - n = x.shape[0] - x = x.reshape((n, -1)) - y = y.reshape((n, -1)) - weight = weight.reshape((n, -1)) - return (np.abs(y - x) * weight).sum(axis=1) +# def l1loss(x, y, weight): +# n = x.shape[0] +# x = x.reshape((n, -1)) +# y = y.reshape((n, -1)) +# weight = weight.reshape((n, -1)) +# return (np.abs(y - x) * weight).sum(axis=1) +# +# +# def mse(x, y, weight): +# n = x.shape[0] +# x = x.reshape((n, -1)) +# y = y.reshape((n, -1)) +# weight = weight.reshape((n, -1)) +# return ((y - x)**2 * weight).sum(axis=1) +# +# +# def sce(x, label, weight): +# n = x.shape[0] +# x = x.reshape((n, -1)) +# label = label.reshape((n, -1)) +# weight = weight.reshape((n, -1)) +# sigmoid_x = expit(x) +# term1 = label * np.log(sigmoid_x) +# term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) +# return ((-term1 - term2) * weight).sum(axis=1) -def mse(x, y, weight): - n = x.shape[0] - x = x.reshape((n, -1)) - y = y.reshape((n, -1)) - weight = weight.reshape((n, -1)) - return ((y - x)**2 * weight).sum(axis=1) +def l1loss(x, y): + return abs(x - y) -def sce(x, label, weight): - n = x.shape[0] - x = x.reshape((n, -1)) - label = label.reshape((n, -1)) - weight = weight.reshape((n, -1)) +def sce(x, label): sigmoid_x = expit(x) term1 = label * np.log(sigmoid_x) term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) - return ((-term1 - term2) * weight).sum(axis=1) + return -term1 - term2 def box_iou(box1, box2): @@ -160,6 +170,121 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-1.0 * x)) + + +def batch_xywh_box_iou(box1, box2): + b1_left = box1[:, :, 0] - box1[:, :, 2] / 2 + b1_right = box1[:, :, 0] + box1[:, :, 2] / 2 + b1_top = box1[:, :, 1] - box1[:, :, 3] / 2 + b1_bottom = box1[:, :, 1] + box1[:, :, 3] / 2 + + b2_left = box2[:, :, 0] - box2[:, :, 2] / 2 + b2_right = box2[:, :, 0] + box2[:, :, 2] / 2 + b2_top = box2[:, :, 1] - box2[:, :, 3] / 2 + b2_bottom = box2[:, :, 1] + box2[:, :, 3] / 2 + + left = np.maximum(b1_left[:, :, np.newaxis], b2_left[:, np.newaxis, :]) + right = np.minimum(b1_right[:, :, np.newaxis], b2_right[:, np.newaxis, :]) + top = np.maximum(b1_top[:, :, np.newaxis], b2_top[:, np.newaxis, :]) + bottom = np.minimum(b1_bottom[:, :, np.newaxis], + b2_bottom[:, np.newaxis, :]) + + inter_w = np.clip(right - left, 0., 1.) + inter_h = np.clip(bottom - top, 0., 1.) + inter_area = inter_w * inter_h + + b1_area = (b1_right - b1_left) * (b1_bottom - b1_top) + b2_area = (b2_right - b2_left) * (b2_bottom - b2_top) + union = b1_area[:, :, np.newaxis] + b2_area[:, np.newaxis, :] - inter_area + + return inter_area / union + + +def YOLOv3Loss(x, gtbox, gtlabel, attrs): + n, c, h, w = x.shape + b = gtbox.shape[1] + anchors = attrs['anchors'] + an_num = len(anchors) // 2 + anchor_mask = attrs['anchor_mask'] + mask_num = len(anchor_mask) + class_num = attrs["class_num"] + ignore_thresh = attrs['ignore_thresh'] + downsample = attrs['downsample'] + input_size = downsample * h + x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) + loss = np.zeros((n)).astype('float32') + + pred_box = x[:, :, :, :, :4].copy() + grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) + grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) + pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w + pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h + + mask_anchors = [] + for m in anchor_mask: + mask_anchors.append((anchors[2 * m], anchors[2 * m + 1])) + anchors_s = np.array( + [(an_w / input_size, an_h / input_size) for an_w, an_h in mask_anchors]) + anchor_w = anchors_s[:, 0:1].reshape((1, mask_num, 1, 1)) + anchor_h = anchors_s[:, 1:2].reshape((1, mask_num, 1, 1)) + pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w + pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h + + pred_box = pred_box.reshape((n, -1, 4)) + pred_obj = x[:, :, :, :, 4].reshape((n, -1)) + objness = np.zeros(pred_box.shape[:2]) + ious = batch_xywh_box_iou(pred_box, gtbox) + ious_max = np.max(ious, axis=-1) + objness = np.where(ious_max > ignore_thresh, -np.ones_like(objness), + objness) + + gtbox_shift = gtbox.copy() + gtbox_shift[:, :, 0] = 0 + gtbox_shift[:, :, 1] = 0 + + anchors = [(anchors[2 * i], anchors[2 * i + 1]) for i in range(0, an_num)] + anchors_s = np.array( + [(an_w / input_size, an_h / input_size) for an_w, an_h in anchors]) + anchor_boxes = np.concatenate( + [np.zeros_like(anchors_s), anchors_s], axis=-1) + anchor_boxes = np.tile(anchor_boxes[np.newaxis, :, :], (n, 1, 1)) + ious = batch_xywh_box_iou(gtbox_shift, anchor_boxes) + iou_matches = np.argmax(ious, axis=-1) + for i in range(n): + for j in range(b): + if gtbox[i, j, 2:].sum() == 0: + continue + if iou_matches[i, j] not in anchor_mask: + continue + an_idx = anchor_mask.index(iou_matches[i, j]) + gi = int(gtbox[i, j, 0] * w) + gj = int(gtbox[i, j, 1] * h) + + tx = gtbox[i, j, 0] * w - gi + ty = gtbox[i, j, 1] * w - gj + tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0]) + th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1]) + scale = 2.0 - gtbox[i, j, 2] * gtbox[i, j, 3] + loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale + loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale + loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale + loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale + + objness[i, an_idx * h * w + gj * w + gi] = 1 + + for label_idx in range(class_num): + loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], + int(label_idx == gtlabel[i, j])) + + for j in range(mask_num * h * w): + if objness[i, j] >= 0: + loss[i] += sce(pred_obj[i, j], objness[i, j]) + + return loss + + class TestYolov3LossOp(OpTest): def setUp(self): self.initTestCase() @@ -171,13 +296,14 @@ class TestYolov3LossOp(OpTest): self.attrs = { "anchors": self.anchors, + "anchor_mask": self.anchor_mask, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, - "input_size": self.input_size, + "downsample": self.downsample, } self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} - self.outputs = {'Loss': YoloV3Loss(x, gtbox, gtlabel, self.attrs)} + self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)} def test_check_output(self): place = core.CPUPlace() @@ -189,15 +315,19 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.31) + max_relative_error=0.15) def initTestCase(self): - self.anchors = [12, 12] + self.anchors = [ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, + 373, 326 + ] + self.anchor_mask = [0, 1, 2] self.class_num = 5 - self.ignore_thresh = 0.5 - self.input_size = 416 - self.x_shape = (1, len(self.anchors) // 2 * (5 + self.class_num), 3, 3) - self.gtbox_shape = (1, 5, 4) + self.ignore_thresh = 0.7 + self.downsample = 32 + self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) + self.gtbox_shape = (3, 10, 4) if __name__ == "__main__":