lookup_table_v2_op.h 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

T
tangwei12 已提交
17
#include <algorithm>
18 19 20 21 22 23
#include <string>
#include <vector>

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
24
#include "paddle/fluid/framework/selected_rows_utils.h"
25
#include "paddle/phi/kernels/funcs/blas/blas.h"
26 27 28 29

namespace paddle {
namespace operators {

30
using SelectedRows = phi::SelectedRows;
31 32 33 34
using DDim = framework::DDim;

constexpr int64_t kNoPadding = -1;

35
template <typename InT, typename OutT>
36
static std::vector<OutT> CopyIdsToVector(const phi::DenseTensor &ids) {
37 38 39 40 41 42 43 44 45 46 47 48 49
  auto numel = ids.numel();
  const auto *src = ids.data<InT>();
  std::vector<OutT> ret(numel);
  if (std::is_same<InT, OutT>::value) {
    std::memcpy(ret.data(), src, numel * sizeof(InT));
  } else {
    for (decltype(numel) i = 0; i < numel; ++i) {
      ret[i] = src[i];
    }
  }
  return ret;
}

50
template <typename T>
51 52
struct LookupTableV2CPUFunctor {
  LookupTableV2CPUFunctor(const framework::ExecutionContext &context,
53
                          const phi::DenseTensor *ids_t)
54
      : context_(context), ids_t_(ids_t) {}
55

56 57
  template <typename IdT>
  void apply() {
58
    auto *output_t = context_.Output<phi::DenseTensor>("Out");  // float tensor
59
    auto *table_var = context_.InputVar("W");
60

61
    int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
62

63 64
    auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
    auto ids_numel = static_cast<int64_t>(ids.size());
T
tangwei12 已提交
65

66 67
    if (table_var->template IsType<phi::DenseTensor>()) {
      const auto &table_t = table_var->template Get<phi::DenseTensor>();
68 69
      int64_t row_number = table_t.dims()[0];
      int64_t row_width = table_t.dims()[1];
T
tangwei12 已提交
70

71 72
      auto *table = table_t.template data<T>();
      auto *output = output_t->template mutable_data<T>(context_.GetPlace());
T
tangwei12 已提交
73 74 75 76 77 78

      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_LT(
79 80
              ids[i],
              row_number,
81 82 83 84
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
85 86
                  row_number,
                  ids[i]));
T
tangwei12 已提交
87
          PADDLE_ENFORCE_GE(
88 89
              ids[i],
              0,
90 91 92 93
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
94 95 96 97
                  row_number,
                  ids[i]));
          memcpy(output + i * row_width,
                 table + ids[i] * row_width,
T
tangwei12 已提交
98
                 row_width * sizeof(T));
99
        }
T
tangwei12 已提交
100
      }
101 102
    } else if (table_var->template IsType<phi::SelectedRows>()) {
      const auto &table_t = table_var->template Get<phi::SelectedRows>();
T
tangwei12 已提交
103
      int64_t row_width = table_t.value().dims()[1];
104 105
      const auto *table = table_t.value().template data<T>();
      auto *output = output_t->template mutable_data<T>(context_.GetPlace());
106 107
      auto input_data_type =
          framework::TransToProtoVarType(table_t.value().dtype());
T
tangwei12 已提交
108 109 110 111 112 113

      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_GE(
114 115
              ids[i],
              0,
116 117 118 119
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0. But received %ld",
                  ids[i]));
T
tangwei12 已提交
120
          auto id_index = table_t.Index(ids[i]);
121
          PADDLE_ENFORCE_GE(
122 123
              id_index,
              0,
124 125 126
              platform::errors::InvalidArgument(
                  "the input key should be exists. But received %d.",
                  id_index));
127 128

          if (input_data_type == framework::proto::VarType::BF16) {
129 130
            memcpy(output + i * row_width,
                   table + id_index * row_width,
131 132
                   row_width * sizeof(T));
          } else {
133 134
            auto &dev_ctx = context_.template device_context<phi::CPUContext>();
            auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
135 136
            blas.VCOPY(row_width,
                       table + id_index * row_width,
137 138
                       output + i * row_width);
          }
139 140 141 142
        }
      }
    }
  }
143 144 145

 private:
  const framework::ExecutionContext &context_;
146
  const phi::DenseTensor *ids_t_;
147 148 149
};

template <typename T>
150
class LookupTableV2Kernel : public framework::OpKernel<T> {
151 152
 public:
  void Compute(const framework::ExecutionContext &context) const override {
153
    const auto *ids = context.Input<phi::DenseTensor>("Ids");
154
    LookupTableV2CPUFunctor<T> functor(context, ids);
155 156
    framework::VisitIntDataType(framework::TransToProtoVarType(ids->dtype()),
                                functor);
157 158 159 160 161 162
  }
};

