smooth_l1_loss_op.h 6.5 KB
Newer Older
Y
yangyaming 已提交
1 2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
yangyaming 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
yangyaming 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
yangyaming 已提交
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/hostdevice.h"
Y
yangyaming 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename T>
32 33
struct SmoothL1LossForward {
  HOSTDEVICE SmoothL1LossForward(const T& sigma2) : sigma2(sigma2) {}
Y
yangyaming 已提交
34

35
  HOSTDEVICE T operator()(const T& val) const {
Y
yangyaming 已提交
36 37 38 39 40 41 42 43 44 45 46
    T abs_val = std::abs(val);
    if (abs_val < 1.0 / sigma2) {
      return 0.5 * val * val * sigma2;
    } else {
      return abs_val - 0.5 / sigma2;
    }
  }

  T sigma2;
};

Q
QI JUN 已提交
47
template <typename DeviceContext, typename T, typename AttrType = T>
Y
Yu Yang 已提交
48
class SmoothL1LossKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
49 50 51 52 53 54
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in0 = context.Input<Tensor>("X");
    auto* in1 = context.Input<Tensor>("Y");
    auto* in2 = context.Input<Tensor>("InsideWeight");
    auto* in3 = context.Input<Tensor>("OutsideWeight");
55
    auto* out0 = context.Output<Tensor>("Diff");
Y
yangyaming 已提交
56 57 58 59
    auto* out1 = context.Output<Tensor>("Out");

    out0->mutable_data<T>(context.GetPlace());
    out1->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
60 61
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Y
yangyaming 已提交
62

Y
yangyaming 已提交
63
    auto sigma = static_cast<T>(context.Attr<AttrType>("sigma"));
Y
yangyaming 已提交
64 65 66 67 68 69 70
    T sigma2 = sigma * sigma;
    bool has_weight = (in2 != nullptr) && (in3 != nullptr);

    auto x = EigenVector<T>::Flatten(*in0);
    auto y = EigenVector<T>::Flatten(*in1);
    auto diff = EigenVector<T>::Flatten(*out0);

Q
QI JUN 已提交
71
    diff.device(*place) = x - y;
Y
yangyaming 已提交
72 73 74 75
    // multiply inside weight
    if (has_weight) {
      auto inside_weight = EigenVector<T>::Flatten(*in2);
      // cache diff, reused in bp
Q
QI JUN 已提交
76
      diff.device(*place) = diff * inside_weight;
Y
yangyaming 已提交
77 78
    }

79 80 81 82 83
    auto in_counts = in0->numel();
    Tensor ptensor_errors;
    ptensor_errors.mutable_data<T>({static_cast<int>(in_counts)},
                                   context.GetPlace());
    auto errors = EigenVector<T>::Flatten(ptensor_errors);
Y
yangyaming 已提交
84
    // apply smooth l1 forward
Q
QI JUN 已提交
85
    errors.device(*place) = diff.unaryExpr(SmoothL1LossForward<T>(sigma2));
Y
yangyaming 已提交
86 87 88 89

    // multiply outside weight
    if (has_weight) {
      auto outside_weight = EigenVector<T>::Flatten(*in3);
Q
QI JUN 已提交
90
      errors.device(*place) = errors * outside_weight;
Y
yangyaming 已提交
91
    }
Y
yangyaming 已提交
92
    auto loss = EigenVector<T>::Flatten(*out1);
Y
yangyaming 已提交
93
    // first dimension of 'X' is the number of samples
Y
yangyaming 已提交
94 95 96
    auto mat_dims =
        framework::make_ddim({static_cast<int>(in0->dims()[0]),
                              static_cast<int>(in_counts / in0->dims()[0])});
97
    auto errors_mat_view = EigenMatrix<T>::From(ptensor_errors, mat_dims);
Q
QI JUN 已提交
98
    loss.device(*place) = errors_mat_view.sum(Eigen::array<int, 1>({{1}}));
Y
yangyaming 已提交
99 100 101 102 103
  }
};

