yolov3_loss_op.h 20.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

28
using Array5 = Eigen::DSizes<int64_t, 5>;
29 30 31

template <typename T>
static inline bool isZero(T x) {
D
dengkaipeng 已提交
32
  return fabs(x) < 1e-6;
33 34 35
}

template <typename T>
36 37 38 39 40
static inline void CalcMSEWithWeight(const Tensor& x, const Tensor& y,
                                     const Tensor& weight, const T loss_weight,
                                     T* loss) {
  int n = x.dims()[0];
  int stride = x.numel() / n;
41 42 43
  const T* x_data = x.data<T>();
  const T* y_data = y.data<T>();
  const T* weight_data = weight.data<T>();
44

45 46 47 48 49 50 51
  for (int i = 0; i < n; i++) {
    for (int j = 0; j < stride; j++) {
      loss[i] += pow(y_data[j] - x_data[j], 2) * weight_data[j] * loss_weight;
    }
    x_data += stride;
    y_data += stride;
    weight_data += stride;
52 53 54
  }
}

55
template <typename T>
56 57 58 59 60
static void CalcMSEGradWithWeight(const T* loss_grad, Tensor* grad,
                                  const Tensor& x, const Tensor& y,
                                  const Tensor& weight) {
  int n = x.dims()[0];
  int stride = x.numel() / n;
61 62 63 64 65
  T* grad_data = grad->data<T>();
  const T* x_data = x.data<T>();
  const T* y_data = y.data<T>();
  const T* weight_data = weight.data<T>();

66 67 68 69 70 71 72 73 74
  for (int i = 0; i < n; i++) {
    for (int j = 0; j < stride; j++) {
      grad_data[j] =
          2.0 * weight_data[j] * (x_data[j] - y_data[j]) * loss_grad[i];
    }
    grad_data += stride;
    x_data += stride;
    y_data += stride;
    weight_data += stride;
D
dengkaipeng 已提交
75
  }
76 77
}

78
template <typename T>
79 80 81 82 83
static inline void CalcSCEWithWeight(const Tensor& x, const Tensor& label,
                                     const Tensor& weight, const T loss_weight,
                                     T* loss) {
  int n = x.dims()[0];
  int stride = x.numel() / n;
84
  const T* x_data = x.data<T>();
85
  const T* label_data = label.data<T>();
86 87
  const T* weight_data = weight.data<T>();

88 89 90 91 92 93 94 95 96 97
  for (int i = 0; i < n; i++) {
    for (int j = 0; j < stride; j++) {
      T term1 = (x_data[j] > 0) ? x_data[j] : 0;
      T term2 = x_data[j] * label_data[j];
      T term3 = std::log(1.0 + std::exp(-std::abs(x_data[j])));
      loss[i] += (term1 - term2 + term3) * weight_data[j] * loss_weight;
    }
    x_data += stride;
    label_data += stride;
    weight_data += stride;
D
dengkaipeng 已提交
98
  }
99 100 101
}

template <typename T>
102 103 104 105 106
static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad,
                                         const Tensor& x, const Tensor& label,
                                         const Tensor& weight) {
  int n = x.dims()[0];
  int stride = x.numel() / n;
107 108
  T* grad_data = grad->data<T>();
  const T* x_data = x.data<T>();
109
  const T* label_data = label.data<T>();
110 111
  const T* weight_data = weight.data<T>();

112 113 114 115 116 117 118 119 120 121 122 123
  // LOG(ERROR) << "SCE grad start";
  for (int i = 0; i < n; i++) {
    for (int j = 0; j < stride; j++) {
      grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) *
                     weight_data[j] * loss_grad[i];
      // if (j == 18) LOG(ERROR) << x_data[j] << " " << label_data[j] << " " <<
      // weight_data[j] << " " << loss_grad[i];
    }
    grad_data += stride;
    x_data += stride;
    label_data += stride;
    weight_data += stride;
124 125 126 127
  }
}

