提交 f115eb0d 编写于 作者: D dengkaipeng

enhance api. test=develop

上级 95d5060d
......@@ -288,7 +288,7 @@ paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'i
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'anchors', 'class_num', 'ignore_thresh', 'lambda_xy', 'lambda_wh', 'lambda_conf_obj', 'lambda_conf_noobj', 'lambda_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
......
......@@ -25,11 +25,14 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
"Input(X) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("GTBox"),
"Input(GTBox) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("GTLabel"),
"Input(GTLabel) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
"Output(Loss) of Yolov3LossOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
auto dim_gt = ctx->GetInputDim("GTBox");
auto dim_gtbox = ctx->GetInputDim("GTBox");
auto dim_gtlabel = ctx->GetInputDim("GTLabel");
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
auto class_num = ctx->Attrs().Get<int>("class_num");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
......@@ -38,8 +41,15 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num),
"Input(X) dim[1] should be equal to (anchor_number * (5 "
"+ class_num)).");
PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor");
PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5");
PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3,
"Input(GTBox) should be a 3-D tensor");
PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5");
PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2,
"Input(GTBox) should be a 2-D tensor");
PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0],
"Input(GTBox) and Input(GTLabel) dim[0] should be same");
PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1],
"Input(GTBox) and Input(GTLabel) dim[1] should be same");
PADDLE_ENFORCE_GT(anchors.size(), 0,
"Attr(anchors) length should be greater then 0.");
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
......@@ -73,11 +83,15 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
"The input tensor of ground truth boxes, "
"This is a 3-D tensor with shape of [N, max_box_num, 5], "
"max_box_num is the max number of boxes in each image, "
"In the third dimention, stores label, x, y, w, h, "
"label is an integer to specify box class, x, y is the "
"center cordinate of boxes and w, h is the width and height"
"and x, y, w, h should be divided by input image height to "
"scale to [0, 1].");
"In the third dimention, stores x, y, w, h coordinates, "
"x, y is the center cordinate of boxes and w, h is the "
"width and height and x, y, w, h should be divided by "
"input image height to scale to [0, 1].");
AddInput("GTLabel",
"The input tensor of ground truth label, "
"This is a 2-D tensor with shape of [N, max_box_num], "
"and each element shoudl be an integer to indicate the "
"box class id.");
AddOutput("Loss",
"The output yolov3 loss tensor, "
"This is a 1-D tensor with shape of [1]");
......@@ -88,19 +102,19 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
"it will be parsed pair by pair.");
AddAttr<float>("ignore_thresh",
"The ignore threshold to ignore confidence loss.");
AddAttr<float>("lambda_xy", "The weight of x, y location loss.")
AddAttr<float>("loss_weight_xy", "The weight of x, y location loss.")
.SetDefault(1.0);
AddAttr<float>("lambda_wh", "The weight of w, h location loss.")
AddAttr<float>("loss_weight_wh", "The weight of w, h location loss.")
.SetDefault(1.0);
AddAttr<float>(
"lambda_conf_obj",
"loss_weight_conf_target",
"The weight of confidence score loss in locations with target object.")
.SetDefault(1.0);
AddAttr<float>("lambda_conf_noobj",
AddAttr<float>("loss_weight_conf_notarget",
"The weight of confidence score loss in locations without "
"target object.")
.SetDefault(1.0);
AddAttr<float>("lambda_class", "The weight of classification loss.")
AddAttr<float>("loss_weight_class", "The weight of classification loss.")
.SetDefault(1.0);
AddComment(R"DOC(
This operator generate yolov3 loss by given predict result and ground
......@@ -141,10 +155,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
Final loss will be represented as follow.
$$
loss = \lambda_{xy} * loss_{xy} + \lambda_{wh} * loss_{wh}
+ \lambda_{conf_obj} * loss_{conf_obj}
+ \lambda_{conf_noobj} * loss_{conf_noobj}
+ \lambda_{class} * loss_{class}
loss = \loss_weight_{xy} * loss_{xy} + \loss_weight_{wh} * loss_{wh}
+ \loss_weight_{conf_target} * loss_{conf_target}
+ \loss_weight_{conf_notarget} * loss_{conf_notarget}
+ \loss_weight_{class} * loss_{class}
$$
)DOC");
}
......@@ -182,12 +196,14 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op->SetType("yolov3_loss_grad");
op->SetInput("X", Input("X"));
op->SetInput("GTBox", Input("GTBox"));
op->SetInput("GTLabel", Input("GTLabel"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("GTBox"), {});
op->SetOutput(framework::GradVarName("GTLabel"), {});
return std::unique_ptr<framework::OpDesc>(op);
}
};
......
......@@ -186,15 +186,17 @@ static T CalcBoxIoU(std::vector<T> box1, std::vector<T> box2) {
}
template <typename T>
static void PreProcessGTBox(const Tensor& gt_boxes, const float ignore_thresh,
std::vector<int> anchors, const int grid_size,
Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx,
Tensor* ty, Tensor* tw, Tensor* th, Tensor* tconf,
static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label,
const float ignore_thresh, std::vector<int> anchors,
const int grid_size, Tensor* obj_mask,
Tensor* noobj_mask, Tensor* tx, Tensor* ty,
Tensor* tw, Tensor* th, Tensor* tconf,
Tensor* tclass) {
const int n = gt_boxes.dims()[0];
const int b = gt_boxes.dims()[1];
const int n = gt_box.dims()[0];
const int b = gt_box.dims()[1];
const int anchor_num = anchors.size() / 2;
auto gt_boxes_t = EigenTensor<T, 3>::From(gt_boxes);
auto gt_box_t = EigenTensor<T, 3>::From(gt_box);
auto gt_label_t = EigenTensor<int, 2>::From(gt_label);
auto obj_mask_t = EigenTensor<int, 4>::From(*obj_mask).setConstant(0);
auto noobj_mask_t = EigenTensor<int, 4>::From(*noobj_mask).setConstant(1);
auto tx_t = EigenTensor<T, 4>::From(*tx).setConstant(0.0);
......@@ -206,28 +208,27 @@ static void PreProcessGTBox(const Tensor& gt_boxes, const float ignore_thresh,
for (int i = 0; i < n; i++) {
for (int j = 0; j < b; j++) {
if (isZero<T>(gt_boxes_t(i, j, 0)) && isZero<T>(gt_boxes_t(i, j, 1)) &&
isZero<T>(gt_boxes_t(i, j, 2)) && isZero<T>(gt_boxes_t(i, j, 3)) &&
isZero<T>(gt_boxes_t(i, j, 4))) {
if (isZero<T>(gt_box_t(i, j, 0)) && isZero<T>(gt_box_t(i, j, 1)) &&
isZero<T>(gt_box_t(i, j, 2)) && isZero<T>(gt_box_t(i, j, 3))) {
continue;
}
int gt_label = static_cast<int>(gt_boxes_t(i, j, 0));
T gx = gt_boxes_t(i, j, 1) * grid_size;
T gy = gt_boxes_t(i, j, 2) * grid_size;
T gw = gt_boxes_t(i, j, 3) * grid_size;
T gh = gt_boxes_t(i, j, 4) * grid_size;
int cur_label = gt_label_t(i, j);
T gx = gt_box_t(i, j, 0) * grid_size;
T gy = gt_box_t(i, j, 1) * grid_size;
T gw = gt_box_t(i, j, 2) * grid_size;
T gh = gt_box_t(i, j, 3) * grid_size;
int gi = static_cast<int>(gx);
int gj = static_cast<int>(gy);
T max_iou = static_cast<T>(0);
T iou;
int best_an_index = -1;
std::vector<T> gt_box({0, 0, gw, gh});
std::vector<T> gt_box_shape({0, 0, gw, gh});
for (int an_idx = 0; an_idx < anchor_num; an_idx++) {
std::vector<T> anchor_shape({0, 0, static_cast<T>(anchors[2 * an_idx]),
static_cast<T>(anchors[2 * an_idx + 1])});
iou = CalcBoxIoU<T>(gt_box, anchor_shape);
iou = CalcBoxIoU<T>(gt_box_shape, anchor_shape);
if (iou > max_iou) {
max_iou = iou;
best_an_index = an_idx;
......@@ -242,7 +243,7 @@ static void PreProcessGTBox(const Tensor& gt_boxes, const float ignore_thresh,
ty_t(i, best_an_index, gj, gi) = gy - gj;
tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]);
th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]);
tclass_t(i, best_an_index, gj, gi, gt_label) = 1;
tclass_t(i, best_an_index, gj, gi, cur_label) = 1;
tconf_t(i, best_an_index, gj, gi) = 1;
}
}
......@@ -267,10 +268,10 @@ static void AddAllGradToInputGrad(
Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y,
const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x,
const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h,
const Tensor& grad_conf_obj, const Tensor& grad_conf_noobj,
const Tensor& grad_class, const int class_num, const float lambda_xy,
const float lambda_wh, const float lambda_conf_obj,
const float lambda_conf_noobj, const float lambda_class) {
const Tensor& grad_conf_target, const Tensor& grad_conf_notarget,
const Tensor& grad_class, const int class_num, const float loss_weight_xy,
const float loss_weight_wh, const float loss_weight_conf_target,
const float loss_weight_conf_notarget, const float loss_weight_class) {
const int n = pred_x.dims()[0];
const int an_num = pred_x.dims()[1];
const int h = pred_x.dims()[2];
......@@ -285,8 +286,8 @@ static void AddAllGradToInputGrad(
auto grad_y_t = EigenTensor<T, 4>::From(grad_y);
auto grad_w_t = EigenTensor<T, 4>::From(grad_w);
auto grad_h_t = EigenTensor<T, 4>::From(grad_h);
auto grad_conf_obj_t = EigenTensor<T, 4>::From(grad_conf_obj);
auto grad_conf_noobj_t = EigenTensor<T, 4>::From(grad_conf_noobj);
auto grad_conf_target_t = EigenTensor<T, 4>::From(grad_conf_target);
auto grad_conf_notarget_t = EigenTensor<T, 4>::From(grad_conf_notarget);
auto grad_class_t = EigenTensor<T, 5>::From(grad_class);
for (int i = 0; i < n; i++) {
......@@ -295,25 +296,26 @@ static void AddAllGradToInputGrad(
for (int l = 0; l < w; l++) {
grad_t(i, j * attr_num, k, l) =
grad_x_t(i, j, k, l) * pred_x_t(i, j, k, l) *
(1.0 - pred_x_t(i, j, k, l)) * loss * lambda_xy;
(1.0 - pred_x_t(i, j, k, l)) * loss * loss_weight_xy;
grad_t(i, j * attr_num + 1, k, l) =
grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) *
(1.0 - pred_y_t(i, j, k, l)) * loss * lambda_xy;
(1.0 - pred_y_t(i, j, k, l)) * loss * loss_weight_xy;
grad_t(i, j * attr_num + 2, k, l) =
grad_w_t(i, j, k, l) * loss * lambda_wh;
grad_w_t(i, j, k, l) * loss * loss_weight_wh;
grad_t(i, j * attr_num + 3, k, l) =
grad_h_t(i, j, k, l) * loss * lambda_wh;
grad_h_t(i, j, k, l) * loss * loss_weight_wh;
grad_t(i, j * attr_num + 4, k, l) =
grad_conf_obj_t(i, j, k, l) * pred_conf_t(i, j, k, l) *
(1.0 - pred_conf_t(i, j, k, l)) * loss * lambda_conf_obj;
grad_conf_target_t(i, j, k, l) * pred_conf_t(i, j, k, l) *
(1.0 - pred_conf_t(i, j, k, l)) * loss * loss_weight_conf_target;
grad_t(i, j * attr_num + 4, k, l) +=
grad_conf_noobj_t(i, j, k, l) * pred_conf_t(i, j, k, l) *
(1.0 - pred_conf_t(i, j, k, l)) * loss * lambda_conf_noobj;
grad_conf_notarget_t(i, j, k, l) * pred_conf_t(i, j, k, l) *
(1.0 - pred_conf_t(i, j, k, l)) * loss *
loss_weight_conf_notarget;
for (int c = 0; c < class_num; c++) {
grad_t(i, j * attr_num + 5 + c, k, l) =
grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) *
(1.0 - pred_class_t(i, j, k, l, c)) * loss * lambda_class;
(1.0 - pred_class_t(i, j, k, l, c)) * loss * loss_weight_class;
}
}
}
......@@ -326,16 +328,18 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* gt_boxes = ctx.Input<Tensor>("GTBox");
auto* gt_box = ctx.Input<Tensor>("GTBox");
auto* gt_label = ctx.Input<Tensor>("GTLabel");
auto* loss = ctx.Output<Tensor>("Loss");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float ignore_thresh = ctx.Attr<float>("ignore_thresh");
float lambda_xy = ctx.Attr<float>("lambda_xy");
float lambda_wh = ctx.Attr<float>("lambda_wh");
float lambda_conf_obj = ctx.Attr<float>("lambda_conf_obj");
float lambda_conf_noobj = ctx.Attr<float>("lambda_conf_noobj");
float lambda_class = ctx.Attr<float>("lambda_class");
float loss_weight_xy = ctx.Attr<float>("loss_weight_xy");
float loss_weight_wh = ctx.Attr<float>("loss_weight_wh");
float loss_weight_conf_target = ctx.Attr<float>("loss_weight_conf_target");
float loss_weight_conf_notarget =
ctx.Attr<float>("loss_weight_conf_notarget");
float loss_weight_class = ctx.Attr<float>("loss_weight_class");
const int n = input->dims()[0];
const int h = input->dims()[2];
......@@ -363,7 +367,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
PreProcessGTBox<T>(*gt_boxes, ignore_thresh, anchors, h, &obj_mask,
PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask,
&noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass);
Tensor obj_mask_expand;
......@@ -375,15 +379,16 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
T loss_y = CalcMSEWithMask<T>(pred_y, ty, obj_mask);
T loss_w = CalcMSEWithMask<T>(pred_w, tw, obj_mask);
T loss_h = CalcMSEWithMask<T>(pred_h, th, obj_mask);
T loss_conf_obj = CalcBCEWithMask<T>(pred_conf, tconf, obj_mask);
T loss_conf_noobj = CalcBCEWithMask<T>(pred_conf, tconf, noobj_mask);
T loss_conf_target = CalcBCEWithMask<T>(pred_conf, tconf, obj_mask);
T loss_conf_notarget = CalcBCEWithMask<T>(pred_conf, tconf, noobj_mask);
T loss_class = CalcBCEWithMask<T>(pred_class, tclass, obj_mask_expand);
auto* loss_data = loss->mutable_data<T>({1}, ctx.GetPlace());
loss_data[0] =
lambda_xy * (loss_x + loss_y) + lambda_wh * (loss_w + loss_h) +
lambda_conf_obj * loss_conf_obj + lambda_conf_noobj * loss_conf_noobj +
lambda_class * loss_class;
loss_data[0] = loss_weight_xy * (loss_x + loss_y) +
loss_weight_wh * (loss_w + loss_h) +
loss_weight_conf_target * loss_conf_target +
loss_weight_conf_notarget * loss_conf_notarget +
loss_weight_class * loss_class;
}
};
......@@ -392,18 +397,20 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* gt_boxes = ctx.Input<Tensor>("GTBox");
auto* gt_box = ctx.Input<Tensor>("GTBox");
auto* gt_label = ctx.Input<Tensor>("GTLabel");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float ignore_thresh = ctx.Attr<float>("ignore_thresh");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
const T loss = output_grad->data<T>()[0];
float lambda_xy = ctx.Attr<float>("lambda_xy");
float lambda_wh = ctx.Attr<float>("lambda_wh");
float lambda_conf_obj = ctx.Attr<float>("lambda_conf_obj");
float lambda_conf_noobj = ctx.Attr<float>("lambda_conf_noobj");
float lambda_class = ctx.Attr<float>("lambda_class");
float loss_weight_xy = ctx.Attr<float>("loss_weight_xy");
float loss_weight_wh = ctx.Attr<float>("loss_weight_wh");
float loss_weight_conf_target = ctx.Attr<float>("loss_weight_conf_target");
float loss_weight_conf_notarget =
ctx.Attr<float>("loss_weight_conf_notarget");
float loss_weight_class = ctx.Attr<float>("loss_weight_class");
const int n = input->dims()[0];
const int c = input->dims()[1];
......@@ -432,7 +439,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
PreProcessGTBox<T>(*gt_boxes, ignore_thresh, anchors, h, &obj_mask,
PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask,
&noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass);
Tensor obj_mask_expand;
......@@ -441,13 +448,13 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask);
Tensor grad_x, grad_y, grad_w, grad_h;
Tensor grad_conf_obj, grad_conf_noobj, grad_class;
Tensor grad_conf_target, grad_conf_notarget, grad_class;
grad_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_conf_obj.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_conf_noobj.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_conf_target.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_conf_notarget.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
T obj_mf = CalcMaskPointNum<int>(obj_mask);
T noobj_mf = CalcMaskPointNum<int>(noobj_mask);
......@@ -456,8 +463,9 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
CalcMSEGradWithMask<T>(&grad_y, pred_y, ty, obj_mask, obj_mf);
CalcMSEGradWithMask<T>(&grad_w, pred_w, tw, obj_mask, obj_mf);
CalcMSEGradWithMask<T>(&grad_h, pred_h, th, obj_mask, obj_mf);
CalcBCEGradWithMask<T>(&grad_conf_obj, pred_conf, tconf, obj_mask, obj_mf);
CalcBCEGradWithMask<T>(&grad_conf_noobj, pred_conf, tconf, noobj_mask,
CalcBCEGradWithMask<T>(&grad_conf_target, pred_conf, tconf, obj_mask,
obj_mf);
CalcBCEGradWithMask<T>(&grad_conf_notarget, pred_conf, tconf, noobj_mask,
noobj_mf);
CalcBCEGradWithMask<T>(&grad_class, pred_class, tclass, obj_mask_expand,
obj_expand_mf);
......@@ -465,8 +473,9 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
AddAllGradToInputGrad<T>(
input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y,
grad_w, grad_h, grad_conf_obj, grad_conf_noobj, grad_class, class_num,
lambda_xy, lambda_wh, lambda_conf_obj, lambda_conf_noobj, lambda_class);
grad_w, grad_h, grad_conf_target, grad_conf_notarget, grad_class,
class_num, loss_weight_xy, loss_weight_wh, loss_weight_conf_target,
loss_weight_conf_notarget, loss_weight_class);
}
};
......
......@@ -409,32 +409,36 @@ def polygon_box_transform(input, name=None):
@templatedoc(op_type="yolov3_loss")
def yolov3_loss(x,
gtbox,
gtlabel,
anchors,
class_num,
ignore_thresh,
lambda_xy=None,
lambda_wh=None,
lambda_conf_obj=None,
lambda_conf_noobj=None,
lambda_class=None,
loss_weight_xy=None,
loss_weight_wh=None,
loss_weight_conf_target=None,
loss_weight_conf_notarget=None,
loss_weight_class=None,
name=None):
"""
${comment}
Args:
x (Variable): ${x_comment}
gtbox (Variable): groud truth boxes, shoulb be in shape of [N, B, 5],
in the third dimenstion, class_id, x, y, w, h should
be stored and x, y, w, h should be relative valud of
input image.
gtbox (Variable): groud truth boxes, should be in shape of [N, B, 4],
in the third dimenstion, x, y, w, h should be stored
and x, y, w, h should be relative value of input image.
N is the batch number and B is the max box number in
an image.
gtlabel (Variable): class id of ground truth boxes, shoud be ins shape
of [N, B].
anchors (list|tuple): ${anchors_comment}
class_num (int): ${class_num_comment}
ignore_thresh (float): ${ignore_thresh_comment}
lambda_xy (float|None): ${lambda_xy_comment}
lambda_wh (float|None): ${lambda_wh_comment}
lambda_conf_obj (float|None): ${lambda_conf_obj_comment}
lambda_conf_noobj (float|None): ${lambda_conf_noobj_comment}
lambda_class (float|None): ${lambda_class_comment}
loss_weight_xy (float|None): ${loss_weight_xy_comment}
loss_weight_wh (float|None): ${loss_weight_wh_comment}
loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment}
loss_weight_conf_notarget (float|None): ${loss_weight_conf_notarget_comment}
loss_weight_class (float|None): ${loss_weight_class_comment}
name (string): the name of yolov3 loss
Returns:
......@@ -443,6 +447,7 @@ def yolov3_loss(x,
Raises:
TypeError: Input x of yolov3_loss must be Variable
TypeError: Input gtbox of yolov3_loss must be Variable"
TypeError: Input gtlabel of yolov3_loss must be Variable"
TypeError: Attr anchors of yolov3_loss must be list or tuple
TypeError: Attr class_num of yolov3_loss must be an integer
TypeError: Attr ignore_thresh of yolov3_loss must be a float number
......@@ -450,8 +455,9 @@ def yolov3_loss(x,
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[10, 255, 13, 13], dtype='float32')
gtbox = fluid.layers.data(name='gtbox', shape=[10, 6, 5], dtype='float32')
x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32')
gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32')
anchors = [10, 13, 16, 30, 33, 23]
loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80
anchors=anchors, ignore_thresh=0.5)
......@@ -462,6 +468,8 @@ def yolov3_loss(x,
raise TypeError("Input x of yolov3_loss must be Variable")
if not isinstance(gtbox, Variable):
raise TypeError("Input gtbox of yolov3_loss must be Variable")
if not isinstance(gtlabel, Variable):
raise TypeError("Input gtlabel of yolov3_loss must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
if not isinstance(class_num, int):
......@@ -482,21 +490,24 @@ def yolov3_loss(x,
"ignore_thresh": ignore_thresh,
}
if lambda_xy is not None and isinstance(lambda_xy, float):
self.attrs['lambda_xy'] = lambda_xy
if lambda_wh is not None and isinstance(lambda_wh, float):
self.attrs['lambda_wh'] = lambda_wh
if lambda_conf_obj is not None and isinstance(lambda_conf_obj, float):
self.attrs['lambda_conf_obj'] = lambda_conf_obj
if lambda_conf_noobj is not None and isinstance(lambda_conf_noobj, float):
self.attrs['lambda_conf_noobj'] = lambda_conf_noobj
if lambda_class is not None and isinstance(lambda_class, float):
self.attrs['lambda_class'] = lambda_class
if loss_weight_xy is not None and isinstance(loss_weight_xy, float):
self.attrs['loss_weight_xy'] = loss_weight_xy
if loss_weight_wh is not None and isinstance(loss_weight_wh, float):
self.attrs['loss_weight_wh'] = loss_weight_wh
if loss_weight_conf_target is not None and isinstance(
loss_weight_conf_target, float):
self.attrs['loss_weight_conf_target'] = loss_weight_conf_target
if loss_weight_conf_notarget is not None and isinstance(
loss_weight_conf_notarget, float):
self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget
if loss_weight_class is not None and isinstance(loss_weight_class, float):
self.attrs['loss_weight_class'] = loss_weight_class
helper.append_op(
type='yolov3_loss',
inputs={'X': x,
"GTBox": gtbox},
inputs={"X": x,
"GTBox": gtbox,
"GTLabel": gtlabel},
outputs={'Loss': loss},
attrs=attrs)
return loss
......
......@@ -366,5 +366,18 @@ class TestGenerateProposals(unittest.TestCase):
print(rpn_rois.shape)
class TestYoloDetection(unittest.TestCase):
def test_yolov3_loss(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32')
gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32')
loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10,
0.5)
self.assertIsNotNone(loss)
if __name__ == '__main__':
unittest.main()
......@@ -911,15 +911,6 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(data_1)
print(str(program))
def test_yolov3_loss(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
gtbox = layers.data(name='gtbox', shape=[10, 5], dtype='float32')
loss = layers.yolov3_loss(x, gtbox, [10, 13, 30, 13], 10, 0.5)
self.assertIsNotNone(loss)
def test_bilinear_tensor_product_layer(self):
program = Program()
with program_guard(program):
......
......@@ -66,7 +66,7 @@ def box_iou(box1, box2):
return inter_area / (b1_area + b2_area + inter_area)
def build_target(gtboxs, attrs, grid_size):
def build_target(gtboxs, gtlabel, attrs, grid_size):
n, b, _ = gtboxs.shape
ignore_thresh = attrs["ignore_thresh"]
anchors = attrs["anchors"]
......@@ -87,11 +87,11 @@ def build_target(gtboxs, attrs, grid_size):
if gtboxs[i, j, :].sum() == 0:
continue
gt_label = int(gtboxs[i, j, 0])
gx = gtboxs[i, j, 1] * grid_size
gy = gtboxs[i, j, 2] * grid_size
gw = gtboxs[i, j, 3] * grid_size
gh = gtboxs[i, j, 4] * grid_size
gt_label = gtlabel[i, j]
gx = gtboxs[i, j, 0] * grid_size
gy = gtboxs[i, j, 1] * grid_size
gw = gtboxs[i, j, 2] * grid_size
gh = gtboxs[i, j, 3] * grid_size
gi = int(gx)
gj = int(gy)
......@@ -121,7 +121,7 @@ def build_target(gtboxs, attrs, grid_size):
return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask)
def YoloV3Loss(x, gtbox, attrs):
def YoloV3Loss(x, gtbox, gtlabel, attrs):
n, c, h, w = x.shape
an_num = len(attrs['anchors']) // 2
class_num = attrs["class_num"]
......@@ -134,7 +134,7 @@ def YoloV3Loss(x, gtbox, attrs):
pred_cls = sigmoid(x[:, :, :, :, 5:])
tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target(
gtbox, attrs, x.shape[2])
gtbox, gtlabel, attrs, x.shape[2])
obj_mask_expand = np.tile(
np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num'])))
......@@ -142,73 +142,73 @@ def YoloV3Loss(x, gtbox, attrs):
loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum())
loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum())
loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum())
loss_conf_obj = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask)
loss_conf_noobj = bce(pred_conf * noobj_mask, tconf * noobj_mask,
noobj_mask)
loss_conf_target = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask)
loss_conf_notarget = bce(pred_conf * noobj_mask, tconf * noobj_mask,
noobj_mask)
loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand,
obj_mask_expand)
return attrs['lambda_xy'] * (loss_x + loss_y) \
+ attrs['lambda_wh'] * (loss_w + loss_h) \
+ attrs['lambda_conf_obj'] * loss_conf_obj \
+ attrs['lambda_conf_noobj'] * loss_conf_noobj \
+ attrs['lambda_class'] * loss_class
return attrs['loss_weight_xy'] * (loss_x + loss_y) \
+ attrs['loss_weight_wh'] * (loss_w + loss_h) \
+ attrs['loss_weight_conf_target'] * loss_conf_target \
+ attrs['loss_weight_conf_notarget'] * loss_conf_notarget \
+ attrs['loss_weight_class'] * loss_class
class TestYolov3LossOp(OpTest):
def setUp(self):
self.lambda_xy = 1.0
self.lambda_wh = 1.0
self.lambda_conf_obj = 1.0
self.lambda_conf_noobj = 1.0
self.lambda_class = 1.0
self.loss_weight_xy = 1.0
self.loss_weight_wh = 1.0
self.loss_weight_conf_target = 1.0
self.loss_weight_conf_notarget = 1.0
self.loss_weight_class = 1.0
self.initTestCase()
self.op_type = 'yolov3_loss'
x = np.random.random(size=self.x_shape).astype('float32')
gtbox = np.random.random(size=self.gtbox_shape).astype('float32')
gtbox[:, :, 0] = np.random.randint(0, self.class_num,
self.gtbox_shape[:2])
gtlabel = np.random.randint(0, self.class_num,
self.gtbox_shape[:2]).astype('int32')
self.attrs = {
"anchors": self.anchors,
"class_num": self.class_num,
"ignore_thresh": self.ignore_thresh,
"lambda_xy": self.lambda_xy,
"lambda_wh": self.lambda_wh,
"lambda_conf_obj": self.lambda_conf_obj,
"lambda_conf_noobj": self.lambda_conf_noobj,
"lambda_class": self.lambda_class,
"loss_weight_xy": self.loss_weight_xy,
"loss_weight_wh": self.loss_weight_wh,
"loss_weight_conf_target": self.loss_weight_conf_target,
"loss_weight_conf_notarget": self.loss_weight_conf_notarget,
"loss_weight_class": self.loss_weight_class,
}
self.inputs = {'X': x, 'GTBox': gtbox}
self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel}
self.outputs = {
'Loss':
np.array([YoloV3Loss(x, gtbox, self.attrs)]).astype('float32')
'Loss': np.array(
[YoloV3Loss(x, gtbox, gtlabel, self.attrs)]).astype('float32')
}
def test_check_output(self):
place = core.CPUPlace()
self.check_output_with_place(place, atol=1e-3)
# def test_check_grad_ignore_gtbox(self):
# place = core.CPUPlace()
# self.check_grad_with_place(
# place, ['X'],
# 'Loss',
# no_grad_set=set("GTBox"),
# max_relative_error=0.06)
def test_check_grad_ignore_gtbox(self):
place = core.CPUPlace()
self.check_grad_with_place(
place, ['X'],
'Loss',
no_grad_set=set("GTBox"),
max_relative_error=0.06)
def initTestCase(self):
self.anchors = [10, 13, 12, 12]
self.class_num = 10
self.ignore_thresh = 0.5
self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7)
self.gtbox_shape = (5, 5, 5)
self.lambda_xy = 2.5
self.lambda_wh = 0.8
self.lambda_conf_obj = 1.5
self.lambda_conf_noobj = 0.5
self.lambda_class = 1.2
self.gtbox_shape = (5, 10, 4)
self.loss_weight_xy = 2.5
self.loss_weight_wh = 0.8
self.loss_weight_conf_target = 1.5
self.loss_weight_conf_notarget = 0.5
self.loss_weight_class = 1.2
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册