提交 21071f71 编写于 作者: T typhoonzero

no create trainer var on listen_and_serv

上级 b0096361
...@@ -85,7 +85,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -85,7 +85,7 @@ class ListenAndServOp : public framework::OperatorBase {
rpc_service_->SetScope(&recv_scope); rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx); rpc_service_->SetDevCtx(&dev_ctx);
auto ins = Inputs("X"); auto ins = Inputs("X");
auto fan_in = ins.size(); auto fan_in = Attr<int>("Fanin");
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program(); auto *program = block->Program();
...@@ -163,6 +163,8 @@ from send_op and send back variables to recv_op. ...@@ -163,6 +163,8 @@ from send_op and send back variables to recv_op.
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side."); "BlockID to run on server side.");
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
} }
}; };
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from __future__ import print_function from __future__ import print_function
import framework import framework
from framework import Program, default_main_program, Parameter, Variable from framework import Program, default_main_program, default_startup_program, Parameter, Variable
import optimizer import optimizer
from layer_helper import LayerHelper from layer_helper import LayerHelper
from distributed_spliter import * from distributed_spliter import *
...@@ -131,6 +131,7 @@ class DistributeTranspiler: ...@@ -131,6 +131,7 @@ class DistributeTranspiler:
# 4. append concat_op to trainer to update local weights. # 4. append concat_op to trainer to update local weights.
# 5. create new program for parameter server. # 5. create new program for parameter server.
# 6. create parameter server program by split_method generated endpoint->VarBlock # 6. create parameter server program by split_method generated endpoint->VarBlock
# 7. update startup_program, rename variables to variables with trainer_id
pserver_endpoints = pservers.split(",") pserver_endpoints = pservers.split(",")
...@@ -175,7 +176,6 @@ class DistributeTranspiler: ...@@ -175,7 +176,6 @@ class DistributeTranspiler:
shape=[0]) shape=[0])
# create send_op # create send_op
print("send inputs: ", send_inputs)
send_op = program.global_block().append_op( send_op = program.global_block().append_op(
type="send", type="send",
inputs={"X": send_inputs}, inputs={"X": send_inputs},
...@@ -194,6 +194,15 @@ class DistributeTranspiler: ...@@ -194,6 +194,15 @@ class DistributeTranspiler:
outputs={"Out": [orig_param]}, outputs={"Out": [orig_param]},
attrs={"axis": 0}) attrs={"axis": 0})
# step 7
startup_prog = default_startup_program()
for varname in startup_prog.global_block().vars.keys():
if varname in param_var_mapping and \
len(param_var_mapping[varname]) == 1:
new_var_name = "%s.trainer_%d" % \
(varname, self.trainer_id)
startup_prog.global_block().rename_var(varname, new_var_name)
def _create_vars_from_blocklist(self, program, block_list): def _create_vars_from_blocklist(self, program, block_list):
# Create respective variables using the block_list # Create respective variables using the block_list
block_map = dict() block_map = dict()
...@@ -210,7 +219,6 @@ class DistributeTranspiler: ...@@ -210,7 +219,6 @@ class DistributeTranspiler:
new_var_name = "%s.trainer_%d" % \ new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id) (orig_var.name, self.trainer_id)
program.global_block().rename_var(varname, new_var_name) program.global_block().rename_var(varname, new_var_name)
print("renaming OK...", varname, new_var_name)
var_mapping[varname] = \ var_mapping[varname] = \
[program.global_block().var(new_var_name)] [program.global_block().var(new_var_name)]
continue continue
...@@ -377,10 +385,7 @@ class DistributeTranspiler: ...@@ -377,10 +385,7 @@ class DistributeTranspiler:
new_inputs = dict() new_inputs = dict()
# update param/grad shape first, then other inputs like # update param/grad shape first, then other inputs like
# moment can use the updated shape # moment can use the updated shape
print("mark1")
for key in opt_op.input_names: for key in opt_op.input_names:
# print("opt type: ", opt_op.type)
# print("opt op input: ", key)
if key == "Grad": if key == "Grad":
grad_block = None grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]: for g in self.param_grad_ep_mapping[endpoint]["grads"]:
...@@ -427,7 +432,6 @@ class DistributeTranspiler: ...@@ -427,7 +432,6 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
print("mark2")
for key in opt_op.input_names: for key in opt_op.input_names:
if key in ["Param", "Grad"]: if key in ["Param", "Grad"]:
continue continue
...@@ -451,7 +455,6 @@ class DistributeTranspiler: ...@@ -451,7 +455,6 @@ class DistributeTranspiler:
inputs=new_inputs, inputs=new_inputs,
outputs=outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
print("mark3")
def _append_pserver_non_opt_ops(self, optimize_block, opt_op): def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
program = optimize_block.program program = optimize_block.program
...@@ -505,8 +508,6 @@ class DistributeTranspiler: ...@@ -505,8 +508,6 @@ class DistributeTranspiler:
suff_idx = v.name.find(".trainer_") suff_idx = v.name.find(".trainer_")
if suff_idx >= 0: if suff_idx >= 0:
orig_var_name = v.name[:suff_idx] orig_var_name = v.name[:suff_idx]
print("create variable for program: %s.trainer_%d" %
(orig_var_name, trainer_id))
var = pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id), name="%s.trainer_%d" % (orig_var_name, trainer_id),
persistable=True, persistable=True,
...@@ -517,11 +518,6 @@ class DistributeTranspiler: ...@@ -517,11 +518,6 @@ class DistributeTranspiler:
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
# Iterate through the ops and append ops as needed # Iterate through the ops and append ops as needed
for idx, opt_op in enumerate(self.optimize_ops): for idx, opt_op in enumerate(self.optimize_ops):
print("mark0")
print(opt_op.inputs.keys())
for v in opt_op.inputs.values():
print(v.name)
print(v.shape)
is_op_on_pserver = self._is_op_on_pserver(endpoint, is_op_on_pserver = self._is_op_on_pserver(endpoint,
self.optimize_ops, idx) self.optimize_ops, idx)
if not is_op_on_pserver: if not is_op_on_pserver:
...@@ -547,7 +543,7 @@ class DistributeTranspiler: ...@@ -547,7 +543,7 @@ class DistributeTranspiler:
# p.name # p.name
# for p in self.param_grad_ep_mapping[endpoint]["grads"] # for p in self.param_grad_ep_mapping[endpoint]["grads"]
# ], # ],
# "Fanin": self.trainers "Fanin": self.trainers
}) })
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
......
...@@ -761,17 +761,6 @@ class Block(object): ...@@ -761,17 +761,6 @@ class Block(object):
else: else:
raise ValueError("unsupported var type: %s", type(v)) raise ValueError("unsupported var type: %s", type(v))
def _clear_op_io_for_var(name):
for op in self.ops:
for k in op.inputs.keys():
if op.inputs[k].name == name:
op.inputs[k] = None
for k in op.outputs.keys():
if op.outputs[k].name == name:
op.outputs[k] = None
_clear_op_io_for_var(name)
self.desc.rename_var(name, new_name) self.desc.rename_var(name, new_name)
d = self.desc.find_var(new_name) d = self.desc.find_var(new_name)
var = None var = None
...@@ -797,17 +786,6 @@ class Block(object): ...@@ -797,17 +786,6 @@ class Block(object):
# rename the python side, sync_with_cpp will only add # rename the python side, sync_with_cpp will only add
# new vars/ops to python side. # new vars/ops to python side.
self.vars[new_name] = var self.vars[new_name] = var
for op in self.ops:
print("### rename op i/o ", name, op.inputs)
if op.inputs:
for k in op.inputs.keys():
if op.inputs[k] == None:
print("rename input: ", name, var)
op.inputs[k] = var
if op.outputs:
for k in op.outputs.keys():
if op.outputs[k] == None:
op.outputs[k] = var
del self.vars[name] del self.vars[name]
self.sync_with_cpp() self.sync_with_cpp()
...@@ -901,7 +879,6 @@ class Block(object): ...@@ -901,7 +879,6 @@ class Block(object):
for p in other.iter_parameters(): for p in other.iter_parameters():
assert isinstance(p, Parameter) assert isinstance(p, Parameter)
v = self.vars.get(p.name, None) v = self.vars.get(p.name, None)
print("var shape to copy", v, p)
if v is None: if v is None:
raise ValueError("copy_param_info_from should be invoked with " raise ValueError("copy_param_info_from should be invoked with "
"same topology") "same topology")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册