layer.cc 16.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/layer.h"
M
minqiyang 已提交
16

17
#include <algorithm>
18 19 20 21
#include <deque>
#include <limits>
#include <map>
#include <random>
M
minqiyang 已提交
22
#include <unordered_set>
23 24 25 26
#include <utility>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
27
#include "paddle/fluid/framework/operator.h"
M
minqiyang 已提交
28 29 30
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
C
chengduo 已提交
31
#include "paddle/fluid/platform/profiler.h"
32 33 34 35 36 37 38
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace imperative {

using framework::Variable;

M
minqiyang 已提交
39 40 41 42 43 44 45 46 47 48 49
namespace detail {

template <typename T>
class TensorAddToFunctor : public boost::static_visitor<> {
 public:
  TensorAddToFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

  void operator()(const platform::CPUPlace& place) {
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
50
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
M
minqiyang 已提交
51 52 53 54 55 56 57 58
    blas.AXPY(numel_, 1., x_, y_);
  }

#ifdef PADDLE_WITH_CUDA
  void operator()(const platform::CUDAPlace& place) {
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
59
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
M
minqiyang 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
  void operator()(const platform::CUDAPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }
#endif

  // there is NO blas in CUDAPinnedPlace
  void operator()(const platform::CUDAPinnedPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }

 private:
  int64_t numel_;
  const T* x_;
  T* y_;
};

}  // namespace detail

81
void AddTo(std::shared_ptr<VarBase> src, std::shared_ptr<VarBase> dst,
82 83 84 85
           platform::Place place, GradientRef* grad_ref) {
  PADDLE_ENFORCE(grad_ref->find(dst.get()) != grad_ref->end(),
                 "gradient %s are not found in grad_ref", dst->Name());
  if ((*grad_ref)[dst.get()].second) {
86 87
    PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase");
    dst->var_ = std::move(src->var_);
88 89 90 91
    (*grad_ref)[dst.get()].second = false;
    if (!dst->IsInitialize()) {
      dst->SetInitialize(true);
    }
M
minqiyang 已提交
92
    return;
93 94 95 96 97 98 99 100 101 102 103
  } else {
    framework::Tensor* dst_tensor =
        dst->var_->GetMutable<framework::LoDTensor>();
    framework::Tensor* src_tensor =
        src->var_->GetMutable<framework::LoDTensor>();

    // FIXME(minqiyang): loss_grad op will pass a zero grad of label
    // ugly fix for it
    if (src_tensor->numel() == 0) {
      return;
    }
M
minqiyang 已提交
104

105 106 107
    PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
                   "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
                   src_tensor->numel());
M
minqiyang 已提交
108

109 110 111 112 113
    detail::TensorAddToFunctor<float> func(
        src_tensor->numel(), src_tensor->data<float>(),
        dst_tensor->mutable_data<float>(place));
    boost::apply_visitor(func, place);
  }
114 115
}

116 117
void ZeroGrads(const std::shared_ptr<imperative::VarBase> vb,
               const platform::Place& place) {
118 119 120 121 122 123
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  auto grad_t = vb->var_->GetMutable<framework::LoDTensor>();
  operators::math::set_constant(*dev_ctx, grad_t, 0.0);
}

124
void AddGradBySort(BackwardSumMap* bck_map,
125 126
                   std::shared_ptr<imperative::VarBase> target,
                   GradientRef* grad_ref) {
127
  PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(),
128
                 "Can't find %s in backward grad map", target->Name());
129 130 131 132 133 134 135 136
  std::pair<platform::Place,
            std::vector<std::pair<int, std::shared_ptr<imperative::VarBase>>>>&
      current = bck_map->at(target.get());
  std::sort(current.second.begin(), current.second.end(),
            [](const std::pair<int, std::shared_ptr<imperative::VarBase>>& a,
               const std::pair<int, std::shared_ptr<imperative::VarBase>>& b) {
              return a.first > b.first;
            });
137
  for (auto& var_pair : current.second) {
138 139 140
    VLOG(10) << "add origin_grad: " << target->Name();
    VLOG(10) << "added grad: " << var_pair.second->Name()
             << " trace id is: " << var_pair.first;
141
    AddTo(var_pair.second, target, current.first, grad_ref);
142
    var_pair.second.reset();
143 144 145
  }
}

