未验证 提交 d56a0a1b 编写于 作者: W wangguanqun 提交者: GitHub

fix bug in new the_one_ps (#39505)

* fix benchmark and communicator config

* fix bugs of the_one_ps

* multi program and fix bug in optimizer

* multi program in the_one_ps

* public commcontext
上级 a6abb6e7
...@@ -31,7 +31,8 @@ struct CommContext { ...@@ -31,7 +31,8 @@ struct CommContext {
const std::vector<std::string> &origin_names, int id, const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true, bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false, int table_id_ = -1, bool is_distributed_ = false, int table_id_ = -1,
bool is_tensor_table_ = false) bool is_tensor_table_ = false, bool is_datanorm_table_ = false,
int64_t program_id_ = -1)
: var_name(name), : var_name(name),
splited_varnames(names), splited_varnames(names),
epmap(emap), epmap(emap),
...@@ -42,7 +43,9 @@ struct CommContext { ...@@ -42,7 +43,9 @@ struct CommContext {
is_sparse(is_sparse_), is_sparse(is_sparse_),
is_distributed(is_distributed_), is_distributed(is_distributed_),
table_id(table_id_), table_id(table_id_),
is_tensor_table(is_tensor_table_) {} program_id(program_id_),
is_tensor_table(is_tensor_table_),
is_datanorm_table(is_datanorm_table_) {}
CommContext(const CommContext &ctx) { CommContext(const CommContext &ctx) {
var_name = ctx.var_name; var_name = ctx.var_name;
...@@ -55,7 +58,9 @@ struct CommContext { ...@@ -55,7 +58,9 @@ struct CommContext {
origin_varnames = ctx.origin_varnames; origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed; is_distributed = ctx.is_distributed;
table_id = ctx.table_id; table_id = ctx.table_id;
program_id = ctx.program_id;
is_tensor_table = ctx.is_tensor_table; is_tensor_table = ctx.is_tensor_table;
is_datanorm_table = ctx.is_datanorm_table;
} }
std::string print() const { std::string print() const {
...@@ -78,7 +83,9 @@ struct CommContext { ...@@ -78,7 +83,9 @@ struct CommContext {
ss << " is_sparse: " << is_sparse; ss << " is_sparse: " << is_sparse;
ss << " is_distributed: " << is_distributed << "\n"; ss << " is_distributed: " << is_distributed << "\n";
ss << " table_id: " << table_id << "\n"; ss << " table_id: " << table_id << "\n";
ss << " program_id: " << program_id << "\n";
ss << " is_tensor_table: " << is_tensor_table << "\n"; ss << " is_tensor_table: " << is_tensor_table << "\n";
ss << " is_datanorm_table: " << is_datanorm_table << "\n";
return ss.str(); return ss.str();
} }
...@@ -93,7 +100,9 @@ struct CommContext { ...@@ -93,7 +100,9 @@ struct CommContext {
bool is_sparse; bool is_sparse;
bool is_distributed; bool is_distributed;
int table_id; int table_id;
int64_t program_id;
bool is_tensor_table; bool is_tensor_table;
bool is_datanorm_table;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -103,11 +103,13 @@ void BindCommunicatorContext(py::module* m) { ...@@ -103,11 +103,13 @@ void BindCommunicatorContext(py::module* m) {
py::init<const std::string&, const std::vector<std::string>&, py::init<const std::string&, const std::vector<std::string>&,
const std::vector<std::string>&, const std::vector<int64_t>&, const std::vector<std::string>&, const std::vector<int64_t>&,
const std::vector<std::string>&, int, bool, bool, bool, int, const std::vector<std::string>&, int, bool, bool, bool, int,
bool>()) bool, bool, int64_t>())
.def("var_name", [](const CommContext& self) { return self.var_name; }) .def("var_name", [](const CommContext& self) { return self.var_name; })
.def("trainer_id", .def("trainer_id",
[](const CommContext& self) { return self.trainer_id; }) [](const CommContext& self) { return self.trainer_id; })
.def("table_id", [](const CommContext& self) { return self.table_id; }) .def("table_id", [](const CommContext& self) { return self.table_id; })
.def("program_id",
[](const CommContext& self) { return self.program_id; })
.def("split_varnames", .def("split_varnames",
[](const CommContext& self) { return self.splited_varnames; }) [](const CommContext& self) { return self.splited_varnames; })
.def("split_endpoints", .def("split_endpoints",
...@@ -122,6 +124,8 @@ void BindCommunicatorContext(py::module* m) { ...@@ -122,6 +124,8 @@ void BindCommunicatorContext(py::module* m) {
[](const CommContext& self) { return self.origin_varnames; }) [](const CommContext& self) { return self.origin_varnames; })
.def("is_tensor_table", .def("is_tensor_table",
[](const CommContext& self) { return self.is_tensor_table; }) [](const CommContext& self) { return self.is_tensor_table; })
.def("is_datanorm_table",
[](const CommContext& self) { return self.is_datanorm_table; })
.def("__str__", [](const CommContext& self) { return self.print(); }); .def("__str__", [](const CommContext& self) { return self.print(); });
} }
......
...@@ -46,7 +46,9 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -46,7 +46,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
attrs['loss'] = loss attrs['loss'] = loss
attrs['min_block_size'] = 81920 attrs['min_block_size'] = 81920
attrs['origin_main_program'] = loss.block.program attrs['origin_main_program'] = loss.block.program
attrs['origin_main_programs'] = [loss.block.program]
attrs['origin_startup_program'] = startup_program attrs['origin_startup_program'] = startup_program
attrs['origin_startup_programs'] = [startup_program]
attrs['cloned_main'] = attrs['origin_main_program'].clone() attrs['cloned_main'] = attrs['origin_main_program'].clone()
attrs['cloned_startup'] = attrs['origin_startup_program'].clone() attrs['cloned_startup'] = attrs['origin_startup_program'].clone()
......
...@@ -560,9 +560,9 @@ class FakeInitOpsPass(PassBase): ...@@ -560,9 +560,9 @@ class FakeInitOpsPass(PassBase):
return True return True
def _get_sparse_table_names(self, attrs): def _get_sparse_table_names(self, attrs):
dist_varnames = get_sparse_tablenames(attrs['origin_main_program'], dist_varnames = get_sparse_tablenames(attrs['origin_main_programs'],
True) True)
sparse_varnames = get_sparse_tablenames(attrs['origin_main_program'], sparse_varnames = get_sparse_tablenames(attrs['origin_main_programs'],
False) False)
return list(set(dist_varnames + sparse_varnames)) return list(set(dist_varnames + sparse_varnames))
......
...@@ -24,8 +24,8 @@ from paddle.fluid.compiler import CompiledProgram ...@@ -24,8 +24,8 @@ from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.framework import Variable, Parameter from paddle.fluid.framework import Variable, Parameter
from .runtime_base import RuntimeBase from paddle.distributed.fleet.runtime.runtime_base import RuntimeBase
from ..base.private_helper_function import wait_server_ready from paddle.distributed.fleet.base.private_helper_function import wait_server_ready
from paddle.fluid.communicator import Communicator, HeterClient from paddle.fluid.communicator import Communicator, HeterClient
from google.protobuf import text_format from google.protobuf import text_format
...@@ -39,8 +39,17 @@ def conv_indent(indent): ...@@ -39,8 +39,17 @@ def conv_indent(indent):
PSERVER_SAVE_SUFFIX = ".shard" PSERVER_SAVE_SUFFIX = ".shard"
def parse_table_class(varname, o_main_program): def get_program_by_id(context, program_id):
for op in o_main_program.global_block().ops: programs = context["origin_main_programs"]
for i, program in enumerate(programs):
if id(program) == program_id:
return program, context["origin_startup_programs"][i]
return None, None
def parse_table_class(varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op): if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue continue
...@@ -53,9 +62,10 @@ def parse_table_class(varname, o_main_program): ...@@ -53,9 +62,10 @@ def parse_table_class(varname, o_main_program):
return "MemorySparseTable" return "MemorySparseTable"
def get_default_accessor_proto(accessor, varname, o_main_program): def get_default_accessor_proto(accessor, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
embedding_dim = 0 embedding_dim = 0
for var in o_main_program.list_vars(): for var in main_program.list_vars():
if var.name == varname: if var.name == varname:
embedding_dim = var.shape[1] embedding_dim = var.shape[1]
break break
...@@ -123,9 +133,10 @@ def get_default_accessor_proto(accessor, varname, o_main_program): ...@@ -123,9 +133,10 @@ def get_default_accessor_proto(accessor, varname, o_main_program):
sgd_param.adam.weight_bounds.extend([-10.0, 10.0]) sgd_param.adam.weight_bounds.extend([-10.0, 10.0])
def check_embedding_dim(accessor, varname, o_main_program): def check_embedding_dim(accessor, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
embedding_dim = 0 embedding_dim = 0
for var in o_main_program.list_vars(): for var in main_program.list_vars():
if var.name == varname: if var.name == varname:
embedding_dim = var.shape[1] embedding_dim = var.shape[1]
break break
...@@ -172,6 +183,8 @@ class CommonAccessor: ...@@ -172,6 +183,8 @@ class CommonAccessor:
self.dims = [] self.dims = []
self.trainer_num = 0 self.trainer_num = 0
self.sync = "false" self.sync = "false"
self.table_num = None
self.table_dim = None
self.initializers = [] self.initializers = []
self.opt_input_map = {} self.opt_input_map = {}
self.opt_attr_map = {} self.opt_attr_map = {}
...@@ -192,6 +205,7 @@ class CommonAccessor: ...@@ -192,6 +205,7 @@ class CommonAccessor:
opt_input_map["sum"] = [("Param", None)] opt_input_map["sum"] = [("Param", None)]
opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1), opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1),
("LearningRate", 1)] ("LearningRate", 1)]
opt_input_map["summary"] = [("Param", None), ("SummaryDecayRate", 1)]
opt_attr_map = {} opt_attr_map = {}
opt_attr_map["sgd"] = [] opt_attr_map["sgd"] = []
...@@ -201,6 +215,7 @@ class CommonAccessor: ...@@ -201,6 +215,7 @@ class CommonAccessor:
("epsilon", "f")] ("epsilon", "f")]
opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"), opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")] ("epsilon", "f")]
opt_attr_map["summary"] = []
opt_init_map = {} opt_init_map = {}
opt_init_map["gaussian_random"] = ["seed", "mean", "std"] opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
...@@ -212,8 +227,9 @@ class CommonAccessor: ...@@ -212,8 +227,9 @@ class CommonAccessor:
self.opt_input_map = opt_input_map self.opt_input_map = opt_input_map
self.opt_init_map = opt_init_map self.opt_init_map = opt_init_map
def parse_entry(self, varname, o_main_program): def parse_entry(self, varname, program_id, context):
for op in o_main_program.global_block().ops: main_program, startup_program = get_program_by_id(context, program_id)
for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op): if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue continue
...@@ -243,23 +259,36 @@ class CommonAccessor: ...@@ -243,23 +259,36 @@ class CommonAccessor:
attr_str = "" attr_str = ""
origin_var_name = value_name origin_var_name = value_name
print("get_initializer_attr param name:", value_name)
for op in o_startup_program.global_block().ops: for op in o_startup_program.global_block().ops:
if op.type in self.opt_init_map.keys( if op.type in self.opt_init_map.keys(
) and origin_var_name == op.output("Out")[0]: ) and origin_var_name == op.output("Out")[0]:
init_attr = [op.type] init_attr = [op.type]
print("get_initializer_attr op type:", op.type)
for attr in self.opt_init_map[op.type]: for attr in self.opt_init_map[op.type]:
print("get_initializer_attr opt_init_map attr:", attr)
init_attr.append(str(op.attr(attr))) init_attr.append(str(op.attr(attr)))
print("get_initializer_attr op attr:", str(op.attr(attr)))
attr_str = l_in.join(init_attr) attr_str = l_in.join(init_attr)
break break
return attr_str return attr_str
def parse_by_optimizer(self, grad_name, is_sparse, total_dims, context, def parse_by_optimizer(self, ctx, context):
adam_d2sum): grad_name = ctx.origin_varnames()[0]
main_program = context['origin_main_program'] is_sparse = ctx.is_sparse()
startup_program = context['startup_main_program'] size = ctx.sections()[0]
single_dim = ctx.sections()[1] if ctx.is_sparse() else 1
adam_d2sum = context["user_defined_strategy"].adam_d2sum
print("parse_by_optimizer table_id:{} is_datanorm:{}".format(
ctx.table_id(), ctx.is_datanorm_table()))
main_program, startup_program = get_program_by_id(context,
ctx.program_id())
pserver_id = get_role_id(context['role_maker']) pserver_id = get_role_id(context['role_maker'])
pserver_num = len(get_ps_endpoints(context['role_maker'])) pserver_num = len(get_ps_endpoints(context['role_maker']))
optimizer_ops = get_optimize_ops(main_program) optimizer_ops = get_optimize_ops(main_program)
print("the one ps optimizer_ops:", optimizer_ops)
print("the one ps parse_by_optimizer grad_name:", grad_name)
oop = None oop = None
for op in optimizer_ops: for op in optimizer_ops:
...@@ -278,6 +307,8 @@ class CommonAccessor: ...@@ -278,6 +307,8 @@ class CommonAccessor:
initializers = [] initializers = []
self.trainer_num = get_trainers(context['role_maker']) self.trainer_num = get_trainers(context['role_maker'])
self.table_num = size
self.table_dim = single_dim
if oop.type != 'adam' and adam_d2sum == True: if oop.type != 'adam' and adam_d2sum == True:
print('optimization algorithm is not adam, set adam_d2sum False') print('optimization algorithm is not adam, set adam_d2sum False')
...@@ -291,7 +322,11 @@ class CommonAccessor: ...@@ -291,7 +322,11 @@ class CommonAccessor:
param_varnames = self.opt_input_map["naive_adagrad"] param_varnames = self.opt_input_map["naive_adagrad"]
attr_varnames = self.opt_attr_map["naive_adagrad"] attr_varnames = self.opt_attr_map["naive_adagrad"]
self.accessor_class = "sgd" self.accessor_class = "sgd"
elif adam_d2sum: elif ctx.is_datanorm_table():
param_varnames = self.opt_input_map["summary"]
attr_varnames = self.opt_attr_map["summary"]
self.accessor_class = "summary"
elif adam_d2sum and not is_sparse:
param_varnames = self.opt_input_map["adam_d2sum"] param_varnames = self.opt_input_map["adam_d2sum"]
attr_varnames = self.opt_attr_map["adam_d2sum"] attr_varnames = self.opt_attr_map["adam_d2sum"]
self.accessor_class = "adam_d2sum" self.accessor_class = "adam_d2sum"
...@@ -306,10 +341,9 @@ class CommonAccessor: ...@@ -306,10 +341,9 @@ class CommonAccessor:
#for dims #for dims
if shape is None: if shape is None:
if is_sparse: if is_sparse:
shape = total_dims shape = single_dim
else: else:
shape = self.get_shard(total_dims, pserver_num, shape = self.get_shard(size, pserver_num, pserver_id)
pserver_id)
dims.append(shape) dims.append(shape)
#for initializers #for initializers
...@@ -333,6 +367,27 @@ class CommonAccessor: ...@@ -333,6 +367,27 @@ class CommonAccessor:
else: else:
initializer = "fill_constant&0" initializer = "fill_constant&0"
initializers.append(initializer) initializers.append(initializer)
elif self.accessor_class == "summary":
#for dims
if shape is None:
if is_sparse:
shape = single_dim
else:
shape = self.get_shard(size, pserver_num, pserver_id)
dims.append(shape)
#for initializers
if formal_name == "Param":
param = main_program.global_block().vars[oop.input(
formal_name)[0]]
initializer = self.get_initializer_attr(param.name,
startup_program)
elif formal_name == "SummaryDecayRate":
initializer = "fill_constant&0.99999"
else:
initializer = "fill_constant&0"
initializers.append(initializer)
else: else:
if formal_name == "G2Sum": if formal_name == "G2Sum":
dims.append(1) dims.append(1)
...@@ -348,9 +403,9 @@ class CommonAccessor: ...@@ -348,9 +403,9 @@ class CommonAccessor:
if shape is None: if shape is None:
if is_sparse: if is_sparse:
shape = total_dims shape = single_dim
else: else:
shape = self.get_shard(total_dims, pserver_num, shape = self.get_shard(size, pserver_num,
pserver_id) pserver_id)
dims.append(shape) dims.append(shape)
...@@ -379,6 +434,10 @@ class CommonAccessor: ...@@ -379,6 +434,10 @@ class CommonAccessor:
attrs += "entry: \"{}\" ".format(self.entry) attrs += "entry: \"{}\" ".format(self.entry)
attrs += "trainer_num: {} ".format(self.trainer_num) attrs += "trainer_num: {} ".format(self.trainer_num)
attrs += "sync: {} ".format(self.sync) attrs += "sync: {} ".format(self.sync)
if self.table_num:
attrs += "table_num: {} ".format(self.table_num)
if self.table_dim:
attrs += "table_dim: {} ".format(self.table_dim)
for param in self.params: for param in self.params:
attrs += "params: \"{}\" ".format(param) attrs += "params: \"{}\" ".format(param)
...@@ -448,10 +507,7 @@ class Table: ...@@ -448,10 +507,7 @@ class Table:
accessor_str = accessor_str.format( accessor_str = accessor_str.format(
conv_indent(indent), self.accessor_proto, conv_indent(indent)) conv_indent(indent), self.accessor_proto, conv_indent(indent))
attrs += accessor_str + "\n" attrs += accessor_str + "\n"
return table_str.format( elif self.accessor is not None:
conv_indent(indent), attrs, conv_indent(indent))
if self.accessor is not None:
attrs += self.accessor.to_string(indent) attrs += self.accessor.to_string(indent)
attrs += "\n" attrs += "\n"
...@@ -607,7 +663,9 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -607,7 +663,9 @@ class TheOnePSRuntime(RuntimeBase):
def _set_basic_info(self, context): def _set_basic_info(self, context):
self.context = context self.context = context
self.role_maker = context["role_maker"] self.role_maker = context["role_maker"]
self.origin_main_program = context["origin_main_program"] self.origin_main_program = context["origin_main_program"]
self.origin_main_programs = context["origin_main_programs"]
self.context[ self.context[
'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode 'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode
...@@ -615,10 +673,13 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -615,10 +673,13 @@ class TheOnePSRuntime(RuntimeBase):
self.context['trainer'] = TrainerRuntimeConfig(context[ self.context['trainer'] = TrainerRuntimeConfig(context[
'valid_strategy']) 'valid_strategy'])
self.context['ps_mode'] = self.context['trainer'].mode self.context['ps_mode'] = self.context['trainer'].mode
self.context['use_ps_gpu'] = context['valid_strategy'].use_ps_gpu self.context['use_ps_gpu'] = context['valid_strategy'].a_sync_configs[
'use_ps_gpu']
self.is_sync = True if self.context[ self.is_sync = True if self.context[
'ps_mode'] == DistributedMode.SYNC else False 'ps_mode'] == DistributedMode.SYNC else False
self.context['grad_name_to_param_name'] = {} self.context['grad_name_to_param_name'] = {}
self.context['tensor_table'] = {}
build_var_distributed(self.context)
def _init_worker(self): def _init_worker(self):
worker = self._get_fleet_proto(is_server=False, is_sync=self.is_sync) worker = self._get_fleet_proto(is_server=False, is_sync=self.is_sync)
...@@ -689,6 +750,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -689,6 +750,7 @@ class TheOnePSRuntime(RuntimeBase):
sync_kwargs = sync_strategy_envs() sync_kwargs = sync_strategy_envs()
kwargs.update(sync_kwargs) kwargs.update(sync_kwargs)
print("communicator config:", trainer_config.get_communicator_flags())
self._communicator = Communicator( self._communicator = Communicator(
trainer_config.mode, kwargs, trainer_config.mode, kwargs,
trainer_config.get_communicator_flags()) trainer_config.get_communicator_flags())
...@@ -893,7 +955,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -893,7 +955,7 @@ class TheOnePSRuntime(RuntimeBase):
common.table_name = self.context['grad_name_to_param_name'][ common.table_name = self.context['grad_name_to_param_name'][
ctx.origin_varnames()[0]] ctx.origin_varnames()[0]]
if self.ps_mode == DistributedMode.GEO: if self.context['ps_mode'] == DistributedMode.GEO:
table.table_class = "SparseGeoTable" table.table_class = "SparseGeoTable"
else: else:
all_table_proto = self.context[ all_table_proto = self.context[
...@@ -907,7 +969,8 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -907,7 +969,8 @@ class TheOnePSRuntime(RuntimeBase):
table.table_class = table_proto.table_class table.table_class = table_proto.table_class
else: else:
table.table_class = parse_table_class( table.table_class = parse_table_class(
common.table_name, self.origin_main_program) common.table_name,
ctx.program_id(), self.context)
if table.table_class != 'MemorySparseTable': if table.table_class != 'MemorySparseTable':
table.table_class = 'MemorySparseTable' table.table_class = 'MemorySparseTable'
warnings.warn( warnings.warn(
...@@ -925,12 +988,12 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -925,12 +988,12 @@ class TheOnePSRuntime(RuntimeBase):
warnings.warn( warnings.warn(
"The accessor of sparse table is not set, use default value." "The accessor of sparse table is not set, use default value."
) )
get_default_accessor_proto(table_proto.accessor, get_default_accessor_proto(
common.table_name, table_proto.accessor, common.table_name,
self.origin_main_program) ctx.program_id(), self.context)
check_embedding_dim(table_proto.accessor, check_embedding_dim(table_proto.accessor,
common.table_name, common.table_name,
self.origin_main_program) ctx.program_id(), self.context)
table.accessor_proto = text_format.MessageToString( table.accessor_proto = text_format.MessageToString(
table_proto.accessor) table_proto.accessor)
else: else:
...@@ -940,15 +1003,11 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -940,15 +1003,11 @@ class TheOnePSRuntime(RuntimeBase):
common.table_name = "MergedDense" common.table_name = "MergedDense"
adam_d2sum = self.context["user_defined_strategy"].adam_d2sum adam_d2sum = self.context["user_defined_strategy"].adam_d2sum
common.parse_by_optimizer(ctx.origin_varnames()[0], common.parse_by_optimizer(ctx, self.context)
ctx.is_sparse(),
ctx.sections()[1] if ctx.is_sparse()
else ctx.sections()[0], self.context,
adam_d2sum)
if ctx.is_sparse(): if ctx.is_sparse():
common.parse_entry(common.table_name, common.parse_entry(common.table_name,
self.origin_main_program) ctx.program_id(), self.context)
if is_sync: if is_sync:
common.sync = "true" common.sync = "true"
...@@ -1023,8 +1082,9 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1023,8 +1082,9 @@ class TheOnePSRuntime(RuntimeBase):
self._server.init_server(proto_txt, string_hosts, role_id, trainers, self._server.init_server(proto_txt, string_hosts, role_id, trainers,
self._server_sub_program) self._server_sub_program)
dist_varnames = get_sparse_tablenames(self.origin_main_program, True) dist_varnames = get_sparse_tablenames(self.origin_main_programs, True)
sparse_varnames = get_sparse_tablenames(self.origin_main_program, False) sparse_varnames = get_sparse_tablenames(self.origin_main_programs,
False)
distributed_varnames = dist_varnames + sparse_varnames distributed_varnames = dist_varnames + sparse_varnames
...@@ -1070,6 +1130,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1070,6 +1130,7 @@ class TheOnePSRuntime(RuntimeBase):
if var.name in exclude_var_names: if var.name in exclude_var_names:
return False return False
from .utils.public import _get_varname_parts
origin_varname, _, _ = _get_varname_parts(var.name) origin_varname, _, _ = _get_varname_parts(var.name)
if origin_varname.endswith("@GRAD"): if origin_varname.endswith("@GRAD"):
return False return False
...@@ -1085,16 +1146,24 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1085,16 +1146,24 @@ class TheOnePSRuntime(RuntimeBase):
return is_valid return is_valid
def _get_inference_model_path(self, dirname):
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
return model_path
def _save_sparse_params(self, executor, dirname, context, main_program, def _save_sparse_params(self, executor, dirname, context, main_program,
mode): mode):
distributed_varnames = get_sparse_tablenames( distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
self.context['origin_main_program'], True) True)
values = [] values = []
model_path = self._get_inference_model_path(dirname)
for id, names in context.items(): for id, names in context.items():
if names[0] not in distributed_varnames: if names[0] not in distributed_varnames:
# only save sparse param to local # only save sparse param to local
try: try:
self._worker.recv_and_save_model(id, dirname) self._worker.recv_and_save_model(id, model_path)
except: except:
pass pass
# save sparse & distributed param on server # save sparse & distributed param on server
...@@ -1221,10 +1290,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1221,10 +1290,7 @@ class TheOnePSRuntime(RuntimeBase):
infer_program._copy_dist_param_info_from(program) infer_program._copy_dist_param_info_from(program)
if dirname.startswith("afs:") or dirname.startswith("hdfs:"): model_path = self._get_inference_model_path(dirname)
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
model_basename = "__model__" model_basename = "__model__"
model_basename = os.path.join(model_path, model_basename) model_basename = os.path.join(model_path, model_basename)
paddle.save(infer_program, model_basename) paddle.save(infer_program, model_basename)
...@@ -1266,7 +1332,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1266,7 +1332,7 @@ class TheOnePSRuntime(RuntimeBase):
self._ps_inference_save_persistables(*args, **kwargs) self._ps_inference_save_persistables(*args, **kwargs)
def _load_sparse_params(self, dirname, context, main_program, mode): def _load_sparse_params(self, dirname, context, main_program, mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_program, distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
True) True)
values = [] values = []
for id, names in context.items(): for id, names in context.items():
......
...@@ -79,7 +79,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式 ...@@ -79,7 +79,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式
super(GeoPsProgramBuilder, self).__init__(pass_ctx) super(GeoPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.GEO: if self.ps_mode != DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}", raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "GeoPsProgramBuilder")) format(self.ps_mode, "GeoPsProgramBuilder"))
def _build_trainer_programs(self): def _build_trainer_programs(self):
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs) append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
...@@ -97,9 +97,9 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder): ...@@ -97,9 +97,9 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx): def __init__(self, pass_ctx):
logger.info("start building cpu-sync-ps program") logger.info("start building cpu-sync-ps program")
super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx) super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.SYNC: if self.ps_mode != DistributedMode.SYNC and self.ps_mode != DistributedMode.ASYNC:
raise ValueError("ps mode: {} not matched {}", raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "CpuSyncPsProgramBuilder")) format(self.ps_mode, "CpuSyncPsProgramBuilder"))
def _build_trainer_programs(self): def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass", add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
...@@ -178,7 +178,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder): ...@@ -178,7 +178,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
if self.use_ps_gpu or self.ps_mode == DistributedMode.GEO or self.attrs[ if self.use_ps_gpu or self.ps_mode == DistributedMode.GEO or self.attrs[
'is_heter_ps_mode'] == False: 'is_heter_ps_mode'] == False:
raise ValueError("ps mode: {} not matched {}", raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "HeterAsyncPsProgramBuilder")) format(self.ps_mode, "HeterAsyncPsProgramBuilder"))
def _build_trainer_programs(self): def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass", add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
......
...@@ -577,7 +577,7 @@ class CompileTimeStrategy(object): ...@@ -577,7 +577,7 @@ class CompileTimeStrategy(object):
sparse_ctx = CommContext(grad_name, [grad_name], sparse_ctx = CommContext(grad_name, [grad_name],
["127.0.0.1:6071"], [var_numel], ["127.0.0.1:6071"], [var_numel],
[grad_name], trainer_id, True, True, [grad_name], trainer_id, True, True,
is_distributed, idx, False) is_distributed, idx, False, False, -1)
idx += 1 idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx send_ctx[sparse_ctx.var_name()] = sparse_ctx
...@@ -615,7 +615,8 @@ class CompileTimeStrategy(object): ...@@ -615,7 +615,8 @@ class CompileTimeStrategy(object):
aggregate = True aggregate = True
dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"],
[var_numel], origin_varnames, trainer_id, [var_numel], origin_varnames, trainer_id,
aggregate, False, False, idx, False) aggregate, False, False, idx, False, False,
-1)
send_ctx[grad_name] = dense_ctx send_ctx[grad_name] = dense_ctx
idx += 1 idx += 1
else: else:
...@@ -630,7 +631,7 @@ class CompileTimeStrategy(object): ...@@ -630,7 +631,7 @@ class CompileTimeStrategy(object):
dense_ctx = CommContext(grad_name, [grad_name], dense_ctx = CommContext(grad_name, [grad_name],
["127.0.0.1:6071"], [var_numel], ["127.0.0.1:6071"], [var_numel],
[origin_varname], trainer_id, aggregate, [origin_varname], trainer_id, aggregate,
False, False, idx, False) False, False, idx, False, False, -1)
send_ctx[grad_name] = dense_ctx send_ctx[grad_name] = dense_ctx
idx += 1 idx += 1
return idx return idx
...@@ -672,7 +673,7 @@ class CompileTimeStrategy(object): ...@@ -672,7 +673,7 @@ class CompileTimeStrategy(object):
sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape, sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape,
[grad_name], trainer_id, True, True, [grad_name], trainer_id, True, True,
is_distributed, idx, False) is_distributed, idx, False, False, -1)
idx += 1 idx += 1
send_ctx[sparse_ctx.var_name()] = sparse_ctx send_ctx[sparse_ctx.var_name()] = sparse_ctx
...@@ -750,7 +751,7 @@ class CompileTimeStrategy(object): ...@@ -750,7 +751,7 @@ class CompileTimeStrategy(object):
sections = [1] * len(endpoints) sections = [1] * len(endpoints)
names = [name] * len(endpoints) names = [name] * len(endpoints)
ctx = CommContext(name, names, endpoints, sections, [name], trainer_id, ctx = CommContext(name, names, endpoints, sections, [name], trainer_id,
True, False, False, idx, True) True, False, False, idx, True, False, -1)
return name, ctx return name, ctx
def _create_vars_from_blocklist(self, block_list): def _create_vars_from_blocklist(self, block_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册