rmsprop_kernel_impl.h 11.2 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <math.h>

19
#include "paddle/phi/common/amp_type_traits.h"
H
hong 已提交
20 21 22
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/for_range.h"
23
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
24
#include "paddle/phi/kernels/rmsprop_kernel.h"
H
hong 已提交
25 26
namespace phi {

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
template <typename T, typename Context>
struct RmsFunctor {
  RmsFunctor(const Context &ctx,
             const DenseTensor &param,
             const DenseTensor &mean_square,
             const DenseTensor &grad,
             const DenseTensor &moment,
             const DenseTensor &learning_rate,
             const paddle::optional<DenseTensor> &mean_grad_opt,
             const paddle::optional<DenseTensor> &master_param,
             float epsilon_t,
             float decay_t,
             float momentum_t,
             bool centered,
             bool multi_precision,
             DenseTensor *param_out,
             DenseTensor *moment_out,
             DenseTensor *mean_square_out,
             DenseTensor *mean_grad_out,
             DenseTensor *master_param_outs);
};

H
hong 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
template <typename T>
struct DenseRmspropGradFunctor {
  inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {}

  HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; }

  const T *grad_;
};

template <typename T>
struct SparseRmspropGradFunctor {
  inline SparseRmspropGradFunctor(const T *grad,
                                  const int64_t *rows,
                                  int64_t row_numel,
                                  int64_t row_count)
      : grad_(grad),
        rows_(rows),
        row_numel_(row_numel),
        row_count_(row_count) {}

  HOSTDEVICE inline T operator()(int64_t idx) const {
    auto row_idx =
        phi::funcs::BinarySearch(rows_, row_count_, idx / row_numel_);
72 73
    return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_]
                        : static_cast<T>(0);
H
hong 已提交
74 75 76 77 78 79 80 81
  }

  const T *grad_;
  const int64_t *rows_;
  int64_t row_numel_;
  int64_t row_count_;
};

82
template <typename T, typename MT, typename GradFunctor>
H
hong 已提交
83 84
struct UncenteredRmspropFunctor {
  UncenteredRmspropFunctor(T *param,
85 86 87 88 89 90 91
                           MT *ms,
                           MT *mom,
                           const MT *lr,
                           MT *master_p,
                           MT rho,
                           MT epsilon,
                           MT momentum,
H
hong 已提交
92 93 94 95
                           const GradFunctor &grad_functor)
      : param_(param),
        ms_(ms),
        mom_(mom),
96
        master_p_(master_p),
H
hong 已提交
97 98 99 100 101 102 103
        lr_(lr),
        rho_(rho),
        epsilon_(epsilon),
        momentum_(momentum),
        grad_functor_(grad_functor) {}

  HOSTDEVICE inline void operator()(int64_t idx) const {
104 105 106 107 108 109 110 111
    MT g = static_cast<MT>(grad_functor_(idx));
    MT l_rho = static_cast<MT>(1) - rho_;
    MT ms_out = rho_ * ms_[idx] + l_rho * g * g;
    MT mom_out = momentum_ * mom_[idx] +
                 static_cast<MT>(lr_[0]) * g / sqrt(ms_out + epsilon_);
    MT p = master_p_ ? master_p_[idx] : static_cast<MT>(param_[idx]);
    MT p_m = p - mom_out;
    param_[idx] = static_cast<T>(p_m);
H
hong 已提交
112 113
    ms_[idx] = ms_out;
    mom_[idx] = mom_out;
114
    if (master_p_) master_p_[idx] = p_m;
H
hong 已提交
115 116 117
  }

  T *param_;
118 119 120 121 122 123 124
  MT *ms_;
  MT *mom_;
  MT *master_p_;
  const MT *lr_;
  MT rho_;
  MT epsilon_;
  MT momentum_;
H
hong 已提交
125 126 127
  GradFunctor grad_functor_;
};

