提交 0e856867 编写于 作者: T typhoonzero

wip

上级 136a5919
...@@ -42,28 +42,30 @@ bool BlockDesc::HasVar(const std::string &name) const { ...@@ -42,28 +42,30 @@ bool BlockDesc::HasVar(const std::string &name) const {
return vars_.find(name) != vars_.end(); return vars_.find(name) != vars_.end();
} }
void BlockDesc::RenameVar(const std::string &old_name, VarDesc *BlockDesc::RenameVar(const std::string &old_name,
const std::string &new_name) { const std::string &new_name) {
if (this->HasVar(old_name)) { if (!this->HasVar(old_name)) {
auto *var = this->Var(old_name); return nullptr;
var->SetName(new_name); }
vars_[new_name].reset(var); need_update_ = true;
vars_.erase(old_name); auto *var = this->Var(old_name);
// rename inputs and outputs VarDesc *new_var = new VarDesc(*(var->Proto()));
for (const auto &op : ops_) { new_var->SetName(new_name);
auto *it = op.get(); // new_var->SetShape(var->GetShape());
for (auto in_name : it->InputArgumentNames()) { // new_var->SetType(var->GetType());
if (in_name == old_name) { // new_var->SetDataType(var->GetDataType());
it->RenameInput(old_name, new_name); // new_var->SetLoDLevel(var->GetLoDLevel());
} // new_var->SetPersistable(var->Persistable());
}
for (auto out_name : it->OutputArgumentNames()) { vars_[new_name].reset(new_var);
if (out_name == old_name) {
it->RenameOutput(old_name, new_name); // rename inputs and outputs
} for (const auto &op : ops_) {
} auto *it = op.get();
} it->Rename(old_name, new_name);
} }
vars_.erase(old_name);
return new_var;
} }
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
......
...@@ -55,7 +55,7 @@ class BlockDesc { ...@@ -55,7 +55,7 @@ class BlockDesc {
bool HasVar(const std::string &var_name) const; bool HasVar(const std::string &var_name) const;
void RenameVar(const std::string &old_name, const std::string &new_name); VarDesc *RenameVar(const std::string &old_name, const std::string &new_name);
VarDesc *FindVarRecursive(const std::string &name_bytes) const; VarDesc *FindVarRecursive(const std::string &name_bytes) const;
......
...@@ -170,12 +170,14 @@ void BindBlockDesc(py::module &m) { ...@@ -170,12 +170,14 @@ void BindBlockDesc(py::module &m) {
[](BlockDesc &self, py::bytes byte_name) { [](BlockDesc &self, py::bytes byte_name) {
std::string name = byte_name; std::string name = byte_name;
return self.HasVar(name); return self.HasVar(name);
}) },
py::return_value_policy::reference)
.def("rename_var", .def("rename_var",
[](BlockDesc &self, py::bytes byte_name, py::bytes byte_name_new) { [](BlockDesc &self, const py::bytes &byte_name,
const py::bytes &byte_name_new) {
std::string name = byte_name; std::string name = byte_name;
std::string new_name = byte_name_new; std::string new_name = byte_name_new;
return self.RenameVar(name, new_name); self.RenameVar(name, new_name);
}) })
.def("has_var_recursive", .def("has_var_recursive",
[](BlockDesc &self, py::bytes byte_name) { [](BlockDesc &self, py::bytes byte_name) {
...@@ -213,7 +215,7 @@ void BindVarDsec(py::module &m) { ...@@ -213,7 +215,7 @@ void BindVarDsec(py::module &m) {
py::class_<VarDesc> var_desc(m, "VarDesc", ""); py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc var_desc
.def("name", .def("name",
[](const VarDesc &self) { [](VarDesc &self) {
py::bytes name = self.Name(); py::bytes name = self.Name();
return name; return name;
}, },
......
...@@ -74,6 +74,8 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -74,6 +74,8 @@ def download(url, module_name, md5sum, save_name=None):
retry = 0 retry = 0
retry_limit = 3 retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum): while not (os.path.exists(filename) and md5file(filename) == md5sum):
if os.path.exists(filename):
print "file md5", md5file(filename), md5sum
if retry < retry_limit: if retry < retry_limit:
retry += 1 retry += 1
else: else:
......
...@@ -175,6 +175,7 @@ class DistributeTranspiler: ...@@ -175,6 +175,7 @@ class DistributeTranspiler:
shape=[0]) shape=[0])
# create send_op # create send_op
print("send inputs: ", send_inputs)
send_op = program.global_block().append_op( send_op = program.global_block().append_op(
type="send", type="send",
inputs={"X": send_inputs}, inputs={"X": send_inputs},
...@@ -204,12 +205,12 @@ class DistributeTranspiler: ...@@ -204,12 +205,12 @@ class DistributeTranspiler:
block_map[varname].append((long(offset), long(size))) block_map[varname].append((long(offset), long(size)))
for varname, splited in block_map.iteritems(): for varname, splited in block_map.iteritems():
orig_var = program.global_block().var(varname) orig_var = program.global_block().var(varname)
if len(splited) == 1: if len(splited) == 1:
# rename var to the trainer_id var # rename var to the trainer_id var
new_var_name = "%s.trainer_%d" % \ new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id) (orig_var.name, self.trainer_id)
program.global_block().rename_var(varname, new_var_name) program.global_block().rename_var(varname, new_var_name)
print("renaming OK...", varname, new_var_name)
var_mapping[varname] = \ var_mapping[varname] = \
[program.global_block().var(new_var_name)] [program.global_block().var(new_var_name)]
continue continue
...@@ -375,7 +376,10 @@ class DistributeTranspiler: ...@@ -375,7 +376,10 @@ class DistributeTranspiler:
new_inputs = dict() new_inputs = dict()
# update param/grad shape first, then other inputs like # update param/grad shape first, then other inputs like
# moment can use the updated shape # moment can use the updated shape
print("mark1")
for key in opt_op.input_names: for key in opt_op.input_names:
# print("opt type: ", opt_op.type)
# print("opt op input: ", key)
if key == "Grad": if key == "Grad":
grad_block = None grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]: for g in self.param_grad_ep_mapping[endpoint]["grads"]:
...@@ -422,6 +426,7 @@ class DistributeTranspiler: ...@@ -422,6 +426,7 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
print("mark2")
for key in opt_op.input_names: for key in opt_op.input_names:
if key in ["Param", "Grad"]: if key in ["Param", "Grad"]:
continue continue
...@@ -453,6 +458,7 @@ class DistributeTranspiler: ...@@ -453,6 +458,7 @@ class DistributeTranspiler:
inputs=new_inputs, inputs=new_inputs,
outputs=outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
print("mark3")
def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op): def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
...@@ -523,6 +529,11 @@ class DistributeTranspiler: ...@@ -523,6 +529,11 @@ class DistributeTranspiler:
optimize_sub_program = Program() optimize_sub_program = Program()
# Iterate through the ops and append ops as needed # Iterate through the ops and append ops as needed
for idx, opt_op in enumerate(self.optimize_ops): for idx, opt_op in enumerate(self.optimize_ops):
print("mark0")
print(opt_op.inputs.keys())
for v in opt_op.inputs.values():
print(v.name)
print(v.shape)
is_op_on_pserver = self._is_op_on_pserver(endpoint, is_op_on_pserver = self._is_op_on_pserver(endpoint,
self.optimize_ops, idx) self.optimize_ops, idx)
if not is_op_on_pserver: if not is_op_on_pserver:
......
...@@ -741,9 +741,75 @@ class Block(object): ...@@ -741,9 +741,75 @@ class Block(object):
""" """
if not self.has_var(name): if not self.has_var(name):
raise ValueError("var %s is not in current" % name) raise ValueError("var %s is not in current" % name)
v = self.var(name)
stop_gradient = None
trainable = None
optimize_attr = None
regularizer = None
gradient_clip_attr = None
error_clip = None
if type(v) == Parameter:
stop_gradient = v.stop_gradient
trainable = v.trainable
optimize_attr = v.optimize_attr
regularizer = v.regularizer
gradient_clip_attr = v.gradient_clip_attr
error_clip = v.error_clip
elif type(v) == Variable:
error_clip = v.error_clip
stop_gradient = v.stop_gradient
else:
raise ValueError("unsupported var type: %s", type(v))
def _clear_op_io_for_var(name):
for op in self.ops:
for k in op.inputs.keys():
if op.inputs[k].name == name:
op.inputs[k] = None
for k in op.outputs.keys():
if op.outputs[k].name == name:
op.outputs[k] = None
_clear_op_io_for_var(name)
self.desc.rename_var(name, new_name) self.desc.rename_var(name, new_name)
d = self.desc.find_var(new_name)
var = None
if type(v) == Parameter:
var = Parameter(
self,
d.shape(),
d.dtype(),
name=new_name,
stop_gradient=stop_gradient,
trainable=trainable,
optimize_attr=optimize_attr,
regularizer=regularizer,
gradient_clip_attr=gradient_clip_attr,
error_clip=error_clip)
elif type(v) == Variable:
var = Variable(
self,
name=new_name,
error_clip=error_clip,
stop_gradient=stop_gradient)
# rename the python side, sync_with_cpp will only add
# new vars/ops to python side.
self.vars[new_name] = var
for op in self.ops:
print("### rename op i/o ", name, op.inputs)
if op.inputs:
for k in op.inputs.keys():
if op.inputs[k] == None:
print("rename input: ", name, var)
op.inputs[k] = var
if op.outputs:
for k in op.outputs.keys():
if op.outputs[k] == None:
op.outputs[k] = var
del self.vars[name]
self.sync_with_cpp() self.sync_with_cpp()
print("renamed var: ", self.var(new_name))
def create_parameter(self, *args, **kwargs): def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block() global_block = self.program.global_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册