layer.cc 15.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/layer.h"
M
minqiyang 已提交
16

17 18 19 20
#include <deque>
#include <limits>
#include <map>
#include <random>
M
minqiyang 已提交
21
#include <unordered_set>
22 23 24 25
#include <utility>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
26
#include "paddle/fluid/framework/operator.h"
M
minqiyang 已提交
27 28 29
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
C
chengduo 已提交
30
#include "paddle/fluid/platform/profiler.h"
31 32 33 34 35 36 37
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace imperative {

using framework::Variable;

M
minqiyang 已提交
38 39 40 41 42 43 44 45 46 47 48
namespace detail {

template <typename T>
class TensorAddToFunctor : public boost::static_visitor<> {
 public:
  TensorAddToFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

  void operator()(const platform::CPUPlace& place) {
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
49
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
M
minqiyang 已提交
50 51 52 53 54 55 56 57
    blas.AXPY(numel_, 1., x_, y_);
  }

#ifdef PADDLE_WITH_CUDA
  void operator()(const platform::CUDAPlace& place) {
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
58
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
M
minqiyang 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
  void operator()(const platform::CUDAPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }
#endif

  // there is NO blas in CUDAPinnedPlace
  void operator()(const platform::CUDAPinnedPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }

 private:
  int64_t numel_;
  const T* x_;
  T* y_;
};

}  // namespace detail

P
Paddle CI 已提交
80
void AddTo(Variable* src, Variable* dst, platform::Place place) {
M
minqiyang 已提交
81 82 83
  framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
  framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>();

M
minqiyang 已提交
84 85 86 87 88
  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (src_tensor->numel() == 0) {
    return;
  }
M
minqiyang 已提交
89

90 91 92
  PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
                 "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
                 src_tensor->numel());
M
minqiyang 已提交
93 94 95 96 97

  detail::TensorAddToFunctor<float> func(
      src_tensor->numel(), src_tensor->data<float>(),
      dst_tensor->mutable_data<float>(place));
  boost::apply_visitor(func, place);
98 99
}

100 101 102 103 104 105 106
void ZeroGrads(VarBase* vb, const platform::Place& place) {
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  auto grad_t = vb->var_->GetMutable<framework::LoDTensor>();
  operators::math::set_constant(*dev_ctx, grad_t, 0.0);
}

107 108 109 110 111 112 113 114 115 116 117
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
  PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(),
                 "Can't find %s in backward grad map", target->Name());
  std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>& current =
      bck_map->at(target);
  std::sort(
      current.second.begin(), current.second.end(),
      [](const std::pair<int, VarBase*>& a, const std::pair<int, VarBase*>& b) {
        return a.first > b.first;
      });
  for (auto& var_pair : current.second) {
118 119
    Variable* origin_grad = target->var_.get();
    Variable* grad_to_add = var_pair.second->var_.get();
120 121 122
    VLOG(10) << "add origin_grad: " << target->Name();
    VLOG(10) << "added grad: " << var_pair.second->Name()
             << " trace id is: " << var_pair.first;
123
    AddTo(grad_to_add, origin_grad, current.first);
124
    delete var_pair.second;
125 126 127 128
    var_pair.second = nullptr;
  }
}

129 130
class Autograd {
 public:
X
Xin Pan 已提交
131
  Autograd() {}
132

133
  void RunBackward(VarBase* var, const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
134
    if (var->IsStopGradient()) {
135 136
      return;
    }
137
    VLOG(2) << "start autograd";
138 139
    BackwardSumMap bck_map;
    GradientRef grad_ref;
140
    std::deque<OpBase*> ready;
X
Xin Pan 已提交
141
    ready.push_back(var->PreOp());
142

143
    std::map<OpBase*, int> dep_counts =
144
        ComputeDepCounts(var->PreOp(), bck_stratedy, &grad_ref);
145 146 147 148

    while (!ready.empty()) {
      OpBase* ready_op = ready.front();
      ready.pop_front();
X
Xin Pan 已提交
149
      std::map<std::string, std::vector<VarBase*>> input_grads =
150
          ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy);
X
Xin Pan 已提交
151

152 153
      for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) {
        const std::vector<VarBase*>& ingrads = it->second;
X
Xin Pan 已提交
154 155
        for (size_t i = 0; i < ingrads.size(); ++i) {
          if (!ingrads[i]) continue;
Y
Yan Xu 已提交
156 157 158
          auto p = ready_op->input_vars_[it->first][i];

          if (p->IsStopGradient()) continue;
159
          OpBase* pre_op = ready_op->pre_ops_[it->first][i];
X
Xin Pan 已提交
160 161 162 163 164 165 166 167
          if (!pre_op) continue;

          dep_counts[pre_op] -= 1;
          PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
          bool pre_op_ready = dep_counts[pre_op] == 0;
          if (pre_op_ready) {
            ready.push_back(pre_op);
          }
168 169
        }
      }
170 171

