提交 7ccbdb1b 编写于 作者: T typhoonzero

for test

上级 c32040c3
...@@ -42,6 +42,30 @@ bool BlockDesc::HasVar(const std::string &name) const { ...@@ -42,6 +42,30 @@ bool BlockDesc::HasVar(const std::string &name) const {
return vars_.find(name) != vars_.end(); return vars_.find(name) != vars_.end();
} }
void BlockDesc::RenameVar(const std::string &old_name,
const std::string &new_name) {
if (this->HasVar(old_name)) {
auto *var = this->Var(old_name);
var->SetName(new_name);
vars_[new_name].reset(var);
vars_.erase(old_name);
// rename inputs and outputs
for (const auto &op : ops_) {
auto *it = op.get();
for (auto in_name : it->InputArgumentNames()) {
if (in_name == old_name) {
it->RenameInput(old_name, new_name);
}
}
for (auto out_name : it->OutputArgumentNames()) {
if (out_name == old_name) {
it->RenameOutput(old_name, new_name);
}
}
}
}
}
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
if (name == kEmptyVarName) return nullptr; if (name == kEmptyVarName) return nullptr;
......
...@@ -55,6 +55,8 @@ class BlockDesc { ...@@ -55,6 +55,8 @@ class BlockDesc {
bool HasVar(const std::string &var_name) const; bool HasVar(const std::string &var_name) const;
void RenameVar(const std::string &old_name, const std::string &new_name);
VarDesc *FindVarRecursive(const std::string &name_bytes) const; VarDesc *FindVarRecursive(const std::string &name_bytes) const;
VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes); VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes);
......
...@@ -171,6 +171,12 @@ void BindBlockDesc(py::module &m) { ...@@ -171,6 +171,12 @@ void BindBlockDesc(py::module &m) {
std::string name = byte_name; std::string name = byte_name;
return self.HasVar(name); return self.HasVar(name);
}) })
.def("rename_var",
[](BlockDesc &self, py::bytes byte_name, py::bytes byte_name_new) {
std::string name = byte_name;
std::string new_name = byte_name_new;
return self.RenameVar(name, new_name);
})
.def("has_var_recursive", .def("has_var_recursive",
[](BlockDesc &self, py::bytes byte_name) { [](BlockDesc &self, py::bytes byte_name) {
std::string name = byte_name; std::string name = byte_name;
......
...@@ -203,7 +203,7 @@ class DistributeTranspiler: ...@@ -203,7 +203,7 @@ class DistributeTranspiler:
block_map[varname] = [] block_map[varname] = []
block_map[varname].append((long(offset), long(size))) block_map[varname].append((long(offset), long(size)))
for varname, splited in block_map.iteritems(): for varname, splited in block_map.iteritems():
orig_var = program.global_block().vars[varname] orig_var = program.global_block().var(varname)
if len(splited) == 1: if len(splited) == 1:
# rename var to the trainer_id var # rename var to the trainer_id var
......
...@@ -740,15 +740,9 @@ class Block(object): ...@@ -740,15 +740,9 @@ class Block(object):
""" """
if not self.has_var(name): if not self.has_var(name):
raise ValueError("var %s is not in current" % name) raise ValueError("var %s is not in current" % name)
orig_var = self.var(name) self.desc.rename_var(name, new_name)
del self.vars[name] self.sync_with_cpp()
orig_var.name = new_name print("renamed var: ", self.var(new_name))
self.vars[new_name] = orig_var
for op in self.ops:
if name in op.input_arg_names:
op.rename_input(name, new_name)
if name in op.output_arg_names:
op.rename_output(name, new_name)
def create_parameter(self, *args, **kwargs): def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block() global_block = self.program.global_block()
...@@ -837,7 +831,7 @@ class Block(object): ...@@ -837,7 +831,7 @@ class Block(object):
for p in other.iter_parameters(): for p in other.iter_parameters():
assert isinstance(p, Parameter) assert isinstance(p, Parameter)
v = self.vars.get(p.name, None) v = self.vars.get(p.name, None)
print("var shape to copy", v) print("var shape to copy", v, p)
if v is None: if v is None:
raise ValueError("copy_param_info_from should be invoked with " raise ValueError("copy_param_info_from should be invoked with "
"same topology") "same topology")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册