提交 051eaa5f 编写于 作者: T tangwei12

add ditriubted attrs

上级 057a6450
...@@ -733,6 +733,9 @@ def load_inference_model(dirname, ...@@ -733,6 +733,9 @@ def load_inference_model(dirname,
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname) raise ValueError("There is no directory named '%s'", dirname)
if training_role == "PSERVER":
_load_lookup_table_vars(executor, dirname, program, role_id)
if model_filename is not None: if model_filename is not None:
model_filename = os.path.basename(model_filename) model_filename = os.path.basename(model_filename)
else: else:
...@@ -749,10 +752,7 @@ def load_inference_model(dirname, ...@@ -749,10 +752,7 @@ def load_inference_model(dirname,
load_persistables(executor, dirname, program, params_filename) load_persistables(executor, dirname, program, params_filename)
if pserver_endpoints: if pserver_endpoints:
_endpoints_replacement(program, pserver_endpoints) program = _endpoints_replacement(program, pserver_endpoints)
if training_role == "PSERVER":
_load_lookup_table_vars(executor, dirname, program, role_id)
feed_target_names = program.desc.get_feed_target_names() feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names() fetch_target_names = program.desc.get_fetch_target_names()
...@@ -871,8 +871,10 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id): ...@@ -871,8 +871,10 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id):
def _endpoints_replacement(program, endpoints): def _endpoints_replacement(program, endpoints):
ENDPOINT_MAP = "epmap" ENDPOINT_MAP = "epmap"
for op in program.global_block().ops: for op in program.global_block().ops:
if op.attrs.has_key(ENDPOINT_MAP): if op.has_attr(ENDPOINT_MAP):
op.attrs[ENDPOINT_MAP] = endpoints op.set_attr(ENDPOINT_MAP, endpoints)
program = program.clone()
return program
def get_parameter_value(para, executor): def get_parameter_value(para, executor):
......
...@@ -501,6 +501,8 @@ class DistributeTranspiler(object): ...@@ -501,6 +501,8 @@ class DistributeTranspiler(object):
checkpoint_block_id = self._create_checkpoint_save_block( checkpoint_block_id = self._create_checkpoint_save_block(
pserver_program, table_opt_block.idx) pserver_program, table_opt_block.idx)
pserver_program._distributed_lookup_table = self.table_name
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will # NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place # not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
...@@ -527,9 +529,13 @@ class DistributeTranspiler(object): ...@@ -527,9 +529,13 @@ class DistributeTranspiler(object):
outputs={}, outputs={},
attrs=attrs) attrs=attrs)
# add slice vars # add distributed attrs
slice_vars_and_atts = self._get_slice_vars_and_atts(endpoint) pserver_program._slice_vars_and_atts = self._get_slice_vars_and_atts(
pserver_program._slice_vars_and_atts = slice_vars_and_atts endpoint)
pserver_program._is_distributed = True
pserver_program._endpoints = self.pserver_endpoints
pserver_program._is_chief = self.trainer_id == 0
pserver_program._distributed_lookup_table = self.table_name if self.table_name else None
pserver_program._sync_with_cpp() pserver_program._sync_with_cpp()
return pserver_program return pserver_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册