提交 577424e5 编写于 作者: D dengkaipeng

use darknet loss and trick

上级 042fecef
......@@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes',
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(True, None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
......
......@@ -27,8 +27,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
"Input(GTBox) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("GTLabel"),
"Input(GTLabel) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("GTScore"),
"Input(GTScore) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
"Output(Loss) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(
......@@ -40,7 +38,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
auto dim_x = ctx->GetInputDim("X");
auto dim_gtbox = ctx->GetInputDim("GTBox");
auto dim_gtlabel = ctx->GetInputDim("GTLabel");
auto dim_gtscore = ctx->GetInputDim("GTScore");
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
int anchor_num = anchors.size() / 2;
auto anchor_mask = ctx->Attrs().Get<std::vector<int>>("anchor_mask");
......@@ -63,12 +60,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
"Input(GTBox) and Input(GTLabel) dim[0] should be same");
PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1],
"Input(GTBox) and Input(GTLabel) dim[1] should be same");
PADDLE_ENFORCE_EQ(dim_gtscore.size(), 2,
"Input(GTScore) should be a 2-D tensor");
PADDLE_ENFORCE_EQ(dim_gtscore[0], dim_gtbox[0],
"Input(GTBox) and Input(GTScore) dim[0] should be same");
PADDLE_ENFORCE_EQ(dim_gtscore[1], dim_gtbox[1],
"Input(GTBox) and Input(GTScore) dim[1] should be same");
PADDLE_ENFORCE_GT(anchors.size(), 0,
"Attr(anchors) length should be greater then 0.");
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
......@@ -121,11 +112,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
"This is a 2-D tensor with shape of [N, max_box_num], "
"and each element should be an integer to indicate the "
"box class id.");
AddInput("GTScore",
"The score of GTLabel, This is a 2-D tensor in same shape "
"GTLabel, and score values should in range (0, 1). This "
"input is for GTLabel score can be not 1.0 in image mixup "
"augmentation.");
AddOutput("Loss",
"The output yolov3 loss tensor, "
"This is a 1-D tensor with shape of [N]");
......@@ -157,8 +143,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("ignore_thresh",
"The ignore threshold to ignore confidence loss.")
.SetDefault(0.7);
AddAttr<bool>("use_label_smooth", "bool,default True", "use label smooth")
.SetDefault(true);
AddComment(R"DOC(
This operator generate yolov3 loss by given predict result and ground
truth boxes.
......@@ -245,7 +229,6 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op->SetInput("X", Input("X"));
op->SetInput("GTBox", Input("GTBox"));
op->SetInput("GTLabel", Input("GTLabel"));
op->SetInput("GTScore", Input("GTScore"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetInput("ObjectnessMask", Output("ObjectnessMask"));
op->SetInput("GTMatchMask", Output("GTMatchMask"));
......@@ -255,7 +238,6 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("GTBox"), {});
op->SetOutput(framework::GradVarName("GTLabel"), {});
op->SetOutput(framework::GradVarName("GTScore"), {});
return std::unique_ptr<framework::OpDesc>(op);
}
};
......
......@@ -36,11 +36,6 @@ static T SCE(T x, T label) {
return (x > 0 ? x : 0.0) - x * label + std::log(1.0 + std::exp(-std::abs(x)));
}
template <typename T>
static T L1Loss(T x, T y) {
return std::abs(y - x);
}
template <typename T>
static T L2Loss(T x, T y) {
return 0.5 * (y - x) * (y - x);
......@@ -51,11 +46,6 @@ static T SCEGrad(T x, T label) {
return 1.0 / (1.0 + std::exp(-x)) - label;
}
template <typename T>
static T L1LossGrad(T x, T y) {
return x > y ? 1.0 : -1.0;
}
template <typename T>
static T L2LossGrad(T x, T y) {
return x - y;
......@@ -131,13 +121,13 @@ template <typename T>
static void CalcBoxLocationLoss(T* loss, const T* input, Box<T> gt,
std::vector<int> anchors, int an_idx,
int box_idx, int gi, int gj, int grid_size,
int input_size, int stride, T score) {
int input_size, int stride) {
T tx = gt.x * grid_size - gi;
T ty = gt.y * grid_size - gj;
T tw = std::log(gt.w * input_size / anchors[2 * an_idx]);
T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]);
T scale = (2.0 - gt.w * gt.h) * score;
T scale = (2.0 - gt.w * gt.h);
loss[0] += SCE<T>(input[box_idx], tx) * scale;
loss[0] += SCE<T>(input[box_idx + stride], ty) * scale;
loss[0] += L2Loss<T>(input[box_idx + 2 * stride], tw) * scale;
......@@ -148,14 +138,13 @@ template <typename T>
static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input,
Box<T> gt, std::vector<int> anchors,
int an_idx, int box_idx, int gi, int gj,
int grid_size, int input_size, int stride,
T score) {
int grid_size, int input_size, int stride) {
T tx = gt.x * grid_size - gi;
T ty = gt.y * grid_size - gj;
T tw = std::log(gt.w * input_size / anchors[2 * an_idx]);
T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]);
T scale = (2.0 - gt.w * gt.h) * score;
T scale = (2.0 - gt.w * gt.h);
input_grad[box_idx] = SCEGrad<T>(input[box_idx], tx) * scale * loss;
input_grad[box_idx + stride] =
SCEGrad<T>(input[box_idx + stride], ty) * scale * loss;
......@@ -168,11 +157,10 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input,
template <typename T>
static inline void CalcLabelLoss(T* loss, const T* input, const int index,
const int label, const int class_num,
const int stride, const T pos, const T neg,
T score) {
const int stride) {
for (int i = 0; i < class_num; i++) {
T pred = input[index + i * stride];
loss[0] += SCE<T>(pred, (i == label) ? pos : neg) * score;
loss[0] += SCE<T>(pred, (i == label) ? 1.0 : 0.0);
}
}
......@@ -180,12 +168,11 @@ template <typename T>
static inline void CalcLabelLossGrad(T* input_grad, const T loss,
const T* input, const int index,
const int label, const int class_num,
const int stride, const T pos, const T neg,
T score) {
const int stride) {
for (int i = 0; i < class_num; i++) {
T pred = input[index + i * stride];
input_grad[index + i * stride] =
SCEGrad<T>(pred, (i == label) ? pos : neg) * score * loss;
SCEGrad<T>(pred, (i == label) ? 1.0 : 0.0) * loss;
}
}
......@@ -201,7 +188,7 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness,
T obj = objness[k * w + l];
if (obj > 1e-5) {
// positive sample: obj = mixup score
loss[i] += SCE<T>(input[k * w + l], 1.0) * obj;
loss[i] += SCE<T>(input[k * w + l], 1.0);
} else if (obj > -0.5) {
// negetive sample: obj = 0
loss[i] += SCE<T>(input[k * w + l], 0.0);
......@@ -226,8 +213,7 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss,
for (int l = 0; l < w; l++) {
T obj = objness[k * w + l];
if (obj > 1e-5) {
input_grad[k * w + l] =
SCEGrad<T>(input[k * w + l], 1.0) * obj * loss[i];
input_grad[k * w + l] = SCEGrad<T>(input[k * w + l], 1.0) * loss[i];
} else if (obj > -0.5) {
input_grad[k * w + l] = SCEGrad<T>(input[k * w + l], 0.0) * loss[i];
}
......@@ -263,7 +249,6 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X");
auto* gt_box = ctx.Input<Tensor>("GTBox");
auto* gt_label = ctx.Input<Tensor>("GTLabel");
auto* gt_score = ctx.Input<Tensor>("GTScore");
auto* loss = ctx.Output<Tensor>("Loss");
auto* objness_mask = ctx.Output<Tensor>("ObjectnessMask");
auto* gt_match_mask = ctx.Output<Tensor>("GTMatchMask");
......@@ -272,7 +257,6 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int class_num = ctx.Attr<int>("class_num");
float ignore_thresh = ctx.Attr<float>("ignore_thresh");
int downsample = ctx.Attr<int>("downsample");
bool use_label_smooth = ctx.Attr<bool>("use_label_smooth");
const int n = input->dims()[0];
const int h = input->dims()[2];
......@@ -285,17 +269,9 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
T label_pos = 1.0;
T label_neg = 0.0;
if (use_label_smooth) {
label_pos = 1.0 - 1.0 / static_cast<T>(class_num);
label_neg = 1.0 / static_cast<T>(class_num);
}
const T* input_data = input->data<T>();
const T* gt_box_data = gt_box->data<T>();
const int* gt_label_data = gt_label->data<int>();
const T* gt_score_data = gt_score->data<T>();
T* loss_data = loss->mutable_data<T>({n}, ctx.GetPlace());
memset(loss_data, 0, loss->numel() * sizeof(T));
T* obj_mask_data =
......@@ -376,20 +352,19 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int mask_idx = GetMaskIndex(anchor_mask, best_n);
gt_match_mask_data[i * b + t] = mask_idx;
if (mask_idx >= 0) {
T score = gt_score_data[i * b + t];
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0);
CalcBoxLocationLoss<T>(loss_data + i, input_data, gt, anchors, best_n,
box_idx, gi, gj, h, input_size, stride, score);
box_idx, gi, gj, h, input_size, stride);
int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi;
obj_mask_data[obj_idx] = score;
obj_mask_data[obj_idx] = 1.0;
int label = gt_label_data[i * b + t];
int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 5);
CalcLabelLoss<T>(loss_data + i, input_data, label_idx, label,
class_num, stride, label_pos, label_neg, score);
class_num, stride);
}
}
}
......@@ -406,7 +381,6 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X");
auto* gt_box = ctx.Input<Tensor>("GTBox");
auto* gt_label = ctx.Input<Tensor>("GTLabel");
auto* gt_score = ctx.Input<Tensor>("GTScore");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* objness_mask = ctx.Input<Tensor>("ObjectnessMask");
......@@ -415,7 +389,6 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
auto anchor_mask = ctx.Attr<std::vector<int>>("anchor_mask");
int class_num = ctx.Attr<int>("class_num");
int downsample = ctx.Attr<int>("downsample");
bool use_label_smooth = ctx.Attr<bool>("use_label_smooth");
const int n = input_grad->dims()[0];
const int c = input_grad->dims()[1];
......@@ -428,17 +401,9 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
T label_pos = 1.0;
T label_neg = 0.0;
if (use_label_smooth) {
label_pos = 1.0 - 1.0 / static_cast<T>(class_num);
label_neg = 1.0 / static_cast<T>(class_num);
}
const T* input_data = input->data<T>();
const T* gt_box_data = gt_box->data<T>();
const int* gt_label_data = gt_label->data<int>();
const T* gt_score_data = gt_score->data<T>();
const T* loss_grad_data = loss_grad->data<T>();
const T* obj_mask_data = objness_mask->data<T>();
const int* gt_match_mask_data = gt_match_mask->data<int>();
......@@ -450,24 +415,21 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
for (int t = 0; t < b; t++) {
int mask_idx = gt_match_mask_data[i * b + t];
if (mask_idx >= 0) {
T score = gt_score_data[i * b + t];
Box<T> gt = GetGtBox(gt_box_data, i, b, t);
int gi = static_cast<int>(gt.x * w);
int gj = static_cast<int>(gt.y * h);
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0);
CalcBoxLocationLossGrad<T>(input_grad_data, loss_grad_data[i],
input_data, gt, anchors,
anchor_mask[mask_idx], box_idx, gi, gj, h,
input_size, stride, score);
CalcBoxLocationLossGrad<T>(
input_grad_data, loss_grad_data[i], input_data, gt, anchors,
anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride);
int label = gt_label_data[i * b + t];
int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 5);
CalcLabelLossGrad<T>(input_grad_data, loss_grad_data[i], input_data,
label_idx, label, class_num, stride, label_pos,
label_neg, score);
label_idx, label, class_num, stride);
}
}
}
......
......@@ -412,13 +412,11 @@ def polygon_box_transform(input, name=None):
def yolov3_loss(x,
gtbox,
gtlabel,
gtscore,
anchors,
anchor_mask,
class_num,
ignore_thresh,
downsample,
use_label_smooth=True,
name=None):
"""
${comment}
......@@ -432,14 +430,11 @@ def yolov3_loss(x,
an image.
gtlabel (Variable): class id of ground truth boxes, shoud be in shape
of [N, B].
gtscore (Variable): score of gtlabel, should be in same shape with gtlabel
and score value in range (0, 1).
anchors (list|tuple): ${anchors_comment}
anchor_mask (list|tuple): ${anchor_mask_comment}
class_num (int): ${class_num_comment}
ignore_thresh (float): ${ignore_thresh_comment}
downsample (int): ${downsample_comment}
use_label_smooth(bool): ${use_label_smooth_comment}
name (string): the name of yolov3 loss
Returns:
......@@ -449,11 +444,9 @@ def yolov3_loss(x,
TypeError: Input x of yolov3_loss must be Variable
TypeError: Input gtbox of yolov3_loss must be Variable"
TypeError: Input gtlabel of yolov3_loss must be Variable"
TypeError: Input gtscore of yolov3_loss must be Variable"
TypeError: Attr anchors of yolov3_loss must be list or tuple
TypeError: Attr class_num of yolov3_loss must be an integer
TypeError: Attr ignore_thresh of yolov3_loss must be a float number
TypeError: Attr use_label_smooth of yolov3_loss must be a bool value
Examples:
.. code-block:: python
......@@ -474,16 +467,12 @@ def yolov3_loss(x,
raise TypeError("Input gtbox of yolov3_loss must be Variable")
if not isinstance(gtlabel, Variable):
raise TypeError("Input gtlabel of yolov3_loss must be Variable")
if not isinstance(gtscore, Variable):
raise TypeError("Input gtscore of yolov3_loss must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple):
raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple")
if not isinstance(class_num, int):
raise TypeError("Attr class_num of yolov3_loss must be an integer")
if not isinstance(use_label_smooth, bool):
raise TypeError("Attr ues_label_smooth of yolov3 must be a bool value")
if not isinstance(ignore_thresh, float):
raise TypeError(
"Attr ignore_thresh of yolov3_loss must be a float number")
......@@ -503,7 +492,6 @@ def yolov3_loss(x,
"class_num": class_num,
"ignore_thresh": ignore_thresh,
"downsample": downsample,
"use_label_smooth": use_label_smooth
}
helper.append_op(
......@@ -512,7 +500,6 @@ def yolov3_loss(x,
"X": x,
"GTBox": gtbox,
"GTLabel": gtlabel,
"GTScore": gtscore
},
outputs={
'Loss': loss,
......
......@@ -23,10 +23,6 @@ from op_test import OpTest
from paddle.fluid import core
def l1loss(x, y):
return abs(x - y)
def l2loss(x, y):
return 0.5 * (y - x) * (y - x)
......@@ -70,7 +66,7 @@ def batch_xywh_box_iou(box1, box2):
return inter_area / union
def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs):
def YOLOv3Loss(x, gtbox, gtlabel, attrs):
n, c, h, w = x.shape
b = gtbox.shape[1]
anchors = attrs['anchors']
......@@ -80,14 +76,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs):
class_num = attrs["class_num"]
ignore_thresh = attrs['ignore_thresh']
downsample = attrs['downsample']
use_label_smooth = attrs['use_label_smooth']
input_size = downsample * h
x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
loss = np.zeros((n)).astype('float32')
label_pos = 1.0 - 1.0 / class_num if use_label_smooth else 1.0
label_neg = 1.0 / class_num if use_label_smooth else 0.0
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
......@@ -146,22 +138,21 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs):
ty = gtbox[i, j, 1] * w - gj
tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0])
th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1])
scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) * gtscore[i, j]
scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3])
loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale
loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale
loss[i] += l2loss(x[i, an_idx, gj, gi, 2], tw) * scale
loss[i] += l2loss(x[i, an_idx, gj, gi, 3], th) * scale
objness[i, an_idx * h * w + gj * w + gi] = gtscore[i, j]
objness[i, an_idx * h * w + gj * w + gi] = 1.0
for label_idx in range(class_num):
loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], label_pos
if label_idx == gtlabel[i, j] else
label_neg) * gtscore[i, j]
loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx],
float(label_idx == gtlabel[i, j]))
for j in range(mask_num * h * w):
if objness[i, j] > 0:
loss[i] += sce(pred_obj[i, j], 1.0) * objness[i, j]
loss[i] += sce(pred_obj[i, j], 1.0)
elif objness[i, j] == 0:
loss[i] += sce(pred_obj[i, j], 0.0)
......@@ -176,7 +167,6 @@ class TestYolov3LossOp(OpTest):
x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32'))
gtbox = np.random.random(size=self.gtbox_shape).astype('float32')
gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2])
gtscore = np.random.random(self.gtbox_shape[:2]).astype('float32')
gtmask = np.random.randint(0, 2, self.gtbox_shape[:2])
gtbox = gtbox * gtmask[:, :, np.newaxis]
gtlabel = gtlabel * gtmask
......@@ -187,17 +177,14 @@ class TestYolov3LossOp(OpTest):
"class_num": self.class_num,
"ignore_thresh": self.ignore_thresh,
"downsample": self.downsample,
"use_label_smooth": self.use_label_smooth,
}
self.inputs = {
'X': x,
'GTBox': gtbox.astype('float32'),
'GTLabel': gtlabel.astype('int32'),
'GTScore': gtscore.astype('float32')
}
loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, gtscore,
self.attrs)
loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs)
self.outputs = {
'Loss': loss,
'ObjectnessMask': objness,
......@@ -213,7 +200,7 @@ class TestYolov3LossOp(OpTest):
self.check_grad_with_place(
place, ['X'],
'Loss',
no_grad_set=set(["GTBox", "GTLabel", "GTScore"]),
no_grad_set=set(["GTBox", "GTLabel"]),
max_relative_error=0.3)
def initTestCase(self):
......@@ -224,12 +211,6 @@ class TestYolov3LossOp(OpTest):
self.downsample = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4)
self.use_label_smooth = True
class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp):
def set_label_smooth(self):
self.use_label_smooth = False
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册