gradient_accumulator.cc 7.0 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/gradient_accumulator.h"
#include <algorithm>
#include <memory>
#include <utility>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"

namespace paddle {
namespace imperative {

template <typename T>
class TensorAddFunctor : public boost::static_visitor<> {
 public:
  TensorAddFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

  void operator()(const platform::CPUPlace& place) {
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
    blas.AXPY(numel_, 1., x_, y_);
  }

#ifdef PADDLE_WITH_CUDA
  void operator()(const platform::CUDAPlace& place) {
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
  void operator()(const platform::CUDAPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }
#endif

  // there is NO blas in CUDAPinnedPlace
  void operator()(const platform::CUDAPinnedPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }

 private:
  int64_t numel_;
  const T* x_;
  T* y_;
};

void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
  auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
  auto& src_tensor = src.Get<framework::LoDTensor>();

  auto numel = src_tensor.numel();

  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (numel == 0) {
    return;
  }

  PADDLE_ENFORCE_EQ(dst_tensor->numel() == numel, true,
                    "dst_numel %d vs. src_numel %d", dst_tensor->numel(),
                    numel);

  auto data_type = src_tensor.type();
  auto place = src_tensor.place();

#define PADDLE_TENSOR_ADD_MACRO(cpp_type)                            \
  if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
    TensorAddFunctor<cpp_type> func(                                 \
        numel, src_tensor.data<cpp_type>(),                          \
        dst_tensor->mutable_data<cpp_type>(place));                  \
    boost::apply_visitor(func, place);                               \
    return;                                                          \
  }

  PADDLE_TENSOR_ADD_MACRO(float);
  PADDLE_TENSOR_ADD_MACRO(double);

#undef PADDLE_TENSOR_ADD_MACRO

  PADDLE_THROW("Not supported data type %s for AddTo",
               framework::DataTypeToString(data_type));
}

void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var,
                                   size_t trace_id) {
  auto* dst_var = var_->MutableVar();
108 109 110 111 112 113 114 115
  auto place = var->Var().Get<framework::LoDTensor>().place();
  if (!var_->OverridedStopGradient()) {
    VLOG(3) << "Sum Gradient for: " << var_->Name();
    if (cur_cnt_ == 0) {
      *dst_var = std::move(*(var->MutableVar()));
    } else {
      TensorAdd(var->Var(), dst_var);
    }
J
Jiabin Yang 已提交
116
  } else {
117 118
    if (!var_->Var().IsInitialized() ||
        !var_->Var().Get<framework::LoDTensor>().IsInitialized()) {
119 120
      VLOG(6) << "Set StopGradient Grad: " << var_->Name() << " as zero ";

121
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
122 123 124 125 126 127 128 129 130 131 132 133
      if (!var_->Var().IsInitialized()) {
        auto* tensor = var_->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << var_->Name() << " is set as: "
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      } else {
        auto* tensor = var_->MutableVar()->GetMutable<framework::LoDTensor>();
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      }
134
    }
J
Jiabin Yang 已提交
135 136 137 138 139 140 141
  }
  ++cur_cnt_;
}

void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
                                    size_t trace_id) {
  auto* dst_var = var_->MutableVar();
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
  auto place = var->Var().Get<framework::LoDTensor>().place();
  if (!var_->OverridedStopGradient()) {
    if (ref_cnt_ == 1) {
      *dst_var = std::move(*(var->MutableVar()));
    } else {
      if (tmp_grad_vars_.empty()) {
        tmp_grad_vars_.reserve(ref_cnt_);
      }

      tmp_grad_vars_.emplace_back(std::move(var), trace_id);

      if (tmp_grad_vars_.size() != ref_cnt_) {
        return;
      }

      std::sort(tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
                [](const std::pair<std::shared_ptr<VarBase>, size_t>& p1,
                   const std::pair<std::shared_ptr<VarBase>, size_t>& p2) {
                  return p1.second > p2.second;
                });

      *dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar()));
      for (size_t i = 1; i < tmp_grad_vars_.size(); ++i) {
        TensorAdd(tmp_grad_vars_[i].first->Var(), dst_var);
      }

      tmp_grad_vars_.clear();
J
Jiabin Yang 已提交
169
    }
170 171 172 173 174
  } else {
    if (!var_->Var().IsInitialized() ||
        !var_->Var().Get<framework::LoDTensor>().IsInitialized()) {
      VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero";
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
175 176 177 178 179 180 181 182 183 184 185 186
      if (!var_->Var().IsInitialized()) {
        auto* tensor = var_->MutableVar()->GetMutable<framework::LoDTensor>();
        VLOG(6) << "Dims of " << var_->Name() << " is set as: "
                << var->Var().Get<framework::LoDTensor>().dims();
        tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      } else {
        auto* tensor = var_->MutableVar()->GetMutable<framework::LoDTensor>();
        tensor->mutable_data(place, var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      }
J
Jiabin Yang 已提交
187
    }
188
    // looks like tmp_grad_vars will not have any member but just in case
J
Jiabin Yang 已提交
189 190 191 192 193 194
    tmp_grad_vars_.clear();
  }
}

}  // namespace imperative
}  // namespace paddle