未验证 提交 d56a0a1b 编写于 作者: W wangguanqun 提交者: GitHub

fix bug in new the_one_ps (#39505)

* fix benchmark and communicator config

* fix bugs of the_one_ps

* multi program and fix bug in optimizer

* multi program in the_one_ps

* public commcontext
上级 a6abb6e7
......@@ -31,7 +31,8 @@ struct CommContext {
const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false, int table_id_ = -1,
bool is_tensor_table_ = false)
bool is_tensor_table_ = false, bool is_datanorm_table_ = false,
int64_t program_id_ = -1)
: var_name(name),
splited_varnames(names),
epmap(emap),
......@@ -42,7 +43,9 @@ struct CommContext {
is_sparse(is_sparse_),
is_distributed(is_distributed_),
table_id(table_id_),
is_tensor_table(is_tensor_table_) {}
program_id(program_id_),
is_tensor_table(is_tensor_table_),
is_datanorm_table(is_datanorm_table_) {}
CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
......@@ -55,7 +58,9 @@ struct CommContext {
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
table_id = ctx.table_id;
program_id = ctx.program_id;
is_tensor_table = ctx.is_tensor_table;
is_datanorm_table = ctx.is_datanorm_table;
}
std::string print() const {
......@@ -78,7 +83,9 @@ struct CommContext {
ss << " is_sparse: " << is_sparse;
ss << " is_distributed: " << is_distributed << "\n";
ss << " table_id: " << table_id << "\n";
ss << " program_id: " << program_id << "\n";
ss << " is_tensor_table: " << is_tensor_table << "\n";
ss << " is_datanorm_table: " << is_datanorm_table << "\n";
return ss.str();
}
......@@ -93,7 +100,9 @@ struct CommContext {
bool is_sparse;
bool is_distributed;
int table_id;
int64_t program_id;
bool is_tensor_table;
bool is_datanorm_table;
};
} // namespace distributed
......
......@@ -103,11 +103,13 @@ void BindCommunicatorContext(py::module* m) {
py::init<const std::string&, const std::vector<std::string>&,
const std::vector<std::string>&, const std::vector<int64_t>&,
const std::vector<std::string>&, int, bool, bool, bool, int,
bool>())
bool, bool, int64_t>())
.def("var_name", [](const CommContext& self) { return self.var_name; })
.def("trainer_id",
[](const CommContext& self) { return self.trainer_id; })
.def("table_id", [](const CommContext& self) { return self.table_id; })
.def("program_id",
[](const CommContext& self) { return self.program_id; })
.def("split_varnames",
[](const CommContext& self) { return self.splited_varnames; })
.def("split_endpoints",
......@@ -122,6 +124,8 @@ void BindCommunicatorContext(py::module* m) {
[](const CommContext& self) { return self.origin_varnames; })
.def("is_tensor_table",
[](const CommContext& self) { return self.is_tensor_table; })
.def("is_datanorm_table",
[](const CommContext& self) { return self.is_datanorm_table; })
.def("__str__", [](const CommContext& self) { return self.print(); });
}
......
......@@ -46,7 +46,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
attrs['loss'] = loss
attrs['min_block_size'] = 81920
attrs['origin_main_program'] = loss.block.program
attrs['origin_main_programs'] = [loss.block.program]
attrs['origin_startup_program'] = startup_program
attrs['origin_startup_programs'] = [startup_program]
attrs['cloned_main'] = attrs['origin_main_program'].clone()
attrs['cloned_startup'] = attrs['origin_startup_program'].clone()
......
......@@ -560,9 +560,9 @@ class FakeInitOpsPass(PassBase):
return True
def _get_sparse_table_names(self, attrs):
dist_varnames = get_sparse_tablenames(attrs['origin_main_program'],
dist_varnames = get_sparse_tablenames(attrs['origin_main_programs'],
True)
sparse_varnames = get_sparse_tablenames(attrs['origin_main_program'],
sparse_varnames = get_sparse_tablenames(attrs['origin_main_programs'],
False)
return list(set(dist_varnames + sparse_varnames))
......
......@@ -24,8 +24,8 @@ from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.framework import Variable, Parameter
from .runtime_base import RuntimeBase
from ..base.private_helper_function import wait_server_ready
from paddle.distributed.fleet.runtime.runtime_base import RuntimeBase
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready
from paddle.fluid.communicator import Communicator, HeterClient
from google.protobuf import text_format
......@@ -39,8 +39,17 @@ def conv_indent(indent):
PSERVER_SAVE_SUFFIX = ".shard"
def parse_table_class(varname, o_main_program):
for op in o_main_program.global_block().ops:
def get_program_by_id(context, program_id):
programs = context["origin_main_programs"]
for i, program in enumerate(programs):
if id(program) == program_id:
return program, context["origin_startup_programs"][i]
return None, None
def parse_table_class(varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue
......@@ -53,9 +62,10 @@ def parse_table_class(varname, o_main_program):
return "MemorySparseTable"
def get_default_accessor_proto(accessor, varname, o_main_program):
def get_default_accessor_proto(accessor, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
embedding_dim = 0
for var in o_main_program.list_vars():
for var in main_program.list_vars():
if var.name == varname:
embedding_dim = var.shape[1]
break
......@@ -123,9 +133,10 @@ def get_default_accessor_proto(accessor, varname, o_main_program):
sgd_param.adam.weight_bounds.extend([-10.0, 10.0])
def check_embedding_dim(accessor, varname, o_main_program):
def check_embedding_dim(accessor, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
embedding_dim = 0
for var in o_main_program.list_vars():
for var in main_program.list_vars():
if var.name == varname:
embedding_dim = var.shape[1]
break
......@@ -172,6 +183,8 @@ class CommonAccessor:
self.dims = []
self.trainer_num = 0
self.sync = "false"
self.table_num = None
self.table_dim = None
self.initializers = []
self.opt_input_map = {}
self.opt_attr_map = {}
......@@ -192,6 +205,7 @@ class CommonAccessor:
opt_input_map["sum"] = [("Param", None)]
opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1),
("LearningRate", 1)]
opt_input_map["summary"] = [("Param", None), ("SummaryDecayRate", 1)]
opt_attr_map = {}
opt_attr_map["sgd"] = []
......@@ -201,6 +215,7 @@ class CommonAccessor:
("epsilon", "f")]
opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")]
opt_attr_map["summary"] = []
opt_init_map = {}
opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
......@@ -212,8 +227,9 @@ class CommonAccessor:
self.opt_input_map = opt_input_map
self.opt_init_map = opt_init_map
def parse_entry(self, varname, o_main_program):
for op in o_main_program.global_block().ops:
def parse_entry(self, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue
......@@ -243,23 +259,36 @@ class CommonAccessor:
attr_str = ""
origin_var_name = value_name
print("get_initializer_attr param name:", value_name)
for op in o_startup_program.global_block().ops:
if op.type in self.opt_init_map.keys(
) and origin_var_name == op.output("Out")[0]:
init_attr = [op.type]
print("get_initializer_attr op type:", op.type)
for attr in self.opt_init_map[op.type]:
print("get_initializer_attr opt_init_map attr:", attr)
init_attr.append(str(op.attr(attr)))
print("get_initializer_attr op attr:", str(op.attr(attr)))
attr_str = l_in.join(init_attr)
break
return attr_str
def parse_by_optimizer(self, grad_name, is_sparse, total_dims, context,
adam_d2sum):
main_program = context['origin_main_program']
startup_program = context['startup_main_program']
def parse_by_optimizer(self, ctx, context):
grad_name = ctx.origin_varnames()[0]
is_sparse = ctx.is_sparse()
size = ctx.sections()[0]
single_dim = ctx.sections()[1] if ctx.is_sparse() else 1
adam_d2sum = context["user_defined_strategy"].adam_d2sum
print("parse_by_optimizer table_id:{} is_datanorm:{}".format(
ctx.table_id(), ctx.is_datanorm_table()))
main_program, startup_program = get_program_by_id(context,
ctx.program_id())
pserver_id = get_role_id(context['role_maker'])
pserver_num = len(get_ps_endpoints(context['role_maker']))
optimizer_ops = get_optimize_ops(main_program)
print("the one ps optimizer_ops:", optimizer_ops)
print("the one ps parse_by_optimizer grad_name:", grad_name)
oop = None
for op in optimizer_ops:
......@@ -278,6 +307,8 @@ class CommonAccessor:
initializers = []
self.trainer_num = get_trainers(context['role_maker'])
self.table_num = size
self.table_dim = single_dim
if oop.type != 'adam' and adam_d2sum == True:
print('optimization algorithm is not adam, set adam_d2sum False')
......@@ -291,7 +322,11 @@ class CommonAccessor:
param_varnames = self.opt_input_map["naive_adagrad"]
attr_varnames = self.opt_attr_map["naive_adagrad"]
self.accessor_class = "sgd"
elif adam_d2sum:
elif ctx.is_datanorm_table():
param_varnames = self.opt_input_map["summary"]
attr_varnames = self.opt_attr_map["summary"]
self.accessor_class = "summary"
elif adam_d2sum and not is_sparse:
param_varnames = self.opt_input_map["adam_d2sum"]
attr_varnames = self.opt_attr_map["adam_d2sum"]
self.accessor_class = "adam_d2sum"
......@@ -306,10 +341,9 @@ class CommonAccessor:
#for dims
if shape is None:
if is_sparse:
shape = total_dims
shape = single_dim
else:
shape = self.get_shard(total_dims, pserver_num,
pserver_id)
shape = self.get_shard(size, pserver_num, pserver_id)
dims.append(shape)
#for initializers
......@@ -333,6 +367,27 @@ class CommonAccessor:
else:
initializer = "fill_constant&0"
initializers.append(initializer)
elif self.accessor_class == "summary":
#for dims
if shape is None:
if is_sparse:
shape = single_dim
else:
shape = self.get_shard(size, pserver_num, pserver_id)
dims.append(shape)
#for initializers
if formal_name == "Param":
param = main_program.global_block().vars[oop.input(
formal_name)[0]]
initializer = self.get_initializer_attr(param.name,
startup_program)
elif formal_name == "SummaryDecayRate":
initializer = "fill_constant&0.99999"
else:
initializer = "fill_constant&0"
initializers.append(initializer)
else:
if formal_name == "G2Sum":
dims.append(1)
......@@ -348,9 +403,9 @@ class CommonAccessor:
if shape is None:
if is_sparse:
shape = total_dims
shape = single_dim
else:
shape = self.get_shard(total_dims, pserver_num,
shape = self.get_shard(size, pserver_num,
pserver_id)
dims.append(shape)
......@@ -379,6 +434,10 @@ class CommonAccessor:
attrs += "entry: \"{}\" ".format(self.entry)
attrs += "trainer_num: {} ".format(self.trainer_num)
attrs += "sync: {} ".format(self.sync)
if self.table_num:
attrs += "table_num: {} ".format(self.table_num)
if self.table_dim:
attrs += "table_dim: {} ".format(self.table_dim)
for param in self.params:
attrs += "params: \"{}\" ".format(param)
......@@ -448,10 +507,7 @@ class Table:
accessor_str = accessor_str.format(
conv_indent(indent), self.accessor_proto, conv_indent(indent))
attrs += accessor_str + "\n"
return table_str.format(
conv_indent(indent), attrs, conv_indent(indent))
if self.accessor is not None:
elif self.accessor is not None:
attrs += self.accessor.to_string(indent)
attrs += "\n"
......@@ -607,7 +663,9 @@ class TheOnePSRuntime(RuntimeBase):
def _set_basic_info(self, context):
self.context = context
self.role_maker = context["role_maker"]
self.origin_main_program = context["origin_main_program"]
self.origin_main_programs = context["origin_main_programs"]
self.context[
'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode
......@@ -615,10 +673,13 @@ class TheOnePSRuntime(RuntimeBase):
self.context['trainer'] = TrainerRuntimeConfig(context[
'valid_strategy'])
self.context['ps_mode'] = self.context['trainer'].mode
self.context['use_ps_gpu'] = context['valid_strategy'].use_ps_gpu
self.context['use_ps_gpu'] = context['valid_strategy'].a_sync_configs[
'use_ps_gpu']
self.is_sync = True if self.context[
'ps_mode'] == DistributedMode.SYNC else False
self.context['grad_name_to_param_name'] = {}
self.context['tensor_table'] = {}
build_var_distributed(self.context)
def _init_worker(self):
worker = self._get_fleet_proto(is_server=False, is_sync=self.is_sync)
......@@ -689,6 +750,7 @@ class TheOnePSRuntime(RuntimeBase):
sync_kwargs = sync_strategy_envs()
kwargs.update(sync_kwargs)
print("communicator config:", trainer_config.get_communicator_flags())
self._communicator = Communicator(
trainer_config.mode, kwargs,
trainer_config.get_communicator_flags())
......@@ -893,7 +955,7 @@ class TheOnePSRuntime(RuntimeBase):
common.table_name = self.context['grad_name_to_param_name'][
ctx.origin_varnames()[0]]
if self.ps_mode == DistributedMode.GEO:
if self.context['ps_mode'] == DistributedMode.GEO:
table.table_class = "SparseGeoTable"
else:
all_table_proto = self.context[
......@@ -907,7 +969,8 @@ class TheOnePSRuntime(RuntimeBase):
table.table_class = table_proto.table_class
else:
table.table_class = parse_table_class(
common.table_name, self.origin_main_program)
common.table_name,
ctx.program_id(), self.context)
if table.table_class != 'MemorySparseTable':
table.table_class = 'MemorySparseTable'
warnings.warn(
......@@ -925,12 +988,12 @@ class TheOnePSRuntime(RuntimeBase):
warnings.warn(
"The accessor of sparse table is not set, use default value."
)
get_default_accessor_proto(table_proto.accessor,
common.table_name,
self.origin_main_program)
get_default_accessor_proto(
table_proto.accessor, common.table_name,
ctx.program_id(), self.context)
check_embedding_dim(table_proto.accessor,
common.table_name,
self.origin_main_program)
ctx.program_id(), self.context)
table.accessor_proto = text_format.MessageToString(
table_proto.accessor)
else:
......@@ -940,15 +1003,11 @@ class TheOnePSRuntime(RuntimeBase):
common.table_name = "MergedDense"
adam_d2sum = self.context["user_defined_strategy"].adam_d2sum
common.parse_by_optimizer(ctx.origin_varnames()[0],
ctx.is_sparse(),
ctx.sections()[1] if ctx.is_sparse()
else ctx.sections()[0], self.context,
adam_d2sum)
common.parse_by_optimizer(ctx, self.context)
if ctx.is_sparse():
common.parse_entry(common.table_name,
self.origin_main_program)
ctx.program_id(), self.context)
if is_sync:
common.sync = "true"
......@@ -1023,8 +1082,9 @@ class TheOnePSRuntime(RuntimeBase):
self._server.init_server(proto_txt, string_hosts, role_id, trainers,
self._server_sub_program)
dist_varnames = get_sparse_tablenames(self.origin_main_program, True)
sparse_varnames = get_sparse_tablenames(self.origin_main_program, False)
dist_varnames = get_sparse_tablenames(self.origin_main_programs, True)
sparse_varnames = get_sparse_tablenames(self.origin_main_programs,
False)
distributed_varnames = dist_varnames + sparse_varnames
......@@ -1070,6 +1130,7 @@ class TheOnePSRuntime(RuntimeBase):
if var.name in exclude_var_names:
return False
from .utils.public import _get_varname_parts
origin_varname, _, _ = _get_varname_parts(var.name)
if origin_varname.endswith("@GRAD"):
return False
......@@ -1085,16 +1146,24 @@ class TheOnePSRuntime(RuntimeBase):
return is_valid
def _get_inference_model_path(self, dirname):
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
return model_path
def _save_sparse_params(self, executor, dirname, context, main_program,
mode):
distributed_varnames = get_sparse_tablenames(
self.context['origin_main_program'], True)
distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
True)
values = []
model_path = self._get_inference_model_path(dirname)
for id, names in context.items():
if names[0] not in distributed_varnames:
# only save sparse param to local
try:
self._worker.recv_and_save_model(id, dirname)
self._worker.recv_and_save_model(id, model_path)
except:
pass
# save sparse & distributed param on server
......@@ -1221,10 +1290,7 @@ class TheOnePSRuntime(RuntimeBase):
infer_program._copy_dist_param_info_from(program)
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
model_path = self._get_inference_model_path(dirname)
model_basename = "__model__"
model_basename = os.path.join(model_path, model_basename)
paddle.save(infer_program, model_basename)
......@@ -1266,7 +1332,7 @@ class TheOnePSRuntime(RuntimeBase):
self._ps_inference_save_persistables(*args, **kwargs)
def _load_sparse_params(self, dirname, context, main_program, mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_program,
distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
True)
values = []
for id, names in context.items():
......
......@@ -79,7 +79,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式
super(GeoPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "GeoPsProgramBuilder"))
format(self.ps_mode, "GeoPsProgramBuilder"))
def _build_trainer_programs(self):
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
......@@ -97,9 +97,9 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
logger.info("start building cpu-sync-ps program")
super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.SYNC:
if self.ps_mode != DistributedMode.SYNC and self.ps_mode != DistributedMode.ASYNC:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "CpuSyncPsProgramBuilder"))
format(self.ps_mode, "CpuSyncPsProgramBuilder"))
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
......@@ -178,7 +178,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
if self.use_ps_gpu or self.ps_mode == DistributedMode.GEO or self.attrs[
'is_heter_ps_mode'] == False:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "HeterAsyncPsProgramBuilder"))
format(self.ps_mode, "HeterAsyncPsProgramBuilder"))
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
......
......@@ -54,6 +54,9 @@ SPARSE_GRAD_OP_TYPE_DICT = {
}
DEFAULT_DEVICE = 'cpu'
DATA_NORM_NAME = [".batch_size", ".batch_sum", ".batch_square_sum"]
DATA_NORM_GRAD_NAME = [x + "@GRAD" for x in DATA_NORM_NAME]
def logger_config(log_path, logging_name):
logger = logging.getLogger(logging_name)
......@@ -84,6 +87,8 @@ class DistributedMode:
class TrainerRuntimeConfig(object):
def __init__(self, valid_strategy):
self.mode = None
num_threads = os.getenv("CPU_NUM", "1")
send_queue_size = num_threads
k_steps = valid_strategy.a_sync_configs["k_steps"]
logger.info("ps mode in strategy: {}, {}".format(
valid_strategy.a_sync, valid_strategy.a_sync_configs["k_steps"]))
......@@ -95,14 +100,13 @@ class TrainerRuntimeConfig(object):
if valid_strategy.a_sync and k_steps > 0:
self.mode = DistributedMode.GEO
num_threads = os.getenv("CPU_NUM", "1")
send_queue_size = k_steps
self.runtime_configs = {}
self.runtime_configs['communicator_max_merge_var_num'] = os.getenv(
"FLAGS_communicator_max_merge_var_num", num_threads)
"FLAGS_communicator_max_merge_var_num", send_queue_size)
self.runtime_configs['communicator_send_queue_size'] = os.getenv(
"FLAGS_communicator_send_queue_size", num_threads)
"FLAGS_communicator_send_queue_size", send_queue_size)
self.runtime_configs[
'communicator_independent_recv_thread'] = os.getenv(
"FLAGS_communicator_independent_recv_thread", "1")
......@@ -116,6 +120,55 @@ class TrainerRuntimeConfig(object):
self.runtime_configs['communicator_is_sgd_optimizer'] = os.getenv(
"FLAGS_communicator_is_sgd_optimizer", "1")
def get_communicator_flags(self):
need_keys = []
num_threads = os.getenv("CPU_NUM", "1")
mode_str = ""
if self.mode is None or self.mode == DistributedMode.ASYNC:
need_keys = self.runtime_configs.keys()
mode_str = "async"
elif self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC:
mode_str = "sync or half_async"
need_keys = [
'communicator_max_merge_var_num',
'communicator_send_wait_times', 'communicator_thread_pool_size',
'communicator_send_queue_size'
]
elif self.mode == DistributedMode.GEO:
mode_str = "GEO"
need_keys = [
'communicator_thread_pool_size', 'communicator_send_wait_times',
'communicator_max_merge_var_num', 'communicator_send_queue_size'
]
else:
raise ValueError("Unsupported Mode")
if self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC:
max_merge_var_num = self.runtime_configs[
'communicator_max_merge_var_num']
send_queue_size = self.runtime_configs[
'communicator_send_queue_size']
if max_merge_var_num != num_threads:
print('WARNING: In {} mode, communicator_max_merge_var_num '
'must be equal to CPU_NUM. But received, '
'communicator_max_merge_var_num = {}, CPU_NUM = '
'{}. communicator_max_merge_var_num will be forced to {}.'
.format(mode_str, max_merge_var_num, num_threads,
num_threads))
self.runtime_configs[
'communicator_max_merge_var_num'] = num_threads
if send_queue_size != num_threads:
print('WARNING: In {} mode, communicator_send_queue_size '
'must be equal to CPU_NUM. But received, '
'communicator_send_queue_size = {}, CPU_NUM = '
'{}. communicator_send_queue_size will be forced to {}.'
.format(mode_str, send_queue_size, num_threads,
num_threads))
self.runtime_configs[
'communicator_send_queue_size'] = num_threads
return dict((key, str(self.runtime_configs[key])) for key in need_keys)
def get_lr_ops(program):
lr_ops = []
......@@ -176,6 +229,13 @@ def get_ps_endpoint(role_maker):
return role_maker.get_pserver_endpoints()[get_role_id(role_maker)]
def get_ps_endpoints(role_maker):
try:
return role_maker._get_pserver_endpoints()
except Exception:
return role_maker.get_pserver_endpoints()
def get_heter_worker_endpoint(role_maker):
try:
return role_maker._get_heter_worker_endpoint()
......@@ -224,8 +284,9 @@ def is_sparse_op(op):
return False
def get_sparse_tablenames(program, is_distributed):
def get_sparse_tablenames(programs, is_distributed):
tablenames = set()
for program in programs:
if is_distributed:
for op in program.global_block().ops:
if is_distributed_sparse_op(op):
......@@ -237,13 +298,6 @@ def get_sparse_tablenames(program, is_distributed):
return list(tablenames)
def get_ps_endpoints(role_maker):
try:
return role_maker._get_pserver_endpoints()
except Exception:
return role_maker.get_pserver_endpoints()
def get_trainers(role_maker):
try:
return role_maker._worker_num()
......@@ -251,7 +305,7 @@ def get_trainers(role_maker):
return role_maker.worker_num()
def get_dense_send_context(context,
def get_dense_send_context(program,
send_ctx,
idx,
merged_dense_pairs,
......@@ -260,34 +314,72 @@ def get_dense_send_context(context,
if len(merged_dense_pairs) < 1:
return idx
if not split_dense_table:
dense_pairs = []
data_norm_pairs = []
for merged in merged_dense_pairs:
is_data_norm = False
grad = merged[1]
varname = grad.merged_var.name
for name in DATA_NORM_GRAD_NAME:
if varname.endswith(name):
is_data_norm = True
if is_data_norm:
data_norm_pairs.append(merged)
else:
dense_pairs.append(merged)
# simple dense table
origin_varnames = []
var_numel = 0
for merged in merged_dense_pairs:
for merged in dense_pairs:
grad = merged[1]
origin_varnames.append(grad.merged_var.name)
var = context['origin_main_program'].global_block().vars[
grad.merged_var.name]
var = program.global_block().vars[grad.merged_var.name]
var_numel += reduce(lambda x, y: x * y, var.shape)
grad_name = "Dense@Grad"
trainer_id = get_role_id(context['role_maker'])
grad_name = "Dense@GRAD_" + str(idx)
aggregate = True
print("public get_dense_send_context dense_table:", grad_name,
var_numel, origin_varnames)
dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], origin_varnames, trainer_id,
aggregate, False, False, idx, False)
aggregate, False, False, idx, False, False,
id(program))
send_ctx[grad_name] = dense_ctx
idx += 1
if len(data_norm_pairs) <= 0:
return idx
# data norm table
origin_varnames = []
var_numel = 0
for merged in data_norm_pairs:
grad = merged[1]
origin_varnames.append(grad.merged_var.name)
var = program.global_block().vars[grad.merged_var.name]
var_numel += reduce(lambda x, y: x * y, var.shape)
grad_name = "DataNorm@GRAD_" + str(idx)
aggregate = True
print("public get_dense_send_context data_norm table:", grad_name,
var_numel, origin_varnames)
data_norm_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], origin_varnames, trainer_id,
aggregate, False, False, idx, False, True,
id(program))
send_ctx[grad_name] = data_norm_ctx
idx += 1
else:
for merged in merged_dense_pairs:
grad = merged[1]
origin_varname = grad.merged_var.name
var = context['origin_main_program'].global_block().vars[
origin_varname]
var = program.global_block().vars[origin_varname]
var_numel = reduce(lambda x, y: x * y, var.shape)
grad_name = origin_varname
aggregate = True
dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], [origin_varname], trainer_id,
aggregate, False, False, idx, False)
aggregate, False, False, idx, False, False,
id(program))
send_ctx[grad_name] = dense_ctx
idx += 1
return idx
......@@ -299,23 +391,26 @@ def get_geo_trainer_send_context(context):
format(ps_mode, "get_geo_trainer_send_context"))
send_ctx = {}
trainer_id = get_role_id(context['role_maker'])
origin_programs = context['origin_main_programs']
idx = 0
distibuted_varnames = get_sparse_tablenames(context['origin_main_program'],
True)
for merged in context['merged_sparse_pairs']:
distibuted_varnames = get_sparse_tablenames(origin_programs, True)
for i, program in enumerate(origin_programs):
merged_sparse_pairs = context['merged_sparse_pairs'][i]
for merged in merged_sparse_pairs:
param, grad = merged
grad_name = grad.merged_var.name
param_name = param.merged_var.name
is_distributed = True if param_name in distibuted_varnames else False
var = context['origin_main_program'].global_block().vars[
grad.merged_var.name]
var = program.global_block().vars[grad.merged_var.name]
var_numel = reduce(lambda x, y: x * y, var.shape[1:])
sparse_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], [grad_name], trainer_id, True,
True, is_distributed, idx, False)
sparse_ctx = CommContext(grad_name, [grad_name],
["127.0.0.1:6071"], [var_numel],
[grad_name], trainer_id, True, True,
is_distributed, idx, False, False,
id(program))
idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx
......@@ -336,7 +431,7 @@ def _step_ctx(idx, role_maker):
sections = [1] * len(endpoints)
names = [name] * len(endpoints)
ctx = CommContext(name, names, endpoints, sections, [name], trainer_id,
True, False, False, idx, True)
True, False, False, idx, True, False, -1)
return name, ctx
......@@ -348,14 +443,19 @@ def get_the_one_send_context(context,
ep_list = ["127.0.0.1:6071"]
send_ctx = {}
trainer_id = get_role_id(context['role_maker'])
origin_programs = context['origin_main_programs']
idx = 0
idx += get_dense_send_context(context, send_ctx, idx,
context['merged_dense_pairs'], trainer_id,
for i, program in enumerate(origin_programs):
merged_dense_pairs = context['merged_dense_pairs'][i]
idx += get_dense_send_context(program, send_ctx, idx,
merged_dense_pairs, trainer_id,
split_dense_table)
distibuted_varnames = get_sparse_tablenames(context['origin_main_program'],
True)
for merged in context['merged_sparse_pairs']:
distibuted_varnames = get_sparse_tablenames(origin_programs, True)
print("public distibuted_varnames:", distibuted_varnames)
for i, program in enumerate(origin_programs):
merged_sparse_pairs = context['merged_sparse_pairs'][i]
for merged in merged_sparse_pairs:
param, grad = merged
grad_name = grad.merged_var.name
param_name = param.merged_var.name
......@@ -366,15 +466,19 @@ def get_the_one_send_context(context,
is_distributed = True if param_name in distibuted_varnames else False
var = context['origin_main_program'].global_block().vars[
grad.merged_var.name]
var = program.global_block().vars[grad.merged_var.name]
shape = list(var.shape)
shape[0] = 0 if is_distributed else shape[0]
print("public get_the_one_send_context sparse:", grad_name,
splited_varname, shape)
if grad_name in send_ctx:
continue
sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape,
[grad_name], trainer_id, True, True,
is_distributed, idx, False)
is_distributed, idx, False, False,
id(program))
idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx
......@@ -1073,7 +1177,7 @@ def get_the_one_recv_context(context,
param_names = []
for grad_varname in origin_grad_varnames:
param_name = grad_name_to_param_name[grad_varname]
param_name = context["grad_name_to_param_name"][grad_varname]
param_names.append(param_name)
recv_id_maps[ctx.table_id()] = param_names
else:
......@@ -1090,7 +1194,7 @@ def get_the_one_recv_context(context,
param_names = []
for grad_varname in origin_grad_varnames:
param_name = grad_name_to_param_name[grad_varname]
param_name = context["grad_name_to_param_name"][grad_varname]
param_names.append(param_name)
recv_id_maps[ctx.table_id()] = param_names
return recv_id_maps
......@@ -1141,15 +1245,25 @@ class MergedVariable:
def build_var_distributed(context):
sparse_pairs, dense_pairs = get_param_grads(context['origin_main_program'])
origin_for_sparse = []
origin_for_dense = []
param_name_grad_name = {}
origin_programs = context['origin_main_programs']
param_name_to_grad_name = {}
grad_name_to_param_name = {}
context["merged_variables_pairs"] = []
context["origin_sparse_pairs"] = []
context["origin_dense_pairs"] = []
context["merged_sparse_pairs"] = []
context['merged_dense_pairs'] = []
context["merged_variables_pairs"] = []
context["merged_variable_map"] = {}
for origin_program in origin_programs:
sparse_pairs, dense_pairs = get_param_grads(origin_program)
print("public build_var_distributed sparse_pairs:", sparse_pairs)
print("public build_var_distributed dense_pairs:", dense_pairs)
origin_for_sparse = []
origin_for_dense = []
merged_sparse_pairs = []
merged_dense_pairs = []
merged_variables_pairs = []
for param, grad in sparse_pairs:
origin_for_sparse.append((param, grad))
......@@ -1162,18 +1276,22 @@ def build_var_distributed(context):
m_param = MergedVariable(param, [param], [0])
m_grad = MergedVariable(grad, [grad], [0])
context["merged_variables_pairs"].append((m_param, m_grad))
context["merged_dense_pairs"].append((m_param, m_grad))
merged_variables_pairs.append((m_param, m_grad))
merged_dense_pairs.append((m_param, m_grad))
print("public build_var_distributed merged_dense_pairs:",
merged_dense_pairs)
for sparse_pair in origin_for_sparse:
param, grad = sparse_pair
m_param = MergedVariable(param, [param], [0])
m_grad = MergedVariable(grad, [grad], [0])
context["merged_variables_pairs"].append((m_param, m_grad))
context["merged_sparse_pairs"].append((m_param, m_grad))
merged_variables_pairs.append((m_param, m_grad))
merged_sparse_pairs.append((m_param, m_grad))
print("public build_var_distributed merged_sparse_pairs:",
merged_sparse_pairs)
for merged in context["merged_variables_pairs"]:
for merged in merged_variables_pairs:
m_param, m_grad = merged
context["merged_variable_map"][
m_param.merged_var.name] = m_param.merged_var
......@@ -1185,14 +1303,30 @@ def build_var_distributed(context):
param_merges.extend(origin_for_dense)
for param, grad in param_merges:
param_name_grad_name[param.name] = grad.name
param_name_to_grad_name[param.name] = grad.name
grad_name_to_param_name[grad.name] = param.name
context["origin_sparse_pairs"] = origin_for_sparse
context["origin_dense_pairs"] = origin_for_dense
context["param_name_to_grad_name"] = param_name_grad_name
context["origin_sparse_pairs"].append(origin_for_sparse)
context["origin_dense_pairs"].append(origin_for_dense)
context["merged_sparse_pairs"].append(merged_sparse_pairs)
context['merged_dense_pairs'].append(merged_dense_pairs)
context["param_name_to_grad_name"] = param_name_to_grad_name
context["grad_name_to_param_name"] = grad_name_to_param_name
print("public build_var_distributed origin_sparse_pairs:",
context["origin_sparse_pairs"])
print("public build_var_distributed origin_for_dense:",
context["origin_dense_pairs"])
print("public build_var_distributed merged_sparse_pairs:",
context["merged_sparse_pairs"])
print("public build_var_distributed merged_dense_pairs:",
context['merged_dense_pairs'])
print("public build_var_distributed param_name_to_grad_name:",
param_name_to_grad_name)
print("public build_var_distributed grad_name_to_param_name:",
grad_name_to_param_name)
def _is_opt_role_op(op):
# NOTE : depend on oprole to find out whether this op is for
......
......@@ -577,7 +577,7 @@ class CompileTimeStrategy(object):
sparse_ctx = CommContext(grad_name, [grad_name],
["127.0.0.1:6071"], [var_numel],
[grad_name], trainer_id, True, True,
is_distributed, idx, False)
is_distributed, idx, False, False, -1)
idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx
......@@ -615,7 +615,8 @@ class CompileTimeStrategy(object):
aggregate = True
dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], origin_varnames, trainer_id,
aggregate, False, False, idx, False)
aggregate, False, False, idx, False, False,
-1)
send_ctx[grad_name] = dense_ctx
idx += 1
else:
......@@ -630,7 +631,7 @@ class CompileTimeStrategy(object):
dense_ctx = CommContext(grad_name, [grad_name],
["127.0.0.1:6071"], [var_numel],
[origin_varname], trainer_id, aggregate,
False, False, idx, False)
False, False, idx, False, False, -1)
send_ctx[grad_name] = dense_ctx
idx += 1
return idx
......@@ -672,7 +673,7 @@ class CompileTimeStrategy(object):
sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape,
[grad_name], trainer_id, True, True,
is_distributed, idx, False)
is_distributed, idx, False, False, -1)
idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx
......@@ -750,7 +751,7 @@ class CompileTimeStrategy(object):
sections = [1] * len(endpoints)
names = [name] * len(endpoints)
ctx = CommContext(name, names, endpoints, sections, [name], trainer_id,
True, False, False, idx, True)
True, False, False, idx, True, False, -1)
return name, ctx
def _create_vars_from_blocklist(self, block_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册