146 147
class Autograd {
 public:
X
Xin Pan 已提交
148
  Autograd() {}
149

150
  void RunBackward(VarBase* var, const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
151
    if (var->IsStopGradient()) {
152 153
      return;
    }
154
    VLOG(2) << "start autograd";
155
    BackwardSumMap bck_map;
156
    std::deque<OpBase*> ready;
X
Xin Pan 已提交
157
    ready.push_back(var->PreOp());
158

159
    std::map<OpBase*, int> dep_counts =
160
        ComputeDepCounts(var->PreOp(), bck_stratedy, &grad_ref);
161 162 163 164

    while (!ready.empty()) {
      OpBase* ready_op = ready.front();
      ready.pop_front();
165
      std::vector<VarBasePtrMap> grads_outputs =
166
          ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy);
X
Xin Pan 已提交
167

168 169 170 171 172 173 174 175 176 177 178 179 180
      for (const auto& map : grads_outputs) {
        for (auto it = map.rbegin(); it != map.rend(); ++it) {
          const std::vector<std::shared_ptr<VarBase>>& grad_outs = it->second;
          for (size_t i = 0; i < grad_outs.size(); ++i) {
            if (!grad_outs[i] || grad_outs[i]->IsStopGradient()) continue;
            OpBase* pre_op = grad_outs[i]->PreOp();
            if (!pre_op) continue;
            dep_counts[pre_op] -= 1;
            PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
            bool pre_op_ready = dep_counts[pre_op] == 0;
            if (pre_op_ready) {
              ready.push_back(pre_op);
            }
X
Xin Pan 已提交
181
          }
182 183
        }
      }
184 185

      ready_op->InvokeBackwardHooks();
186 187 188 189
    }
  }

 private:
190
  std::map<OpBase*, int> ComputeDepCounts(
191 192 193 194 195 196 197
      OpBase* op, const detail::BackwardStrategy& bck_stratedy,
      GradientRef* grad_ref) {
    if (bck_stratedy.sorted_sum_gradient_) {
      PADDLE_ENFORCE_NOT_NULL(grad_ref,
                              "grad_ref should not be null when "
                              "using sorted grad backward strategy");
    }
198 199 200 201 202 203 204 205 206
    std::map<OpBase*, int> ret;

    std::deque<OpBase*> queue;
    queue.push_back(op);
    std::unordered_set<OpBase*> visited;
    visited.insert(op);
    while (!queue.empty()) {
      OpBase* candidate = queue.front();
      queue.pop_front();
207 208 209 210 211
      for (const auto& map : candidate->grad_output_vars_) {
        for (const auto& it : map) {
          for (const auto& vb : it.second) {
            if (bck_stratedy.sorted_sum_gradient_) {
              ++(*grad_ref)[vb.get()].first;
212
            }
213 214
            // init the state of the grad_
            (*grad_ref)[vb.get()].second = true;
215 216 217
          }
        }
      }
X
Xin Pan 已提交
218
      for (auto it : candidate->pre_ops_) {
X
Xin Pan 已提交
219 220
        for (OpBase* pre_op : it.second) {
          if (!pre_op) continue;
221
          VLOG(2) << "op dep " << candidate->Type() << " trace id "
222
                  << candidate->trace_id_ << " <---- " << it.first << " <---- "
223
                  << pre_op->Type() << " trace id " << pre_op->trace_id_;
X
Xin Pan 已提交
224 225 226 227 228
          if (visited.find(pre_op) == visited.end()) {
            visited.insert(pre_op);
            queue.push_back(pre_op);
          }
          ret[pre_op] += 1;
229 230 231 232 233
        }
      }
    }
    return ret;
  }
234 235

  GradientRef grad_ref;
236 237
};