template <typename T>
128 129 130 131
static void SplitPredResult(const Tensor& input, Tensor* pred_conf,
                            Tensor* pred_class, Tensor* pred_x, Tensor* pred_y,
                            Tensor* pred_w, Tensor* pred_h,
                            const int anchor_num, const int class_num) {
132 133 134 135 136 137
  const int n = input.dims()[0];
  const int h = input.dims()[2];
  const int w = input.dims()[3];
  const int box_attr_num = 5 + class_num;

  auto input_t = EigenTensor<T, 4>::From(input);
138 139
  auto pred_conf_t = EigenTensor<T, 4>::From(*pred_conf);
  auto pred_class_t = EigenTensor<T, 5>::From(*pred_class);
140 141 142 143 144 145 146 147 148
  auto pred_x_t = EigenTensor<T, 4>::From(*pred_x);
  auto pred_y_t = EigenTensor<T, 4>::From(*pred_y);
  auto pred_w_t = EigenTensor<T, 4>::From(*pred_w);
  auto pred_h_t = EigenTensor<T, 4>::From(*pred_h);

  for (int i = 0; i < n; i++) {
    for (int an_idx = 0; an_idx < anchor_num; an_idx++) {
      for (int j = 0; j < h; j++) {
        for (int k = 0; k < w; k++) {
149
          pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, k);
150
          pred_y_t(i, an_idx, j, k) =
151
              input_t(i, box_attr_num * an_idx + 1, j, k);
152
          pred_w_t(i, an_idx, j, k) =
D
dengkaipeng 已提交
153
              input_t(i, box_attr_num * an_idx + 2, j, k);
154
          pred_h_t(i, an_idx, j, k) =
D
dengkaipeng 已提交
155
              input_t(i, box_attr_num * an_idx + 3, j, k);
156

157
          pred_conf_t(i, an_idx, j, k) =
158
              input_t(i, box_attr_num * an_idx + 4, j, k);
159 160

          for (int c = 0; c < class_num; c++) {
161
            pred_class_t(i, an_idx, j, k, c) =
162
                input_t(i, box_attr_num * an_idx + 5 + c, j, k);
163 164 165 166 167 168 169 170
          }
        }
      }
    }
  }
}

template <typename T>
D
dengkaipeng 已提交
171 172 173 174 175 176 177 178 179 180 181 182
static T CalcBoxIoU(std::vector<T> box1, std::vector<T> box2) {
  T b1_x1 = box1[0] - box1[2] / 2;
  T b1_x2 = box1[0] + box1[2] / 2;
  T b1_y1 = box1[1] - box1[3] / 2;
  T b1_y2 = box1[1] + box1[3] / 2;
  T b2_x1 = box2[0] - box2[2] / 2;
  T b2_x2 = box2[0] + box2[2] / 2;
  T b2_y1 = box2[1] - box2[3] / 2;
  T b2_y2 = box2[1] + box2[3] / 2;

  T b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1);
  T b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1);
183 184 185 186 187

  T inter_rect_x1 = std::max(b1_x1, b2_x1);
  T inter_rect_y1 = std::max(b1_y1, b2_y1);
  T inter_rect_x2 = std::min(b1_x2, b2_x2);
  T inter_rect_y2 = std::min(b1_y2, b2_y2);
D
dengkaipeng 已提交
188 189
  T inter_area = std::max(inter_rect_x2 - inter_rect_x1, static_cast<T>(0.0)) *
                 std::max(inter_rect_y2 - inter_rect_y1, static_cast<T>(0.0));
190

D
dengkaipeng 已提交
191
  return inter_area / (b1_area + b2_area - inter_area);
192 193 194
}

template <typename T>
D
dengkaipeng 已提交
195 196
static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label,
                            const float ignore_thresh, std::vector<int> anchors,
