提交 051eaa5f 编写于 作者: T tangwei12

add ditriubted attrs

上级 057a6450
......@@ -733,6 +733,9 @@ def load_inference_model(dirname,
if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname)
if training_role == "PSERVER":
_load_lookup_table_vars(executor, dirname, program, role_id)
if model_filename is not None:
model_filename = os.path.basename(model_filename)
else:
......@@ -749,10 +752,7 @@ def load_inference_model(dirname,
load_persistables(executor, dirname, program, params_filename)
if pserver_endpoints:
_endpoints_replacement(program, pserver_endpoints)
if training_role == "PSERVER":
_load_lookup_table_vars(executor, dirname, program, role_id)
program = _endpoints_replacement(program, pserver_endpoints)
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
......@@ -871,8 +871,10 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id):
def _endpoints_replacement(program, endpoints):
ENDPOINT_MAP = "epmap"
for op in program.global_block().ops:
if op.attrs.has_key(ENDPOINT_MAP):
op.attrs[ENDPOINT_MAP] = endpoints
if op.has_attr(ENDPOINT_MAP):
op.set_attr(ENDPOINT_MAP, endpoints)
program = program.clone()
return program
def get_parameter_value(para, executor):
......
......@@ -501,6 +501,8 @@ class DistributeTranspiler(object):
checkpoint_block_id = self._create_checkpoint_save_block(
pserver_program, table_opt_block.idx)
pserver_program._distributed_lookup_table = self.table_name
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table:
......@@ -527,9 +529,13 @@ class DistributeTranspiler(object):
outputs={},
attrs=attrs)
# add slice vars
slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint)
pserver_program._slice_vars_and_atts = slice_vars_and_atts
# add distributed attrs
pserver_program._slice_vars_and_atts = self._get_slice_vars_and_atts(
endpoint)
pserver_program._is_distributed = True
pserver_program._endpoints = self.pserver_endpoints
pserver_program._is_chief = self.trainer_id == 0
pserver_program._distributed_lookup_table = self.table_name if self.table_name else None
pserver_program._sync_with_cpp()
return pserver_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册