M
minqiyang 已提交
238 239
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
                                             const bool blocking) const {
M
minqiyang 已提交
240 241 242
  PADDLE_ENFORCE(var_->IsInitialized(),
                 "Variable must be initialized when getting numpy tensor");

243 244 245 246
  // TODO(minqiyang): change this after move unique_name generator to CXX
  const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
  std::unique_ptr<VarBase> new_var(new VarBase(
      "Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
P
Paddle CI 已提交
247 248 249
  framework::LoDTensor* tensor =
      new_var->var_->GetMutable<framework::LoDTensor>();
  tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
M
minqiyang 已提交
250

251 252
  const auto& src_tensor = var_->Get<framework::LoDTensor>();
  framework::TensorCopy(src_tensor, dst_place, tensor);
P
Paddle CI 已提交
253
  if (blocking) {
254 255 256 257 258
    platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
    auto src_place = src_tensor.place();
    if (!(src_place == dst_place)) {
      platform::DeviceContextPool::Instance().Get(src_place)->Wait();
    }
P
Paddle CI 已提交
259 260 261
  }

  if (platform::is_gpu_place(dst_place)) {
262
    VLOG(3) << "copy tensor " << Name() << " from gpu";
M
minqiyang 已提交
263 264
  }

P
Paddle CI 已提交
265
  return new_var;
M
minqiyang 已提交
266 267
}

M
minqiyang 已提交
268
framework::LoDTensor& VarBase::GradValue() {
269 270 271
  VLOG(3) << "get var grad " << Name();
  PADDLE_ENFORCE_NOT_NULL(grads_,
                          "Could not get grad value from no grad variable");
M
minqiyang 已提交
272
  return *(grads_->var_->GetMutable<framework::LoDTensor>());
273 274
}

275
std::vector<VarBasePtrMap> OpBase::ApplyGrad(
276 277
    BackwardSumMap* bck_map, GradientRef* grad_ref,
    const detail::BackwardStrategy& bck_stratedy) {
278 279
  PADDLE_ENFORCE(!grad_op_descs_.empty(), "%s has no backward implementation",
                 Type());
280
  VLOG(3) << "apply op grad: " << Type();
M
minqiyang 已提交
281
  std::vector<VarBasePtrMap> tmp_grad_outputs;
282 283 284 285 286 287 288 289 290 291 292 293 294
  const size_t grad_op_count = grad_op_descs_.size();

  tmp_grad_outputs.resize(grad_op_count);
  for (size_t k = 0; k < grad_op_count; ++k) {
    framework::OpDesc* grad_op_desc = grad_op_descs_[k];
    platform::RecordEvent record_event(grad_op_desc->Type());
    auto& grad_output_variable_map = grad_output_vars_[k];
    VLOG(3) << "apply grad op " << grad_op_desc->Type();

    // Allocate tmp grad output variable
    for (const auto& it : grad_output_variable_map) {
      auto& outputs = tmp_grad_outputs[k][it.first];
      outputs.reserve(it.second.size());
295 296
      for (const std::shared_ptr<imperative::VarBase>& origin_grad_var_base :
           it.second) {
297
        // Allocate a new variable
298
        std::shared_ptr<imperative::VarBase> tmp_grad_var_base(new VarBase(
299 300
            string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
            origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
301 302
            place_, true, false));
        outputs.emplace_back(std::move(tmp_grad_var_base));
X
polish  
Xin Pan 已提交
303
      }
304
    }
305

306 307 308
    // No need to do compile time infer shape here.
    // grad_op_desc_->InferShape(*block_);
    // grad_op_desc->InferVarType(block_);
X
Xin Pan 已提交
309

310 311
    std::unique_ptr<framework::OperatorBase> opbase =
        framework::OpRegistry::CreateOp(*grad_op_desc);
M
minqiyang 已提交
312

313 314 315
    auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
    if (info.infer_var_type_) {
      RuntimeInferVarTypeContext infer_var_type_ctx(
316
          &grad_input_vars_[k], &tmp_grad_outputs[k], &(opbase->Attrs()));
317 318
      info.infer_var_type_(&infer_var_type_ctx);
    }
M
minqiyang 已提交
319

320 321 322
    framework::OperatorWithKernel* op_kernel =
        dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
    PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
X
Xin Pan 已提交
323

324 325 326
    // Run grad op
    framework::VariableValueMap grad_invars_map;
    framework::VariableValueMap grad_outvars_map;
M
minqiyang 已提交
327

328 329 330
    for (const auto& it : grad_input_vars_[k]) {
      auto& grad_invars = grad_invars_map[it.first];
      grad_invars.reserve(it.second.size());
331
      for (const std::shared_ptr<imperative::VarBase>& grad_inp : it.second) {
332 333
        PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
                                grad_op_desc->Type(), grad_inp->Name());
334 335 336 337
        if (!grad_inp->IsInitialize()) {
          grad_inp->InitBuffer();
          ZeroGrads(grad_inp, place_);
        }
338
        const std::shared_ptr<imperative::VarBase>& const_grad_inp = grad_inp;
339
        grad_invars.emplace_back(const_grad_inp->var_.get());
M
minqiyang 已提交
340
      }
341
    }
M
minqiyang 已提交
342

343 344 345
    for (const auto& it : tmp_grad_outputs[k]) {
      auto& grad_outvars = grad_outvars_map[it.first];
      grad_outvars.reserve(it.second.size());
346
      for (const std::shared_ptr<imperative::VarBase>& grad_out : it.second) {
347 348
        PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
                                grad_op_desc->Type(), grad_out->Name());
M
minqiyang 已提交
349

350
        grad_outvars.emplace_back(grad_out->var_.get());
M
minqiyang 已提交
351
      }
X
Xin Pan 已提交
352
    }
353 354 355 356 357 358 359

    framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
    framework::Scope scope;
    PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
    p.op.RuntimeInferShape(scope, place_, ctx);
    p.func(
        framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
X
Xin Pan 已提交
360
  }
X
Xin Pan 已提交
361

C
chengduo 已提交
362
  platform::RecordEvent record_event("merge_grads");
363
  // Add tmp grad outputs to original grad vars
X
Xin Pan 已提交
364
  for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
M
minqiyang 已提交
365
    for (const auto& it : grad_output_vars_[k]) {
366
      auto& outputs = tmp_grad_outputs[k][it.first];
M
minqiyang 已提交
367
      const auto& origin_outputs = it.second;
X
Xin Pan 已提交
368 369 370
      PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());

      for (size_t i = 0; i < outputs.size(); ++i) {
371 372
        // track outputs used by sum
        if (bck_stratedy.sorted_sum_gradient_) {
373
          if (bck_map->find(origin_outputs[i].get()) != bck_map->end()) {
374
            VLOG(10) << "add sub grad to " << origin_outputs[i]->Name();
375
            bck_map->at(origin_outputs[i].get())
376
                .second.emplace_back(
377 378
                    std::pair<int, std::shared_ptr<imperative::VarBase>>(
                        this->trace_id_, std::move(outputs[i])));
379
          } else {
380
            VLOG(10) << "insert new map for " << origin_outputs[i]->Name();
381 382 383 384 385 386
            std::pair<platform::Place,
                      std::vector<
                          std::pair<int, std::shared_ptr<imperative::VarBase>>>>
                tmp(place_,
                    {std::make_pair(this->trace_id_, std::move(outputs[i]))});
            bck_map->insert(std::make_pair(origin_outputs[i].get(), tmp));
387 388
          }

389 390 391 392
          PADDLE_ENFORCE(
              grad_ref->find(origin_outputs[i].get()) != grad_ref->end(),
              "Can't find  %s in grad_reference count map",
              origin_outputs[i]->Name());
393
          PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()).first >= 1,
394
                         "Backward error when calculate grad reference");
395
          if (grad_ref->at(origin_outputs[i].get()).first > 1) {
396
            VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
397
            grad_ref->at(origin_outputs[i].get()).first--;
398
          } else {
399
            VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
400 401
            AddGradBySort(bck_map, origin_outputs[i], grad_ref);
            grad_ref->at(origin_outputs[i].get()).first--;
402 403
          }
        } else {
404 405 406
          VLOG(10) << "AddTo Called with orig_grad is: "
                   << origin_outputs[i]->name_ << " Grad to be added is "
                   << outputs[i]->name_;
407
          AddTo(outputs[i], origin_outputs[i], place_, grad_ref);
408
          outputs[i].reset();
409
        }
X
Xin Pan 已提交
410
      }
