layer.cc 15.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/layer.h"
M
minqiyang 已提交
16

17
#include <algorithm>
18 19 20 21
#include <deque>
#include <limits>
#include <map>
#include <random>
M
minqiyang 已提交
22
#include <unordered_set>
23 24 25 26
#include <utility>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
27
#include "paddle/fluid/framework/operator.h"
M
minqiyang 已提交
28 29 30
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
C
chengduo 已提交
31
#include "paddle/fluid/platform/profiler.h"
32 33 34 35 36 37 38
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace imperative {

using framework::Variable;

M
minqiyang 已提交
39 40 41 42 43 44 45 46 47 48 49
namespace detail {

template <typename T>
class TensorAddToFunctor : public boost::static_visitor<> {
 public:
  TensorAddToFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

  void operator()(const platform::CPUPlace& place) {
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
50
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
M
minqiyang 已提交
51 52 53 54 55 56 57 58
    blas.AXPY(numel_, 1., x_, y_);
  }

#ifdef PADDLE_WITH_CUDA
  void operator()(const platform::CUDAPlace& place) {
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
59
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
M
minqiyang 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
  void operator()(const platform::CUDAPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }
#endif

  // there is NO blas in CUDAPinnedPlace
  void operator()(const platform::CUDAPinnedPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }

 private:
  int64_t numel_;
  const T* x_;
  T* y_;
};

}  // namespace detail

81 82 83 84 85 86 87
void AddTo(std::shared_ptr<VarBase> src, std::shared_ptr<VarBase> dst,
           platform::Place place) {
  if (!dst->IsInitialize()) {
    VLOG(2) << "im here1";
    PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase");
    dst->var_ = std::move(src->var_);
    dst->SetInitialize(true);
M
minqiyang 已提交
88
    return;
89 90 91 92 93 94 95 96 97 98 99
  } else {
    framework::Tensor* dst_tensor =
        dst->var_->GetMutable<framework::LoDTensor>();
    framework::Tensor* src_tensor =
        src->var_->GetMutable<framework::LoDTensor>();

    // FIXME(minqiyang): loss_grad op will pass a zero grad of label
    // ugly fix for it
    if (src_tensor->numel() == 0) {
      return;
    }
M
minqiyang 已提交
100

101 102 103
    PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
                   "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
                   src_tensor->numel());
M
minqiyang 已提交
104

105 106 107 108 109
    detail::TensorAddToFunctor<float> func(
        src_tensor->numel(), src_tensor->data<float>(),
        dst_tensor->mutable_data<float>(place));
    boost::apply_visitor(func, place);
  }
110 111
}

112 113
void ZeroGrads(const std::shared_ptr<imperative::VarBase> vb,
               const platform::Place& place) {
114 115 116 117 118 119
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  auto grad_t = vb->var_->GetMutable<framework::LoDTensor>();
  operators::math::set_constant(*dev_ctx, grad_t, 0.0);
}

120 121 122
void AddGradBySort(BackwardSumMap* bck_map,
                   std::shared_ptr<imperative::VarBase> target) {
  PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(),
123
                 "Can't find %s in backward grad map", target->Name());
124 125 126 127 128 129 130 131
  std::pair<platform::Place,
            std::vector<std::pair<int, std::shared_ptr<imperative::VarBase>>>>&
      current = bck_map->at(target.get());
  std::sort(current.second.begin(), current.second.end(),
            [](const std::pair<int, std::shared_ptr<imperative::VarBase>>& a,
               const std::pair<int, std::shared_ptr<imperative::VarBase>>& b) {
              return a.first > b.first;
            });
132
  for (auto& var_pair : current.second) {
133 134 135
    VLOG(10) << "add origin_grad: " << target->Name();
    VLOG(10) << "added grad: " << var_pair.second->Name()
             << " trace id is: " << var_pair.first;
136 137
    AddTo(var_pair.second, target, current.first);
    var_pair.second.reset();
138 139 140
  }
}