128
template <typename T, typename MT, typename GradFunctor>
H
hong 已提交
129 130
struct CenteredRmspropFunctor {
  CenteredRmspropFunctor(T *param,
131 132 133 134 135 136 137 138
                         MT *ms,
                         MT *mom,
                         MT *mean_grad,
                         const MT *lr,
                         MT *master_param,
                         MT rho,
                         MT epsilon,
                         MT momentum,
H
hong 已提交
139 140 141 142
                         const GradFunctor &grad_functor)
      : param_(param),
        ms_(ms),
        mom_(mom),
143
        master_p_(master_param),
H
hong 已提交
144 145 146 147 148 149 150 151
        mean_grad_(mean_grad),
        lr_(lr),
        rho_(rho),
        epsilon_(epsilon),
        momentum_(momentum),
        grad_functor_(grad_functor) {}

  HOSTDEVICE inline void operator()(int64_t idx) const {
152 153 154 155 156 157 158 159 160 161 162
    MT g = static_cast<MT>(grad_functor_(idx));
    MT l_rho = static_cast<MT>(1) - rho_;
    MT ms_out = rho_ * ms_[idx] + l_rho * g * g;
    MT mg_out = rho_ * mean_grad_[idx] + l_rho * g;
    MT mom_out =
        momentum_ * mom_[idx] +
        static_cast<MT>(lr_[0]) * g / sqrt(ms_out - mg_out * mg_out + epsilon_);

    MT p = master_p_ ? master_p_[idx] : static_cast<MT>(param_[idx]);
    MT p_m = p - mom_out;
    param_[idx] = static_cast<T>(p_m);
H
hong 已提交
163 164 165
    ms_[idx] = ms_out;
    mom_[idx] = mom_out;
    mean_grad_[idx] = mg_out;
166
    if (master_p_) master_p_[idx] = p_m;
H
hong 已提交
167 168 169
  }

  T *param_;
170 171 172 173 174 175 176 177
  MT *ms_;
  MT *mom_;
  MT *master_p_;
  MT *mean_grad_;
  const MT *lr_;
  MT rho_;
  MT epsilon_;
  MT momentum_;
H
hong 已提交
178 179 180 181 182 183 184 185 186 187
  GradFunctor grad_functor_;
};

template <typename T, typename Context>
void RmspropDenseKernel(const Context &ctx,
                        const DenseTensor &param,
                        const DenseTensor &mean_square,
                        const DenseTensor &grad,
                        const DenseTensor &moment,
                        const DenseTensor &learning_rate,
188
                        const paddle::optional<DenseTensor> &mean_grad_opt,
189
                        const paddle::optional<DenseTensor> &master_param,
H
hong 已提交
190 191 192 193
                        float epsilon_t,
                        float decay_t,
                        float momentum_t,
                        bool centered,
194
                        bool multi_precision,
H
hong 已提交
195 196 197
                        DenseTensor *param_out,
                        DenseTensor *moment_out,
                        DenseTensor *mean_square_out,
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
                        DenseTensor *mean_grad_out,
                        DenseTensor *master_param_outs) {
  RmsFunctor<T, Context> functor(ctx,
                                 param,
                                 mean_square,
                                 grad,
                                 moment,
                                 learning_rate,
                                 mean_grad_opt,
                                 master_param,
                                 epsilon_t,
                                 decay_t,
                                 momentum_t,
                                 centered,
                                 multi_precision,
                                 param_out,
                                 moment_out,
                                 mean_square_out,
                                 mean_grad_out,
                                 master_param_outs);
H
hong 已提交
218 219 220 221 222 223 224 225 226
}

