embedding_grad_kernel.cc 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/embedding_grad_kernel.h"

17
#include "paddle/fluid/memory/memcpy.h"
18 19
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
Y
ykkk2333 已提交
20
#include "paddle/phi/kernels/funcs/embedding_util.h"
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46

namespace phi {

template <typename T, typename Context>
void EmbeddingGradKernel(const Context& ctx,
                         const DenseTensor& input,
                         const DenseTensor& weight,
                         const DenseTensor& out_grad,
                         int64_t padding_idx,
                         DenseTensor* weight_grad) {
  DDim table_dim;
  table_dim = weight.dims();

  auto ids_t = &input;
  auto d_output_t = &out_grad;
  auto d_table_t = weight_grad;

  int64_t ids_numel = ids_t->numel();
  PADDLE_ENFORCE_EQ(
      ids_numel <= std::numeric_limits<int32_t>::max(),
      true,
      phi::errors::OutOfRange(
          "Number of ids greater than int32_t::max , please check "
          "number of ids in LookupTableV2GradXPUKernel."));

  auto& dev_ctx = ctx;
47 48 49 50 51 52 53 54 55 56 57 58
  xpu::ctx_guard RAII_GUARD(ctx.x_context());
  const int64_t* ids_data;
  if (ids_t->dtype() == phi::DataType::INT64) {
    ids_data = ids_t->data<int64_t>();
  } else {
    int64_t* ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel());
    int r = xpu::cast<int32_t, int64_t>(
        ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel());
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
    ids_data = reinterpret_cast<const int64_t*>(ids_tt);
  }

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
  const T* d_output_data = d_output_t->data<T>();
  T* d_table_data = dev_ctx.template Alloc<T>(d_table_t);
  int xm = d_table_t->dims()[0];
  int ym = static_cast<int>(ids_numel);
  int n = d_table_t->dims()[1];

  int r = xpu::embedding_grad<T, int64_t>(dev_ctx.x_context(),
                                          d_output_data,
                                          ids_data,
                                          d_table_data,
                                          xm,
                                          n,
                                          ym,
                                          padding_idx);
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad");
}

76 77 78 79 80 81 82 83
template <typename T, typename Context>
void EmbeddingSparseGradKernel(const Context& ctx,
                               const DenseTensor& input,
                               const DenseTensor& weight,
                               const DenseTensor& out_grad,
                               int64_t padding_idx,
                               SelectedRows* weight_grad) {
  DDim table_dim = weight.dims();
Y
ykkk2333 已提交
84
  auto xpu_place = ctx.GetPlace();
85 86

  xpu::ctx_guard RAII_GUARD(ctx.x_context());
Y
ykkk2333 已提交
87 88 89 90 91
  std::vector<int64_t> ids;
  DenseTensor ids_cpu;
  ids_cpu.Resize(input.dims());
  ctx.template HostAlloc(
      &ids_cpu, input.dtype(), input.numel() * sizeof(int64_t));
92
  if (input.dtype() == phi::DataType::INT64) {
Y
ykkk2333 已提交
93 94 95
    phi::Copy(ctx, input, CPUPlace(), false, &ids_cpu);

    ids = CopyIdsToVector<int64_t, int64_t>(ids_cpu);
96 97 98 99 100 101 102

  } else if (input.dtype() == phi::DataType::INT32) {
    int64_t* id_t = RAII_GUARD.alloc_l3_or_gm<int64_t>(input.numel());
    int r = xpu::cast<int32_t, int64_t>(
        ctx.x_context(), input.data<int>(), id_t, input.numel());
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
    paddle::memory::Copy(CPUPlace(),
Y
ykkk2333 已提交
103
                         ids_cpu.data(),
104 105 106
                         input.place(),
                         id_t,
                         sizeof(int64_t) * input.numel());
Y
ykkk2333 已提交
107
    ids = CopyIdsToVector<int, int64_t>(ids_cpu);
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "emebdding input only support int32 and int64"));
  }

  auto ids_num = static_cast<int64_t>(input.numel());
  // Since paddings are not trainable and fixed in forward, the gradient of
  // paddings makes no sense and we don't deal with it in backward.
  auto* d_table = weight_grad;
  auto* d_output = &out_grad;
  d_table->set_rows(ids);

  auto* d_table_value = d_table->mutable_value();
  d_table_value->Resize({ids_num, table_dim[1]});

Y
ykkk2333 已提交
123
  ctx.template HostAlloc<T>(d_table_value);
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141

  d_table->set_height(table_dim[0]);

  auto* d_output_data = d_output->template data<T>();
  auto* d_table_data = d_table_value->template data<T>();

  auto d_output_dims = d_output->dims();
  auto d_output_dims_2d =
      flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
  PADDLE_ENFORCE_EQ(d_table_value->dims(),
                    d_output_dims_2d,
                    phi::errors::InvalidArgument(
                        "ShapeError: The shape of lookup_table@Grad and "
                        "output@Grad should be same. "
                        "But received lookup_table@Grad's shape = [%s], "
                        "output@Grad's shape = [%s].",
                        d_table_value->dims(),
                        d_output_dims_2d));
Y
ykkk2333 已提交
142 143 144 145 146 147

  paddle::memory::Copy(CPUPlace(),
                       d_table_data,
                       xpu_place,
                       d_output_data,
                       d_output->numel() * sizeof(T));
148
}
149 150 151 152
}  // namespace phi

PD_REGISTER_KERNEL(
    embedding_grad, XPU, ALL_LAYOUT, phi::EmbeddingGradKernel, float) {}
153 154 155 156 157
PD_REGISTER_KERNEL(embedding_sparse_grad,
                   XPU,
                   ALL_LAYOUT,
                   phi::EmbeddingSparseGradKernel,
                   float) {}