layer.cc 15.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/layer.h"
M
minqiyang 已提交
16

17 18 19 20
#include <deque>
#include <limits>
#include <map>
#include <random>
M
minqiyang 已提交
21
#include <unordered_set>
22 23 24 25
#include <utility>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
26
#include "paddle/fluid/framework/operator.h"
M
minqiyang 已提交
27 28 29
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
C
chengduo 已提交
30
#include "paddle/fluid/platform/profiler.h"
31 32 33 34 35 36 37
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace imperative {

using framework::Variable;

M
minqiyang 已提交
38 39 40 41 42 43 44 45 46 47 48
namespace detail {

template <typename T>
class TensorAddToFunctor : public boost::static_visitor<> {
 public:
  TensorAddToFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

  void operator()(const platform::CPUPlace& place) {
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
49
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
M
minqiyang 已提交
50 51 52 53 54 55 56 57
    blas.AXPY(numel_, 1., x_, y_);
  }

#ifdef PADDLE_WITH_CUDA
  void operator()(const platform::CUDAPlace& place) {
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
58
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
M
minqiyang 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
  void operator()(const platform::CUDAPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }
#endif

  // there is NO blas in CUDAPinnedPlace
  void operator()(const platform::CUDAPinnedPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }

 private:
  int64_t numel_;
  const T* x_;
  T* y_;
};

}  // namespace detail

P
Paddle CI 已提交
80
void AddTo(Variable* src, Variable* dst, platform::Place place) {
M
minqiyang 已提交
81 82 83
  framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
  framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>();

M
minqiyang 已提交
84 85 86 87 88
  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (src_tensor->numel() == 0) {
    return;
  }
M
minqiyang 已提交
89

90 91 92
  PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
                 "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
                 src_tensor->numel());
M
minqiyang 已提交
93 94 95 96 97

  detail::TensorAddToFunctor<float> func(
      src_tensor->numel(), src_tensor->data<float>(),
      dst_tensor->mutable_data<float>(place));
  boost::apply_visitor(func, place);
98 99
}

100 101 102 103 104 105 106 107 108 109 110
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
  PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(),
                 "Can't find %s in backward grad map", target->Name());
  std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>& current =
      bck_map->at(target);
  std::sort(
      current.second.begin(), current.second.end(),
      [](const std::pair<int, VarBase*>& a, const std::pair<int, VarBase*>& b) {
        return a.first > b.first;
      });
  for (auto& var_pair : current.second) {
111 112
    Variable* origin_grad = target->var_.get();
    Variable* grad_to_add = var_pair.second->var_.get();
113 114 115 116
    VLOG(2) << "add origin_grad: " << target->Name();
    VLOG(2) << "added grad: " << var_pair.second->Name()
            << " trace id is: " << var_pair.first;
    AddTo(grad_to_add, origin_grad, current.first);
117
    delete var_pair.second;
118 119 120 121
    var_pair.second = nullptr;
  }
}

122 123
class Autograd {
 public:
X
Xin Pan 已提交
124
  Autograd() {}
125

126
  void RunBackward(VarBase* var, const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
127
    if (var->IsStopGradient()) {
128 129
      return;
    }
X
Xin Pan 已提交
130
    VLOG(3) << "start autograd";
131 132
    BackwardSumMap bck_map;
    GradientRef grad_ref;
133
    std::deque<OpBase*> ready;
X
Xin Pan 已提交
134
    ready.push_back(var->PreOp());
135

136
    std::map<OpBase*, int> dep_counts =
137
        ComputeDepCounts(var->PreOp(), bck_stratedy, &grad_ref);
138 139 140 141

    while (!ready.empty()) {
      OpBase* ready_op = ready.front();
      ready.pop_front();
X
Xin Pan 已提交
142
      std::map<std::string, std::vector<VarBase*>> input_grads =
143
          ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy);
X
Xin Pan 已提交
144

145 146
      for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) {
        const std::vector<VarBase*>& ingrads = it->second;
X
Xin Pan 已提交
147 148
        for (size_t i = 0; i < ingrads.size(); ++i) {
          if (!ingrads[i]) continue;
Y
Yan Xu 已提交
149 150 151
          auto p = ready_op->input_vars_[it->first][i];

          if (p->IsStopGradient()) continue;
152
          OpBase* pre_op = ready_op->pre_ops_[it->first][i];
X
Xin Pan 已提交
153 154 155 156 157 158 159 160
          if (!pre_op) continue;

          dep_counts[pre_op] -= 1;
          PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
          bool pre_op_ready = dep_counts[pre_op] == 0;
          if (pre_op_ready) {
            ready.push_back(pre_op);
          }
161 162
        }
      }
163 164