template <typename T, typename Context>
void RmspropSparseKernel(const Context &ctx,
                         const DenseTensor &param,
                         const DenseTensor &mean_square,
                         const SelectedRows &grad,
                         const DenseTensor &moment,
                         const DenseTensor &learning_rate,
227
                         const paddle::optional<DenseTensor> &mean_grad_opt,
228
                         const paddle::optional<DenseTensor> &master_param,
H
hong 已提交
229 230 231 232
                         float epsilon_t,
                         float decay_t,
                         float momentum_t,
                         bool centered,
233
                         bool multi_precision,
H
hong 已提交
234 235 236
                         DenseTensor *param_out,
                         DenseTensor *moment_out,
                         DenseTensor *mean_square_out,
237 238 239 240 241 242
                         DenseTensor *mean_grad_out,
                         DenseTensor *master_param_outs) {
  using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
  auto epsilon = static_cast<MPDType>(epsilon_t);
  auto rho = static_cast<MPDType>(decay_t);
  auto momentum = static_cast<MPDType>(momentum_t);
H
hong 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265

  auto &p_tensor = param;
  auto &ms_tensor = mean_square;
  auto &lr_tensor = learning_rate;
  auto &mom_tensor = moment;

  PADDLE_ENFORCE_EQ(p_tensor.IsSharedBufferWith(*param_out),
                    true,
                    phi::errors::InvalidArgument(
                        "Param and ParamOut must be the same Tensor"));
  PADDLE_ENFORCE_EQ(mom_tensor.IsSharedBufferWith(*moment_out),
                    true,
                    phi::errors::InvalidArgument(
                        "Moment and MomentOut must be the same Tensor"));
  PADDLE_ENFORCE_EQ(
      ms_tensor.IsSharedBufferWith(*mean_square_out),
      true,
      phi::errors::InvalidArgument(
          "MeanSquare and MeanSquareOut must be the same Tensor"));
  size_t limit = static_cast<size_t>(ms_tensor.numel());

  phi::SelectedRows tmp_merged_grad;
  phi::SelectedRows *merged_grad = &tmp_merged_grad;
266
  phi::funcs::scatter::MergeAdd<Context, T> merge_func;
H
hong 已提交
267 268 269 270
  merge_func(ctx, grad, merged_grad);

  funcs::ForRange<Context> for_range(ctx, limit);
  auto &grad_merge_rows = merged_grad->rows();
H
Huang Jiyi 已提交
271
  phi::MixVector<int64_t> mixv_grad_merge_rows(&grad_merge_rows);
H
hong 已提交
272 273 274 275 276 277 278 279
  const int64_t *rows = mixv_grad_merge_rows.Data(ctx.GetPlace());

  auto &merged_tensor = merged_grad->value();
  int64_t row_count = merged_grad->rows().size();
  int64_t row_numel = merged_tensor.numel() / row_count;
  SparseRmspropGradFunctor<T> grad_func(
      merged_tensor.data<T>(), rows, row_numel, row_count);

280 281 282 283
  MPDType *master_out_data =
      multi_precision ? ctx.template Alloc<MPDType>(master_param_outs)
                      : nullptr;

H
hong 已提交
284 285
  if (centered) {
    auto mg_tensor = mean_grad_opt.get_ptr();
286 287 288 289 290 291 292 293 294 295 296 297 298
    if (mg_tensor) {
      PADDLE_ENFORCE_EQ(
          mg_tensor->Holder(),
          mean_grad_out->Holder(),
          phi::errors::InvalidArgument(
              "MeanGrad and MeanGradOut must be the same Tensor"));
    } else {
      PADDLE_ENFORCE_EQ(
          mg_tensor,
          mean_grad_out,
          phi::errors::InvalidArgument(
              "MeanGrad and MeanGradOut must be the same Tensor"));
    }
H
hong 已提交
299

300
    for_range(CenteredRmspropFunctor<T, MPDType, SparseRmspropGradFunctor<T>>(
H
hong 已提交
301
        ctx.template Alloc<T>(param_out),
302 303 304 305 306
        ctx.template Alloc<MPDType>(mean_square_out),
        ctx.template Alloc<MPDType>(moment_out),
        ctx.template Alloc<MPDType>(mean_grad_out),
        lr_tensor.data<MPDType>(),
        master_out_data,
H
hong 已提交
307 308 309 310 311
        rho,
        epsilon,
        momentum,
        grad_func));
  } else {
312
    for_range(UncenteredRmspropFunctor<T, MPDType, SparseRmspropGradFunctor<T>>(
H
hong 已提交
313
        ctx.template Alloc<T>(param_out),
314 315 316 317
        ctx.template Alloc<MPDType>(mean_square_out),
        ctx.template Alloc<MPDType>(moment_out),
        lr_tensor.data<MPDType>(),
        master_out_data,
H
hong 已提交
318 319 320 321 322 323 324 325
        rho,
        epsilon,
        momentum,
        grad_func));
  }
}

}  // namespace phi