141 142
class Autograd {
 public:
X
Xin Pan 已提交
143
  Autograd() {}
144

145
  void RunBackward(VarBase* var, const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
146
    if (var->IsStopGradient()) {
147 148
      return;
    }
149
    VLOG(2) << "start autograd";
150 151
    BackwardSumMap bck_map;
    GradientRef grad_ref;
152
    std::deque<OpBase*> ready;
X
Xin Pan 已提交
153
    ready.push_back(var->PreOp());
154

155
    std::map<OpBase*, int> dep_counts =
156
        ComputeDepCounts(var->PreOp(), bck_stratedy, &grad_ref);
157 158 159 160

    while (!ready.empty()) {
      OpBase* ready_op = ready.front();
      ready.pop_front();
161
      std::vector<VarBasePtrMap> grads_outputs =
162
          ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy);
X
Xin Pan 已提交
163

164 165 166 167 168 169 170 171 172 173 174 175 176
      for (const auto& map : grads_outputs) {
        for (auto it = map.rbegin(); it != map.rend(); ++it) {
          const std::vector<std::shared_ptr<VarBase>>& grad_outs = it->second;
          for (size_t i = 0; i < grad_outs.size(); ++i) {
            if (!grad_outs[i] || grad_outs[i]->IsStopGradient()) continue;
            OpBase* pre_op = grad_outs[i]->PreOp();
            if (!pre_op) continue;
            dep_counts[pre_op] -= 1;
            PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
            bool pre_op_ready = dep_counts[pre_op] == 0;
            if (pre_op_ready) {
              ready.push_back(pre_op);
            }
X
Xin Pan 已提交
177
          }
178 179
        }
      }
180 181

      ready_op->InvokeBackwardHooks();
182 183 184 185
    }
  }

 private:
186
  std::map<OpBase*, int> ComputeDepCounts(
187 188 189 190 191 192 193
      OpBase* op, const detail::BackwardStrategy& bck_stratedy,
      GradientRef* grad_ref) {
    if (bck_stratedy.sorted_sum_gradient_) {
      PADDLE_ENFORCE_NOT_NULL(grad_ref,
                              "grad_ref should not be null when "
                              "using sorted grad backward strategy");
    }
194 195 196 197 198 199 200 201 202
    std::map<OpBase*, int> ret;

    std::deque<OpBase*> queue;
    queue.push_back(op);
    std::unordered_set<OpBase*> visited;
    visited.insert(op);
    while (!queue.empty()) {
      OpBase* candidate = queue.front();
      queue.pop_front();
203 204 205 206
      if (bck_stratedy.sorted_sum_gradient_) {
        for (const auto& map : candidate->grad_output_vars_) {
          for (const auto& it : map) {
            for (const auto& vb : it.second) {
207
              ++(*grad_ref)[vb.get()];
208 209 210 211
            }
          }
        }
      }
X
Xin Pan 已提交
212
      for (auto it : candidate->pre_ops_) {
X
Xin Pan 已提交
213 214
        for (OpBase* pre_op : it.second) {
          if (!pre_op) continue;
215
          VLOG(2) << "op dep " << candidate->Type() << " trace id "
216
                  << candidate->trace_id_ << " <---- " << it.first << " <---- "
217
                  << pre_op->Type() << " trace id " << pre_op->trace_id_;
X
Xin Pan 已提交
218 219 220 221 222
          if (visited.find(pre_op) == visited.end()) {
            visited.insert(pre_op);
            queue.push_back(pre_op);
          }
          ret[pre_op] += 1;
223 224 225 226 227 228 229
        }
      }
    }
    return ret;
  }
};

M
minqiyang 已提交
230 231
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
                                             const bool blocking) const {
M
minqiyang 已提交
232 233 234
  PADDLE_ENFORCE(var_->IsInitialized(),
                 "Variable must be initialized when getting numpy tensor");

235 236 237 238
  // TODO(minqiyang): change this after move unique_name generator to CXX
  const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
  std::unique_ptr<VarBase> new_var(new VarBase(
      "Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
P
Paddle CI 已提交
239 240 241
  framework::LoDTensor* tensor =
      new_var->var_->GetMutable<framework::LoDTensor>();
  tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
M
minqiyang 已提交
242

243 244
  const auto& src_tensor = var_->Get<framework::LoDTensor>();
  framework::TensorCopy(src_tensor, dst_place, tensor);
P
Paddle CI 已提交
245
  if (blocking) {
246 247 248 249 250
    platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
    auto src_place = src_tensor.place();
    if (!(src_place == dst_place)) {
      platform::DeviceContextPool::Instance().Get(src_place)->Wait();
    }
P
Paddle CI 已提交
251 252 253
  }

  if (platform::is_gpu_place(dst_place)) {
254
    VLOG(3) << "copy tensor " << Name() << " from gpu";
M
minqiyang 已提交
255 256
  }

P
Paddle CI 已提交
257
  return new_var;
M
minqiyang 已提交
258 259
}

M
minqiyang 已提交
260
framework::LoDTensor& VarBase::GradValue() {
261 262 263
  VLOG(3) << "get var grad " << Name();
  PADDLE_ENFORCE_NOT_NULL(grads_,
                          "Could not get grad value from no grad variable");
M
minqiyang 已提交
264
  return *(grads_->var_->GetMutable<framework::LoDTensor>());
265 266
}

267
std::vector<VarBasePtrMap> OpBase::ApplyGrad(
268 269
    BackwardSumMap* bck_map, GradientRef* grad_ref,
    const detail::BackwardStrategy& bck_stratedy) {
270 271
  PADDLE_ENFORCE(!grad_op_descs_.empty(), "%s has no backward implementation",
                 Type());
272
  VLOG(3) << "apply op grad: " << Type();
M
minqiyang 已提交
273
  std::vector<VarBasePtrMap> tmp_grad_outputs;
274 275 276 277 278 279 280 281 282 283 284 285 286
  const size_t grad_op_count = grad_op_descs_.size();

  tmp_grad_outputs.resize(grad_op_count);
  for (size_t k = 0; k < grad_op_count; ++k) {
    framework::OpDesc* grad_op_desc = grad_op_descs_[k];
    platform::RecordEvent record_event(grad_op_desc->Type());
    auto& grad_output_variable_map = grad_output_vars_[k];
    VLOG(3) << "apply grad op " << grad_op_desc->Type();

    // Allocate tmp grad output variable
    for (const auto& it : grad_output_variable_map) {
      auto& outputs = tmp_grad_outputs[k][it.first];
      outputs.reserve(it.second.size());
287 288
      for (const std::shared_ptr<imperative::VarBase>& origin_grad_var_base :
           it.second) {
289
        // Allocate a new variable
290
        std::shared_ptr<imperative::VarBase> tmp_grad_var_base(new VarBase(
291 292
            string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
            origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
293 294
            place_, true, false));
        outputs.emplace_back(std::move(tmp_grad_var_base));
X
polish  
Xin Pan 已提交
295
      }
296
    }
297

298 299 300
    // No need to do compile time infer shape here.
    // grad_op_desc_->InferShape(*block_);
    // grad_op_desc->InferVarType(block_);
X
Xin Pan 已提交
301

302 303
    std::unique_ptr<framework::OperatorBase> opbase =
        framework::OpRegistry::CreateOp(*grad_op_desc);
M
minqiyang 已提交
304

305 306 307
    auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
    if (info.infer_var_type_) {
      RuntimeInferVarTypeContext infer_var_type_ctx(
308
          &grad_input_vars_[k], &tmp_grad_outputs[k], &(opbase->Attrs()));
309 310
      info.infer_var_type_(&infer_var_type_ctx);
    }
M
minqiyang 已提交
311

312 313 314
    framework::OperatorWithKernel* op_kernel =
        dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
    PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
X
Xin Pan 已提交
315

316 317 318
    // Run grad op
    framework::VariableValueMap grad_invars_map;
    framework::VariableValueMap grad_outvars_map;
M
minqiyang 已提交
319

320 321 322
    for (const auto& it : grad_input_vars_[k]) {
      auto& grad_invars = grad_invars_map[it.first];
      grad_invars.reserve(it.second.size());
323
      for (const std::shared_ptr<imperative::VarBase>& grad_inp : it.second) {
324 325
        PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
                                grad_op_desc->Type(), grad_inp->Name());
326 327 328 329
        if (!grad_inp->IsInitialize()) {
          grad_inp->InitBuffer();
          ZeroGrads(grad_inp, place_);
        }
330
        const std::shared_ptr<imperative::VarBase>& const_grad_inp = grad_inp;
331
        grad_invars.emplace_back(const_grad_inp->var_.get());
M
minqiyang 已提交
332
      }
333
    }
M
minqiyang 已提交
334

335 336 337
    for (const auto& it : tmp_grad_outputs[k]) {
      auto& grad_outvars = grad_outvars_map[it.first];
      grad_outvars.reserve(it.second.size());
338
      for (const std::shared_ptr<imperative::VarBase>& grad_out : it.second) {
339 340
        PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
                                grad_op_desc->Type(), grad_out->Name());
M
minqiyang 已提交
341

342
        grad_outvars.emplace_back(grad_out->var_.get());
M
minqiyang 已提交
343
      }
X
Xin Pan 已提交
344
    }
345 346 347 348 349 350 351

    framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
    framework::Scope scope;
    PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
    p.op.RuntimeInferShape(scope, place_, ctx);
    p.func(
        framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
X
Xin Pan 已提交
352
  }
X
Xin Pan 已提交
353

C
chengduo 已提交
354
  platform::RecordEvent record_event("merge_grads");
355
  // Add tmp grad outputs to original grad vars
X
Xin Pan 已提交
356
  for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
M
minqiyang 已提交
357
    for (const auto& it : grad_output_vars_[k]) {
358
      auto& outputs = tmp_grad_outputs[k][it.first];
M
minqiyang 已提交
359
      const auto& origin_outputs = it.second;
X
Xin Pan 已提交
360 361 362
      PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());

      for (size_t i = 0; i < outputs.size(); ++i) {
363 364
        // track outputs used by sum
        if (bck_stratedy.sorted_sum_gradient_) {
365
          if (bck_map->find(origin_outputs[i].get()) != bck_map->end()) {
366
            VLOG(10) << "add sub grad to " << origin_outputs[i]->Name();
367
            bck_map->at(origin_outputs[i].get())
368
                .second.emplace_back(
369 370
                    std::pair<int, std::shared_ptr<imperative::VarBase>>(
                        this->trace_id_, std::move(outputs[i])));
371
          } else {
372
            VLOG(10) << "insert new map for " << origin_outputs[i]->Name();
373 374 375 376 377 378
            std::pair<platform::Place,
                      std::vector<
                          std::pair<int, std::shared_ptr<imperative::VarBase>>>>
                tmp(place_,
                    {std::make_pair(this->trace_id_, std::move(outputs[i]))});
            bck_map->insert(std::make_pair(origin_outputs[i].get(), tmp));
379 380
          }

381 382 383 384 385
          PADDLE_ENFORCE(
              grad_ref->find(origin_outputs[i].get()) != grad_ref->end(),
              "Can't find  %s in grad_reference count map",
              origin_outputs[i]->Name());
          PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()) >= 1,
386
                         "Backward error when calculate grad reference");
387
          if (grad_ref->at(origin_outputs[i].get()) > 1) {
388
            VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
389
            grad_ref->at(origin_outputs[i].get())--;
390
          } else {
391
            VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
392
            AddGradBySort(bck_map, origin_outputs[i]);
393
            grad_ref->at(origin_outputs[i].get())--;
394 395
          }
        } else {
396 397 398
          VLOG(10) << "AddTo Called with orig_grad is: "
                   << origin_outputs[i]->name_ << " Grad to be added is "
                   << outputs[i]->name_;
399 400
          AddTo(outputs[i], origin_outputs[i], place_);
          outputs[i].reset();
401
        }
X
Xin Pan 已提交
402
      }
