layer.cc 16.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/layer.h"
M
minqiyang 已提交
16

17
#include <algorithm>
18 19 20 21
#include <deque>
#include <limits>
#include <map>
#include <random>
M
minqiyang 已提交
22
#include <unordered_set>
23 24 25 26
#include <utility>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
27
#include "paddle/fluid/framework/operator.h"
M
minqiyang 已提交
28 29 30
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
C
chengduo 已提交
31
#include "paddle/fluid/platform/profiler.h"
32 33 34 35 36
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace imperative {

Z
Zeng Jinle 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
void ThreadSafeNameSet::Insert(const std::string& name) {
  std::lock_guard<std::mutex> guard(mtx_);
  set_.insert(name);
}

void ThreadSafeNameSet::Remove(const std::string& name) {
  std::lock_guard<std::mutex> guard(mtx_);
  auto iter = set_.find(name);
  PADDLE_ENFORCE(iter != set_.end(), "%s does not exist", name);
  set_.erase(iter);
}

std::vector<std::string> ThreadSafeNameSet::Names() const {
  std::lock_guard<std::mutex> guard(mtx_);
  return std::vector<std::string>(set_.begin(), set_.end());
}

ThreadSafeNameSet VarBase::name_set_;

std::vector<std::string> VarBase::AliveVarNames() { return name_set_.Names(); }

58 59
using framework::Variable;

M
minqiyang 已提交
60 61 62 63 64 65 66 67 68 69 70
namespace detail {

template <typename T>
class TensorAddToFunctor : public boost::static_visitor<> {
 public:
  TensorAddToFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

  void operator()(const platform::CPUPlace& place) {
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
71
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
M
minqiyang 已提交
72 73 74 75 76 77 78 79
    blas.AXPY(numel_, 1., x_, y_);
  }

#ifdef PADDLE_WITH_CUDA
  void operator()(const platform::CUDAPlace& place) {
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
80
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
M
minqiyang 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
  void operator()(const platform::CUDAPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }
#endif

  // there is NO blas in CUDAPinnedPlace
  void operator()(const platform::CUDAPinnedPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }

 private:
  int64_t numel_;
  const T* x_;
  T* y_;
};

}  // namespace detail

102
void AddTo(std::shared_ptr<VarBase> src, std::shared_ptr<VarBase> dst,
103 104 105 106
           platform::Place place, GradientRef* grad_ref) {
  PADDLE_ENFORCE(grad_ref->find(dst.get()) != grad_ref->end(),
                 "gradient %s are not found in grad_ref", dst->Name());
  if ((*grad_ref)[dst.get()].second) {
107 108
    PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase");
    dst->var_ = std::move(src->var_);
109 110 111 112
    (*grad_ref)[dst.get()].second = false;
    if (!dst->IsInitialize()) {
      dst->SetInitialize(true);
    }
M
minqiyang 已提交
113
    return;
114 115 116 117 118 119 120 121 122 123 124
  } else {
    framework::Tensor* dst_tensor =
        dst->var_->GetMutable<framework::LoDTensor>();
    framework::Tensor* src_tensor =
        src->var_->GetMutable<framework::LoDTensor>();

    // FIXME(minqiyang): loss_grad op will pass a zero grad of label
    // ugly fix for it
    if (src_tensor->numel() == 0) {
      return;
    }
M
minqiyang 已提交
125

126 127 128
    PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
                   "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
                   src_tensor->numel());
M
minqiyang 已提交
129

130 131 132 133 134
    detail::TensorAddToFunctor<float> func(
        src_tensor->numel(), src_tensor->data<float>(),
        dst_tensor->mutable_data<float>(place));
    boost::apply_visitor(func, place);
  }
135 136
}

137 138
void ZeroGrads(const std::shared_ptr<imperative::VarBase> vb,
               const platform::Place& place) {
139 140 141 142 143 144
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  auto grad_t = vb->var_->GetMutable<framework::LoDTensor>();
  operators::math::set_constant(*dev_ctx, grad_t, 0.0);
}

145
void AddGradBySort(BackwardSumMap* bck_map,
146 147
                   std::shared_ptr<imperative::VarBase> target,
                   GradientRef* grad_ref) {
148
  PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(),
149
                 "Can't find %s in backward grad map", target->Name());
150 151 152 153 154 155 156 157
  std::pair<platform::Place,
            std::vector<std::pair<int, std::shared_ptr<imperative::VarBase>>>>&
      current = bck_map->at(target.get());
  std::sort(current.second.begin(), current.second.end(),
            [](const std::pair<int, std::shared_ptr<imperative::VarBase>>& a,
               const std::pair<int, std::shared_ptr<imperative::VarBase>>& b) {
              return a.first > b.first;
            });
158
  for (auto& var_pair : current.second) {
159 160 161
    VLOG(10) << "add origin_grad: " << target->Name();
    VLOG(10) << "added grad: " << var_pair.second->Name()
             << " trace id is: " << var_pair.first;
162
    AddTo(var_pair.second, target, current.first, grad_ref);
163
    var_pair.second.reset();
164 165 166
  }
}

