lookup_table_v2_op.h 10.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

T
tangwei12 已提交
17
#include <algorithm>
18 19 20 21 22 23
#include <string>
#include <vector>

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
24
#include "paddle/fluid/framework/selected_rows_utils.h"
25
#include "paddle/phi/kernels/funcs/blas/blas.h"
26 27 28 29 30 31

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
32
using SelectedRows = phi::SelectedRows;
33 34 35 36
using DDim = framework::DDim;

constexpr int64_t kNoPadding = -1;

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
template <typename InT, typename OutT>
static std::vector<OutT> CopyIdsToVector(const Tensor &ids) {
  auto numel = ids.numel();
  const auto *src = ids.data<InT>();
  std::vector<OutT> ret(numel);
  if (std::is_same<InT, OutT>::value) {
    std::memcpy(ret.data(), src, numel * sizeof(InT));
  } else {
    for (decltype(numel) i = 0; i < numel; ++i) {
      ret[i] = src[i];
    }
  }
  return ret;
}

52
template <typename T>
53 54 55 56
struct LookupTableV2CPUFunctor {
  LookupTableV2CPUFunctor(const framework::ExecutionContext &context,
                          const Tensor *ids_t)
      : context_(context), ids_t_(ids_t) {}
57

58 59 60 61
  template <typename IdT>
  void apply() {
    auto *output_t = context_.Output<LoDTensor>("Out");  // float tensor
    auto *table_var = context_.InputVar("W");
62

63
    int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
64

65 66
    auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
    auto ids_numel = static_cast<int64_t>(ids.size());
T
tangwei12 已提交
67

68 69 70 71
    if (table_var->template IsType<LoDTensor>()) {
      const auto &table_t = table_var->template Get<LoDTensor>();
      int64_t row_number = table_t.dims()[0];
      int64_t row_width = table_t.dims()[1];
T
tangwei12 已提交
72

73 74
      auto *table = table_t.template data<T>();
      auto *output = output_t->template mutable_data<T>(context_.GetPlace());
T
tangwei12 已提交
75 76 77 78 79 80

      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_LT(
81 82
              ids[i],
              row_number,
83 84 85 86
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
87 88
                  row_number,
                  ids[i]));
T
tangwei12 已提交
89
          PADDLE_ENFORCE_GE(
90 91
              ids[i],
              0,
92 93 94 95
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
96 97 98 99
                  row_number,
                  ids[i]));
          memcpy(output + i * row_width,
                 table + ids[i] * row_width,
T
tangwei12 已提交
100
                 row_width * sizeof(T));
101
        }
T
tangwei12 已提交
102
      }
103 104
    } else if (table_var->template IsType<phi::SelectedRows>()) {
      const auto &table_t = table_var->template Get<phi::SelectedRows>();
T
tangwei12 已提交
105
      int64_t row_width = table_t.value().dims()[1];
106 107
      const auto *table = table_t.value().template data<T>();
      auto *output = output_t->template mutable_data<T>(context_.GetPlace());
108 109
      auto input_data_type =
          framework::TransToProtoVarType(table_t.value().dtype());
T
tangwei12 已提交
110 111 112 113 114 115

      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_GE(
116 117
              ids[i],
              0,
118 119 120 121
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0. But received %ld",
                  ids[i]));
T
tangwei12 已提交
122
          auto id_index = table_t.Index(ids[i]);
123
          PADDLE_ENFORCE_GE(
124 125
              id_index,
              0,
126 127 128
              platform::errors::InvalidArgument(
                  "the input key should be exists. But received %d.",
                  id_index));
129 130

          if (input_data_type == framework::proto::VarType::BF16) {
131 132
            memcpy(output + i * row_width,
                   table + id_index * row_width,
133 134
                   row_width * sizeof(T));
          } else {
135
            auto blas =
136
                phi::funcs::GetBlas<platform::CPUDeviceContext, T>(context_);
137 138
            blas.VCOPY(row_width,
                       table + id_index * row_width,
139 140
                       output + i * row_width);
          }
141 142 143 144
        }
      }
    }
  }
145 146 147 148

 private:
  const framework::ExecutionContext &context_;
  const Tensor *ids_t_;
149 150 151
};

