未验证 提交 d56a0a1b 编写于 作者: W wangguanqun 提交者: GitHub

fix bug in new the_one_ps (#39505)

* fix benchmark and communicator config

* fix bugs of the_one_ps

* multi program and fix bug in optimizer

* multi program in the_one_ps

* public commcontext
上级 a6abb6e7
......@@ -31,7 +31,8 @@ struct CommContext {
const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false, int table_id_ = -1,
bool is_tensor_table_ = false)
bool is_tensor_table_ = false, bool is_datanorm_table_ = false,
int64_t program_id_ = -1)
: var_name(name),
splited_varnames(names),
epmap(emap),
......@@ -42,7 +43,9 @@ struct CommContext {
is_sparse(is_sparse_),
is_distributed(is_distributed_),
table_id(table_id_),
is_tensor_table(is_tensor_table_) {}
program_id(program_id_),
is_tensor_table(is_tensor_table_),
is_datanorm_table(is_datanorm_table_) {}
CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
......@@ -55,7 +58,9 @@ struct CommContext {
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
table_id = ctx.table_id;
program_id = ctx.program_id;
is_tensor_table = ctx.is_tensor_table;
is_datanorm_table = ctx.is_datanorm_table;
}
std::string print() const {
......@@ -78,7 +83,9 @@ struct CommContext {
ss << " is_sparse: " << is_sparse;
ss << " is_distributed: " << is_distributed << "\n";
ss << " table_id: " << table_id << "\n";
ss << " program_id: " << program_id << "\n";
ss << " is_tensor_table: " << is_tensor_table << "\n";
ss << " is_datanorm_table: " << is_datanorm_table << "\n";
return ss.str();
}
......@@ -93,7 +100,9 @@ struct CommContext {
bool is_sparse;
bool is_distributed;
int table_id;
int64_t program_id;
bool is_tensor_table;
bool is_datanorm_table;
};
} // namespace distributed
......
......@@ -103,11 +103,13 @@ void BindCommunicatorContext(py::module* m) {
py::init<const std::string&, const std::vector<std::string>&,
const std::vector<std::string>&, const std::vector<int64_t>&,
const std::vector<std::string>&, int, bool, bool, bool, int,
bool>())
bool, bool, int64_t>())
.def("var_name", [](const CommContext& self) { return self.var_name; })
.def("trainer_id",
[](const CommContext& self) { return self.trainer_id; })
.def("table_id", [](const CommContext& self) { return self.table_id; })
.def("program_id",
[](const CommContext& self) { return self.program_id; })
.def("split_varnames",
[](const CommContext& self) { return self.splited_varnames; })
.def("split_endpoints",
......@@ -122,6 +124,8 @@ void BindCommunicatorContext(py::module* m) {
[](const CommContext& self) { return self.origin_varnames; })
.def("is_tensor_table",
[](const CommContext& self) { return self.is_tensor_table; })
.def("is_datanorm_table",
[](const CommContext& self) { return self.is_datanorm_table; })
.def("__str__", [](const CommContext& self) { return self.print(); });
}
......
......@@ -46,7 +46,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
attrs['loss'] = loss
attrs['min_block_size'] = 81920
attrs['origin_main_program'] = loss.block.program
attrs['origin_main_programs'] = [loss.block.program]
attrs['origin_startup_program'] = startup_program
attrs['origin_startup_programs'] = [startup_program]
attrs['cloned_main'] = attrs['origin_main_program'].clone()
attrs['cloned_startup'] = attrs['origin_startup_program'].clone()
......
......@@ -560,9 +560,9 @@ class FakeInitOpsPass(PassBase):
return True
def _get_sparse_table_names(self, attrs):
dist_varnames = get_sparse_tablenames(attrs['origin_main_program'],
dist_varnames = get_sparse_tablenames(attrs['origin_main_programs'],
True)
sparse_varnames = get_sparse_tablenames(attrs['origin_main_program'],
sparse_varnames = get_sparse_tablenames(attrs['origin_main_programs'],
False)
return list(set(dist_varnames + sparse_varnames))
......
......@@ -24,8 +24,8 @@ from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.framework import Variable, Parameter
from .runtime_base import RuntimeBase
from ..base.private_helper_function import wait_server_ready
from paddle.distributed.fleet.runtime.runtime_base import RuntimeBase
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready
from paddle.fluid.communicator import Communicator, HeterClient
from google.protobuf import text_format
......@@ -39,8 +39,17 @@ def conv_indent(indent):
PSERVER_SAVE_SUFFIX = ".shard"
def parse_table_class(varname, o_main_program):
for op in o_main_program.global_block().ops:
def get_program_by_id(context, program_id):
programs = context["origin_main_programs"]
for i, program in enumerate(programs):
if id(program) == program_id:
return program, context["origin_startup_programs"][i]
return None, None
def parse_table_class(varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue
......@@ -53,9 +62,10 @@ def parse_table_class(varname, o_main_program):
return "MemorySparseTable"
def get_default_accessor_proto(accessor, varname, o_main_program):
def get_default_accessor_proto(accessor, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
embedding_dim = 0
for var in o_main_program.list_vars():
for var in main_program.list_vars():
if var.name == varname:
embedding_dim = var.shape[1]
break
......@@ -123,9 +133,10 @@ def get_default_accessor_proto(accessor, varname, o_main_program):
sgd_param.adam.weight_bounds.extend([-10.0, 10.0])
def check_embedding_dim(accessor, varname, o_main_program):
def check_embedding_dim(accessor, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
embedding_dim = 0
for var in o_main_program.list_vars():
for var in main_program.list_vars():
if var.name == varname:
embedding_dim = var.shape[1]
break
......@@ -172,6 +183,8 @@ class CommonAccessor:
self.dims = []
self.trainer_num = 0
self.sync = "false"
self.table_num = None
self.table_dim = None
self.initializers = []
self.opt_input_map = {}
self.opt_attr_map = {}
......@@ -192,6 +205,7 @@ class CommonAccessor:
opt_input_map["sum"] = [("Param", None)]
opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1),
("LearningRate", 1)]
opt_input_map["summary"] = [("Param", None), ("SummaryDecayRate", 1)]
opt_attr_map = {}
opt_attr_map["sgd"] = []
......@@ -201,6 +215,7 @@ class CommonAccessor:
("epsilon", "f")]
opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")]
opt_attr_map["summary"] = []
opt_init_map = {}
opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
......@@ -212,8 +227,9 @@ class CommonAccessor:
self.opt_input_map = opt_input_map
self.opt_init_map = opt_init_map
def parse_entry(self, varname, o_main_program):
for op in o_main_program.global_block().ops:
def parse_entry(self, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue
......@@ -243,23 +259,36 @@ class CommonAccessor:
attr_str = ""
origin_var_name = value_name
print("get_initializer_attr param name:", value_name)
for op in o_startup_program.global_block().ops:
if op.type in self.opt_init_map.keys(
) and origin_var_name == op.output("Out")[0]:
init_attr = [op.type]
print("get_initializer_attr op type:", op.type)
for attr in self.opt_init_map[op.type]:
print("get_initializer_attr opt_init_map attr:", attr)
init_attr.append(str(op.attr(attr)))
print("get_initializer_attr op attr:", str(op.attr(attr)))
attr_str = l_in.join(init_attr)
break
return attr_str
def parse_by_optimizer(self, grad_name, is_sparse, total_dims, context,
adam_d2sum):
main_program = context['origin_main_program']
startup_program = context['startup_main_program']
def parse_by_optimizer(self, ctx, context):
grad_name = ctx.origin_varnames()[0]
is_sparse = ctx.is_sparse()
size = ctx.sections()[0]
single_dim = ctx.sections()[1] if ctx.is_sparse() else 1
adam_d2sum = context["user_defined_strategy"].adam_d2sum
print("parse_by_optimizer table_id:{} is_datanorm:{}".format(
ctx.table_id(), ctx.is_datanorm_table()))
main_program, startup_program = get_program_by_id(context,
ctx.program_id())
pserver_id = get_role_id(context['role_maker'])
pserver_num = len(get_ps_endpoints(context['role_maker']))
optimizer_ops = get_optimize_ops(main_program)
print("the one ps optimizer_ops:", optimizer_ops)
print("the one ps parse_by_optimizer grad_name:", grad_name)
oop = None
for op in optimizer_ops:
......@@ -278,6 +307,8 @@ class CommonAccessor:
initializers = []
self.trainer_num = get_trainers(context['role_maker'])
self.table_num = size
self.table_dim = single_dim
if oop.type != 'adam' and adam_d2sum == True:
print('optimization algorithm is not adam, set adam_d2sum False')
......@@ -291,7 +322,11 @@ class CommonAccessor:
param_varnames = self.opt_input_map["naive_adagrad"]
attr_varnames = self.opt_attr_map["naive_adagrad"]
self.accessor_class = "sgd"
elif adam_d2sum:
elif ctx.is_datanorm_table():
param_varnames = self.opt_input_map["summary"]
attr_varnames = self.opt_attr_map["summary"]
self.accessor_class = "summary"
elif adam_d2sum and not is_sparse:
param_varnames = self.opt_input_map["adam_d2sum"]
attr_varnames = self.opt_attr_map["adam_d2sum"]
self.accessor_class = "adam_d2sum"
......@@ -306,10 +341,9 @@ class CommonAccessor:
#for dims
if shape is None:
if is_sparse:
shape = total_dims
shape = single_dim
else:
shape = self.get_shard(total_dims, pserver_num,
pserver_id)
shape = self.get_shard(size, pserver_num, pserver_id)
dims.append(shape)
#for initializers
......@@ -333,6 +367,27 @@ class CommonAccessor:
else:
initializer = "fill_constant&0"
initializers.append(initializer)
elif self.accessor_class == "summary":
#for dims
if shape is None:
if is_sparse:
shape = single_dim
else:
shape = self.get_shard(size, pserver_num, pserver_id)
dims.append(shape)
#for initializers
if formal_name == "Param":
param = main_program.global_block().vars[oop.input(
formal_name)[0]]
initializer = self.get_initializer_attr(param.name,
startup_program)
elif formal_name == "SummaryDecayRate":
initializer = "fill_constant&0.99999"
else:
initializer = "fill_constant&0"
initializers.append(initializer)
else:
if formal_name == "G2Sum":
dims.append(1)
......@@ -348,9 +403,9 @@ class CommonAccessor:
if shape is None:
if is_sparse:
shape = total_dims
shape = single_dim
else:
shape = self.get_shard(total_dims, pserver_num,
shape = self.get_shard(size, pserver_num,
pserver_id)
dims.append(shape)
......@@ -379,6 +434,10 @@ class CommonAccessor:
attrs += "entry: \"{}\" ".format(self.entry)
attrs += "trainer_num: {} ".format(self.trainer_num)
attrs += "sync: {} ".format(self.sync)
if self.table_num:
attrs += "table_num: {} ".format(self.table_num)
if self.table_dim:
attrs += "table_dim: {} ".format(self.table_dim)
for param in self.params:
attrs += "params: \"{}\" ".format(param)
......@@ -448,10 +507,7 @@ class Table:
accessor_str = accessor_str.format(
conv_indent(indent), self.accessor_proto, conv_indent(indent))
attrs += accessor_str + "\n"
return table_str.format(
conv_indent(indent), attrs, conv_indent(indent))
if self.accessor is not None:
elif self.accessor is not None:
attrs += self.accessor.to_string(indent)
attrs += "\n"
......@@ -607,7 +663,9 @@ class TheOnePSRuntime(RuntimeBase):
def _set_basic_info(self, context):
self.context = context
self.role_maker = context["role_maker"]
self.origin_main_program = context["origin_main_program"]
self.origin_main_programs = context["origin_main_programs"]
self.context[
'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode
......@@ -615,10 +673,13 @@ class TheOnePSRuntime(RuntimeBase):
self.context['trainer'] = TrainerRuntimeConfig(context[
'valid_strategy'])
self.context['ps_mode'] = self.context['trainer'].mode
self.context['use_ps_gpu'] = context['valid_strategy'].use_ps_gpu
self.context['use_ps_gpu'] = context['valid_strategy'].a_sync_configs[
'use_ps_gpu']
self.is_sync = True if self.context[
'ps_mode'] == DistributedMode.SYNC else False
self.context['grad_name_to_param_name'] = {}
self.context['tensor_table'] = {}
build_var_distributed(self.context)
def _init_worker(self):
worker = self._get_fleet_proto(is_server=False, is_sync=self.is_sync)
......@@ -689,6 +750,7 @@ class TheOnePSRuntime(RuntimeBase):
sync_kwargs = sync_strategy_envs()
kwargs.update(sync_kwargs)
print("communicator config:", trainer_config.get_communicator_flags())
self._communicator = Communicator(
trainer_config.mode, kwargs,
trainer_config.get_communicator_flags())
......@@ -893,7 +955,7 @@ class TheOnePSRuntime(RuntimeBase):
common.table_name = self.context['grad_name_to_param_name'][
ctx.origin_varnames()[0]]
if self.ps_mode == DistributedMode.GEO:
if self.context['ps_mode'] == DistributedMode.GEO:
table.table_class = "SparseGeoTable"
else:
all_table_proto = self.context[
......@@ -907,7 +969,8 @@ class TheOnePSRuntime(RuntimeBase):
table.table_class = table_proto.table_class
else:
table.table_class = parse_table_class(
common.table_name, self.origin_main_program)
common.table_name,
ctx.program_id(), self.context)
if table.table_class != 'MemorySparseTable':
table.table_class = 'MemorySparseTable'
warnings.warn(
......@@ -925,12 +988,12 @@ class TheOnePSRuntime(RuntimeBase):
warnings.warn(
"The accessor of sparse table is not set, use default value."
)
get_default_accessor_proto(table_proto.accessor,
common.table_name,
self.origin_main_program)
get_default_accessor_proto(
table_proto.accessor, common.table_name,
ctx.program_id(), self.context)
check_embedding_dim(table_proto.accessor,
common.table_name,
self.origin_main_program)
ctx.program_id(), self.context)
table.accessor_proto = text_format.MessageToString(
table_proto.accessor)
else:
......@@ -940,15 +1003,11 @@ class TheOnePSRuntime(RuntimeBase):
common.table_name = "MergedDense"
adam_d2sum = self.context["user_defined_strategy"].adam_d2sum
common.parse_by_optimizer(ctx.origin_varnames()[0],
ctx.is_sparse(),
ctx.sections()[1] if ctx.is_sparse()
else ctx.sections()[0], self.context,
adam_d2sum)
common.parse_by_optimizer(ctx, self.context)
if ctx.is_sparse():
common.parse_entry(common.table_name,
self.origin_main_program)
ctx.program_id(), self.context)
if is_sync:
common.sync = "true"
......@@ -1023,8 +1082,9 @@ class TheOnePSRuntime(RuntimeBase):
self._server.init_server(proto_txt, string_hosts, role_id, trainers,
self._server_sub_program)
dist_varnames = get_sparse_tablenames(self.origin_main_program, True)
sparse_varnames = get_sparse_tablenames(self.origin_main_program, False)
dist_varnames = get_sparse_tablenames(self.origin_main_programs, True)
sparse_varnames = get_sparse_tablenames(self.origin_main_programs,
False)
distributed_varnames = dist_varnames + sparse_varnames
......@@ -1070,6 +1130,7 @@ class TheOnePSRuntime(RuntimeBase):
if var.name in exclude_var_names:
return False
from .utils.public import _get_varname_parts
origin_varname, _, _ = _get_varname_parts(var.name)
if origin_varname.endswith("@GRAD"):
return False
......@@ -1085,16 +1146,24 @@ class TheOnePSRuntime(RuntimeBase):
return is_valid
def _get_inference_model_path(self, dirname):
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
return model_path
def _save_sparse_params(self, executor, dirname, context, main_program,
mode):
distributed_varnames = get_sparse_tablenames(
self.context['origin_main_program'], True)
distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
True)
values = []
model_path = self._get_inference_model_path(dirname)
for id, names in context.items():
if names[0] not in distributed_varnames:
# only save sparse param to local
try:
self._worker.recv_and_save_model(id, dirname)
self._worker.recv_and_save_model(id, model_path)
except:
pass
# save sparse & distributed param on server
......@@ -1221,10 +1290,7 @@ class TheOnePSRuntime(RuntimeBase):
infer_program._copy_dist_param_info_from(program)
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
model_path = self._get_inference_model_path(dirname)
model_basename = "__model__"
model_basename = os.path.join(model_path, model_basename)
paddle.save(infer_program, model_basename)
......@@ -1266,7 +1332,7 @@ class TheOnePSRuntime(RuntimeBase):
self._ps_inference_save_persistables(*args, **kwargs)
def _load_sparse_params(self, dirname, context, main_program, mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_program,
distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
True)
values = []
for id, names in context.items():
......
......@@ -79,7 +79,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式
super(GeoPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "GeoPsProgramBuilder"))
format(self.ps_mode, "GeoPsProgramBuilder"))
def _build_trainer_programs(self):
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
......@@ -97,9 +97,9 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
logger.info("start building cpu-sync-ps program")
super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.SYNC:
if self.ps_mode != DistributedMode.SYNC and self.ps_mode != DistributedMode.ASYNC:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "CpuSyncPsProgramBuilder"))
format(self.ps_mode, "CpuSyncPsProgramBuilder"))
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
......@@ -178,7 +178,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
if self.use_ps_gpu or self.ps_mode == DistributedMode.GEO or self.attrs[
'is_heter_ps_mode'] == False:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "HeterAsyncPsProgramBuilder"))
format(self.ps_mode, "HeterAsyncPsProgramBuilder"))
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
......
......@@ -577,7 +577,7 @@ class CompileTimeStrategy(object):
sparse_ctx = CommContext(grad_name, [grad_name],
["127.0.0.1:6071"], [var_numel],
[grad_name], trainer_id, True, True,
is_distributed, idx, False)
is_distributed, idx, False, False, -1)
idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx
......@@ -615,7 +615,8 @@ class CompileTimeStrategy(object):
aggregate = True
dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], origin_varnames, trainer_id,
aggregate, False, False, idx, False)
aggregate, False, False, idx, False, False,
-1)
send_ctx[grad_name] = dense_ctx
idx += 1
else:
......@@ -630,7 +631,7 @@ class CompileTimeStrategy(object):
dense_ctx = CommContext(grad_name, [grad_name],
["127.0.0.1:6071"], [var_numel],
[origin_varname], trainer_id, aggregate,
False, False, idx, False)
False, False, idx, False, False, -1)
send_ctx[grad_name] = dense_ctx
idx += 1
return idx
......@@ -672,7 +673,7 @@ class CompileTimeStrategy(object):
sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape,
[grad_name], trainer_id, True, True,
is_distributed, idx, False)
is_distributed, idx, False, False, -1)
idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx
......@@ -750,7 +751,7 @@ class CompileTimeStrategy(object):
sections = [1] * len(endpoints)
names = [name] * len(endpoints)
ctx = CommContext(name, names, endpoints, sections, [name], trainer_id,
True, False, False, idx, True)
True, False, False, idx, True, False, -1)
return name, ctx
def _create_vars_from_blocklist(self, block_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册