167 168
class Autograd {
 public:
X
Xin Pan 已提交
169
  Autograd() {}
170

171
  void RunBackward(VarBase* var, const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
172
    if (var->IsStopGradient()) {
173 174
      return;
    }
175
    VLOG(2) << "start autograd";
176
    BackwardSumMap bck_map;
177
    std::deque<OpBase*> ready;
X
Xin Pan 已提交
178
    ready.push_back(var->PreOp());
179

180
    std::map<OpBase*, int> dep_counts =
181
        ComputeDepCounts(var->PreOp(), bck_stratedy, &grad_ref);
182 183 184 185

    while (!ready.empty()) {
      OpBase* ready_op = ready.front();
      ready.pop_front();
186
      std::vector<VarBasePtrMap> grads_outputs =
187
          ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy);
X
Xin Pan 已提交
188

189 190 191 192 193 194 195 196 197 198 199 200 201
      for (const auto& map : grads_outputs) {
        for (auto it = map.rbegin(); it != map.rend(); ++it) {
          const std::vector<std::shared_ptr<VarBase>>& grad_outs = it->second;
          for (size_t i = 0; i < grad_outs.size(); ++i) {
            if (!grad_outs[i] || grad_outs[i]->IsStopGradient()) continue;
            OpBase* pre_op = grad_outs[i]->PreOp();
            if (!pre_op) continue;
            dep_counts[pre_op] -= 1;
            PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
            bool pre_op_ready = dep_counts[pre_op] == 0;
            if (pre_op_ready) {
              ready.push_back(pre_op);
            }
X
Xin Pan 已提交
202
          }
203 204
        }
      }
205 206

      ready_op->InvokeBackwardHooks();
207 208 209 210
    }
  }

 private:
211
  std::map<OpBase*, int> ComputeDepCounts(
212 213 214 215 216 217 218
      OpBase* op, const detail::BackwardStrategy& bck_stratedy,
      GradientRef* grad_ref) {
    if (bck_stratedy.sorted_sum_gradient_) {
      PADDLE_ENFORCE_NOT_NULL(grad_ref,
                              "grad_ref should not be null when "
                              "using sorted grad backward strategy");
    }
219 220 221 222 223 224 225 226 227
    std::map<OpBase*, int> ret;

    std::deque<OpBase*> queue;
    queue.push_back(op);
    std::unordered_set<OpBase*> visited;
    visited.insert(op);
    while (!queue.empty()) {
      OpBase* candidate = queue.front();
      queue.pop_front();
228 229 230 231 232
      for (const auto& map : candidate->grad_output_vars_) {
        for (const auto& it : map) {
          for (const auto& vb : it.second) {
            if (bck_stratedy.sorted_sum_gradient_) {
              ++(*grad_ref)[vb.get()].first;
233
            }
234 235
            // init the state of the grad_
            (*grad_ref)[vb.get()].second = true;
236 237 238
          }
        }
      }
X
Xin Pan 已提交
239
      for (auto it : candidate->pre_ops_) {
X
Xin Pan 已提交
240 241
        for (OpBase* pre_op : it.second) {
          if (!pre_op) continue;
242
          VLOG(2) << "op dep " << candidate->Type() << " trace id "
243
                  << candidate->trace_id_ << " <---- " << it.first << " <---- "
244
                  << pre_op->Type() << " trace id " << pre_op->trace_id_;
X
Xin Pan 已提交
245 246 247 248 249
          if (visited.find(pre_op) == visited.end()) {
            visited.insert(pre_op);
            queue.push_back(pre_op);
          }
          ret[pre_op] += 1;
250 251 252 253 254
        }
      }
    }
    return ret;
  }
255 256

  GradientRef grad_ref;
257 258
};