197 198 199 200
                            const int input_size, const int grid_size,
                            Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx,
                            Tensor* ty, Tensor* tw, Tensor* th, Tensor* tweight,
                            Tensor* tconf, Tensor* tclass) {
D
dengkaipeng 已提交
201 202
  const int n = gt_box.dims()[0];
  const int b = gt_box.dims()[1];
203
  const int anchor_num = anchors.size() / 2;
D
dengkaipeng 已提交
204 205
  auto gt_box_t = EigenTensor<T, 3>::From(gt_box);
  auto gt_label_t = EigenTensor<int, 2>::From(gt_label);
206 207
  auto obj_mask_t = EigenTensor<T, 4>::From(*obj_mask).setConstant(0);
  auto noobj_mask_t = EigenTensor<T, 4>::From(*noobj_mask).setConstant(1);
208 209 210 211
  auto tx_t = EigenTensor<T, 4>::From(*tx).setConstant(0.0);
  auto ty_t = EigenTensor<T, 4>::From(*ty).setConstant(0.0);
  auto tw_t = EigenTensor<T, 4>::From(*tw).setConstant(0.0);
  auto th_t = EigenTensor<T, 4>::From(*th).setConstant(0.0);
212
  auto tweight_t = EigenTensor<T, 4>::From(*tweight).setConstant(0.0);
213 214 215 216 217
  auto tconf_t = EigenTensor<T, 4>::From(*tconf).setConstant(0.0);
  auto tclass_t = EigenTensor<T, 5>::From(*tclass).setConstant(0.0);

  for (int i = 0; i < n; i++) {
    for (int j = 0; j < b; j++) {
D
dengkaipeng 已提交
218 219
      if (isZero<T>(gt_box_t(i, j, 0)) && isZero<T>(gt_box_t(i, j, 1)) &&
          isZero<T>(gt_box_t(i, j, 2)) && isZero<T>(gt_box_t(i, j, 3))) {
220 221 222
        continue;
      }

D
dengkaipeng 已提交
223 224 225
      int cur_label = gt_label_t(i, j);
      T gx = gt_box_t(i, j, 0) * grid_size;
      T gy = gt_box_t(i, j, 1) * grid_size;
226 227
      T gw = gt_box_t(i, j, 2) * input_size;
      T gh = gt_box_t(i, j, 3) * input_size;
228 229 230
      int gi = static_cast<int>(gx);
      int gj = static_cast<int>(gy);

231
      T max_iou = static_cast<T>(0);
232 233
      T iou;
      int best_an_index = -1;
D
dengkaipeng 已提交
234
      std::vector<T> gt_box_shape({0, 0, gw, gh});
235 236 237
      for (int an_idx = 0; an_idx < anchor_num; an_idx++) {
        std::vector<T> anchor_shape({0, 0, static_cast<T>(anchors[2 * an_idx]),
                                     static_cast<T>(anchors[2 * an_idx + 1])});
D
dengkaipeng 已提交
238
        iou = CalcBoxIoU<T>(gt_box_shape, anchor_shape);
239 240 241 242 243
        if (iou > max_iou) {
          max_iou = iou;
          best_an_index = an_idx;
        }
        if (iou > ignore_thresh) {
244
          noobj_mask_t(i, an_idx, gj, gi) = static_cast<T>(0.0);
245 246
        }
      }
247 248
      obj_mask_t(i, best_an_index, gj, gi) = static_cast<T>(1.0);
      noobj_mask_t(i, best_an_index, gj, gi) = static_cast<T>(0.0);
249 250
      tx_t(i, best_an_index, gj, gi) = gx - gi;
      ty_t(i, best_an_index, gj, gi) = gy - gj;
D
dengkaipeng 已提交
251 252
      tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]);
      th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]);
253 254
      tweight_t(i, best_an_index, gj, gi) =
          2.0 - gt_box_t(i, j, 2) * gt_box_t(i, j, 3);
D
dengkaipeng 已提交
255
      tclass_t(i, best_an_index, gj, gi, cur_label) = 1;
256
      tconf_t(i, best_an_index, gj, gi) = 1;
257 258
    }
  }
259 260
}

