momentum_op.cu 5.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
S
sidgoyal78 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
S
sidgoyal78 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
S
sidgoyal78 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
S
sidgoyal78 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/operators/momentum_op.h"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32

namespace paddle {
namespace operators {

template <typename T>
__global__ void MomentumKernel(const T* p, const T* g, const T* v,
                               const T* learning_rate, const T mu,
                               const int64_t num, bool use_nesterov, T* p_out,
                               T* v_out) {
  T lr = learning_rate[0];
  if (use_nesterov) {
    for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
         i += blockDim.x * gridDim.x) {
      T g_val = g[i];
      T v_new = v[i] * mu + g_val;
      v_out[i] = v_new;
33
      p_out[i] = p[i] - (g_val + v_new * mu) * lr;
34 35 36 37 38 39 40 41 42 43 44
    }
  } else {
    for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
         i += blockDim.x * gridDim.x) {
      T v_new = v[i] * mu + g[i];
      v_out[i] = v_new;
      p_out[i] = p[i] - lr * v_new;
    }
  }
}

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
template <typename T>
__global__ void SparseMomentumKernel(const T* p, const T* g, const T* v,
                                     const T* lr, const T mu,
                                     const int64_t* grad_rows,
                                     const size_t grad_row_numel,
                                     const size_t grad_row_size,
                                     const T use_nesterov, T* p_out, T* v_out) {
  for (int i = blockIdx.x; i < grad_row_size; i += gridDim.x) {
    for (int j = threadIdx.x; j < grad_row_numel; j += blockDim.x) {
      size_t p_i = grad_rows[i] * grad_row_numel + j;
      size_t g_i = i * grad_row_numel + j;
      v_out[g_i] = v[g_i] * mu + g[g_i];
      if (use_nesterov) {
        p_out[p_i] = p[p_i] - (g[g_i] + v_out[g_i] * mu) * lr[0];
      } else {
        p_out[p_i] = p[p_i] - v_out[g_i] * lr[0];
      }
    }
  }
}

66 67 68 69
template <typename T>
class MomentumOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
70 71 72
    T mu = static_cast<T>(ctx.Attr<float>("mu"));
    bool use_nesterov = ctx.Attr<bool>("use_nesterov");

73
    auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
74 75 76 77
    auto param = ctx.Input<framework::Tensor>("Param");
    auto param_out = ctx.Output<framework::Tensor>("ParamOut");
    auto* velocity_var = ctx.InputVar("Velocity");
    auto* grad_var = ctx.InputVar("Grad");
78

79 80 81 82 83 84 85 86 87 88 89 90
    if (grad_var->IsType<framework::LoDTensor>()) {
      PADDLE_ENFORCE(velocity_var->IsType<framework::LoDTensor>(),
                     "Unmatched Type of Param and Grad");
      auto velocity = ctx.Input<framework::Tensor>("Velocity");
      auto grad = ctx.Input<framework::Tensor>("Grad");
      auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
      T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
      T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
      auto* p = param->data<T>();
      auto* v = velocity->data<T>();
      auto* g = grad->data<T>();
      auto* lr = learning_rate->data<T>();
91

92 93 94 95 96 97 98 99 100 101 102 103
      const int kThreadPerBlock = 256;
      int grid = (param->numel() + kThreadPerBlock - 1) / kThreadPerBlock;
      MomentumKernel<
          T><<<grid, kThreadPerBlock, 0, ctx.cuda_device_context().stream()>>>(
          p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out);
    } else if (grad_var->IsType<framework::SelectedRows>()) {
      // sparse update embedding with selectedrows
      PADDLE_ENFORCE(velocity_var->IsType<framework::SelectedRows>(),
                     "Unmatched Type of Param and Grad");
      auto velocity = ctx.Input<framework::SelectedRows>("Velocity");
      auto grad = ctx.Input<framework::SelectedRows>("Grad");
      auto velocity_out = ctx.Output<framework::SelectedRows>("VelocityOut");
104

105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
      // sparse update maybe empty.
      if (grad->rows().size() == 0) {
        return;
      }
      PADDLE_ENFORCE(grad->height() == velocity->height(),
                     "Unmatched gradient and velocity.");
      auto* p_out = param_out->mutable_data<T>(ctx.GetPlace());
      auto* v_out =
          velocity_out->mutable_value()->mutable_data<T>(ctx.GetPlace());
      auto* lr = learning_rate->data<T>();
      auto* p = param->data<T>();
      auto* g = grad->value().data<T>();
      auto* v = velocity->value().data<T>();
      size_t grad_row_numel = grad->value().numel() / grad->rows().size();
      size_t grad_row_size = grad->rows().size();
      framework::Vector<int64_t> rows(grad->rows());
121

122 123 124 125 126 127 128 129 130
      const int kThreadPerBlock = 256;
      int grid = (param->numel() + kThreadPerBlock - 1) / kThreadPerBlock;
      SparseMomentumKernel<
          T><<<grid, kThreadPerBlock, 0, ctx.cuda_device_context().stream()>>>(
          p, g, v, lr, mu, rows.CUDAData(ctx.GetPlace()), grad_row_numel,
          grad->rows().size(), use_nesterov, p_out, v_out);
    } else {
      PADDLE_THROW("Unsupported Variable Type of Grad");
    }
131 132 133 134 135
  }
};

}  // namespace operators
}  // namespace paddle
S
sidgoyal78 已提交
136 137

namespace ops = paddle::operators;
Q
QI JUN 已提交
138 139
REGISTER_OP_CUDA_KERNEL(momentum, ops::MomentumOpCUDAKernel<float>,
                        ops::MomentumOpCUDAKernel<double>);