提交 77c1328f 编写于 作者: D dengkaipeng

add CPU kernel forward

上级 5d0b568e
...@@ -27,18 +27,8 @@ class Yolov3LossOp : public framework::OperatorWithKernel { ...@@ -27,18 +27,8 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
"Input(X) of Yolov3LossOp should not be null."); "Input(X) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("GTBox"), PADDLE_ENFORCE(ctx->HasInput("GTBox"),
"Input(GTBox) of Yolov3LossOp should not be null."); "Input(GTBox) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Loss"),
"Output(Out) of Yolov3LossOp should not be null."); "Output(Loss) of Yolov3LossOp should not be null.");
// PADDLE_ENFORCE(ctx->HasAttr("img_height"),
// "Attr(img_height) of Yolov3LossOp should not be null. ");
// PADDLE_ENFORCE(ctx->HasAttr("anchors"),
// "Attr(anchor) of Yolov3LossOp should not be null.")
// PADDLE_ENFORCE(ctx->HasAttr("class_num"),
// "Attr(class_num) of Yolov3LossOp should not be null.");
// PADDLE_ENFORCE(ctx->HasAttr(
// "ignore_thresh",
// "Attr(ignore_thresh) of Yolov3LossOp should not be null."));
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
auto dim_gt = ctx->GetInputDim("GTBox"); auto dim_gt = ctx->GetInputDim("GTBox");
...@@ -46,6 +36,14 @@ class Yolov3LossOp : public framework::OperatorWithKernel { ...@@ -46,6 +36,14 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors"); auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
auto box_num = ctx->Attrs().Get<int>("box_num"); auto box_num = ctx->Attrs().Get<int>("box_num");
auto class_num = ctx->Attrs().Get<int>("class_num"); auto class_num = ctx->Attrs().Get<int>("class_num");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3],
"Input(X) dim[3] and dim[4] should be euqal.");
PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num),
"Input(X) dim[1] should be equal to (anchor_number * (5 "
"+ class_num)).");
PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor");
PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5");
PADDLE_ENFORCE_GT(img_height, 0, PADDLE_ENFORCE_GT(img_height, 0,
"Attr(img_height) value should be greater then 0"); "Attr(img_height) value should be greater then 0");
PADDLE_ENFORCE_GT(anchors.size(), 0, PADDLE_ENFORCE_GT(anchors.size(), 0,
...@@ -56,14 +54,9 @@ class Yolov3LossOp : public framework::OperatorWithKernel { ...@@ -56,14 +54,9 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
"Attr(box_num) should be an integer greater then 0."); "Attr(box_num) should be an integer greater then 0.");
PADDLE_ENFORCE_GT(class_num, 0, PADDLE_ENFORCE_GT(class_num, 0,
"Attr(class_num) should be an integer greater then 0."); "Attr(class_num) should be an integer greater then 0.");
PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num),
"Input(X) dim[1] should be equal to (anchor_number * (5 "
"+ class_num)).");
PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor");
PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5");
std::vector<int64_t> dim_out({dim_x[0], 1}); std::vector<int64_t> dim_out({1});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); ctx->SetOutputDim("Loss", framework::make_ddim(dim_out));
} }
protected: protected:
...@@ -80,12 +73,31 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -80,12 +73,31 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"The input tensor of bilinear interpolation, " "The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of [N, C, H, W]"); "This is a 4-D tensor with shape of [N, C, H, W]");
AddOutput("Out", AddInput(
"The output yolo loss tensor, " "GTBox",
"This is a 2-D tensor with shape of [N, 1]"); "The input tensor of ground truth boxes, "
"This is a 3-D tensor with shape of [N, max_box_num, 5 + class_num], "
"max_box_num is the max number of boxes in each image, "
"class_num is the number of classes in data set. "
"In the third dimention, stores x, y, w, h, confidence, classes "
"one-hot key. "
"x, y is the center cordinate of boxes and w, h is the width and "
"height, "
"and all of them should be divided by input image height to scale to "
"[0, 1].");
AddOutput("Loss",
"The output yolov3 loss tensor, "
"This is a 1-D tensor with shape of [1]");
AddAttr<int>("box_num", "The number of boxes generated in each grid."); AddAttr<int>("box_num", "The number of boxes generated in each grid.");
AddAttr<int>("class_num", "The number of classes to predict."); AddAttr<int>("class_num", "The number of classes to predict.");
AddAttr<std::vector<int>>("anchors",
"The anchor width and height, "
"it will be parsed pair by pair.");
AddAttr<int>("img_height",
"The input image height after crop of yolov3 network.");
AddAttr<float>("ignore_thresh",
"The ignore threshold to ignore confidence loss.");
AddComment(R"DOC( AddComment(R"DOC(
This operator generate yolov3 loss by given predict result and ground This operator generate yolov3 loss by given predict result and ground
truth boxes. truth boxes.
...@@ -100,8 +112,8 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { ...@@ -100,8 +112,8 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
"Input(Out@GRAD) should not be null"); "Input(Loss@GRAD) should not be null");
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x); ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
......
...@@ -44,8 +44,16 @@ static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, ...@@ -44,8 +44,16 @@ static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y,
auto x_t = EigenVector<T>::Flatten(x); auto x_t = EigenVector<T>::Flatten(x);
auto y_t = EigenVector<T>::Flatten(y); auto y_t = EigenVector<T>::Flatten(y);
auto mask_t = EigenVector<T>::Flatten(mask); auto mask_t = EigenVector<T>::Flatten(mask);
auto result = ((x_t - y_t) * mask_t).pow(2).sum().eval();
return result(0); T error_sum = 0.0;
T points = 0.0;
for (int i = 0; i < x_t.dimensions()[0]; i++) {
if (mask_t(i)) {
error_sum += pow(x_t(i) - y_t(i), 2);
points += 1;
}
}
return (error_sum / points);
} }
template <typename T> template <typename T>
...@@ -55,27 +63,24 @@ static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, ...@@ -55,27 +63,24 @@ static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y,
auto y_t = EigenVector<T>::Flatten(y); auto y_t = EigenVector<T>::Flatten(y);
auto mask_t = EigenVector<T>::Flatten(mask); auto mask_t = EigenVector<T>::Flatten(mask);
auto result = T error_sum = 0.0;
((y_t * (x_t.log()) + (1.0 - y_t) * ((1.0 - x_t).log())) * mask_t) T points = 0.0;
.sum() for (int i = 0; i < x_t.dimensions()[0]; i++) {
.eval(); if (mask_t(i)) {
return result; error_sum +=
} -1.0 * (y_t(i) * log(x_t(i)) + (1.0 - y_t(i)) * log(1.0 - x_t(i)));
points += 1;
template <typename T> }
static inline T CalcCEWithMask(const Tensor& x, const Tensor& y, }
const Tensor& mask) { return (error_sum / points);
auto x_t = EigenVector<T>::Flatten(x);
auto y_t = EigenVector<T>::Flatten(y);
auto mask_t = EigenVector<T>::Flatten(mask);
} }
template <typename T> template <typename T>
static void CalcPredResult(const Tensor& input, Tensor* pred_boxes, static void CalcPredResult(const Tensor& input, Tensor* pred_confs,
Tensor* pred_confs, Tensor* pred_classes, Tensor* pred_classes, Tensor* pred_x, Tensor* pred_y,
Tensor* pred_x, Tensor* pred_y, Tensor* pred_w, Tensor* pred_w, Tensor* pred_h,
Tensor* pred_h, std::vector<int> anchors, std::vector<int> anchors, const int class_num,
const int class_num, const int stride) { const int stride) {
const int n = input.dims()[0]; const int n = input.dims()[0];
const int c = input.dims()[1]; const int c = input.dims()[1];
const int h = input.dims()[2]; const int h = input.dims()[2];
...@@ -84,7 +89,7 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_boxes, ...@@ -84,7 +89,7 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_boxes,
const int box_attr_num = 5 + class_num; const int box_attr_num = 5 + class_num;
auto input_t = EigenTensor<T, 4>::From(input); auto input_t = EigenTensor<T, 4>::From(input);
auto pred_boxes_t = EigenTensor<T, 5>::From(*pred_boxes); // auto pred_boxes_t = EigenTensor<T, 5>::From(*pred_boxes);
auto pred_confs_t = EigenTensor<T, 4>::From(*pred_confs); auto pred_confs_t = EigenTensor<T, 4>::From(*pred_confs);
auto pred_classes_t = EigenTensor<T, 5>::From(*pred_classes); auto pred_classes_t = EigenTensor<T, 5>::From(*pred_classes);
auto pred_x_t = EigenTensor<T, 4>::From(*pred_x); auto pred_x_t = EigenTensor<T, 4>::From(*pred_x);
...@@ -104,16 +109,16 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_boxes, ...@@ -104,16 +109,16 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_boxes,
pred_y_t(i, an_idx, j, k) = pred_y_t(i, an_idx, j, k) =
sigmod(input_t(i, box_attr_num * an_idx + 1, j, k)); sigmod(input_t(i, box_attr_num * an_idx + 1, j, k));
pred_w_t(i, an_idx, j, k) = pred_w_t(i, an_idx, j, k) =
sigmod(input_t(i, box_attr_num * an_idx + 2, j, k)); input_t(i, box_attr_num * an_idx + 2, j, k);
pred_h_t(i, an_idx, j, k) = pred_h_t(i, an_idx, j, k) =
sigmod(input_t(i, box_attr_num * an_idx + 3, j, k)); input_t(i, box_attr_num * an_idx + 3, j, k);
pred_boxes_t(i, an_idx, j, k, 0) = pred_x_t(i, an_idx, j, k) + k; // pred_boxes_t(i, an_idx, j, k, 0) = pred_x_t(i, an_idx, j, k) + k;
pred_boxes_t(i, an_idx, j, k, 1) = pred_y_t(i, an_idx, j, k) + j; // pred_boxes_t(i, an_idx, j, k, 1) = pred_y_t(i, an_idx, j, k) + j;
pred_boxes_t(i, an_idx, j, k, 2) = // pred_boxes_t(i, an_idx, j, k, 2) =
exp(pred_w_t(i, an_idx, j, k)) * an_w; // exp(pred_w_t(i, an_idx, j, k)) * an_w;
pred_boxes_t(i, an_idx, j, k, 3) = // pred_boxes_t(i, an_idx, j, k, 3) =
exp(pred_h_t(i, an_idx, j, k)) * an_h; // exp(pred_h_t(i, an_idx, j, k)) * an_h;
pred_confs_t(i, an_idx, j, k) = pred_confs_t(i, an_idx, j, k) =
sigmod(input_t(i, box_attr_num * an_idx + 4, j, k)); sigmod(input_t(i, box_attr_num * an_idx + 4, j, k));
...@@ -129,40 +134,27 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_boxes, ...@@ -129,40 +134,27 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_boxes,
} }
template <typename T> template <typename T>
static T CalcBoxIoU(std::vector<T> box1, std::vector<T> box2, static T CalcBoxIoU(std::vector<T> box1, std::vector<T> box2) {
bool center_mode) { T b1_x1 = box1[0] - box1[2] / 2;
T b1_x1, b1_x2, b1_y1, b1_y2; T b1_x2 = box1[0] + box1[2] / 2;
T b2_x1, b2_x2, b2_y1, b2_y2; T b1_y1 = box1[1] - box1[3] / 2;
if (center_mode) { T b1_y2 = box1[1] + box1[3] / 2;
b1_x1 = box1[0] - box1[2] / 2; T b2_x1 = box2[0] - box2[2] / 2;
b1_x2 = box1[0] + box1[2] / 2; T b2_x2 = box2[0] + box2[2] / 2;
b1_y1 = box1[1] - box1[3] / 2; T b2_y1 = box2[1] - box2[3] / 2;
b1_y2 = box1[1] + box1[3] / 2; T b2_y2 = box2[1] + box2[3] / 2;
b2_x1 = box2[0] - box2[2] / 2;
b2_x2 = box2[0] + box2[2] / 2; T b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1);
b2_y1 = box2[1] - box2[3] / 2; T b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1);
b2_y2 = box2[1] + box2[3] / 2;
} else {
b1_x1 = box1[0];
b1_x2 = box1[1];
b1_y1 = box1[2];
b1_y2 = box1[3];
b2_x1 = box2[0];
b2_x2 = box2[0];
b2_y1 = box2[1];
b2_y2 = box2[1];
}
T b1_area = (b1_x2 - b1_x1 + 1.0) * (b1_y2 - b1_y1 + 1.0);
T b2_area = (b2_x2 - b2_x1 + 1.0) * (b2_y2 - b2_y1 + 1.0);
T inter_rect_x1 = std::max(b1_x1, b2_x1); T inter_rect_x1 = std::max(b1_x1, b2_x1);
T inter_rect_y1 = std::max(b1_y1, b2_y1); T inter_rect_y1 = std::max(b1_y1, b2_y1);
T inter_rect_x2 = std::min(b1_x2, b2_x2); T inter_rect_x2 = std::min(b1_x2, b2_x2);
T inter_rect_y2 = std::min(b1_y2, b2_y2); T inter_rect_y2 = std::min(b1_y2, b2_y2);
T inter_area = std::max(inter_rect_x2 - inter_rect_x1 + 1.0, 0.0) * T inter_area = std::max(inter_rect_x2 - inter_rect_x1, static_cast<T>(0.0)) *
std::max(inter_rect_y2 - inter_rect_y1 + 1.0, 0.0); std::max(inter_rect_y2 - inter_rect_y1, static_cast<T>(0.0));
return inter_area / (b1_area + b2_area - inter_area + 1e-16); return inter_area / (b1_area + b2_area - inter_area);
} }
template <typename T> template <typename T>
...@@ -181,23 +173,18 @@ static inline int GetPredLabel(const Tensor& pred_classes, int n, ...@@ -181,23 +173,18 @@ static inline int GetPredLabel(const Tensor& pred_classes, int n,
} }
template <typename T> template <typename T>
static void CalcPredBoxWithGTBox( static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh,
const Tensor& pred_boxes, const Tensor& pred_confs, std::vector<int> anchors, const int img_height,
const Tensor& pred_classes, const Tensor& gt_boxes, const int grid_size, Tensor* obj_mask,
std::vector<int> anchors, const float ignore_thresh, const int img_height, Tensor* noobj_mask, Tensor* tx, Tensor* ty,
int* gt_num, int* correct_num, Tensor* mask_true, Tensor* mask_false, Tensor* tw, Tensor* th, Tensor* tconf,
Tensor* tx, Tensor* ty, Tensor* tw, Tensor* th, Tensor* tconf,
Tensor* tclass) { Tensor* tclass) {
const int n = gt_boxes.dims()[0]; const int n = gt_boxes.dims()[0];
const int b = gt_boxes.dims()[1]; const int b = gt_boxes.dims()[1];
const int grid_size = pred_boxes.dims()[1];
const int anchor_num = anchors.size() / 2; const int anchor_num = anchors.size() / 2;
auto pred_boxes_t = EigenTensor<T, 5>::From(pred_boxes);
auto pred_confs_t = EigenTensor<T, 4>::From(pred_confs);
auto pred_classes_t = EigenTensor<T, 5>::From(pred_classes);
auto gt_boxes_t = EigenTensor<T, 3>::From(gt_boxes); auto gt_boxes_t = EigenTensor<T, 3>::From(gt_boxes);
auto mask_true_t = EigenTensor<int, 4>::From(*mask_true).setConstant(0.0); auto obj_mask_t = EigenTensor<int, 4>::From(*obj_mask).setConstant(0);
auto mask_false_t = EigenTensor<int, 4>::From(*mask_false).setConstant(1.0); auto noobj_mask_t = EigenTensor<int, 4>::From(*noobj_mask).setConstant(1);
auto tx_t = EigenTensor<T, 4>::From(*tx).setConstant(0.0); auto tx_t = EigenTensor<T, 4>::From(*tx).setConstant(0.0);
auto ty_t = EigenTensor<T, 4>::From(*ty).setConstant(0.0); auto ty_t = EigenTensor<T, 4>::From(*ty).setConstant(0.0);
auto tw_t = EigenTensor<T, 4>::From(*tw).setConstant(0.0); auto tw_t = EigenTensor<T, 4>::From(*tw).setConstant(0.0);
...@@ -205,8 +192,6 @@ static void CalcPredBoxWithGTBox( ...@@ -205,8 +192,6 @@ static void CalcPredBoxWithGTBox(
auto tconf_t = EigenTensor<T, 4>::From(*tconf).setConstant(0.0); auto tconf_t = EigenTensor<T, 4>::From(*tconf).setConstant(0.0);
auto tclass_t = EigenTensor<T, 5>::From(*tclass).setConstant(0.0); auto tclass_t = EigenTensor<T, 5>::From(*tclass).setConstant(0.0);
*gt_num = 0;
*correct_num = 0;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
for (int j = 0; j < b; j++) { for (int j = 0; j < b; j++) {
if (isZero(gt_boxes_t(i, j, 0)) && isZero(gt_boxes_t(i, j, 1)) && if (isZero(gt_boxes_t(i, j, 0)) && isZero(gt_boxes_t(i, j, 1)) &&
...@@ -214,12 +199,11 @@ static void CalcPredBoxWithGTBox( ...@@ -214,12 +199,11 @@ static void CalcPredBoxWithGTBox(
continue; continue;
} }
*(gt_num)++;
int gt_label = gt_boxes_t(i, j, 0); int gt_label = gt_boxes_t(i, j, 0);
T gx = gt_boxes_t(i, j, 1); T gx = gt_boxes_t(i, j, 1) * grid_size;
T gy = gt_boxes_t(i, j, 2); T gy = gt_boxes_t(i, j, 2) * grid_size;
T gw = gt_boxes_t(i, j, 3); T gw = gt_boxes_t(i, j, 3) * grid_size;
T gh = gt_boxes_t(i, j, 4); T gh = gt_boxes_t(i, j, 4) * grid_size;
int gi = static_cast<int>(gx); int gi = static_cast<int>(gx);
int gj = static_cast<int>(gy); int gj = static_cast<int>(gy);
...@@ -230,43 +214,26 @@ static void CalcPredBoxWithGTBox( ...@@ -230,43 +214,26 @@ static void CalcPredBoxWithGTBox(
for (int an_idx = 0; an_idx < anchor_num; an_idx++) { for (int an_idx = 0; an_idx < anchor_num; an_idx++) {
std::vector<T> anchor_shape({0, 0, static_cast<T>(anchors[2 * an_idx]), std::vector<T> anchor_shape({0, 0, static_cast<T>(anchors[2 * an_idx]),
static_cast<T>(anchors[2 * an_idx + 1])}); static_cast<T>(anchors[2 * an_idx + 1])});
iou = CalcBoxIoU(gt_box, anchor_shape, false); iou = CalcBoxIoU<T>(gt_box, anchor_shape);
if (iou > max_iou) { if (iou > max_iou) {
max_iou = iou; max_iou = iou;
best_an_index = an_idx; best_an_index = an_idx;
} }
if (iou > ignore_thresh) { if (iou > ignore_thresh) {
mask_false_t(b, an_idx, gj, gi) = 0; noobj_mask_t(b, an_idx, gj, gi) = 0;
} }
} }
mask_true_t(b, best_an_index, gj, gi) = 1; obj_mask_t(b, best_an_index, gj, gi) = 1;
mask_false_t(b, best_an_index, gj, gi) = 1; noobj_mask_t(b, best_an_index, gj, gi) = 1;
tx_t(i, best_an_index, gj, gi) = gx - gi; tx_t(i, best_an_index, gj, gi) = gx - gi;
ty_t(i, best_an_index, gj, gi) = gy - gj; ty_t(i, best_an_index, gj, gi) = gy - gj;
tw_t(i, best_an_index, gj, gi) = tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]);
log(gw / anchors[2 * best_an_index] + 1e-16); th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]);
th_t(i, best_an_index, gj, gi) =
log(gh / anchors[2 * best_an_index + 1] + 1e-16);
tclass_t(b, best_an_index, gj, gi, gt_label) = 1; tclass_t(b, best_an_index, gj, gi, gt_label) = 1;
tconf_t(b, best_an_index, gj, gi) = 1; tconf_t(b, best_an_index, gj, gi) = 1;
std::vector<T> pred_box({
pred_boxes_t(i, best_an_index, gj, gi, 0),
pred_boxes_t(i, best_an_index, gj, gi, 1),
pred_boxes_t(i, best_an_index, gj, gi, 2),
pred_boxes_t(i, best_an_index, gj, gi, 3),
});
gt_box[0] = gx;
gt_box[1] = gy;
iou = CalcBoxIoU(gt_box, pred_box, true);
int pred_label = GetPredLabel<T>(pred_classes, i, best_an_index, gj, gi);
T score = pred_confs_t(i, best_an_index, gj, gi);
if (iou > 0.5 && pred_label == gt_label && score > 0.5) {
(*correct_num)++;
}
} }
} }
mask_false_t = mask_true_t - mask_false_t; noobj_mask_t = noobj_mask_t - obj_mask_t;
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -275,7 +242,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -275,7 +242,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* gt_boxes = ctx.Input<Tensor>("GTBox"); auto* gt_boxes = ctx.Input<Tensor>("GTBox");
auto* output = ctx.Output<Tensor>("Out"); auto* loss = ctx.Output<Tensor>("Loss");
int img_height = ctx.Attr<int>("img_height"); int img_height = ctx.Attr<int>("img_height");
auto anchors = ctx.Attr<std::vector<int>>("anchors"); auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num"); int class_num = ctx.Attr<int>("class_num");
...@@ -286,44 +253,44 @@ class Yolov3LossKernel : public framework::OpKernel<T> { ...@@ -286,44 +253,44 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
const int h = input->dims()[2]; const int h = input->dims()[2];
const int w = input->dims()[3]; const int w = input->dims()[3];
const int an_num = anchors.size() / 2; const int an_num = anchors.size() / 2;
const float stride = static_cast<float>(img_height) / h; const T stride = static_cast<T>(img_height) / h;
Tensor pred_x, pred_y, pred_w, pred_h; Tensor pred_x, pred_y, pred_w, pred_h;
Tensor pred_boxes, pred_confs, pred_classes; Tensor pred_confs, pred_classes;
pred_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace()); pred_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace()); pred_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace()); pred_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace()); pred_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_boxes.mutable_data<T>({n, an_num, h, w, 4}, ctx.GetPlace());
pred_confs.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace()); pred_confs.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_classes.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace()); pred_classes.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
CalcPredResult<T>(*input, &pred_boxes, &pred_confs, &pred_classes, &pred_x, CalcPredResult<T>(*input, &pred_confs, &pred_classes, &pred_x, &pred_y,
&pred_y, &pred_w, &pred_h, anchors, class_num, stride); &pred_w, &pred_h, anchors, class_num, stride);
Tensor mask_true, mask_false; Tensor obj_mask, noobj_mask;
Tensor tx, ty, tw, th, tconf, tclass; Tensor tx, ty, tw, th, tconf, tclass;
mask_true.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace()); obj_mask.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace());
mask_false.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace()); noobj_mask.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace());
tx.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
ty.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tw.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace()); th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace()); tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
int gt_num = 0; PrePorcessGTBox<T>(*gt_boxes, ignore_thresh, anchors, img_height, h,
int correct_num = 0; &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf,
CalcPredBoxWithGTBox<T>(pred_boxes, pred_confs, pred_classes, *gt_boxes, &tclass);
anchors, ignore_thresh, img_height, &gt_num,
&correct_num, &mask_true, &mask_false, &tx, &ty, T loss_x = CalcMSEWithMask<T>(pred_x, tx, obj_mask);
&tw, &th, &tconf, &tclass); T loss_y = CalcMSEWithMask<T>(pred_y, ty, obj_mask);
T loss_w = CalcMSEWithMask<T>(pred_w, tw, obj_mask);
T loss_x = CalcMSEWithMask<T>(pred_x, tx, mask_true); T loss_h = CalcMSEWithMask<T>(pred_h, th, obj_mask);
T loss_y = CalcMSEWithMask<T>(pred_y, ty, mask_true); T loss_conf_true = CalcBCEWithMask<T>(pred_confs, tconf, obj_mask);
T loss_w = CalcMSEWithMask<T>(pred_w, tw, mask_true); T loss_conf_false = CalcBCEWithMask<T>(pred_confs, tconf, noobj_mask);
T loss_h = CalcMSEWithMask<T>(pred_h, th, mask_true); T loss_class = CalcBCEWithMask<T>(pred_classes, tclass, obj_mask);
T loss_conf_true = CalcBCEWithMask<T>(pred_confs, tconf, mask_true);
T loss_conf_false = CalcBCEWithMask<T>(pred_confs, tconf, mask_false); auto* loss_data = loss->mutable_data<T>({1}, ctx.GetPlace());
// T loss_class = CalcCEWithMask<T>() loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_true +
loss_conf_false + loss_class;
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册