adagrad_kernel_impl.h 4.6 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/phi/kernels/adagrad_kernel.h"
H
hong 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename DeviceContext, typename T>
struct SparseAdagradFunctor {
  void operator()(const DeviceContext& context,
                  const phi::SelectedRows& grad,
                  const DenseTensor& learning_rate,
                  T epsilon,
                  DenseTensor* moment,
                  DenseTensor* param);
};

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
template <typename DeviceContext, typename T>
struct DenseAdagradFunctor {
  void operator()(const DeviceContext& ctx,
                  const DenseTensor& param_t,
                  const DenseTensor& grad_t,
                  const DenseTensor& moment_t,
                  const DenseTensor& learning_rate,
                  const paddle::optional<DenseTensor>& master_param,
                  float epsilon_t,
                  bool multi_precision,
                  DenseTensor* param_out_tensor,
                  DenseTensor* moment_out_tensor,
                  DenseTensor* master_param_outs);
};

H
hong 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
template <typename DeviceContext, typename T>
phi::SelectedRows SquareSelectedRows(const DeviceContext& context,
                                     const phi::SelectedRows& input) {
  phi::SelectedRows out;
  out.set_rows(input.rows());
  out.set_height(input.height());
  out.mutable_value()->Resize(input.value().dims());
  context.template Alloc<T>(out.mutable_value());
  auto e_out = EigenVector<T>::Flatten(*(out.mutable_value()));
  auto e_in = EigenVector<T>::Flatten(input.value());
  e_out.device(*context.eigen_device()) = e_in.square();
  return out;
}

template <typename T, typename Context>
void AdagradDenseKernel(const Context& ctx,
                        const DenseTensor& param_t,
                        const DenseTensor& grad_t,
                        const DenseTensor& moment_t,
                        const DenseTensor& learning_rate,
68
                        const paddle::optional<DenseTensor>& master_param,
H
hong 已提交
69
                        float epsilon_t,
70
                        bool multi_precision,
H
hong 已提交
71
                        DenseTensor* param_out_tensor,
72 73 74 75 76 77 78 79 80 81 82 83 84 85
                        DenseTensor* moment_out_tensor,
                        DenseTensor* master_param_outs) {
  DenseAdagradFunctor<Context, T> functor;
  functor(ctx,
          param_t,
          grad_t,
          moment_t,
          learning_rate,
          master_param,
          epsilon_t,
          multi_precision,
          param_out_tensor,
          moment_out_tensor,
          master_param_outs);
H
hong 已提交
86 87 88 89 90 91 92 93
}

template <typename T, typename Context>
void AdagradSparseKernel(const Context& ctx,
                         const DenseTensor& param_t,
                         const SelectedRows& grad_t,
                         const DenseTensor& moment_t,
                         const DenseTensor& learning_rate,
94
                         const paddle::optional<DenseTensor>& master_param,
H
hong 已提交
95
                         float epsilon_t,
96
                         bool multi_precision,
H
hong 已提交
97
                         DenseTensor* param_out,
98 99
                         DenseTensor* moment_out,
                         DenseTensor* master_param_outs) {
H
hong 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
  auto* param_out_tensor = param_out;
  auto* moment_out_tensor = moment_out;

  ctx.template Alloc<T>(param_out_tensor);
  ctx.template Alloc<T>(moment_out_tensor);

  T epsilon = static_cast<T>(epsilon_t);

  auto* param_tensor = &param_t;
  PADDLE_ENFORCE_EQ(param_tensor,
                    param_out_tensor,
                    phi::errors::InvalidArgument(
                        "the input tensor not euqal with output tensor"));

  auto* moment_tensor = &moment_t;
  PADDLE_ENFORCE_EQ(moment_tensor,
                    moment_out_tensor,
                    phi::errors::InvalidArgument(
                        "the input moment not eual with output moment"));

  SparseAdagradFunctor<Context, T> functor;
  functor(
      ctx, grad_t, learning_rate, epsilon, moment_out_tensor, param_out_tensor);
}

}  // namespace phi