sgd_kernel.cu 7.3 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/sgd_kernel.h"

H
hong 已提交
17
#include "paddle/phi/backends/gpu/gpu_context.h"
18
#include "paddle/phi/backends/gpu/gpu_helper.h"
W
Wang Xin 已提交
19
#include "paddle/phi/backends/gpu/gpu_primitives.h"
20
#include "paddle/phi/common/amp_type_traits.h"
H
hong 已提交
21
#include "paddle/phi/core/kernel_registry.h"
H
Huang Jiyi 已提交
22
#include "paddle/phi/core/mixed_vector.h"
H
hong 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

namespace phi {

template <typename T, typename MT>
__global__ void SGDKernelMT(const T* param,
                            const T* grad,
                            const T* learning_rate,
                            const int num,
                            T* param_out,
                            const MT* master_param,
                            MT* master_param_out) {
  MT lr = static_cast<MT>(learning_rate[0]);
  CUDA_KERNEL_LOOP(i, num) {
    MT p_data = master_param ? master_param[i] : static_cast<MT>(param[i]);
    MT g_data = static_cast<MT>(grad[i]);
    p_data = p_data - lr * g_data;
    param_out[i] = static_cast<T>(p_data);
    if (master_param_out) {
      master_param_out[i] = p_data;
    }
  }
}

template <typename T>
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
                                       const int64_t* rows,
                                       const T* learning_rate,
                                       T* tensor_out,
                                       int64_t row_numel,
                                       int64_t limit) {
  for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) {
    const T* selected_rows_ptr = selected_rows + i * row_numel;
    T* tensor_out_ptr = tensor_out + rows[i] * row_numel;
    for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) {
      // Since index in rows of SelectedRows can be duplicate, we have to use
      // Atomic Operation to avoid concurrent write error.
W
Wang Xin 已提交
59
      phi::CudaAtomicAdd(
H
hong 已提交
60 61 62 63 64 65 66 67 68 69 70
          tensor_out_ptr + index,
          -static_cast<T>(1.0) * learning_rate[0] * selected_rows_ptr[index]);
    }
  }
}

template <typename T, typename Context>
void SGDDenseKernel(const Context& dev_ctx,
                    const DenseTensor& param,
                    const DenseTensor& learning_rate,
                    const DenseTensor& grad,
71
                    const paddle::optional<DenseTensor>& master_param,
H
hong 已提交
72 73 74
                    bool multi_precision,
                    DenseTensor* param_out,
                    DenseTensor* master_param_out) {
75
  using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
H
hong 已提交
76 77 78 79 80 81 82 83 84
  // do check here
  // if (multi_precision) {
  //   bool has_master =
  //       ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");

  // }
  const MPDType* master_in_data =
      multi_precision ? master_param->data<MPDType>() : nullptr;
  MPDType* master_out_data =
85 86
      multi_precision ? dev_ctx.template Alloc<MPDType>(master_param_out)
                      : nullptr;
H
hong 已提交
87 88 89 90 91 92 93 94 95

  int block = 512;
  int grid = (param.numel() + block - 1) / block;

  SGDKernelMT<T, MPDType><<<grid, block, 0, dev_ctx.stream()>>>(
      param.data<T>(),
      grad.data<T>(),
      learning_rate.data<T>(),
      param.numel(),
96
      dev_ctx.template Alloc<T>(param_out),
H
hong 已提交
97 98 99 100 101 102 103 104 105 106
      master_in_data,
      master_out_data);
}

template <typename T, typename Context>
void SGDDenseParamSparseGradKernel(
    const Context& dev_ctx,
    const DenseTensor& param,
    const DenseTensor& learning_rate,
    const SelectedRows& grad,
107
    const paddle::optional<DenseTensor>& master_param,
H
hong 已提交
108 109 110
    bool multi_precision,
    DenseTensor* param_out,
    DenseTensor* master_param_out) {
111
  using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
H
hong 已提交
112 113 114 115 116 117 118 119 120
  // do some check here
  // if (multi_precision) {
  //   bool has_master =
  //       ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");

  // }
  const MPDType* master_in_data =
      multi_precision ? master_param->data<MPDType>() : nullptr;
  MPDType* master_out_data =
121 122
      multi_precision ? dev_ctx.template Alloc<MPDType>(master_param_out)
                      : nullptr;
H
hong 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

  PADDLE_ENFORCE_EQ(
      &param,
      param_out,
      phi::errors::InvalidArgument(
          "The input tensor Param of SgdOp should be equal with ParamOut "
          "if variable's type is SelectedRows."));

  auto in_height = grad.height();
  auto out_dims = param_out->dims();
  PADDLE_ENFORCE_EQ(in_height,
                    out_dims[0],
                    phi::errors::InvalidArgument(
                        "The input tensor Grad's height of SgdOp should be "
                        "equal with ParamOut's dims. But received Grad's "
                        "height [%s] and ParamOut's dims [%s]",
                        in_height,
                        out_dims[0]));

  auto& in_value = grad.value();
  auto& in_rows = grad.rows();

  int64_t in_row_numel = in_value.numel() / in_rows.size();
  PADDLE_ENFORCE_EQ(in_row_numel,
                    param_out->numel() / in_height,
                    phi::errors::InvalidArgument(
                        "The in_row_numel of SgdOp should be equal with "
                        "param_out's numel / in_height."));

  auto* in_data = in_value.data<T>();
  auto* out_data = param_out->data<T>();

  const int kThreadsPerBlock = 256;
  int thread_x = kThreadsPerBlock;
  int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
  int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
H
Huang Jiyi 已提交
159
  phi::MixVector<int64_t> mixv_in_rows(&in_rows);
H
hong 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
  SparseSGDFunctorKernel<<<max_blocks, thread_x, 0, dev_ctx.stream()>>>(
      in_data,
      mixv_in_rows.CUDAData(dev_ctx.GetPlace()),
      learning_rate.data<T>(),
      out_data,
      in_row_numel,
      in_rows.size());
}

template <typename T, typename Context>
void SGDSparseParamSparseGradKernel(
    const Context& dev_ctx,
    const SelectedRows& param,
    const DenseTensor& learning_rate,
    const SelectedRows& grad,
175
    const paddle::optional<SelectedRows>& master_param,
H
hong 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189
    bool multi_precision,
    SelectedRows* param_out,
    SelectedRows* master_param_out) {
  PADDLE_THROW("not impl");
}

}  // namespace phi

PD_REGISTER_KERNEL(sgd,
                   GPU,
                   ALL_LAYOUT,
                   phi::SGDDenseKernel,
                   phi::dtype::float16,
                   float,
190
                   double) {
191 192 193
  if (kernel_key.dtype() == phi::DataType::FLOAT16) {
    kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
  }
194
}
H
hong 已提交
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210

PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::SGDDenseParamSparseGradKernel,
                   phi::dtype::float16,
                   float,
                   double) {}

PD_REGISTER_KERNEL(sgd_sparse_param_sparse_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::SGDSparseParamSparseGradKernel,
                   phi::dtype::float16,
                   float,
                   double) {}