提交 6c5a5d07 编写于 作者: D dengkaipeng

format code. test=develop

上级 e7e4f084
......@@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes',
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'input_size', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
......
......@@ -26,110 +26,9 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
using Array5 = Eigen::DSizes<int64_t, 5>;
template <typename T>
static inline bool isZero(T x) {
return fabs(x) < 1e-6;
}
template <typename T>
static T CalcBoxIoU(std::vector<T> box1, std::vector<T> box2) {
T b1_x1 = box1[0] - box1[2] / 2;
T b1_x2 = box1[0] + box1[2] / 2;
T b1_y1 = box1[1] - box1[3] / 2;
T b1_y2 = box1[1] + box1[3] / 2;
T b2_x1 = box2[0] - box2[2] / 2;
T b2_x2 = box2[0] + box2[2] / 2;
T b2_y1 = box2[1] - box2[3] / 2;
T b2_y2 = box2[1] + box2[3] / 2;
T b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1);
T b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1);
T inter_rect_x1 = std::max(b1_x1, b2_x1);
T inter_rect_y1 = std::max(b1_y1, b2_y1);
T inter_rect_x2 = std::min(b1_x2, b2_x2);
T inter_rect_y2 = std::min(b1_y2, b2_y2);
T inter_area = std::max(inter_rect_x2 - inter_rect_x1, static_cast<T>(0.0)) *
std::max(inter_rect_y2 - inter_rect_y1, static_cast<T>(0.0));
return inter_area / (b1_area + b2_area - inter_area);
}
template <typename T>
static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label,
const float ignore_thresh, std::vector<int> anchors,
const int input_size, const int grid_size,
Tensor* conf_mask, Tensor* obj_mask, Tensor* tx,
Tensor* ty, Tensor* tw, Tensor* th, Tensor* tweight,
Tensor* tconf, Tensor* tclass) {
const int n = gt_box.dims()[0];
const int b = gt_box.dims()[1];
const int an_num = anchors.size() / 2;
const int h = tclass->dims()[2];
const int w = tclass->dims()[3];
const int class_num = tclass->dims()[4];
const T* gt_box_data = gt_box.data<T>();
const int* gt_label_data = gt_label.data<int>();
T* conf_mask_data = conf_mask->data<T>();
T* obj_mask_data = obj_mask->data<T>();
T* tx_data = tx->data<T>();
T* ty_data = ty->data<T>();
T* tw_data = tw->data<T>();
T* th_data = th->data<T>();
T* tweight_data = tweight->data<T>();
T* tconf_data = tconf->data<T>();
T* tclass_data = tclass->data<T>();
for (int i = 0; i < n; i++) {
for (int j = 0; j < b; j++) {
int box_idx = (i * b + j) * 4;
if (isZero<T>(gt_box_data[box_idx + 2]) &&
isZero<T>(gt_box_data[box_idx + 3])) {
continue;
}
int cur_label = gt_label_data[i * b + j];
T gx = gt_box_data[box_idx] * grid_size;
T gy = gt_box_data[box_idx + 1] * grid_size;
T gw = gt_box_data[box_idx + 2] * input_size;
T gh = gt_box_data[box_idx + 3] * input_size;
int gi = static_cast<int>(gx);
int gj = static_cast<int>(gy);
T max_iou = static_cast<T>(0);
T iou;
int best_an_index = -1;
std::vector<T> gt_box_shape({0, 0, gw, gh});
for (int an_idx = 0; an_idx < an_num; an_idx++) {
std::vector<T> anchor_shape({0, 0, static_cast<T>(anchors[2 * an_idx]),
static_cast<T>(anchors[2 * an_idx + 1])});
iou = CalcBoxIoU<T>(gt_box_shape, anchor_shape);
if (iou > max_iou) {
max_iou = iou;
best_an_index = an_idx;
}
if (iou > ignore_thresh) {
int conf_idx = ((i * an_num + an_idx) * h + gj) * w + gi;
conf_mask_data[conf_idx] = static_cast<T>(0.0);
}
}
int obj_idx = ((i * an_num + best_an_index) * h + gj) * w + gi;
conf_mask_data[obj_idx] = static_cast<T>(1.0);
obj_mask_data[obj_idx] = static_cast<T>(1.0);
tx_data[obj_idx] = gx - gi;
ty_data[obj_idx] = gy - gj;
tw_data[obj_idx] = log(gw / anchors[2 * best_an_index]);
th_data[obj_idx] = log(gh / anchors[2 * best_an_index + 1]);
tweight_data[obj_idx] =
2.0 - gt_box_data[box_idx + 2] * gt_box_data[box_idx + 3];
tconf_data[obj_idx] = static_cast<T>(1.0);
tclass_data[obj_idx * class_num + cur_label] = static_cast<T>(1.0);
}
}
static inline bool LessEqualZero(T x) {
return x < 1e-6;
}
template <typename T>
......@@ -152,177 +51,8 @@ static T L1LossGrad(T x, T y) {
return x > y ? 1.0 : -1.0;
}
template <typename T>
static void CalcSCE(T* loss_data, const T* input, const T* target,
const T* weight, const T* mask, const int n,
const int an_num, const int grid_num, const int class_num,
const int num) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < grid_num; k++) {
int sub_idx = k * num;
for (int l = 0; l < num; l++) {
loss_data[i] += SCE<T>(input[l * grid_num + k], target[sub_idx + l]) *
weight[k] * mask[k];
}
}
input += (class_num + 5) * grid_num;
target += grid_num * num;
weight += grid_num;
mask += grid_num;
}
}
}
template <typename T>
static void CalcSCEGrad(T* input_grad, const T* loss_grad, const T* input,
const T* target, const T* weight, const T* mask,
const int n, const int an_num, const int grid_num,
const int class_num, const int num) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < grid_num; k++) {
int sub_idx = k * num;
for (int l = 0; l < num; l++) {
input_grad[l * grid_num + k] =
SCEGrad<T>(input[l * grid_num + k], target[sub_idx + l]) *
weight[k] * mask[k] * loss_grad[i];
}
}
input_grad += (class_num + 5) * grid_num;
input += (class_num + 5) * grid_num;
target += grid_num * num;
weight += grid_num;
mask += grid_num;
}
}
}
template <typename T>
static void CalcL1Loss(T* loss_data, const T* input, const T* target,
const T* weight, const T* mask, const int n,
const int an_num, const int grid_num,
const int class_num) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < grid_num; k++) {
loss_data[i] += L1Loss<T>(input[k], target[k]) * weight[k] * mask[k];
}
input += (class_num + 5) * grid_num;
target += grid_num;
weight += grid_num;
mask += grid_num;
}
}
}
template <typename T>
static void CalcL1LossGrad(T* input_grad, const T* loss_grad, const T* input,
const T* target, const T* weight, const T* mask,
const int n, const int an_num, const int grid_num,
const int class_num) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < grid_num; k++) {
input_grad[k] = L1LossGrad<T>(input[k], target[k]) * weight[k] *
mask[k] * loss_grad[i];
}
input_grad += (class_num + 5) * grid_num;
input += (class_num + 5) * grid_num;
target += grid_num;
weight += grid_num;
mask += grid_num;
}
}
}
template <typename T>
static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx,
const Tensor& ty, const Tensor& tw, const Tensor& th,
const Tensor& tweight, const Tensor& tconf,
const Tensor& tclass, const Tensor& conf_mask,
const Tensor& obj_mask) {
const T* input_data = input.data<T>();
const T* tx_data = tx.data<T>();
const T* ty_data = ty.data<T>();
const T* tw_data = tw.data<T>();
const T* th_data = th.data<T>();
const T* tweight_data = tweight.data<T>();
const T* tconf_data = tconf.data<T>();
const T* tclass_data = tclass.data<T>();
const T* conf_mask_data = conf_mask.data<T>();
const T* obj_mask_data = obj_mask.data<T>();
const int n = tclass.dims()[0];
const int an_num = tclass.dims()[1];
const int h = tclass.dims()[2];
const int w = tclass.dims()[3];
const int class_num = tclass.dims()[4];
const int grid_num = h * w;
CalcSCE<T>(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n,
an_num, grid_num, class_num, 1);
CalcSCE<T>(loss_data, input_data + grid_num, ty_data, tweight_data,
obj_mask_data, n, an_num, grid_num, class_num, 1);
CalcL1Loss<T>(loss_data, input_data + 2 * grid_num, tw_data, tweight_data,
obj_mask_data, n, an_num, grid_num, class_num);
CalcL1Loss<T>(loss_data, input_data + 3 * grid_num, th_data, tweight_data,
obj_mask_data, n, an_num, grid_num, class_num);
CalcSCE<T>(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data,
conf_mask_data, n, an_num, grid_num, class_num, 1);
CalcSCE<T>(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data,
obj_mask_data, n, an_num, grid_num, class_num, class_num);
}
template <typename T>
static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad,
const Tensor& input, const Tensor& tx,
const Tensor& ty, const Tensor& tw,
const Tensor& th, const Tensor& tweight,
const Tensor& tconf, const Tensor& tclass,
const Tensor& conf_mask,
const Tensor& obj_mask) {
const T* loss_grad_data = loss_grad.data<T>();
const T* input_data = input.data<T>();
const T* tx_data = tx.data<T>();
const T* ty_data = ty.data<T>();
const T* tw_data = tw.data<T>();
const T* th_data = th.data<T>();
const T* tweight_data = tweight.data<T>();
const T* tconf_data = tconf.data<T>();
const T* tclass_data = tclass.data<T>();
const T* conf_mask_data = conf_mask.data<T>();
const T* obj_mask_data = obj_mask.data<T>();
const int n = tclass.dims()[0];
const int an_num = tclass.dims()[1];
const int h = tclass.dims()[2];
const int w = tclass.dims()[3];
const int class_num = tclass.dims()[4];
const int grid_num = h * w;
CalcSCEGrad<T>(input_grad_data, loss_grad_data, input_data, tx_data,
tweight_data, obj_mask_data, n, an_num, grid_num, class_num,
1);
CalcSCEGrad<T>(input_grad_data + grid_num, loss_grad_data,
input_data + grid_num, ty_data, tweight_data, obj_mask_data, n,
an_num, grid_num, class_num, 1);
CalcL1LossGrad<T>(input_grad_data + 2 * grid_num, loss_grad_data,
input_data + 2 * grid_num, tw_data, tweight_data,
obj_mask_data, n, an_num, grid_num, class_num);
CalcL1LossGrad<T>(input_grad_data + 3 * grid_num, loss_grad_data,
input_data + 3 * grid_num, th_data, tweight_data,
obj_mask_data, n, an_num, grid_num, class_num);
CalcSCEGrad<T>(input_grad_data + 4 * grid_num, loss_grad_data,
input_data + 4 * grid_num, tconf_data, conf_mask_data,
conf_mask_data, n, an_num, grid_num, class_num, 1);
CalcSCEGrad<T>(input_grad_data + 5 * grid_num, loss_grad_data,
input_data + 5 * grid_num, tclass_data, obj_mask_data,
obj_mask_data, n, an_num, grid_num, class_num, class_num);
}
static int mask_index(std::vector<int> mask, int val) {
for (int i = 0; i < mask.size(); i++) {
static int GetMaskIndex(std::vector<int> mask, int val) {
for (size_t i = 0; i < mask.size(); i++) {
if (mask[i] == val) {
return i;
}
......@@ -341,16 +71,9 @@ static inline T sigmoid(T x) {
}
template <typename T>
static inline void sigmoid_arrray(T* arr, int len) {
for (int i = 0; i < len; i++) {
arr[i] = sigmoid(arr[i]);
}
}
template <typename T>
static inline Box<T> get_yolo_box(const T* x, std::vector<int> anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride) {
static inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride) {
Box<T> b;
b.x = (i + sigmoid<T>(x[index])) / grid_size;
b.y = (j + sigmoid<T>(x[index + stride])) / grid_size;
......@@ -360,8 +83,7 @@ static inline Box<T> get_yolo_box(const T* x, std::vector<int> anchors, int i,
}
template <typename T>
static inline Box<T> get_gt_box(const T* gt, int batch, int max_boxes,
int idx) {
static inline Box<T> GetGtBox(const T* gt, int batch, int max_boxes, int idx) {
Box<T> b;
b.x = gt[(batch * max_boxes + idx) * 4];
b.y = gt[(batch * max_boxes + idx) * 4 + 1];
......@@ -371,7 +93,7 @@ static inline Box<T> get_gt_box(const T* gt, int batch, int max_boxes,
}
template <typename T>
static inline T overlap(T c1, T w1, T c2, T w2) {
static inline T BoxOverlap(T c1, T w1, T c2, T w2) {
T l1 = c1 - w1 / 2.0;
T l2 = c2 - w2 / 2.0;
T left = l1 > l2 ? l1 : l2;
......@@ -382,16 +104,16 @@ static inline T overlap(T c1, T w1, T c2, T w2) {
}
template <typename T>
static inline T box_iou(Box<T> b1, Box<T> b2) {
T w = overlap(b1.x, b1.w, b2.x, b2.w);
T h = overlap(b1.y, b1.h, b2.y, b2.h);
static inline T CalcBoxIoU(Box<T> b1, Box<T> b2) {
T w = BoxOverlap(b1.x, b1.w, b2.x, b2.w);
T h = BoxOverlap(b1.y, b1.h, b2.y, b2.h);
T inter_area = (w < 0 || h < 0) ? 0.0 : w * h;
T union_area = b1.w * b1.h + b2.w * b2.h - inter_area;
return inter_area / union_area;
}
static inline int entry_index(int batch, int an_idx, int hw_idx, int an_num,
int an_stride, int stride, int entry) {
static inline int GetEntryIndex(int batch, int an_idx, int hw_idx, int an_num,
int an_stride, int stride, int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
......@@ -523,7 +245,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
const T* gt_box_data = gt_box->data<T>();
const int* gt_label_data = gt_label->data<int>();
T* loss_data = loss->mutable_data<T>({n}, ctx.GetPlace());
memset(loss_data, 0, n * sizeof(int));
memset(loss_data, 0, loss->numel() * sizeof(T));
Tensor objness;
int* objness_data =
......@@ -538,22 +260,18 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int box_idx =
entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0);
Box<T> pred =
get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h,
input_size, box_idx, stride);
GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0);
Box<T> pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j],
h, input_size, box_idx, stride);
T best_iou = 0;
// int best_t = 0;
for (int t = 0; t < b; t++) {
if (isZero<T>(gt_box_data[i * b * 4 + t * 4]) &&
isZero<T>(gt_box_data[i * b * 4 + t * 4 + 1])) {
Box<T> gt = GetGtBox(gt_box_data, i, b, t);
if (LessEqualZero<T>(gt.w) || LessEqualZero<T>(gt.h)) {
continue;
}
Box<T> gt = get_gt_box(gt_box_data, i, b, t);
T iou = box_iou(pred, gt);
T iou = CalcBoxIoU(pred, gt);
if (iou > best_iou) {
best_iou = iou;
// best_t = t;
}
}
......@@ -565,11 +283,10 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
}
}
for (int t = 0; t < b; t++) {
if (isZero<T>(gt_box_data[i * b * 4 + t * 4]) &&
isZero<T>(gt_box_data[i * b * 4 + t * 4 + 1])) {
Box<T> gt = GetGtBox(gt_box_data, i, b, t);
if (LessEqualZero<T>(gt.w) || LessEqualZero<T>(gt.h)) {
continue;
}
Box<T> gt = get_gt_box(gt_box_data, i, b, t);
int gi = static_cast<int>(gt.x * w);
int gj = static_cast<int>(gt.y * h);
Box<T> gt_shift = gt;
......@@ -583,7 +300,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
an_box.y = 0.0;
an_box.w = anchors[2 * an_idx] / static_cast<T>(input_size);
an_box.h = anchors[2 * an_idx + 1] / static_cast<T>(input_size);
float iou = box_iou<T>(an_box, gt_shift);
float iou = CalcBoxIoU<T>(an_box, gt_shift);
// TO DO: iou > 0.5 ?
if (iou > best_iou) {
best_iou = iou;
......@@ -591,10 +308,10 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
}
}
int mask_idx = mask_index(anchor_mask, best_n);
int mask_idx = GetMaskIndex(anchor_mask, best_n);
if (mask_idx >= 0) {
int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0);
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0);
CalcBoxLocationLoss<T>(loss_data + i, input_data, gt, anchors, best_n,
box_idx, gi, gj, h, input_size, stride);
......@@ -602,8 +319,8 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
objness_data[obj_idx] = 1;
int label = gt_label_data[i * b + t];
int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 5);
int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 5);
CalcLabelLoss<T>(loss_data + i, input_data, label_idx, label,
class_num, stride);
}
......@@ -612,52 +329,6 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
CalcObjnessLoss<T>(loss_data, input_data + 4 * stride, objness_data, n,
mask_num, h, w, stride, an_stride);
// Tensor conf_mask, obj_mask;
// Tensor tx, ty, tw, th, tweight, tconf, tclass;
// conf_mask.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// obj_mask.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// tx.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// ty.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// tw.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// tweight.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
//
// math::SetConstant<platform::CPUDeviceContext, T> constant;
// constant(ctx.template device_context<platform::CPUDeviceContext>(),
// &conf_mask, static_cast<T>(1.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(),
// &obj_mask, static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(), &tx,
// static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(), &ty,
// static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(), &tw,
// static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(), &th,
// static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(),
// &tweight, static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(),
// &tconf,
// static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(),
// &tclass,
// static_cast<T>(0.0));
//
// PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors,
// input_size,
// h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th,
// &tweight,
// &tconf, &tclass);
//
// T* loss_data = loss->mutable_data<T>({n}, ctx.GetPlace());
// memset(loss_data, 0, n * sizeof(T));
// CalcYolov3Loss<T>(loss_data, *input, tx, ty, tw, th, tweight, tconf,
// tclass,
// conf_mask, obj_mask);
}
};
......@@ -706,22 +377,18 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int box_idx =
entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0);
Box<T> pred =
get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h,
input_size, box_idx, stride);
GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0);
Box<T> pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j],
h, input_size, box_idx, stride);
T best_iou = 0;
// int best_t = 0;
for (int t = 0; t < b; t++) {
if (isZero<T>(gt_box_data[i * b * 4 + t * 4]) &&
isZero<T>(gt_box_data[i * b * 4 + t * 4 + 1])) {
Box<T> gt = GetGtBox(gt_box_data, i, b, t);
if (LessEqualZero<T>(gt.w) || LessEqualZero<T>(gt.h)) {
continue;
}
Box<T> gt = get_gt_box(gt_box_data, i, b, t);
T iou = box_iou(pred, gt);
T iou = CalcBoxIoU(pred, gt);
if (iou > best_iou) {
best_iou = iou;
// best_t = t;
}
}
......@@ -733,11 +400,10 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
}
}
for (int t = 0; t < b; t++) {
if (isZero<T>(gt_box_data[i * b * 4 + t * 4]) &&
isZero<T>(gt_box_data[i * b * 4 + t * 4 + 1])) {
Box<T> gt = GetGtBox(gt_box_data, i, b, t);
if (LessEqualZero<T>(gt.w) || LessEqualZero<T>(gt.h)) {
continue;
}
Box<T> gt = get_gt_box(gt_box_data, i, b, t);
int gi = static_cast<int>(gt.x * w);
int gj = static_cast<int>(gt.y * h);
Box<T> gt_shift = gt;
......@@ -751,7 +417,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
an_box.y = 0.0;
an_box.w = anchors[2 * an_idx] / static_cast<T>(input_size);
an_box.h = anchors[2 * an_idx + 1] / static_cast<T>(input_size);
float iou = box_iou<T>(an_box, gt_shift);
float iou = CalcBoxIoU<T>(an_box, gt_shift);
// TO DO: iou > 0.5 ?
if (iou > best_iou) {
best_iou = iou;
......@@ -759,10 +425,10 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
}
}
int mask_idx = mask_index(anchor_mask, best_n);
int mask_idx = GetMaskIndex(anchor_mask, best_n);
if (mask_idx >= 0) {
int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0);
int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 0);
CalcBoxLocationLossGrad<T>(input_grad_data, loss_grad_data[i],
input_data, gt, anchors, best_n, box_idx,
gi, gj, h, input_size, stride);
......@@ -771,8 +437,8 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
objness_data[obj_idx] = 1;
int label = gt_label_data[i * b + t];
int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 5);
int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num,
an_stride, stride, 5);
CalcLabelLossGrad<T>(input_grad_data, loss_grad_data[i], input_data,
label_idx, label, class_num, stride);
}
......@@ -782,58 +448,6 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
CalcObjnessLossGrad<T>(input_grad_data + 4 * stride, loss_grad_data,
input_data + 4 * stride, objness_data, n, mask_num,
h, w, stride, an_stride);
// const int n = input->dims()[0];
// const int c = input->dims()[1];
// const int h = input->dims()[2];
// const int w = input->dims()[3];
// const int an_num = anchors.size() / 2;
//
// Tensor conf_mask, obj_mask;
// Tensor tx, ty, tw, th, tweight, tconf, tclass;
// conf_mask.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// obj_mask.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// tx.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// ty.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// tw.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// tweight.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
// tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
//
// math::SetConstant<platform::CPUDeviceContext, T> constant;
// constant(ctx.template device_context<platform::CPUDeviceContext>(),
// &conf_mask, static_cast<T>(1.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(),
// &obj_mask, static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(), &tx,
// static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(), &ty,
// static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(), &tw,
// static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(), &th,
// static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(),
// &tweight, static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(),
// &tconf,
// static_cast<T>(0.0));
// constant(ctx.template device_context<platform::CPUDeviceContext>(),
// &tclass,
// static_cast<T>(0.0));
//
// PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors,
// input_size,
// h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th,
// &tweight,
// &tconf, &tclass);
//
// T* input_grad_data =
// input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
// CalcYolov3LossGrad<T>(input_grad_data, *loss_grad, *input, tx, ty, tw,
// th,
// tweight, tconf, tclass, conf_mask, obj_mask);
}
};
......
......@@ -22,32 +22,6 @@ from op_test import OpTest
from paddle.fluid import core
# def l1loss(x, y, weight):
# n = x.shape[0]
# x = x.reshape((n, -1))
# y = y.reshape((n, -1))
# weight = weight.reshape((n, -1))
# return (np.abs(y - x) * weight).sum(axis=1)
#
#
# def mse(x, y, weight):
# n = x.shape[0]
# x = x.reshape((n, -1))
# y = y.reshape((n, -1))
# weight = weight.reshape((n, -1))
# return ((y - x)**2 * weight).sum(axis=1)
#
#
# def sce(x, label, weight):
# n = x.shape[0]
# x = x.reshape((n, -1))
# label = label.reshape((n, -1))
# weight = weight.reshape((n, -1))
# sigmoid_x = expit(x)
# term1 = label * np.log(sigmoid_x)
# term2 = (1.0 - label) * np.log(1.0 - sigmoid_x)
# return ((-term1 - term2) * weight).sum(axis=1)
def l1loss(x, y):
return abs(x - y)
......@@ -60,116 +34,6 @@ def sce(x, label):
return -term1 - term2
def box_iou(box1, box2):
b1_x1 = box1[0] - box1[2] / 2
b1_x2 = box1[0] + box1[2] / 2
b1_y1 = box1[1] - box1[3] / 2
b1_y2 = box1[1] + box1[3] / 2
b2_x1 = box2[0] - box2[2] / 2
b2_x2 = box2[0] + box2[2] / 2
b2_y1 = box2[1] - box2[3] / 2
b2_y2 = box2[1] + box2[3] / 2
b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
inter_rect_x1 = max(b1_x1, b2_x1)
inter_rect_y1 = max(b1_y1, b2_y1)
inter_rect_x2 = min(b1_x2, b2_x2)
inter_rect_y2 = min(b1_y2, b2_y2)
inter_area = max(inter_rect_x2 - inter_rect_x1, 0) * max(
inter_rect_y2 - inter_rect_y1, 0)
return inter_area / (b1_area + b2_area + inter_area)
def build_target(gtboxes, gtlabel, attrs, grid_size):
n, b, _ = gtboxes.shape
ignore_thresh = attrs["ignore_thresh"]
anchors = attrs["anchors"]
class_num = attrs["class_num"]
input_size = attrs["input_size"]
an_num = len(anchors) // 2
conf_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32')
obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
tweight = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
tcls = np.zeros(
(n, an_num, grid_size, grid_size, class_num)).astype('float32')
for i in range(n):
for j in range(b):
if gtboxes[i, j, :].sum() == 0:
continue
gt_label = gtlabel[i, j]
gx = gtboxes[i, j, 0] * grid_size
gy = gtboxes[i, j, 1] * grid_size
gw = gtboxes[i, j, 2] * input_size
gh = gtboxes[i, j, 3] * input_size
gi = int(gx)
gj = int(gy)
gtbox = [0, 0, gw, gh]
max_iou = 0
for k in range(an_num):
anchor_box = [0, 0, anchors[2 * k], anchors[2 * k + 1]]
iou = box_iou(gtbox, anchor_box)
if iou > max_iou:
max_iou = iou
best_an_index = k
if iou > ignore_thresh:
conf_mask[i, best_an_index, gj, gi] = 0
conf_mask[i, best_an_index, gj, gi] = 1
obj_mask[i, best_an_index, gj, gi] = 1
tx[i, best_an_index, gj, gi] = gx - gi
ty[i, best_an_index, gj, gi] = gy - gj
tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 *
best_an_index])
th[i, best_an_index, gj, gi] = np.log(
gh / anchors[2 * best_an_index + 1])
tweight[i, best_an_index, gj, gi] = 2.0 - gtboxes[
i, j, 2] * gtboxes[i, j, 3]
tconf[i, best_an_index, gj, gi] = 1
tcls[i, best_an_index, gj, gi, gt_label] = 1
return (tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask)
def YoloV3Loss(x, gtbox, gtlabel, attrs):
n, c, h, w = x.shape
an_num = len(attrs['anchors']) // 2
class_num = attrs["class_num"]
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
pred_x = x[:, :, :, :, 0]
pred_y = x[:, :, :, :, 1]
pred_w = x[:, :, :, :, 2]
pred_h = x[:, :, :, :, 3]
pred_conf = x[:, :, :, :, 4]
pred_cls = x[:, :, :, :, 5:]
tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask = build_target(
gtbox, gtlabel, attrs, x.shape[2])
obj_weight = obj_mask * tweight
obj_mask_expand = np.tile(
np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num'])))
loss_x = sce(pred_x, tx, obj_weight)
loss_y = sce(pred_y, ty, obj_weight)
loss_w = l1loss(pred_w, tw, obj_weight)
loss_h = l1loss(pred_h, th, obj_weight)
loss_obj = sce(pred_conf, tconf, conf_mask)
loss_class = sce(pred_cls, tcls, obj_mask_expand)
return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-1.0 * x))
......@@ -291,8 +155,10 @@ class TestYolov3LossOp(OpTest):
self.op_type = 'yolov3_loss'
x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32'))
gtbox = np.random.random(size=self.gtbox_shape).astype('float32')
gtlabel = np.random.randint(0, self.class_num,
self.gtbox_shape[:2]).astype('int32')
gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2])
gtmask = np.random.randint(0, 2, self.gtbox_shape[:2])
gtbox = gtbox * gtmask[:, :, np.newaxis]
gtlabel = gtlabel * gtmask
self.attrs = {
"anchors": self.anchors,
......@@ -302,7 +168,11 @@ class TestYolov3LossOp(OpTest):
"downsample": self.downsample,
}
self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel}
self.inputs = {
'X': x,
'GTBox': gtbox.astype('float32'),
'GTLabel': gtlabel.astype('int32')
}
self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册