diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index dd2ed87252102aee6d384f37365d19305f19b281..8579582e7e46a1bf097cb51054de7d563bc423f0 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -42,6 +42,30 @@ bool BlockDesc::HasVar(const std::string &name) const { return vars_.find(name) != vars_.end(); } +void BlockDesc::RenameVar(const std::string &old_name, + const std::string &new_name) { + if (this->HasVar(old_name)) { + auto *var = this->Var(old_name); + var->SetName(new_name); + vars_[new_name].reset(var); + vars_.erase(old_name); + // rename inputs and outputs + for (const auto &op : ops_) { + auto *it = op.get(); + for (auto in_name : it->InputArgumentNames()) { + if (in_name == old_name) { + it->RenameInput(old_name, new_name); + } + } + for (auto out_name : it->OutputArgumentNames()) { + if (out_name == old_name) { + it->RenameOutput(old_name, new_name); + } + } + } + } +} + VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { if (name == kEmptyVarName) return nullptr; diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 4b609e4bcb67bb8dda5924a639e7a8165eda4353..e87a543909d57c6839171ac296963f6e9ac3ef52 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -55,6 +55,8 @@ class BlockDesc { bool HasVar(const std::string &var_name) const; + void RenameVar(const std::string &old_name, const std::string &new_name); + VarDesc *FindVarRecursive(const std::string &name_bytes) const; VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 371d6119d4ab73e683821d0dc5db5194f44a64ce..f39dc47262903219d3c952743fb77346911e9c9d 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -171,6 +171,12 @@ void BindBlockDesc(py::module &m) { std::string name = byte_name; return self.HasVar(name); }) + .def("rename_var", + [](BlockDesc &self, py::bytes byte_name, py::bytes byte_name_new) { + std::string name = byte_name; + std::string new_name = byte_name_new; + return self.RenameVar(name, new_name); + }) .def("has_var_recursive", [](BlockDesc &self, py::bytes byte_name) { std::string name = byte_name; diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 4533405e4613804b54ed9c4fba3bd76b4990cc32..89e467b0bdb244d21c32de968a608d92b732e5a6 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -203,7 +203,7 @@ class DistributeTranspiler: block_map[varname] = [] block_map[varname].append((long(offset), long(size))) for varname, splited in block_map.iteritems(): - orig_var = program.global_block().vars[varname] + orig_var = program.global_block().var(varname) if len(splited) == 1: # rename var to the trainer_id var diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 415960f512f4fb2f90583ba0cbf5f59e19197f0d..5e7dd9837323bd23a19c79d34ba7a291eb7928c8 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -740,15 +740,9 @@ class Block(object): """ if not self.has_var(name): raise ValueError("var %s is not in current" % name) - orig_var = self.var(name) - del self.vars[name] - orig_var.name = new_name - self.vars[new_name] = orig_var - for op in self.ops: - if name in op.input_arg_names: - op.rename_input(name, new_name) - if name in op.output_arg_names: - op.rename_output(name, new_name) + self.desc.rename_var(name, new_name) + self.sync_with_cpp() + print("renamed var: ", self.var(new_name)) def create_parameter(self, *args, **kwargs): global_block = self.program.global_block() @@ -837,7 +831,7 @@ class Block(object): for p in other.iter_parameters(): assert isinstance(p, Parameter) v = self.vars.get(p.name, None) - print("var shape to copy", v) + print("var shape to copy", v, p) if v is None: raise ValueError("copy_param_info_from should be invoked with " "same topology")