nll_loss_op.h 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
static void nll_loss_1D(T* out_data, T* total_weight_data, const T* x_data,
                        const int64_t* label_data, const T* weight_data,
                        const int64_t batch_size, const int64_t n_classes,
                        const std::string reduction,
                        const int64_t ignore_index) {
  if (reduction == "none") {
    for (int64_t i = 0; i < batch_size; ++i) {
      const auto cur_label = label_data[i];
      if (cur_label == ignore_index) {
        out_data[i] = 0;
        continue;
      }
      PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
                        platform::errors::InvalidArgument(
                            "label should not be out of bounds."));

      const auto cur_weight =
          weight_data ? weight_data[cur_label] : static_cast<T>(1);
      out_data[i] = -x_data[i * n_classes + cur_label] * cur_weight;
    }
    return;
  }

  T output_val = 0;
  T total_weight_val = 0;

  for (int64_t i = 0; i < batch_size; i++) {
    const auto cur_label = label_data[i];
    if (cur_label == ignore_index) {
      out_data[i] = 0;
      continue;
    }
    PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
                      platform::errors::InvalidArgument(
                          "label should not be out of bounds."));

    const auto cur_weight =
        weight_data ? weight_data[cur_label] : static_cast<T>(1);
    total_weight_val += cur_weight;
    output_val -= x_data[i * n_classes + cur_label] * cur_weight;
  }
  if (reduction == "mean" && total_weight_val != 0) {
    output_val /= total_weight_val;
  }
  *out_data = output_val;
  *total_weight_data = total_weight_val;
}

template <typename T>
static void nll_loss_2D(T* out_data, T* total_weight_data, const T* x_data,
                        const int64_t* label_data, const T* weight_data,
                        const int64_t batch_size, const int64_t n_classes,
                        const int64_t in_dim2, const int64_t in_dim3,
                        const std::string reduction,
                        const int64_t ignore_index) {
  const auto map_size = in_dim2 * in_dim3;
  const auto sample_size = n_classes * map_size;
  if (reduction == "none") {
    for (int i = 0; i < batch_size; i++) {
      for (int h = 0; h < in_dim2; h++) {
        for (int w = 0; w < in_dim3; w++) {
          const auto index = i * map_size + h * in_dim3 + w;
          const auto cur_label = label_data[index];
          if (cur_label == ignore_index) {
            out_data[index] = 0;
            continue;
          }
          PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
                            platform::errors::InvalidArgument(
                                "label should nor be out of bounds."));
          const auto cur_weight =
              weight_data ? weight_data[cur_label] : static_cast<T>(1);
          out_data[index] = -x_data[i * sample_size + cur_label * map_size +
                                    h * in_dim3 + w] *
                            cur_weight;
        }
      }
    }
    return;
  }

  T output_val = 0;
  T total_weight_val = 0;

  for (int i = 0; i < batch_size; i++) {
    for (int h = 0; h < in_dim2; h++) {
      for (int w = 0; w < in_dim3; w++) {
        const auto index = i * map_size + h * in_dim3 + w;
        const auto cur_label = label_data[index];
        if (cur_label == ignore_index) {
          out_data[index] = 0;
          continue;
        }
        PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
                          platform::errors::InvalidArgument(
                              "label should nor be out of bounds."));
        const auto cur_weight =
            weight_data ? weight_data[cur_label] : static_cast<T>(1);
        total_weight_val += cur_weight;
        output_val -=
            x_data[i * sample_size + cur_label * map_size + h * in_dim3 + w] *
            cur_weight;
      }
    }
  }

  if (reduction == "mean" && total_weight_val != 0) {
    output_val /= total_weight_val;
  }
  *out_data = output_val;
  *total_weight_data = total_weight_val;
}

template <typename DeviceContext, typename T>
class NLLLossOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* x = ctx.Input<Tensor>("X");
    auto* labels = ctx.Input<Tensor>("Label");
    auto* weight = ctx.Input<Tensor>("Weight");
    auto* out = ctx.Output<Tensor>("Out");
    auto* total_weight = ctx.Output<Tensor>("Total_weight");
    auto reduction = ctx.Attr<std::string>("reduction");
    auto ignore_index = ctx.Attr<int64_t>("ignore_index");

    auto x_data = x->data<T>();
    auto label_data = labels->data<int64_t>();
    auto weight_data = weight ? weight->data<T>() : nullptr;
    auto out_data = out->mutable_data<T>(ctx.GetPlace());
    auto total_weight_data = total_weight->mutable_data<T>(ctx.GetPlace());
    *total_weight_data = 0;

    auto x_dims = x->dims();
    const auto batch_size = x_dims[0];
    const auto n_classes = x_dims[1];

    if (x_dims.size() == 2) {
      nll_loss_1D<T>(out_data, total_weight_data, x_data, label_data,
                     weight_data, batch_size, n_classes, reduction,
                     ignore_index);
    } else if (x_dims.size() == 4) {
      const auto in_dim2 = x_dims[2];
      const auto in_dim3 = x_dims[3];
      nll_loss_2D<T>(out_data, total_weight_data, x_data, label_data,
                     weight_data, batch_size, n_classes, in_dim2, in_dim3,
                     reduction, ignore_index);
    }
  }
};

