lookup_table_v2_op.h 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

T
tangwei12 已提交
17
#include <algorithm>
18 19 20 21 22 23
#include <string>
#include <vector>

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
24
#include "paddle/fluid/framework/selected_rows_utils.h"
25
#include "paddle/pten/kernels/funcs/blas/blas.h"
26 27 28 29 30 31

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
32
using SelectedRows = pten::SelectedRows;
33 34 35 36
using DDim = framework::DDim;

constexpr int64_t kNoPadding = -1;

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
template <typename InT, typename OutT>
static std::vector<OutT> CopyIdsToVector(const Tensor &ids) {
  auto numel = ids.numel();
  const auto *src = ids.data<InT>();
  std::vector<OutT> ret(numel);
  if (std::is_same<InT, OutT>::value) {
    std::memcpy(ret.data(), src, numel * sizeof(InT));
  } else {
    for (decltype(numel) i = 0; i < numel; ++i) {
      ret[i] = src[i];
    }
  }
  return ret;
}

52
template <typename T>
53 54 55 56
struct LookupTableV2CPUFunctor {
  LookupTableV2CPUFunctor(const framework::ExecutionContext &context,
                          const Tensor *ids_t)
      : context_(context), ids_t_(ids_t) {}
57

58 59 60 61
  template <typename IdT>
  void apply() {
    auto *output_t = context_.Output<LoDTensor>("Out");  // float tensor
    auto *table_var = context_.InputVar("W");
62

63
    int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
64

65 66
    auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
    auto ids_numel = static_cast<int64_t>(ids.size());
T
tangwei12 已提交
67

68 69 70 71
    if (table_var->template IsType<LoDTensor>()) {
      const auto &table_t = table_var->template Get<LoDTensor>();
      int64_t row_number = table_t.dims()[0];
      int64_t row_width = table_t.dims()[1];
T
tangwei12 已提交
72

73 74
      auto *table = table_t.template data<T>();
      auto *output = output_t->template mutable_data<T>(context_.GetPlace());
T
tangwei12 已提交
75 76 77 78 79 80 81

      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_LT(
              ids[i], row_number,
82 83 84 85 86
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
                  row_number, ids[i]));
T
tangwei12 已提交
87 88
          PADDLE_ENFORCE_GE(
              ids[i], 0,
89 90 91 92 93
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
                  row_number, ids[i]));
T
tangwei12 已提交
94 95
          memcpy(output + i * row_width, table + ids[i] * row_width,
                 row_width * sizeof(T));
96
        }
T
tangwei12 已提交
97
      }
98 99
    } else if (table_var->template IsType<pten::SelectedRows>()) {
      const auto &table_t = table_var->template Get<pten::SelectedRows>();
T
tangwei12 已提交
100
      int64_t row_width = table_t.value().dims()[1];
101 102
      const auto *table = table_t.value().template data<T>();
      auto *output = output_t->template mutable_data<T>(context_.GetPlace());
103 104
      auto input_data_type =
          framework::TransToProtoVarType(table_t.value().dtype());
T
tangwei12 已提交
105 106 107 108 109 110 111

      for (int64_t i = 0; i < ids_numel; ++i) {
        if (padding_idx != kNoPadding && ids[i] == padding_idx) {
          memset(output + i * row_width, 0, row_width * sizeof(T));
        } else {
          PADDLE_ENFORCE_GE(
              ids[i], 0,
112 113 114 115
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0. But received %ld",
                  ids[i]));
T
tangwei12 已提交
116
          auto id_index = table_t.Index(ids[i]);
117 118 119 120 121
          PADDLE_ENFORCE_GE(
              id_index, 0,
              platform::errors::InvalidArgument(
                  "the input key should be exists. But received %d.",
                  id_index));
122 123 124 125 126

          if (input_data_type == framework::proto::VarType::BF16) {
            memcpy(output + i * row_width, table + id_index * row_width,
                   row_width * sizeof(T));
          } else {
127 128
            auto blas =
                pten::funcs::GetBlas<platform::CPUDeviceContext, T>(context_);
129 130 131
            blas.VCOPY(row_width, table + id_index * row_width,
                       output + i * row_width);
          }
132 133 134 135
        }
      }
    }
  }
136 137 138 139

 private:
  const framework::ExecutionContext &context_;
  const Tensor *ids_t_;
140 141 142
};

template <typename T>
143
class LookupTableV2Kernel : public framework::OpKernel<T> {
144 145
 public:
  void Compute(const framework::ExecutionContext &context) const override {
146 147
    const auto *ids = context.Input<Tensor>("Ids");
    LookupTableV2CPUFunctor<T> functor(context, ids);
148 149
    framework::VisitIntDataType(framework::TransToProtoVarType(ids->dtype()),
                                functor);
150 151 152 153 154 155 156 157 158 159 160 161
  }
};

template <typename T>
struct LookupTableV2GradCPUFunctor {
  LookupTableV2GradCPUFunctor(const framework::ExecutionContext &context,
                              const Tensor *ids_t)
      : context_(context), ids_t_(ids_t) {}

  template <typename IdT>
  void apply() {
    auto *table_var = context_.InputVar("W");
162
    DDim table_dim;
163 164 165 166
    if (table_var->template IsType<LoDTensor>()) {
      table_dim = context_.Input<LoDTensor>("W")->dims();
    } else if (table_var->template IsType<pten::SelectedRows>()) {
      auto *table_t = context_.Input<pten::SelectedRows>("W");
167 168
      table_dim = table_t->value().dims();
    } else {
169
      PADDLE_THROW(platform::errors::InvalidArgument(
170
          "The parameter W of a LookupTableV2 "
171
          "must be either LoDTensor or SelectedRows"));
172 173
    }

174 175 176 177 178 179
    int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
    bool is_sparse = context_.Attr<bool>("is_sparse");

    auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
    auto ids_num = static_cast<int64_t>(ids.size());

180 181 182
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
    if (is_sparse) {
183
      auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
184
      auto *d_table =
185
          context_.Output<pten::SelectedRows>(framework::GradVarName("W"));
186

T
tangwei12 已提交
187
      d_table->set_rows(ids);
188 189 190 191

      auto *d_table_value = d_table->mutable_value();
      d_table_value->Resize({ids_num, table_dim[1]});

192
      d_table_value->template mutable_data<T>(context_.GetPlace());
193 194 195

      d_table->set_height(table_dim[0]);

196 197
      auto *d_output_data = d_output->template data<T>();
      auto *d_table_data = d_table_value->template data<T>();
198 199

      auto d_output_dims = d_output->dims();
200
      auto d_output_dims_2d =
201
          pten::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
202
      PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
203 204 205 206 207 208
                        platform::errors::InvalidArgument(
                            "ShapeError: The shape of lookup_table@Grad and "
                            "output@Grad should be same. "
                            "But received lookup_table@Grad's shape = [%s], "
                            "output@Grad's shape = [%s].",
                            d_table_value->dims(), d_output_dims_2d));
209 210 211
      memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());

    } else {
212 213
      auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context_.Output<LoDTensor>(framework::GradVarName("W"));
T
tangwei12 已提交
214
      auto *ids_data = ids.data();
215 216 217 218

      int64_t N = table_dim[0];
      int64_t D = table_dim[1];

219 220 221
      auto *d_output_data = d_output->template data<T>();
      auto *d_table_data =
          d_table->template mutable_data<T>(context_.GetPlace());
222 223 224

      memset(d_table_data, 0, d_table->numel() * sizeof(T));

T
tangwei12 已提交
225
      for (int64_t i = 0; i < ids_num; ++i) {
226 227 228 229 230 231
        if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
          // the gradient of padding_idx should be 0, already done by memset, so
          // do nothing.
        } else {
          PADDLE_ENFORCE_LT(
              ids_data[i], N,
232 233 234 235 236
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
                  N, ids_data[i]));
237 238
          PADDLE_ENFORCE_GE(
              ids_data[i], 0,
239 240 241 242 243
              platform::errors::InvalidArgument(
                  "Variable value (input) of OP(fluid.layers.embedding) "
                  "expected >= 0 and < %ld, but got %ld. Please check input "
                  "value.",
                  N, ids_data[i]));
244 245 246 247 248 249 250
          for (int j = 0; j < D; ++j) {
            d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
          }
        }
      }
    }
  }
251 252 253 254 255 256 257 258 259 260 261 262

 private:
  const framework::ExecutionContext &context_;
  const Tensor *ids_t_;
};

template <typename T>
class LookupTableV2GradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    const auto *ids = context.Input<Tensor>("Ids");
    LookupTableV2GradCPUFunctor<T> functor(context, ids);
263 264
    framework::VisitIntDataType(framework::TransToProtoVarType(ids->dtype()),
                                functor);
265
  }
266 267 268 269
};

}  // namespace operators
}  // namespace paddle