261 262
template <typename T>
static void AddAllGradToInputGrad(
263 264 265 266 267 268 269 270 271 272
    Tensor* grad, const Tensor& grad_x, const Tensor& grad_y,
    const Tensor& grad_w, const Tensor& grad_h, const Tensor& grad_conf_target,
    const Tensor& grad_conf_notarget, const Tensor& grad_class,
    const int class_num, const float loss_weight_xy, const float loss_weight_wh,
    const float loss_weight_conf_target, const float loss_weight_conf_notarget,
    const float loss_weight_class) {
  const int n = grad_x.dims()[0];
  const int an_num = grad_x.dims()[1];
  const int h = grad_x.dims()[2];
  const int w = grad_x.dims()[3];
273 274 275 276 277 278
  const int attr_num = class_num + 5;
  auto grad_t = EigenTensor<T, 4>::From(*grad).setConstant(0.0);
  auto grad_x_t = EigenTensor<T, 4>::From(grad_x);
  auto grad_y_t = EigenTensor<T, 4>::From(grad_y);
  auto grad_w_t = EigenTensor<T, 4>::From(grad_w);
  auto grad_h_t = EigenTensor<T, 4>::From(grad_h);
D
dengkaipeng 已提交
279 280
  auto grad_conf_target_t = EigenTensor<T, 4>::From(grad_conf_target);
  auto grad_conf_notarget_t = EigenTensor<T, 4>::From(grad_conf_notarget);
281 282 283 284 285 286
  auto grad_class_t = EigenTensor<T, 5>::From(grad_class);

  for (int i = 0; i < n; i++) {
    for (int j = 0; j < an_num; j++) {
      for (int k = 0; k < h; k++) {
        for (int l = 0; l < w; l++) {
287
          grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) * loss_weight_xy;
288
          grad_t(i, j * attr_num + 1, k, l) =
289
              grad_y_t(i, j, k, l) * loss_weight_xy;
D
dengkaipeng 已提交
290
          grad_t(i, j * attr_num + 2, k, l) =
291
              grad_w_t(i, j, k, l) * loss_weight_wh;
D
dengkaipeng 已提交
292
          grad_t(i, j * attr_num + 3, k, l) =
293
              grad_h_t(i, j, k, l) * loss_weight_wh;
294
          grad_t(i, j * attr_num + 4, k, l) =
295
              grad_conf_target_t(i, j, k, l) * loss_weight_conf_target;
296
          grad_t(i, j * attr_num + 4, k, l) +=
297
              grad_conf_notarget_t(i, j, k, l) * loss_weight_conf_notarget;
298 299 300

          for (int c = 0; c < class_num; c++) {
            grad_t(i, j * attr_num + 5 + c, k, l) =
301
                grad_class_t(i, j, k, l, c) * loss_weight_class;
302 303 304 305 306 307 308
          }
        }
      }
    }
  }
}

309
template <typename T>
310 311 312 313
class Yolov3LossKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<Tensor>("X");
D
dengkaipeng 已提交
314 315
    auto* gt_box = ctx.Input<Tensor>("GTBox");
    auto* gt_label = ctx.Input<Tensor>("GTLabel");
D
dengkaipeng 已提交
316
    auto* loss = ctx.Output<Tensor>("Loss");
317 318
    auto anchors = ctx.Attr<std::vector<int>>("anchors");
    int class_num = ctx.Attr<int>("class_num");
319
    int input_size = ctx.Attr<int>("input_size");
320
    float ignore_thresh = ctx.Attr<float>("ignore_thresh");
D
dengkaipeng 已提交
321 322 323 324 325 326
    float loss_weight_xy = ctx.Attr<float>("loss_weight_xy");
    float loss_weight_wh = ctx.Attr<float>("loss_weight_wh");
    float loss_weight_conf_target = ctx.Attr<float>("loss_weight_conf_target");
    float loss_weight_conf_notarget =
        ctx.Attr<float>("loss_weight_conf_notarget");
    float loss_weight_class = ctx.Attr<float>("loss_weight_class");
327 328 329 330 331 332 333

    const int n = input->dims()[0];
    const int h = input->dims()[2];
    const int w = input->dims()[3];
    const int an_num = anchors.size() / 2;

    Tensor pred_x, pred_y, pred_w, pred_h;
334
    Tensor pred_conf, pred_class;
335 336 337 338
    pred_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    pred_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    pred_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    pred_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
339 340
    pred_conf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    pred_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
341 342
    SplitPredResult<T>(*input, &pred_conf, &pred_class, &pred_x, &pred_y,
                       &pred_w, &pred_h, an_num, class_num);
343

D
dengkaipeng 已提交
344
    Tensor obj_mask, noobj_mask;
345 346 347
    Tensor tx, ty, tw, th, tweight, tconf, tclass;
    obj_mask.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    noobj_mask.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
348 349 350 351
    tx.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    ty.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    tw.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
352
    tweight.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
353 354
    tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
355 356 357 358 359 360 361 362 363 364
    PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors, input_size,
                       h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight,
                       &tconf, &tclass);

    Tensor obj_weight;
    obj_weight.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    auto obj_weight_t = EigenTensor<T, 4>::From(obj_weight);
    auto obj_mask_t = EigenTensor<T, 4>::From(obj_mask);
    auto tweight_t = EigenTensor<T, 4>::From(tweight);
    obj_weight_t = obj_mask_t * tweight_t;
D
dengkaipeng 已提交
365

366
    Tensor obj_mask_expand;
367 368 369 370 371 372
    obj_mask_expand.mutable_data<T>({n, an_num, h, w, class_num},
                                    ctx.GetPlace());
    auto obj_mask_expand_t = EigenTensor<T, 5>::From(obj_mask_expand);
    obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1))
                            .broadcast(Array5(1, 1, 1, 1, class_num));

373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
    T* loss_data = loss->mutable_data<T>({n}, ctx.GetPlace());
    memset(loss_data, 0, n * sizeof(T));
    CalcSCEWithWeight<T>(pred_x, tx, obj_weight, loss_weight_xy, loss_data);
    CalcSCEWithWeight<T>(pred_y, ty, obj_weight, loss_weight_xy, loss_data);
    CalcMSEWithWeight<T>(pred_w, tw, obj_weight, loss_weight_wh, loss_data);
    CalcMSEWithWeight<T>(pred_h, th, obj_weight, loss_weight_wh, loss_data);
    CalcSCEWithWeight<T>(pred_conf, tconf, obj_mask, loss_weight_conf_target,
                         loss_data);
    CalcSCEWithWeight<T>(pred_conf, tconf, noobj_mask,
                         loss_weight_conf_notarget, loss_data);
    CalcSCEWithWeight<T>(pred_class, tclass, obj_mask_expand, loss_weight_class,
                         loss_data);

    // loss_data[0] = (loss_weight_xy * (loss_x + loss_y) +
    //                loss_weight_wh * (loss_w + loss_h) +
    //                loss_weight_conf_target * loss_conf_target +
    //                loss_weight_conf_notarget * loss_conf_notarget +
    //                loss_weight_class * loss_class) / n;
391 392 393
  }
};

394
template <typename T>
395 396 397
class Yolov3LossGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
398
    auto* input = ctx.Input<Tensor>("X");
D
dengkaipeng 已提交
399 400
    auto* gt_box = ctx.Input<Tensor>("GTBox");
    auto* gt_label = ctx.Input<Tensor>("GTLabel");
401 402 403 404
    auto anchors = ctx.Attr<std::vector<int>>("anchors");
    int class_num = ctx.Attr<int>("class_num");
    float ignore_thresh = ctx.Attr<float>("ignore_thresh");
    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
405 406
    auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
    const T* loss_grad_data = loss_grad->data<T>();
407
    int input_size = ctx.Attr<int>("input_size");
