From 577a92d99203a67042f2b7fd6db25ecae09a1938 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 17 Dec 2018 11:45:16 +0800 Subject: [PATCH] use typename DeviceContext. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 12 +- paddle/fluid/operators/yolov3_loss_op.h | 301 ++++++------------ .../tests/unittests/test_yolov3_loss_op.py | 6 +- 3 files changed, 103 insertions(+), 216 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 3bd0db8b5..495a8f6c0 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -204,7 +204,11 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, ops::Yolov3LossGradMaker); REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); -REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel, - ops::Yolov3LossKernel); -REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel, - ops::Yolov3LossGradKernel); +REGISTER_OP_CPU_KERNEL( + yolov3_loss, + ops::Yolov3LossKernel, + ops::Yolov3LossKernel); +REGISTER_OP_CPU_KERNEL( + yolov3_loss_grad, + ops::Yolov3LossGradKernel, + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 5de5b4efc..f086e89a9 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -13,6 +13,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { @@ -32,183 +33,6 @@ static inline bool isZero(T x) { return fabs(x) < 1e-6; } -template -static inline void CalcL1LossWithWeight(const Tensor& x, const Tensor& y, - const Tensor& weight, - const T loss_weight, T* loss) { - int n = x.dims()[0]; - int stride = x.numel() / n; - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - loss[i] += fabs(y_data[j] - x_data[j]) * weight_data[j] * loss_weight; - } - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static void CalcL1LossGradWithWeight(const T* loss_grad, Tensor* grad, - const Tensor& x, const Tensor& y, - const Tensor& weight) { - int n = x.dims()[0]; - int stride = x.numel() / n; - T* grad_data = grad->data(); - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - grad_data[j] = weight_data[j] * loss_grad[i]; - if (x_data[j] < y_data[j]) grad_data[j] *= -1.0; - } - grad_data += stride; - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static inline void CalcMSEWithWeight(const Tensor& x, const Tensor& y, - const Tensor& weight, const T loss_weight, - T* loss) { - int n = x.dims()[0]; - int stride = x.numel() / n; - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - loss[i] += pow(y_data[j] - x_data[j], 2) * weight_data[j] * loss_weight; - } - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static void CalcMSEGradWithWeight(const T* loss_grad, Tensor* grad, - const Tensor& x, const Tensor& y, - const Tensor& weight) { - int n = x.dims()[0]; - int stride = x.numel() / n; - T* grad_data = grad->data(); - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - grad_data[j] = - 2.0 * weight_data[j] * (x_data[j] - y_data[j]) * loss_grad[i]; - } - grad_data += stride; - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static inline void CalcSCEWithWeight(const Tensor& x, const Tensor& label, - const Tensor& weight, const T loss_weight, - T* loss) { - int n = x.dims()[0]; - int stride = x.numel() / n; - const T* x_data = x.data(); - const T* label_data = label.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - T term1 = (x_data[j] > 0) ? x_data[j] : 0; - T term2 = x_data[j] * label_data[j]; - T term3 = std::log(1.0 + std::exp(-std::abs(x_data[j]))); - loss[i] += (term1 - term2 + term3) * weight_data[j] * loss_weight; - } - x_data += stride; - label_data += stride; - weight_data += stride; - } -} - -template -static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad, - const Tensor& x, const Tensor& label, - const Tensor& weight) { - int n = x.dims()[0]; - int stride = x.numel() / n; - T* grad_data = grad->data(); - const T* x_data = x.data(); - const T* label_data = label.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) * - weight_data[j] * loss_grad[i]; - } - grad_data += stride; - x_data += stride; - label_data += stride; - weight_data += stride; - } -} - -// template -// static void SplitPredResult(const Tensor& input, Tensor* pred_conf, -// Tensor* pred_class, Tensor* pred_x, Tensor* -// pred_y, -// Tensor* pred_w, Tensor* pred_h, -// const int anchor_num, const int class_num) { -// const int n = input.dims()[0]; -// const int h = input.dims()[2]; -// const int w = input.dims()[3]; -// const int box_attr_num = 5 + class_num; -// -// auto input_t = EigenTensor::From(input); -// auto pred_conf_t = EigenTensor::From(*pred_conf); -// auto pred_class_t = EigenTensor::From(*pred_class); -// auto pred_x_t = EigenTensor::From(*pred_x); -// auto pred_y_t = EigenTensor::From(*pred_y); -// auto pred_w_t = EigenTensor::From(*pred_w); -// auto pred_h_t = EigenTensor::From(*pred_h); -// -// for (int i = 0; i < n; i++) { -// for (int an_idx = 0; an_idx < anchor_num; an_idx++) { -// for (int j = 0; j < h; j++) { -// for (int k = 0; k < w; k++) { -// pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, -// k); -// pred_y_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 1, j, k); -// pred_w_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 2, j, k); -// pred_h_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 3, j, k); -// -// pred_conf_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 4, j, k); -// -// for (int c = 0; c < class_num; c++) { -// pred_class_t(i, an_idx, j, k, c) = -// input_t(i, box_attr_num * an_idx + 5 + c, j, k); -// } -// } -// } -// } -// } -// } - template static T CalcBoxIoU(std::vector box1, std::vector box2) { T b1_x1 = box1[0] - box1[2] / 2; @@ -242,30 +66,36 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, Tensor* tconf, Tensor* tclass) { const int n = gt_box.dims()[0]; const int b = gt_box.dims()[1]; - const int anchor_num = anchors.size() / 2; - auto gt_box_t = EigenTensor::From(gt_box); - auto gt_label_t = EigenTensor::From(gt_label); - auto conf_mask_t = EigenTensor::From(*conf_mask).setConstant(1.0); - auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0.0); - auto tx_t = EigenTensor::From(*tx).setConstant(0.0); - auto ty_t = EigenTensor::From(*ty).setConstant(0.0); - auto tw_t = EigenTensor::From(*tw).setConstant(0.0); - auto th_t = EigenTensor::From(*th).setConstant(0.0); - auto tweight_t = EigenTensor::From(*tweight).setConstant(0.0); - auto tconf_t = EigenTensor::From(*tconf).setConstant(0.0); - auto tclass_t = EigenTensor::From(*tclass).setConstant(0.0); + const int an_num = anchors.size() / 2; + const int h = tclass->dims()[2]; + const int w = tclass->dims()[3]; + const int class_num = tclass->dims()[4]; + + const T* gt_box_data = gt_box.data(); + const int* gt_label_data = gt_label.data(); + T* conf_mask_data = conf_mask->data(); + T* obj_mask_data = obj_mask->data(); + T* tx_data = tx->data(); + T* ty_data = ty->data(); + T* tw_data = tw->data(); + T* th_data = th->data(); + T* tweight_data = tweight->data(); + T* tconf_data = tconf->data(); + T* tclass_data = tclass->data(); for (int i = 0; i < n; i++) { for (int j = 0; j < b; j++) { - if (isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { + int box_idx = (i * b + j) * 4; + if (isZero(gt_box_data[box_idx + 2]) && + isZero(gt_box_data[box_idx + 3])) { continue; } - int cur_label = gt_label_t(i, j); - T gx = gt_box_t(i, j, 0) * grid_size; - T gy = gt_box_t(i, j, 1) * grid_size; - T gw = gt_box_t(i, j, 2) * input_size; - T gh = gt_box_t(i, j, 3) * input_size; + int cur_label = gt_label_data[i * b + j]; + T gx = gt_box_data[box_idx] * grid_size; + T gy = gt_box_data[box_idx + 1] * grid_size; + T gw = gt_box_data[box_idx + 2] * input_size; + T gh = gt_box_data[box_idx + 3] * input_size; int gi = static_cast(gx); int gj = static_cast(gy); @@ -273,7 +103,7 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, T iou; int best_an_index = -1; std::vector gt_box_shape({0, 0, gw, gh}); - for (int an_idx = 0; an_idx < anchor_num; an_idx++) { + for (int an_idx = 0; an_idx < an_num; an_idx++) { std::vector anchor_shape({0, 0, static_cast(anchors[2 * an_idx]), static_cast(anchors[2 * an_idx + 1])}); iou = CalcBoxIoU(gt_box_shape, anchor_shape); @@ -282,19 +112,22 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, best_an_index = an_idx; } if (iou > ignore_thresh) { - conf_mask_t(i, an_idx, gj, gi) = static_cast(0.0); + int conf_idx = ((i * an_num + an_idx) * h + gj) * w + gi; + conf_mask_data[conf_idx] = static_cast(0.0); } } - conf_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); - obj_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); - tx_t(i, best_an_index, gj, gi) = gx - gi; - ty_t(i, best_an_index, gj, gi) = gy - gj; - tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); - th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); - tweight_t(i, best_an_index, gj, gi) = - 2.0 - gt_box_t(i, j, 2) * gt_box_t(i, j, 3); - tclass_t(i, best_an_index, gj, gi, cur_label) = 1; - tconf_t(i, best_an_index, gj, gi) = 1; + + int obj_idx = ((i * an_num + best_an_index) * h + gj) * w + gi; + conf_mask_data[obj_idx] = static_cast(1.0); + obj_mask_data[obj_idx] = static_cast(1.0); + tx_data[obj_idx] = gx - gi; + ty_data[obj_idx] = gy - gj; + tw_data[obj_idx] = log(gw / anchors[2 * best_an_index]); + th_data[obj_idx] = log(gh / anchors[2 * best_an_index + 1]); + tweight_data[obj_idx] = + 2.0 - gt_box_data[box_idx + 2] * gt_box_data[box_idx + 3]; + tconf_data[obj_idx] = static_cast(1.0); + tclass_data[obj_idx * class_num + cur_label] = static_cast(1.0); } } } @@ -427,18 +260,26 @@ static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx, const int class_num = tclass.dims()[4]; const int grid_num = h * w; + // T l = 0.0; CalcSCE(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num, 1); CalcSCE(loss_data, input_data + grid_num, ty_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num, 1); + // LOG(ERROR) << "C++ xy: " << loss_data[0] - l; + // l = loss_data[0]; CalcL1Loss(loss_data, input_data + 2 * grid_num, tw_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num); CalcL1Loss(loss_data, input_data + 3 * grid_num, th_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num); + // LOG(ERROR) << "C++ wh: " << loss_data[0] - l; + // l = loss_data[0]; CalcSCE(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data, conf_mask_data, n, an_num, grid_num, class_num, 1); + // LOG(ERROR) << "C++ conf: " << loss_data[0] - l; + // l = loss_data[0]; CalcSCE(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data, obj_mask_data, n, an_num, grid_num, class_num, class_num); + // LOG(ERROR) << "C++ class: " << loss_data[0] - l; } template @@ -488,7 +329,7 @@ static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, obj_mask_data, n, an_num, grid_num, class_num, class_num); } -template +template class Yolov3LossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -517,6 +358,27 @@ class Yolov3LossKernel : public framework::OpKernel { tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + + math::SetConstant constant; + constant(ctx.template device_context(), &conf_mask, + static_cast(1.0)); + constant(ctx.template device_context(), &obj_mask, + static_cast(0.0)); + constant(ctx.template device_context(), &tx, + static_cast(0.0)); + constant(ctx.template device_context(), &ty, + static_cast(0.0)); + constant(ctx.template device_context(), &tw, + static_cast(0.0)); + constant(ctx.template device_context(), &th, + static_cast(0.0)); + constant(ctx.template device_context(), &tweight, + static_cast(0.0)); + constant(ctx.template device_context(), &tconf, + static_cast(0.0)); + constant(ctx.template device_context(), &tclass, + static_cast(0.0)); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, &tconf, &tclass); @@ -528,7 +390,7 @@ class Yolov3LossKernel : public framework::OpKernel { } }; -template +template class Yolov3LossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -559,6 +421,27 @@ class Yolov3LossGradKernel : public framework::OpKernel { tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + + math::SetConstant constant; + constant(ctx.template device_context(), &conf_mask, + static_cast(1.0)); + constant(ctx.template device_context(), &obj_mask, + static_cast(0.0)); + constant(ctx.template device_context(), &tx, + static_cast(0.0)); + constant(ctx.template device_context(), &ty, + static_cast(0.0)); + constant(ctx.template device_context(), &tw, + static_cast(0.0)); + constant(ctx.template device_context(), &th, + static_cast(0.0)); + constant(ctx.template device_context(), &tweight, + static_cast(0.0)); + constant(ctx.template device_context(), &tconf, + static_cast(0.0)); + constant(ctx.template device_context(), &tclass, + static_cast(0.0)); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, &tconf, &tclass); diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index cf7e2c528..862e77e66 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -197,12 +197,12 @@ class TestYolov3LossOp(OpTest): max_relative_error=0.31) def initTestCase(self): - self.anchors = [12, 12, 11, 13] + self.anchors = [12, 12] self.class_num = 5 self.ignore_thresh = 0.5 self.input_size = 416 - self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5) - self.gtbox_shape = (3, 5, 4) + self.x_shape = (1, len(self.anchors) // 2 * (5 + self.class_num), 3, 3) + self.gtbox_shape = (1, 5, 4) if __name__ == "__main__": -- GitLab