lookup_table_op.h 10.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15 16

#pragma once

17 18 19
#include <string>
#include <vector>

Y
Yi Wang 已提交
20 21 22
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
24
#include "paddle/phi/kernels/funcs/blas/blas.h"
25 26 27 28

namespace paddle {
namespace operators {

29
using SelectedRows = phi::SelectedRows;
30 31
using DDim = framework::DDim;

Q
qiaolongfei 已提交
32
constexpr int64_t kNoPadding = -1;
33 34

template <typename T>
Y
Yu Yang 已提交
35
class LookupTableKernel : public framework::OpKernel<T> {
36
 public:
37
  void Compute(const framework::ExecutionContext &context) const override {
38 39
    auto *ids_t = context.Input<phi::DenseTensor>("Ids");      // int tensor
    auto *output_t = context.Output<phi::DenseTensor>("Out");  // float tensor
40
    auto *table_var = context.InputVar("W");
41

H
hong 已提交
42 43 44
    auto id_name = context.InputNames("Ids").front();
    auto embedding_name = context.InputNames("W").front();
    auto out_name = context.OutputNames("Out").front();
Q
Qiao Longfei 已提交
45

46 47
    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
    bool is_test = context.Attr<bool>("is_test");
Q
Qiao Longfei 已提交
48

49 50
    int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
    int64_t ids_numel = ids_t->numel();
Q
Qiao Longfei 已提交
51

52 53
    if (table_var->IsType<phi::DenseTensor>()) {
      auto *table_t = context.Input<phi::DenseTensor>("W");
54 55 56 57 58 59 60 61 62 63 64
      int64_t row_number = table_t->dims()[0];
      int64_t row_width = table_t->dims()[1];

      auto *table = table_t->data<T>();
      auto *output = output_t->mutable_data<T>(context.GetPlace());

      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_LT(
65 66
              ids[i],
              row_number,
67 68 69 70
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
71 72
                  row_number,
                  ids[i]));
73
          PADDLE_ENFORCE_GE(
74 75
              ids[i],
              0,
76 77 78 79
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
80 81 82 83
                  row_number,
                  ids[i]));
          memcpy(output + i * row_width,
                 table + ids[i] * row_width,
84
                 row_width * sizeof(T));
85
        }
86 87
      }

88 89
    } else if (table_var->IsType<phi::SelectedRows>()) {
      const auto &table_t = table_var->Get<phi::SelectedRows>();
90 91 92
      int64_t row_width = table_t.value().dims()[1];
      const auto *table = table_t.value().data<T>();
      auto *output = output_t->mutable_data<T>(context.GetPlace());
93 94
      auto input_data_type =
          framework::TransToProtoVarType(table_t.value().dtype());
95 96 97 98 99
      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_GE(
100 101
              ids[i],
              0,
102 103 104 105 106 107 108 109
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0. But received %ld",
                  ids[i]));
          if (is_test) {
            auto id_index = table_t.GetIndexFromId(ids[i]);

            if (id_index != -1) {
110
              if (input_data_type == framework::proto::VarType::INT8 ||
111
                  input_data_type == framework::proto::VarType::INT16 ||
112
                  input_data_type == framework::proto::VarType::BF16) {
113 114
                memcpy(output + i * row_width,
                       table + id_index * row_width,
115 116
                       row_width * sizeof(T));
              } else {
117 118 119
                auto &dev_ctx =
                    context.template device_context<phi::CPUContext>();
                auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
120 121
                blas.VCOPY(row_width,
                           table + id_index * row_width,
122 123 124 125 126
                           output + i * row_width);
              }
            } else {
              memset(output + i * row_width, 0, row_width * sizeof(T));
            }
Q
Qiao Longfei 已提交
127
          } else {
128
            auto id_index = table_t.Index(ids[i]);
129
            PADDLE_ENFORCE_GE(
130 131
                ids[i],
                0,
132 133 134 135
                platform::errors::InvalidArgument(
                    "Variable value (input) of OP(fluid.layers.embedding) "
                    "expected >= 0. But received %ld",
                    ids[i]));
136
            PADDLE_ENFORCE_GE(
137 138
                id_index,
                0,
139 140 141
                platform::errors::InvalidArgument(
                    "the input key should be exists. But received %d.",
                    id_index));
142

143
            if (input_data_type == framework::proto::VarType::INT8 ||
144
                input_data_type == framework::proto::VarType::INT16 ||
145
                input_data_type == framework::proto::VarType::BF16) {
146 147
              memcpy(output + i * row_width,
                     table + id_index * row_width,
148 149
                     row_width * sizeof(T));
            } else {
150 151 152
              auto &dev_ctx =
                  context.template device_context<phi::CPUContext>();
              auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
153 154
              blas.VCOPY(row_width,
                         table + id_index * row_width,
155 156
                         output + i * row_width);
            }
Q
Qiao Longfei 已提交
157
          }