D
dengkaipeng 已提交
408 409 410 411 412 413
    float loss_weight_xy = ctx.Attr<float>("loss_weight_xy");
    float loss_weight_wh = ctx.Attr<float>("loss_weight_wh");
    float loss_weight_conf_target = ctx.Attr<float>("loss_weight_conf_target");
    float loss_weight_conf_notarget =
        ctx.Attr<float>("loss_weight_conf_notarget");
    float loss_weight_class = ctx.Attr<float>("loss_weight_class");
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428

    const int n = input->dims()[0];
    const int c = input->dims()[1];
    const int h = input->dims()[2];
    const int w = input->dims()[3];
    const int an_num = anchors.size() / 2;

    Tensor pred_x, pred_y, pred_w, pred_h;
    Tensor pred_conf, pred_class;
    pred_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    pred_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    pred_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    pred_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    pred_conf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    pred_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
429 430
    SplitPredResult<T>(*input, &pred_conf, &pred_class, &pred_x, &pred_y,
                       &pred_w, &pred_h, an_num, class_num);
431 432

    Tensor obj_mask, noobj_mask;
433 434 435
    Tensor tx, ty, tw, th, tweight, tconf, tclass;
    obj_mask.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    noobj_mask.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
436 437 438 439
    tx.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    ty.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    tw.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
440
    tweight.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
441 442
    tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
443 444 445 446 447 448 449 450 451 452
    PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors, input_size,
                       h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight,
                       &tconf, &tclass);

    Tensor obj_weight;
    obj_weight.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    auto obj_weight_t = EigenTensor<T, 4>::From(obj_weight);
    auto obj_mask_t = EigenTensor<T, 4>::From(obj_mask);
    auto tweight_t = EigenTensor<T, 4>::From(tweight);
    obj_weight_t = obj_mask_t * tweight_t;
453

454 455
    // LOG(ERROR) << obj_mask_t;

456
    Tensor obj_mask_expand;
457 458 459 460 461
    obj_mask_expand.mutable_data<T>({n, an_num, h, w, class_num},
                                    ctx.GetPlace());
    auto obj_mask_expand_t = EigenTensor<T, 5>::From(obj_mask_expand);
    obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1))
                            .broadcast(Array5(1, 1, 1, 1, class_num));
462 463

    Tensor grad_x, grad_y, grad_w, grad_h;
D
dengkaipeng 已提交
464
    Tensor grad_conf_target, grad_conf_notarget, grad_class;
465 466 467 468
    grad_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    grad_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    grad_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    grad_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
D
dengkaipeng 已提交
469 470
    grad_conf_target.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
    grad_conf_notarget.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
471
    grad_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
472 473 474 475 476 477 478 479 480 481
    CalcSCEGradWithWeight<T>(loss_grad_data, &grad_x, pred_x, tx, obj_weight);
    CalcSCEGradWithWeight<T>(loss_grad_data, &grad_y, pred_y, ty, obj_weight);
    CalcMSEGradWithWeight<T>(loss_grad_data, &grad_w, pred_w, tw, obj_weight);
    CalcMSEGradWithWeight<T>(loss_grad_data, &grad_h, pred_h, th, obj_weight);
    CalcSCEGradWithWeight<T>(loss_grad_data, &grad_conf_target, pred_conf,
                             tconf, obj_mask);
    CalcSCEGradWithWeight<T>(loss_grad_data, &grad_conf_notarget, pred_conf,
                             tconf, noobj_mask);
    CalcSCEGradWithWeight<T>(loss_grad_data, &grad_class, pred_class, tclass,
                             obj_mask_expand);
482 483

    input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
484 485 486 487 488
    AddAllGradToInputGrad<T>(input_grad, grad_x, grad_y, grad_w, grad_h,
                             grad_conf_target, grad_conf_notarget, grad_class,
                             class_num, loss_weight_xy, loss_weight_wh,
                             loss_weight_conf_target, loss_weight_conf_notarget,
                             loss_weight_class);
489 490 491 492 493
  }
};

}  // namespace operators
}  // namespace paddle