未验证 提交 d56a0a1b 编写于 作者: W wangguanqun 提交者: GitHub

fix bug in new the_one_ps (#39505)

* fix benchmark and communicator config

* fix bugs of the_one_ps

* multi program and fix bug in optimizer

* multi program in the_one_ps

* public commcontext
上级 a6abb6e7
...@@ -31,7 +31,8 @@ struct CommContext { ...@@ -31,7 +31,8 @@ struct CommContext {
const std::vector<std::string> &origin_names, int id, const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true, bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false, int table_id_ = -1, bool is_distributed_ = false, int table_id_ = -1,
bool is_tensor_table_ = false) bool is_tensor_table_ = false, bool is_datanorm_table_ = false,
int64_t program_id_ = -1)
: var_name(name), : var_name(name),
splited_varnames(names), splited_varnames(names),
epmap(emap), epmap(emap),
...@@ -42,7 +43,9 @@ struct CommContext { ...@@ -42,7 +43,9 @@ struct CommContext {
is_sparse(is_sparse_), is_sparse(is_sparse_),
is_distributed(is_distributed_), is_distributed(is_distributed_),
table_id(table_id_), table_id(table_id_),
is_tensor_table(is_tensor_table_) {} program_id(program_id_),
is_tensor_table(is_tensor_table_),
is_datanorm_table(is_datanorm_table_) {}
CommContext(const CommContext &ctx) { CommContext(const CommContext &ctx) {
var_name = ctx.var_name; var_name = ctx.var_name;
...@@ -55,7 +58,9 @@ struct CommContext { ...@@ -55,7 +58,9 @@ struct CommContext {
origin_varnames = ctx.origin_varnames; origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed; is_distributed = ctx.is_distributed;
table_id = ctx.table_id; table_id = ctx.table_id;
program_id = ctx.program_id;
is_tensor_table = ctx.is_tensor_table; is_tensor_table = ctx.is_tensor_table;
is_datanorm_table = ctx.is_datanorm_table;
} }
std::string print() const { std::string print() const {
...@@ -78,7 +83,9 @@ struct CommContext { ...@@ -78,7 +83,9 @@ struct CommContext {
ss << " is_sparse: " << is_sparse; ss << " is_sparse: " << is_sparse;
ss << " is_distributed: " << is_distributed << "\n"; ss << " is_distributed: " << is_distributed << "\n";
ss << " table_id: " << table_id << "\n"; ss << " table_id: " << table_id << "\n";
ss << " program_id: " << program_id << "\n";
ss << " is_tensor_table: " << is_tensor_table << "\n"; ss << " is_tensor_table: " << is_tensor_table << "\n";
ss << " is_datanorm_table: " << is_datanorm_table << "\n";
return ss.str(); return ss.str();
} }
...@@ -93,7 +100,9 @@ struct CommContext { ...@@ -93,7 +100,9 @@ struct CommContext {
bool is_sparse; bool is_sparse;
bool is_distributed; bool is_distributed;
int table_id; int table_id;
int64_t program_id;
bool is_tensor_table; bool is_tensor_table;
bool is_datanorm_table;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -103,11 +103,13 @@ void BindCommunicatorContext(py::module* m) { ...@@ -103,11 +103,13 @@ void BindCommunicatorContext(py::module* m) {
py::init<const std::string&, const std::vector<std::string>&, py::init<const std::string&, const std::vector<std::string>&,
const std::vector<std::string>&, const std::vector<int64_t>&, const std::vector<std::string>&, const std::vector<int64_t>&,
const std::vector<std::string>&, int, bool, bool, bool, int, const std::vector<std::string>&, int, bool, bool, bool, int,
bool>()) bool, bool, int64_t>())
.def("var_name", [](const CommContext& self) { return self.var_name; }) .def("var_name", [](const CommContext& self) { return self.var_name; })
.def("trainer_id", .def("trainer_id",
[](const CommContext& self) { return self.trainer_id; }) [](const CommContext& self) { return self.trainer_id; })
.def("table_id", [](const CommContext& self) { return self.table_id; }) .def("table_id", [](const CommContext& self) { return self.table_id; })
.def("program_id",
[](const CommContext& self) { return self.program_id; })
.def("split_varnames", .def("split_varnames",
[](const CommContext& self) { return self.splited_varnames; }) [](const CommContext& self) { return self.splited_varnames; })
.def("split_endpoints", .def("split_endpoints",
...@@ -122,6 +124,8 @@ void BindCommunicatorContext(py::module* m) { ...@@ -122,6 +124,8 @@ void BindCommunicatorContext(py::module* m) {
[](const CommContext& self) { return self.origin_varnames; }) [](const CommContext& self) { return self.origin_varnames; })
.def("is_tensor_table", .def("is_tensor_table",
[](const CommContext& self) { return self.is_tensor_table; }) [](const CommContext& self) { return self.is_tensor_table; })
.def("is_datanorm_table",
[](const CommContext& self) { return self.is_datanorm_table; })
.def("__str__", [](const CommContext& self) { return self.print(); }); .def("__str__", [](const CommContext& self) { return self.print(); });
} }
......
...@@ -46,7 +46,9 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -46,7 +46,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
attrs['loss'] = loss attrs['loss'] = loss
attrs['min_block_size'] = 81920 attrs['min_block_size'] = 81920
attrs['origin_main_program'] = loss.block.program attrs['origin_main_program'] = loss.block.program
attrs['origin_main_programs'] = [loss.block.program]
attrs['origin_startup_program'] = startup_program attrs['origin_startup_program'] = startup_program
attrs['origin_startup_programs'] = [startup_program]
attrs['cloned_main'] = attrs['origin_main_program'].clone() attrs['cloned_main'] = attrs['origin_main_program'].clone()
attrs['cloned_startup'] = attrs['origin_startup_program'].clone() attrs['cloned_startup'] = attrs['origin_startup_program'].clone()
......
...@@ -560,9 +560,9 @@ class FakeInitOpsPass(PassBase): ...@@ -560,9 +560,9 @@ class FakeInitOpsPass(PassBase):
return True return True
def _get_sparse_table_names(self, attrs): def _get_sparse_table_names(self, attrs):
dist_varnames = get_sparse_tablenames(attrs['origin_main_program'], dist_varnames = get_sparse_tablenames(attrs['origin_main_programs'],
True) True)
sparse_varnames = get_sparse_tablenames(attrs['origin_main_program'], sparse_varnames = get_sparse_tablenames(attrs['origin_main_programs'],
False) False)
return list(set(dist_varnames + sparse_varnames)) return list(set(dist_varnames + sparse_varnames))
......
...@@ -24,8 +24,8 @@ from paddle.fluid.compiler import CompiledProgram ...@@ -24,8 +24,8 @@ from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.framework import Variable, Parameter from paddle.fluid.framework import Variable, Parameter
from .runtime_base import RuntimeBase from paddle.distributed.fleet.runtime.runtime_base import RuntimeBase
from ..base.private_helper_function import wait_server_ready from paddle.distributed.fleet.base.private_helper_function import wait_server_ready
from paddle.fluid.communicator import Communicator, HeterClient from paddle.fluid.communicator import Communicator, HeterClient
from google.protobuf import text_format from google.protobuf import text_format
...@@ -39,8 +39,17 @@ def conv_indent(indent): ...@@ -39,8 +39,17 @@ def conv_indent(indent):
PSERVER_SAVE_SUFFIX = ".shard" PSERVER_SAVE_SUFFIX = ".shard"
def parse_table_class(varname, o_main_program): def get_program_by_id(context, program_id):
for op in o_main_program.global_block().ops: programs = context["origin_main_programs"]
for i, program in enumerate(programs):
if id(program) == program_id:
return program, context["origin_startup_programs"][i]
return None, None
def parse_table_class(varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op): if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue continue
...@@ -53,9 +62,10 @@ def parse_table_class(varname, o_main_program): ...@@ -53,9 +62,10 @@ def parse_table_class(varname, o_main_program):
return "MemorySparseTable" return "MemorySparseTable"
def get_default_accessor_proto(accessor, varname, o_main_program): def get_default_accessor_proto(accessor, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
embedding_dim = 0 embedding_dim = 0
for var in o_main_program.list_vars(): for var in main_program.list_vars():
if var.name == varname: if var.name == varname:
embedding_dim = var.shape[1] embedding_dim = var.shape[1]
break break
...@@ -123,9 +133,10 @@ def get_default_accessor_proto(accessor, varname, o_main_program): ...@@ -123,9 +133,10 @@ def get_default_accessor_proto(accessor, varname, o_main_program):
sgd_param.adam.weight_bounds.extend([-10.0, 10.0]) sgd_param.adam.weight_bounds.extend([-10.0, 10.0])
def check_embedding_dim(accessor, varname, o_main_program): def check_embedding_dim(accessor, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
embedding_dim = 0 embedding_dim = 0
for var in o_main_program.list_vars(): for var in main_program.list_vars():
if var.name == varname: if var.name == varname:
embedding_dim = var.shape[1] embedding_dim = var.shape[1]
break break
...@@ -172,6 +183,8 @@ class CommonAccessor: ...@@ -172,6 +183,8 @@ class CommonAccessor:
self.dims = [] self.dims = []
self.trainer_num = 0 self.trainer_num = 0
self.sync = "false" self.sync = "false"
self.table_num = None
self.table_dim = None
self.initializers = [] self.initializers = []
self.opt_input_map = {} self.opt_input_map = {}
self.opt_attr_map = {} self.opt_attr_map = {}
...@@ -192,6 +205,7 @@ class CommonAccessor: ...@@ -192,6 +205,7 @@ class CommonAccessor:
opt_input_map["sum"] = [("Param", None)] opt_input_map["sum"] = [("Param", None)]
opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1), opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1),
("LearningRate", 1)] ("LearningRate", 1)]
opt_input_map["summary"] = [("Param", None), ("SummaryDecayRate", 1)]
opt_attr_map = {} opt_attr_map = {}
opt_attr_map["sgd"] = [] opt_attr_map["sgd"] = []
...@@ -201,6 +215,7 @@ class CommonAccessor: ...@@ -201,6 +215,7 @@ class CommonAccessor:
("epsilon", "f")] ("epsilon", "f")]
opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"), opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")] ("epsilon", "f")]
opt_attr_map["summary"] = []
opt_init_map = {} opt_init_map = {}
opt_init_map["gaussian_random"] = ["seed", "mean", "std"] opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
...@@ -212,8 +227,9 @@ class CommonAccessor: ...@@ -212,8 +227,9 @@ class CommonAccessor:
self.opt_input_map = opt_input_map self.opt_input_map = opt_input_map
self.opt_init_map = opt_init_map self.opt_init_map = opt_init_map
def parse_entry(self, varname, o_main_program): def parse_entry(self, varname, program_id, context):
for op in o_main_program.global_block().ops: main_program, startup_program = get_program_by_id(context, program_id)
for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op): if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue continue
...@@ -243,23 +259,36 @@ class CommonAccessor: ...@@ -243,23 +259,36 @@ class CommonAccessor:
attr_str = "" attr_str = ""
origin_var_name = value_name origin_var_name = value_name
print("get_initializer_attr param name:", value_name)
for op in o_startup_program.global_block().ops: for op in o_startup_program.global_block().ops:
if op.type in self.opt_init_map.keys( if op.type in self.opt_init_map.keys(
) and origin_var_name == op.output("Out")[0]: ) and origin_var_name == op.output("Out")[0]:
init_attr = [op.type] init_attr = [op.type]
print("get_initializer_attr op type:", op.type)
for attr in self.opt_init_map[op.type]: for attr in self.opt_init_map[op.type]:
print("get_initializer_attr opt_init_map attr:", attr)
init_attr.append(str(op.attr(attr))) init_attr.append(str(op.attr(attr)))
print("get_initializer_attr op attr:", str(op.attr(attr)))
attr_str = l_in.join(init_attr) attr_str = l_in.join(init_attr)
break break
return attr_str return attr_str
def parse_by_optimizer(self, grad_name, is_sparse, total_dims, context, def parse_by_optimizer(self, ctx, context):
adam_d2sum): grad_name = ctx.origin_varnames()[0]
main_program = context['origin_main_program'] is_sparse = ctx.is_sparse()
startup_program = context['startup_main_program'] size = ctx.sections()[0]
single_dim = ctx.sections()[1] if ctx.is_sparse() else 1
adam_d2sum = context["user_defined_strategy"].adam_d2sum
print("parse_by_optimizer table_id:{} is_datanorm:{}".format(
ctx.table_id(), ctx.is_datanorm_table()))
main_program, startup_program = get_program_by_id(context,
ctx.program_id())
pserver_id = get_role_id(context['role_maker']) pserver_id = get_role_id(context['role_maker'])
pserver_num = len(get_ps_endpoints(context['role_maker'])) pserver_num = len(get_ps_endpoints(context['role_maker']))
optimizer_ops = get_optimize_ops(main_program) optimizer_ops = get_optimize_ops(main_program)
print("the one ps optimizer_ops:", optimizer_ops)
print("the one ps parse_by_optimizer grad_name:", grad_name)
oop = None oop = None
for op in optimizer_ops: for op in optimizer_ops:
...@@ -278,6 +307,8 @@ class CommonAccessor: ...@@ -278,6 +307,8 @@ class CommonAccessor:
initializers = [] initializers = []
self.trainer_num = get_trainers(context['role_maker']) self.trainer_num = get_trainers(context['role_maker'])
self.table_num = size
self.table_dim = single_dim
if oop.type != 'adam' and adam_d2sum == True: if oop.type != 'adam' and adam_d2sum == True:
print('optimization algorithm is not adam, set adam_d2sum False') print('optimization algorithm is not adam, set adam_d2sum False')
...@@ -291,7 +322,11 @@ class CommonAccessor: ...@@ -291,7 +322,11 @@ class CommonAccessor:
param_varnames = self.opt_input_map["naive_adagrad"] param_varnames = self.opt_input_map["naive_adagrad"]
attr_varnames = self.opt_attr_map["naive_adagrad"] attr_varnames = self.opt_attr_map["naive_adagrad"]
self.accessor_class = "sgd" self.accessor_class = "sgd"
elif adam_d2sum: elif ctx.is_datanorm_table():
param_varnames = self.opt_input_map["summary"]
attr_varnames = self.opt_attr_map["summary"]
self.accessor_class = "summary"
elif adam_d2sum and not is_sparse:
param_varnames = self.opt_input_map["adam_d2sum"] param_varnames = self.opt_input_map["adam_d2sum"]
attr_varnames = self.opt_attr_map["adam_d2sum"] attr_varnames = self.opt_attr_map["adam_d2sum"]
self.accessor_class = "adam_d2sum" self.accessor_class = "adam_d2sum"
...@@ -306,10 +341,9 @@ class CommonAccessor: ...@@ -306,10 +341,9 @@ class CommonAccessor:
#for dims #for dims
if shape is None: if shape is None:
if is_sparse: if is_sparse:
shape = total_dims shape = single_dim
else: else:
shape = self.get_shard(total_dims, pserver_num, shape = self.get_shard(size, pserver_num, pserver_id)
pserver_id)
dims.append(shape) dims.append(shape)
#for initializers #for initializers
...@@ -333,6 +367,27 @@ class CommonAccessor: ...@@ -333,6 +367,27 @@ class CommonAccessor:
else: else:
initializer = "fill_constant&0" initializer = "fill_constant&0"
initializers.append(initializer) initializers.append(initializer)
elif self.accessor_class == "summary":
#for dims
if shape is None:
if is_sparse:
shape = single_dim
else:
shape = self.get_shard(size, pserver_num, pserver_id)
dims.append(shape)
#for initializers
if formal_name == "Param":
param = main_program.global_block().vars[oop.input(
formal_name)[0]]
initializer = self.get_initializer_attr(param.name,
startup_program)
elif formal_name == "SummaryDecayRate":
initializer = "fill_constant&0.99999"
else:
initializer = "fill_constant&0"
initializers.append(initializer)
else: else:
if formal_name == "G2Sum": if formal_name == "G2Sum":
dims.append(1) dims.append(1)
...@@ -348,9 +403,9 @@ class CommonAccessor: ...@@ -348,9 +403,9 @@ class CommonAccessor:
if shape is None: if shape is None:
if is_sparse: if is_sparse:
shape = total_dims shape = single_dim
else: else:
shape = self.get_shard(total_dims, pserver_num, shape = self.get_shard(size, pserver_num,
pserver_id) pserver_id)
dims.append(shape) dims.append(shape)
...@@ -379,6 +434,10 @@ class CommonAccessor: ...@@ -379,6 +434,10 @@ class CommonAccessor:
attrs += "entry: \"{}\" ".format(self.entry) attrs += "entry: \"{}\" ".format(self.entry)
attrs += "trainer_num: {} ".format(self.trainer_num) attrs += "trainer_num: {} ".format(self.trainer_num)
attrs += "sync: {} ".format(self.sync) attrs += "sync: {} ".format(self.sync)
if self.table_num:
attrs += "table_num: {} ".format(self.table_num)
if self.table_dim:
attrs += "table_dim: {} ".format(self.table_dim)
for param in self.params: for param in self.params:
attrs += "params: \"{}\" ".format(param) attrs += "params: \"{}\" ".format(param)
...@@ -448,10 +507,7 @@ class Table: ...@@ -448,10 +507,7 @@ class Table:
accessor_str = accessor_str.format( accessor_str = accessor_str.format(
conv_indent(indent), self.accessor_proto, conv_indent(indent)) conv_indent(indent), self.accessor_proto, conv_indent(indent))
attrs += accessor_str + "\n" attrs += accessor_str + "\n"
return table_str.format( elif self.accessor is not None:
conv_indent(indent), attrs, conv_indent(indent))
if self.accessor is not None:
attrs += self.accessor.to_string(indent) attrs += self.accessor.to_string(indent)
attrs += "\n" attrs += "\n"
...@@ -607,7 +663,9 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -607,7 +663,9 @@ class TheOnePSRuntime(RuntimeBase):
def _set_basic_info(self, context): def _set_basic_info(self, context):
self.context = context self.context = context
self.role_maker = context["role_maker"] self.role_maker = context["role_maker"]
self.origin_main_program = context["origin_main_program"] self.origin_main_program = context["origin_main_program"]
self.origin_main_programs = context["origin_main_programs"]
self.context[ self.context[
'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode 'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode
...@@ -615,10 +673,13 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -615,10 +673,13 @@ class TheOnePSRuntime(RuntimeBase):
self.context['trainer'] = TrainerRuntimeConfig(context[ self.context['trainer'] = TrainerRuntimeConfig(context[
'valid_strategy']) 'valid_strategy'])
self.context['ps_mode'] = self.context['trainer'].mode self.context['ps_mode'] = self.context['trainer'].mode
self.context['use_ps_gpu'] = context['valid_strategy'].use_ps_gpu self.context['use_ps_gpu'] = context['valid_strategy'].a_sync_configs[
'use_ps_gpu']
self.is_sync = True if self.context[ self.is_sync = True if self.context[
'ps_mode'] == DistributedMode.SYNC else False 'ps_mode'] == DistributedMode.SYNC else False
self.context['grad_name_to_param_name'] = {} self.context['grad_name_to_param_name'] = {}
self.context['tensor_table'] = {}
build_var_distributed(self.context)
def _init_worker(self): def _init_worker(self):
worker = self._get_fleet_proto(is_server=False, is_sync=self.is_sync) worker = self._get_fleet_proto(is_server=False, is_sync=self.is_sync)
...@@ -689,6 +750,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -689,6 +750,7 @@ class TheOnePSRuntime(RuntimeBase):
sync_kwargs = sync_strategy_envs() sync_kwargs = sync_strategy_envs()
kwargs.update(sync_kwargs) kwargs.update(sync_kwargs)
print("communicator config:", trainer_config.get_communicator_flags())
self._communicator = Communicator( self._communicator = Communicator(
trainer_config.mode, kwargs, trainer_config.mode, kwargs,
trainer_config.get_communicator_flags()) trainer_config.get_communicator_flags())
...@@ -893,7 +955,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -893,7 +955,7 @@ class TheOnePSRuntime(RuntimeBase):
common.table_name = self.context['grad_name_to_param_name'][ common.table_name = self.context['grad_name_to_param_name'][
ctx.origin_varnames()[0]] ctx.origin_varnames()[0]]
if self.ps_mode == DistributedMode.GEO: if self.context['ps_mode'] == DistributedMode.GEO:
table.table_class = "SparseGeoTable" table.table_class = "SparseGeoTable"
else: else:
all_table_proto = self.context[ all_table_proto = self.context[
...@@ -907,7 +969,8 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -907,7 +969,8 @@ class TheOnePSRuntime(RuntimeBase):
table.table_class = table_proto.table_class table.table_class = table_proto.table_class
else: else:
table.table_class = parse_table_class( table.table_class = parse_table_class(
common.table_name, self.origin_main_program) common.table_name,
ctx.program_id(), self.context)
if table.table_class != 'MemorySparseTable': if table.table_class != 'MemorySparseTable':
table.table_class = 'MemorySparseTable' table.table_class = 'MemorySparseTable'
warnings.warn( warnings.warn(
...@@ -925,12 +988,12 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -925,12 +988,12 @@ class TheOnePSRuntime(RuntimeBase):
warnings.warn( warnings.warn(
"The accessor of sparse table is not set, use default value." "The accessor of sparse table is not set, use default value."
) )
get_default_accessor_proto(table_proto.accessor, get_default_accessor_proto(
common.table_name, table_proto.accessor, common.table_name,
self.origin_main_program) ctx.program_id(), self.context)
check_embedding_dim(table_proto.accessor, check_embedding_dim(table_proto.accessor,
common.table_name, common.table_name,
self.origin_main_program) ctx.program_id(), self.context)
table.accessor_proto = text_format.MessageToString( table.accessor_proto = text_format.MessageToString(
table_proto.accessor) table_proto.accessor)
else: else:
...@@ -940,15 +1003,11 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -940,15 +1003,11 @@ class TheOnePSRuntime(RuntimeBase):
common.table_name = "MergedDense" common.table_name = "MergedDense"
adam_d2sum = self.context["user_defined_strategy"].adam_d2sum adam_d2sum = self.context["user_defined_strategy"].adam_d2sum
common.parse_by_optimizer(ctx.origin_varnames()[0], common.parse_by_optimizer(ctx, self.context)
ctx.is_sparse(),
ctx.sections()[1] if ctx.is_sparse()
else ctx.sections()[0], self.context,
adam_d2sum)
if ctx.is_sparse(): if ctx.is_sparse():
common.parse_entry(common.table_name, common.parse_entry(common.table_name,
self.origin_main_program) ctx.program_id(), self.context)
if is_sync: if is_sync:
common.sync = "true" common.sync = "true"
...@@ -1023,8 +1082,9 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1023,8 +1082,9 @@ class TheOnePSRuntime(RuntimeBase):
self._server.init_server(proto_txt, string_hosts, role_id, trainers, self._server.init_server(proto_txt, string_hosts, role_id, trainers,
self._server_sub_program) self._server_sub_program)
dist_varnames = get_sparse_tablenames(self.origin_main_program, True) dist_varnames = get_sparse_tablenames(self.origin_main_programs, True)
sparse_varnames = get_sparse_tablenames(self.origin_main_program, False) sparse_varnames = get_sparse_tablenames(self.origin_main_programs,
False)
distributed_varnames = dist_varnames + sparse_varnames distributed_varnames = dist_varnames + sparse_varnames
...@@ -1070,6 +1130,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1070,6 +1130,7 @@ class TheOnePSRuntime(RuntimeBase):
if var.name in exclude_var_names: if var.name in exclude_var_names:
return False return False
from .utils.public import _get_varname_parts
origin_varname, _, _ = _get_varname_parts(var.name) origin_varname, _, _ = _get_varname_parts(var.name)
if origin_varname.endswith("@GRAD"): if origin_varname.endswith("@GRAD"):
return False return False
...@@ -1085,16 +1146,24 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1085,16 +1146,24 @@ class TheOnePSRuntime(RuntimeBase):
return is_valid return is_valid
def _get_inference_model_path(self, dirname):
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
return model_path
def _save_sparse_params(self, executor, dirname, context, main_program, def _save_sparse_params(self, executor, dirname, context, main_program,
mode): mode):
distributed_varnames = get_sparse_tablenames( distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
self.context['origin_main_program'], True) True)
values = [] values = []
model_path = self._get_inference_model_path(dirname)
for id, names in context.items(): for id, names in context.items():
if names[0] not in distributed_varnames: if names[0] not in distributed_varnames:
# only save sparse param to local # only save sparse param to local
try: try:
self._worker.recv_and_save_model(id, dirname) self._worker.recv_and_save_model(id, model_path)
except: except:
pass pass
# save sparse & distributed param on server # save sparse & distributed param on server
...@@ -1221,10 +1290,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1221,10 +1290,7 @@ class TheOnePSRuntime(RuntimeBase):
infer_program._copy_dist_param_info_from(program) infer_program._copy_dist_param_info_from(program)
if dirname.startswith("afs:") or dirname.startswith("hdfs:"): model_path = self._get_inference_model_path(dirname)
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
model_basename = "__model__" model_basename = "__model__"
model_basename = os.path.join(model_path, model_basename) model_basename = os.path.join(model_path, model_basename)
paddle.save(infer_program, model_basename) paddle.save(infer_program, model_basename)
...@@ -1266,7 +1332,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1266,7 +1332,7 @@ class TheOnePSRuntime(RuntimeBase):
self._ps_inference_save_persistables(*args, **kwargs) self._ps_inference_save_persistables(*args, **kwargs)
def _load_sparse_params(self, dirname, context, main_program, mode): def _load_sparse_params(self, dirname, context, main_program, mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_program, distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
True) True)
values = [] values = []
for id, names in context.items(): for id, names in context.items():
......
...@@ -79,7 +79,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式 ...@@ -79,7 +79,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式
super(GeoPsProgramBuilder, self).__init__(pass_ctx) super(GeoPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.GEO: if self.ps_mode != DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}", raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "GeoPsProgramBuilder")) format(self.ps_mode, "GeoPsProgramBuilder"))
def _build_trainer_programs(self): def _build_trainer_programs(self):
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs) append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
...@@ -97,9 +97,9 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder): ...@@ -97,9 +97,9 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx): def __init__(self, pass_ctx):
logger.info("start building cpu-sync-ps program") logger.info("start building cpu-sync-ps program")
super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx) super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.SYNC: if self.ps_mode != DistributedMode.SYNC and self.ps_mode != DistributedMode.ASYNC:
raise ValueError("ps mode: {} not matched {}", raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "CpuSyncPsProgramBuilder")) format(self.ps_mode, "CpuSyncPsProgramBuilder"))
def _build_trainer_programs(self): def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass", add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
...@@ -178,7 +178,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder): ...@@ -178,7 +178,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
if self.use_ps_gpu or self.ps_mode == DistributedMode.GEO or self.attrs[ if self.use_ps_gpu or self.ps_mode == DistributedMode.GEO or self.attrs[
'is_heter_ps_mode'] == False: 'is_heter_ps_mode'] == False:
raise ValueError("ps mode: {} not matched {}", raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "HeterAsyncPsProgramBuilder")) format(self.ps_mode, "HeterAsyncPsProgramBuilder"))
def _build_trainer_programs(self): def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass", add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
......
...@@ -54,6 +54,9 @@ SPARSE_GRAD_OP_TYPE_DICT = { ...@@ -54,6 +54,9 @@ SPARSE_GRAD_OP_TYPE_DICT = {
} }
DEFAULT_DEVICE = 'cpu' DEFAULT_DEVICE = 'cpu'
DATA_NORM_NAME = [".batch_size", ".batch_sum", ".batch_square_sum"]
DATA_NORM_GRAD_NAME = [x + "@GRAD" for x in DATA_NORM_NAME]
def logger_config(log_path, logging_name): def logger_config(log_path, logging_name):
logger = logging.getLogger(logging_name) logger = logging.getLogger(logging_name)
...@@ -84,6 +87,8 @@ class DistributedMode: ...@@ -84,6 +87,8 @@ class DistributedMode:
class TrainerRuntimeConfig(object): class TrainerRuntimeConfig(object):
def __init__(self, valid_strategy): def __init__(self, valid_strategy):
self.mode = None self.mode = None
num_threads = os.getenv("CPU_NUM", "1")
send_queue_size = num_threads
k_steps = valid_strategy.a_sync_configs["k_steps"] k_steps = valid_strategy.a_sync_configs["k_steps"]
logger.info("ps mode in strategy: {}, {}".format( logger.info("ps mode in strategy: {}, {}".format(
valid_strategy.a_sync, valid_strategy.a_sync_configs["k_steps"])) valid_strategy.a_sync, valid_strategy.a_sync_configs["k_steps"]))
...@@ -95,14 +100,13 @@ class TrainerRuntimeConfig(object): ...@@ -95,14 +100,13 @@ class TrainerRuntimeConfig(object):
if valid_strategy.a_sync and k_steps > 0: if valid_strategy.a_sync and k_steps > 0:
self.mode = DistributedMode.GEO self.mode = DistributedMode.GEO
send_queue_size = k_steps
num_threads = os.getenv("CPU_NUM", "1")
self.runtime_configs = {} self.runtime_configs = {}
self.runtime_configs['communicator_max_merge_var_num'] = os.getenv( self.runtime_configs['communicator_max_merge_var_num'] = os.getenv(
"FLAGS_communicator_max_merge_var_num", num_threads) "FLAGS_communicator_max_merge_var_num", send_queue_size)
self.runtime_configs['communicator_send_queue_size'] = os.getenv( self.runtime_configs['communicator_send_queue_size'] = os.getenv(
"FLAGS_communicator_send_queue_size", num_threads) "FLAGS_communicator_send_queue_size", send_queue_size)
self.runtime_configs[ self.runtime_configs[
'communicator_independent_recv_thread'] = os.getenv( 'communicator_independent_recv_thread'] = os.getenv(
"FLAGS_communicator_independent_recv_thread", "1") "FLAGS_communicator_independent_recv_thread", "1")
...@@ -116,6 +120,55 @@ class TrainerRuntimeConfig(object): ...@@ -116,6 +120,55 @@ class TrainerRuntimeConfig(object):
self.runtime_configs['communicator_is_sgd_optimizer'] = os.getenv( self.runtime_configs['communicator_is_sgd_optimizer'] = os.getenv(
"FLAGS_communicator_is_sgd_optimizer", "1") "FLAGS_communicator_is_sgd_optimizer", "1")
def get_communicator_flags(self):
need_keys = []
num_threads = os.getenv("CPU_NUM", "1")
mode_str = ""
if self.mode is None or self.mode == DistributedMode.ASYNC:
need_keys = self.runtime_configs.keys()
mode_str = "async"
elif self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC:
mode_str = "sync or half_async"
need_keys = [
'communicator_max_merge_var_num',
'communicator_send_wait_times', 'communicator_thread_pool_size',
'communicator_send_queue_size'
]
elif self.mode == DistributedMode.GEO:
mode_str = "GEO"
need_keys = [
'communicator_thread_pool_size', 'communicator_send_wait_times',
'communicator_max_merge_var_num', 'communicator_send_queue_size'
]
else:
raise ValueError("Unsupported Mode")
if self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC:
max_merge_var_num = self.runtime_configs[
'communicator_max_merge_var_num']
send_queue_size = self.runtime_configs[
'communicator_send_queue_size']
if max_merge_var_num != num_threads:
print('WARNING: In {} mode, communicator_max_merge_var_num '
'must be equal to CPU_NUM. But received, '
'communicator_max_merge_var_num = {}, CPU_NUM = '
'{}. communicator_max_merge_var_num will be forced to {}.'
.format(mode_str, max_merge_var_num, num_threads,
num_threads))
self.runtime_configs[
'communicator_max_merge_var_num'] = num_threads
if send_queue_size != num_threads:
print('WARNING: In {} mode, communicator_send_queue_size '
'must be equal to CPU_NUM. But received, '
'communicator_send_queue_size = {}, CPU_NUM = '
'{}. communicator_send_queue_size will be forced to {}.'
.format(mode_str, send_queue_size, num_threads,
num_threads))
self.runtime_configs[
'communicator_send_queue_size'] = num_threads
return dict((key, str(self.runtime_configs[key])) for key in need_keys)
def get_lr_ops(program): def get_lr_ops(program):
lr_ops = [] lr_ops = []
...@@ -176,6 +229,13 @@ def get_ps_endpoint(role_maker): ...@@ -176,6 +229,13 @@ def get_ps_endpoint(role_maker):
return role_maker.get_pserver_endpoints()[get_role_id(role_maker)] return role_maker.get_pserver_endpoints()[get_role_id(role_maker)]
def get_ps_endpoints(role_maker):
try:
return role_maker._get_pserver_endpoints()
except Exception:
return role_maker.get_pserver_endpoints()
def get_heter_worker_endpoint(role_maker): def get_heter_worker_endpoint(role_maker):
try: try:
return role_maker._get_heter_worker_endpoint() return role_maker._get_heter_worker_endpoint()
...@@ -224,8 +284,9 @@ def is_sparse_op(op): ...@@ -224,8 +284,9 @@ def is_sparse_op(op):
return False return False
def get_sparse_tablenames(program, is_distributed): def get_sparse_tablenames(programs, is_distributed):
tablenames = set() tablenames = set()
for program in programs:
if is_distributed: if is_distributed:
for op in program.global_block().ops: for op in program.global_block().ops:
if is_distributed_sparse_op(op): if is_distributed_sparse_op(op):
...@@ -237,13 +298,6 @@ def get_sparse_tablenames(program, is_distributed): ...@@ -237,13 +298,6 @@ def get_sparse_tablenames(program, is_distributed):
return list(tablenames) return list(tablenames)
def get_ps_endpoints(role_maker):
try:
return role_maker._get_pserver_endpoints()
except Exception:
return role_maker.get_pserver_endpoints()
def get_trainers(role_maker): def get_trainers(role_maker):
try: try:
return role_maker._worker_num() return role_maker._worker_num()
...@@ -251,7 +305,7 @@ def get_trainers(role_maker): ...@@ -251,7 +305,7 @@ def get_trainers(role_maker):
return role_maker.worker_num() return role_maker.worker_num()
def get_dense_send_context(context, def get_dense_send_context(program,
send_ctx, send_ctx,
idx, idx,
merged_dense_pairs, merged_dense_pairs,
...@@ -260,34 +314,72 @@ def get_dense_send_context(context, ...@@ -260,34 +314,72 @@ def get_dense_send_context(context,
if len(merged_dense_pairs) < 1: if len(merged_dense_pairs) < 1:
return idx return idx
if not split_dense_table: if not split_dense_table:
dense_pairs = []
data_norm_pairs = []
for merged in merged_dense_pairs:
is_data_norm = False
grad = merged[1]
varname = grad.merged_var.name
for name in DATA_NORM_GRAD_NAME:
if varname.endswith(name):
is_data_norm = True
if is_data_norm:
data_norm_pairs.append(merged)
else:
dense_pairs.append(merged)
# simple dense table
origin_varnames = [] origin_varnames = []
var_numel = 0 var_numel = 0
for merged in merged_dense_pairs: for merged in dense_pairs:
grad = merged[1] grad = merged[1]
origin_varnames.append(grad.merged_var.name) origin_varnames.append(grad.merged_var.name)
var = context['origin_main_program'].global_block().vars[ var = program.global_block().vars[grad.merged_var.name]
grad.merged_var.name]
var_numel += reduce(lambda x, y: x * y, var.shape) var_numel += reduce(lambda x, y: x * y, var.shape)
grad_name = "Dense@Grad" grad_name = "Dense@GRAD_" + str(idx)
trainer_id = get_role_id(context['role_maker'])
aggregate = True aggregate = True
print("public get_dense_send_context dense_table:", grad_name,
var_numel, origin_varnames)
dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], origin_varnames, trainer_id, [var_numel], origin_varnames, trainer_id,
aggregate, False, False, idx, False) aggregate, False, False, idx, False, False,
id(program))
send_ctx[grad_name] = dense_ctx send_ctx[grad_name] = dense_ctx
idx += 1 idx += 1
if len(data_norm_pairs) <= 0:
return idx
# data norm table
origin_varnames = []
var_numel = 0
for merged in data_norm_pairs:
grad = merged[1]
origin_varnames.append(grad.merged_var.name)
var = program.global_block().vars[grad.merged_var.name]
var_numel += reduce(lambda x, y: x * y, var.shape)
grad_name = "DataNorm@GRAD_" + str(idx)
aggregate = True
print("public get_dense_send_context data_norm table:", grad_name,
var_numel, origin_varnames)
data_norm_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], origin_varnames, trainer_id,
aggregate, False, False, idx, False, True,
id(program))
send_ctx[grad_name] = data_norm_ctx
idx += 1
else: else:
for merged in merged_dense_pairs: for merged in merged_dense_pairs:
grad = merged[1] grad = merged[1]
origin_varname = grad.merged_var.name origin_varname = grad.merged_var.name
var = context['origin_main_program'].global_block().vars[ var = program.global_block().vars[origin_varname]
origin_varname]
var_numel = reduce(lambda x, y: x * y, var.shape) var_numel = reduce(lambda x, y: x * y, var.shape)
grad_name = origin_varname grad_name = origin_varname
aggregate = True aggregate = True
dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], [origin_varname], trainer_id, [var_numel], [origin_varname], trainer_id,
aggregate, False, False, idx, False) aggregate, False, False, idx, False, False,
id(program))
send_ctx[grad_name] = dense_ctx send_ctx[grad_name] = dense_ctx
idx += 1 idx += 1
return idx return idx
...@@ -299,23 +391,26 @@ def get_geo_trainer_send_context(context): ...@@ -299,23 +391,26 @@ def get_geo_trainer_send_context(context):
format(ps_mode, "get_geo_trainer_send_context")) format(ps_mode, "get_geo_trainer_send_context"))
send_ctx = {} send_ctx = {}
trainer_id = get_role_id(context['role_maker']) trainer_id = get_role_id(context['role_maker'])
origin_programs = context['origin_main_programs']
idx = 0 idx = 0
distibuted_varnames = get_sparse_tablenames(context['origin_main_program'], distibuted_varnames = get_sparse_tablenames(origin_programs, True)
True) for i, program in enumerate(origin_programs):
for merged in context['merged_sparse_pairs']: merged_sparse_pairs = context['merged_sparse_pairs'][i]
for merged in merged_sparse_pairs:
param, grad = merged param, grad = merged
grad_name = grad.merged_var.name grad_name = grad.merged_var.name
param_name = param.merged_var.name param_name = param.merged_var.name
is_distributed = True if param_name in distibuted_varnames else False is_distributed = True if param_name in distibuted_varnames else False
var = context['origin_main_program'].global_block().vars[ var = program.global_block().vars[grad.merged_var.name]
grad.merged_var.name]
var_numel = reduce(lambda x, y: x * y, var.shape[1:]) var_numel = reduce(lambda x, y: x * y, var.shape[1:])
sparse_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], sparse_ctx = CommContext(grad_name, [grad_name],
[var_numel], [grad_name], trainer_id, True, ["127.0.0.1:6071"], [var_numel],
True, is_distributed, idx, False) [grad_name], trainer_id, True, True,
is_distributed, idx, False, False,
id(program))
idx += 1 idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx send_ctx[sparse_ctx.var_name()] = sparse_ctx
...@@ -336,7 +431,7 @@ def _step_ctx(idx, role_maker): ...@@ -336,7 +431,7 @@ def _step_ctx(idx, role_maker):
sections = [1] * len(endpoints) sections = [1] * len(endpoints)
names = [name] * len(endpoints) names = [name] * len(endpoints)
ctx = CommContext(name, names, endpoints, sections, [name], trainer_id, ctx = CommContext(name, names, endpoints, sections, [name], trainer_id,
True, False, False, idx, True) True, False, False, idx, True, False, -1)
return name, ctx return name, ctx
...@@ -348,14 +443,19 @@ def get_the_one_send_context(context, ...@@ -348,14 +443,19 @@ def get_the_one_send_context(context,
ep_list = ["127.0.0.1:6071"] ep_list = ["127.0.0.1:6071"]
send_ctx = {} send_ctx = {}
trainer_id = get_role_id(context['role_maker']) trainer_id = get_role_id(context['role_maker'])
origin_programs = context['origin_main_programs']
idx = 0 idx = 0
idx += get_dense_send_context(context, send_ctx, idx, for i, program in enumerate(origin_programs):
context['merged_dense_pairs'], trainer_id, merged_dense_pairs = context['merged_dense_pairs'][i]
idx += get_dense_send_context(program, send_ctx, idx,
merged_dense_pairs, trainer_id,
split_dense_table) split_dense_table)
distibuted_varnames = get_sparse_tablenames(context['origin_main_program'], distibuted_varnames = get_sparse_tablenames(origin_programs, True)
True) print("public distibuted_varnames:", distibuted_varnames)
for merged in context['merged_sparse_pairs']: for i, program in enumerate(origin_programs):
merged_sparse_pairs = context['merged_sparse_pairs'][i]
for merged in merged_sparse_pairs:
param, grad = merged param, grad = merged
grad_name = grad.merged_var.name grad_name = grad.merged_var.name
param_name = param.merged_var.name param_name = param.merged_var.name
...@@ -366,15 +466,19 @@ def get_the_one_send_context(context, ...@@ -366,15 +466,19 @@ def get_the_one_send_context(context,
is_distributed = True if param_name in distibuted_varnames else False is_distributed = True if param_name in distibuted_varnames else False
var = context['origin_main_program'].global_block().vars[ var = program.global_block().vars[grad.merged_var.name]
grad.merged_var.name]
shape = list(var.shape) shape = list(var.shape)
shape[0] = 0 if is_distributed else shape[0] shape[0] = 0 if is_distributed else shape[0]
print("public get_the_one_send_context sparse:", grad_name,
splited_varname, shape)
if grad_name in send_ctx:
continue
sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape, sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape,
[grad_name], trainer_id, True, True, [grad_name], trainer_id, True, True,
is_distributed, idx, False) is_distributed, idx, False, False,
id(program))
idx += 1 idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx send_ctx[sparse_ctx.var_name()] = sparse_ctx
...@@ -1073,7 +1177,7 @@ def get_the_one_recv_context(context, ...@@ -1073,7 +1177,7 @@ def get_the_one_recv_context(context,
param_names = [] param_names = []
for grad_varname in origin_grad_varnames: for grad_varname in origin_grad_varnames:
param_name = grad_name_to_param_name[grad_varname] param_name = context["grad_name_to_param_name"][grad_varname]
param_names.append(param_name) param_names.append(param_name)
recv_id_maps[ctx.table_id()] = param_names recv_id_maps[ctx.table_id()] = param_names
else: else:
...@@ -1090,7 +1194,7 @@ def get_the_one_recv_context(context, ...@@ -1090,7 +1194,7 @@ def get_the_one_recv_context(context,
param_names = [] param_names = []
for grad_varname in origin_grad_varnames: for grad_varname in origin_grad_varnames:
param_name = grad_name_to_param_name[grad_varname] param_name = context["grad_name_to_param_name"][grad_varname]
param_names.append(param_name) param_names.append(param_name)
recv_id_maps[ctx.table_id()] = param_names recv_id_maps[ctx.table_id()] = param_names
return recv_id_maps return recv_id_maps
...@@ -1141,15 +1245,25 @@ class MergedVariable: ...@@ -1141,15 +1245,25 @@ class MergedVariable:
def build_var_distributed(context): def build_var_distributed(context):
sparse_pairs, dense_pairs = get_param_grads(context['origin_main_program']) origin_programs = context['origin_main_programs']
origin_for_sparse = []
origin_for_dense = [] param_name_to_grad_name = {}
param_name_grad_name = {}
grad_name_to_param_name = {} grad_name_to_param_name = {}
context["merged_variables_pairs"] = [] context["origin_sparse_pairs"] = []
context["origin_dense_pairs"] = []
context["merged_sparse_pairs"] = [] context["merged_sparse_pairs"] = []
context['merged_dense_pairs'] = [] context['merged_dense_pairs'] = []
context["merged_variables_pairs"] = []
context["merged_variable_map"] = {} context["merged_variable_map"] = {}
for origin_program in origin_programs:
sparse_pairs, dense_pairs = get_param_grads(origin_program)
print("public build_var_distributed sparse_pairs:", sparse_pairs)
print("public build_var_distributed dense_pairs:", dense_pairs)
origin_for_sparse = []
origin_for_dense = []
merged_sparse_pairs = []
merged_dense_pairs = []
merged_variables_pairs = []
for param, grad in sparse_pairs: for param, grad in sparse_pairs:
origin_for_sparse.append((param, grad)) origin_for_sparse.append((param, grad))
...@@ -1162,18 +1276,22 @@ def build_var_distributed(context): ...@@ -1162,18 +1276,22 @@ def build_var_distributed(context):
m_param = MergedVariable(param, [param], [0]) m_param = MergedVariable(param, [param], [0])
m_grad = MergedVariable(grad, [grad], [0]) m_grad = MergedVariable(grad, [grad], [0])
context["merged_variables_pairs"].append((m_param, m_grad)) merged_variables_pairs.append((m_param, m_grad))
context["merged_dense_pairs"].append((m_param, m_grad)) merged_dense_pairs.append((m_param, m_grad))
print("public build_var_distributed merged_dense_pairs:",
merged_dense_pairs)
for sparse_pair in origin_for_sparse: for sparse_pair in origin_for_sparse:
param, grad = sparse_pair param, grad = sparse_pair
m_param = MergedVariable(param, [param], [0]) m_param = MergedVariable(param, [param], [0])
m_grad = MergedVariable(grad, [grad], [0]) m_grad = MergedVariable(grad, [grad], [0])
context["merged_variables_pairs"].append((m_param, m_grad)) merged_variables_pairs.append((m_param, m_grad))
context["merged_sparse_pairs"].append((m_param, m_grad)) merged_sparse_pairs.append((m_param, m_grad))
print("public build_var_distributed merged_sparse_pairs:",
merged_sparse_pairs)
for merged in context["merged_variables_pairs"]: for merged in merged_variables_pairs:
m_param, m_grad = merged m_param, m_grad = merged
context["merged_variable_map"][ context["merged_variable_map"][
m_param.merged_var.name] = m_param.merged_var m_param.merged_var.name] = m_param.merged_var
...@@ -1185,14 +1303,30 @@ def build_var_distributed(context): ...@@ -1185,14 +1303,30 @@ def build_var_distributed(context):
param_merges.extend(origin_for_dense) param_merges.extend(origin_for_dense)
for param, grad in param_merges: for param, grad in param_merges:
param_name_grad_name[param.name] = grad.name param_name_to_grad_name[param.name] = grad.name
grad_name_to_param_name[grad.name] = param.name grad_name_to_param_name[grad.name] = param.name
context["origin_sparse_pairs"] = origin_for_sparse context["origin_sparse_pairs"].append(origin_for_sparse)
context["origin_dense_pairs"] = origin_for_dense context["origin_dense_pairs"].append(origin_for_dense)
context["param_name_to_grad_name"] = param_name_grad_name context["merged_sparse_pairs"].append(merged_sparse_pairs)
context['merged_dense_pairs'].append(merged_dense_pairs)
context["param_name_to_grad_name"] = param_name_to_grad_name
context["grad_name_to_param_name"] = grad_name_to_param_name context["grad_name_to_param_name"] = grad_name_to_param_name
print("public build_var_distributed origin_sparse_pairs:",
context["origin_sparse_pairs"])
print("public build_var_distributed origin_for_dense:",
context["origin_dense_pairs"])
print("public build_var_distributed merged_sparse_pairs:",
context["merged_sparse_pairs"])
print("public build_var_distributed merged_dense_pairs:",
context['merged_dense_pairs'])
print("public build_var_distributed param_name_to_grad_name:",
param_name_to_grad_name)
print("public build_var_distributed grad_name_to_param_name:",
grad_name_to_param_name)
def _is_opt_role_op(op): def _is_opt_role_op(op):
# NOTE : depend on oprole to find out whether this op is for # NOTE : depend on oprole to find out whether this op is for
......
...@@ -577,7 +577,7 @@ class CompileTimeStrategy(object): ...@@ -577,7 +577,7 @@ class CompileTimeStrategy(object):
sparse_ctx = CommContext(grad_name, [grad_name], sparse_ctx = CommContext(grad_name, [grad_name],
["127.0.0.1:6071"], [var_numel], ["127.0.0.1:6071"], [var_numel],
[grad_name], trainer_id, True, True, [grad_name], trainer_id, True, True,
is_distributed, idx, False) is_distributed, idx, False, False, -1)
idx += 1 idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx send_ctx[sparse_ctx.var_name()] = sparse_ctx
...@@ -615,7 +615,8 @@ class CompileTimeStrategy(object): ...@@ -615,7 +615,8 @@ class CompileTimeStrategy(object):
aggregate = True aggregate = True
dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], origin_varnames, trainer_id, [var_numel], origin_varnames, trainer_id,
aggregate, False, False, idx, False) aggregate, False, False, idx, False, False,
-1)
send_ctx[grad_name] = dense_ctx send_ctx[grad_name] = dense_ctx
idx += 1 idx += 1
else: else:
...@@ -630,7 +631,7 @@ class CompileTimeStrategy(object): ...@@ -630,7 +631,7 @@ class CompileTimeStrategy(object):
dense_ctx = CommContext(grad_name, [grad_name], dense_ctx = CommContext(grad_name, [grad_name],
["127.0.0.1:6071"], [var_numel], ["127.0.0.1:6071"], [var_numel],
[origin_varname], trainer_id, aggregate, [origin_varname], trainer_id, aggregate,
False, False, idx, False) False, False, idx, False, False, -1)
send_ctx[grad_name] = dense_ctx send_ctx[grad_name] = dense_ctx
idx += 1 idx += 1
return idx return idx
...@@ -672,7 +673,7 @@ class CompileTimeStrategy(object): ...@@ -672,7 +673,7 @@ class CompileTimeStrategy(object):
sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape, sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape,
[grad_name], trainer_id, True, True, [grad_name], trainer_id, True, True,
is_distributed, idx, False) is_distributed, idx, False, False, -1)
idx += 1 idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx send_ctx[sparse_ctx.var_name()] = sparse_ctx
...@@ -750,7 +751,7 @@ class CompileTimeStrategy(object): ...@@ -750,7 +751,7 @@ class CompileTimeStrategy(object):
sections = [1] * len(endpoints) sections = [1] * len(endpoints)
names = [name] * len(endpoints) names = [name] * len(endpoints)
ctx = CommContext(name, names, endpoints, sections, [name], trainer_id, ctx = CommContext(name, names, endpoints, sections, [name], trainer_id,
True, False, False, idx, True) True, False, False, idx, True, False, -1)
return name, ctx return name, ctx
def _create_vars_from_blocklist(self, block_list): def _create_vars_from_blocklist(self, block_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册