adamw_op.cu 18.3 KB
Newer Older
Z
zhaoyingli 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/adamw_op.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

template <typename T, typename MT>
__global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
23 24 25 26 27 28 29
                               MT lr_ratio, MT beta1_pow_, MT beta2_pow_,
                               const MT* moment1, MT* moment1_out,
                               const MT* moment2, MT* moment2_out,
                               const MT* lr_, const T* grad, const T* param,
                               T* param_out, const MT* master_param,
                               MT* master_param_out, int ndim) {
  MT lr = *lr_ * lr_ratio;
Z
zhaoyingli 已提交
30 31 32 33 34 35 36 37
  MT beta1_pow = beta1_pow_;
  MT beta2_pow = beta2_pow_;

  int id = blockIdx.x * blockDim.x + threadIdx.x;

  for (; id < ndim; id += gridDim.x * blockDim.x) {
    MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
    MT g = static_cast<MT>(grad[id]);
38 39 40 41 42
    MT mom1 = static_cast<MT>(moment1[id]);
    MT mom2 = static_cast<MT>(moment2[id]);

    p *= (static_cast<MT>(1.0) - lr * coeff);

Z
zhaoyingli 已提交
43 44
    mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
    mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
45 46 47 48

    MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;

    p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
Z
zhaoyingli 已提交
49 50 51 52 53 54 55 56 57 58 59

    moment1_out[id] = mom1;
    moment2_out[id] = mom2;
    param_out[id] = static_cast<T>(p);
    if (master_param_out) {
      master_param_out[id] = p;
    }
  }
}

template <typename T, typename MT>
60 61 62 63 64 65
__global__ void AdamWKernelMEM(
    MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT* beta1_pow_,
    const MT* beta2_pow_, const MT* moment1, MT* moment1_out, const MT* moment2,
    MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out,
    const MT* master_param, MT* master_param_out, int ndim) {
  MT lr = *lr_ * lr_ratio;
Z
zhaoyingli 已提交
66 67 68 69 70 71 72 73 74 75
  MT beta1_pow = *beta1_pow_;
  MT beta2_pow = *beta2_pow_;

  int id = blockIdx.x * blockDim.x + threadIdx.x;

  for (; id < ndim; id += gridDim.x * blockDim.x) {
    MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
    MT g = static_cast<MT>(grad[id]);
    MT mom1 = static_cast<MT>(moment1[id]);
    MT mom2 = static_cast<MT>(moment2[id]);
76 77 78

    p *= (static_cast<MT>(1.0) - lr * coeff);

Z
zhaoyingli 已提交
79 80
    mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
    mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
81 82 83 84

    MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;

    p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
Z
zhaoyingli 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103

    moment1_out[id] = mom1;
    moment2_out[id] = mom2;
    param_out[id] = static_cast<T>(p);
    if (master_param_out) {
      master_param_out[id] = p;
    }
  }
}
template <typename T>
__global__ void UpdateAdamWBetaPow(T beta1, T beta2, const T* beta1_pow_,
                                   const T* beta2_pow_, T* beta1_pow_out,
                                   T* beta2_pow_out) {
  *beta1_pow_out = beta1 * beta1_pow_[0];
  *beta2_pow_out = beta2 * beta2_pow_[0];
}