template <typename T>
struct LookupTableV2GradCPUFunctor {
  LookupTableV2GradCPUFunctor(const framework::ExecutionContext &context,
163
                              const phi::DenseTensor *ids_t)
164 165 166 167 168
      : context_(context), ids_t_(ids_t) {}

  template <typename IdT>
  void apply() {
    auto *table_var = context_.InputVar("W");
169
    DDim table_dim;
170 171
    if (table_var->template IsType<phi::DenseTensor>()) {
      table_dim = context_.Input<phi::DenseTensor>("W")->dims();
172 173
    } else if (table_var->template IsType<phi::SelectedRows>()) {
      auto *table_t = context_.Input<phi::SelectedRows>("W");
174 175
      table_dim = table_t->value().dims();
    } else {
176
      PADDLE_THROW(platform::errors::InvalidArgument(
177
          "The parameter W of a LookupTableV2 "
178
          "must be either phi::DenseTensor or SelectedRows"));
179 180
    }

181 182 183 184 185 186
    int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
    bool is_sparse = context_.Attr<bool>("is_sparse");

    auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
    auto ids_num = static_cast<int64_t>(ids.size());

187 188 189
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
    if (is_sparse) {
190 191
      auto *d_output =
          context_.Input<phi::DenseTensor>(framework::GradVarName("Out"));
192
      auto *d_table =
193
          context_.Output<phi::SelectedRows>(framework::GradVarName("W"));
194

T
tangwei12 已提交
195
      d_table->set_rows(ids);
196 197 198 199

      auto *d_table_value = d_table->mutable_value();
      d_table_value->Resize({ids_num, table_dim[1]});

200
      d_table_value->template mutable_data<T>(context_.GetPlace());
201 202 203

      d_table->set_height(table_dim[0]);

204 205
      auto *d_output_data = d_output->template data<T>();
      auto *d_table_data = d_table_value->template data<T>();
206 207

      auto d_output_dims = d_output->dims();
208
      auto d_output_dims_2d =
209
          phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
210 211
      PADDLE_ENFORCE_EQ(d_table_value->dims(),
                        d_output_dims_2d,
212 213 214 215 216
                        platform::errors::InvalidArgument(
                            "ShapeError: The shape of lookup_table@Grad and "
                            "output@Grad should be same. "
                            "But received lookup_table@Grad's shape = [%s], "
                            "output@Grad's shape = [%s].",
217 218
                            d_table_value->dims(),
                            d_output_dims_2d));
219 220 221
      memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());

    } else {
222 223 224 225
      auto *d_output =
          context_.Input<phi::DenseTensor>(framework::GradVarName("Out"));
      auto *d_table =
          context_.Output<phi::DenseTensor>(framework::GradVarName("W"));
T
tangwei12 已提交
226
      auto *ids_data = ids.data();
227 228 229 230

      int64_t N = table_dim[0];
      int64_t D = table_dim[1];

231 232 233
      auto *d_output_data = d_output->template data<T>();
      auto *d_table_data =
          d_table->template mutable_data<T>(context_.GetPlace());
234 235 236

      memset(d_table_data, 0, d_table->numel() * sizeof(T));

T
tangwei12 已提交
237
      for (int64_t i = 0; i < ids_num; ++i) {
238 239 240 241 242
        if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
          // the gradient of padding_idx should be 0, already done by memset, so
          // do nothing.
        } else {
          PADDLE_ENFORCE_LT(
243 244
              ids_data[i],
              N,
245 246 247 248
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
249 250
                  N,
                  ids_data[i]));
251
          PADDLE_ENFORCE_GE(
252 253
              ids_data[i],
              0,
254 255 256 257
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
258 259
                  N,
                  ids_data[i]));
260 261 262 263 264 265 266
          for (int j = 0; j < D; ++j) {
            d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
          }
        }
      }
    }
  }
267 268 269

 private:
  const framework::ExecutionContext &context_;
270
  const phi::DenseTensor *ids_t_;
271 272 273 274 275 276
};

template <typename T>
class LookupTableV2GradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
277
    const auto *ids = context.Input<phi::DenseTensor>("Ids");
278
    LookupTableV2GradCPUFunctor<T> functor(context, ids);
279 280
    framework::VisitIntDataType(framework::TransToProtoVarType(ids->dtype()),
                                functor);
281
  }
282 283 284 285
};

}  // namespace operators
}  // namespace paddle