template <typename T>
152
class LookupTableV2Kernel : public framework::OpKernel<T> {
153 154
 public:
  void Compute(const framework::ExecutionContext &context) const override {
155 156
    const auto *ids = context.Input<Tensor>("Ids");
    LookupTableV2CPUFunctor<T> functor(context, ids);
157 158
    framework::VisitIntDataType(framework::TransToProtoVarType(ids->dtype()),
                                functor);
159 160 161 162 163 164 165 166 167 168 169 170
  }
};

template <typename T>
struct LookupTableV2GradCPUFunctor {
  LookupTableV2GradCPUFunctor(const framework::ExecutionContext &context,
                              const Tensor *ids_t)
      : context_(context), ids_t_(ids_t) {}

  template <typename IdT>
  void apply() {
    auto *table_var = context_.InputVar("W");
171
    DDim table_dim;
172 173
    if (table_var->template IsType<LoDTensor>()) {
      table_dim = context_.Input<LoDTensor>("W")->dims();
174 175
    } else if (table_var->template IsType<phi::SelectedRows>()) {
      auto *table_t = context_.Input<phi::SelectedRows>("W");
176 177
      table_dim = table_t->value().dims();
    } else {
178
      PADDLE_THROW(platform::errors::InvalidArgument(
179
          "The parameter W of a LookupTableV2 "
180
          "must be either LoDTensor or SelectedRows"));
181 182
    }

183 184 185 186 187 188
    int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
    bool is_sparse = context_.Attr<bool>("is_sparse");

    auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
    auto ids_num = static_cast<int64_t>(ids.size());

189 190 191
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
    if (is_sparse) {
192
      auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
193
      auto *d_table =
194
          context_.Output<phi::SelectedRows>(framework::GradVarName("W"));
195

T
tangwei12 已提交
196
      d_table->set_rows(ids);
197 198 199 200

      auto *d_table_value = d_table->mutable_value();
      d_table_value->Resize({ids_num, table_dim[1]});

201
      d_table_value->template mutable_data<T>(context_.GetPlace());
202 203 204

      d_table->set_height(table_dim[0]);

205 206
      auto *d_output_data = d_output->template data<T>();
      auto *d_table_data = d_table_value->template data<T>();
207 208

      auto d_output_dims = d_output->dims();
209
      auto d_output_dims_2d =
210
          phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
211 212
      PADDLE_ENFORCE_EQ(d_table_value->dims(),
                        d_output_dims_2d,
213 214 215 216 217
                        platform::errors::InvalidArgument(
                            "ShapeError: The shape of lookup_table@Grad and "
                            "output@Grad should be same. "
                            "But received lookup_table@Grad's shape = [%s], "
                            "output@Grad's shape = [%s].",
218 219
                            d_table_value->dims(),
                            d_output_dims_2d));
220 221 222
      memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());

    } else {
223 224
      auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context_.Output<LoDTensor>(framework::GradVarName("W"));
T
tangwei12 已提交
225
      auto *ids_data = ids.data();
226 227 228 229

      int64_t N = table_dim[0];
      int64_t D = table_dim[1];

230 231 232
      auto *d_output_data = d_output->template data<T>();
      auto *d_table_data =
          d_table->template mutable_data<T>(context_.GetPlace());
233 234 235

      memset(d_table_data, 0, d_table->numel() * sizeof(T));

T
tangwei12 已提交
236
      for (int64_t i = 0; i < ids_num; ++i) {
237 238 239 240 241
        if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
          // the gradient of padding_idx should be 0, already done by memset, so
          // do nothing.
        } else {
          PADDLE_ENFORCE_LT(
242 243
              ids_data[i],
              N,
244 245 246 247
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
248 249
                  N,
                  ids_data[i]));
250
          PADDLE_ENFORCE_GE(
251 252
              ids_data[i],
              0,
253 254 255 256
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
257 258
                  N,
                  ids_data[i]));
259 260 261 262 263 264 265
          for (int j = 0; j < D; ++j) {
            d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
          }
        }
      }
    }
  }
266 267 268 269 270 271 272 273 274 275 276 277

 private:
  const framework::ExecutionContext &context_;
  const Tensor *ids_t_;
};

template <typename T>
class LookupTableV2GradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    const auto *ids = context.Input<Tensor>("Ids");
    LookupTableV2GradCPUFunctor<T> functor(context, ids);
278 279
    framework::VisitIntDataType(framework::TransToProtoVarType(ids->dtype()),
                                functor);
280
  }
281 282 283 284
};

}  // namespace operators
}  // namespace paddle