lookup_table_v2_op.h 9.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

T
tangwei12 已提交
17
#include <algorithm>
18 19 20 21 22 23
#include <string>
#include <vector>

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
24
#include "paddle/fluid/framework/selected_rows_utils.h"
25 26 27 28 29 30 31
#include "paddle/fluid/operators/math/blas.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
32
using SelectedRows = pten::SelectedRows;
33 34 35 36
using DDim = framework::DDim;

constexpr int64_t kNoPadding = -1;

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
template <typename InT, typename OutT>
static std::vector<OutT> CopyIdsToVector(const Tensor &ids) {
  auto numel = ids.numel();
  const auto *src = ids.data<InT>();
  std::vector<OutT> ret(numel);
  if (std::is_same<InT, OutT>::value) {
    std::memcpy(ret.data(), src, numel * sizeof(InT));
  } else {
    for (decltype(numel) i = 0; i < numel; ++i) {
      ret[i] = src[i];
    }
  }
  return ret;
}

52
template <typename T>
53 54 55 56
struct LookupTableV2CPUFunctor {
  LookupTableV2CPUFunctor(const framework::ExecutionContext &context,
                          const Tensor *ids_t)
      : context_(context), ids_t_(ids_t) {}
57

58 59 60 61
  template <typename IdT>
  void apply() {
    auto *output_t = context_.Output<LoDTensor>("Out");  // float tensor
    auto *table_var = context_.InputVar("W");
62

63
    int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
64

65 66
    auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
    auto ids_numel = static_cast<int64_t>(ids.size());
T
tangwei12 已提交
67

68 69 70 71
    if (table_var->template IsType<LoDTensor>()) {
      const auto &table_t = table_var->template Get<LoDTensor>();
      int64_t row_number = table_t.dims()[0];
      int64_t row_width = table_t.dims()[1];
T
tangwei12 已提交
72

73 74
      auto *table = table_t.template data<T>();
      auto *output = output_t->template mutable_data<T>(context_.GetPlace());
T
tangwei12 已提交
75 76 77 78 79 80 81

      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_LT(
              ids[i], row_number,
82 83 84 85 86
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
                  row_number, ids[i]));
T
tangwei12 已提交
87 88
          PADDLE_ENFORCE_GE(
              ids[i], 0,
89 90 91 92 93
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
                  row_number, ids[i]));
T
tangwei12 已提交
94 95
          memcpy(output + i * row_width, table + ids[i] * row_width,
                 row_width * sizeof(T));
96
        }
T
tangwei12 已提交
97
      }
98 99
    } else if (table_var->template IsType<pten::SelectedRows>()) {
      const auto &table_t = table_var->template Get<pten::SelectedRows>();
T
tangwei12 已提交
100
      int64_t row_width = table_t.value().dims()[1];
101 102
      const auto *table = table_t.value().template data<T>();
      auto *output = output_t->template mutable_data<T>(context_.GetPlace());
103
      auto input_data_type = table_t.value().type();
T
tangwei12 已提交
104 105 106 107 108 109 110

      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_GE(
              ids[i], 0,
111 112 113 114
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0. But received %ld",
                  ids[i]));
T
tangwei12 已提交
115
          auto id_index = table_t.Index(ids[i]);
116 117 118 119 120
          PADDLE_ENFORCE_GE(
              id_index, 0,
              platform::errors::InvalidArgument(
                  "the input key should be exists. But received %d.",
                  id_index));
121 122 123 124 125

          if (input_data_type == framework::proto::VarType::BF16) {
            memcpy(output + i * row_width, table + id_index * row_width,
                   row_width * sizeof(T));
          } else {
126
            auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context_);
127 128 129
            blas.VCOPY(row_width, table + id_index * row_width,
                       output + i * row_width);
          }
130 131 132 133
        }
      }
    }
  }
134 135 136 137

 private:
  const framework::ExecutionContext &context_;
  const Tensor *ids_t_;
138 139 140
};

template <typename T>
141
class LookupTableV2Kernel : public framework::OpKernel<T> {
142 143
 public:
  void Compute(const framework::ExecutionContext &context) const override {
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
    const auto *ids = context.Input<Tensor>("Ids");
    LookupTableV2CPUFunctor<T> functor(context, ids);
    framework::VisitIntDataType(ids->type(), functor);
  }
};

template <typename T>
struct LookupTableV2GradCPUFunctor {
  LookupTableV2GradCPUFunctor(const framework::ExecutionContext &context,
                              const Tensor *ids_t)
      : context_(context), ids_t_(ids_t) {}

  template <typename IdT>
  void apply() {
    auto *table_var = context_.InputVar("W");
159
    DDim table_dim;
160 161 162 163
    if (table_var->template IsType<LoDTensor>()) {
      table_dim = context_.Input<LoDTensor>("W")->dims();
    } else if (table_var->template IsType<pten::SelectedRows>()) {
      auto *table_t = context_.Input<pten::SelectedRows>("W");
164 165
      table_dim = table_t->value().dims();
    } else {
166
      PADDLE_THROW(platform::errors::InvalidArgument(
167
          "The parameter W of a LookupTableV2 "
168
          "must be either LoDTensor or SelectedRows"));
169 170
    }

171 172 173 174 175 176
    int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
    bool is_sparse = context_.Attr<bool>("is_sparse");

    auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
    auto ids_num = static_cast<int64_t>(ids.size());

177 178 179
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
    if (is_sparse) {
180
      auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
181
      auto *d_table =
182
          context_.Output<pten::SelectedRows>(framework::GradVarName("W"));
183

T
tangwei12 已提交
184
      d_table->set_rows(ids);
185 186 187 188

      auto *d_table_value = d_table->mutable_value();
      d_table_value->Resize({ids_num, table_dim[1]});

189
      d_table_value->template mutable_data<T>(context_.GetPlace());
190 191 192

      d_table->set_height(table_dim[0]);

193 194
      auto *d_output_data = d_output->template data<T>();
      auto *d_table_data = d_table_value->template data<T>();
195 196

      auto d_output_dims = d_output->dims();
197 198 199
      auto d_output_dims_2d =
          framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
      PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
200 201 202 203 204 205
                        platform::errors::InvalidArgument(
                            "ShapeError: The shape of lookup_table@Grad and "
                            "output@Grad should be same. "
                            "But received lookup_table@Grad's shape = [%s], "
                            "output@Grad's shape = [%s].",
                            d_table_value->dims(), d_output_dims_2d));
206 207 208
      memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());

    } else {
209 210
      auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context_.Output<LoDTensor>(framework::GradVarName("W"));
T
tangwei12 已提交
211
      auto *ids_data = ids.data();
212 213 214 215

      int64_t N = table_dim[0];
      int64_t D = table_dim[1];

216 217 218
      auto *d_output_data = d_output->template data<T>();
      auto *d_table_data =
          d_table->template mutable_data<T>(context_.GetPlace());
219 220 221

      memset(d_table_data, 0, d_table->numel() * sizeof(T));

T
tangwei12 已提交
222
      for (int64_t i = 0; i < ids_num; ++i) {
223 224 225 226 227 228
        if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
          // the gradient of padding_idx should be 0, already done by memset, so
          // do nothing.
        } else {
          PADDLE_ENFORCE_LT(
              ids_data[i], N,
229 230 231 232 233
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
                  N, ids_data[i]));
234 235
          PADDLE_ENFORCE_GE(
              ids_data[i], 0,
236 237 238 239 240
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
                  N, ids_data[i]));
241 242 243 244 245 246 247
          for (int j = 0; j < D; ++j) {
            d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
          }
        }
      }
    }
  }
248 249 250 251 252 253 254 255 256 257 258 259 260 261

 private:
  const framework::ExecutionContext &context_;
  const Tensor *ids_t_;
};

template <typename T>
class LookupTableV2GradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    const auto *ids = context.Input<Tensor>("Ids");
    LookupTableV2GradCPUFunctor<T> functor(context, ids);
    framework::VisitIntDataType(ids->type(), functor);
  }
262 263 264 265
};

}  // namespace operators
}  // namespace paddle