template <typename T, typename MT>
__global__ void SparseAdamWCUDAKernelREG(
104
    MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT beta1_pow,
Z
zhaoyingli 已提交
105 106 107 108 109 110
    const MT beta2_pow, const MT* mom1_, MT* mom1_out_, const MT* mom2_,
    MT* mom2_out_, const MT* lr_, const T* grad_, const T* param_,
    T* param_out_, const MT* master_param, MT* master_param_out,
    const int64_t* rows_, int64_t row_numel, int64_t row_count, bool lazy_mode,
    int ndim) {
  int id = blockIdx.x * blockDim.x + threadIdx.x;
111
  MT lr = *lr_ * lr_ratio;
Z
zhaoyingli 已提交
112 113 114

  for (; id < ndim; id += blockDim.x * gridDim.x) {
    auto row_idx =
115
        phi::funcs::BinarySearch<int64_t>(rows_, row_count, id / row_numel);
Z
zhaoyingli 已提交
116 117 118
    if (lazy_mode && row_idx < 0) {
      return;
    } else {
119 120 121
      MT mom1 = static_cast<MT>(mom1_[id]);
      MT mom2 = static_cast<MT>(mom2_[id]);

Z
zhaoyingli 已提交
122 123 124 125
      MT p = master_param ? master_param[id] : static_cast<MT>(param_[id]);
      MT g = row_idx >= 0
                 ? static_cast<MT>(grad_[row_idx * row_numel + id % row_numel])
                 : static_cast<MT>(0);
126 127 128

      p *= (static_cast<MT>(1.0) - lr * coeff);

Z
zhaoyingli 已提交
129 130
      mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
      mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
131 132 133 134 135

      MT denom =
          (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;

      p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
Z
zhaoyingli 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167

      // Write back to global memory
      mom1_out_[id] = mom1;
      mom2_out_[id] = mom2;
      param_out_[id] = static_cast<T>(p);
      if (master_param_out) {
        master_param_out[id] = p;
      }
    }
  }
}

template <typename T>
class AdamWOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    const auto* param_var = ctx.InputVar("Param");
    PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
                      platform::errors::InvalidArgument(
                          "The Var(%s)'s type should be LoDTensor, "
                          "but the received is %s",
                          ctx.InputNames("Param").front(),
                          framework::ToTypeName(param_var->Type())));

    using paddle::framework::LoDTensor;
    using MPDType = typename details::MPTypeTrait<T>::Type;

    int64_t min_row_size_to_use_multithread =
        ctx.Attr<int64_t>("min_row_size_to_use_multithread");
    bool lazy_mode = ctx.Attr<bool>("lazy_mode");
    bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
    VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
168 169 170

    MPDType coeff = static_cast<MPDType>(ctx.Attr<float>("coeff"));
    MPDType lr_ratio = static_cast<MPDType>(ctx.Attr<float>("lr_ratio"));
Z
zhaoyingli 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194

    auto* param = ctx.Input<LoDTensor>("Param");
    auto* grad_var = ctx.InputVar("Grad");
    auto* mom1 = ctx.Input<LoDTensor>("Moment1");
    auto* mom2 = ctx.Input<LoDTensor>("Moment2");
    auto* lr = ctx.Input<LoDTensor>("LearningRate");

    auto* beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
    auto* beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");

    auto* param_out = ctx.Output<LoDTensor>("ParamOut");
    auto* mom1_out = ctx.Output<LoDTensor>("Moment1Out");
    auto* mom2_out = ctx.Output<LoDTensor>("Moment2Out");
    auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
    auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");

    bool skip_update = false;
    if (ctx.HasInput("SkipUpdate")) {
      auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
      PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(SkipUpdate) size must be 1, but get %d",
                            skip_update_tensor->numel()));
      std::vector<bool> skip_update_vec;
195 196
      paddle::framework::TensorToVector(*skip_update_tensor,
                                        ctx.device_context(), &skip_update_vec);
