lookup_table_v2_op.h 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/blas.h"

#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;

constexpr int64_t kNoPadding = -1;

template <typename T>
class LookupTableV2Kernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    auto *ids_t = context.Input<LoDTensor>("Ids");      // int tensor
    auto *output_t = context.Output<LoDTensor>("Out");  // float tensor
    auto *table_var = context.InputVar("W");

    auto id_name = context.Inputs("Ids").front();
    auto embedding_name = context.Inputs("W").front();
    auto out_name = context.Outputs("Out").front();

    // for remote prefetch
    auto epmap = context.Attr<std::vector<std::string>>("epmap");
    auto remote_prefetch = context.Attr<bool>("remote_prefetch");
    auto height_sections =
        context.Attr<std::vector<int64_t>>("height_sections");
    auto table_names = context.Attr<std::vector<std::string>>("table_names");

    if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter server

#ifdef PADDLE_WITH_DISTRIBUTE
      operators::distributed::prefetch(id_name, out_name, embedding_name, false,
                                       table_names, epmap, height_sections,
                                       context, context.scope());
#else
      PADDLE_THROW(
          "paddle is not compiled with distribute support, can not do "
          "parameter prefetch!");
#endif
    } else {
      int64_t padding_idx = context.Attr<int64_t>("padding_idx");
      int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
      int64_t ids_numel = ids_t->numel();

      if (table_var->IsType<LoDTensor>()) {
        auto *table_t = context.Input<LoDTensor>("W");
        int64_t row_number = table_t->dims()[0];
        int64_t row_width = table_t->dims()[1];

        auto *table = table_t->data<T>();
        auto *output = output_t->mutable_data<T>(context.GetPlace());

        for (int64_t i = 0; i < ids_numel; ++i) {
          if (padding_idx != kNoPadding && ids[i] == padding_idx) {
            memset(output + i * row_width, 0, row_width * sizeof(T));
          } else {
            PADDLE_ENFORCE_LT(
                ids[i], row_number,
                "Variable value (input) of OP(fluid.layers.embedding) "
                "expected >= 0 and < %ld, but got %ld. Please check input "
                "value.",
                row_number, ids[i]);
            PADDLE_ENFORCE_GE(
                ids[i], 0,
                "Variable value (input) of OP(fluid.layers.embedding) "
                "expected >= 0 and < %ld, but got %ld. Please check input "
                "value.",
                row_number, ids[i]);
            memcpy(output + i * row_width, table + ids[i] * row_width,
                   row_width * sizeof(T));
          }
        }
      } else if (table_var->IsType<SelectedRows>()) {
        const auto &table_t = table_var->Get<SelectedRows>();
        int64_t row_width = table_t.value().dims()[1];
        const auto *table = table_t.value().data<T>();
        auto *output = output_t->mutable_data<T>(context.GetPlace());

        auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
        for (int64_t i = 0; i < ids_numel; ++i) {
          if (padding_idx != kNoPadding && ids[i] == padding_idx) {
            memset(output + i * row_width, 0, row_width * sizeof(T));
          } else {
116 117 118 119 120
            PADDLE_ENFORCE_GE(
                ids[i], 0,
                "Variable value (input) of OP(fluid.layers.embedding) "
                "expected >= 0. But received %ld",
                ids[i]);
121
            auto id_index = table_t.Index(ids[i]);
122 123 124
            PADDLE_ENFORCE_GE(
                id_index, 0, "the input key should be exists. But received %d.",
                id_index);
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
            blas.VCOPY(row_width, table + id_index * row_width,
                       output + i * row_width);
          }
        }
      }
    }
  }
};

template <typename T>
class LookupTableV2GradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    auto *table_var = context.InputVar("W");
    DDim table_dim;
    if (table_var->IsType<LoDTensor>()) {
      table_dim = context.Input<LoDTensor>("W")->dims();
    } else if (table_var->IsType<SelectedRows>()) {
      auto *table_t = context.Input<SelectedRows>("W");
      table_dim = table_t->value().dims();
    } else {
      PADDLE_THROW(
          "The parameter W of a LookupTableV2 "
          "must be either LoDTensor or SelectedRows");
    }

    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
    bool is_sparse = context.Attr<bool>("is_sparse");
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
    if (is_sparse) {
      auto *ids = context.Input<LoDTensor>("Ids");
      auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));

      auto *ids_data = ids->data<int64_t>();
      int64_t ids_num = ids->numel();

      std::vector<int64_t> new_rows;
      new_rows.resize(ids_num);
      std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t));
      d_table->set_rows(new_rows);

      auto *d_table_value = d_table->mutable_value();
      d_table_value->Resize({ids_num, table_dim[1]});

      d_table_value->mutable_data<T>(context.GetPlace());

      d_table->set_height(table_dim[0]);

      auto *d_output_data = d_output->data<T>();
      auto *d_table_data = d_table_value->data<T>();

      auto d_output_dims = d_output->dims();
179 180 181 182 183 184 185 186
      auto d_output_dims_2d =
          framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
      PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
                        "ShapeError: The shape of lookup_table@Grad and "
                        "output@Grad should be same. "
                        "But received lookup_table@Grad's shape = [%s], "
                        "output@Grad's shape = [%s].",
                        d_table_value->dims(), d_output_dims_2d);
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
      memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());

    } else {
      auto *ids = context.Input<LoDTensor>("Ids");
      auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));

      auto *ids_data = ids->data<int64_t>();

      int64_t N = table_dim[0];
      int64_t D = table_dim[1];

      auto *d_output_data = d_output->data<T>();
      auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());

      memset(d_table_data, 0, d_table->numel() * sizeof(T));

      for (int64_t i = 0; i < ids->numel(); ++i) {
        if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
          // the gradient of padding_idx should be 0, already done by memset, so
          // do nothing.
        } else {
          PADDLE_ENFORCE_LT(
              ids_data[i], N,
              "Variable value (input) of OP(fluid.layers.embedding) "
              "expected >= 0 and < %ld, but got %ld. Please check input value.",
              N, ids_data[i]);
          PADDLE_ENFORCE_GE(
              ids_data[i], 0,
              "Variable value (input) of OP(fluid.layers.embedding) "
              "expected >= 0 and < %ld, but got %ld. Please check input value.",
              N, ids_data[i]);
          for (int j = 0; j < D; ++j) {
            d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
          }
        }
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle