提交 0e856867 编写于 作者: T typhoonzero

wip

上级 136a5919
......@@ -42,28 +42,30 @@ bool BlockDesc::HasVar(const std::string &name) const {
return vars_.find(name) != vars_.end();
}
void BlockDesc::RenameVar(const std::string &old_name,
VarDesc *BlockDesc::RenameVar(const std::string &old_name,
const std::string &new_name) {
if (this->HasVar(old_name)) {
if (!this->HasVar(old_name)) {
return nullptr;
}
need_update_ = true;
auto *var = this->Var(old_name);
var->SetName(new_name);
vars_[new_name].reset(var);
vars_.erase(old_name);
VarDesc *new_var = new VarDesc(*(var->Proto()));
new_var->SetName(new_name);
// new_var->SetShape(var->GetShape());
// new_var->SetType(var->GetType());
// new_var->SetDataType(var->GetDataType());
// new_var->SetLoDLevel(var->GetLoDLevel());
// new_var->SetPersistable(var->Persistable());
vars_[new_name].reset(new_var);
// rename inputs and outputs
for (const auto &op : ops_) {
auto *it = op.get();
for (auto in_name : it->InputArgumentNames()) {
if (in_name == old_name) {
it->RenameInput(old_name, new_name);
}
}
for (auto out_name : it->OutputArgumentNames()) {
if (out_name == old_name) {
it->RenameOutput(old_name, new_name);
}
}
}
it->Rename(old_name, new_name);
}
vars_.erase(old_name);
return new_var;
}
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
......
......@@ -55,7 +55,7 @@ class BlockDesc {
bool HasVar(const std::string &var_name) const;
void RenameVar(const std::string &old_name, const std::string &new_name);
VarDesc *RenameVar(const std::string &old_name, const std::string &new_name);
VarDesc *FindVarRecursive(const std::string &name_bytes) const;
......
......@@ -170,12 +170,14 @@ void BindBlockDesc(py::module &m) {
[](BlockDesc &self, py::bytes byte_name) {
std::string name = byte_name;
return self.HasVar(name);
})
},
py::return_value_policy::reference)
.def("rename_var",
[](BlockDesc &self, py::bytes byte_name, py::bytes byte_name_new) {
[](BlockDesc &self, const py::bytes &byte_name,
const py::bytes &byte_name_new) {
std::string name = byte_name;
std::string new_name = byte_name_new;
return self.RenameVar(name, new_name);
self.RenameVar(name, new_name);
})
.def("has_var_recursive",
[](BlockDesc &self, py::bytes byte_name) {
......@@ -213,7 +215,7 @@ void BindVarDsec(py::module &m) {
py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc
.def("name",
[](const VarDesc &self) {
[](VarDesc &self) {
py::bytes name = self.Name();
return name;
},
......
......@@ -74,6 +74,8 @@ def download(url, module_name, md5sum, save_name=None):
retry = 0
retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if os.path.exists(filename):
print "file md5", md5file(filename), md5sum
if retry < retry_limit:
retry += 1
else:
......
......@@ -175,6 +175,7 @@ class DistributeTranspiler:
shape=[0])
# create send_op
print("send inputs: ", send_inputs)
send_op = program.global_block().append_op(
type="send",
inputs={"X": send_inputs},
......@@ -204,12 +205,12 @@ class DistributeTranspiler:
block_map[varname].append((long(offset), long(size)))
for varname, splited in block_map.iteritems():
orig_var = program.global_block().var(varname)
if len(splited) == 1:
# rename var to the trainer_id var
new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id)
program.global_block().rename_var(varname, new_var_name)
print("renaming OK...", varname, new_var_name)
var_mapping[varname] = \
[program.global_block().var(new_var_name)]
continue
......@@ -375,7 +376,10 @@ class DistributeTranspiler:
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
print("mark1")
for key in opt_op.input_names:
# print("opt type: ", opt_op.type)
# print("opt op input: ", key)
if key == "Grad":
grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
......@@ -422,6 +426,7 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar
print("mark2")
for key in opt_op.input_names:
if key in ["Param", "Grad"]:
continue
......@@ -453,6 +458,7 @@ class DistributeTranspiler:
inputs=new_inputs,
outputs=outputs,
attrs=opt_op.attrs)
print("mark3")
def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
# Append the ops for parameters that do not need to be optimized/updated
......@@ -523,6 +529,11 @@ class DistributeTranspiler:
optimize_sub_program = Program()
# Iterate through the ops and append ops as needed
for idx, opt_op in enumerate(self.optimize_ops):
print("mark0")
print(opt_op.inputs.keys())
for v in opt_op.inputs.values():
print(v.name)
print(v.shape)
is_op_on_pserver = self._is_op_on_pserver(endpoint,
self.optimize_ops, idx)
if not is_op_on_pserver:
......
......@@ -741,9 +741,75 @@ class Block(object):
"""
if not self.has_var(name):
raise ValueError("var %s is not in current" % name)
v = self.var(name)
stop_gradient = None
trainable = None
optimize_attr = None
regularizer = None
gradient_clip_attr = None
error_clip = None
if type(v) == Parameter:
stop_gradient = v.stop_gradient
trainable = v.trainable
optimize_attr = v.optimize_attr
regularizer = v.regularizer
gradient_clip_attr = v.gradient_clip_attr
error_clip = v.error_clip
elif type(v) == Variable:
error_clip = v.error_clip
stop_gradient = v.stop_gradient
else:
raise ValueError("unsupported var type: %s", type(v))
def _clear_op_io_for_var(name):
for op in self.ops:
for k in op.inputs.keys():
if op.inputs[k].name == name:
op.inputs[k] = None
for k in op.outputs.keys():
if op.outputs[k].name == name:
op.outputs[k] = None
_clear_op_io_for_var(name)
self.desc.rename_var(name, new_name)
d = self.desc.find_var(new_name)
var = None
if type(v) == Parameter:
var = Parameter(
self,
d.shape(),
d.dtype(),
name=new_name,
stop_gradient=stop_gradient,
trainable=trainable,
optimize_attr=optimize_attr,
regularizer=regularizer,
gradient_clip_attr=gradient_clip_attr,
error_clip=error_clip)
elif type(v) == Variable:
var = Variable(
self,
name=new_name,
error_clip=error_clip,
stop_gradient=stop_gradient)
# rename the python side, sync_with_cpp will only add
# new vars/ops to python side.
self.vars[new_name] = var
for op in self.ops:
print("### rename op i/o ", name, op.inputs)
if op.inputs:
for k in op.inputs.keys():
if op.inputs[k] == None:
print("rename input: ", name, var)
op.inputs[k] = var
if op.outputs:
for k in op.outputs.keys():
if op.outputs[k] == None:
op.outputs[k] = var
del self.vars[name]
self.sync_with_cpp()
print("renamed var: ", self.var(new_name))
def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册