lars_momentum_op.cu 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
W
Wu Yi 已提交
17
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
18 19 20 21 22

namespace paddle {
namespace operators {

template <typename T>
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

template <typename T, typename MT>
__global__ void MomentumLarsKernel(
    const T* p, const T* g, const MT* v,
    const MultiPrecisionType<T>* learning_rate, const MT mu, const int64_t num,
    const MT lars_coeff, const MT lars_weight_decay,
    const MultiPrecisionType<T>* p_norm, const MultiPrecisionType<T>* g_norm,
    T* p_out, MT* v_out, const MT epsilon, const MT* master_p, MT* master_p_out,
    const MultiPrecisionType<T> rescale_grad) {
  const MT lr = static_cast<MT>(learning_rate[0]);
  MT local_lr = lr;
  const MT p_n = static_cast<MT>(p_norm[0]);
  const MT g_n = static_cast<MT>(g_norm[0]);

  if (lars_weight_decay > static_cast<MT>(0) && p_n > static_cast<MT>(0) &&
      g_n > static_cast<MT>(0)) {
    local_lr =
        lr * lars_coeff * p_n / (g_n + lars_weight_decay * p_n + epsilon);
  }
43
  CUDA_KERNEL_LOOP(i, num) {
44 45 46 47 48
    MT grad = static_cast<MT>(g[i]) * static_cast<MT>(rescale_grad);
    MT param = master_p ? master_p[i] : static_cast<MT>(p[i]);

    MT v_new = v[i] * mu + local_lr * (grad + lars_weight_decay * param);
    MT p_new = param - v_new;
49

50
    v_out[i] = v_new;
51 52
    p_out[i] = static_cast<T>(p_new);
    if (master_p_out) master_p_out[i] = p_new;
53 54 55 56 57
  }
}

template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
58 59
  using MPDType = MultiPrecisionType<T>;

60 61
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
62 63 64 65 66 67 68 69 70 71 72 73
    const bool multi_precision = ctx.Attr<bool>("multi_precision");
    if (multi_precision) {
      InnerCompute<MPDType>(ctx, multi_precision);
    } else {
      InnerCompute<T>(ctx, multi_precision);
    }
  }

 private:
  template <typename MT>
  void InnerCompute(const framework::ExecutionContext& ctx,
                    const bool multi_precision) const {
74 75 76 77 78 79 80
    auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
    auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
    auto param = ctx.Input<framework::LoDTensor>("Param");
    auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
    auto grad = ctx.Input<framework::LoDTensor>("Grad");
    auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    const framework::Tensor* master_param = nullptr;
    framework::Tensor* master_param_out = nullptr;
    if (multi_precision) {
      bool has_master =
          ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
      PADDLE_ENFORCE_EQ(has_master, true,
                        platform::errors::InvalidArgument(
                            "The Input(MasterParam) and Output(MasterParamOut) "
                            "should not be null when "
                            "the attr `multi_precision` is true"));
      master_param = ctx.Input<framework::Tensor>("MasterParam");
      master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
    }

    const MT* master_p = multi_precision ? master_param->data<MT>() : nullptr;
    MT* master_p_out = multi_precision
                           ? master_param_out->mutable_data<MT>(ctx.GetPlace())
                           : nullptr;

100
    T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
101
    MT* v_out = velocity_out->mutable_data<MT>(ctx.GetPlace());
102

103 104 105 106 107 108 109
    MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
    MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff"));
    MT lars_weight_decay =
        static_cast<MT>(ctx.Attr<float>("lars_weight_decay"));
    MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
    MPDType rescale_grad =
        static_cast<MPDType>(ctx.Attr<float>("rescale_grad"));
110 111 112

    auto* p = param->data<T>();
    auto* g = grad->data<T>();
113 114
    auto* v = velocity->data<MT>();
    auto* lr = learning_rate->data<MPDType>();
115 116 117 118 119 120 121 122 123 124

    int block = 512;
    int grid = (param->numel() + block - 1) / block;

    auto eigen_p = framework::EigenVector<T>::Flatten(*param);
    auto eigen_g = framework::EigenVector<T>::Flatten(*grad);
    // calculate norms using eigein and launch the kernel.
    framework::Tensor p_norm_t, g_norm_t;
    p_norm_t.Resize({1});
    g_norm_t.Resize({1});
125 126 127 128
    auto* p_norm_data = p_norm_t.mutable_data<MPDType>(ctx.GetPlace());
    auto* g_norm_data = g_norm_t.mutable_data<MPDType>(ctx.GetPlace());
    auto ep_norm = framework::EigenScalar<MPDType>::From(p_norm_t);
    auto eg_norm = framework::EigenScalar<MPDType>::From(g_norm_t);
129 130

    auto* place = ctx.template device_context<DeviceContext>().eigen_device();
131 132 133 134 135 136 137 138 139

    // eigen unsupport fp16 l2-norm
    ep_norm.device(*place) =
        eigen_p.template cast<MPDType>().square().sum().sqrt();
    eg_norm.device(*place) =
        (eigen_g.template cast<MPDType>() * rescale_grad).square().sum().sqrt();

    MomentumLarsKernel<
        T, MT><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
140
        p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay,
141 142
        p_norm_data, g_norm_data, p_out, v_out, epsilon, master_p, master_p_out,
        rescale_grad);
143 144 145 146 147 148 149 150 151 152
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    lars_momentum,
    ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
153 154 155
    ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext,
                                  paddle::platform::float16>);