adam_op.cu 15.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
W
Wu Yi 已提交
15
#include "paddle/fluid/operators/optimizers/adam_op.h"
16
#include "paddle/fluid/platform/float16.h"
17

18 19 20
namespace paddle {
namespace operators {

21 22 23 24
template <typename T, typename MT>
__global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, MT beta1_pow_,
                              MT beta2_pow_, const MT* moment1, MT* moment1_out,
                              const MT* moment2, MT* moment2_out, const MT* lr_,
25
                              const T* grad, const T* param, T* param_out,
26
                              const MT* master_param, MT* master_param_out,
27
                              int ndim) {
28 29 30
  MT lr = *lr_;
  MT beta1_pow = beta1_pow_;
  MT beta2_pow = beta2_pow_;
31

32 33
  lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
        (static_cast<MT>(1.0) - beta1_pow);
34 35 36 37

  int id = blockIdx.x * blockDim.x + threadIdx.x;

  for (; id < ndim; id += gridDim.x * blockDim.x) {
38 39 40 41 42 43
    MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
    MT g = static_cast<MT>(grad[id]);
    MT mom1 = moment1[id];
    MT mom2 = moment2[id];
    mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
    mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
M
MRXLT 已提交
44
    p -= lr * (mom1 /
45
               (sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
46 47 48

    moment1_out[id] = mom1;
    moment2_out[id] = mom2;
49 50 51 52
    param_out[id] = static_cast<T>(p);
    if (master_param_out) {
      master_param_out[id] = p;
    }
53 54 55
  }
}

56 57 58 59 60 61 62 63 64 65 66 67 68 69
template <typename T, typename MT>
__global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon,
                              const MT* beta1_pow_, const MT* beta2_pow_,
                              const MT* moment1, MT* moment1_out,
                              const MT* moment2, MT* moment2_out, const MT* lr_,
                              const T* grad, const T* param, T* param_out,
                              const MT* master_param, MT* master_param_out,
                              int ndim) {
  MT lr = *lr_;
  MT beta1_pow = *beta1_pow_;
  MT beta2_pow = *beta2_pow_;

  lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
        (static_cast<MT>(1.0) - beta1_pow);
70 71 72 73

  int id = blockIdx.x * blockDim.x + threadIdx.x;

  for (; id < ndim; id += gridDim.x * blockDim.x) {
74 75 76 77 78 79
    MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
    MT g = static_cast<MT>(grad[id]);
    MT mom1 = static_cast<MT>(moment1[id]);
    MT mom2 = static_cast<MT>(moment2[id]);
    mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
    mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
M
MRXLT 已提交
80
    p -= lr * (mom1 /
81
               (sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
82 83 84

    moment1_out[id] = mom1;
    moment2_out[id] = mom2;
85 86 87 88
    param_out[id] = static_cast<T>(p);
    if (master_param_out) {
      master_param_out[id] = p;
    }
89 90 91 92 93 94 95 96 97 98
  }
}
template <typename T>
__global__ void UpdateBetaPow(T beta1, T beta2, const T* beta1_pow_,
                              const T* beta2_pow_, T* beta1_pow_out,
                              T* beta2_pow_out) {
  *beta1_pow_out = beta1 * beta1_pow_[0];
  *beta2_pow_out = beta2 * beta2_pow_[0];
}

99
template <typename T, typename MT>
100
__global__ void SparseAdamCUDAKernelREG(
101 102 103 104
    MT beta1, MT beta2, MT epsilon, const MT beta1_pow, const MT beta2_pow,
    const MT* mom1_, MT* mom1_out_, const MT* mom2_, MT* mom2_out_,
    const MT* lr_, const T* grad_, const T* param_, T* param_out_,
    const MT* master_param, MT* master_param_out, const int64_t* rows_,
105 106
    int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) {
  int id = blockIdx.x * blockDim.x + threadIdx.x;
107 108 109
  MT lr = *lr_;
  lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
        (static_cast<MT>(1.0) - beta1_pow);
110 111 112 113 114 115 116

  for (; id < ndim; id += blockDim.x * gridDim.x) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_count, id / row_numel);
    if (lazy_mode && row_idx < 0) {
      return;
    } else {
117 118 119 120 121 122 123 124
      MT mom1 = mom1_[id];
      MT mom2 = mom2_[id];
      MT p = master_param ? master_param[id] : static_cast<MT>(param_[id]);
      MT g = row_idx >= 0
                 ? static_cast<MT>(grad_[row_idx * row_numel + id % row_numel])
                 : static_cast<MT>(0);
      mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
      mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
M
MRXLT 已提交
125
      p -= lr * (mom1 / (sqrt(mom2) +
126
                         epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
127 128 129 130

      // Write back to global memory
      mom1_out_[id] = mom1;
      mom2_out_[id] = mom2;
131 132 133 134
      param_out_[id] = static_cast<T>(p);
      if (master_param_out) {
        master_param_out[id] = p;
      }
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
    }
  }
}

template <typename T>
class AdamOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    const auto* param_var = ctx.InputVar("Param");
    PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
                      platform::errors::InvalidArgument(
                          "The Var(%s)'s type should be LoDTensor, "
                          "but the received is %s",
                          ctx.InputNames("Param").front(),
                          framework::ToTypeName(param_var->Type())));

    using paddle::framework::LoDTensor;
152
    using MPDType = typename details::MPTypeTrait<T>::Type;
153 154 155 156

    int64_t min_row_size_to_use_multithread =
        ctx.Attr<int64_t>("min_row_size_to_use_multithread");
    bool lazy_mode = ctx.Attr<bool>("lazy_mode");
157

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
    auto* param = ctx.Input<LoDTensor>("Param");
    auto* grad_var = ctx.InputVar("Grad");
    auto* mom1 = ctx.Input<LoDTensor>("Moment1");
    auto* mom2 = ctx.Input<LoDTensor>("Moment2");
    auto* lr = ctx.Input<LoDTensor>("LearningRate");

    auto* beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
    auto* beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");

    auto* param_out = ctx.Output<LoDTensor>("ParamOut");
    auto* mom1_out = ctx.Output<LoDTensor>("Moment1Out");
    auto* mom2_out = ctx.Output<LoDTensor>("Moment2Out");
    auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
    auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");

173
    MPDType beta1 = static_cast<MPDType>(ctx.Attr<float>("beta1"));
174 175
    if (ctx.HasInput("Beta1Tensor")) {
      auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
W
wangchaochaohu 已提交
176 177 178 179
      PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(Beta1Tensor) size must be 1, but get %d",
                            beta1_tensor->numel()));
180
      beta1 = static_cast<MPDType>(GetAttrFromTensor(beta1_tensor));
181
    }
182
    MPDType beta2 = static_cast<MPDType>(ctx.Attr<float>("beta2"));
183 184
    if (ctx.HasInput("Beta2Tensor")) {
      auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
W
wangchaochaohu 已提交
185 186 187 188
      PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(Beta2Tensor) size must be 1, but get %d",
                            beta2_tensor->numel()));