      ready_op->InvokeBackwardHooks();
172 173 174 175
    }
  }

 private:
176
  std::map<OpBase*, int> ComputeDepCounts(
177 178 179 180 181 182 183
      OpBase* op, const detail::BackwardStrategy& bck_stratedy,
      GradientRef* grad_ref) {
    if (bck_stratedy.sorted_sum_gradient_) {
      PADDLE_ENFORCE_NOT_NULL(grad_ref,
                              "grad_ref should not be null when "
                              "using sorted grad backward strategy");
    }
184 185 186 187 188 189 190 191 192
    std::map<OpBase*, int> ret;

    std::deque<OpBase*> queue;
    queue.push_back(op);
    std::unordered_set<OpBase*> visited;
    visited.insert(op);
    while (!queue.empty()) {
      OpBase* candidate = queue.front();
      queue.pop_front();
193 194 195 196
      if (bck_stratedy.sorted_sum_gradient_) {
        for (const auto& map : candidate->grad_output_vars_) {
          for (const auto& it : map) {
            for (const auto& vb : it.second) {
197
              ++(*grad_ref)[vb];
198 199 200 201
            }
          }
        }
      }
X
Xin Pan 已提交
202
      for (auto it : candidate->pre_ops_) {
X
Xin Pan 已提交
203 204
        for (OpBase* pre_op : it.second) {
          if (!pre_op) continue;
205
          VLOG(9) << "op dep " << candidate->Type() << " trace id "
206
                  << candidate->trace_id_ << " <---- " << it.first << " <---- "
207
                  << pre_op->Type() << " trace id " << pre_op->trace_id_;
X
Xin Pan 已提交
208 209 210 211 212
          if (visited.find(pre_op) == visited.end()) {
            visited.insert(pre_op);
            queue.push_back(pre_op);
          }
          ret[pre_op] += 1;
213 214 215 216 217 218 219
        }
      }
    }
    return ret;
  }
};

M
minqiyang 已提交
220 221
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
                                             const bool blocking) const {
M
minqiyang 已提交
222 223 224
  PADDLE_ENFORCE(var_->IsInitialized(),
                 "Variable must be initialized when getting numpy tensor");

225 226 227 228
  // TODO(minqiyang): change this after move unique_name generator to CXX
  const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
  std::unique_ptr<VarBase> new_var(new VarBase(
      "Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
P
Paddle CI 已提交
229 230 231
  framework::LoDTensor* tensor =
      new_var->var_->GetMutable<framework::LoDTensor>();
  tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
M
minqiyang 已提交
232

233 234
  const auto& src_tensor = var_->Get<framework::LoDTensor>();
  framework::TensorCopy(src_tensor, dst_place, tensor);
P
Paddle CI 已提交
235
  if (blocking) {
236 237 238 239 240
    platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
    auto src_place = src_tensor.place();
    if (!(src_place == dst_place)) {
      platform::DeviceContextPool::Instance().Get(src_place)->Wait();
    }
P
Paddle CI 已提交
241 242 243
  }

  if (platform::is_gpu_place(dst_place)) {
244
    VLOG(3) << "copy tensor " << Name() << " from gpu";
M
minqiyang 已提交
245 246
  }

P
Paddle CI 已提交
247
  return new_var;
M
minqiyang 已提交
248 249
}

M
minqiyang 已提交
250
framework::LoDTensor& VarBase::GradValue() {
251 252 253
  VLOG(3) << "get var grad " << Name();
  PADDLE_ENFORCE_NOT_NULL(grads_,
                          "Could not get grad value from no grad variable");
M
minqiyang 已提交
254
  return *(grads_->var_->GetMutable<framework::LoDTensor>());
255 256
}

257 258 259
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
    BackwardSumMap* bck_map, GradientRef* grad_ref,
    const detail::BackwardStrategy& bck_stratedy) {
260 261
  PADDLE_ENFORCE(!grad_op_descs_.empty(), "%s has no backward implementation",
                 Type());
262
  VLOG(3) << "apply op grad: " << Type();
M
minqiyang 已提交
263
  std::vector<VarBasePtrMap> tmp_grad_outputs;
264 265 266 267 268 269 270 271 272 273 274 275 276
  const size_t grad_op_count = grad_op_descs_.size();

  tmp_grad_outputs.resize(grad_op_count);
  for (size_t k = 0; k < grad_op_count; ++k) {
    framework::OpDesc* grad_op_desc = grad_op_descs_[k];
    platform::RecordEvent record_event(grad_op_desc->Type());
    auto& grad_output_variable_map = grad_output_vars_[k];
    VLOG(3) << "apply grad op " << grad_op_desc->Type();

    // Allocate tmp grad output variable
    for (const auto& it : grad_output_variable_map) {
      auto& outputs = tmp_grad_outputs[k][it.first];
      outputs.reserve(it.second.size());
277 278 279 280 281
      for (VarBase* origin_grad_var_base : it.second) {
        if (!origin_grad_var_base->IsInitialize()) {
          origin_grad_var_base->InitBuffer();
          ZeroGrads(origin_grad_var_base, place_);
        }
282 283 284 285 286 287
        // Allocate a new variable
        VarBase* tmp_grad_var_base = new VarBase(
            string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
            origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
            place_, true, false);
        outputs.emplace_back(tmp_grad_var_base);
X
polish  
Xin Pan 已提交
288
      }
289
    }
290

291 292 293
    // No need to do compile time infer shape here.
    // grad_op_desc_->InferShape(*block_);
    // grad_op_desc->InferVarType(block_);
X
Xin Pan 已提交
294

295 296
    std::unique_ptr<framework::OperatorBase> opbase =
        framework::OpRegistry::CreateOp(*grad_op_desc);
M
minqiyang 已提交
297

298 299 300 301 302 303
    auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
    if (info.infer_var_type_) {
      RuntimeInferVarTypeContext infer_var_type_ctx(
          &grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
      info.infer_var_type_(&infer_var_type_ctx);
    }
M
minqiyang 已提交
304

305 306 307
    framework::OperatorWithKernel* op_kernel =
        dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
    PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
X
Xin Pan 已提交
308

309 310 311
    // Run grad op
    framework::VariableValueMap grad_invars_map;
    framework::VariableValueMap grad_outvars_map;
M
minqiyang 已提交
312

313 314 315
    for (const auto& it : grad_input_vars_[k]) {
      auto& grad_invars = grad_invars_map[it.first];
      grad_invars.reserve(it.second.size());
316
      for (VarBase* grad_inp : it.second) {
317 318
        PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
                                grad_op_desc->Type(), grad_inp->Name());
319 320 321 322 323 324
        if (!grad_inp->IsInitialize()) {
          grad_inp->InitBuffer();
          ZeroGrads(grad_inp, place_);
        }
        const VarBase* const_grad_inp = grad_inp;
        grad_invars.emplace_back(const_grad_inp->var_.get());
M
minqiyang 已提交
325
      }
326
    }
M
minqiyang 已提交
327

328 329 330 331 332 333
    for (const auto& it : tmp_grad_outputs[k]) {
      auto& grad_outvars = grad_outvars_map[it.first];
      grad_outvars.reserve(it.second.size());
      for (VarBase* grad_out : it.second) {
        PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
                                grad_op_desc->Type(), grad_out->Name());
M
minqiyang 已提交
334

335
        grad_outvars.emplace_back(grad_out->var_.get());
M
minqiyang 已提交
336
      }
X
Xin Pan 已提交
337
    }
338 339 340 341 342 343 344

    framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
    framework::Scope scope;
    PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
    p.op.RuntimeInferShape(scope, place_, ctx);
    p.func(
        framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
X
Xin Pan 已提交
345
  }
X
Xin Pan 已提交
346

C
chengduo 已提交
347
  platform::RecordEvent record_event("merge_grads");
348
  // Add tmp grad outputs to original grad vars
X
Xin Pan 已提交
349
  for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
M
minqiyang 已提交
350
    for (const auto& it : grad_output_vars_[k]) {
351
      auto& outputs = tmp_grad_outputs[k][it.first];
M
minqiyang 已提交
352
      const auto& origin_outputs = it.second;
X
Xin Pan 已提交
353 354 355
      PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());

      for (size_t i = 0; i < outputs.size(); ++i) {
356 357 358
        // track outputs used by sum
        if (bck_stratedy.sorted_sum_gradient_) {
#ifndef PADDLE_WITH_CUDA
359 360 361 362 363 364 365 366 367
          VLOG(10) << "origin_outputs is : " << origin_outputs[i]->Name()
                   << " ";
          VLOG(10) << origin_outputs[i]
                          ->var_->GetMutable<framework::LoDTensor>()
                          ->data<float>()[0];
          VLOG(10) << "outputs is : " << outputs[i]->Name() << " ";
          VLOG(10) << outputs[i]
                          ->var_->GetMutable<framework::LoDTensor>()
                          ->data<float>()[0];
368 369
#endif
          if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
370
            VLOG(10) << "add sub grad to " << origin_outputs[i]->Name();
371 372 373 374
            bck_map->at(origin_outputs[i])
                .second.emplace_back(
                    std::pair<int, VarBase*>(this->trace_id_, outputs[i]));
          } else {
375
            VLOG(10) << "insert new map for " << origin_outputs[i]->Name();
376 377 378 379 380 381 382 383 384 385 386
            std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>
                tmp(place_, {std::make_pair(this->trace_id_, outputs[i])});
            bck_map->insert(std::make_pair(origin_outputs[i], tmp));
          }

          PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(),
                         "Can't find  %s in grad_reference count map",
                         origin_outputs[i]->Name());
          PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1,
                         "Backward error when calculate grad reference");
          if (grad_ref->at(origin_outputs[i]) > 1) {
387
            VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
388 389
            grad_ref->at(origin_outputs[i])--;
          } else {
390
            VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
391 392 393 394
            AddGradBySort(bck_map, origin_outputs[i]);
            grad_ref->at(origin_outputs[i])--;
          }
        } else {
395 396
          framework::Variable* grad = outputs[i]->var_.get();
          framework::Variable* orig_grad = origin_outputs[i]->var_.get();
397 398 399
          VLOG(10) << "AddTo Called with orig_grad is: "
                   << origin_outputs[i]->name_ << " Grad to be added is "
                   << outputs[i]->name_;
400
          AddTo(grad, orig_grad, place_);
401
          delete outputs[i];
402
        }