template <typename T>
struct SmoothL1LossBackward {
104
  HOSTDEVICE SmoothL1LossBackward(const T& sigma2) : sigma2(sigma2) {}
Y
yangyaming 已提交
105

106
  HOSTDEVICE T operator()(const T& val) const {
Y
yangyaming 已提交
107 108 109 110 111 112 113 114 115 116 117
    T abs_val = std::abs(val);
    if (abs_val < 1.0 / sigma2) {
      return sigma2 * val;
    } else {
      return (0 < val) - (val < 0);
    }
  }

  T sigma2;
};

Q
QI JUN 已提交
118
template <typename DeviceContext, typename T, typename AttrType = T>
Y
Yu Yang 已提交
119
class SmoothL1LossGradKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
120 121 122 123
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in0 = context.Input<Tensor>("InsideWeight");
    auto* in1 = context.Input<Tensor>("OutsideWeight");
124
    auto* in2 = context.Input<Tensor>("Diff");
Y
yangyaming 已提交
125
    auto* og = context.Input<Tensor>(framework::GradVarName("Out"));
Y
yangyaming 已提交
126
    auto sigma = static_cast<T>(context.Attr<AttrType>("sigma"));
Y
yangyaming 已提交
127 128 129
    T sigma2 = sigma * sigma;
    bool has_weight = (in0 != nullptr) && (in1 != nullptr);

Q
QI JUN 已提交
130 131
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Y
yangyaming 已提交
132 133

    auto in_dims = in2->dims();
134
    auto counts = in2->numel();
Y
yangyaming 已提交
135 136 137 138
    auto cols = counts / in_dims[0];
    auto mat_dims = framework::make_ddim(
        {static_cast<int>(in_dims[0]), static_cast<int>(cols)});

139 140 141 142
    Tensor ptensor_diff;
    ptensor_diff.mutable_data<T>({static_cast<int>(counts)},
                                 context.GetPlace());
    auto diff = EigenVector<T>::Flatten(ptensor_diff);
Y
yangyaming 已提交
143
    // apply smooth l1 backwoard
Q
QI JUN 已提交
144
    diff.device(*place) = EigenVector<T>::Flatten(*in2).unaryExpr(
Y
yangyaming 已提交
145 146 147
        SmoothL1LossBackward<T>(sigma2));

    // compute weights
148 149 150
    Tensor ptensor_weights;
    ptensor_weights.mutable_data<T>(mat_dims, context.GetPlace());
    auto weights = EigenMatrix<T>::From(ptensor_weights);
Y
yangyaming 已提交
151
    // initialize to 1.0
Q
QI JUN 已提交
152
    weights.device(*place) = weights.constant(static_cast<T>(1.0));
Y
yangyaming 已提交
153 154 155
    if (has_weight) {
      auto inside_weight = EigenMatrix<T>::From(*in0, mat_dims);
      auto outside_weight = EigenMatrix<T>::From(*in1, mat_dims);
Q
QI JUN 已提交
156
      weights.device(*place) = inside_weight * outside_weight;
Y
yangyaming 已提交
157 158 159 160
    }

    // compute gradients
    auto out_grad = EigenMatrix<T>::From(*og);
161
    auto diff_mat_view = EigenMatrix<T>::From(ptensor_diff, mat_dims);
Y
yangyaming 已提交
162 163 164
    auto gradients = out_grad.broadcast(
                         Eigen::array<int, 2>({{1, static_cast<int>(cols)}})) *
                     weights * diff_mat_view;
Y
yangyaming 已提交
165

166 167 168
    auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
    auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));

Y
yangyaming 已提交
169 170 171
    if (out0) {
      out0->mutable_data<T>(context.GetPlace());
      auto x_grad = EigenMatrix<T>::From(*out0, mat_dims);
Q
QI JUN 已提交
172
      x_grad.device(*place) = gradients;
Y
yangyaming 已提交
173 174 175 176 177
    }

    if (out1) {
      out1->mutable_data<T>(context.GetPlace());
      auto y_grad = EigenMatrix<T>::From(*out1, mat_dims);
Q
QI JUN 已提交
178
      y_grad.device(*place) = -1 * gradients;
Y
yangyaming 已提交
179 180 181 182 183 184
    }
  }
};

}  // namespace operators
}  // namespace paddle