template <typename T>
static void nll_loss_grad_1D(T* dx_data, const T* dout_data,
                             const int64_t* label_data, const T* weight_data,
                             const T* total_weight_data,
                             const int64_t batch_size, const int64_t n_classes,
                             const std::string reduction,
                             const int64_t ignore_index) {
  if (reduction == "none") {
    for (int i = 0; i < batch_size; i++) {
      const auto cur_label = label_data[i];
      if (cur_label == ignore_index) {
        continue;
      }
      const auto cur_weight =
          weight_data ? weight_data[cur_label] : static_cast<T>(1);
      dx_data[i * n_classes + cur_label] = -dout_data[i] * cur_weight;
    }
    return;
  }

  const T dout_val = *dout_data;
  const T total_weight_val = *total_weight_data;
  for (int i = 0; i < batch_size; i++) {
    const auto cur_label = label_data[i];
    if (cur_label == ignore_index) {
      continue;
    }
    const auto cur_weight =
        weight_data ? weight_data[cur_label] : static_cast<T>(1);
    dx_data[i * n_classes + cur_label] = -dout_val * cur_weight;
    if (reduction == "mean") {
      dx_data[i * n_classes + cur_label] /= total_weight_val;
    }
  }
}

template <typename T>
static void nll_loss_grad_2D(T* dx_data, const T* dout_data,
                             const int64_t* label_data, const T* weight_data,
                             const T* total_weight_data,
                             const int64_t batch_size, const int64_t n_classes,
                             const int64_t in_dim2, const int64_t in_dim3,
                             const std::string reduction,
                             const int64_t ignore_index) {
  const auto map_size = in_dim2 * in_dim3;
  const auto sample_size = n_classes * map_size;

  if (reduction == "none") {
    for (int i = 0; i < batch_size; i++) {
      for (int h = 0; h < in_dim2; h++) {
        for (int w = 0; w < in_dim3; w++) {
          const auto index = i * map_size + h * in_dim3 + w;
          const auto cur_label = label_data[index];
          if (cur_label == ignore_index) {
            continue;
          }
          const auto cur_weight =
              weight_data ? weight_data[cur_label] : static_cast<T>(1);
          dx_data[i * sample_size + cur_label * map_size + h * in_dim3 + w] =
              -cur_weight * dout_data[index];
        }
      }
    }
    return;
  }

  const T dout_val = *dout_data;
  const T total_weight_val = *total_weight_data;
  for (int i = 0; i < batch_size; i++) {
    for (int h = 0; h < in_dim2; h++) {
      for (int w = 0; w < in_dim3; w++) {
        const auto index = i * map_size + h * in_dim3 + w;
        const auto cur_label = label_data[index];
        if (cur_label == ignore_index) {
          continue;
        }
        const auto cur_weight =
            weight_data ? weight_data[cur_label] : static_cast<T>(1);
        const auto dx_index =
            i * sample_size + cur_label * map_size + h * in_dim3 + w;
        dx_data[dx_index] = -dout_val * cur_weight;
        if (reduction == "mean") {
          dx_data[dx_index] /= total_weight_val;
        }
      }
    }
  }
}

template <typename DeviceContext, typename T>
class NLLLossGradOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* x = ctx.Input<Tensor>("X");
    auto* labels = ctx.Input<Tensor>("Label");
    auto* weight = ctx.Input<Tensor>("Weight");
    auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* total_weight = ctx.Input<Tensor>("Total_weight");
    auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto ignore_index = ctx.Attr<int64_t>("ignore_index");
    auto reduction = ctx.Attr<std::string>("reduction");

    auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
    auto dout_data = dout->data<T>();
    auto label_data = labels->data<int64_t>();
    auto weight_data = weight ? weight->data<T>() : nullptr;
    auto total_weight_data = total_weight->data<T>();
    memset(dx_data, 0, dx->numel() * sizeof(T));

    const auto x_dims = x->dims();
    const auto batch_size = x_dims[0];
    const auto n_classes = x_dims[1];

    if (x_dims.size() == 2) {
      nll_loss_grad_1D(dx_data, dout_data, label_data, weight_data,
                       total_weight_data, batch_size, n_classes, reduction,
                       ignore_index);
    } else if (x_dims.size() == 4) {
      const auto in_dim2 = x_dims[2];
      const auto in_dim3 = x_dims[3];
      nll_loss_grad_2D(dx_data, dout_data, label_data, weight_data,
                       total_weight_data, batch_size, n_classes, in_dim2,
                       in_dim3, reduction, ignore_index);
    }
  }
};

}  // namespace operators
}  // namespace paddle