diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 5e9901bb87c9a454a393a913b6da6e82266ee719..170e0f839719c71d56008abefb79c7814d0f3e76 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -350,6 +350,22 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)) +paddle.fluid.contrib.load_persistables_for_increment ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.load_persistables_for_inference ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var_name'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.convert_dist_to_sparse_program ArgSpec(args=['program'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.HDFSClient.__init__ ArgSpec(args=['self', 'hadoop_home', 'configs'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.HDFSClient.delete ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.HDFSClient.download ArgSpec(args=['self', 'hdfs_path', 'local_path', 'overwrite', 'unzip'], varargs=None, keywords=None, defaults=(False, False)) +paddle.fluid.contrib.HDFSClient.is_dir ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.contrib.HDFSClient.is_exist ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.contrib.HDFSClient.ls ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.HDFSClient.lsr ArgSpec(args=['self', 'hdfs_path', 'only_file', 'sort'], varargs=None, keywords=None, defaults=(True, True)) +paddle.fluid.contrib.HDFSClient.make_local_dirs ArgSpec(args=['local_path'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.HDFSClient.makedirs ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.HDFSClient.rename ArgSpec(args=['self', 'hdfs_src_path', 'hdfs_dst_path', 'overwrite'], varargs=None, keywords=None, defaults=(False,)) +paddle.fluid.contrib.HDFSClient.upload ArgSpec(args=['self', 'hdfs_path', 'local_path', 'overwrite', 'retry_times'], varargs=None, keywords=None, defaults=(False, 5)) +paddle.fluid.contrib.multi_download ArgSpec(args=['client', 'hdfs_path', 'local_path', 'trainer_id', 'trainers', 'multi_processes'], varargs=None, keywords=None, defaults=(5,)) +paddle.fluid.contrib.multi_upload ArgSpec(args=['client', 'hdfs_path', 'local_path', 'multi_processes', 'overwrite', 'sync'], varargs=None, keywords=None, defaults=(5, False, True)) paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) diff --git a/python/paddle/fluid/contrib/__init__.py b/python/paddle/fluid/contrib/__init__.py index 3bf2fe5db0cb2126295ebfda822eeac8762dbdb7..ece97b661fd7d60f8822439a84ee4403b9e3d81c 100644 --- a/python/paddle/fluid/contrib/__init__.py +++ b/python/paddle/fluid/contrib/__init__.py @@ -22,9 +22,12 @@ from . import op_frequence from .op_frequence import * from . import quantize from .quantize import * +from . import utils +from .utils import * __all__ = [] __all__ += decoder.__all__ __all__ += memory_usage_calc.__all__ __all__ += op_frequence.__all__ __all__ += quantize.__all__ +__all__ += utils.__all__ diff --git a/python/paddle/fluid/contrib/utils/__init__.py b/python/paddle/fluid/contrib/utils/__init__.py index 20b2cc381aaa1b837ce106410246bc8cedb2fc88..1c1c2fb22709189ca03dc543ca551257c8031c1a 100644 --- a/python/paddle/fluid/contrib/utils/__init__.py +++ b/python/paddle/fluid/contrib/utils/__init__.py @@ -13,10 +13,11 @@ # limitations under the License. from __future__ import print_function -#from . import lookup_table_utils -#from .lookup_table_utils import * +from . import lookup_table_utils +from .lookup_table_utils import * from . import hdfs_utils from .hdfs_utils import * -#__all__ = lookup_table_utils.__all__ -__all__ = hdfs_utils.__all__ +__all__ = [] +__all__ += lookup_table_utils.__all__ +__all__ += hdfs_utils.__all__ diff --git a/python/paddle/fluid/contrib/utils/hdfs_utils.py b/python/paddle/fluid/contrib/utils/hdfs_utils.py index baea57ccce0e9ca3a8fab244e43a107a89cfe67d..35ddf97ff2361d8abd34b16761be99990fc3880d 100644 --- a/python/paddle/fluid/contrib/utils/hdfs_utils.py +++ b/python/paddle/fluid/contrib/utils/hdfs_utils.py @@ -14,6 +14,7 @@ """HDFS Utils""" import os +import sys import subprocess import multiprocessing from datetime import datetime @@ -24,7 +25,7 @@ import errno import logging -__all__ = ["HDFSClient", "multi_download"] +__all__ = ["HDFSClient", "multi_download", "multi_upload"] logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') _logger = logging.getLogger("hdfs_utils") @@ -93,13 +94,15 @@ class HDFSClient(object): def upload(self, hdfs_path, local_path, overwrite=False, retry_times=5): """ - upload the local file to hdfs - Args: - hdfs_path: hdfs path, target path - local_path: local file path, source path - overwrite: will overwrite the original file - retry_times: max times retry to upload - Returns: + upload the local file to hdfs + + Args: + hdfs_path(str): the hdfs file path + local_path(str): the local file path + overwrite(bool|None): will overwrite the file on HDFS or not + retry_times(int|5): retry times + + Returns: True or False """ assert hdfs_path is not None @@ -109,7 +112,7 @@ class HDFSClient(object): _logger.warn( "The Local path: {} is dir and I will support it later, return". format(local_path)) - return + return False base = os.path.basename(local_path) if not self.is_exist(hdfs_path): @@ -141,14 +144,16 @@ class HDFSClient(object): def download(self, hdfs_path, local_path, overwrite=False, unzip=False): """ - download from hdfs - Args: - hdfs_path: hdfs path, target path - local_path: local file path, source path - overwrite: will remove original file and overwrite it. - unzip: ignore this param - Returns - True or False + download file from HDFS + + Args: + hdfs_path(str): the hdfs file path + local_path(str): the local file path + overwrite(bool|None): will overwrite the file on HDFS or not + unzip(bool|False): if the download file is compressed by zip, unzip it or not. + + Returns: + True or False """ _logger.info('Downloading %r to %r.', hdfs_path, local_path) _logger.info('Download of %s to %r complete.', hdfs_path, local_path) @@ -188,13 +193,13 @@ class HDFSClient(object): def is_exist(self, hdfs_path=None): """ - whether the remote hdfs path exists? - Args: - hdfs_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp) - fs_name: The default values are the same as in the job configuration - fs_ugi: The default values are the same as in the job configuration - Returns: - True or False + whether the remote HDFS path exists + + Args: + hdfs_path(str): the hdfs file path + + Returns: + True or False """ exist_cmd = ['-test', '-e', hdfs_path] returncode, output, errors = self.__run_hdfs_cmd( @@ -211,13 +216,13 @@ class HDFSClient(object): def is_dir(self, hdfs_path=None): """ - whether the remote hdfs path exists? - Args: - remote_file_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp) - fs_name: The default values are the same as in the job configuration - fs_ugi: The default values are the same as in the job configuration - Returns: - True or False + whether the remote HDFS path is directory + + Args: + hdfs_path(str): the hdfs file path + + Returns: + True or False """ if not self.is_exist(hdfs_path): @@ -237,17 +242,17 @@ class HDFSClient(object): def delete(self, hdfs_path): """ - Remove a file or directory from HDFS. + Remove a file or directory from HDFS. + + whether the remote HDFS path exists Args: - param hdfs_path: HDFS path. - param recursive: Recursively delete files and directories. By default, - this method will raise an :class:`HdfsError` if trying to delete a - non-empty directory. + hdfs_path: HDFS path. + Returns: + True or False This function returns `True` if the deletion was successful and `False` if no file or directory previously existed at `hdfs_path`. - """ _logger.info('Deleting %r.', hdfs_path) @@ -273,16 +278,14 @@ class HDFSClient(object): def rename(self, hdfs_src_path, hdfs_dst_path, overwrite=False): """ - Rename a file or folder. - Args: - :param hdfs_src_path: Source path. - :param hdfs_dst_path: Destination path. If the path already exists and is - a directory, the source will be moved into it. If the path exists and is - a file, or if a parent destination directory is missing, this method will - raise an :class:`HdfsError`. + Move a file or folder on HDFS. + + Args: + hdfs_path(str): HDFS path. + overwrite(bool|False): If the path already exists and overwrite is False, will return False. + Returns: - This function returns `True` if the rename was successful and `False` if - rename was faild. + True or False """ assert hdfs_src_path is not None assert hdfs_dst_path is not None @@ -320,17 +323,20 @@ class HDFSClient(object): raise def makedirs(self, hdfs_path): - """Create a remote directory, recursively if necessary. + """ + Create a remote directory, recursively if necessary. + Args: - :param hdfs_path: Remote path. Intermediate directories will be created - appropriately. + hdfs_path(str): Remote path. Intermediate directories will be created appropriately. + Returns: - True if make a directories was successful, False when make a directiries was failed. + True or False """ _logger.info('Creating directories to %r.', hdfs_path) assert hdfs_path is not None if self.is_exist(hdfs_path): + _logger.error("HDFS path is exist: {}".format(hdfs_path)) return mkdirs_commands = ['-mkdir', hdfs_path] @@ -346,11 +352,13 @@ class HDFSClient(object): def ls(self, hdfs_path): """ - ls a hdfs_path. - Args: - :param hdfs_path: hdfs_path will be ls. + ls directory contents about HDFS hdfs_path + + Args: + hdfs_path(str): Remote HDFS path will be ls. + Returns: - This function returns a `list` that contaion all files in the hdfs_path. + List: a contents list about hdfs_path. """ assert hdfs_path is not None @@ -378,11 +386,15 @@ class HDFSClient(object): def lsr(self, hdfs_path, only_file=True, sort=True): """ - ls a hdfs_path sort by time. - Args: - :param hdfs_path: hdfs_path will be ls. + list directory contents about HDFS hdfs_path recursively + + Args: + hdfs_path(str): Remote HDFS path. + only_file(bool|True): will discard folders. + sort(bool|True): will be sorted by create time. + Returns: - This function returns a `list` that contaion all files sorted by time in the hdfs_path. + List: a contents list about hdfs_path. """ def sort_by_time(v1, v2): @@ -422,21 +434,106 @@ class HDFSClient(object): return ret_lines +def multi_download(client, + hdfs_path, + local_path, + trainer_id, + trainers, + multi_processes=5): + """ + Download files from HDFS using multi process. + + Args: + client(HDFSClient): instance of HDFSClient + hdfs_path(str): path on hdfs + local_path(str): path on local + trainer_id(int): current trainer id + trainers(int): all trainers number + multi_processes(int|5): the download data process at the same time, default=5 + + Returns: + List: + Download files in local folder. + """ + + def __subprocess_download(datas): + for data in datas: + re_path = os.path.relpath(os.path.dirname(data), hdfs_path) + if re_path == os.curdir: + sub_local_re_path = local_path + else: + sub_local_re_path = os.path.join(local_path, re_path) + client.download(data, sub_local_re_path) + + assert isinstance(client, HDFSClient) + + client.make_local_dirs(local_path) + _logger.info("Make local dir {} successfully".format(local_path)) + + all_need_download = client.lsr(hdfs_path, sort=True) + need_download = all_need_download[trainer_id::trainers] + _logger.info("Get {} files From all {} files need to be download from {}". + format(len(need_download), len(all_need_download), hdfs_path)) + + _logger.info("Start {} multi process to download datas".format( + multi_processes)) + procs = [] + for i in range(multi_processes): + process_datas = need_download[i::multi_processes] + p = multiprocessing.Process( + target=__subprocess_download, args=(process_datas, )) + procs.append(p) + p.start() + + # complete the processes + for proc in procs: + proc.join() + + _logger.info("Finish {} multi process to download datas".format( + multi_processes)) + + local_downloads = [] + for data in need_download: + data_name = os.path.basename(data) + re_path = os.path.relpath(os.path.dirname(data), hdfs_path) + if re_path == os.curdir: + local_re_path = os.path.join(local_path, data_name) + else: + local_re_path = os.path.join(local_path, re_path, data_name) + local_downloads.append(local_re_path) + + return local_downloads + + +def getfilelist(path): + rlist = [] + for dir, folder, file in os.walk(path): + for i in file: + t = os.path.join(dir, i) + rlist.append(t) + for r in rlist: + print(r) + + def multi_upload(client, hdfs_path, local_path, multi_processes=5, - overwrite=False): + overwrite=False, + sync=True): """ - Upload file to hdfs. + Upload files to HDFS using multi process. + Args: - :param overwrite: will overwrite hdfs file or not - :param multi_processes: the upload data process at the same time, default=5 - :param client: instance of HDFSClient - :param hdfs_path: path on hdfs - :param local_path: path on local + client(HDFSClient): instance of HDFSClient + hdfs_path(str): path on hdfs + local_path(str): path on local + multi_processes(int|5): the upload data process at the same time, default=5 + overwrite(bool|False): will overwrite file on HDFS or not + sync(bool|True): upload files sync or not. + Returns: - + None """ def __subprocess_upload(datas): @@ -446,13 +543,6 @@ def multi_upload(client, client.upload(hdfs_re_path, data, overwrite, retry_times=5) def get_local_files(path): - """ - Get all local files - Args: - path: local file path - Returns: - A list that contation all files in the path. - """ rlist = [] if not os.path.isdir(path): @@ -488,71 +578,6 @@ def multi_upload(client, multi_processes)) -def multi_download(client, - hdfs_path, - local_path, - trainer_id, - trainers, - file_cnt, - multi_processes=5): - """ - multi_download - Args: - :param client: instance of HDFSClient - :param hdfs_path: path on hdfs - :param local_path: path on local - :param trainer_id: current trainer id - :param trainers: all trainers number - :param file_cnt: all file number - :param multi_processes: the download data process at the same time, default=5 - :return: None - Returns: - A list that be downloaded. - """ - - def __subprocess_download(datas): - for data in datas: - re_path = os.path.relpath(os.path.dirname(data), hdfs_path) - local_re_path = os.path.join(local_path, re_path) - client.download(data, local_re_path) - - assert isinstance(client, HDFSClient) - - client.make_local_dirs(local_path) - _logger.info("Make local dir {} successfully".format(local_path)) - - all_need_download = client.lsr(hdfs_path, sort=True)[:file_cnt] - need_download = all_need_download[trainer_id::trainers] - _logger.info("Get {} files From all {} files need to be download from {}". - format(len(need_download), len(all_need_download), hdfs_path)) - - _logger.info("Start {} multi process to download datas".format( - multi_processes)) - procs = [] - for i in range(multi_processes): - process_datas = need_download[i::multi_processes] - p = multiprocessing.Process( - target=__subprocess_download, args=(process_datas, )) - procs.append(p) - p.start() - - # complete the processes - for proc in procs: - proc.join() - - _logger.info("Finish {} multi process to download datas".format( - multi_processes)) - - local_downloads = [] - for data in need_download: - data_name = os.path.basename(data) - re_path = os.path.relpath(os.path.dirname(data), hdfs_path) - local_re_path = os.path.join(local_path, re_path, data_name) - local_downloads.append(local_re_path) - - return local_downloads - - if __name__ == "__main__": hadoop_home = "/home/client/hadoop-client/hadoop/" diff --git a/python/paddle/fluid/contrib/utils/lookup_table_utils.py b/python/paddle/fluid/contrib/utils/lookup_table_utils.py index cc2418238f98d8e2b9af0cf4290f6088c11e1b92..20e6328d81cc727340ea4a16012f6ee9967ea1e6 100644 --- a/python/paddle/fluid/contrib/utils/lookup_table_utils.py +++ b/python/paddle/fluid/contrib/utils/lookup_table_utils.py @@ -18,14 +18,12 @@ import os import time import logging -import paddle -import paddle.fluid as fluid from paddle.fluid import core from paddle.fluid import io from paddle.fluid import Program __all__ = [ - "load_inference_model", "load_persistable_vars", + "load_persistables_for_increment", "load_persistables_for_inference", "convert_dist_to_sparse_program" ] @@ -80,19 +78,28 @@ def __get_prefetch_op_tuples(main_program): return prefetch_op_tuples -def convert_dist_to_sparse_program(main_program): - if not main_program._distributed_lookup_table: +def convert_dist_to_sparse_program(program): + """ + WARNING: this function will only be used for distributed training with distributed lookup table. + when we train model with distributed lookup table but want to do the local inference, we can use + this function to convert the train program with distributed lookup table to sparse lookup table. + + :param program(Program): the program must be the trainer program, which will be get by the distribute transpiler. + :return: + program: The `program` is a Program, it's the program replace distributed lookup table to sparse lookup table. + """ + if not program._distributed_lookup_table: _logger.warn( "There are no distributed lookup tables need to be converted") return # create table param and grad var in pserver program - origin_emb_var = "{}.origin".format(main_program._distributed_lookup_table) - emb_var = main_program._distributed_lookup_table - main_program.global_block()._rename_var(emb_var, origin_emb_var) - origin_param_var = main_program.global_block().vars[origin_emb_var] + origin_emb_var = "{}.origin".format(program._distributed_lookup_table) + emb_var = program._distributed_lookup_table + program.global_block()._rename_var(emb_var, origin_emb_var) + origin_param_var = program.global_block().vars[origin_emb_var] - param_var = main_program.global_block().create_var( + param_var = program.global_block().create_var( name=emb_var, shape=origin_param_var.shape, dtype=origin_param_var.dtype, @@ -100,28 +107,28 @@ def convert_dist_to_sparse_program(main_program): persistable=True) # parameter must be selected rows param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) - main_program._sync_with_cpp() + program._sync_with_cpp() - prefetch_op_tuples = __get_prefetch_op_tuples(main_program) + prefetch_op_tuples = __get_prefetch_op_tuples(program) split_ids_id = prefetch_op_tuples[0] for idx in range(split_ids_id + 2, split_ids_id - 1, -1): - main_program.global_block()._remove_op(idx) - main_program.desc.flush() + program.global_block()._remove_op(idx) + program.desc.flush() in_out_pairs = zip(prefetch_op_tuples[1], prefetch_op_tuples[2]) for in_out_pair in in_out_pairs: idx = split_ids_id - ids = main_program.global_block().vars[in_out_pair[0]] - out = main_program.global_block().vars[in_out_pair[1]] - __insert_lookup_sparse_table_op(main_program, idx, ids, param_var, out) - main_program.desc.flush() - return main_program + ids = program.global_block().vars[in_out_pair[0]] + out = program.global_block().vars[in_out_pair[1]] + __insert_lookup_sparse_table_op(program, idx, ids, param_var, out) + program.desc.flush() + return program -def load_persistable_vars(executor, dirname, program, lookup_table_var): +def _load_persistable_vars(executor, dirname, program, lookup_table_vars): def _is_checkpoint_var(exclude_fluid_vars=None): """ the checkpoint will not save or load all the variables. @@ -159,8 +166,82 @@ def load_persistable_vars(executor, dirname, program, lookup_table_var): return is_valid - def _load_lookup_table_vars(executor, dirname, main_program, - lookup_table_vars): + io.load_vars( + executor, + dirname=dirname, + main_program=program, + predicate=_is_checkpoint_var(lookup_table_vars), + filename=None) + + +def load_persistables_for_increment(dirname, executor, program, + lookup_table_var, lookup_table_var_path): + """ + WARNING: this function will only be used for distributed training with distributed lookup table. + for increment trainning, the pserver will not only load dense variables, + but also load the suitable lookup table var. Because of slice lookup table + var with HASH, we must load the correct slice var. + + + :param dirname(str): The directory path + :param executor(Executor): The executor to run for loading inference model. + :param program(Program): The parameter server program, which will run on Pserver. + :param lookup_table_var: the distributed lookup tables var name. + :param lookup_table_var_path: the the distributed lookup tables var location. + :return: None + """ + + def __load_lookup_table_vars(executor, main_program, lookup_table_var, + lookup_table_var_path): + emb_var = main_program.global_block().var(lookup_table_var) + + load_program = Program() + load_block = load_program.global_block() + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [emb_var]}, + attrs={'file_path': lookup_table_var_path}) + executor.run(load_program) + + if not os.path.isdir(dirname): + raise ValueError("There is no directory named '%s'", dirname) + + if not os.path.exists(lookup_table_var_path): + raise ValueError("There is no file named '%s'", lookup_table_var_path) + + if not isinstance(program, Program): + raise ValueError("program must be an instance of fluid.Program") + + _logger.info("Start Load Sparse Program With " + "Distributed Lookup Table Vars from {}, time = {}".format( + dirname, time.ctime())) + + _load_persistable_vars(executor, dirname, program, [lookup_table_var]) + __load_lookup_table_vars(executor, program, lookup_table_var, + lookup_table_var_path) + + _logger.info("Finish Load Sparse Program With " + "Distributed Lookup Table Vars from {}, time = {}".format( + dirname, time.ctime())) + + +def load_persistables_for_inference(dirname, executor, program, + lookup_table_var_name): + """ + WARNING: this function will only be used for inference with distributed lookup table. + Inference with distributed lookup table is a little funky, this function will load distributed + lookup table vars into sparse var, can be used in local inference mode. + + :param dirname(str): The directory path + :param executor(Executor): The executor to run for loading inference model. + :param program(Program): The parameter server program, which will run on Pserver. + :param lookup_table_var_name: the distributed lookup tables var name. + :return: None + """ + + def __load_lookup_table_vars(executor, dirname, main_program, + lookup_table_vars): if not os.path.isdir(dirname): raise ValueError("There is no directory named '%s'", dirname) @@ -209,48 +290,34 @@ def load_persistable_vars(executor, dirname, program, lookup_table_var): global_block.append_op(type='delete_var', inputs={'X': sums}) executor.run(convert_program) - _logger.info("Start Load Sparse Program With " - "Distributed Lookup Table Vars from {}, time = {}".format( - dirname, time.ctime())) - - lookup_table_vars = [lookup_table_var] - - io.load_vars( - executor, - dirname=dirname, - main_program=program, - predicate=_is_checkpoint_var(lookup_table_vars), - filename=None) - - _load_lookup_table_vars(executor, dirname, program, lookup_table_vars) - - _logger.info("Finish Load Sparse Program With " - "Distributed Lookup Table Vars from {}, time = {}".format( - dirname, time.ctime())) - - -def load_inference_model(dirname, executor, lookup_table_var_name): if not os.path.isdir(dirname): raise ValueError("There is no directory named '%s'", dirname) - local_model = os.path.join(dirname, model_filename) + if program: + if not isinstance(program, Program): + raise ValueError("program must be an instance of fluid.Program") + else: + local_model = os.path.join(dirname, model_filename) - with open(local_model, "rb") as f: - program_desc_str = f.read() + with open(local_model, "rb") as f: + program_desc_str = f.read() - program = Program.parse_from_string(program_desc_str) + program = Program.parse_from_string(program_desc_str) - if not core._is_program_version_supported(program._version()): - raise ValueError("Unsupported program version: %d\n" % - program._version()) + if not core._is_program_version_supported(program._version()): + raise ValueError("Unsupported program version: %d\n" % + program._version()) - # Binary data also need version. - load_persistable_vars(executor, dirname, program, lookup_table_var_name) + _logger.info("Start Load Sparse Program With " + "Distributed Lookup Table Vars from {}, time = {}".format( + dirname, time.ctime())) + + _load_persistable_vars(executor, dirname, program, [lookup_table_var_name]) + __load_lookup_table_vars(executor, dirname, program, + [lookup_table_var_name]) - feed_target_names = program.desc.get_feed_target_names() - fetch_target_names = program.desc.get_fetch_target_names() - fetch_targets = [ - program.global_block().var(name) for name in fetch_target_names - ] + _logger.info("Finish Load Sparse Program With " + "Distributed Lookup Table Vars from {}, time = {}".format( + dirname, time.ctime())) - return [program, feed_target_names, fetch_targets] + return program diff --git a/python/setup.py.in b/python/setup.py.in index 22b9537a90e79c2571f61ec0dc156b602df784d6..5d5f2dd0f18cd3e707dca8b9f337f2f2a07d47aa 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -107,9 +107,9 @@ packages=['paddle', 'paddle.fluid.distributed', 'paddle.fluid.layers', 'paddle.fluid.contrib', - 'paddle.fluid.contrib.utils', 'paddle.fluid.contrib.decoder', 'paddle.fluid.contrib.quantize', + 'paddle.fluid.contrib.utils', 'paddle.fluid.transpiler', 'paddle.fluid.transpiler.details']