403 404
    }
  }
X
Xin Pan 已提交
405

406
  return grad_output_vars_;
407 408
}

409
void OpBase::InvokeBackwardHooks() {
M
minqiyang 已提交
410
  VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
411 412 413 414 415 416 417

  // call backward hooks
  for (py::object& callable : backward_hooks_) {
    callable(this);
  }
}

Y
Yan Xu 已提交
418
void OpBase::RegisterBackwardHooks(const py::object& callable) {
M
minqiyang 已提交
419
  VLOG(3) << "Register backward hooks " << trace_id_;
420 421

  // TODO(minqiyang): check the callable format
Y
Yan Xu 已提交
422
  backward_hooks_.push_back(callable);
423 424
}

425
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
426
  if (!pre_op_) return;
C
chengduo 已提交
427
  platform::RecordEvent record_event("Imperative Backward");
X
Xin Pan 已提交
428
  VLOG(3) << "start backward";
429
  grads_->InitBuffer();
M
minqiyang 已提交
430
  auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
M
minqiyang 已提交
431 432 433 434
  operators::math::set_constant(
      *(platform::DeviceContextPool::Instance().Get(
          var_->GetMutable<framework::LoDTensor>()->place())),
      grads_t, 1.0);
X
Xin Pan 已提交
435

436
  Autograd().RunBackward(this, bck_stratedy);
437 438 439 440
}

}  // namespace imperative
}  // namespace paddle