提交 8acad27e 编写于 作者: T typhoonzero

refine code

上级 4b91cb52
...@@ -75,8 +75,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -75,8 +75,8 @@ class ListenAndServOp : public framework::OperatorBase {
server_thread_->join(); server_thread_->join();
} }
void Run(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
...@@ -101,7 +101,6 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -101,7 +101,6 @@ class ListenAndServOp : public framework::OperatorBase {
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0); rpc_service_->SetCond(0);
size_t recv_var_cnt = 0; size_t recv_var_cnt = 0;
size_t update_param_cnt = 0;
int batch_barrier = 0; int batch_barrier = 0;
while (batch_barrier != fan_in) { while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get(); const detail::MessageWithName &v = rpc_service_->Get();
...@@ -128,29 +127,26 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -128,29 +127,26 @@ class ListenAndServOp : public framework::OperatorBase {
} }
} }
} }
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
if (exit_flag) { if (exit_flag) {
rpc_service_->ShutDown(); rpc_service_->ShutDown();
} }
VLOG(3) << "run optimize graph...";
try { try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/); false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next // sum the input sparse variables which rows is empty at the next
// mini-batch. // mini-batch.
// TOOD(Yancey1989): move the reset action into an operator, we couldn't // TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator. // have any hide logic in the operator.
for (auto &var : sparse_vars) { for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} }
rpc_service_->SetCond(1); rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(update_param_cnt); // FIXME(typhoonzero): use another condition to sync wait clients get.
grads_counter_.clear(); rpc_service_->WaitClientGet(ins.size());
sparse_vars.clear(); sparse_vars.clear();
} // while(true) } // while(true)
} }
...@@ -158,7 +154,6 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -158,7 +154,6 @@ class ListenAndServOp : public framework::OperatorBase {
protected: protected:
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_; std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
std::shared_ptr<std::thread> server_thread_; std::shared_ptr<std::thread> server_thread_;
mutable std::unordered_map<std::string, int> grads_counter_;
}; };
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -32,8 +32,8 @@ class RecvOp : public framework::OperatorBase { ...@@ -32,8 +32,8 @@ class RecvOp : public framework::OperatorBase {
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
......
...@@ -48,8 +48,8 @@ class SendOp : public framework::OperatorBase { ...@@ -48,8 +48,8 @@ class SendOp : public framework::OperatorBase {
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
auto ins = Inputs("X"); auto ins = Inputs("X");
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
......
...@@ -147,6 +147,21 @@ class DistributeTranspiler: ...@@ -147,6 +147,21 @@ class DistributeTranspiler:
Use different methods to split trainable variables to different Use different methods to split trainable variables to different
parameter servers. parameter servers.
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
5. add listen_and_serv op
:param optimize_ops: op list of optimization, should be the :param optimize_ops: op list of optimization, should be the
return value of Optimizer.minimize return value of Optimizer.minimize
:type optimize_ops: list :type optimize_ops: list
...@@ -154,7 +169,7 @@ class DistributeTranspiler: ...@@ -154,7 +169,7 @@ class DistributeTranspiler:
:type params_grads: list :type params_grads: list
:param trainer_id: one unique id for each trainer in a job. :param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int :type trainer_id: int
:param program: program to optimize, default is default_main_program :param program: program to transpile, default is default_main_program
:type program: Program :type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174" :param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string :type pservers: string
...@@ -174,27 +189,15 @@ class DistributeTranspiler: ...@@ -174,27 +189,15 @@ class DistributeTranspiler:
# like Kubernetes, we should port this to use etcd later when developing # like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance. # fluid distributed training with fault-tolerance.
self.trainer_id = trainer_id self.trainer_id = trainer_id
# steps to transpile:
# 1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
# 2. modify trainer program add split_op to each Grad.
# 3. append send_op to trainer.
# 4. append concat_op to trainer to update local weights.
# 5. create new program for parameter server.
# 6. create parameter server program by split_method generated endpoint->VarBlock
# 7. update startup_program, rename variables to variables with trainer_id
pserver_endpoints = pservers.split(",") pserver_endpoints = pservers.split(",")
# step1 # step1
param_list = [pg[0] for pg in params_grads] param_list = [pg[0] for pg in params_grads]
grad_list = [pg[1] for pg in params_grads] grad_list = [pg[1] for pg in params_grads]
# TODO: add split selected rows support
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
# step2 # step2
grad_var_mapping = self._append_split_op(program, grad_blocks) grad_var_mapping = self._append_split_op(program, grad_blocks)
# step3 # step3
send_inputs = [] send_inputs = []
send_outputs = [] send_outputs = []
...@@ -222,12 +225,12 @@ class DistributeTranspiler: ...@@ -222,12 +225,12 @@ class DistributeTranspiler:
rpc_client_var = program.global_block().create_var( rpc_client_var = program.global_block().create_var(
name="RPC_CLIENT_VAR", name="RPC_CLIENT_VAR",
psersistable=True, persistable=True,
dtype='float32', # dtype and shape is not used in fact dtype='float32', # dtype and shape is not used in fact
shape=[0]) shape=[0])
# create send_op # create send_op
send_op = program.global_block().append_op( program.global_block().append_op(
type="send", type="send",
inputs={"X": send_inputs}, inputs={"X": send_inputs},
outputs={"Out": send_outputs, outputs={"Out": send_outputs,
...@@ -239,23 +242,158 @@ class DistributeTranspiler: ...@@ -239,23 +242,158 @@ class DistributeTranspiler:
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
orig_param = program.global_block().vars[varname] orig_param = program.global_block().vars[varname]
concat = program.global_block().append_op( program.global_block().append_op(
type="concat", type="concat",
inputs={"X": splited_var}, inputs={"X": splited_var},
outputs={"Out": [orig_param]}, outputs={"Out": [orig_param]},
attrs={"axis": 0}) attrs={"axis": 0})
# step 7 def get_trainer_program(self):
startup_prog = default_startup_program() # remove optimize ops and add a send op to main_program
for varname in startup_prog.global_block().vars.keys(): self.program.global_block().delete_ops(self.optimize_ops)
if varname in param_var_mapping and \ return self.program
len(param_var_mapping[varname]) == 1:
new_var_name = "%s.trainer_%d" % \ def get_pserver_program(self, endpoint):
(varname, self.trainer_id) """
startup_prog.global_block().rename_var(varname, new_var_name) Get pserver side program using the endpoint.
NOTE: assume blocks of the same variable is not distributed
def _create_vars_from_blocklist(self, program, block_list): on the same pserver, only change param/grad varnames for
# Create respective variables using the block_list trainers to fetch.
"""
# step1
pserver_program = Program()
# step2
recv_inputs = []
for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
# change client side var name to origin name by
# removing ".trainer_%d" suffix
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
dtype=v.dtype,
shape=v.shape)
print("create origin var: ", orig_var_name)
for trainer_id in xrange(self.trainers):
var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id),
persistable=False,
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(var)
print("create per trainer var: ", var.name)
# step3
optimize_block = pserver_program.create_block(0)
# step 4
# Create a union-find data struct from optimize ops,
# If two ops are connected, we could add these two ops
# into one set.
ufind = self._create_ufind(self.optimize_ops)
# step 4.2
# Iterate through the ops and append optimize op which
# located on current pserver
opt_op_on_pserver = []
for _, op in enumerate(self.optimize_ops):
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op):
opt_op_on_pserver.append(op)
# step 4.3
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
for _, op in enumerate(self.optimize_ops):
for _, opt_op in enumerate(opt_op_on_pserver):
if ufind.is_connected(op, opt_op):
if self._is_opt_op(op):
self._append_pserver_ops(optimize_block, op, endpoint)
else:
self._append_pserver_non_opt_ops(optimize_block, op)
break
# step5 append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
inputs={'X': recv_inputs},
outputs={},
attrs={
"OptimizeBlock": optimize_block,
"endpoint": endpoint,
"Fanin": self.trainers
})
pserver_program.sync_with_cpp()
return pserver_program
def get_startup_program(self, endpoint, pserver_program):
"""
Get startup program for current parameter server.
Modify operator input variables if there are variables that
were split to several blocks.
"""
s_prog = Program()
orig_s_prog = framework.default_startup_program()
params = self.param_grad_ep_mapping[endpoint]["params"]
def _get_splited_name_and_shape(varname):
for idx, splited_param in enumerate(params):
pname = splited_param.name
if same_or_split_var(pname, varname) and varname != pname:
return pname, splited_param.shape
return "", []
# 1. create vars in pserver program to startup program
pserver_vars = pserver_program.global_block().vars
created_var_map = dict()
for _, var in pserver_vars.iteritems():
tmpvar = s_prog.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
created_var_map[var.name] = tmpvar
# 2. rename op outputs
for op in orig_s_prog.global_block().ops:
new_inputs = dict()
new_outputs = dict()
# do not append startup op if var is not on this pserver
op_on_pserver = False
for key in op.output_names:
newname, _ = _get_splited_name_and_shape(op.output(key)[0])
if newname:
op_on_pserver = True
new_outputs[key] = created_var_map[newname]
elif op.output(key)[0] in pserver_vars:
op_on_pserver = True
new_outputs[key] = pserver_vars[op.output(key)[0]]
# most startup program ops have no inputs
new_inputs = self._get_input_map_from_op(pserver_vars, op)
if op_on_pserver:
if op.type in [
"gaussian_random", "fill_constant", "uniform_random"
]:
op.attrs["shape"] = new_outputs["Out"].shape
s_prog.global_block().append_op(
type=op.type,
inputs=new_inputs,
outputs=new_outputs,
attrs=op.attrs)
return s_prog
# ====================== private transpiler functions =====================
def _create_vars_from_blocklist(self,
program,
block_list,
add_trainer_suffix=False):
"""
NOTE: only grads need to be named for different trainers, use
add_trainer_suffix to rename the grad vars.
"""
block_map = dict() block_map = dict()
var_mapping = dict() var_mapping = dict()
for block_str in block_list: for block_str in block_list:
...@@ -266,12 +404,15 @@ class DistributeTranspiler: ...@@ -266,12 +404,15 @@ class DistributeTranspiler:
for varname, splited in block_map.iteritems(): for varname, splited in block_map.iteritems():
orig_var = program.global_block().var(varname) orig_var = program.global_block().var(varname)
if len(splited) == 1: if len(splited) == 1:
# rename var to the trainer_id var if add_trainer_suffix:
new_var_name = "%s.trainer_%d" % \ new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id) (orig_var.name, self.trainer_id)
program.global_block().rename_var(varname, new_var_name) program.global_block().rename_var(varname, new_var_name)
var_mapping[varname] = \ var_mapping[varname] = \
[program.global_block().var(new_var_name)] [program.global_block().var(new_var_name)]
else:
var_mapping[varname] = \
[program.global_block().var(orig_var.name)]
continue continue
var_mapping[varname] = [] var_mapping[varname] = []
...@@ -286,10 +427,16 @@ class DistributeTranspiler: ...@@ -286,10 +427,16 @@ class DistributeTranspiler:
splited_shape = [rows] splited_shape = [rows]
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:]) splited_shape.extend(orig_shape[1:])
new_var_name = ""
if add_trainer_suffix:
new_var_name = "%s.block%d.trainer_%d" % \
(varname, i, self.trainer_id)
else:
new_var_name = "%s.block%d" % \
(varname, i)
var = program.global_block().create_var( var = program.global_block().create_var(
name="%s.block%d.trainer_%d" % name=new_var_name,
(varname, i, self.trainer_id), persistable=False,
psersistable=False,
dtype=orig_var.dtype, dtype=orig_var.dtype,
type=orig_var.type, type=orig_var.type,
shape=splited_shape) # flattend splited var shape=splited_shape) # flattend splited var
...@@ -305,13 +452,12 @@ class DistributeTranspiler: ...@@ -305,13 +452,12 @@ class DistributeTranspiler:
dtype=var.dtype, dtype=var.dtype,
type=var.type, type=var.type,
lod_level=var.lod_level, lod_level=var.lod_level,
# HACK: let all param in pserver be persistable so the child
# program in recv can get them
persistable=True) persistable=True)
def _append_split_op(self, program, gradblocks): def _append_split_op(self, program, gradblocks):
# Split variables that need to be split and append respective ops # Split variables that need to be split and append respective ops
var_mapping = self._create_vars_from_blocklist(program, gradblocks) var_mapping = self._create_vars_from_blocklist(
program, gradblocks, add_trainer_suffix=True)
for varname, splited_vars in var_mapping.iteritems(): for varname, splited_vars in var_mapping.iteritems():
# variable that don't need to split have empty splited_vars # variable that don't need to split have empty splited_vars
if len(splited_vars) <= 1: if len(splited_vars) <= 1:
...@@ -341,24 +487,6 @@ class DistributeTranspiler: ...@@ -341,24 +487,6 @@ class DistributeTranspiler:
"[LOD_TENSOR, SELECTED_ROWS]") "[LOD_TENSOR, SELECTED_ROWS]")
return var_mapping return var_mapping
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops)
return self.program
def _create_var_for_trainers(self, block, var, trainers):
# For each trainer, create the necessary variables
var_list = []
for i in xrange(trainers):
var_each = block.create_var(
name="%s.trainer_%d" % (var.name, i),
psersistable=var.persistable,
dtype=var.dtype,
type=var.type,
shape=var.shape)
var_list.append(var_each)
return var_list
def _get_optimizer_input_shape(self, op_type, varkey, orig_shape, def _get_optimizer_input_shape(self, op_type, varkey, orig_shape,
param_shape): param_shape):
""" """
...@@ -386,6 +514,13 @@ class DistributeTranspiler: ...@@ -386,6 +514,13 @@ class DistributeTranspiler:
pass pass
return orig_shape return orig_shape
def _orig_varname(self, varname):
suff_idx = varname.find(".trainer_")
orig_var_name = ""
if suff_idx >= 0:
orig_var_name = varname[:suff_idx]
return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint): def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block() pserver_block = program.global_block()
...@@ -396,18 +531,23 @@ class DistributeTranspiler: ...@@ -396,18 +531,23 @@ class DistributeTranspiler:
if key == "Grad": if key == "Grad":
grad_block = None grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]: for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var(g.name, opt_op.input(key)[0]): if same_or_split_var(
self._orig_varname(g.name), opt_op.input(key)[0]):
grad_block = g grad_block = g
break break
if not grad_block: if not grad_block:
# do not append this op if current endpoint # do not append this op if current endpoint
# is not dealing with this grad block # is not dealing with this grad block
return return
merged_var = pserver_block.vars[grad_block.name] merged_var = \
# append merging ops if trainers > 1 pserver_block.vars[self._orig_varname(grad_block.name)]
if self.trainers > 1: if self.trainers > 1:
vars2merge = self._create_var_for_trainers( vars2merge = []
pserver_block, grad_block, self.trainers) for i in xrange(self.trainers):
per_trainer_name = "%s.trainer_%d" % \
(self._orig_varname(grad_block.name), i)
vars2merge.append(pserver_block.vars[per_trainer_name])
optimize_block.append_op( optimize_block.append_op(
type="sum", type="sum",
inputs={"X": vars2merge}, inputs={"X": vars2merge},
...@@ -550,76 +690,6 @@ class DistributeTranspiler: ...@@ -550,76 +690,6 @@ class DistributeTranspiler:
return False return False
return False return False
def get_pserver_program(self, endpoint):
"""
Get pserver side program using the endpoint
NOTE: assume blocks of the same variable is not distributed
on the same pserver, only change param/grad varnames for
trainers to fetch. For each pserver endpoint, server side
program must be a sub-set of the original optimization program.
"""
# step5
pserver_program = Program()
recv_inputs = []
for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
pserver_program.global_block().create_var(
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
for trainer_id in xrange(self.trainers):
# change client side var name to origin name by
# removing ".trainer_%d" suffix
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id),
persistable=True,
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(var)
# step6
optimize_block = pserver_program.create_block(0)
# step 6.1
# Create a union-find data struct by optimize ops,
# If two ops are connected, we could add these two ops
# into one set.
ufind = self._create_ufind(self.optimize_ops)
# step 6.2
# Iterate through the ops and append optimize op which
# located on current pserver
opt_op_on_pserver = []
for _, op in enumerate(self.optimize_ops):
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op):
opt_op_on_pserver.append(op)
# step 6.3
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
for _, op in enumerate(self.optimize_ops):
for _, opt_op in enumerate(opt_op_on_pserver):
if ufind.is_connected(op, opt_op):
if self._is_opt_op(op):
self._append_pserver_ops(optimize_block, op, endpoint)
else:
self._append_pserver_non_opt_ops(optimize_block, op)
break
# Append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
inputs={'X': recv_inputs},
outputs={},
attrs={
"OptimizeBlock": optimize_block,
"endpoint": endpoint,
"Fanin": self.trainers
})
pserver_program.sync_with_cpp()
return pserver_program
def _get_input_map_from_op(self, varmap, op): def _get_input_map_from_op(self, varmap, op):
iomap = dict() iomap = dict()
for key in op.input_names: for key in op.input_names:
...@@ -643,61 +713,3 @@ class DistributeTranspiler: ...@@ -643,61 +713,3 @@ class DistributeTranspiler:
else: else:
iomap[key] = vars iomap[key] = vars
return iomap return iomap
def get_startup_program(self, endpoint, pserver_program):
"""
Get startup program for current parameter server.
Modify operator input variables if there are variables that
were split to several blocks.
"""
s_prog = Program()
orig_s_prog = framework.default_startup_program()
params = self.param_grad_ep_mapping[endpoint]["params"]
def _get_splited_name_and_shape(varname):
for idx, splited_param in enumerate(params):
pname = splited_param.name
if same_or_split_var(pname, varname) and varname != pname:
return pname, splited_param.shape
return "", []
# 1. create vars in pserver program to startup program
pserver_vars = pserver_program.global_block().vars
created_var_map = dict()
for _, var in pserver_vars.iteritems():
tmpvar = s_prog.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
created_var_map[var.name] = tmpvar
# 2. rename op outputs
for op in orig_s_prog.global_block().ops:
new_inputs = dict()
new_outputs = dict()
# do not append startup op if var is not on this pserver
op_on_pserver = False
for key in op.output_names:
newname, _ = _get_splited_name_and_shape(op.output(key)[0])
if newname:
op_on_pserver = True
new_outputs[key] = created_var_map[newname]
elif op.output(key)[0] in pserver_vars:
op_on_pserver = True
new_outputs[key] = pserver_vars[op.output(key)[0]]
# most startup program ops have no inputs
new_inputs = self._get_input_map_from_op(pserver_vars, op)
if op_on_pserver:
if op.type in [
"gaussian_random", "fill_constant", "uniform_random"
]:
op.attrs["shape"] = new_outputs["Out"].shape
s_prog.global_block().append_op(
type=op.type,
inputs=new_inputs,
outputs=new_outputs,
attrs=op.attrs)
return s_prog
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册