未验证 提交 3b70f870 编写于 作者: J Jiabin Yang 提交者: GitHub

Using Smart pointer to optimizer memory usage of dyGraph (#17768)

* for debug

* test=develop, memory optimize for dygraph using shared_ptr

* test=develop, fix travis ci showed error

* test=develop, fix bug for recurrent usage of varbase

* test=develop, init varbase when it need to be Add
上级 82358bfd
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include <algorithm>
#include <deque> #include <deque>
#include <limits> #include <limits>
#include <map> #include <map>
...@@ -77,52 +78,63 @@ class TensorAddToFunctor : public boost::static_visitor<> { ...@@ -77,52 +78,63 @@ class TensorAddToFunctor : public boost::static_visitor<> {
} // namespace detail } // namespace detail
void AddTo(Variable* src, Variable* dst, platform::Place place) { void AddTo(std::shared_ptr<VarBase> src, std::shared_ptr<VarBase> dst,
framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>(); platform::Place place) {
framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>(); if (!dst->IsInitialize()) {
VLOG(2) << "im here1";
// FIXME(minqiyang): loss_grad op will pass a zero grad of label PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase");
// ugly fix for it dst->var_ = std::move(src->var_);
if (src_tensor->numel() == 0) { dst->SetInitialize(true);
return; return;
} } else {
framework::Tensor* dst_tensor =
dst->var_->GetMutable<framework::LoDTensor>();
framework::Tensor* src_tensor =
src->var_->GetMutable<framework::LoDTensor>();
// FIXME(minqiyang): loss_grad op will pass a zero grad of label
// ugly fix for it
if (src_tensor->numel() == 0) {
return;
}
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
"dst_numel %lld vs. src_numel %lld", dst_tensor->numel(), "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
src_tensor->numel()); src_tensor->numel());
detail::TensorAddToFunctor<float> func( detail::TensorAddToFunctor<float> func(
src_tensor->numel(), src_tensor->data<float>(), src_tensor->numel(), src_tensor->data<float>(),
dst_tensor->mutable_data<float>(place)); dst_tensor->mutable_data<float>(place));
boost::apply_visitor(func, place); boost::apply_visitor(func, place);
}
} }
void ZeroGrads(VarBase* vb, const platform::Place& place) { void ZeroGrads(const std::shared_ptr<imperative::VarBase> vb,
const platform::Place& place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
auto grad_t = vb->var_->GetMutable<framework::LoDTensor>(); auto grad_t = vb->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(*dev_ctx, grad_t, 0.0); operators::math::set_constant(*dev_ctx, grad_t, 0.0);
} }
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) { void AddGradBySort(BackwardSumMap* bck_map,
PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(), std::shared_ptr<imperative::VarBase> target) {
PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(),
"Can't find %s in backward grad map", target->Name()); "Can't find %s in backward grad map", target->Name());
std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>& current = std::pair<platform::Place,
bck_map->at(target); std::vector<std::pair<int, std::shared_ptr<imperative::VarBase>>>>&
std::sort( current = bck_map->at(target.get());
current.second.begin(), current.second.end(), std::sort(current.second.begin(), current.second.end(),
[](const std::pair<int, VarBase*>& a, const std::pair<int, VarBase*>& b) { [](const std::pair<int, std::shared_ptr<imperative::VarBase>>& a,
return a.first > b.first; const std::pair<int, std::shared_ptr<imperative::VarBase>>& b) {
}); return a.first > b.first;
});
for (auto& var_pair : current.second) { for (auto& var_pair : current.second) {
Variable* origin_grad = target->var_.get();
Variable* grad_to_add = var_pair.second->var_.get();
VLOG(10) << "add origin_grad: " << target->Name(); VLOG(10) << "add origin_grad: " << target->Name();
VLOG(10) << "added grad: " << var_pair.second->Name() VLOG(10) << "added grad: " << var_pair.second->Name()
<< " trace id is: " << var_pair.first; << " trace id is: " << var_pair.first;
AddTo(grad_to_add, origin_grad, current.first); AddTo(var_pair.second, target, current.first);
delete var_pair.second; var_pair.second.reset();
var_pair.second = nullptr;
} }
} }
...@@ -146,24 +158,22 @@ class Autograd { ...@@ -146,24 +158,22 @@ class Autograd {
while (!ready.empty()) { while (!ready.empty()) {
OpBase* ready_op = ready.front(); OpBase* ready_op = ready.front();
ready.pop_front(); ready.pop_front();
std::map<std::string, std::vector<VarBase*>> input_grads = std::vector<VarBasePtrMap> grads_outputs =
ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy); ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy);
for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) { for (const auto& map : grads_outputs) {
const std::vector<VarBase*>& ingrads = it->second; for (auto it = map.rbegin(); it != map.rend(); ++it) {
for (size_t i = 0; i < ingrads.size(); ++i) { const std::vector<std::shared_ptr<VarBase>>& grad_outs = it->second;
if (!ingrads[i]) continue; for (size_t i = 0; i < grad_outs.size(); ++i) {
auto p = ready_op->input_vars_[it->first][i]; if (!grad_outs[i] || grad_outs[i]->IsStopGradient()) continue;
OpBase* pre_op = grad_outs[i]->PreOp();
if (p->IsStopGradient()) continue; if (!pre_op) continue;
OpBase* pre_op = ready_op->pre_ops_[it->first][i]; dep_counts[pre_op] -= 1;
if (!pre_op) continue; PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
bool pre_op_ready = dep_counts[pre_op] == 0;
dep_counts[pre_op] -= 1; if (pre_op_ready) {
PADDLE_ENFORCE(dep_counts[pre_op] >= 0); ready.push_back(pre_op);
bool pre_op_ready = dep_counts[pre_op] == 0; }
if (pre_op_ready) {
ready.push_back(pre_op);
} }
} }
} }
...@@ -194,7 +204,7 @@ class Autograd { ...@@ -194,7 +204,7 @@ class Autograd {
for (const auto& map : candidate->grad_output_vars_) { for (const auto& map : candidate->grad_output_vars_) {
for (const auto& it : map) { for (const auto& it : map) {
for (const auto& vb : it.second) { for (const auto& vb : it.second) {
++(*grad_ref)[vb]; ++(*grad_ref)[vb.get()];
} }
} }
} }
...@@ -202,7 +212,7 @@ class Autograd { ...@@ -202,7 +212,7 @@ class Autograd {
for (auto it : candidate->pre_ops_) { for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) { for (OpBase* pre_op : it.second) {
if (!pre_op) continue; if (!pre_op) continue;
VLOG(9) << "op dep " << candidate->Type() << " trace id " VLOG(2) << "op dep " << candidate->Type() << " trace id "
<< candidate->trace_id_ << " <---- " << it.first << " <---- " << candidate->trace_id_ << " <---- " << it.first << " <---- "
<< pre_op->Type() << " trace id " << pre_op->trace_id_; << pre_op->Type() << " trace id " << pre_op->trace_id_;
if (visited.find(pre_op) == visited.end()) { if (visited.find(pre_op) == visited.end()) {
...@@ -254,7 +264,7 @@ framework::LoDTensor& VarBase::GradValue() { ...@@ -254,7 +264,7 @@ framework::LoDTensor& VarBase::GradValue() {
return *(grads_->var_->GetMutable<framework::LoDTensor>()); return *(grads_->var_->GetMutable<framework::LoDTensor>());
} }
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( std::vector<VarBasePtrMap> OpBase::ApplyGrad(
BackwardSumMap* bck_map, GradientRef* grad_ref, BackwardSumMap* bck_map, GradientRef* grad_ref,
const detail::BackwardStrategy& bck_stratedy) { const detail::BackwardStrategy& bck_stratedy) {
PADDLE_ENFORCE(!grad_op_descs_.empty(), "%s has no backward implementation", PADDLE_ENFORCE(!grad_op_descs_.empty(), "%s has no backward implementation",
...@@ -274,17 +284,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( ...@@ -274,17 +284,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
for (const auto& it : grad_output_variable_map) { for (const auto& it : grad_output_variable_map) {
auto& outputs = tmp_grad_outputs[k][it.first]; auto& outputs = tmp_grad_outputs[k][it.first];
outputs.reserve(it.second.size()); outputs.reserve(it.second.size());
for (VarBase* origin_grad_var_base : it.second) { for (const std::shared_ptr<imperative::VarBase>& origin_grad_var_base :
if (!origin_grad_var_base->IsInitialize()) { it.second) {
origin_grad_var_base->InitBuffer();
ZeroGrads(origin_grad_var_base, place_);
}
// Allocate a new variable // Allocate a new variable
VarBase* tmp_grad_var_base = new VarBase( std::shared_ptr<imperative::VarBase> tmp_grad_var_base(new VarBase(
string::Sprintf("%s@IGrad", origin_grad_var_base->Name()), string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
origin_grad_var_base->DataType(), origin_grad_var_base->Dims(), origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
place_, true, false); place_, true, false));
outputs.emplace_back(tmp_grad_var_base); outputs.emplace_back(std::move(tmp_grad_var_base));
} }
} }
...@@ -298,7 +305,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( ...@@ -298,7 +305,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type()); auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
if (info.infer_var_type_) { if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx( RuntimeInferVarTypeContext infer_var_type_ctx(
&grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_); &grad_input_vars_[k], &tmp_grad_outputs[k], &(opbase->Attrs()));
info.infer_var_type_(&infer_var_type_ctx); info.infer_var_type_(&infer_var_type_ctx);
} }
...@@ -313,14 +320,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( ...@@ -313,14 +320,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
for (const auto& it : grad_input_vars_[k]) { for (const auto& it : grad_input_vars_[k]) {
auto& grad_invars = grad_invars_map[it.first]; auto& grad_invars = grad_invars_map[it.first];
grad_invars.reserve(it.second.size()); grad_invars.reserve(it.second.size());
for (VarBase* grad_inp : it.second) { for (const std::shared_ptr<imperative::VarBase>& grad_inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr", PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
grad_op_desc->Type(), grad_inp->Name()); grad_op_desc->Type(), grad_inp->Name());
if (!grad_inp->IsInitialize()) { if (!grad_inp->IsInitialize()) {
grad_inp->InitBuffer(); grad_inp->InitBuffer();
ZeroGrads(grad_inp, place_); ZeroGrads(grad_inp, place_);
} }
const VarBase* const_grad_inp = grad_inp; const std::shared_ptr<imperative::VarBase>& const_grad_inp = grad_inp;
grad_invars.emplace_back(const_grad_inp->var_.get()); grad_invars.emplace_back(const_grad_inp->var_.get());
} }
} }
...@@ -328,7 +335,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( ...@@ -328,7 +335,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
for (const auto& it : tmp_grad_outputs[k]) { for (const auto& it : tmp_grad_outputs[k]) {
auto& grad_outvars = grad_outvars_map[it.first]; auto& grad_outvars = grad_outvars_map[it.first];
grad_outvars.reserve(it.second.size()); grad_outvars.reserve(it.second.size());
for (VarBase* grad_out : it.second) { for (const std::shared_ptr<imperative::VarBase>& grad_out : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr", PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
grad_op_desc->Type(), grad_out->Name()); grad_op_desc->Type(), grad_out->Name());
...@@ -355,56 +362,48 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( ...@@ -355,56 +362,48 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
// track outputs used by sum // track outputs used by sum
if (bck_stratedy.sorted_sum_gradient_) { if (bck_stratedy.sorted_sum_gradient_) {
#ifndef PADDLE_WITH_CUDA if (bck_map->find(origin_outputs[i].get()) != bck_map->end()) {
VLOG(10) << "origin_outputs is : " << origin_outputs[i]->Name()
<< " ";
VLOG(10) << origin_outputs[i]
->var_->GetMutable<framework::LoDTensor>()
->data<float>()[0];
VLOG(10) << "outputs is : " << outputs[i]->Name() << " ";
VLOG(10) << outputs[i]
->var_->GetMutable<framework::LoDTensor>()
->data<float>()[0];
#endif
if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
VLOG(10) << "add sub grad to " << origin_outputs[i]->Name(); VLOG(10) << "add sub grad to " << origin_outputs[i]->Name();
bck_map->at(origin_outputs[i]) bck_map->at(origin_outputs[i].get())
.second.emplace_back( .second.emplace_back(
std::pair<int, VarBase*>(this->trace_id_, outputs[i])); std::pair<int, std::shared_ptr<imperative::VarBase>>(
this->trace_id_, std::move(outputs[i])));
} else { } else {
VLOG(10) << "insert new map for " << origin_outputs[i]->Name(); VLOG(10) << "insert new map for " << origin_outputs[i]->Name();
std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>> std::pair<platform::Place,
tmp(place_, {std::make_pair(this->trace_id_, outputs[i])}); std::vector<
bck_map->insert(std::make_pair(origin_outputs[i], tmp)); std::pair<int, std::shared_ptr<imperative::VarBase>>>>
tmp(place_,
{std::make_pair(this->trace_id_, std::move(outputs[i]))});
bck_map->insert(std::make_pair(origin_outputs[i].get(), tmp));
} }
PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(), PADDLE_ENFORCE(
"Can't find %s in grad_reference count map", grad_ref->find(origin_outputs[i].get()) != grad_ref->end(),
origin_outputs[i]->Name()); "Can't find %s in grad_reference count map",
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1, origin_outputs[i]->Name());
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()) >= 1,
"Backward error when calculate grad reference"); "Backward error when calculate grad reference");
if (grad_ref->at(origin_outputs[i]) > 1) { if (grad_ref->at(origin_outputs[i].get()) > 1) {
VLOG(10) << "remove ref for " << origin_outputs[i]->Name(); VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
grad_ref->at(origin_outputs[i])--; grad_ref->at(origin_outputs[i].get())--;
} else { } else {
VLOG(10) << "Add grad for: " << origin_outputs[i]->Name(); VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
AddGradBySort(bck_map, origin_outputs[i]); AddGradBySort(bck_map, origin_outputs[i]);
grad_ref->at(origin_outputs[i])--; grad_ref->at(origin_outputs[i].get())--;
} }
} else { } else {
framework::Variable* grad = outputs[i]->var_.get();
framework::Variable* orig_grad = origin_outputs[i]->var_.get();
VLOG(10) << "AddTo Called with orig_grad is: " VLOG(10) << "AddTo Called with orig_grad is: "
<< origin_outputs[i]->name_ << " Grad to be added is " << origin_outputs[i]->name_ << " Grad to be added is "
<< outputs[i]->name_; << outputs[i]->name_;
AddTo(grad, orig_grad, place_); AddTo(outputs[i], origin_outputs[i], place_);
delete outputs[i]; outputs[i].reset();
} }
} }
} }
} }
return input_vars_; return grad_output_vars_;
} }
void OpBase::InvokeBackwardHooks() { void OpBase::InvokeBackwardHooks() {
...@@ -434,9 +433,6 @@ void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) { ...@@ -434,9 +433,6 @@ void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
var_->GetMutable<framework::LoDTensor>()->place())), var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0); grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this, bck_stratedy); Autograd().RunBackward(this, bck_stratedy);
} }
......
...@@ -171,32 +171,27 @@ class VarBase { ...@@ -171,32 +171,27 @@ class VarBase {
if (need_initialize) { if (need_initialize) {
tensor->mutable_data(place, dtype); tensor->mutable_data(place, dtype);
is_initialized_ = true; is_initialized_ = true;
VLOG(2) << "initialized varbase: " << name_ << " type: " << dtype VLOG(8) << "initialized varbase: " << name_ << " type: " << dtype
<< " place: " << place; << " place: " << place;
} else { } else {
is_initialized_ = false; is_initialized_ = false;
VLOG(2) << "not initialized varbase: " << name_; VLOG(8) << "not initialized varbase: " << name_;
} }
VLOG(2) << "create varbase: " << name_ << " type: " << dtype VLOG(8) << "create varbase: " << name_ << " type: " << dtype
<< " place: " << place; << " place: " << place << "Stop gradient: " << stop_gradient_;
} }
public: public:
virtual ~VarBase() { virtual ~VarBase() {
if (grads_) {
delete grads_;
grads_ = nullptr;
}
pre_op_ = nullptr; pre_op_ = nullptr;
pre_op_out_idx_ = -1; pre_op_out_idx_ = -1;
VLOG(2) << "destruct varbase: " << name_; VLOG(8) << "destruct varbase: " << name_;
} }
inline void SetName(const std::string& name) { name_ = name; } inline void SetName(const std::string& name) { name_ = name; }
inline std::string Name() const { return name_; } inline std::string Name() const { return name_; }
inline bool IsInitialize() const { return is_initialized_; } inline bool IsInitialize() const { return is_initialized_; }
inline void SetInitialize(bool inited) { is_initialized_ = inited; }
inline std::vector<int64_t> Shape() const { inline std::vector<int64_t> Shape() const {
if (var_->IsInitialized()) { if (var_->IsInitialized()) {
return framework::vectorize(var_->Get<framework::LoDTensor>().dims()); return framework::vectorize(var_->Get<framework::LoDTensor>().dims());
...@@ -214,10 +209,7 @@ class VarBase { ...@@ -214,10 +209,7 @@ class VarBase {
auto tensor = var_->GetMutable<framework::LoDTensor>(); auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->mutable_data(tensor->place(), type); tensor->mutable_data(tensor->place(), type);
} }
inline framework::proto::VarType::Type DataType() const { inline framework::proto::VarType::Type DataType() const { return dtype_; }
auto tensor = var_->Get<framework::LoDTensor>();
return tensor.type();
}
// tensor type. e.g.. LoDTensor // tensor type. e.g.. LoDTensor
inline void SetType(framework::proto::VarType::Type type) { type_ = type; } inline void SetType(framework::proto::VarType::Type type) { type_ = type; }
...@@ -225,11 +217,15 @@ class VarBase { ...@@ -225,11 +217,15 @@ class VarBase {
inline void SetStopGradient(bool stop_gradient) { inline void SetStopGradient(bool stop_gradient) {
stop_gradient_ = stop_gradient; stop_gradient_ = stop_gradient;
if (grads_) {
grads_->stop_gradient_ = stop_gradient;
}
} }
inline bool IsStopGradient() const { return stop_gradient_; } inline bool IsStopGradient() const { return stop_gradient_; }
inline void SetPersistable(bool persistable) { persistable_ = persistable; } inline void SetPersistable(bool persistable) { persistable_ = persistable; }
inline bool IsPersistable() const { return persistable_; } inline bool IsPersistable() const { return persistable_; }
inline void SetPreOp(OpBase* op) { pre_op_ = op; }
inline platform::Place GetPlace() { return place_; } inline platform::Place GetPlace() { return place_; }
inline OpBase* PreOp() const { return pre_op_; } inline OpBase* PreOp() const { return pre_op_; }
inline int PreOpOutIdx() const { return pre_op_out_idx_; } inline int PreOpOutIdx() const { return pre_op_out_idx_; }
...@@ -248,10 +244,10 @@ class VarBase { ...@@ -248,10 +244,10 @@ class VarBase {
if (!is_initialized_) { if (!is_initialized_) {
var_->GetMutable<framework::LoDTensor>()->mutable_data(place_, dtype_); var_->GetMutable<framework::LoDTensor>()->mutable_data(place_, dtype_);
is_initialized_ = true; is_initialized_ = true;
VLOG(2) << "initialized varbase: " << name_ << " type: " << dtype_ VLOG(8) << "initialized varbase: " << name_ << " type: " << dtype_
<< " place: " << place_; << " place: " << place_;
} else { } else {
VLOG(2) << "var: " << name_ << " has already been initialized "; VLOG(8) << "var: " << name_ << " has already been initialized ";
} }
} }
...@@ -290,7 +286,7 @@ class VarBase { ...@@ -290,7 +286,7 @@ class VarBase {
platform::Place place_; platform::Place place_;
std::unique_ptr<framework::Variable> var_; std::unique_ptr<framework::Variable> var_;
VarBase* grads_; std::shared_ptr<VarBase> grads_;
private: private:
framework::proto::VarType::Type dtype_; framework::proto::VarType::Type dtype_;
...@@ -314,22 +310,23 @@ class PYBIND11_HIDDEN OpBase { ...@@ -314,22 +310,23 @@ class PYBIND11_HIDDEN OpBase {
backward_hooks_() {} backward_hooks_() {}
virtual ~OpBase() { virtual ~OpBase() {
// TODO(minqiyang): remove op_desc from block_desc in tracer for (const auto& iter : outputs_ref) {
// for (const auto& var : iter.second) {
// reset all output vars' pre op auto vb = var.lock();
for (auto iter : output_vars_) { if (vb) {
for (VarBase* var : iter.second) { VLOG(3) << "Op reset by" << vb->name_;
var->ResetPreOp(this); vb->ResetPreOp(this);
}
} }
} }
// TODO(minqiyang): remove op_desc from block_desc in tracer
// release resource // release resource
for (framework::OpDesc* desc : grad_op_descs_) { for (framework::OpDesc* desc : grad_op_descs_) {
delete desc; delete desc;
} }
} }
std::map<std::string, std::vector<VarBase*>> ApplyGrad( std::vector<VarBasePtrMap> ApplyGrad(
BackwardSumMap* bck_map, GradientRef* grad_ref, BackwardSumMap* bck_map, GradientRef* grad_ref,
const detail::BackwardStrategy& bck_stratedy); const detail::BackwardStrategy& bck_stratedy);
...@@ -343,12 +340,13 @@ class PYBIND11_HIDDEN OpBase { ...@@ -343,12 +340,13 @@ class PYBIND11_HIDDEN OpBase {
void InvokeBackwardHooks(); void InvokeBackwardHooks();
void TrackPreOp(const std::string& inp_name, void TrackPreOp(
const std::vector<VarBase*>& inputs) { const std::string& inp_name,
const std::vector<std::shared_ptr<imperative::VarBase>>& inputs) {
auto& pre_ops_list = pre_ops_[inp_name]; auto& pre_ops_list = pre_ops_[inp_name];
pre_ops_list.reserve(inputs.size()); pre_ops_list.reserve(inputs.size());
auto& pre_ops_out_idx_list = pre_ops_out_idx_[inp_name]; auto& pre_ops_out_idx_list = pre_ops_out_idx_[inp_name];
for (VarBase* inp_var : inputs) { for (std::shared_ptr<imperative::VarBase> inp_var : inputs) {
if (inp_var->PreOp() && !inp_var->IsStopGradient()) { if (inp_var->PreOp() && !inp_var->IsStopGradient()) {
VLOG(3) << "add pre op " << inp_var->PreOp()->Type() << " in slot " VLOG(3) << "add pre op " << inp_var->PreOp()->Type() << " in slot "
<< inp_name; << inp_name;
...@@ -371,11 +369,10 @@ class PYBIND11_HIDDEN OpBase { ...@@ -371,11 +369,10 @@ class PYBIND11_HIDDEN OpBase {
platform::Place place_; platform::Place place_;
VarBasePtrMap input_vars_;
VarBasePtrMap output_vars_;
OpBasePtrMap pre_ops_; OpBasePtrMap pre_ops_;
std::map<std::string, std::vector<int>> pre_ops_out_idx_; std::map<std::string, std::vector<int>> pre_ops_out_idx_;
VarBaseWeakPtrMap outputs_ref;
// Inputs to a vector of bwd ops. // Inputs to a vector of bwd ops.
std::vector<VarBasePtrMap> grad_input_vars_; std::vector<VarBasePtrMap> grad_input_vars_;
// Outputs to a vector of bwd ops. // Outputs to a vector of bwd ops.
...@@ -390,8 +387,9 @@ class Layer { ...@@ -390,8 +387,9 @@ class Layer {
public: public:
virtual ~Layer() {} virtual ~Layer() {}
virtual std::vector<VarBase*> Forward(const std::vector<VarBase*>& inputs) { virtual std::vector<std::shared_ptr<VarBase>> Forward(
std::vector<VarBase*> vars; const std::vector<std::shared_ptr<VarBase>>& inputs) {
std::vector<std::shared_ptr<VarBase>> vars;
return vars; return vars;
} }
}; };
...@@ -412,7 +410,7 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext ...@@ -412,7 +410,7 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
var_set_() { var_set_() {
input_names_.reserve(inputs_->size()); input_names_.reserve(inputs_->size());
for (auto& it : *inputs_) { for (auto& it : *inputs_) {
for (imperative::VarBase* var : it.second) { for (std::shared_ptr<imperative::VarBase> var : it.second) {
input_names_[it.first].emplace_back(var->Name()); input_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var; var_set_[var->Name()] = var;
} }
...@@ -420,7 +418,7 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext ...@@ -420,7 +418,7 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
output_names_.reserve(outputs_->size()); output_names_.reserve(outputs_->size());
for (auto& it : *outputs_) { for (auto& it : *outputs_) {
for (imperative::VarBase* var : it.second) { for (std::shared_ptr<imperative::VarBase> var : it.second) {
output_names_[it.first].emplace_back(var->Name()); output_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var; var_set_[var->Name()] = var;
} }
...@@ -516,7 +514,8 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext ...@@ -516,7 +514,8 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
const framework::AttributeMap* attrs_; const framework::AttributeMap* attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_; std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_; std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, imperative::VarBase*> var_set_; std::unordered_map<std::string, std::shared_ptr<imperative::VarBase>>
var_set_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -46,23 +46,25 @@ void CreateGradOp(const framework::OpDesc& op_desc, ...@@ -46,23 +46,25 @@ void CreateGradOp(const framework::OpDesc& op_desc,
} }
} }
void CreateNoBuffuerGrad(VarBase* var, platform::DeviceContext* dev_ctx) { void CreateNoBuffuerGrad(std::shared_ptr<imperative::VarBase> var,
platform::DeviceContext* dev_ctx) {
PADDLE_ENFORCE_NOT_NULL(var, "Could not get valid var base"); PADDLE_ENFORCE_NOT_NULL(var, "Could not get valid var base");
PADDLE_ENFORCE_NOT_NULL(dev_ctx, PADDLE_ENFORCE_NOT_NULL(dev_ctx,
"Could not get valid device from forward op"); "Could not get valid device from forward op");
if (var->grads_ == nullptr) { if (var->grads_ == nullptr) {
auto& var_t = var->var_->Get<framework::LoDTensor>(); auto& var_t = var->var_->Get<framework::LoDTensor>();
var->grads_ = new VarBase(var->GradName(), framework::proto::VarType::FP32, var->grads_ = std::shared_ptr<imperative::VarBase>(
framework::vectorize(var_t.dims()), new VarBase(var->GradName(), framework::proto::VarType::FP32,
dev_ctx->GetPlace(), true, false, false); framework::vectorize(var_t.dims()), dev_ctx->GetPlace(),
var->IsStopGradient(), false, false));
} }
} }
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
platform::Place result = place; platform::Place result = place;
for (auto it : inputs) { for (const auto& it : inputs) {
for (VarBase* var : it.second) { for (const std::shared_ptr<imperative::VarBase>& var : it.second) {
platform::Place tmp_place = platform::Place tmp_place =
var->var_->Get<framework::LoDTensor>().place(); var->var_->Get<framework::LoDTensor>().place();
if (!platform::is_same_place(tmp_place, result)) { if (!platform::is_same_place(tmp_place, result)) {
...@@ -96,7 +98,7 @@ framework::VariableNameMap CreateInputVarNameMap( ...@@ -96,7 +98,7 @@ framework::VariableNameMap CreateInputVarNameMap(
auto var_vector = it->second; auto var_vector = it->second;
std::vector<std::string> args; std::vector<std::string> args;
args.reserve(var_vector.size()); args.reserve(var_vector.size());
for (VarBase* var_base : var_vector) { for (std::shared_ptr<imperative::VarBase> var_base : var_vector) {
args.emplace_back(var_base->Name()); args.emplace_back(var_base->Name());
} }
result[in.name()] = args; result[in.name()] = args;
...@@ -124,7 +126,7 @@ framework::VariableNameMap CreateOutputVarNameMap( ...@@ -124,7 +126,7 @@ framework::VariableNameMap CreateOutputVarNameMap(
auto var_vector = it->second; auto var_vector = it->second;
std::vector<std::string> args; std::vector<std::string> args;
args.reserve(var_vector.size()); args.reserve(var_vector.size());
for (VarBase* var_base : var_vector) { for (const std::shared_ptr<imperative::VarBase>& var_base : var_vector) {
args.emplace_back(var_base->Name()); args.emplace_back(var_base->Name());
} }
result[out.name()] = args; result[out.name()] = args;
...@@ -135,22 +137,20 @@ framework::VariableNameMap CreateOutputVarNameMap( ...@@ -135,22 +137,20 @@ framework::VariableNameMap CreateOutputVarNameMap(
Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {} Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {}
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
VarBasePtrMap* outputs, VarBasePtrMap* outputs, framework::AttributeMap attrs_map,
framework::AttributeMap attrs_map, const platform::Place expected_place,
const platform::Place expected_place, const bool stop_gradient) {
const bool stop_gradient) {
platform::RecordEvent record_event(op->type_); platform::RecordEvent record_event(op->type_);
framework::VariableValueMap invars_map; framework::VariableValueMap invars_map;
framework::VariableValueMap outvars_map; framework::VariableValueMap outvars_map;
// Construct input_vars_map and output_vars_map // Construct input_vars_map and output_vars_map
std::map<std::string, VarBase*> current_vars_map; std::map<std::string, std::shared_ptr<imperative::VarBase>> current_vars_map;
op->input_vars_ = inputs; for (auto it : inputs) {
for (auto it : op->input_vars_) {
auto& invars = invars_map[it.first]; auto& invars = invars_map[it.first];
invars.reserve(it.second.size()); invars.reserve(it.second.size());
for (VarBase* inp : it.second) { for (std::shared_ptr<imperative::VarBase> inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->Type(), PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->Type(),
inp->Name()); inp->Name());
...@@ -165,13 +165,15 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -165,13 +165,15 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->TrackPreOp(it.first, it.second); op->TrackPreOp(it.first, it.second);
} }
op->output_vars_ = *outputs; for (const auto& it : *outputs) {
for (auto it : op->output_vars_) {
auto& outvars = outvars_map[it.first]; auto& outvars = outvars_map[it.first];
const std::vector<VarBase*>& outputs = it.second; const std::vector<std::shared_ptr<imperative::VarBase>>& outputs_tmp =
outvars.reserve(outputs.size()); it.second;
for (size_t i = 0U; i < outputs.size(); ++i) { outvars.reserve(outputs_tmp.size());
VarBase* out = outputs[i]; for (size_t i = 0U; i < outputs_tmp.size(); ++i) {
// Add weak_ptr to track outputs
op->outputs_ref[it.first].emplace_back(outputs_tmp[i]);
std::shared_ptr<imperative::VarBase> out = outputs_tmp[i];
outvars.emplace_back(out->var_.get()); outvars.emplace_back(out->var_.get());
out->TrackPreOp(op, it.first, i, stop_gradient); out->TrackPreOp(op, it.first, i, stop_gradient);
if (!stop_gradient) { if (!stop_gradient) {
...@@ -223,8 +225,6 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -223,8 +225,6 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx, framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx,
prepared_op.ctx, prepared_op.kernel_configs)); prepared_op.ctx, prepared_op.kernel_configs));
// construct backward op
std::set<std::string> vars_saved_for_backward;
if (!stop_gradient) { if (!stop_gradient) {
VLOG(5) << "start construct backward op"; VLOG(5) << "start construct backward op";
...@@ -258,13 +258,13 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -258,13 +258,13 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
// Forward inputs or outputs. // Forward inputs or outputs.
grad_in_vars.emplace_back(fwd_var_it->second); grad_in_vars.emplace_back(fwd_var_it->second);
} else { } else {
VarBase* var = current_vars_map[var_it->second]; std::shared_ptr<imperative::VarBase> var =
current_vars_map[var_it->second];
CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext()); CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext());
// Douts. // Douts.
var->grads_->SetPreOp(var->PreOp());
grad_in_vars.emplace_back(var->grads_); grad_in_vars.emplace_back(var->grads_);
} }
vars_saved_for_backward.insert(it.first);
} }
} }
...@@ -276,16 +276,17 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -276,16 +276,17 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
"Could not found the grad op output var, should this " "Could not found the grad op output var, should this "
"operator %s's stop gradient be True", "operator %s's stop gradient be True",
op->Type()); op->Type());
VarBase* var = current_vars_map[var_it->second];
std::shared_ptr<imperative::VarBase> var =
current_vars_map[var_it->second];
CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext()); CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext());
var->grads_->SetPreOp(var->PreOp());
grad_out_vars.push_back(var->grads_); grad_out_vars.push_back(var->grads_);
VLOG(3) << "grads output var name: " << var->name_; VLOG(3) << "grads output var name: " << var->name_;
} }
} }
} }
} }
return vars_saved_for_backward;
} }
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -36,9 +36,6 @@ void CreateGradOp(const framework::OpDesc& op_desc, ...@@ -36,9 +36,6 @@ void CreateGradOp(const framework::OpDesc& op_desc,
framework::OpDesc** grad_op_desc, framework::OpDesc** grad_op_desc,
std::unordered_map<std::string, std::string>* grad_to_var); std::unordered_map<std::string, std::string>* grad_to_var);
void InitVar(const VarBase* var, framework::Variable* grad_var,
platform::DeviceContext* dev_ctx);
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs); platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs);
class Tracer { class Tracer {
...@@ -47,11 +44,11 @@ class Tracer { ...@@ -47,11 +44,11 @@ class Tracer {
virtual ~Tracer() {} virtual ~Tracer() {}
std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs, void Trace(OpBase* op, const VarBasePtrMap& inputs,
VarBasePtrMap* outputs, // NOLINT VarBasePtrMap* outputs, // NOLINT
framework::AttributeMap attrs_map, framework::AttributeMap attrs_map,
const platform::Place expected_place, const platform::Place expected_place,
const bool stop_gradient = false); const bool stop_gradient = false);
private: private:
platform::Place GetPlace(const VarBasePtrMap& inputs); platform::Place GetPlace(const VarBasePtrMap& inputs);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -26,12 +27,17 @@ namespace imperative { ...@@ -26,12 +27,17 @@ namespace imperative {
class VarBase; class VarBase;
class OpBase; class OpBase;
typedef std::map<std::string, std::vector<VarBase*>> VarBasePtrMap; typedef std::map<std::string, std::vector<std::shared_ptr<VarBase>>>
typedef std::map<std::string, std::vector<const VarBase*>> ConstVarBasePtrMap; VarBasePtrMap;
typedef std::map<std::string, std::vector<std::weak_ptr<VarBase>>>
VarBaseWeakPtrMap;
typedef std::map<std::string, std::vector<const std::shared_ptr<VarBase>>>
ConstVarBasePtrMap;
typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap; typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap;
typedef std::unordered_map< typedef std::unordered_map<
const VarBase*, const VarBase*,
std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>> std::pair<platform::Place,
std::vector<std::pair<int, std::shared_ptr<VarBase>>>>>
BackwardSumMap; // var_grad -> {place, {id -> var_grad@rename}} BackwardSumMap; // var_grad -> {place, {id -> var_grad@rename}}
typedef std::unordered_map<const VarBase*, int> GradientRef; typedef std::unordered_map<const VarBase*, int> GradientRef;
......
...@@ -35,9 +35,11 @@ class Layer : public imperative::Layer { ...@@ -35,9 +35,11 @@ class Layer : public imperative::Layer {
public: public:
using imperative::Layer::Layer; // Inherit constructors using imperative::Layer::Layer; // Inherit constructors
std::vector<imperative::VarBase *> Forward( std::vector<std::shared_ptr<imperative::VarBase>> Forward(
const std::vector<imperative::VarBase *> &inputs) override { const std::vector<std::shared_ptr<imperative::VarBase>> &inputs)
PYBIND11_OVERLOAD(std::vector<imperative::VarBase *>, Layer, Forward, override {
PYBIND11_OVERLOAD(std::vector<std::shared_ptr<imperative::VarBase>>, Layer,
Forward,
inputs); // NOLINT inputs); // NOLINT
} }
}; };
...@@ -72,7 +74,8 @@ void BindImperative(pybind11::module *m_ptr) { ...@@ -72,7 +74,8 @@ void BindImperative(pybind11::module *m_ptr) {
m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); }); m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); });
py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC") py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
m, "VarBase", R"DOC()DOC")
.def( .def(
py::init<const std::string &, paddle::framework::proto::VarType::Type, py::init<const std::string &, paddle::framework::proto::VarType::Type,
const std::vector<int64_t>, const paddle::platform::CPUPlace, const std::vector<int64_t>, const paddle::platform::CPUPlace,
...@@ -136,10 +139,11 @@ void BindImperative(pybind11::module *m_ptr) { ...@@ -136,10 +139,11 @@ void BindImperative(pybind11::module *m_ptr) {
py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer"); py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
layer.def(py::init<>()) layer.def(py::init<>())
.def("forward", [](imperative::Layer &self, .def("forward",
const std::vector<imperative::VarBase *> &inputs) { [](imperative::Layer &self,
return self.Forward(inputs); const std::vector<std::shared_ptr<imperative::VarBase>> &inputs) {
}); return self.Forward(inputs);
});
py::class_<imperative::Tracer>(*m, "Tracer", "") py::class_<imperative::Tracer>(*m, "Tracer", "")
.def("__init__", .def("__init__",
...@@ -154,8 +158,8 @@ void BindImperative(pybind11::module *m_ptr) { ...@@ -154,8 +158,8 @@ void BindImperative(pybind11::module *m_ptr) {
const platform::CPUPlace expected_place, const platform::CPUPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
py::gil_scoped_release release; py::gil_scoped_release release;
return self.Trace(op, inputs, outputs, attrs_map, expected_place, self.Trace(op, inputs, outputs, attrs_map, expected_place,
stop_gradient); stop_gradient);
}) })
.def("trace", [](imperative::Tracer &self, imperative::OpBase *op, .def("trace", [](imperative::Tracer &self, imperative::OpBase *op,
const imperative::VarBasePtrMap &inputs, const imperative::VarBasePtrMap &inputs,
...@@ -164,8 +168,8 @@ void BindImperative(pybind11::module *m_ptr) { ...@@ -164,8 +168,8 @@ void BindImperative(pybind11::module *m_ptr) {
const platform::CUDAPlace expected_place, const platform::CUDAPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
py::gil_scoped_release release; py::gil_scoped_release release;
return self.Trace(op, inputs, outputs, attrs_map, expected_place, self.Trace(op, inputs, outputs, attrs_map, expected_place,
stop_gradient); stop_gradient);
}); });
// define parallel context // define parallel context
......
...@@ -24,9 +24,7 @@ __all__ = ['Tracer'] ...@@ -24,9 +24,7 @@ __all__ = ['Tracer']
def release_op(op): def release_op(op):
del framework._dygraph_tracer()._ops[op._trace_id].inputs del framework._dygraph_tracer()._ops[op._trace_id]
del framework._dygraph_tracer()._ops[op._trace_id].outputs
del framework._dygraph_tracer()._ops[op._trace_id].backward_refs
class Tracer(core.Tracer): class Tracer(core.Tracer):
...@@ -55,7 +53,6 @@ class Tracer(core.Tracer): ...@@ -55,7 +53,6 @@ class Tracer(core.Tracer):
def trace_op(self, op, inputs, outputs, stop_gradient=False): def trace_op(self, op, inputs, outputs, stop_gradient=False):
# TODO(hy): previous version will cause memory failed # TODO(hy): previous version will cause memory failed
op.inputs = inputs
inps = defaultdict(list) inps = defaultdict(list)
for k, vars in six.iteritems(inputs): for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable): if isinstance(vars, framework.Variable):
...@@ -64,7 +61,6 @@ class Tracer(core.Tracer): ...@@ -64,7 +61,6 @@ class Tracer(core.Tracer):
for var in vars: for var in vars:
inps[k].append(var._ivar) inps[k].append(var._ivar)
op.outputs = outputs
outs = defaultdict(list) outs = defaultdict(list)
for k, vars in six.iteritems(outputs): for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable): if isinstance(vars, framework.Variable):
...@@ -76,28 +72,15 @@ class Tracer(core.Tracer): ...@@ -76,28 +72,15 @@ class Tracer(core.Tracer):
# record op's trace id # record op's trace id
op.iop._trace_id = self._trace_id op.iop._trace_id = self._trace_id
backward_refs = self.trace(op.iop, inps, outs, op.attrs, self.trace(op.iop, inps, outs, op.attrs,
framework._current_expected_place(), framework._current_expected_place(), stop_gradient)
stop_gradient)
if not stop_gradient and self._train_mode: if not stop_gradient and self._train_mode:
self._trace_id += 1 self._trace_id += 1
self._ops[op.iop._trace_id] = op self._ops[op.iop._trace_id] = op
# register backward hooks and variables if needed # register backward hooks and variables if needed
if len(backward_refs) > 0: op.iop.register_backward_hooks(release_op)
op.iop.register_backward_hooks(release_op)
# TODO(minqiyang): remove all inputs and outputs after separate
# var and grad
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(inputs):
if k in backward_refs:
op.backward_refs[k] = inputs[k]
for k, v in six.iteritems(outputs):
if k in backward_refs:
op.backward_refs[k] = outputs[k]
def train_mode(self): def train_mode(self):
self._train_mode = True self._train_mode = True
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.dygraph.nn import Embedding
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
import numpy as np
import six
class RecurrentTest(fluid.Layer):
def __init__(self, name_scope):
super(RecurrentTest, self).__init__(name_scope)
def forward(self, in1, in2):
out = fluid.layers.mul(in1, in2)
sum_out = fluid.layers.reduce_sum(out)
return sum_out, out
class TestRecurrentFeed(unittest.TestCase):
def test_recurrent_feed(self):
seed = 90
original_np1 = np.arange(1, 5).reshape(2, 2).astype("float32")
original_np2 = np.arange(5, 9).reshape(2, 2).astype("float32")
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
original_in1 = to_variable(original_np1)
original_in2 = to_variable(original_np2)
rt = RecurrentTest("RecurrentTest")
for i in range(3):
sum_out, out = rt(original_in1, original_in2)
original_in1 = out
sum_out_value = sum_out.numpy()
sum_out.backward()
rt.clear_gradients()
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
in1 = fluid.layers.data(
name="inp1", shape=[2, 2], append_batch_size=False)
in2 = fluid.layers.data(
name="inp2", shape=[2, 2], append_batch_size=False)
rt1 = RecurrentTest("RecurrentTest")
static_sum_out, static_out = rt1(in1, in2)
fluid.backward.append_backward(static_sum_out)
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
fetch_list = [static_sum_out, static_out]
for i in range(3):
out = exe.run(
fluid.default_main_program(),
feed={"inp1": original_np1,
"inp2": original_np2},
fetch_list=fetch_list)
static_out_value = out[1]
static_sum_out = out[0]
original_np1 = static_out_value
self.assertTrue(np.array_equal(static_sum_out, sum_out_value))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册