Z
zhaoyingli 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
      skip_update = skip_update_vec[0];
    }

    // skip_update=true, just copy input to output, and TensorCopy will call
    // mutable_data
    if (skip_update) {
      VLOG(4) << "Adamw skip update";
      framework::TensorCopy(
          *param, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), param_out);
      framework::TensorCopy(
          *mom1, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), mom1_out);
      framework::TensorCopy(
          *mom2, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(), mom2_out);
      framework::TensorCopy(
          *beta1_pow, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(),
          beta1_pow_out);
      framework::TensorCopy(
          *beta2_pow, ctx.GetPlace(),
          ctx.template device_context<platform::DeviceContext>(),
          beta2_pow_out);
      return;
    }

    // if with_decay = false, coeff = 0
    bool with_decay = ctx.Attr<bool>("with_decay");
    if (!with_decay) {
      coeff = static_cast<float>(0.0);
    }

    MPDType beta1 = static_cast<MPDType>(ctx.Attr<float>("beta1"));
    if (ctx.HasInput("Beta1Tensor")) {
      auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
      PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(Beta1Tensor) size must be 1, but get %d",
                            beta1_tensor->numel()));
      beta1 = static_cast<MPDType>(GetAttrFromTensor(beta1_tensor));
    }
    MPDType beta2 = static_cast<MPDType>(ctx.Attr<float>("beta2"));
    if (ctx.HasInput("Beta2Tensor")) {
      auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
      PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(Beta2Tensor) size must be 1, but get %d",
                            beta2_tensor->numel()));
      beta2 = static_cast<MPDType>(GetAttrFromTensor(beta2_tensor));
    }
    MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
    if (ctx.HasInput("EpsilonTensor")) {
      auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
      PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
                        platform::errors::InvalidArgument(
                            "Input(EpsilonTensor) size must be 1, but get %d",
                            epsilon_tensor->numel()));
      epsilon = static_cast<MPDType>(GetAttrFromTensor(epsilon_tensor));
    }
    VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
            << "beta2_pow.numel() : " << beta2_pow->numel();
    VLOG(3) << "param.numel(): " << param->numel();
    PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
                      platform::errors::InvalidArgument(
                          "beta1 pow output size should be 1, but received "
                          "value is:%d.",
                          beta1_pow_out->numel()));

    PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
                      platform::errors::InvalidArgument(
                          "beta2 pow output size should be 1, but received "
                          "value is:%d.",
                          beta2_pow_out->numel()));

    const bool multi_precision = ctx.Attr<bool>("multi_precision");
    const LoDTensor* master_param = nullptr;
    LoDTensor* master_param_out = nullptr;
    if (multi_precision) {
      bool has_master =
          ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
      PADDLE_ENFORCE_EQ(has_master, true,
                        platform::errors::InvalidArgument(
                            "The Input(MasterParam) and Output(MasterParamOut) "
                            "should not be null when "
                            "the attr `multi_precision` is true"));
      master_param = ctx.Input<LoDTensor>("MasterParam");
      master_param_out = ctx.Output<LoDTensor>("MasterParamOut");
    }
    const MPDType* master_in_data =
        multi_precision ? master_param->data<MPDType>() : nullptr;
    MPDType* master_out_data =
        multi_precision
            ? master_param_out->mutable_data<MPDType>(ctx.GetPlace())
            : nullptr;

    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();

    if (grad_var->IsType<framework::LoDTensor>()) {
      auto* grad = ctx.Input<LoDTensor>("Grad");

      // update param and moment
      int threads = 512;
      int blocks = (param->numel() + threads - 1) / threads;

      if (beta1_pow->place() == platform::CPUPlace() &&
          beta2_pow->place() == platform::CPUPlace()) {
        // Compute with betapow in REG
        AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
306
            beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data<MPDType>(),
Z
zhaoyingli 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
            *beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad->data<T>(), param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, param->numel());
        if (!use_global_beta_pow) {
          // Cpu update
          beta1_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
              beta1 * beta1_pow->data<MPDType>()[0];
          beta2_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
              beta2 * beta2_pow->data<MPDType>()[0];
        }
      } else {
        AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
323
            beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data<MPDType>(),
Z
zhaoyingli 已提交
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
            beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad->data<T>(), param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, param->numel());
        if (!use_global_beta_pow) {
          // Update with gpu
          UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
              beta1, beta2, beta1_pow->data<MPDType>(),
              beta2_pow->data<MPDType>(),
              beta1_pow_out->mutable_data<MPDType>(ctx.GetPlace()),
              beta2_pow_out->mutable_data<MPDType>(ctx.GetPlace()));
        }
      }
340 341
    } else if (grad_var->IsType<phi::SelectedRows>()) {
      auto* grad = ctx.Input<phi::SelectedRows>("Grad");
Z
zhaoyingli 已提交
342 343 344 345 346 347 348 349 350 351 352 353 354 355
      if (grad->rows().size() == 0) {
        VLOG(3) << "grad row size is 0!!";
        return;
      }

      std::vector<int64_t> cpu_rows(grad->rows().begin(), grad->rows().end());
      bool is_strict_sorted = true;
      for (size_t i = 1; i < cpu_rows.size(); ++i) {
        if (cpu_rows[i - 1] >= cpu_rows[i]) {
          is_strict_sorted = false;
          break;
        }
      }

356 357
      phi::SelectedRows tmp_grad_merge;
      const phi::SelectedRows* grad_merge_ptr;
Z
zhaoyingli 已提交
358 359 360 361 362 363 364 365 366 367 368 369 370
      if (is_strict_sorted) {
        grad_merge_ptr = grad;
      } else {
        // merge duplicated rows if any.
        // The rows of grad_merge have been sorted inside MergeAdd functor
        scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func;
        merge_func(ctx.template device_context<platform::CUDADeviceContext>(),
                   *grad, &tmp_grad_merge, true);
        grad_merge_ptr = &tmp_grad_merge;
      }
      auto& grad_merge = *grad_merge_ptr;
      auto& grad_tensor = grad_merge.value();
      const T* grad_data = grad_tensor.template data<T>();
371 372 373 374
      auto* grad_merge_rows = &grad_merge.rows();
      paddle::framework::MixVector<int64_t> mixv_grad_merge_rows(
          grad_merge_rows);
      const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace());
Z
zhaoyingli 已提交
375 376 377 378 379 380 381 382 383 384
      auto row_numel = grad_tensor.numel() / grad_merge.rows().size();

      if (beta1_pow->place() == platform::CPUPlace() &&
          beta2_pow->place() == platform::CPUPlace()) {
        int threads = 512;
        int ndim = param->numel();
        int blocks = (ndim + threads - 1) / threads;

        SparseAdamWCUDAKernelREG<
            T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
385
            beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data<MPDType>(),
Z
zhaoyingli 已提交
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
            *beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad_data, param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, rows, row_numel, grad_merge.rows().size(),
            lazy_mode, ndim);
        if (!use_global_beta_pow) {
          // Update with cpu
          beta1_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
              beta1 * beta1_pow->data<MPDType>()[0];
          beta2_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
              beta2 * beta2_pow->data<MPDType>()[0];
        }
      } else {
        SparseAdamWFunctor<T, GPUAdamW, MPDType> functor(
403
            beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data<MPDType>(),
Z
zhaoyingli 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
            beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
            mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
            mom2->data<MPDType>(),
            mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
            lr->data<MPDType>(), grad_data, param->data<T>(),
            param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
            master_out_data, rows, row_numel, grad_merge.rows().size(),
            lazy_mode);

        // FIXME(minqiyang): remove BinarySearch in GPU later
        platform::ForRange<platform::CUDADeviceContext> for_range(
            static_cast<const platform::CUDADeviceContext&>(
                ctx.device_context()),
            param->numel());
        for_range(functor);
        if (!use_global_beta_pow) {
          // update beta1 and beta2
          UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
              beta1, beta2, beta1_pow->data<MPDType>(),
              beta2_pow->data<MPDType>(),
              beta1_pow_out->mutable_data<MPDType>(ctx.GetPlace()),
              beta2_pow_out->mutable_data<MPDType>(ctx.GetPlace()));
        }
      }
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Variable type not supported by adamw_op"));
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(adamw, ops::AdamWOpCUDAKernel<float>,
                        ops::AdamWOpCUDAKernel<double>,
                        ops::AdamWOpCUDAKernel<plat::float16>);