未验证 提交 3b70f870 编写于 作者: J Jiabin Yang 提交者: GitHub

Using Smart pointer to optimizer memory usage of dyGraph (#17768)

* for debug

* test=develop, memory optimize for dygraph using shared_ptr

* test=develop, fix travis ci showed error

* test=develop, fix bug for recurrent usage of varbase

* test=develop, init varbase when it need to be Add
上级 82358bfd
......@@ -14,6 +14,7 @@
#include "paddle/fluid/imperative/layer.h"
#include <algorithm>
#include <deque>
#include <limits>
#include <map>
......@@ -77,9 +78,19 @@ class TensorAddToFunctor : public boost::static_visitor<> {
} // namespace detail
void AddTo(Variable* src, Variable* dst, platform::Place place) {
framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>();
void AddTo(std::shared_ptr<VarBase> src, std::shared_ptr<VarBase> dst,
platform::Place place) {
if (!dst->IsInitialize()) {
VLOG(2) << "im here1";
PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase");
dst->var_ = std::move(src->var_);
dst->SetInitialize(true);
return;
} else {
framework::Tensor* dst_tensor =
dst->var_->GetMutable<framework::LoDTensor>();
framework::Tensor* src_tensor =
src->var_->GetMutable<framework::LoDTensor>();
// FIXME(minqiyang): loss_grad op will pass a zero grad of label
// ugly fix for it
......@@ -95,34 +106,35 @@ void AddTo(Variable* src, Variable* dst, platform::Place place) {
src_tensor->numel(), src_tensor->data<float>(),
dst_tensor->mutable_data<float>(place));
boost::apply_visitor(func, place);
}
}
void ZeroGrads(VarBase* vb, const platform::Place& place) {
void ZeroGrads(const std::shared_ptr<imperative::VarBase> vb,
const platform::Place& place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
auto grad_t = vb->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(*dev_ctx, grad_t, 0.0);
}
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(),
void AddGradBySort(BackwardSumMap* bck_map,
std::shared_ptr<imperative::VarBase> target) {
PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(),
"Can't find %s in backward grad map", target->Name());
std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>& current =
bck_map->at(target);
std::sort(
current.second.begin(), current.second.end(),
[](const std::pair<int, VarBase*>& a, const std::pair<int, VarBase*>& b) {
std::pair<platform::Place,
std::vector<std::pair<int, std::shared_ptr<imperative::VarBase>>>>&
current = bck_map->at(target.get());
std::sort(current.second.begin(), current.second.end(),
[](const std::pair<int, std::shared_ptr<imperative::VarBase>>& a,
const std::pair<int, std::shared_ptr<imperative::VarBase>>& b) {
return a.first > b.first;
});
for (auto& var_pair : current.second) {
Variable* origin_grad = target->var_.get();
Variable* grad_to_add = var_pair.second->var_.get();
VLOG(10) << "add origin_grad: " << target->Name();
VLOG(10) << "added grad: " << var_pair.second->Name()
<< " trace id is: " << var_pair.first;
AddTo(grad_to_add, origin_grad, current.first);
delete var_pair.second;
var_pair.second = nullptr;
AddTo(var_pair.second, target, current.first);
var_pair.second.reset();
}
}
......@@ -146,19 +158,16 @@ class Autograd {
while (!ready.empty()) {
OpBase* ready_op = ready.front();
ready.pop_front();
std::map<std::string, std::vector<VarBase*>> input_grads =
std::vector<VarBasePtrMap> grads_outputs =
ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy);
for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) {
const std::vector<VarBase*>& ingrads = it->second;
for (size_t i = 0; i < ingrads.size(); ++i) {
if (!ingrads[i]) continue;
auto p = ready_op->input_vars_[it->first][i];
if (p->IsStopGradient()) continue;
OpBase* pre_op = ready_op->pre_ops_[it->first][i];
for (const auto& map : grads_outputs) {
for (auto it = map.rbegin(); it != map.rend(); ++it) {
const std::vector<std::shared_ptr<VarBase>>& grad_outs = it->second;
for (size_t i = 0; i < grad_outs.size(); ++i) {
if (!grad_outs[i] || grad_outs[i]->IsStopGradient()) continue;
OpBase* pre_op = grad_outs[i]->PreOp();
if (!pre_op) continue;
dep_counts[pre_op] -= 1;
PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
bool pre_op_ready = dep_counts[pre_op] == 0;
......@@ -167,6 +176,7 @@ class Autograd {
}
}
}
}
ready_op->InvokeBackwardHooks();
}
......@@ -194,7 +204,7 @@ class Autograd {
for (const auto& map : candidate->grad_output_vars_) {
for (const auto& it : map) {
for (const auto& vb : it.second) {
++(*grad_ref)[vb];
++(*grad_ref)[vb.get()];
}
}
}
......@@ -202,7 +212,7 @@ class Autograd {
for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) {
if (!pre_op) continue;
VLOG(9) << "op dep " << candidate->Type() << " trace id "
VLOG(2) << "op dep " << candidate->Type() << " trace id "
<< candidate->trace_id_ << " <---- " << it.first << " <---- "
<< pre_op->Type() << " trace id " << pre_op->trace_id_;
if (visited.find(pre_op) == visited.end()) {
......@@ -254,7 +264,7 @@ framework::LoDTensor& VarBase::GradValue() {
return *(grads_->var_->GetMutable<framework::LoDTensor>());
}
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
std::vector<VarBasePtrMap> OpBase::ApplyGrad(
BackwardSumMap* bck_map, GradientRef* grad_ref,
const detail::BackwardStrategy& bck_stratedy) {
PADDLE_ENFORCE(!grad_op_descs_.empty(), "%s has no backward implementation",
......@@ -274,17 +284,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
for (const auto& it : grad_output_variable_map) {
auto& outputs = tmp_grad_outputs[k][it.first];
outputs.reserve(it.second.size());
for (VarBase* origin_grad_var_base : it.second) {
if (!origin_grad_var_base->IsInitialize()) {
origin_grad_var_base->InitBuffer();
ZeroGrads(origin_grad_var_base, place_);
}
for (const std::shared_ptr<imperative::VarBase>& origin_grad_var_base :
it.second) {
// Allocate a new variable
VarBase* tmp_grad_var_base = new VarBase(
std::shared_ptr<imperative::VarBase> tmp_grad_var_base(new VarBase(
string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
place_, true, false);
outputs.emplace_back(tmp_grad_var_base);
place_, true, false));
outputs.emplace_back(std::move(tmp_grad_var_base));
}
}
......@@ -298,7 +305,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(
&grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
&grad_input_vars_[k], &tmp_grad_outputs[k], &(opbase->Attrs()));
info.infer_var_type_(&infer_var_type_ctx);
}
......@@ -313,14 +320,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
for (const auto& it : grad_input_vars_[k]) {
auto& grad_invars = grad_invars_map[it.first];
grad_invars.reserve(it.second.size());
for (VarBase* grad_inp : it.second) {
for (const std::shared_ptr<imperative::VarBase>& grad_inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
grad_op_desc->Type(), grad_inp->Name());
if (!grad_inp->IsInitialize()) {
grad_inp->InitBuffer();
ZeroGrads(grad_inp, place_);
}
const VarBase* const_grad_inp = grad_inp;
const std::shared_ptr<imperative::VarBase>& const_grad_inp = grad_inp;
grad_invars.emplace_back(const_grad_inp->var_.get());
}
}
......@@ -328,7 +335,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
for (const auto& it : tmp_grad_outputs[k]) {
auto& grad_outvars = grad_outvars_map[it.first];
grad_outvars.reserve(it.second.size());
for (VarBase* grad_out : it.second) {
for (const std::shared_ptr<imperative::VarBase>& grad_out : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
grad_op_desc->Type(), grad_out->Name());
......@@ -355,56 +362,48 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
for (size_t i = 0; i < outputs.size(); ++i) {
// track outputs used by sum
if (bck_stratedy.sorted_sum_gradient_) {
#ifndef PADDLE_WITH_CUDA
VLOG(10) << "origin_outputs is : " << origin_outputs[i]->Name()
<< " ";
VLOG(10) << origin_outputs[i]
->var_->GetMutable<framework::LoDTensor>()
->data<float>()[0];
VLOG(10) << "outputs is : " << outputs[i]->Name() << " ";
VLOG(10) << outputs[i]
->var_->GetMutable<framework::LoDTensor>()
->data<float>()[0];
#endif
if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
if (bck_map->find(origin_outputs[i].get()) != bck_map->end()) {
VLOG(10) << "add sub grad to " << origin_outputs[i]->Name();
bck_map->at(origin_outputs[i])
bck_map->at(origin_outputs[i].get())
.second.emplace_back(
std::pair<int, VarBase*>(this->trace_id_, outputs[i]));
std::pair<int, std::shared_ptr<imperative::VarBase>>(
this->trace_id_, std::move(outputs[i])));
} else {
VLOG(10) << "insert new map for " << origin_outputs[i]->Name();
std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>
tmp(place_, {std::make_pair(this->trace_id_, outputs[i])});
bck_map->insert(std::make_pair(origin_outputs[i], tmp));
std::pair<platform::Place,
std::vector<
std::pair<int, std::shared_ptr<imperative::VarBase>>>>
tmp(place_,
{std::make_pair(this->trace_id_, std::move(outputs[i]))});
bck_map->insert(std::make_pair(origin_outputs[i].get(), tmp));
}
PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(),
PADDLE_ENFORCE(
grad_ref->find(origin_outputs[i].get()) != grad_ref->end(),
"Can't find %s in grad_reference count map",
origin_outputs[i]->Name());
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1,
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()) >= 1,
"Backward error when calculate grad reference");
if (grad_ref->at(origin_outputs[i]) > 1) {
if (grad_ref->at(origin_outputs[i].get()) > 1) {
VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
grad_ref->at(origin_outputs[i])--;
grad_ref->at(origin_outputs[i].get())--;
} else {
VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
AddGradBySort(bck_map, origin_outputs[i]);
grad_ref->at(origin_outputs[i])--;
grad_ref->at(origin_outputs[i].get())--;
}
} else {
framework::Variable* grad = outputs[i]->var_.get();
framework::Variable* orig_grad = origin_outputs[i]->var_.get();
VLOG(10) << "AddTo Called with orig_grad is: "
<< origin_outputs[i]->name_ << " Grad to be added is "
<< outputs[i]->name_;
AddTo(grad, orig_grad, place_);
delete outputs[i];
AddTo(outputs[i], origin_outputs[i], place_);
outputs[i].reset();
}
}
}
}
return input_vars_;
return grad_output_vars_;
}
void OpBase::InvokeBackwardHooks() {
......@@ -434,9 +433,6 @@ void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this, bck_stratedy);
}
......
......@@ -171,32 +171,27 @@ class VarBase {
if (need_initialize) {
tensor->mutable_data(place, dtype);
is_initialized_ = true;
VLOG(2) << "initialized varbase: " << name_ << " type: " << dtype
VLOG(8) << "initialized varbase: " << name_ << " type: " << dtype
<< " place: " << place;
} else {
is_initialized_ = false;
VLOG(2) << "not initialized varbase: " << name_;
VLOG(8) << "not initialized varbase: " << name_;
}
VLOG(2) << "create varbase: " << name_ << " type: " << dtype
<< " place: " << place;
VLOG(8) << "create varbase: " << name_ << " type: " << dtype
<< " place: " << place << "Stop gradient: " << stop_gradient_;
}
public:
virtual ~VarBase() {
if (grads_) {
delete grads_;
grads_ = nullptr;
}
pre_op_ = nullptr;
pre_op_out_idx_ = -1;
VLOG(2) << "destruct varbase: " << name_;
VLOG(8) << "destruct varbase: " << name_;
}
inline void SetName(const std::string& name) { name_ = name; }
inline std::string Name() const { return name_; }
inline bool IsInitialize() const { return is_initialized_; }
inline void SetInitialize(bool inited) { is_initialized_ = inited; }
inline std::vector<int64_t> Shape() const {
if (var_->IsInitialized()) {
return framework::vectorize(var_->Get<framework::LoDTensor>().dims());
......@@ -214,10 +209,7 @@ class VarBase {
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->mutable_data(tensor->place(), type);
}
inline framework::proto::VarType::Type DataType() const {
auto tensor = var_->Get<framework::LoDTensor>();
return tensor.type();
}
inline framework::proto::VarType::Type DataType() const { return dtype_; }
// tensor type. e.g.. LoDTensor
inline void SetType(framework::proto::VarType::Type type) { type_ = type; }
......@@ -225,11 +217,15 @@ class VarBase {
inline void SetStopGradient(bool stop_gradient) {
stop_gradient_ = stop_gradient;
if (grads_) {
grads_->stop_gradient_ = stop_gradient;
}
}
inline bool IsStopGradient() const { return stop_gradient_; }
inline void SetPersistable(bool persistable) { persistable_ = persistable; }
inline bool IsPersistable() const { return persistable_; }
inline void SetPreOp(OpBase* op) { pre_op_ = op; }
inline platform::Place GetPlace() { return place_; }
inline OpBase* PreOp() const { return pre_op_; }
inline int PreOpOutIdx() const { return pre_op_out_idx_; }
......@@ -248,10 +244,10 @@ class VarBase {
if (!is_initialized_) {
var_->GetMutable<framework::LoDTensor>()->mutable_data(place_, dtype_);
is_initialized_ = true;
VLOG(2) << "initialized varbase: " << name_ << " type: " << dtype_
VLOG(8) << "initialized varbase: " << name_ << " type: " << dtype_
<< " place: " << place_;
} else {
VLOG(2) << "var: " << name_ << " has already been initialized ";
VLOG(8) << "var: " << name_ << " has already been initialized ";
}
}
......@@ -290,7 +286,7 @@ class VarBase {
platform::Place place_;
std::unique_ptr<framework::Variable> var_;
VarBase* grads_;
std::shared_ptr<VarBase> grads_;
private:
framework::proto::VarType::Type dtype_;
......@@ -314,22 +310,23 @@ class PYBIND11_HIDDEN OpBase {
backward_hooks_() {}
virtual ~OpBase() {
// TODO(minqiyang): remove op_desc from block_desc in tracer
//
// reset all output vars' pre op
for (auto iter : output_vars_) {
for (VarBase* var : iter.second) {
var->ResetPreOp(this);
for (const auto& iter : outputs_ref) {
for (const auto& var : iter.second) {
auto vb = var.lock();
if (vb) {
VLOG(3) << "Op reset by" << vb->name_;
vb->ResetPreOp(this);
}
}
}
// TODO(minqiyang): remove op_desc from block_desc in tracer
// release resource
for (framework::OpDesc* desc : grad_op_descs_) {
delete desc;
}
}
std::map<std::string, std::vector<VarBase*>> ApplyGrad(
std::vector<VarBasePtrMap> ApplyGrad(
BackwardSumMap* bck_map, GradientRef* grad_ref,
const detail::BackwardStrategy& bck_stratedy);
......@@ -343,12 +340,13 @@ class PYBIND11_HIDDEN OpBase {
void InvokeBackwardHooks();
void TrackPreOp(const std::string& inp_name,
const std::vector<VarBase*>& inputs) {
void TrackPreOp(
const std::string& inp_name,
const std::vector<std::shared_ptr<imperative::VarBase>>& inputs) {
auto& pre_ops_list = pre_ops_[inp_name];
pre_ops_list.reserve(inputs.size());
auto& pre_ops_out_idx_list = pre_ops_out_idx_[inp_name];
for (VarBase* inp_var : inputs) {
for (std::shared_ptr<imperative::VarBase> inp_var : inputs) {
if (inp_var->PreOp() && !inp_var->IsStopGradient()) {
VLOG(3) << "add pre op " << inp_var->PreOp()->Type() << " in slot "
<< inp_name;
......@@ -371,11 +369,10 @@ class PYBIND11_HIDDEN OpBase {
platform::Place place_;
VarBasePtrMap input_vars_;
VarBasePtrMap output_vars_;
OpBasePtrMap pre_ops_;
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
VarBaseWeakPtrMap outputs_ref;
// Inputs to a vector of bwd ops.
std::vector<VarBasePtrMap> grad_input_vars_;
// Outputs to a vector of bwd ops.
......@@ -390,8 +387,9 @@ class Layer {
public:
virtual ~Layer() {}
virtual std::vector<VarBase*> Forward(const std::vector<VarBase*>& inputs) {
std::vector<VarBase*> vars;
virtual std::vector<std::shared_ptr<VarBase>> Forward(
const std::vector<std::shared_ptr<VarBase>>& inputs) {
std::vector<std::shared_ptr<VarBase>> vars;
return vars;
}
};
......@@ -412,7 +410,7 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
var_set_() {
input_names_.reserve(inputs_->size());
for (auto& it : *inputs_) {
for (imperative::VarBase* var : it.second) {
for (std::shared_ptr<imperative::VarBase> var : it.second) {
input_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
......@@ -420,7 +418,7 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
output_names_.reserve(outputs_->size());
for (auto& it : *outputs_) {
for (imperative::VarBase* var : it.second) {
for (std::shared_ptr<imperative::VarBase> var : it.second) {
output_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
......@@ -516,7 +514,8 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
const framework::AttributeMap* attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, imperative::VarBase*> var_set_;
std::unordered_map<std::string, std::shared_ptr<imperative::VarBase>>
var_set_;
};
} // namespace imperative
......
......@@ -46,23 +46,25 @@ void CreateGradOp(const framework::OpDesc& op_desc,
}
}
void CreateNoBuffuerGrad(VarBase* var, platform::DeviceContext* dev_ctx) {
void CreateNoBuffuerGrad(std::shared_ptr<imperative::VarBase> var,
platform::DeviceContext* dev_ctx) {
PADDLE_ENFORCE_NOT_NULL(var, "Could not get valid var base");
PADDLE_ENFORCE_NOT_NULL(dev_ctx,
"Could not get valid device from forward op");
if (var->grads_ == nullptr) {
auto& var_t = var->var_->Get<framework::LoDTensor>();
var->grads_ = new VarBase(var->GradName(), framework::proto::VarType::FP32,
framework::vectorize(var_t.dims()),
dev_ctx->GetPlace(), true, false, false);
var->grads_ = std::shared_ptr<imperative::VarBase>(
new VarBase(var->GradName(), framework::proto::VarType::FP32,
framework::vectorize(var_t.dims()), dev_ctx->GetPlace(),
var->IsStopGradient(), false, false));
}
}
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
platform::Place result = place;
for (auto it : inputs) {
for (VarBase* var : it.second) {
for (const auto& it : inputs) {
for (const std::shared_ptr<imperative::VarBase>& var : it.second) {
platform::Place tmp_place =
var->var_->Get<framework::LoDTensor>().place();
if (!platform::is_same_place(tmp_place, result)) {
......@@ -96,7 +98,7 @@ framework::VariableNameMap CreateInputVarNameMap(
auto var_vector = it->second;
std::vector<std::string> args;
args.reserve(var_vector.size());
for (VarBase* var_base : var_vector) {
for (std::shared_ptr<imperative::VarBase> var_base : var_vector) {
args.emplace_back(var_base->Name());
}
result[in.name()] = args;
......@@ -124,7 +126,7 @@ framework::VariableNameMap CreateOutputVarNameMap(
auto var_vector = it->second;
std::vector<std::string> args;
args.reserve(var_vector.size());
for (VarBase* var_base : var_vector) {
for (const std::shared_ptr<imperative::VarBase>& var_base : var_vector) {
args.emplace_back(var_base->Name());
}
result[out.name()] = args;
......@@ -135,9 +137,8 @@ framework::VariableNameMap CreateOutputVarNameMap(
Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {}
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
VarBasePtrMap* outputs,
framework::AttributeMap attrs_map,
void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
VarBasePtrMap* outputs, framework::AttributeMap attrs_map,
const platform::Place expected_place,
const bool stop_gradient) {
platform::RecordEvent record_event(op->type_);
......@@ -145,12 +146,11 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::VariableValueMap outvars_map;
// Construct input_vars_map and output_vars_map
std::map<std::string, VarBase*> current_vars_map;
op->input_vars_ = inputs;
for (auto it : op->input_vars_) {
std::map<std::string, std::shared_ptr<imperative::VarBase>> current_vars_map;
for (auto it : inputs) {
auto& invars = invars_map[it.first];
invars.reserve(it.second.size());
for (VarBase* inp : it.second) {
for (std::shared_ptr<imperative::VarBase> inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->Type(),
inp->Name());
......@@ -165,13 +165,15 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->TrackPreOp(it.first, it.second);
}
op->output_vars_ = *outputs;
for (auto it : op->output_vars_) {
for (const auto& it : *outputs) {
auto& outvars = outvars_map[it.first];
const std::vector<VarBase*>& outputs = it.second;
outvars.reserve(outputs.size());
for (size_t i = 0U; i < outputs.size(); ++i) {
VarBase* out = outputs[i];
const std::vector<std::shared_ptr<imperative::VarBase>>& outputs_tmp =
it.second;
outvars.reserve(outputs_tmp.size());
for (size_t i = 0U; i < outputs_tmp.size(); ++i) {
// Add weak_ptr to track outputs
op->outputs_ref[it.first].emplace_back(outputs_tmp[i]);
std::shared_ptr<imperative::VarBase> out = outputs_tmp[i];
outvars.emplace_back(out->var_.get());
out->TrackPreOp(op, it.first, i, stop_gradient);
if (!stop_gradient) {
......@@ -223,8 +225,6 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx,
prepared_op.ctx, prepared_op.kernel_configs));
// construct backward op
std::set<std::string> vars_saved_for_backward;
if (!stop_gradient) {
VLOG(5) << "start construct backward op";
......@@ -258,13 +258,13 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
// Forward inputs or outputs.
grad_in_vars.emplace_back(fwd_var_it->second);
} else {
VarBase* var = current_vars_map[var_it->second];
std::shared_ptr<imperative::VarBase> var =
current_vars_map[var_it->second];
CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext());
// Douts.
var->grads_->SetPreOp(var->PreOp());
grad_in_vars.emplace_back(var->grads_);
}
vars_saved_for_backward.insert(it.first);
}
}
......@@ -276,16 +276,17 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
"Could not found the grad op output var, should this "
"operator %s's stop gradient be True",
op->Type());
VarBase* var = current_vars_map[var_it->second];
std::shared_ptr<imperative::VarBase> var =
current_vars_map[var_it->second];
CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext());
var->grads_->SetPreOp(var->PreOp());
grad_out_vars.push_back(var->grads_);
VLOG(3) << "grads output var name: " << var->name_;
}
}
}
}
return vars_saved_for_backward;
}
} // namespace imperative
} // namespace paddle
......@@ -36,9 +36,6 @@ void CreateGradOp(const framework::OpDesc& op_desc,
framework::OpDesc** grad_op_desc,
std::unordered_map<std::string, std::string>* grad_to_var);
void InitVar(const VarBase* var, framework::Variable* grad_var,
platform::DeviceContext* dev_ctx);
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs);
class Tracer {
......@@ -47,7 +44,7 @@ class Tracer {
virtual ~Tracer() {}
std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs,
void Trace(OpBase* op, const VarBasePtrMap& inputs,
VarBasePtrMap* outputs, // NOLINT
framework::AttributeMap attrs_map,
const platform::Place expected_place,
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
......@@ -26,12 +27,17 @@ namespace imperative {
class VarBase;
class OpBase;
typedef std::map<std::string, std::vector<VarBase*>> VarBasePtrMap;
typedef std::map<std::string, std::vector<const VarBase*>> ConstVarBasePtrMap;
typedef std::map<std::string, std::vector<std::shared_ptr<VarBase>>>
VarBasePtrMap;
typedef std::map<std::string, std::vector<std::weak_ptr<VarBase>>>
VarBaseWeakPtrMap;
typedef std::map<std::string, std::vector<const std::shared_ptr<VarBase>>>
ConstVarBasePtrMap;
typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap;
typedef std::unordered_map<
const VarBase*,
std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>>
std::pair<platform::Place,
std::vector<std::pair<int, std::shared_ptr<VarBase>>>>>
BackwardSumMap; // var_grad -> {place, {id -> var_grad@rename}}
typedef std::unordered_map<const VarBase*, int> GradientRef;
......
......@@ -35,9 +35,11 @@ class Layer : public imperative::Layer {
public:
using imperative::Layer::Layer; // Inherit constructors
std::vector<imperative::VarBase *> Forward(
const std::vector<imperative::VarBase *> &inputs) override {
PYBIND11_OVERLOAD(std::vector<imperative::VarBase *>, Layer, Forward,
std::vector<std::shared_ptr<imperative::VarBase>> Forward(
const std::vector<std::shared_ptr<imperative::VarBase>> &inputs)
override {
PYBIND11_OVERLOAD(std::vector<std::shared_ptr<imperative::VarBase>>, Layer,
Forward,
inputs); // NOLINT
}
};
......@@ -72,7 +74,8 @@ void BindImperative(pybind11::module *m_ptr) {
m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); });
py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC")
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
m, "VarBase", R"DOC()DOC")
.def(
py::init<const std::string &, paddle::framework::proto::VarType::Type,
const std::vector<int64_t>, const paddle::platform::CPUPlace,
......@@ -136,8 +139,9 @@ void BindImperative(pybind11::module *m_ptr) {
py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
layer.def(py::init<>())
.def("forward", [](imperative::Layer &self,
const std::vector<imperative::VarBase *> &inputs) {
.def("forward",
[](imperative::Layer &self,
const std::vector<std::shared_ptr<imperative::VarBase>> &inputs) {
return self.Forward(inputs);
});
......@@ -154,7 +158,7 @@ void BindImperative(pybind11::module *m_ptr) {
const platform::CPUPlace expected_place,
const bool stop_gradient = false) {
py::gil_scoped_release release;
return self.Trace(op, inputs, outputs, attrs_map, expected_place,
self.Trace(op, inputs, outputs, attrs_map, expected_place,
stop_gradient);
})
.def("trace", [](imperative::Tracer &self, imperative::OpBase *op,
......@@ -164,7 +168,7 @@ void BindImperative(pybind11::module *m_ptr) {
const platform::CUDAPlace expected_place,
const bool stop_gradient = false) {
py::gil_scoped_release release;
return self.Trace(op, inputs, outputs, attrs_map, expected_place,
self.Trace(op, inputs, outputs, attrs_map, expected_place,
stop_gradient);
});
......
......@@ -24,9 +24,7 @@ __all__ = ['Tracer']
def release_op(op):
del framework._dygraph_tracer()._ops[op._trace_id].inputs
del framework._dygraph_tracer()._ops[op._trace_id].outputs
del framework._dygraph_tracer()._ops[op._trace_id].backward_refs
del framework._dygraph_tracer()._ops[op._trace_id]
class Tracer(core.Tracer):
......@@ -55,7 +53,6 @@ class Tracer(core.Tracer):
def trace_op(self, op, inputs, outputs, stop_gradient=False):
# TODO(hy): previous version will cause memory failed
op.inputs = inputs
inps = defaultdict(list)
for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable):
......@@ -64,7 +61,6 @@ class Tracer(core.Tracer):
for var in vars:
inps[k].append(var._ivar)
op.outputs = outputs
outs = defaultdict(list)
for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable):
......@@ -76,29 +72,16 @@ class Tracer(core.Tracer):
# record op's trace id
op.iop._trace_id = self._trace_id
backward_refs = self.trace(op.iop, inps, outs, op.attrs,
framework._current_expected_place(),
stop_gradient)
self.trace(op.iop, inps, outs, op.attrs,
framework._current_expected_place(), stop_gradient)
if not stop_gradient and self._train_mode:
self._trace_id += 1
self._ops[op.iop._trace_id] = op
# register backward hooks and variables if needed
if len(backward_refs) > 0:
op.iop.register_backward_hooks(release_op)
# TODO(minqiyang): remove all inputs and outputs after separate
# var and grad
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(inputs):
if k in backward_refs:
op.backward_refs[k] = inputs[k]
for k, v in six.iteritems(outputs):
if k in backward_refs:
op.backward_refs[k] = outputs[k]
def train_mode(self):
self._train_mode = True
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.dygraph.nn import Embedding
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
import numpy as np
import six
class RecurrentTest(fluid.Layer):
def __init__(self, name_scope):
super(RecurrentTest, self).__init__(name_scope)
def forward(self, in1, in2):
out = fluid.layers.mul(in1, in2)
sum_out = fluid.layers.reduce_sum(out)
return sum_out, out
class TestRecurrentFeed(unittest.TestCase):
def test_recurrent_feed(self):
seed = 90
original_np1 = np.arange(1, 5).reshape(2, 2).astype("float32")
original_np2 = np.arange(5, 9).reshape(2, 2).astype("float32")
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
original_in1 = to_variable(original_np1)
original_in2 = to_variable(original_np2)
rt = RecurrentTest("RecurrentTest")
for i in range(3):
sum_out, out = rt(original_in1, original_in2)
original_in1 = out
sum_out_value = sum_out.numpy()
sum_out.backward()
rt.clear_gradients()
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
in1 = fluid.layers.data(
name="inp1", shape=[2, 2], append_batch_size=False)
in2 = fluid.layers.data(
name="inp2", shape=[2, 2], append_batch_size=False)
rt1 = RecurrentTest("RecurrentTest")
static_sum_out, static_out = rt1(in1, in2)
fluid.backward.append_backward(static_sum_out)
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
fetch_list = [static_sum_out, static_out]
for i in range(3):
out = exe.run(
fluid.default_main_program(),
feed={"inp1": original_np1,
"inp2": original_np2},
fetch_list=fetch_list)
static_out_value = out[1]
static_sum_out = out[0]
original_np1 = static_out_value
self.assertTrue(np.array_equal(static_sum_out, sum_out_value))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册