      ready_op->InvokeBackwardHooks();
165 166 167 168
    }
  }

 private:
169
  std::map<OpBase*, int> ComputeDepCounts(
170 171 172 173 174 175 176
      OpBase* op, const detail::BackwardStrategy& bck_stratedy,
      GradientRef* grad_ref) {
    if (bck_stratedy.sorted_sum_gradient_) {
      PADDLE_ENFORCE_NOT_NULL(grad_ref,
                              "grad_ref should not be null when "
                              "using sorted grad backward strategy");
    }
177 178 179 180 181 182 183 184 185
    std::map<OpBase*, int> ret;

    std::deque<OpBase*> queue;
    queue.push_back(op);
    std::unordered_set<OpBase*> visited;
    visited.insert(op);
    while (!queue.empty()) {
      OpBase* candidate = queue.front();
      queue.pop_front();
186 187 188 189
      if (bck_stratedy.sorted_sum_gradient_) {
        for (const auto& map : candidate->grad_output_vars_) {
          for (const auto& it : map) {
            for (const auto& vb : it.second) {
190
              ++(*grad_ref)[vb];
191 192 193 194
            }
          }
        }
      }
X
Xin Pan 已提交
195
      for (auto it : candidate->pre_ops_) {
X
Xin Pan 已提交
196 197
        for (OpBase* pre_op : it.second) {
          if (!pre_op) continue;
198
          VLOG(2) << "op dep " << candidate->Type() << " trace id "
199
                  << candidate->trace_id_ << " <---- " << it.first << " <---- "
200
                  << pre_op->Type() << " trace id " << pre_op->trace_id_;
X
Xin Pan 已提交
201 202 203 204 205
          if (visited.find(pre_op) == visited.end()) {
            visited.insert(pre_op);
            queue.push_back(pre_op);
          }
          ret[pre_op] += 1;
206 207 208 209 210 211 212
        }
      }
    }
    return ret;
  }
};

M
minqiyang 已提交
213 214
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
                                             const bool blocking) const {
M
minqiyang 已提交
215 216 217
  PADDLE_ENFORCE(var_->IsInitialized(),
                 "Variable must be initialized when getting numpy tensor");

218 219 220 221
  // TODO(minqiyang): change this after move unique_name generator to CXX
  const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
  std::unique_ptr<VarBase> new_var(new VarBase(
      "Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
P
Paddle CI 已提交
222 223 224
  framework::LoDTensor* tensor =
      new_var->var_->GetMutable<framework::LoDTensor>();
  tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
M
minqiyang 已提交
225

226 227
  const auto& src_tensor = var_->Get<framework::LoDTensor>();
  framework::TensorCopy(src_tensor, dst_place, tensor);
P
Paddle CI 已提交
228
  if (blocking) {
229 230 231 232 233
    platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
    auto src_place = src_tensor.place();
    if (!(src_place == dst_place)) {
      platform::DeviceContextPool::Instance().Get(src_place)->Wait();
    }
P
Paddle CI 已提交
234 235 236
  }

  if (platform::is_gpu_place(dst_place)) {
237
    VLOG(3) << "copy tensor " << Name() << " from gpu";
M
minqiyang 已提交
238 239
  }

P
Paddle CI 已提交
240
  return new_var;
M
minqiyang 已提交
241 242
}

M
minqiyang 已提交
243
framework::LoDTensor& VarBase::GradValue() {
244 245 246
  VLOG(3) << "get var grad " << Name();
  PADDLE_ENFORCE_NOT_NULL(grads_,
                          "Could not get grad value from no grad variable");
M
minqiyang 已提交
247
  return *(grads_->var_->GetMutable<framework::LoDTensor>());
248 249
}

250 251 252
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
    BackwardSumMap* bck_map, GradientRef* grad_ref,
    const detail::BackwardStrategy& bck_stratedy) {
253 254
  PADDLE_ENFORCE(!grad_op_descs_.empty(), "%s has no backward implementation",
                 Type());
255
  VLOG(3) << "apply op grad: " << Type();
M
minqiyang 已提交
256
  std::vector<VarBasePtrMap> tmp_grad_outputs;
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
  const size_t grad_op_count = grad_op_descs_.size();

  tmp_grad_outputs.resize(grad_op_count);
  for (size_t k = 0; k < grad_op_count; ++k) {
    framework::OpDesc* grad_op_desc = grad_op_descs_[k];
    platform::RecordEvent record_event(grad_op_desc->Type());
    auto& grad_output_variable_map = grad_output_vars_[k];
    VLOG(3) << "apply grad op " << grad_op_desc->Type();

    // Allocate tmp grad output variable
    for (const auto& it : grad_output_variable_map) {
      auto& outputs = tmp_grad_outputs[k][it.first];
      outputs.reserve(it.second.size());
      for (size_t i = 0; i < it.second.size(); ++i) {
        VarBase* origin_grad_var_base = it.second[i];

        // Allocate a new variable
        VarBase* tmp_grad_var_base = new VarBase(
            string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
            origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
            place_, true, false);
        outputs.emplace_back(tmp_grad_var_base);
X
polish  
Xin Pan 已提交
279
      }
280
    }
281

282 283 284
    // No need to do compile time infer shape here.
    // grad_op_desc_->InferShape(*block_);
    // grad_op_desc->InferVarType(block_);
X
Xin Pan 已提交
285

286 287
    std::unique_ptr<framework::OperatorBase> opbase =
        framework::OpRegistry::CreateOp(*grad_op_desc);
M
minqiyang 已提交
288

289 290 291 292 293 294
    auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
    if (info.infer_var_type_) {
      RuntimeInferVarTypeContext infer_var_type_ctx(
          &grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
      info.infer_var_type_(&infer_var_type_ctx);
    }
M
minqiyang 已提交
295

296 297 298
    framework::OperatorWithKernel* op_kernel =
        dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
    PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
X
Xin Pan 已提交
299

300 301 302
    // Run grad op
    framework::VariableValueMap grad_invars_map;
    framework::VariableValueMap grad_outvars_map;
M
minqiyang 已提交
303

304 305 306 307 308 309
    for (const auto& it : grad_input_vars_[k]) {
      auto& grad_invars = grad_invars_map[it.first];
      grad_invars.reserve(it.second.size());
      for (const VarBase* grad_inp : it.second) {
        PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
                                grad_op_desc->Type(), grad_inp->Name());
M
minqiyang 已提交
310

311
        grad_invars.emplace_back(grad_inp->var_.get());
M
minqiyang 已提交
312
      }
313
    }
M
minqiyang 已提交
314

315 316 317 318 319 320
    for (const auto& it : tmp_grad_outputs[k]) {
      auto& grad_outvars = grad_outvars_map[it.first];
      grad_outvars.reserve(it.second.size());
      for (VarBase* grad_out : it.second) {
        PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
                                grad_op_desc->Type(), grad_out->Name());
M
minqiyang 已提交
321

322
        grad_outvars.emplace_back(grad_out->var_.get());
M
minqiyang 已提交
323
      }
X
Xin Pan 已提交
324
    }
325 326 327 328 329 330 331

    framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
    framework::Scope scope;
    PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
    p.op.RuntimeInferShape(scope, place_, ctx);
    p.func(
        framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
X
Xin Pan 已提交
332
  }
X
Xin Pan 已提交
333

C
chengduo 已提交
334
  platform::RecordEvent record_event("merge_grads");
335
  // Add tmp grad outputs to original grad vars
X
Xin Pan 已提交
336
  for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
M
minqiyang 已提交
337
    for (const auto& it : grad_output_vars_[k]) {
338
      auto& outputs = tmp_grad_outputs[k][it.first];
M
minqiyang 已提交
339
      const auto& origin_outputs = it.second;
X
Xin Pan 已提交
340 341 342
      PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());

      for (size_t i = 0; i < outputs.size(); ++i) {
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
        // track outputs used by sum
        if (bck_stratedy.sorted_sum_gradient_) {
#ifndef PADDLE_WITH_CUDA
          VLOG(2) << "origin_outputs is : " << origin_outputs[i]->Name() << " ";
          VLOG(2) << origin_outputs[i]
                         ->var_->GetMutable<framework::LoDTensor>()
                         ->data<float>()[0];
          VLOG(2) << "outputs is : " << outputs[i]->Name() << " ";
          VLOG(2) << outputs[i]
                         ->var_->GetMutable<framework::LoDTensor>()
                         ->data<float>()[0];
#endif
          if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
            VLOG(2) << "add sub grad to " << origin_outputs[i]->Name();
            bck_map->at(origin_outputs[i])
                .second.emplace_back(
                    std::pair<int, VarBase*>(this->trace_id_, outputs[i]));
          } else {
            VLOG(2) << "insert new map for " << origin_outputs[i]->Name();
            std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>
                tmp(place_, {std::make_pair(this->trace_id_, outputs[i])});
            bck_map->insert(std::make_pair(origin_outputs[i], tmp));
          }

          PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(),
                         "Can't find  %s in grad_reference count map",
                         origin_outputs[i]->Name());
          PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1,
                         "Backward error when calculate grad reference");
          if (grad_ref->at(origin_outputs[i]) > 1) {
            VLOG(2) << "remove ref for " << origin_outputs[i]->Name();
            grad_ref->at(origin_outputs[i])--;
          } else {
            VLOG(2) << "Add grad for: " << origin_outputs[i]->Name();
            AddGradBySort(bck_map, origin_outputs[i]);
            grad_ref->at(origin_outputs[i])--;
          }
        } else {
381 382
          framework::Variable* grad = outputs[i]->var_.get();
          framework::Variable* orig_grad = origin_outputs[i]->var_.get();
383 384 385 386
          VLOG(2) << "AddTo Called with orig_grad is: "
                  << origin_outputs[i]->name_ << " Grad to be added is "
                  << outputs[i]->name_;
          AddTo(grad, orig_grad, place_);
387
          delete outputs[i];
388
        }
X
Xin Pan 已提交
389
      }
390 391
    }
  }
X
Xin Pan 已提交
392

X
Xin Pan 已提交
393
  return input_vars_;
394 395
}

396
void OpBase::InvokeBackwardHooks() {
M
minqiyang 已提交
397
  VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
398 399 400 401 402 403 404

  // call backward hooks
  for (py::object& callable : backward_hooks_) {
    callable(this);
  }
}

Y
Yan Xu 已提交
405
void OpBase::RegisterBackwardHooks(const py::object& callable) {
M
minqiyang 已提交
406
  VLOG(3) << "Register backward hooks " << trace_id_;
407 408

  // TODO(minqiyang): check the callable format
Y
Yan Xu 已提交
409
  backward_hooks_.push_back(callable);
410 411
}

412
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
413
  if (!pre_op_) return;
C
chengduo 已提交
414
  platform::RecordEvent record_event("Imperative Backward");
X
Xin Pan 已提交
415
  VLOG(3) << "start backward";
M
minqiyang 已提交
416
  auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
M
minqiyang 已提交
417 418 419 420
  operators::math::set_constant(
      *(platform::DeviceContextPool::Instance().Get(
          var_->GetMutable<framework::LoDTensor>()->place())),
      grads_t, 1.0);
X
Xin Pan 已提交
421

X
Xin Pan 已提交
422 423 424
  PADDLE_ENFORCE(
      grads_ ==
      pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
425
  Autograd().RunBackward(this, bck_stratedy);
426 427 428 429
}

}  // namespace imperative
}  // namespace paddle