From 3639d99f9910e98f4cccb0e2de1f583d62ab0028 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 26 Nov 2018 14:30:21 +0800 Subject: [PATCH] Fix save and load lookup table/optimizer vars (#14301) * fix mkdir conflict * fix load/save lookup tables test=develop * add lookup_table_utils * fix load optimize vars on pserver * delete lookup table utils * fix save and load lookup tables * fix load optimizer var * fix load optimizer var, test=develop * fix python 3 style, test=develop * move lookup_table_utils to contrib utils --- .../fluid/operators/lookup_sparse_table_op.cc | 1 + paddle/fluid/operators/sum_op.h | 3 + python/paddle/fluid/contrib/utils/__init__.py | 4 +- .../fluid/contrib/utils/lookup_table_utils.py | 256 ++++++++++++++++++ python/paddle/fluid/framework.py | 20 ++ python/paddle/fluid/io.py | 51 +++- .../fluid/transpiler/distribute_transpiler.py | 29 +- 7 files changed, 342 insertions(+), 22 deletions(-) create mode 100644 python/paddle/fluid/contrib/utils/lookup_table_utils.py diff --git a/paddle/fluid/operators/lookup_sparse_table_op.cc b/paddle/fluid/operators/lookup_sparse_table_op.cc index a6843f20a..1b55527fd 100644 --- a/paddle/fluid/operators/lookup_sparse_table_op.cc +++ b/paddle/fluid/operators/lookup_sparse_table_op.cc @@ -67,6 +67,7 @@ class LookupSparseTableOp : public framework::OperatorBase { framework::proto::VarType::FP32, "The sparse table only support FP32"); w_t->Get(ids_t, out_t, true, is_test); + out_t->set_lod(ids_t.lod()); } }; diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 19b2c68c8..76cc796a9 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -127,6 +127,9 @@ class SumKernel : public framework::OpKernel { math::scatter::MergeAdd merge_add; merge_add(context.template device_context(), inputs, out); + + out->SyncIndex(); + } else { // no data, just set a empty out tensor. out->mutable_value()->mutable_data(framework::make_ddim({0}), diff --git a/python/paddle/fluid/contrib/utils/__init__.py b/python/paddle/fluid/contrib/utils/__init__.py index df6d36778..6e479bdc2 100644 --- a/python/paddle/fluid/contrib/utils/__init__.py +++ b/python/paddle/fluid/contrib/utils/__init__.py @@ -13,8 +13,10 @@ # limitations under the License. from __future__ import print_function - +from . import lookup_table_utils +from .lookup_table_utils import * from . import hdfs_utils from .hdfs_utils import * +__all__ = lookup_table_utils.__all__ __all__ = hdfs_utils.__all__ diff --git a/python/paddle/fluid/contrib/utils/lookup_table_utils.py b/python/paddle/fluid/contrib/utils/lookup_table_utils.py new file mode 100644 index 000000000..cc2418238 --- /dev/null +++ b/python/paddle/fluid/contrib/utils/lookup_table_utils.py @@ -0,0 +1,256 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import time +import logging + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid import io +from paddle.fluid import Program + +__all__ = [ + "load_inference_model", "load_persistable_vars", + "convert_dist_to_sparse_program" +] + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +_logger = logging.getLogger("lookup_table_utils") +_logger.setLevel(logging.INFO) + +model_filename = "__model__" +lookup_table_dir = "__lookup_table__" + + +def __insert_lookup_sparse_table_op(main_program, idx, ids, w, out): + main_program.global_block()._insert_op( + index=idx, + type="lookup_sparse_table", + inputs={"Ids": [ids], + "W": [w]}, + outputs={"Out": [out]}, + attrs={ + "is_distributed": False, + "is_sparse": True, + "grad_inplace": False + }) + + +def __get_prefetch_op_tuples(main_program): + # current lookup tables op is split_ids->prefetch->merge_ids + prefetch_op_tuples = None + op_types = [op.type for op in main_program.global_block().ops] + + for i in range(len(op_types)): + if op_types[i] == "prefetch": + if op_types[i - 1] == "split_ids" and op_types[i + + 1] == "merge_ids": + split_ids_op_id = i - 1 + split_ids_inputs = main_program.global_block().ops[i - 1].input( + "Ids") + prefetch_op_inputs = main_program.global_block().ops[i].input( + "X") + prefetch_op_outputs = main_program.global_block().ops[i].output( + "Out") + merge_ids_outputs = main_program.global_block().ops[ + i + 1].output("Out") + + need_delete_vars = [] + need_delete_vars.extend(prefetch_op_inputs) + need_delete_vars.extend(prefetch_op_outputs) + + prefetch_op_tuples = (split_ids_op_id, split_ids_inputs, + merge_ids_outputs, need_delete_vars) + break + return prefetch_op_tuples + + +def convert_dist_to_sparse_program(main_program): + if not main_program._distributed_lookup_table: + _logger.warn( + "There are no distributed lookup tables need to be converted") + return + + # create table param and grad var in pserver program + origin_emb_var = "{}.origin".format(main_program._distributed_lookup_table) + emb_var = main_program._distributed_lookup_table + main_program.global_block()._rename_var(emb_var, origin_emb_var) + origin_param_var = main_program.global_block().vars[origin_emb_var] + + param_var = main_program.global_block().create_var( + name=emb_var, + shape=origin_param_var.shape, + dtype=origin_param_var.dtype, + type=core.VarDesc.VarType.SELECTED_ROWS, + persistable=True) + # parameter must be selected rows + param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) + main_program._sync_with_cpp() + + prefetch_op_tuples = __get_prefetch_op_tuples(main_program) + + split_ids_id = prefetch_op_tuples[0] + + for idx in range(split_ids_id + 2, split_ids_id - 1, -1): + main_program.global_block()._remove_op(idx) + main_program.desc.flush() + + in_out_pairs = zip(prefetch_op_tuples[1], prefetch_op_tuples[2]) + + for in_out_pair in in_out_pairs: + idx = split_ids_id + ids = main_program.global_block().vars[in_out_pair[0]] + out = main_program.global_block().vars[in_out_pair[1]] + __insert_lookup_sparse_table_op(main_program, idx, ids, param_var, out) + main_program.desc.flush() + return main_program + + +def load_persistable_vars(executor, dirname, program, lookup_table_var): + def _is_checkpoint_var(exclude_fluid_vars=None): + """ + the checkpoint will not save or load all the variables. + var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. + + : param var(Variable) + """ + + if exclude_fluid_vars is None: + exclude_fluid_vars = [] + + def is_valid(var): + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.RAW: + return False + # @GRAD are named for gradient variables, checkpoint will not save it. + if "@GRAD" in var.name: + return False + # .trainer_ are named for distribute train variables, checkpoint will not save it. + if ".trainer_" in var.name: + return False + + # .block is named for distribute train variables, checkpoint will not save it. + if ".block" in var.name: + return False + + if "tmp_" in var.name: + return False + + if var.name in exclude_fluid_vars: + return False + + return var.persistable + + return is_valid + + def _load_lookup_table_vars(executor, dirname, main_program, + lookup_table_vars): + if not os.path.isdir(dirname): + raise ValueError("There is no directory named '%s'", dirname) + + lookup_table_dirname = os.path.join(dirname, lookup_table_dir) + + emb_var_name = lookup_table_vars[0] + emb_var = main_program.global_block().var(emb_var_name) + + emb_files = [] + for emb_name in os.listdir(lookup_table_dirname): + if emb_var_name in emb_name: + emb_files.append(emb_name) + + convert_program = Program() + global_block = convert_program.global_block() + + emb_var = global_block.create_var( + name=emb_var.name, + shape=emb_var.shape, + dtype=emb_var.dtype, + type=core.VarDesc.VarType.SELECTED_ROWS, + persistable=True) + emb_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) + + sums = [] + + for i, emb_file in enumerate(emb_files): + var_name = "{}_{}".format(emb_var.name, i) + param_var = global_block.create_var( + name=var_name, + shape=emb_var.shape, + dtype=emb_var.dtype, + type=core.VarDesc.VarType.SELECTED_ROWS, + persistable=True) + param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) + global_block.append_op( + type='load', + inputs={}, + outputs={'Out': [param_var]}, + attrs={ + 'file_path': os.path.join(lookup_table_dirname, var_name) + }) + sums.append(param_var) + global_block.append_op( + type='sum', inputs={"X": sums}, outputs={'Out': emb_var}, attrs={}) + global_block.append_op(type='delete_var', inputs={'X': sums}) + executor.run(convert_program) + + _logger.info("Start Load Sparse Program With " + "Distributed Lookup Table Vars from {}, time = {}".format( + dirname, time.ctime())) + + lookup_table_vars = [lookup_table_var] + + io.load_vars( + executor, + dirname=dirname, + main_program=program, + predicate=_is_checkpoint_var(lookup_table_vars), + filename=None) + + _load_lookup_table_vars(executor, dirname, program, lookup_table_vars) + + _logger.info("Finish Load Sparse Program With " + "Distributed Lookup Table Vars from {}, time = {}".format( + dirname, time.ctime())) + + +def load_inference_model(dirname, executor, lookup_table_var_name): + if not os.path.isdir(dirname): + raise ValueError("There is no directory named '%s'", dirname) + + local_model = os.path.join(dirname, model_filename) + + with open(local_model, "rb") as f: + program_desc_str = f.read() + + program = Program.parse_from_string(program_desc_str) + + if not core._is_program_version_supported(program._version()): + raise ValueError("Unsupported program version: %d\n" % + program._version()) + + # Binary data also need version. + load_persistable_vars(executor, dirname, program, lookup_table_var_name) + + feed_target_names = program.desc.get_feed_target_names() + fetch_target_names = program.desc.get_fetch_target_names() + fetch_targets = [ + program.global_block().var(name) for name in fetch_target_names + ] + + return [program, feed_target_names, fetch_targets] diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fd03dff38..b991187d4 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1698,6 +1698,7 @@ class Program(object): p._copy_param_info_from(self) p._copy_data_info_from(self) + p._copy_dist_param_info_from(self) return p def _prune(self, targets): @@ -1938,6 +1939,25 @@ class Program(object): "program, with represent the same topology") self.global_block()._copy_param_info_from(other.global_block()) + def _copy_dist_param_info_from(self, other): + """ + Copy the information of distributed information from other program. + + Args: + other(Program): Other program + + Returns: + None + """ + if not isinstance(other, Program): + raise TypeError("_copy_dist_param_info_from should be invoked with " + "Program") + self._is_distributed = other._is_distributed + self._is_chief = other._is_chief + self._slice_vars_and_attrs = other._slice_vars_and_attrs + self._endpoints = other._endpoints + self._distributed_lookup_table = other._distributed_lookup_table + def _copy_data_info_from(self, other): """ Copy the information of data variables from other program. diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 8936d884d..26d7af87b 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -165,6 +165,7 @@ def save_vars(executor, save_vars( executor, + main_program=main_program, dirname=dirname, vars=list(filter(predicate, main_program.list_vars())), filename=filename) @@ -172,11 +173,18 @@ def save_vars(executor, save_program = Program() save_block = save_program.global_block() + if main_program is None: + main_program = default_main_program() + if not isinstance(main_program, Program): + raise TypeError("program should be as Program type or None") + save_var_map = {} for each_var in vars: # NOTE: don't save the variable which type is RAW if each_var.type == core.VarDesc.VarType.RAW: continue + if each_var.name == main_program._distributed_lookup_table: + continue new_var = _clone_var_in_block_(save_block, each_var) if filename is None: save_block.append_op( @@ -198,6 +206,16 @@ def save_vars(executor, outputs={}, attrs={'file_path': os.path.join(dirname, filename)}) + # if there is lookup table, the trainer 0 will notify all pserver to save. + if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table: + lookup_table_filename = os.path.join(dirname, "__lookup_table__") + attrs = {} + attrs['epmap'] = main_program._endpoints + attrs['dir'] = lookup_table_filename + attrs['lookup_table'] = main_program._distributed_lookup_table + save_block.append_op( + type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) + executor.run(save_program) @@ -379,11 +397,22 @@ def load_vars(executor, load_prog = Program() load_block = load_prog.global_block() + if main_program is None: + main_program = default_main_program() + if not isinstance(main_program, Program): + raise TypeError("program should be as Program type or None") + + load_slice_vars = [] + for each_var in main_program._slice_vars_and_attrs: + load_slice_vars.append(each_var[2].name) + load_var_map = {} for each_var in vars: assert isinstance(each_var, Variable) if each_var.type == core.VarDesc.VarType.RAW: continue + if each_var.name in load_slice_vars: + continue new_var = _clone_var_in_block_(load_block, each_var) if filename is None: load_block.append_op( @@ -406,9 +435,6 @@ def load_vars(executor, attrs={'file_path': os.path.join(dirname, filename)}) executor.run(load_prog) - if main_program is None: - main_program = default_main_program() - # load slice vars on pserver, if have it. _load_slice_up_vars(executor, dirname, main_program._slice_vars_and_attrs) @@ -618,13 +644,6 @@ def save_inference_model(dirname, if main_program is None: main_program = default_main_program() - # if there is lookup table, the trainer 0 will notify all pserver to save. - if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table: - lookup_table_filename = os.path.join(dirname, "__lookup_table__") - _save_lookup_tables_by_notify(executor, lookup_table_filename, - main_program._distributed_lookup_table, - main_program._endpoints) - # when a pserver and a trainer running on the same machine, mkdir may conflict try: os.makedirs(dirname) @@ -642,6 +661,9 @@ def save_inference_model(dirname, # it can only be loaded for inference directly. If it's false, the whole # original program and related meta are saved so that future usage can be # more flexible. + + origin_program = main_program.clone() + if export_for_deployment: main_program = main_program.clone() global_block = main_program.global_block() @@ -666,8 +688,11 @@ def save_inference_model(dirname, with open(model_basename + ".main_program", "wb") as f: f.write(main_program.desc.serialize_to_string()) + main_program._copy_dist_param_info_from(origin_program) + if params_filename is not None: params_filename = os.path.basename(params_filename) + save_persistables(executor, dirname, main_program, params_filename) @@ -897,6 +922,9 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs): slice_var = var_tuple[2] end = start + slice_var.shape[0] + orig_var_name = orig_var.name + orig_var.name = "{}.origin".format(orig_var_name) + clone_orig_var = load_block.create_var( name=orig_var.name, type=orig_var.type, @@ -915,7 +943,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs): type='load', inputs={}, outputs={'Out': [clone_orig_var]}, - attrs={'file_path': os.path.join(dirname, clone_orig_var.name)}) + attrs={'file_path': os.path.join(dirname, orig_var_name)}) load_block.append_op( type="slice", inputs={'Input': clone_orig_var}, @@ -924,6 +952,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs): 'starts': [start], 'ends': [end]}) need_delete_vars.append(clone_orig_var) + load_block.append_op( type='delete_var', inputs={'X': need_delete_vars}, ) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 89bc24802..ebd0d18d3 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -644,6 +644,9 @@ in a single call.") else: recv_inputs.append(single_trainer_var) + self._slice_params_and_optimizes = self._get_slice_vars_and_attrs( + endpoint) + # step 3 # Create a union-find data structure from optimize ops, # If two ops are connected, we could add these two ops @@ -766,7 +769,7 @@ in a single call.") grad_to_block_id, merged_var, lr_ops) -# dedup grad to ids list + # dedup grad to ids list grad_to_block_id = list(set(grad_to_block_id)) # append global ops if global_ops: @@ -827,8 +830,8 @@ in a single call.") attrs=attrs) # add distributed attrs - pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs( - endpoint) + pserver_program._slice_vars_and_attrs = list( + self._slice_params_and_optimizes.values()) pserver_program._sync_with_cpp() # save pserver program to generate pserver side startup relatively. @@ -941,12 +944,12 @@ to transpile() call.") outputs={"Out": startup_tmpvar}) # add slice vars - s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint) + s_prog._slice_vars_and_attrs = pserver_program._slice_vars_and_attrs return s_prog def _get_slice_vars_and_attrs(self, endpoint): - slice_vars_and_attrs = [] + slice_vars_and_attrs = {} block_suffix = "block" for param in self.param_grad_ep_mapping[endpoint]["params"]: orig_var_name, block_name, _ = self._get_varname_parts(param.name) @@ -960,8 +963,7 @@ to transpile() call.") slice_vars = self.param_var_mapping[orig_var_name] for slice_var in slice_vars[:block_idx]: skip_dim0 += slice_var.shape[0] - slice_vars_and_attrs.append([orig_var, skip_dim0, param]) - + slice_vars_and_attrs[param.name] = [orig_var, skip_dim0, param] return slice_vars_and_attrs # ====================== private transpiler functions ===================== @@ -1662,10 +1664,10 @@ to transpile() call.") if key in ["Param", "Grad", "LearningRate"]: continue var = self.origin_program.global_block().vars[opt_op.input(key)[0]] + param_var = new_inputs["Param"] # update accumulator variable shape - param_shape = new_inputs["Param"].shape - new_shape = self._get_optimizer_input_shape(opt_op.type, key, - var.shape, param_shape) + new_shape = self._get_optimizer_input_shape( + opt_op.type, key, var.shape, param_var.shape) tmpvar = pserver_block.create_var( name=var.name, persistable=var.persistable, @@ -1673,6 +1675,13 @@ to transpile() call.") shape=new_shape) new_inputs[key] = tmpvar + # var shape been changed + if new_shape != var.shape: + slice_var_args = self._slice_params_and_optimizes[ + param_var.name] + self._slice_params_and_optimizes[ + var.name] = [var, slice_var_args[1], tmpvar] + # change output's ParamOut variable outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op) -- GitLab