adam_op.cu 17.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
W
Wu Yi 已提交
15
#include "paddle/fluid/operators/optimizers/adam_op.h"
16
#include "paddle/fluid/platform/float16.h"
17

18 19 20
namespace paddle {
namespace operators {

21 22 23 24
template <typename T, typename MT>
__global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, MT beta1_pow_,
                              MT beta2_pow_, const MT* moment1, MT* moment1_out,
                              const MT* moment2, MT* moment2_out, const MT* lr_,
25
                              const T* grad, const T* param, T* param_out,
26
                              const MT* master_param, MT* master_param_out,
27
                              int ndim) {
28 29 30
  MT lr = *lr_;
  MT beta1_pow = beta1_pow_;
  MT beta2_pow = beta2_pow_;
31 32 33 34

  int id = blockIdx.x * blockDim.x + threadIdx.x;

  for (; id < ndim; id += gridDim.x * blockDim.x) {
35 36
    MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
    MT g = static_cast<MT>(grad[id]);
Z
zhangbo9674 已提交
37 38
    MT mom1 = static_cast<MT>(moment1[id]);
    MT mom2 = static_cast<MT>(moment2[id]);
39 40
    mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
    mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
Z
zhangbo9674 已提交
41 42 43

    MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
    p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
44 45 46

    moment1_out[id] = mom1;
    moment2_out[id] = mom2;
47 48 49 50
    param_out[id] = static_cast<T>(p);
    if (master_param_out) {
      master_param_out[id] = p;
    }
51 52 53
  }
}

54 55 56 57 58 59 60 61 62 63 64 65
template <typename T, typename MT>
__global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon,
                              const MT* beta1_pow_, const MT* beta2_pow_,
                              const MT* moment1, MT* moment1_out,
                              const MT* moment2, MT* moment2_out, const MT* lr_,
                              const T* grad, const T* param, T* param_out,
                              const MT* master_param, MT* master_param_out,
                              int ndim) {
  MT lr = *lr_;
  MT beta1_pow = *beta1_pow_;
  MT beta2_pow = *beta2_pow_;

66 67 68
  int id = blockIdx.x * blockDim.x + threadIdx.x;

  for (; id < ndim; id += gridDim.x * blockDim.x) {
69 70 71 72 73 74
    MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
    MT g = static_cast<MT>(grad[id]);
    MT mom1 = static_cast<MT>(moment1[id]);
    MT mom2 = static_cast<MT>(moment2[id]);
    mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
    mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
Z
zhangbo9674 已提交
75 76 77

    MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
    p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
78 79 80

    moment1_out[id] = mom1;
    moment2_out[id] = mom2;
81 82 83 84
    param_out[id] = static_cast<T>(p);
    if (master_param_out) {
      master_param_out[id] = p;
    }
85 86 87 88 89 90 91 92 93 94
  }
}
template <typename T>
__global__ void UpdateBetaPow(T beta1, T beta2, const T* beta1_pow_,
                              const T* beta2_pow_, T* beta1_pow_out,
                              T* beta2_pow_out) {
  *beta1_pow_out = beta1 * beta1_pow_[0];
  *beta2_pow_out = beta2 * beta2_pow_[0];
}

95
template <typename T, typename MT>
96
__global__ void SparseAdamCUDAKernelREG(
97 98 99 100
    MT beta1, MT beta2, MT epsilon, const MT beta1_pow, const MT beta2_pow,
    const MT* mom1_, MT* mom1_out_, const MT* mom2_, MT* mom2_out_,
    const MT* lr_, const T* grad_, const T* param_, T* param_out_,
    const MT* master_param, MT* master_param_out, const int64_t* rows_,
101 102
    int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) {
  int id = blockIdx.x * blockDim.x + threadIdx.x;
103
  MT lr = *lr_;
104 105 106 107 108 109 110

  for (; id < ndim; id += blockDim.x * gridDim.x) {
    auto row_idx =
        math::BinarySearch<int64_t>(rows_, row_count, id / row_numel);
    if (lazy_mode && row_idx < 0) {
      return;
    } else {
111 112 113 114 115 116 117 118
      MT mom1 = mom1_[id];
      MT mom2 = mom2_[id];
      MT p = master_param ? master_param[id] : static_cast<MT>(param_[id]);
      MT g = row_idx >= 0
                 ? static_cast<MT>(grad_[row_idx * row_numel + id % row_numel])
                 : static_cast<MT>(0);
      mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
      mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
Z
zhangbo9674 已提交
119 120 121 122

      MT denom =
          (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
      p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
123 124 125 126

      // Write back to global memory
      mom1_out_[id] = mom1;
      mom2_out_[id] = mom2;
127 128 129 130
      param_out_[id] = static_cast<T>(p);
      if (master_param_out) {
        master_param_out[id] = p;
      }
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    }
  }
}

template <typename T>
class AdamOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    const auto* param_var = ctx.InputVar("Param");
    PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
                      platform::errors::InvalidArgument(
                          "The Var(%s)'s type should be LoDTensor, "
                          "but the received is %s",
                          ctx.InputNames("Param").front(),
                          framework::ToTypeName(param_var->Type())));

    using paddle::framework::LoDTensor;
148
    using MPDType = typename details::MPTypeTrait<T>::Type;
149 150 151 152

    int64_t min_row_size_to_use_multithread =
        ctx.Attr<int64_t>("min_row_size_to_use_multithread");
    bool lazy_mode = ctx.Attr<bool>("lazy_mode");
153 154
    bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
    VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
155

156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    auto* param = ctx.Input<LoDTensor>("Param");
    auto* grad_var = ctx.InputVar("Grad");
    auto* mom1 = ctx.Input<LoDTensor>("Moment1");
    auto* mom2 = ctx.Input<LoDTensor>("Moment2");
    auto* lr = ctx.Input<LoDTensor>("LearningRate");

    auto* beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
    auto* beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");

    auto* param_out = ctx.Output<LoDTensor>("ParamOut");
    auto* mom1_out = ctx.Output<LoDTensor>("Moment1Out");
    auto* mom2_out = ctx.Output<LoDTensor>("Moment2Out");
    auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
    auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");

171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
    bool skip_update = false;
    if (ctx.HasInput("SkipUpdate")) {
      auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
      PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(SkipUpdate) size must be 1, but get %d",
                            skip_update_tensor->numel()));
      std::vector<bool> skip_update_vec;
      TensorToVector(*skip_update_tensor, ctx.device_context(),
                     &skip_update_vec);
      skip_update = skip_update_vec[0];
    }
    // skip_update=true, just copy input to output, and TensorCopy will call
    // mutable_data
    if (skip_update) {
      VLOG(4) << "Adam skip update";
      framework::TensorCopy(
          *param, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), param_out);
      framework::TensorCopy(
          *mom1, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), mom1_out);
      framework::TensorCopy(
          *mom2, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), mom2_out);
      framework::TensorCopy(
197
          *beta1_pow, beta1_pow->place(),
198 199 200
          ctx.template device_context<platform::DeviceContext>(),
          beta1_pow_out);
      framework::TensorCopy(
201
          *beta2_pow, beta2_pow->place(),
202 203 204 205 206
          ctx.template device_context<platform::DeviceContext>(),
          beta2_pow_out);
      return;
    }

207
    MPDType beta1 = static_cast<MPDType>(ctx.Attr<float>("beta1"));
208 209
    if (ctx.HasInput("Beta1Tensor")) {
      auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
W
wangchaochaohu 已提交
210 211 212 213
      PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(Beta1Tensor) size must be 1, but get %d",
                            beta1_tensor->numel()));
214
      beta1 = static_cast<MPDType>(GetAttrFromTensor(beta1_tensor));
215
    }
216
    MPDType beta2 = static_cast<MPDType>(ctx.Attr<float>("beta2"));
217 218
    if (ctx.HasInput("Beta2Tensor")) {
      auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
W
wangchaochaohu 已提交
219 220 221 222
      PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(Beta2Tensor) size must be 1, but get %d",
                            beta2_tensor->numel()));
223
      beta2 = static_cast<MPDType>(GetAttrFromTensor(beta2_tensor));
224
    }
225 226 227 228 229 230 231 232 233
    MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
    if (ctx.HasInput("EpsilonTensor")) {
      auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
      PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(EpsilonTensor) size must be 1, but get %d",
                            epsilon_tensor->numel()));
      epsilon = static_cast<MPDType>(GetAttrFromTensor(epsilon_tensor));
    }
234 235 236 237 238 239 240 241 242 243 244 245 246 247
    VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
            << "beta2_pow.numel() : " << beta2_pow->numel();
    VLOG(3) << "param.numel(): " << param->numel();
    PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
                      platform::errors::InvalidArgument(
                          "beta1 pow output size should be 1, but received "
                          "value is:%d.",
                          beta1_pow_out->numel()));

    PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
                      platform::errors::InvalidArgument(
                          "beta2 pow output size should be 1, but received "
                          "value is:%d.",
                          beta2_pow_out->numel()));
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269

    const bool multi_precision = ctx.Attr<bool>("multi_precision");
    const LoDTensor* master_param = nullptr;
    LoDTensor* master_param_out = nullptr;
    if (multi_precision) {
      bool has_master =
          ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
      PADDLE_ENFORCE_EQ(has_master, true,
                        platform::errors::InvalidArgument(
                            "The Input(MasterParam) and Output(MasterParamOut) "
                            "should not be null when "
                            "the attr `multi_precision` is true"));
      master_param = ctx.Input<LoDTensor>("MasterParam");
      master_param_out = ctx.Output<LoDTensor>("MasterParamOut");
    }
    const MPDType* master_in_data =
        multi_precision ? master_param->data<MPDType>() : nullptr;
    MPDType* master_out_data =
        multi_precision
            ? master_param_out->mutable_data<MPDType>(ctx.GetPlace())
            : nullptr;

270 271 272 273 274 275 276 277 278 279 280 281
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();

    if (grad_var->IsType<framework::LoDTensor>()) {
      auto* grad = ctx.Input<LoDTensor>("Grad");

      // update param and moment
      int threads = 512;
      int blocks = (param->numel() + threads - 1) / threads;

      if (beta1_pow->place() == platform::CPUPlace() &&
          beta2_pow->place() == platform::CPUPlace()) {
        // Compute with betapow in REG
282 283 284 285 286 287 288 289 290
        AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
            beta1, beta2, epsilon, *beta1_pow->data<MPDType>(),
            *beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad->data<T>(), param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, param->numel());
291 292 293 294 295 296 297
        if (!use_global_beta_pow) {
          // Cpu update
          beta1_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
              beta1 * beta1_pow->data<MPDType>()[0];
          beta2_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
              beta2 * beta2_pow->data<MPDType>()[0];
        }
298
      } else {
299 300 301 302 303 304 305 306 307
        AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
            beta1, beta2, epsilon, beta1_pow->data<MPDType>(),
            beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad->data<T>(), param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, param->numel());
308 309 310 311 312 313 314 315
        if (!use_global_beta_pow) {
          // Update with gpu
          UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
              beta1, beta2, beta1_pow->data<MPDType>(),
              beta2_pow->data<MPDType>(),
              beta1_pow_out->mutable_data<MPDType>(ctx.GetPlace()),
              beta2_pow_out->mutable_data<MPDType>(ctx.GetPlace()));
        }
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
      }
    } else if (grad_var->IsType<framework::SelectedRows>()) {
      auto* grad = ctx.Input<framework::SelectedRows>("Grad");
      if (grad->rows().size() == 0) {
        VLOG(3) << "grad row size is 0!!";
        return;
      }

      std::vector<int64_t> cpu_rows(grad->rows().begin(), grad->rows().end());
      bool is_strict_sorted = true;
      for (size_t i = 1; i < cpu_rows.size(); ++i) {
        if (cpu_rows[i - 1] >= cpu_rows[i]) {
          is_strict_sorted = false;
          break;
        }
      }

      framework::SelectedRows tmp_grad_merge;
      const framework::SelectedRows* grad_merge_ptr;
      if (is_strict_sorted) {
        grad_merge_ptr = grad;
      } else {
        // merge duplicated rows if any.
        // The rows of grad_merge have been sorted inside MergeAdd functor
        scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func;
        merge_func(ctx.template device_context<platform::CUDADeviceContext>(),
                   *grad, &tmp_grad_merge, true);
        grad_merge_ptr = &tmp_grad_merge;
      }
      auto& grad_merge = *grad_merge_ptr;
      auto& grad_tensor = grad_merge.value();
      const T* grad_data = grad_tensor.template data<T>();
      const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
      auto row_numel = grad_tensor.numel() / grad_merge.rows().size();

      if (beta1_pow->place() == platform::CPUPlace() &&
          beta2_pow->place() == platform::CPUPlace()) {
        int threads = 512;
        int ndim = param->numel();
        int blocks = (ndim + threads - 1) / threads;

357 358 359 360 361 362 363 364 365 366 367
        SparseAdamCUDAKernelREG<
            T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
            beta1, beta2, epsilon, *beta1_pow->data<MPDType>(),
            *beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad_data, param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, rows, row_numel, grad_merge.rows().size(),
            lazy_mode, ndim);
368 369 370 371 372 373 374
        if (!use_global_beta_pow) {
          // Update with cpu
          beta1_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
              beta1 * beta1_pow->data<MPDType>()[0];
          beta2_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
              beta2 * beta2_pow->data<MPDType>()[0];
        }
375
      } else {
376 377 378 379 380 381 382 383 384 385
        SparseAdamFunctor<T, GPUAdam, MPDType> functor(
            beta1, beta2, epsilon, beta1_pow->data<MPDType>(),
            beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad_data, param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, rows, row_numel, grad_merge.rows().size(),
            lazy_mode);
386 387 388 389 390 391 392

        // FIXME(minqiyang): remove BinarySearch in GPU later
        platform::ForRange<platform::CUDADeviceContext> for_range(
            static_cast<const platform::CUDADeviceContext&>(
                ctx.device_context()),
            param->numel());
        for_range(functor);
393 394 395 396 397 398 399 400
        if (!use_global_beta_pow) {
          // update beta1 and beta2
          UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
              beta1, beta2, beta1_pow->data<MPDType>(),
              beta2_pow->data<MPDType>(),
              beta1_pow_out->mutable_data<MPDType>(ctx.GetPlace()),
              beta2_pow_out->mutable_data<MPDType>(ctx.GetPlace()));
        }
401 402 403 404 405 406 407 408 409 410 411
      }
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Variable type not supported by adam_op"));
    }
  }
};

}  // namespace operators
}  // namespace paddle

412
namespace ops = paddle::operators;
413 414
namespace plat = paddle::platform;

415
REGISTER_OP_CUDA_KERNEL(adam, ops::AdamOpCUDAKernel<float>,
416 417
                        ops::AdamOpCUDAKernel<double>,
                        ops::AdamOpCUDAKernel<plat::float16>);