M
minqiyang 已提交
259 260
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
                                             const bool blocking) const {
M
minqiyang 已提交
261 262 263
  PADDLE_ENFORCE(var_->IsInitialized(),
                 "Variable must be initialized when getting numpy tensor");

264 265 266 267
  // TODO(minqiyang): change this after move unique_name generator to CXX
  const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
  std::unique_ptr<VarBase> new_var(new VarBase(
      "Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
P
Paddle CI 已提交
268 269 270
  framework::LoDTensor* tensor =
      new_var->var_->GetMutable<framework::LoDTensor>();
  tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
M
minqiyang 已提交
271

272 273
  const auto& src_tensor = var_->Get<framework::LoDTensor>();
  framework::TensorCopy(src_tensor, dst_place, tensor);
P
Paddle CI 已提交
274
  if (blocking) {
275 276 277 278 279
    platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
    auto src_place = src_tensor.place();
    if (!(src_place == dst_place)) {
      platform::DeviceContextPool::Instance().Get(src_place)->Wait();
    }
P
Paddle CI 已提交
280 281 282
  }

  if (platform::is_gpu_place(dst_place)) {
283
    VLOG(3) << "copy tensor " << Name() << " from gpu";
M
minqiyang 已提交
284 285
  }

P
Paddle CI 已提交
286
  return new_var;
M
minqiyang 已提交
287 288
}

M
minqiyang 已提交
289
framework::LoDTensor& VarBase::GradValue() {
290 291 292
  VLOG(3) << "get var grad " << Name();
  PADDLE_ENFORCE_NOT_NULL(grads_,
                          "Could not get grad value from no grad variable");
M
minqiyang 已提交
293
  return *(grads_->var_->GetMutable<framework::LoDTensor>());
294 295
}

296
std::vector<VarBasePtrMap> OpBase::ApplyGrad(
297 298
    BackwardSumMap* bck_map, GradientRef* grad_ref,
    const detail::BackwardStrategy& bck_stratedy) {
299 300
  PADDLE_ENFORCE(!grad_op_descs_.empty(), "%s has no backward implementation",
                 Type());
301
  VLOG(3) << "apply op grad: " << Type();
M
minqiyang 已提交
302
  std::vector<VarBasePtrMap> tmp_grad_outputs;
303 304 305 306 307 308 309 310 311 312 313 314 315
  const size_t grad_op_count = grad_op_descs_.size();

  tmp_grad_outputs.resize(grad_op_count);
  for (size_t k = 0; k < grad_op_count; ++k) {
    framework::OpDesc* grad_op_desc = grad_op_descs_[k];
    platform::RecordEvent record_event(grad_op_desc->Type());
    auto& grad_output_variable_map = grad_output_vars_[k];
    VLOG(3) << "apply grad op " << grad_op_desc->Type();

    // Allocate tmp grad output variable
    for (const auto& it : grad_output_variable_map) {
      auto& outputs = tmp_grad_outputs[k][it.first];
      outputs.reserve(it.second.size());
316 317
      for (const std::shared_ptr<imperative::VarBase>& origin_grad_var_base :
           it.second) {
318
        // Allocate a new variable
319
        std::shared_ptr<imperative::VarBase> tmp_grad_var_base(new VarBase(
320 321
            string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
            origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
322 323
            place_, true, false));
        outputs.emplace_back(std::move(tmp_grad_var_base));
X
polish  
Xin Pan 已提交
324
      }
325
    }
326

327 328 329
    // No need to do compile time infer shape here.
    // grad_op_desc_->InferShape(*block_);
    // grad_op_desc->InferVarType(block_);
X
Xin Pan 已提交
330

331 332
    std::unique_ptr<framework::OperatorBase> opbase =
        framework::OpRegistry::CreateOp(*grad_op_desc);
M
minqiyang 已提交
333

334 335 336
    auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
    if (info.infer_var_type_) {
      RuntimeInferVarTypeContext infer_var_type_ctx(
337
          &grad_input_vars_[k], &tmp_grad_outputs[k], &(opbase->Attrs()));
338 339
      info.infer_var_type_(&infer_var_type_ctx);
    }
M
minqiyang 已提交
340

341 342 343
    framework::OperatorWithKernel* op_kernel =
        dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
    PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
X
Xin Pan 已提交
344

345 346 347
    // Run grad op
    framework::VariableValueMap grad_invars_map;
    framework::VariableValueMap grad_outvars_map;
M
minqiyang 已提交
348

349 350 351
    for (const auto& it : grad_input_vars_[k]) {
      auto& grad_invars = grad_invars_map[it.first];
      grad_invars.reserve(it.second.size());
352
      for (const std::shared_ptr<imperative::VarBase>& grad_inp : it.second) {
353 354
        PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
                                grad_op_desc->Type(), grad_inp->Name());
355 356 357 358
        if (!grad_inp->IsInitialize()) {
          grad_inp->InitBuffer();
          ZeroGrads(grad_inp, place_);
        }
359
        const std::shared_ptr<imperative::VarBase>& const_grad_inp = grad_inp;
360
        grad_invars.emplace_back(const_grad_inp->var_.get());
M
minqiyang 已提交
361
      }
362
    }
M
minqiyang 已提交
363

364 365 366
    for (const auto& it : tmp_grad_outputs[k]) {
      auto& grad_outvars = grad_outvars_map[it.first];
      grad_outvars.reserve(it.second.size());
367
      for (const std::shared_ptr<imperative::VarBase>& grad_out : it.second) {
368 369
        PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
                                grad_op_desc->Type(), grad_out->Name());
M
minqiyang 已提交
370

371
        grad_outvars.emplace_back(grad_out->var_.get());
M
minqiyang 已提交
372
      }
X
Xin Pan 已提交
373
    }
374 375 376 377 378 379 380

    framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
    framework::Scope scope;
    PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
    p.op.RuntimeInferShape(scope, place_, ctx);
    p.func(
        framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
X
Xin Pan 已提交
381
  }
X
Xin Pan 已提交
382

C
chengduo 已提交
383
  platform::RecordEvent record_event("merge_grads");
384
  // Add tmp grad outputs to original grad vars
X
Xin Pan 已提交
385
  for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
M
minqiyang 已提交
386
    for (const auto& it : grad_output_vars_[k]) {
387
      auto& outputs = tmp_grad_outputs[k][it.first];
M
minqiyang 已提交
388
      const auto& origin_outputs = it.second;
X
Xin Pan 已提交
389 390 391
      PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());

      for (size_t i = 0; i < outputs.size(); ++i) {
392 393
        // track outputs used by sum
        if (bck_stratedy.sorted_sum_gradient_) {
394
          if (bck_map->find(origin_outputs[i].get()) != bck_map->end()) {
395
            VLOG(10) << "add sub grad to " << origin_outputs[i]->Name();
396
            bck_map->at(origin_outputs[i].get())
397
                .second.emplace_back(
398 399
                    std::pair<int, std::shared_ptr<imperative::VarBase>>(
                        this->trace_id_, std::move(outputs[i])));
400
          } else {
401
            VLOG(10) << "insert new map for " << origin_outputs[i]->Name();
402 403 404 405 406 407
            std::pair<platform::Place,
                      std::vector<
                          std::pair<int, std::shared_ptr<imperative::VarBase>>>>
                tmp(place_,
                    {std::make_pair(this->trace_id_, std::move(outputs[i]))});
            bck_map->insert(std::make_pair(origin_outputs[i].get(), tmp));
408 409
          }

410 411 412 413
          PADDLE_ENFORCE(
              grad_ref->find(origin_outputs[i].get()) != grad_ref->end(),
              "Can't find  %s in grad_reference count map",
              origin_outputs[i]->Name());
414
          PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()).first >= 1,
415
                         "Backward error when calculate grad reference");
416
          if (grad_ref->at(origin_outputs[i].get()).first > 1) {
417
            VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
418
            grad_ref->at(origin_outputs[i].get()).first--;
419
          } else {
420
            VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
421 422
            AddGradBySort(bck_map, origin_outputs[i], grad_ref);
            grad_ref->at(origin_outputs[i].get()).first--;
423 424
          }
        } else {
425 426 427
          VLOG(10) << "AddTo Called with orig_grad is: "
                   << origin_outputs[i]->name_ << " Grad to be added is "
                   << outputs[i]->name_;
428
          AddTo(outputs[i], origin_outputs[i], place_, grad_ref);
429
          outputs[i].reset();
430
        }
X
Xin Pan 已提交
431
      }