189
      beta2 = static_cast<MPDType>(GetAttrFromTensor(beta2_tensor));
190
    }
191 192 193 194 195 196 197 198 199
    MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
    if (ctx.HasInput("EpsilonTensor")) {
      auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
      PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(EpsilonTensor) size must be 1, but get %d",
                            epsilon_tensor->numel()));
      epsilon = static_cast<MPDType>(GetAttrFromTensor(epsilon_tensor));
    }
200 201 202 203 204 205 206 207 208 209 210 211 212 213
    VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
            << "beta2_pow.numel() : " << beta2_pow->numel();
    VLOG(3) << "param.numel(): " << param->numel();
    PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
                      platform::errors::InvalidArgument(
                          "beta1 pow output size should be 1, but received "
                          "value is:%d.",
                          beta1_pow_out->numel()));

    PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
                      platform::errors::InvalidArgument(
                          "beta2 pow output size should be 1, but received "
                          "value is:%d.",
                          beta2_pow_out->numel()));
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235

    const bool multi_precision = ctx.Attr<bool>("multi_precision");
    const LoDTensor* master_param = nullptr;
    LoDTensor* master_param_out = nullptr;
    if (multi_precision) {
      bool has_master =
          ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
      PADDLE_ENFORCE_EQ(has_master, true,
                        platform::errors::InvalidArgument(
                            "The Input(MasterParam) and Output(MasterParamOut) "
                            "should not be null when "
                            "the attr `multi_precision` is true"));
      master_param = ctx.Input<LoDTensor>("MasterParam");
      master_param_out = ctx.Output<LoDTensor>("MasterParamOut");
    }
    const MPDType* master_in_data =
        multi_precision ? master_param->data<MPDType>() : nullptr;
    MPDType* master_out_data =
        multi_precision
            ? master_param_out->mutable_data<MPDType>(ctx.GetPlace())
            : nullptr;

236 237 238 239 240 241 242 243 244 245 246 247
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();

    if (grad_var->IsType<framework::LoDTensor>()) {
      auto* grad = ctx.Input<LoDTensor>("Grad");

      // update param and moment
      int threads = 512;
      int blocks = (param->numel() + threads - 1) / threads;

      if (beta1_pow->place() == platform::CPUPlace() &&
          beta2_pow->place() == platform::CPUPlace()) {
        // Compute with betapow in REG
248 249 250 251 252 253 254 255 256
        AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
            beta1, beta2, epsilon, *beta1_pow->data<MPDType>(),
            *beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad->data<T>(), param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, param->numel());
257
        // Cpu update
258 259 260 261
        beta1_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
            beta1 * beta1_pow->data<MPDType>()[0];
        beta2_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
            beta2 * beta2_pow->data<MPDType>()[0];
262
      } else {
263 264 265 266 267 268 269 270 271
        AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
            beta1, beta2, epsilon, beta1_pow->data<MPDType>(),
            beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad->data<T>(), param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, param->numel());
272
        // Update with gpu
273 274 275 276 277
        UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
            beta1, beta2, beta1_pow->data<MPDType>(),
            beta2_pow->data<MPDType>(),
            beta1_pow_out->mutable_data<MPDType>(ctx.GetPlace()),
            beta2_pow_out->mutable_data<MPDType>(ctx.GetPlace()));
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
      }

    } else if (grad_var->IsType<framework::SelectedRows>()) {
      auto* grad = ctx.Input<framework::SelectedRows>("Grad");
      if (grad->rows().size() == 0) {
        VLOG(3) << "grad row size is 0!!";
        return;
      }

      std::vector<int64_t> cpu_rows(grad->rows().begin(), grad->rows().end());
      bool is_strict_sorted = true;
      for (size_t i = 1; i < cpu_rows.size(); ++i) {
        if (cpu_rows[i - 1] >= cpu_rows[i]) {
          is_strict_sorted = false;
          break;
        }
      }

      framework::SelectedRows tmp_grad_merge;
      const framework::SelectedRows* grad_merge_ptr;
      if (is_strict_sorted) {
        grad_merge_ptr = grad;
      } else {
        // merge duplicated rows if any.
        // The rows of grad_merge have been sorted inside MergeAdd functor
        scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func;
        merge_func(ctx.template device_context<platform::CUDADeviceContext>(),
                   *grad, &tmp_grad_merge, true);
        grad_merge_ptr = &tmp_grad_merge;
      }
      auto& grad_merge = *grad_merge_ptr;
      auto& grad_tensor = grad_merge.value();
      const T* grad_data = grad_tensor.template data<T>();
      const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
      auto row_numel = grad_tensor.numel() / grad_merge.rows().size();

      if (beta1_pow->place() == platform::CPUPlace() &&
          beta2_pow->place() == platform::CPUPlace()) {
        int threads = 512;
        int ndim = param->numel();
        int blocks = (ndim + threads - 1) / threads;

320 321 322 323 324 325 326 327 328 329 330
        SparseAdamCUDAKernelREG<
            T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
            beta1, beta2, epsilon, *beta1_pow->data<MPDType>(),
            *beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad_data, param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, rows, row_numel, grad_merge.rows().size(),
            lazy_mode, ndim);
331
        // Update with cpu
332 333 334 335
        beta1_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
            beta1 * beta1_pow->data<MPDType>()[0];
        beta2_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
            beta2 * beta2_pow->data<MPDType>()[0];
336
      } else {
337 338 339 340 341 342 343 344 345 346
        SparseAdamFunctor<T, GPUAdam, MPDType> functor(
            beta1, beta2, epsilon, beta1_pow->data<MPDType>(),
            beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad_data, param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, rows, row_numel, grad_merge.rows().size(),
            lazy_mode);
347 348 349 350 351 352 353 354

        // FIXME(minqiyang): remove BinarySearch in GPU later
        platform::ForRange<platform::CUDADeviceContext> for_range(
            static_cast<const platform::CUDADeviceContext&>(
                ctx.device_context()),
            param->numel());
        for_range(functor);
        // update beta1 and beta2
355 356 357 358 359
        UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
            beta1, beta2, beta1_pow->data<MPDType>(),
            beta2_pow->data<MPDType>(),
            beta1_pow_out->mutable_data<MPDType>(ctx.GetPlace()),
            beta2_pow_out->mutable_data<MPDType>(ctx.GetPlace()));
360 361 362 363 364 365 366 367 368 369 370
      }
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Variable type not supported by adam_op"));
    }
  }
};

}  // namespace operators
}  // namespace paddle

371
namespace ops = paddle::operators;
372 373
namespace plat = paddle::platform;

374
REGISTER_OP_CUDA_KERNEL(adam, ops::AdamOpCUDAKernel<float>,
375 376
                        ops::AdamOpCUDAKernel<double>,
                        ops::AdamOpCUDAKernel<plat::float16>);