158 159
        }
      }
160 161 162 163 164
    }
  }
};

template <typename T>
Y
Yu Yang 已提交
165
class LookupTableGradKernel : public framework::OpKernel<T> {
166
 public:
167
  void Compute(const framework::ExecutionContext &context) const override {
Q
qiaolongfei 已提交
168 169
    auto *table_var = context.InputVar("W");
    DDim table_dim;
170 171
    if (table_var->IsType<phi::DenseTensor>()) {
      table_dim = context.Input<phi::DenseTensor>("W")->dims();
172 173
    } else if (table_var->IsType<phi::SelectedRows>()) {
      auto *table_t = context.Input<phi::SelectedRows>("W");
Q
qiaolongfei 已提交
174 175
      table_dim = table_t->value().dims();
    } else {
176
      PADDLE_THROW(platform::errors::InvalidArgument(
Q
qiaolongfei 已提交
177
          "The parameter W of a LookupTable "
178
          "must be either phi::DenseTensor or SelectedRows"));
Q
qiaolongfei 已提交
179 180
    }

181
    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
182
    bool is_sparse = context.Attr<bool>("is_sparse");
183 184
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
185
    if (is_sparse) {
186 187 188
      auto *ids = context.Input<phi::DenseTensor>("Ids");
      auto *d_output =
          context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
189
      auto *d_table =
190
          context.Output<phi::SelectedRows>(framework::GradVarName("W"));
191

192
      auto *ids_data = ids->data<int64_t>();
193
      int64_t ids_num = ids->numel();
194

M
minqiyang 已提交
195
      std::vector<int64_t> new_rows;
M
minqiyang 已提交
196 197
      new_rows.resize(ids_num);
      std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t));
198
      d_table->set_rows(new_rows);
199

200
      auto *d_table_value = d_table->mutable_value();
201
      d_table_value->Resize({ids_num, table_dim[1]});
202 203 204 205 206 207 208 209
      d_table_value->mutable_data<T>(context.GetPlace());
      d_table->set_height(table_dim[0]);

      auto *d_output_data = d_output->data<T>();
      auto *d_table_data = d_table_value->data<T>();

      auto d_output_dims = d_output->dims();
      auto d_output_dims_2d =
210
          phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
211 212
      PADDLE_ENFORCE_EQ(d_table_value->dims(),
                        d_output_dims_2d,
213 214 215 216 217
                        platform::errors::InvalidArgument(
                            "ShapeError: The shape of lookup_table@Grad and "
                            "output@Grad should be same. "
                            "But received lookup_table@Grad's shape = [%s], "
                            "output@Grad's shape = [%s].",
218 219
                            d_table_value->dims(),
                            d_output_dims_2d));
220
      memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
221
    } else {
222 223 224 225 226
      auto *ids = context.Input<phi::DenseTensor>("Ids");
      auto *d_output =
          context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
      auto *d_table =
          context.Output<phi::DenseTensor>(framework::GradVarName("W"));
227

228
      auto *ids_data = ids->data<int64_t>();
229

230 231
      int64_t N = table_dim[0];
      int64_t D = table_dim[1];
232

233 234
      auto *d_output_data = d_output->data<T>();
      auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
235

236 237
      memset(d_table_data, 0, d_table->numel() * sizeof(T));

238
      for (int64_t i = 0; i < ids->numel(); ++i) {
Q
Qiao Longfei 已提交
239 240 241 242
        if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
          // the gradient of padding_idx should be 0, already done by memset, so
          // do nothing.
        } else {
243
          PADDLE_ENFORCE_LT(
244 245
              ids_data[i],
              N,
246 247 248 249
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
250 251
                  N,
                  ids_data[i]));
252
          PADDLE_ENFORCE_GE(
253 254
              ids_data[i],
              0,
255 256 257 258
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input"
                  "value.",
259 260
                  N,
                  ids_data[i]));
261 262 263
          for (int j = 0; j < D; ++j) {
            d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
          }
264
        }
265 266 267 268 269 270 271
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle