lookup_table_op.h 10.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15 16

#pragma once

17 18 19
#include <string>
#include <vector>

Y
Yi Wang 已提交
20 21 22 23
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
M
minqiyang 已提交
24
#include "paddle/fluid/operators/math/blas.h"
25

Q
Qiao Longfei 已提交
26
#ifdef PADDLE_WITH_DISTRIBUTE
Q
Qiao Longfei 已提交
27
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
Q
Qiao Longfei 已提交
28 29
#endif

30 31 32
namespace paddle {
namespace operators {

C
chengduoZH 已提交
33
using Tensor = framework::Tensor;
F
fengjiayi 已提交
34
using LoDTensor = framework::LoDTensor;
35
using SelectedRows = framework::SelectedRows;
36 37
using DDim = framework::DDim;

Q
qiaolongfei 已提交
38
constexpr int64_t kNoPadding = -1;
39 40

template <typename T>
Y
Yu Yang 已提交
41
class LookupTableKernel : public framework::OpKernel<T> {
42
 public:
43
  void Compute(const framework::ExecutionContext &context) const override {
44 45
    auto *ids_t = context.Input<LoDTensor>("Ids");      // int tensor
    auto *output_t = context.Output<LoDTensor>("Out");  // float tensor
46
    auto *table_var = context.InputVar("W");
47

H
hong 已提交
48 49 50
    auto id_name = context.InputNames("Ids").front();
    auto embedding_name = context.InputNames("W").front();
    auto out_name = context.OutputNames("Out").front();
Q
Qiao Longfei 已提交
51 52

    // for remote prefetch
Q
Qiao Longfei 已提交
53
    auto epmap = context.Attr<std::vector<std::string>>("epmap");
54
    auto remote_prefetch = context.Attr<bool>("remote_prefetch");
Q
Qiao Longfei 已提交
55 56
    auto height_sections =
        context.Attr<std::vector<int64_t>>("height_sections");
Q
Qiao Longfei 已提交
57
    auto table_names = context.Attr<std::vector<std::string>>("table_names");
Q
Qiao Longfei 已提交
58

59
    if (remote_prefetch && !epmap.empty()) {
Q
Qiao Longfei 已提交
60
// if epmap is not empty, then the parameter will be fetched from remote
61 62
// parameter server

Q
Qiao Longfei 已提交
63
#ifdef PADDLE_WITH_DISTRIBUTE
64 65 66
      operators::distributed::prefetch(id_name, out_name, embedding_name, false,
                                       table_names, epmap, height_sections,
                                       context, context.scope());
Q
Qiao Longfei 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
#else
      PADDLE_THROW(
          "paddle is not compiled with distribute support, can not do "
          "parameter prefetch!");
#endif
    } else {
      int64_t padding_idx = context.Attr<int64_t>("padding_idx");
      int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
      int64_t ids_numel = ids_t->numel();

      if (table_var->IsType<LoDTensor>()) {
        auto *table_t = context.Input<LoDTensor>("W");
        int64_t row_number = table_t->dims()[0];
        int64_t row_width = table_t->dims()[1];

        auto *table = table_t->data<T>();
        auto *output = output_t->mutable_data<T>(context.GetPlace());

        for (int64_t i = 0; i < ids_numel; ++i) {
          if (padding_idx != kNoPadding && ids[i] == padding_idx) {
            memset(output + i * row_width, 0, row_width * sizeof(T));
          } else {
89 90
            PADDLE_ENFORCE_LT(
                ids[i], row_number,
91 92 93 94 95
                platform::errors::InvalidArgument(
                    "Variable value (input) of OP(fluid.layers.embedding) "
                    "expected >= 0 and < %ld, but got %ld. Please check input "
                    "value.",
                    row_number, ids[i]));
96 97
            PADDLE_ENFORCE_GE(
                ids[i], 0,
98 99 100 101 102
                platform::errors::InvalidArgument(
                    "Variable value (input) of OP(fluid.layers.embedding) "
                    "expected >= 0 and < %ld, but got %ld. Please check input "
                    "value.",
                    row_number, ids[i]));
Q
Qiao Longfei 已提交
103 104 105
            memcpy(output + i * row_width, table + ids[i] * row_width,
                   row_width * sizeof(T));
          }
106
        }
Q
Qiao Longfei 已提交
107 108 109 110 111
      } else if (table_var->IsType<SelectedRows>()) {
        const auto &table_t = table_var->Get<SelectedRows>();
        int64_t row_width = table_t.value().dims()[1];
        const auto *table = table_t.value().data<T>();
        auto *output = output_t->mutable_data<T>(context.GetPlace());
112
        auto input_data_type = table_t.value().type();
Q
Qiao Longfei 已提交
113 114 115 116
        for (int64_t i = 0; i < ids_numel; ++i) {
          if (padding_idx != kNoPadding && ids[i] == padding_idx) {
            memset(output + i * row_width, 0, row_width * sizeof(T));
          } else {
117 118
            PADDLE_ENFORCE_GE(
                ids[i], 0,
119 120 121 122
                platform::errors::InvalidArgument(
                    "Variable value (input) of OP(fluid.layers.embedding) "
                    "expected >= 0. But received %ld",
                    ids[i]));
Q
Qiao Longfei 已提交
123
            auto id_index = table_t.Index(ids[i]);
124
            PADDLE_ENFORCE_GE(
125 126 127 128
                id_index, 0,
                platform::errors::InvalidArgument(
                    "the input key should be exists. But received %d.",
                    id_index));
129 130 131 132 133 134 135 136
            if (input_data_type == framework::proto::VarType::INT8) {
              memcpy(output + i * row_width, table + id_index * row_width,
                     row_width * sizeof(T));
            } else {
              auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
              blas.VCOPY(row_width, table + id_index * row_width,
                         output + i * row_width);
            }
Q
Qiao Longfei 已提交
137
          }
138 139
        }
      }
140 141 142 143 144
    }
  }
};

template <typename T>
Y
Yu Yang 已提交
145
class LookupTableGradKernel : public framework::OpKernel<T> {
146
 public:
147
  void Compute(const framework::ExecutionContext &context) const override {
Q
qiaolongfei 已提交
148 149 150 151 152 153 154 155
    auto *table_var = context.InputVar("W");
    DDim table_dim;
    if (table_var->IsType<LoDTensor>()) {
      table_dim = context.Input<LoDTensor>("W")->dims();
    } else if (table_var->IsType<SelectedRows>()) {
      auto *table_t = context.Input<SelectedRows>("W");
      table_dim = table_t->value().dims();
    } else {
Q
qiaolongfei 已提交
156 157 158
      PADDLE_THROW(
          "The parameter W of a LookupTable "
          "must be either LoDTensor or SelectedRows");
Q
qiaolongfei 已提交
159 160
    }

161
    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
162
    bool is_sparse = context.Attr<bool>("is_sparse");
163 164
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
165
    if (is_sparse) {
166 167 168
      auto *ids = context.Input<LoDTensor>("Ids");
      auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
169

170
      auto *ids_data = ids->data<int64_t>();
171
      int64_t ids_num = ids->numel();
172

M
minqiyang 已提交
173
      std::vector<int64_t> new_rows;
M
minqiyang 已提交
174 175
      new_rows.resize(ids_num);
      std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t));
176
      d_table->set_rows(new_rows);
177

178
      auto *d_table_value = d_table->mutable_value();
179
      d_table_value->Resize({ids_num, table_dim[1]});
M
minqiyang 已提交
180
      // FIXME(minqiyang):
M
minqiyang 已提交
181 182
      // memory optimization will NOT reuse Tensor with SelectedRows
      // so we could just share the tensor here directly.
M
minqiyang 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
      // However, the InferVarType method will infer the output SelectedRows
      // to Tensor sometimes, which is a bug, so we will add an attribute
      // here to indicate the inplace and remove this attribute after
      // the InferVarType's bug was fixed
      bool grad_inplace = context.Attr<bool>("grad_inplace");
      if (grad_inplace) {
        d_table_value->ShareDataWith(*d_output);
      } else {
        d_table_value->mutable_data<T>(context.GetPlace());

        d_table->set_height(table_dim[0]);

        auto *d_output_data = d_output->data<T>();
        auto *d_table_data = d_table_value->data<T>();

        auto d_output_dims = d_output->dims();
199 200 201
        auto d_output_dims_2d =
            framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
        PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
202 203 204 205 206 207
                          platform::errors::InvalidArgument(
                              "ShapeError: The shape of lookup_table@Grad and "
                              "output@Grad should be same. "
                              "But received lookup_table@Grad's shape = [%s], "
                              "output@Grad's shape = [%s].",
                              d_table_value->dims(), d_output_dims_2d));
M
minqiyang 已提交
208 209
        memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
      }
210
    } else {
211 212 213
      auto *ids = context.Input<LoDTensor>("Ids");
      auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
214

215
      auto *ids_data = ids->data<int64_t>();
216

217 218
      int64_t N = table_dim[0];
      int64_t D = table_dim[1];
219

220 221
      auto *d_output_data = d_output->data<T>();
      auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
222

223 224
      memset(d_table_data, 0, d_table->numel() * sizeof(T));

225
      for (int64_t i = 0; i < ids->numel(); ++i) {
Q
Qiao Longfei 已提交
226 227 228 229
        if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
          // the gradient of padding_idx should be 0, already done by memset, so
          // do nothing.
        } else {
230 231
          PADDLE_ENFORCE_LT(
              ids_data[i], N,
232 233 234 235 236
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
                  N, ids_data[i]));
237 238
          PADDLE_ENFORCE_GE(
              ids_data[i], 0,
239 240 241 242 243
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input"
                  "value.",
                  N, ids_data[i]));
244 245 246
          for (int j = 0; j < D; ++j) {
            d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
          }
247
        }
248 249 250 251 252 253 254
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle