lookup_table_v2_op.cu 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lookup_table_v2_op.h"
18
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
19 20 21 22 23
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

24
template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX,
25
          bool PaddingFlag>
26
__global__ void LookupTableV2(T *output, const T *table, const IdT *ids,
27 28 29 30 31 32
                              const int64_t N, const int64_t K, const int64_t D,
                              const int64_t padding_idx) {
  int idx = threadIdx.x;
  int idy = blockIdx.x + threadIdx.y * GridDimX;

  while (idy < K) {
33
    auto id = static_cast<int64_t>(ids[idy]);
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    T *out = output + idy * D;
    const T *tab = table + id * D;
    for (int i = idx; i < D; i += BlockDimX) {
      if (PaddingFlag) {
        if (id == padding_idx)
          out[i] = static_cast<T>(0);
        else
          out[i] = tab[i];
      } else {
        out[i] = tab[i];
      }
    }
    idy += BlockDimY * GridDimX;
  }
}

50 51
template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableV2Grad(T *table, const T *output, const IdT *ids,
52 53 54 55 56 57
                                  const int64_t N, const int64_t K,
                                  const int64_t D) {
  int idx = threadIdx.x;
  int idy = blockIdx.x + threadIdx.y * GridDimX;

  while (idy < K) {
58
    auto id = static_cast<int64_t>(ids[idy]);
59 60 61 62 63 64 65 66 67
    const T *out = output + idy * D;
    T *tab = table + id * D;
    for (int i = idx; i < D; i += BlockDimX) {
      paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
    }
    idy += BlockDimY * GridDimX;
  }
}

T
tangwei12 已提交
68
template <typename T>
69 70 71 72
struct LookupTableV2CUDAFunctor {
  LookupTableV2CUDAFunctor(const framework::ExecutionContext &context,
                           const framework::Tensor *ids_t)
      : context_(context), ids_t_(ids_t) {}
73

74 75 76 77 78
  template <typename IdT>
  void apply() {
    auto *table_t = context_.Input<framework::Tensor>("W");
    auto *output_t = context_.Output<framework::Tensor>("Out");
    int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
79 80 81

    size_t N = table_t->dims()[0];
    size_t D = table_t->dims()[1];
82
    size_t K = ids_t_->numel();
83

84 85
    dim3 threads(256, 4);
    dim3 grids(80, 1);
86

87 88 89 90
    const auto *table = table_t->template data<T>();
    const auto *ids = ids_t_->template data<IdT>();
    auto *output = output_t->template mutable_data<T>(context_.GetPlace());
    auto stream = context_.cuda_device_context().stream();
T
tangwei12 已提交
91

92 93 94
    if (padding_idx == -1) {
      LookupTableV2<T, IdT, 256, 4, 80, false><<<grids, threads, 0, stream>>>(
          output, table, ids, N, K, D, padding_idx);
T
tangwei12 已提交
95
    } else {
96 97
      LookupTableV2<T, IdT, 256, 4, 80, true><<<grids, threads, 0, stream>>>(
          output, table, ids, N, K, D, padding_idx);
T
tangwei12 已提交
98
    }
99
  }
100 101 102 103

 private:
  const framework::ExecutionContext &context_;
  const framework::Tensor *ids_t_;
104 105 106
};

template <typename T>
107
class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
108 109
 public:
  void Compute(const framework::ExecutionContext &context) const override {
110 111
    const auto *ids_t = context.Input<framework::Tensor>("Ids");
    LookupTableV2CUDAFunctor<T> functor(context, ids_t);
112 113
    framework::VisitIntDataType(framework::TransToProtoVarType(ids_t->dtype()),
                                functor);
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
  }
};

template <typename InT, typename OutT>
__global__ void InputTypeConvert(const InT *in_ids, const int64_t K,
                                 OutT *out_ids) {
  for (int i = 0; i < K; i++) {
    out_ids[i] = static_cast<OutT>(in_ids[i]);
  }
}

template <typename T>
struct LookupTableV2GradCUDAFunctor {
  LookupTableV2GradCUDAFunctor(const framework::ExecutionContext &context,
                               const framework::Tensor *ids_t)
      : context_(context), ids_t_(ids_t) {}

  template <typename IdT>
  void apply() {
133
    auto &dev_ctx =
134 135
        context_.template device_context<platform::CUDADeviceContext>();
    bool is_sparse = context_.Attr<bool>("is_sparse");
136 137 138 139

    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
    if (is_sparse) {
140 141 142
      auto *table = context_.Input<framework::Tensor>("W");
      auto *d_output =
          context_.Input<framework::Tensor>(framework::GradVarName("Out"));
143
      auto *d_table =
144
          context_.Output<pten::SelectedRows>(framework::GradVarName("W"));
145

146 147
      const auto *ids_data = ids_t_->template data<IdT>();
      int64_t ids_num = ids_t_->numel();
T
tangwei12 已提交
148 149
      dim3 threads(128, 8);
      dim3 grids(8, 1);
150 151 152
      auto stream = dev_ctx.stream();
      framework::Vector<int64_t> new_rows;
      new_rows.resize(ids_num);
153
      auto gpu_place = context_.GetPlace();
154

155 156 157
      if (!std::is_same<IdT, int64_t>::value) {
        InputTypeConvert<<<grids, threads, 0, stream>>>(
            ids_data, ids_num, new_rows.MutableData(gpu_place));
T
tangwei12 已提交
158
      } else {
159 160
        memory::Copy(gpu_place, new_rows.CUDAMutableData(gpu_place), gpu_place,
                     ids_data, ids_num * sizeof(int64_t), stream);
T
tangwei12 已提交
161 162
      }

163 164 165 166
      d_table->set_rows(new_rows);

      auto *d_table_value = d_table->mutable_value();
      d_table_value->Resize({ids_num, table->dims()[1]});
167
      d_table_value->template mutable_data<T>(gpu_place);
168

169 170
      auto *d_table_data = d_table_value->template data<T>();
      auto *d_output_data = d_output->template data<T>();
171
      auto d_output_dims = d_output->dims();
172
      auto d_output_dims_2d =
173
          pten::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
174
      PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
175 176 177 178 179 180
                        platform::errors::InvalidArgument(
                            "ShapeError: The shape of lookup_table@Grad and "
                            "output@Grad should be same. "
                            "But received lookup_table@Grad's shape = [%s], "
                            "output@Grad's shape = [%s].",
                            d_table_value->dims(), d_output_dims_2d));
181 182 183 184
      memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
                   d_output->numel() * sizeof(T), stream);

    } else {
185 186 187 188
      auto d_output_t =
          context_.Input<framework::Tensor>(framework::GradVarName("Out"));
      auto d_table_t =
          context_.Output<framework::Tensor>(framework::GradVarName("W"));
189 190 191

      int N = d_table_t->dims()[0];
      int D = d_table_t->dims()[1];
192
      int K = ids_t_->numel();
T
tangwei12 已提交
193 194 195

      dim3 threads(128, 8);
      dim3 grids(8, 1);
196 197 198
      const T *d_output = d_output_t->template data<T>();
      const auto *ids = ids_t_->template data<IdT>();
      T *d_table = d_table_t->mutable_data<T>(context_.GetPlace());
199 200 201 202

      auto t = framework::EigenVector<T>::Flatten(*d_table_t);
      t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));

203 204 205
      LookupTableV2Grad<T, IdT, 128, 8,
                        8><<<grids, threads, 0, dev_ctx.stream()>>>(
          d_table, d_output, ids, N, K, D);
206 207
    }
  }
208 209 210 211 212 213 214 215 216 217 218 219

 private:
  const framework::ExecutionContext &context_;
  const framework::Tensor *ids_t_;
};

template <typename T>
class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    const auto *ids_t = context.Input<framework::Tensor>("Ids");
    LookupTableV2GradCUDAFunctor<T> functor(context, ids_t);
220 221
    framework::VisitIntDataType(framework::TransToProtoVarType(ids_t->dtype()),
                                functor);
222
  }
223 224 225 226 227 228 229 230 231 232 233 234 235 236
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(lookup_table_v2, ops::LookupTableV2CUDAKernel<float>,
                        ops::LookupTableV2CUDAKernel<double>,
                        ops::LookupTableV2CUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(lookup_table_v2_grad,
                        ops::LookupTableV2GradCUDAKernel<float>,
                        ops::LookupTableV2GradCUDAKernel<double>,
                        ops::LookupTableV2GradCUDAKernel<plat::float16>);