432 433
    }
  }
X
Xin Pan 已提交
434

435
  return grad_output_vars_;
436 437
}

438
void OpBase::InvokeBackwardHooks() {
M
minqiyang 已提交
439
  VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
440 441 442 443 444 445 446

  // call backward hooks
  for (py::object& callable : backward_hooks_) {
    callable(this);
  }
}

Y
Yan Xu 已提交
447
void OpBase::RegisterBackwardHooks(const py::object& callable) {
M
minqiyang 已提交
448
  VLOG(3) << "Register backward hooks " << trace_id_;
449 450

  // TODO(minqiyang): check the callable format
Y
Yan Xu 已提交
451
  backward_hooks_.push_back(callable);
452 453
}

454
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
455
  if (!pre_op_) return;
C
chengduo 已提交
456
  platform::RecordEvent record_event("Imperative Backward");
X
Xin Pan 已提交
457
  VLOG(3) << "start backward";
458
  grads_->InitBuffer();
M
minqiyang 已提交
459
  auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
M
minqiyang 已提交
460 461 462 463
  operators::math::set_constant(
      *(platform::DeviceContextPool::Instance().Get(
          var_->GetMutable<framework::LoDTensor>()->place())),
      grads_t, 1.0);
X
Xin Pan 已提交
464

465
  Autograd().RunBackward(this, bck_stratedy);
466 467 468 469
}

}  // namespace imperative
}  // namespace paddle