411 412
    }
  }
X
Xin Pan 已提交
413

414
  return grad_output_vars_;
415 416
}

417
void OpBase::InvokeBackwardHooks() {
M
minqiyang 已提交
418
  VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
419 420 421 422 423 424 425

  // call backward hooks
  for (py::object& callable : backward_hooks_) {
    callable(this);
  }
}

Y
Yan Xu 已提交
426
void OpBase::RegisterBackwardHooks(const py::object& callable) {
M
minqiyang 已提交
427
  VLOG(3) << "Register backward hooks " << trace_id_;
428 429

  // TODO(minqiyang): check the callable format
Y
Yan Xu 已提交
430
  backward_hooks_.push_back(callable);
431 432
}

433
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
434
  if (!pre_op_) return;
C
chengduo 已提交
435
  platform::RecordEvent record_event("Imperative Backward");
X
Xin Pan 已提交
436
  VLOG(3) << "start backward";
437
  grads_->InitBuffer();
M
minqiyang 已提交
438
  auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
M
minqiyang 已提交
439 440 441 442
  operators::math::set_constant(
      *(platform::DeviceContextPool::Instance().Get(
          var_->GetMutable<framework::LoDTensor>()->place())),
      grads_t, 1.0);
X
Xin Pan 已提交
443

444
  Autograd().RunBackward(this, bck_stratedy);
445 446 447 448
}

}  // namespace imperative
}  // namespace paddle