lookup_table_op.h 10.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15 16

#pragma once

17 18 19
#include <string>
#include <vector>

Y
Yi Wang 已提交
20 21 22
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
24
#include "paddle/phi/kernels/funcs/blas/blas.h"
25 26 27 28

namespace paddle {
namespace operators {

29
using SelectedRows = phi::SelectedRows;
30 31
using DDim = framework::DDim;

Q
qiaolongfei 已提交
32
constexpr int64_t kNoPadding = -1;
33 34

template <typename T>
Y
Yu Yang 已提交
35
class LookupTableKernel : public framework::OpKernel<T> {
36
 public:
37
  void Compute(const framework::ExecutionContext &context) const override {
38 39
    auto *ids_t = context.Input<phi::DenseTensor>("Ids");      // int tensor
    auto *output_t = context.Output<phi::DenseTensor>("Out");  // float tensor
40
    auto *table_var = context.InputVar("W");
41

H
hong 已提交
42 43 44
    auto id_name = context.InputNames("Ids").front();
    auto embedding_name = context.InputNames("W").front();
    auto out_name = context.OutputNames("Out").front();
Q
Qiao Longfei 已提交
45

46 47
    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
    bool is_test = context.Attr<bool>("is_test");
Q
Qiao Longfei 已提交
48

49 50
    int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
    int64_t ids_numel = ids_t->numel();
Q
Qiao Longfei 已提交
51

52 53
    if (table_var->IsType<phi::DenseTensor>()) {
      auto *table_t = context.Input<phi::DenseTensor>("W");
54 55 56 57 58 59 60 61 62 63 64
      int64_t row_number = table_t->dims()[0];
      int64_t row_width = table_t->dims()[1];

      auto *table = table_t->data<T>();
      auto *output = output_t->mutable_data<T>(context.GetPlace());

      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_LT(
65 66
              ids[i],
              row_number,
67 68 69 70
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
71 72
                  row_number,
                  ids[i]));
73
          PADDLE_ENFORCE_GE(
74 75
              ids[i],
              0,
76 77 78 79
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
80 81 82 83
                  row_number,
                  ids[i]));
          memcpy(output + i * row_width,
                 table + ids[i] * row_width,
84
                 row_width * sizeof(T));
85
        }
86 87
      }

88 89
    } else if (table_var->IsType<phi::SelectedRows>()) {
      const auto &table_t = table_var->Get<phi::SelectedRows>();
90 91 92
      int64_t row_width = table_t.value().dims()[1];
      const auto *table = table_t.value().data<T>();
      auto *output = output_t->mutable_data<T>(context.GetPlace());
93 94
      auto input_data_type =
          framework::TransToProtoVarType(table_t.value().dtype());
95 96 97 98 99
      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_GE(
100 101
              ids[i],
              0,
102 103 104 105 106 107 108 109
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0. But received %ld",
                  ids[i]));
          if (is_test) {
            auto id_index = table_t.GetIndexFromId(ids[i]);

            if (id_index != -1) {
110
              if (input_data_type == framework::proto::VarType::INT8 ||
111
                  input_data_type == framework::proto::VarType::INT16 ||
112
                  input_data_type == framework::proto::VarType::BF16) {
113 114
                memcpy(output + i * row_width,
                       table + id_index * row_width,
115 116
                       row_width * sizeof(T));
              } else {
L
Leo Chen 已提交
117
                auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
118 119
                blas.VCOPY(row_width,
                           table + id_index * row_width,
120 121 122 123 124
                           output + i * row_width);
              }
            } else {
              memset(output + i * row_width, 0, row_width * sizeof(T));
            }
Q
Qiao Longfei 已提交
125
          } else {
126
            auto id_index = table_t.Index(ids[i]);
127
            PADDLE_ENFORCE_GE(
128 129
                ids[i],
                0,
130 131 132 133
                platform::errors::InvalidArgument(
                    "Variable value (input) of OP(fluid.layers.embedding) "
                    "expected >= 0. But received %ld",
                    ids[i]));
134
            PADDLE_ENFORCE_GE(
135 136
                id_index,
                0,
137 138 139
                platform::errors::InvalidArgument(
                    "the input key should be exists. But received %d.",
                    id_index));
140

141
            if (input_data_type == framework::proto::VarType::INT8 ||
142
                input_data_type == framework::proto::VarType::INT16 ||
143
                input_data_type == framework::proto::VarType::BF16) {
144 145
              memcpy(output + i * row_width,
                     table + id_index * row_width,
146 147
                     row_width * sizeof(T));
            } else {
L
Leo Chen 已提交
148
              auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(context);
149 150
              blas.VCOPY(row_width,
                         table + id_index * row_width,
151 152
                         output + i * row_width);
            }
Q
Qiao Longfei 已提交
153
          }
