dgc_op.h 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <vector>
#include "dgc/dgc.h"
#include "paddle/fluid/framework/eigen.h"
19
#include "paddle/fluid/memory/malloc.h"
20 21 22 23 24 25 26
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"

namespace paddle {
namespace operators {

inline float get_period_sparcity(const std::vector<float>& sparsity,
                                 float cur_step, float rampup_steps) {
27
  PADDLE_ENFORCE_GE(static_cast<int>(cur_step), 0);
28 29 30 31 32 33

  size_t idx = static_cast<int>(cur_step * sparsity.size() / rampup_steps);
  if (idx >= sparsity.size()) {
    return 0.999;
  }

34
  PADDLE_ENFORCE_LT(idx, sparsity.size());
35 36 37 38 39 40 41 42 43 44 45
  return sparsity[idx];
}

template <typename DeviceContext, typename T>
class DGCOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto u = ctx.Input<framework::Tensor>("U");
    auto v = ctx.Input<framework::Tensor>("V");
    auto g = ctx.Input<framework::Tensor>("Grad");

46 47
    auto grad_out = ctx.Output<framework::Tensor>("Grad_out");

48 49 50 51 52 53 54
    // attrs
    float m = ctx.Attr<float>("m");
    bool use_nesterov = ctx.Attr<bool>("use_nesterov");
    auto sparsity = ctx.Attr<std::vector<float>>("sparsity");
    auto rampup_begin_step = ctx.Attr<float>("rampup_begin_step");
    auto rampup_step = ctx.Attr<float>("rampup_step");

55 56 57 58 59
    // nranks
    auto nranks_tensor = ctx.Input<framework::Tensor>("nranks");
    const int nranks = static_cast<const int>(*nranks_tensor->data<float>());
    PADDLE_ENFORCE_GT(nranks, 1, "DGC is not useful when num_trainers <= 1");

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    // regularization
    auto p = ctx.Input<framework::Tensor>("Param");
    float regular_coeff = ctx.Attr<float>("regular_coeff");
    int regular_type = ctx.Attr<int>("regular_type");

    auto p_e = framework::EigenVector<T>::Flatten(*p);
    auto g_e = framework::EigenVector<T>::Flatten(*g);
    auto grad_out_e = framework::EigenVector<T>::Flatten(*grad_out);

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto& eigen_ctx = *dev_ctx.eigen_device();

    // NOTE. In paddle, loss has divided by nranks. Because dgc_op is before
    // allreduce, so local regular_coeff need div nranks too. But now we
    // multi grad with nranks in dgc_op, in that case regular_coeff don't
    // need to /nranks, can prevent precision loss. For coeff often equal
    // with 1e-4, if nranks=32, coeff/nranks will be 3.125e-6, the numerical
    // accuracy of coeff/nranks will be too low.
    PADDLE_ENFORCE_EQ(regular_type >= 0 && regular_type <= 2, true,
                      platform::errors::InvalidArgument(
                          "DGC only support one of None|L1Decay|L2Decay "
                          "Regularization for now."));
    if (regular_type == 0) {
      grad_out_e.device(eigen_ctx) = (1.0 * nranks) * g_e;
    } else if (regular_type == 1) {
      // L1Decay. grad = grad + coeff * sign(param)
      grad_out_e.device(eigen_ctx) =
          (1.0 * nranks) * g_e + regular_coeff * p_e.sign();
    } else if (regular_type == 2) {
      // L2Decay. grad = grad + coeff * param
      grad_out_e.device(eigen_ctx) = (1.0 * nranks) * g_e + regular_coeff * p_e;
    }

93 94 95 96 97 98 99 100 101 102 103 104 105 106
    // current step
    auto current_step_tensor = ctx.Input<framework::Tensor>("current_step");
    const float* current_step = current_step_tensor->data<float>();

    if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
      VLOG(10) << "current_step:" << *current_step
               << " < rampup_begin_step:" << rampup_begin_step
               << " so does't use dgc";
      return;
    }

    float ratio =
        1 - get_period_sparcity(sparsity, static_cast<float>(*current_step),
                                rampup_step);
107 108
    PADDLE_ENFORCE_GE(ratio, 0.0);
    PADDLE_ENFORCE_LT(ratio, 1.0);
109 110 111 112 113 114
    int k = static_cast<int>(g->numel() * ratio);

    VLOG(10) << "m:" << m << ", use_nesterov:" << use_nesterov
             << ", rampup_begin_step:" << rampup_begin_step
             << ", rampup_step:" << rampup_step
             << ",  current_step:" << *current_step << ", ratio:" << ratio
115
             << ", k:" << k << ", nranks:" << nranks;
116 117 118 119 120 121 122 123

    auto k_out = ctx.Output<framework::Tensor>("k");
    T* k_out_data = k_out->data<T>();
    *k_out_data = k;

    auto u_out = ctx.Output<framework::Tensor>("U_out");
    auto v_out = ctx.Output<framework::Tensor>("V_out");
    auto encode_grad_out = ctx.Output<framework::Tensor>("EncodeGrad");
124
    auto gather_buff = ctx.Output<framework::Tensor>("GatherBuff");
125 126 127 128

    // FIXME(gongwb): use cublas.
    auto u_out_e = framework::EigenVector<T>::Flatten(*u_out);
    auto u_e = framework::EigenVector<T>::Flatten(*u);
129

130 131 132 133 134 135
    // calc local momentum from global momentum
    // NOTE. If grad not multi nranks, need add below code.
    // if (static_cast<int>(*current_step) ==
    //     static_cast<int>(rampup_begin_step)) {
    //   u_out_e.device(eigen_ctx) = (1.0 / nranks) * u_e;
    // }
136

137 138
    if (use_nesterov) {
      // u = m * (u + g)
139
      u_out_e.device(eigen_ctx) = m * (u_e + grad_out_e);
140 141 142 143 144 145 146 147 148

      // v = u + v + g
      ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
          ctx, u, v, 0, AddFunctor<T>(), v_out);

      ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
          ctx, g, v, 0, AddFunctor<T>(), v_out);
    } else {
      // u = m * u + g
149
      u_out_e.device(eigen_ctx) = m * u_e + grad_out_e;
150 151 152 153 154 155 156 157 158 159

      // v = u + v
      ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
          ctx, u, v, 0, AddFunctor<T>(), v_out);
    }

    T* v_out_data = v_out->mutable_data<T>(ctx.GetPlace());
    T* u_out_data = u_out->mutable_data<T>(ctx.GetPlace());
    T* encode_grad_out_data = encode_grad_out->mutable_data<T>(
        framework::DDim{2 * k}, ctx.GetPlace());
160 161
    gather_buff->mutable_data<T>(framework::DDim{2 * k * nranks},
                                 ctx.GetPlace());
162 163

    int buf_size = paddle::communication::dgc::get_buffer_size(k);
164
    auto tmp_ious_data = memory::Alloc(dev_ctx, buf_size);
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    void* buf = reinterpret_cast<void*>(tmp_ious_data->ptr());

    if (!paddle::communication::dgc::k_select(
            static_cast<void*>(encode_grad_out_data), k, v_out_data,
            static_cast<int>(v_out->numel()), buf, dev_ctx.stream(),
            u_out_data)) {
      LOG(FATAL) << "v_out numel:" << v_out->numel();
    }

    math::SetConstant<DeviceContext, T> tset;
    tset(dev_ctx, grad_out, static_cast<T>(0));
  }
};
}  // namespace operators
}  // namespace paddle