未验证 提交 3639d99f 编写于 作者: T tangwei12 提交者: GitHub

Fix save and load lookup table/optimizer vars (#14301)

*  fix mkdir conflict

*  fix load/save lookup tables

 test=develop

* add lookup_table_utils

* fix load optimize vars on pserver

* delete lookup table utils

* fix save and load lookup tables

* fix load optimizer var

* fix load optimizer var, test=develop

* fix python 3 style, test=develop

* move lookup_table_utils to contrib utils
上级 2fc32b17
...@@ -67,6 +67,7 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -67,6 +67,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
"The sparse table only support FP32"); "The sparse table only support FP32");
w_t->Get(ids_t, out_t, true, is_test); w_t->Get(ids_t, out_t, true, is_test);
out_t->set_lod(ids_t.lod());
} }
}; };
......
...@@ -127,6 +127,9 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -127,6 +127,9 @@ class SumKernel : public framework::OpKernel<T> {
math::scatter::MergeAdd<DeviceContext, T> merge_add; math::scatter::MergeAdd<DeviceContext, T> merge_add;
merge_add(context.template device_context<DeviceContext>(), inputs, merge_add(context.template device_context<DeviceContext>(), inputs,
out); out);
out->SyncIndex();
} else { } else {
// no data, just set a empty out tensor. // no data, just set a empty out tensor.
out->mutable_value()->mutable_data<T>(framework::make_ddim({0}), out->mutable_value()->mutable_data<T>(framework::make_ddim({0}),
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
from . import lookup_table_utils
from .lookup_table_utils import *
from . import hdfs_utils from . import hdfs_utils
from .hdfs_utils import * from .hdfs_utils import *
__all__ = lookup_table_utils.__all__
__all__ = hdfs_utils.__all__ __all__ = hdfs_utils.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import time
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import io
from paddle.fluid import Program
__all__ = [
"load_inference_model", "load_persistable_vars",
"convert_dist_to_sparse_program"
]
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
_logger = logging.getLogger("lookup_table_utils")
_logger.setLevel(logging.INFO)
model_filename = "__model__"
lookup_table_dir = "__lookup_table__"
def __insert_lookup_sparse_table_op(main_program, idx, ids, w, out):
main_program.global_block()._insert_op(
index=idx,
type="lookup_sparse_table",
inputs={"Ids": [ids],
"W": [w]},
outputs={"Out": [out]},
attrs={
"is_distributed": False,
"is_sparse": True,
"grad_inplace": False
})
def __get_prefetch_op_tuples(main_program):
# current lookup tables op is split_ids->prefetch->merge_ids
prefetch_op_tuples = None
op_types = [op.type for op in main_program.global_block().ops]
for i in range(len(op_types)):
if op_types[i] == "prefetch":
if op_types[i - 1] == "split_ids" and op_types[i +
1] == "merge_ids":
split_ids_op_id = i - 1
split_ids_inputs = main_program.global_block().ops[i - 1].input(
"Ids")
prefetch_op_inputs = main_program.global_block().ops[i].input(
"X")
prefetch_op_outputs = main_program.global_block().ops[i].output(
"Out")
merge_ids_outputs = main_program.global_block().ops[
i + 1].output("Out")
need_delete_vars = []
need_delete_vars.extend(prefetch_op_inputs)
need_delete_vars.extend(prefetch_op_outputs)
prefetch_op_tuples = (split_ids_op_id, split_ids_inputs,
merge_ids_outputs, need_delete_vars)
break
return prefetch_op_tuples
def convert_dist_to_sparse_program(main_program):
if not main_program._distributed_lookup_table:
_logger.warn(
"There are no distributed lookup tables need to be converted")
return
# create table param and grad var in pserver program
origin_emb_var = "{}.origin".format(main_program._distributed_lookup_table)
emb_var = main_program._distributed_lookup_table
main_program.global_block()._rename_var(emb_var, origin_emb_var)
origin_param_var = main_program.global_block().vars[origin_emb_var]
param_var = main_program.global_block().create_var(
name=emb_var,
shape=origin_param_var.shape,
dtype=origin_param_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
# parameter must be selected rows
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
main_program._sync_with_cpp()
prefetch_op_tuples = __get_prefetch_op_tuples(main_program)
split_ids_id = prefetch_op_tuples[0]
for idx in range(split_ids_id + 2, split_ids_id - 1, -1):
main_program.global_block()._remove_op(idx)
main_program.desc.flush()
in_out_pairs = zip(prefetch_op_tuples[1], prefetch_op_tuples[2])
for in_out_pair in in_out_pairs:
idx = split_ids_id
ids = main_program.global_block().vars[in_out_pair[0]]
out = main_program.global_block().vars[in_out_pair[1]]
__insert_lookup_sparse_table_op(main_program, idx, ids, param_var, out)
main_program.desc.flush()
return main_program
def load_persistable_vars(executor, dirname, program, lookup_table_var):
def _is_checkpoint_var(exclude_fluid_vars=None):
"""
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var(Variable)
"""
if exclude_fluid_vars is None:
exclude_fluid_vars = []
def is_valid(var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW:
return False
# @GRAD are named for gradient variables, checkpoint will not save it.
if "@GRAD" in var.name:
return False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if ".trainer_" in var.name:
return False
# .block is named for distribute train variables, checkpoint will not save it.
if ".block" in var.name:
return False
if "tmp_" in var.name:
return False
if var.name in exclude_fluid_vars:
return False
return var.persistable
return is_valid
def _load_lookup_table_vars(executor, dirname, main_program,
lookup_table_vars):
if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname)
lookup_table_dirname = os.path.join(dirname, lookup_table_dir)
emb_var_name = lookup_table_vars[0]
emb_var = main_program.global_block().var(emb_var_name)
emb_files = []
for emb_name in os.listdir(lookup_table_dirname):
if emb_var_name in emb_name:
emb_files.append(emb_name)
convert_program = Program()
global_block = convert_program.global_block()
emb_var = global_block.create_var(
name=emb_var.name,
shape=emb_var.shape,
dtype=emb_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
emb_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
sums = []
for i, emb_file in enumerate(emb_files):
var_name = "{}_{}".format(emb_var.name, i)
param_var = global_block.create_var(
name=var_name,
shape=emb_var.shape,
dtype=emb_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
global_block.append_op(
type='load',
inputs={},
outputs={'Out': [param_var]},
attrs={
'file_path': os.path.join(lookup_table_dirname, var_name)
})
sums.append(param_var)
global_block.append_op(
type='sum', inputs={"X": sums}, outputs={'Out': emb_var}, attrs={})
global_block.append_op(type='delete_var', inputs={'X': sums})
executor.run(convert_program)
_logger.info("Start Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
lookup_table_vars = [lookup_table_var]
io.load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var(lookup_table_vars),
filename=None)
_load_lookup_table_vars(executor, dirname, program, lookup_table_vars)
_logger.info("Finish Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
def load_inference_model(dirname, executor, lookup_table_var_name):
if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname)
local_model = os.path.join(dirname, model_filename)
with open(local_model, "rb") as f:
program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str)
if not core._is_program_version_supported(program._version()):
raise ValueError("Unsupported program version: %d\n" %
program._version())
# Binary data also need version.
load_persistable_vars(executor, dirname, program, lookup_table_var_name)
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [
program.global_block().var(name) for name in fetch_target_names
]
return [program, feed_target_names, fetch_targets]
...@@ -1698,6 +1698,7 @@ class Program(object): ...@@ -1698,6 +1698,7 @@ class Program(object):
p._copy_param_info_from(self) p._copy_param_info_from(self)
p._copy_data_info_from(self) p._copy_data_info_from(self)
p._copy_dist_param_info_from(self)
return p return p
def _prune(self, targets): def _prune(self, targets):
...@@ -1938,6 +1939,25 @@ class Program(object): ...@@ -1938,6 +1939,25 @@ class Program(object):
"program, with represent the same topology") "program, with represent the same topology")
self.global_block()._copy_param_info_from(other.global_block()) self.global_block()._copy_param_info_from(other.global_block())
def _copy_dist_param_info_from(self, other):
"""
Copy the information of distributed information from other program.
Args:
other(Program): Other program
Returns:
None
"""
if not isinstance(other, Program):
raise TypeError("_copy_dist_param_info_from should be invoked with "
"Program")
self._is_distributed = other._is_distributed
self._is_chief = other._is_chief
self._slice_vars_and_attrs = other._slice_vars_and_attrs
self._endpoints = other._endpoints
self._distributed_lookup_table = other._distributed_lookup_table
def _copy_data_info_from(self, other): def _copy_data_info_from(self, other):
""" """
Copy the information of data variables from other program. Copy the information of data variables from other program.
......
...@@ -165,6 +165,7 @@ def save_vars(executor, ...@@ -165,6 +165,7 @@ def save_vars(executor,
save_vars( save_vars(
executor, executor,
main_program=main_program,
dirname=dirname, dirname=dirname,
vars=list(filter(predicate, main_program.list_vars())), vars=list(filter(predicate, main_program.list_vars())),
filename=filename) filename=filename)
...@@ -172,11 +173,18 @@ def save_vars(executor, ...@@ -172,11 +173,18 @@ def save_vars(executor,
save_program = Program() save_program = Program()
save_block = save_program.global_block() save_block = save_program.global_block()
if main_program is None:
main_program = default_main_program()
if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
save_var_map = {} save_var_map = {}
for each_var in vars: for each_var in vars:
# NOTE: don't save the variable which type is RAW # NOTE: don't save the variable which type is RAW
if each_var.type == core.VarDesc.VarType.RAW: if each_var.type == core.VarDesc.VarType.RAW:
continue continue
if each_var.name == main_program._distributed_lookup_table:
continue
new_var = _clone_var_in_block_(save_block, each_var) new_var = _clone_var_in_block_(save_block, each_var)
if filename is None: if filename is None:
save_block.append_op( save_block.append_op(
...@@ -198,6 +206,16 @@ def save_vars(executor, ...@@ -198,6 +206,16 @@ def save_vars(executor,
outputs={}, outputs={},
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(dirname, filename)})
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
attrs = {}
attrs['epmap'] = main_program._endpoints
attrs['dir'] = lookup_table_filename
attrs['lookup_table'] = main_program._distributed_lookup_table
save_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(save_program) executor.run(save_program)
...@@ -379,11 +397,22 @@ def load_vars(executor, ...@@ -379,11 +397,22 @@ def load_vars(executor,
load_prog = Program() load_prog = Program()
load_block = load_prog.global_block() load_block = load_prog.global_block()
if main_program is None:
main_program = default_main_program()
if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
load_slice_vars = []
for each_var in main_program._slice_vars_and_attrs:
load_slice_vars.append(each_var[2].name)
load_var_map = {} load_var_map = {}
for each_var in vars: for each_var in vars:
assert isinstance(each_var, Variable) assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW: if each_var.type == core.VarDesc.VarType.RAW:
continue continue
if each_var.name in load_slice_vars:
continue
new_var = _clone_var_in_block_(load_block, each_var) new_var = _clone_var_in_block_(load_block, each_var)
if filename is None: if filename is None:
load_block.append_op( load_block.append_op(
...@@ -406,9 +435,6 @@ def load_vars(executor, ...@@ -406,9 +435,6 @@ def load_vars(executor,
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(dirname, filename)})
executor.run(load_prog) executor.run(load_prog)
if main_program is None:
main_program = default_main_program()
# load slice vars on pserver, if have it. # load slice vars on pserver, if have it.
_load_slice_up_vars(executor, dirname, _load_slice_up_vars(executor, dirname,
main_program._slice_vars_and_attrs) main_program._slice_vars_and_attrs)
...@@ -618,13 +644,6 @@ def save_inference_model(dirname, ...@@ -618,13 +644,6 @@ def save_inference_model(dirname,
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
_save_lookup_tables_by_notify(executor, lookup_table_filename,
main_program._distributed_lookup_table,
main_program._endpoints)
# when a pserver and a trainer running on the same machine, mkdir may conflict # when a pserver and a trainer running on the same machine, mkdir may conflict
try: try:
os.makedirs(dirname) os.makedirs(dirname)
...@@ -642,6 +661,9 @@ def save_inference_model(dirname, ...@@ -642,6 +661,9 @@ def save_inference_model(dirname,
# it can only be loaded for inference directly. If it's false, the whole # it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be # original program and related meta are saved so that future usage can be
# more flexible. # more flexible.
origin_program = main_program.clone()
if export_for_deployment: if export_for_deployment:
main_program = main_program.clone() main_program = main_program.clone()
global_block = main_program.global_block() global_block = main_program.global_block()
...@@ -666,8 +688,11 @@ def save_inference_model(dirname, ...@@ -666,8 +688,11 @@ def save_inference_model(dirname,
with open(model_basename + ".main_program", "wb") as f: with open(model_basename + ".main_program", "wb") as f:
f.write(main_program.desc.serialize_to_string()) f.write(main_program.desc.serialize_to_string())
main_program._copy_dist_param_info_from(origin_program)
if params_filename is not None: if params_filename is not None:
params_filename = os.path.basename(params_filename) params_filename = os.path.basename(params_filename)
save_persistables(executor, dirname, main_program, params_filename) save_persistables(executor, dirname, main_program, params_filename)
...@@ -897,6 +922,9 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs): ...@@ -897,6 +922,9 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
slice_var = var_tuple[2] slice_var = var_tuple[2]
end = start + slice_var.shape[0] end = start + slice_var.shape[0]
orig_var_name = orig_var.name
orig_var.name = "{}.origin".format(orig_var_name)
clone_orig_var = load_block.create_var( clone_orig_var = load_block.create_var(
name=orig_var.name, name=orig_var.name,
type=orig_var.type, type=orig_var.type,
...@@ -915,7 +943,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs): ...@@ -915,7 +943,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
type='load', type='load',
inputs={}, inputs={},
outputs={'Out': [clone_orig_var]}, outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)}) attrs={'file_path': os.path.join(dirname, orig_var_name)})
load_block.append_op( load_block.append_op(
type="slice", type="slice",
inputs={'Input': clone_orig_var}, inputs={'Input': clone_orig_var},
...@@ -924,6 +952,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs): ...@@ -924,6 +952,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
'starts': [start], 'starts': [start],
'ends': [end]}) 'ends': [end]})
need_delete_vars.append(clone_orig_var) need_delete_vars.append(clone_orig_var)
load_block.append_op( load_block.append_op(
type='delete_var', type='delete_var',
inputs={'X': need_delete_vars}, ) inputs={'X': need_delete_vars}, )
......
...@@ -644,6 +644,9 @@ in a single call.") ...@@ -644,6 +644,9 @@ in a single call.")
else: else:
recv_inputs.append(single_trainer_var) recv_inputs.append(single_trainer_var)
self._slice_params_and_optimizes = self._get_slice_vars_and_attrs(
endpoint)
# step 3 # step 3
# Create a union-find data structure from optimize ops, # Create a union-find data structure from optimize ops,
# If two ops are connected, we could add these two ops # If two ops are connected, we could add these two ops
...@@ -766,7 +769,7 @@ in a single call.") ...@@ -766,7 +769,7 @@ in a single call.")
grad_to_block_id, merged_var, grad_to_block_id, merged_var,
lr_ops) lr_ops)
# dedup grad to ids list # dedup grad to ids list
grad_to_block_id = list(set(grad_to_block_id)) grad_to_block_id = list(set(grad_to_block_id))
# append global ops # append global ops
if global_ops: if global_ops:
...@@ -827,8 +830,8 @@ in a single call.") ...@@ -827,8 +830,8 @@ in a single call.")
attrs=attrs) attrs=attrs)
# add distributed attrs # add distributed attrs
pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs( pserver_program._slice_vars_and_attrs = list(
endpoint) self._slice_params_and_optimizes.values())
pserver_program._sync_with_cpp() pserver_program._sync_with_cpp()
# save pserver program to generate pserver side startup relatively. # save pserver program to generate pserver side startup relatively.
...@@ -941,12 +944,12 @@ to transpile() call.") ...@@ -941,12 +944,12 @@ to transpile() call.")
outputs={"Out": startup_tmpvar}) outputs={"Out": startup_tmpvar})
# add slice vars # add slice vars
s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint) s_prog._slice_vars_and_attrs = pserver_program._slice_vars_and_attrs
return s_prog return s_prog
def _get_slice_vars_and_attrs(self, endpoint): def _get_slice_vars_and_attrs(self, endpoint):
slice_vars_and_attrs = [] slice_vars_and_attrs = {}
block_suffix = "block" block_suffix = "block"
for param in self.param_grad_ep_mapping[endpoint]["params"]: for param in self.param_grad_ep_mapping[endpoint]["params"]:
orig_var_name, block_name, _ = self._get_varname_parts(param.name) orig_var_name, block_name, _ = self._get_varname_parts(param.name)
...@@ -960,8 +963,7 @@ to transpile() call.") ...@@ -960,8 +963,7 @@ to transpile() call.")
slice_vars = self.param_var_mapping[orig_var_name] slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]: for slice_var in slice_vars[:block_idx]:
skip_dim0 += slice_var.shape[0] skip_dim0 += slice_var.shape[0]
slice_vars_and_attrs.append([orig_var, skip_dim0, param]) slice_vars_and_attrs[param.name] = [orig_var, skip_dim0, param]
return slice_vars_and_attrs return slice_vars_and_attrs
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
...@@ -1662,10 +1664,10 @@ to transpile() call.") ...@@ -1662,10 +1664,10 @@ to transpile() call.")
if key in ["Param", "Grad", "LearningRate"]: if key in ["Param", "Grad", "LearningRate"]:
continue continue
var = self.origin_program.global_block().vars[opt_op.input(key)[0]] var = self.origin_program.global_block().vars[opt_op.input(key)[0]]
param_var = new_inputs["Param"]
# update accumulator variable shape # update accumulator variable shape
param_shape = new_inputs["Param"].shape new_shape = self._get_optimizer_input_shape(
new_shape = self._get_optimizer_input_shape(opt_op.type, key, opt_op.type, key, var.shape, param_var.shape)
var.shape, param_shape)
tmpvar = pserver_block.create_var( tmpvar = pserver_block.create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
...@@ -1673,6 +1675,13 @@ to transpile() call.") ...@@ -1673,6 +1675,13 @@ to transpile() call.")
shape=new_shape) shape=new_shape)
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
# var shape been changed
if new_shape != var.shape:
slice_var_args = self._slice_params_and_optimizes[
param_var.name]
self._slice_params_and_optimizes[
var.name] = [var, slice_var_args[1], tmpvar]
# change output's ParamOut variable # change output's ParamOut variable
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册