154 155
        }
      }
156 157 158 159 160
    }
  }
};

template <typename T>
Y
Yu Yang 已提交
161
class LookupTableGradKernel : public framework::OpKernel<T> {
162
 public:
163
  void Compute(const framework::ExecutionContext &context) const override {
Q
qiaolongfei 已提交
164 165
    auto *table_var = context.InputVar("W");
    DDim table_dim;
166 167
    if (table_var->IsType<phi::DenseTensor>()) {
      table_dim = context.Input<phi::DenseTensor>("W")->dims();
168 169
    } else if (table_var->IsType<phi::SelectedRows>()) {
      auto *table_t = context.Input<phi::SelectedRows>("W");
Q
qiaolongfei 已提交
170 171
      table_dim = table_t->value().dims();
    } else {
172
      PADDLE_THROW(platform::errors::InvalidArgument(
Q
qiaolongfei 已提交
173
          "The parameter W of a LookupTable "
174
          "must be either phi::DenseTensor or SelectedRows"));
Q
qiaolongfei 已提交
175 176
    }

177
    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
178
    bool is_sparse = context.Attr<bool>("is_sparse");
179 180
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
181
    if (is_sparse) {
182 183 184
      auto *ids = context.Input<phi::DenseTensor>("Ids");
      auto *d_output =
          context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
185
      auto *d_table =
186
          context.Output<phi::SelectedRows>(framework::GradVarName("W"));
187

188
      auto *ids_data = ids->data<int64_t>();
189
      int64_t ids_num = ids->numel();
190

M
minqiyang 已提交
191
      std::vector<int64_t> new_rows;
M
minqiyang 已提交
192 193
      new_rows.resize(ids_num);
      std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t));
194
      d_table->set_rows(new_rows);
195

196
      auto *d_table_value = d_table->mutable_value();
197
      d_table_value->Resize({ids_num, table_dim[1]});
198 199 200 201 202 203 204 205
      d_table_value->mutable_data<T>(context.GetPlace());
      d_table->set_height(table_dim[0]);

      auto *d_output_data = d_output->data<T>();
      auto *d_table_data = d_table_value->data<T>();

      auto d_output_dims = d_output->dims();
      auto d_output_dims_2d =
206
          phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
207 208
      PADDLE_ENFORCE_EQ(d_table_value->dims(),
                        d_output_dims_2d,
209 210 211 212 213
                        platform::errors::InvalidArgument(
                            "ShapeError: The shape of lookup_table@Grad and "
                            "output@Grad should be same. "
                            "But received lookup_table@Grad's shape = [%s], "
                            "output@Grad's shape = [%s].",
214 215
                            d_table_value->dims(),
                            d_output_dims_2d));
216
      memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
217
    } else {
218 219 220 221 222
      auto *ids = context.Input<phi::DenseTensor>("Ids");
      auto *d_output =
          context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
      auto *d_table =
          context.Output<phi::DenseTensor>(framework::GradVarName("W"));
223

224
      auto *ids_data = ids->data<int64_t>();
225

226 227
      int64_t N = table_dim[0];
      int64_t D = table_dim[1];
228

229 230
      auto *d_output_data = d_output->data<T>();
      auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
231

232 233
      memset(d_table_data, 0, d_table->numel() * sizeof(T));

234
      for (int64_t i = 0; i < ids->numel(); ++i) {
Q
Qiao Longfei 已提交
235 236 237 238
        if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
          // the gradient of padding_idx should be 0, already done by memset, so
          // do nothing.
        } else {
239
          PADDLE_ENFORCE_LT(
240 241
              ids_data[i],
              N,
242 243 244 245
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
246 247
                  N,
                  ids_data[i]));
248
          PADDLE_ENFORCE_GE(
249 250
              ids_data[i],
              0,
251 252 253 254
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input"
                  "value.",
255 256
                  N,
                  ids_data[i]));
257 258 259
          for (int j = 0; j < D; ++j) {
            d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
          }
260
        }
261 262 263 264 265 266 267
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle