提交 35e6b5e1 编写于 作者: X Xin Pan

polish

test=develop
上级 b80fe826
...@@ -75,16 +75,6 @@ class Autograd { ...@@ -75,16 +75,6 @@ class Autograd {
} }
private: private:
void AccumGrads(int grad_idx, Variable* grad,
std::vector<Variable*>* op_grads) {
if (!(*op_grads)[grad_idx]) {
// FIXME(panyx0718): This should be a deep copy.
(*op_grads)[grad_idx] = grad;
return;
}
AddTo(grad, (*op_grads)[grad_idx]);
}
std::map<OpBase*, int> ComputeDepCounts(OpBase* op) { std::map<OpBase*, int> ComputeDepCounts(OpBase* op) {
std::map<OpBase*, int> ret; std::map<OpBase*, int> ret;
...@@ -108,14 +98,6 @@ class Autograd { ...@@ -108,14 +98,6 @@ class Autograd {
return ret; return ret;
} }
std::vector<Variable*> CreateOpGrads(size_t count) {
std::vector<Variable*> op_grads;
for (size_t i = 0; i < count; ++i) {
op_grads.push_back(nullptr);
}
return op_grads;
}
framework::Scope* scope_; framework::Scope* scope_;
}; };
...@@ -133,7 +115,7 @@ framework::Variable* CreateVariable(const std::string& name, ...@@ -133,7 +115,7 @@ framework::Variable* CreateVariable(const std::string& name,
varname = string::Sprintf("%s@%d", varname, id); varname = string::Sprintf("%s@%d", varname, id);
} }
LOG(ERROR) << "creating var " << varname; VLOG(3) << "creating var " << varname;
framework::Variable* var = scope->Var(varname); framework::Variable* var = scope->Var(varname);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>(); framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
...@@ -165,22 +147,25 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { ...@@ -165,22 +147,25 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
for (const std::string& grad_invar : grad_op_desc_->InputArgumentNames()) { for (const std::string& grad_invar : grad_op_desc_->InputArgumentNames()) {
if (grad_to_var_->find(grad_invar) == grad_to_var_->end()) { if (grad_to_var_->find(grad_invar) == grad_to_var_->end()) {
// grad op inputs can be forward inputs, so not in grad_to_var.
continue; continue;
} }
LOG(ERROR) << "op grad in var " << grad_invar; VLOG(3) << "op grad in var " << grad_invar;
block_->FindRecursiveOrCreateVar(grad_invar); block_->FindRecursiveOrCreateVar(grad_invar);
framework::Variable* var = scope->Var(grad_invar); framework::Variable* var = scope->Var(grad_invar);
const std::string& invar = grad_to_var_->at(grad_invar); const std::string& invar = grad_to_var_->at(grad_invar);
for (VarBase* varbase : *output_vars_) { for (VarBase* varbase : *output_vars_) {
// Use the accumulated grads_ by sharing the input with grads_.
if (varbase->var_desc_->Name() == invar) { if (varbase->var_desc_->Name() == invar) {
var->GetMutable<framework::LoDTensor>()->ShareDataWith( var->GetMutable<framework::LoDTensor>()->ShareDataWith(
varbase->grads_->Get<framework::LoDTensor>()); varbase->grads_->Get<framework::LoDTensor>());
break;
} }
} }
} }
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
LOG(ERROR) << "grad outvar " << outvar; VLOG(3) << "grad outvar " << outvar;
block_->FindRecursiveOrCreateVar(outvar); block_->FindRecursiveOrCreateVar(outvar);
framework::Variable* var = scope->Var(outvar); framework::Variable* var = scope->Var(outvar);
if (!var->IsInitialized()) { if (!var->IsInitialized()) {
...@@ -199,6 +184,7 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { ...@@ -199,6 +184,7 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
opbase->Run(*scope, platform::CPUPlace()); opbase->Run(*scope, platform::CPUPlace());
// `ret` matches exactly with `input_vars_` of forward op.
std::vector<Variable*> ret; std::vector<Variable*> ret;
for (size_t i = 0; i < input_vars_->size(); ++i) { for (size_t i = 0; i < input_vars_->size(); ++i) {
bool found = false; bool found = false;
...@@ -207,7 +193,7 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { ...@@ -207,7 +193,7 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
VarBase* origin_var = (*input_vars_)[i]; VarBase* origin_var = (*input_vars_)[i];
std::string orig_var = grad_to_var_->at(outvar); std::string orig_var = grad_to_var_->at(outvar);
PADDLE_ENFORCE(origin_var->var_desc_->Name() == orig_var); PADDLE_ENFORCE(origin_var->var_desc_->Name() == orig_var);
LOG(ERROR) << "apply grad " << outvar << " with origin " << orig_var; VLOG(3) << "apply grad " << outvar << " with origin " << orig_var;
origin_var->ApplyGrad(scope, var); origin_var->ApplyGrad(scope, var);
found = true; found = true;
ret.push_back(var); ret.push_back(var);
......
...@@ -36,10 +36,7 @@ class VarBase { ...@@ -36,10 +36,7 @@ class VarBase {
var_(nullptr), var_(nullptr),
grads_(nullptr) {} grads_(nullptr) {}
virtual ~VarBase() { virtual ~VarBase() {}
LOG(ERROR) << "deleting var";
LOG(ERROR) << "done deleting var";
}
void ApplyGrad(framework::Scope* scope, framework::Variable* grad); void ApplyGrad(framework::Scope* scope, framework::Variable* grad);
......
...@@ -55,7 +55,7 @@ class Tracer { ...@@ -55,7 +55,7 @@ class Tracer {
framework::BlockDesc* block) { framework::BlockDesc* block) {
framework::Scope* scope = GetScope(block); framework::Scope* scope = GetScope(block);
framework::OpDesc* op_desc = op->op_desc_; framework::OpDesc* op_desc = op->op_desc_;
LOG(ERROR) << "tracer tracing " << op_desc->Type(); VLOG(3) << "tracer tracing " << op_desc->Type();
op_desc->InferShape(*block); op_desc->InferShape(*block);
op_desc->InferVarType(block); op_desc->InferVarType(block);
std::unique_ptr<framework::OperatorBase> op_base = std::unique_ptr<framework::OperatorBase> op_base =
......
...@@ -27,6 +27,8 @@ import pydoc ...@@ -27,6 +27,8 @@ import pydoc
member_dict = collections.OrderedDict() member_dict = collections.OrderedDict()
experimental_namespace = {"paddle.fluid.imperative"}
def visit_member(parent_name, member): def visit_member(parent_name, member):
cur_name = ".".join([parent_name, member.__name__]) cur_name = ".".join([parent_name, member.__name__])
...@@ -50,6 +52,8 @@ def visit_member(parent_name, member): ...@@ -50,6 +52,8 @@ def visit_member(parent_name, member):
def visit_all_module(mod): def visit_all_module(mod):
if (mod.__name__ in experimental_namespace):
return
for member_name in ( for member_name in (
name name
for name in (mod.__all__ if hasattr(mod, "__all__") else dir(mod)) for name in (mod.__all__ if hasattr(mod, "__all__") else dir(mod))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册