From 7ccbdb1b274308c9c11df06d4f8db2d07e491ea9 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 6 Feb 2018 15:04:37 +0800 Subject: [PATCH] for test --- paddle/framework/block_desc.cc | 24 +++++++++++++++++++ paddle/framework/block_desc.h | 2 ++ paddle/pybind/protobuf.cc | 6 +++++ .../paddle/v2/fluid/distribute_transpiler.py | 2 +- python/paddle/v2/fluid/framework.py | 14 ++++------- 5 files changed, 37 insertions(+), 11 deletions(-) diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index dd2ed872521..8579582e7e4 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -42,6 +42,30 @@ bool BlockDesc::HasVar(const std::string &name) const { return vars_.find(name) != vars_.end(); } +void BlockDesc::RenameVar(const std::string &old_name, + const std::string &new_name) { + if (this->HasVar(old_name)) { + auto *var = this->Var(old_name); + var->SetName(new_name); + vars_[new_name].reset(var); + vars_.erase(old_name); + // rename inputs and outputs + for (const auto &op : ops_) { + auto *it = op.get(); + for (auto in_name : it->InputArgumentNames()) { + if (in_name == old_name) { + it->RenameInput(old_name, new_name); + } + } + for (auto out_name : it->OutputArgumentNames()) { + if (out_name == old_name) { + it->RenameOutput(old_name, new_name); + } + } + } + } +} + VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { if (name == kEmptyVarName) return nullptr; diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 4b609e4bcb6..e87a543909d 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -55,6 +55,8 @@ class BlockDesc { bool HasVar(const std::string &var_name) const; + void RenameVar(const std::string &old_name, const std::string &new_name); + VarDesc *FindVarRecursive(const std::string &name_bytes) const; VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 371d6119d4a..f39dc472629 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -171,6 +171,12 @@ void BindBlockDesc(py::module &m) { std::string name = byte_name; return self.HasVar(name); }) + .def("rename_var", + [](BlockDesc &self, py::bytes byte_name, py::bytes byte_name_new) { + std::string name = byte_name; + std::string new_name = byte_name_new; + return self.RenameVar(name, new_name); + }) .def("has_var_recursive", [](BlockDesc &self, py::bytes byte_name) { std::string name = byte_name; diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 4533405e461..89e467b0bdb 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -203,7 +203,7 @@ class DistributeTranspiler: block_map[varname] = [] block_map[varname].append((long(offset), long(size))) for varname, splited in block_map.iteritems(): - orig_var = program.global_block().vars[varname] + orig_var = program.global_block().var(varname) if len(splited) == 1: # rename var to the trainer_id var diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 415960f512f..5e7dd983732 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -740,15 +740,9 @@ class Block(object): """ if not self.has_var(name): raise ValueError("var %s is not in current" % name) - orig_var = self.var(name) - del self.vars[name] - orig_var.name = new_name - self.vars[new_name] = orig_var - for op in self.ops: - if name in op.input_arg_names: - op.rename_input(name, new_name) - if name in op.output_arg_names: - op.rename_output(name, new_name) + self.desc.rename_var(name, new_name) + self.sync_with_cpp() + print("renamed var: ", self.var(new_name)) def create_parameter(self, *args, **kwargs): global_block = self.program.global_block() @@ -837,7 +831,7 @@ class Block(object): for p in other.iter_parameters(): assert isinstance(p, Parameter) v = self.vars.get(p.name, None) - print("var shape to copy", v) + print("var shape to copy", v, p) if v is None: raise ValueError("copy_param_info_from should be invoked with " "same topology") -- GitLab