X
Xin Pan 已提交
403
      }
404 405
    }
  }
X
Xin Pan 已提交
406

X
Xin Pan 已提交
407
  return input_vars_;
408 409
}

410
void OpBase::InvokeBackwardHooks() {
M
minqiyang 已提交
411
  VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
412 413 414 415 416 417 418

  // call backward hooks
  for (py::object& callable : backward_hooks_) {
    callable(this);
  }
}

Y
Yan Xu 已提交
419
void OpBase::RegisterBackwardHooks(const py::object& callable) {
M
minqiyang 已提交
420
  VLOG(3) << "Register backward hooks " << trace_id_;
421 422

  // TODO(minqiyang): check the callable format
Y
Yan Xu 已提交
423
  backward_hooks_.push_back(callable);
424 425
}

426
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
427
  if (!pre_op_) return;
C
chengduo 已提交
428
  platform::RecordEvent record_event("Imperative Backward");
X
Xin Pan 已提交
429
  VLOG(3) << "start backward";
430
  grads_->InitBuffer();
M
minqiyang 已提交
431
  auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
M
minqiyang 已提交
432 433 434 435
  operators::math::set_constant(
      *(platform::DeviceContextPool::Instance().Get(
          var_->GetMutable<framework::LoDTensor>()->place())),
      grads_t, 1.0);
X
Xin Pan 已提交
436

X
Xin Pan 已提交
437 438 439
  PADDLE_ENFORCE(
      grads_ ==
      pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
440
  Autograd().RunBackward(this, bck_stratedy);
441 442 443 444
}

}  // namespace imperative
}  // namespace paddle