diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3111674abbf92e3f608b625023993c69cf326351..18ee6b0c92a81d89ced9e81db440e9da507b7a17 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1319,7 +1319,13 @@ class Program(object): self._seed = 0 self._current_role = core.op_proto_and_checker_maker.OpRole.Forward self._op_role_var = [] + + # for distribute + self._is_distributed = False + self._is_chief = False self._slice_vars_and_atts = [] + self._endpoints = [] + self._distributed_lookup_table = None @property def op_role(self): diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 80b258c9c96d53c4ee7196fcbb59e4a22b04ddef..83e1b6bef8b4880f90eae1f7574e14e908ab1630 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -666,11 +666,22 @@ def save_inference_model(dirname, save_persistables(executor, dirname, inference_program, params_filename) + # if there is lookup table, the trainer 0 will notify all pserver to save. + if main_program._is_distributed and main_program._is_chief: + if main_program._distributed_lookup_table: + lookup_table_filename = os.path.join(dirname, "__lookup_table__") + _save_lookup_tables_by_notify( + executor, lookup_table_filename, + main_program._distributed_lookup_table, main_program._endpoints) + def load_inference_model(dirname, executor, model_filename=None, - params_filename=None): + params_filename=None, + training_role=None, + role_id=None, + pserver_endpoints=None): """ Load inference model from a directory @@ -736,6 +747,12 @@ def load_inference_model(dirname, program = Program.parse_from_string(program_desc_str) load_persistables(executor, dirname, program, params_filename) + if pserver_endpoints: + _endpoints_replacement(program, pserver_endpoints) + + if training_role == "PSERVER": + _load_lookup_table_vars(executor, dirname, program, role_id) + feed_target_names = program.desc.get_feed_target_names() fetch_target_names = program.desc.get_fetch_target_names() fetch_targets = [ @@ -745,6 +762,118 @@ def load_inference_model(dirname, return [program, feed_target_names, fetch_targets] +def _save_lookup_tables_by_notify(executor, dirname, lookup_table, + pserver_endpoints): + """ + This function will send checkpoint notify message from Trainer 0 + to all the pservers. + The checkpoint notify message contains lookup table name, + the absolute path on pserver to save lookup_table. + + Args: + executor(Executor): The executor to run for send checkpoint notify. + dirname(str): The folder where to save. + lookup_table(string): the lookup table name, when use distribute + lookup table, we can get lookup table name by DistributeTranspiler. + table_name + ps_endpoint_list(list): the parameter server ip:port list. + when use distribute lookup table, we can get ps_endpoint_list by + distribute arguments. + Return: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + param_path = "./my_paddle_model" + table_name = "share_w" + ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] + + _save_pserver_vars_by_notify(executor=exe, + dirname=param_path, lookup_table=table_name, + pserver_endpoints=ps_endpoints) + """ + + pserver_notify_program = Program() + pserver_notify_block = pserver_notify_program.global_block() + + attrs = {} + attrs['epmap'] = pserver_endpoints.split(",") + attrs['dir'] = dirname + attrs['lookup_table'] = lookup_table + + pserver_notify_block.append_op( + type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) + executor.run(pserver_notify_program) + + +def _load_lookup_table_vars(executor, dirname, program, pserver_id): + """ + The parameter server will load lookup table's local file in + selectedrows variable. + + Args: + executor(Executor): The executor to run for loading persistable variables + dirname(str): The directory path + main_program(Program): Find the variable named table_name in main_program + pserver_id(int): the serial number in pserver_endpoints list + table_name(str): lookup table name + + Returns: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + dirname = "./checkpoints/checkpoint_9/" + prog = fluid.default_main_program() + pserver_id = 1 + table_name = "share_w" + _load_lookup_table_vars(executor=exe, + dirname=dirname, program=prog, pserver_id=pserver_id, + table_name=table_name) + """ + + LOOKUP_TABLE_TYPE = "lookup_table" + lookup_table_var_name = None + + for op in program.global_block().ops: + if op.type == LOOKUP_TABLE_TYPE: + if op.attrs['is_distributed'] is True: + if lookup_table_var_name is None: + lookup_table_var_name = op.input("W")[0] + if lookup_table_var_name != op.input("W")[0]: + raise RuntimeError("all distributed lookup_table_ops" + " should have only one table") + + lookup_table_var = program.global_block().vars[lookup_table_var_name] + if lookup_table_var is None: + return + + lookup_table_dir = os.path.join(dirname, "__lookup_table__") + table_file = "{}.{}".format(lookup_table_var.name, pserver_id) + + load_prog = Program() + load_block = load_prog.global_block() + + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [lookup_table_var]}, + attrs={'file_path': os.path.join(lookup_table_dir, table_file)}) + + executor.run(load_prog) + + +def _endpoints_replacement(program, endpoints): + ENDPOINT_MAP = "epmap" + for op in program.global_block().ops: + if op.attrs.has_key(ENDPOINT_MAP): + op.attrs[ENDPOINT_MAP] = endpoints + + def get_parameter_value(para, executor): """ Get the LoDTensor value of the given parameter.