未验证 提交 bcaf88d2 编写于 作者: W wangguanqun 提交者: GitHub

Ps optimizer multi programs (#39883)

* fix benchmark and communicator config

* fix bugs of the_one_ps

* multi program and fix bug in optimizer

* multi program in the_one_ps

* public commcontext

* ps optimizer multi programs

* the one ps merge

* fix bug in test
上级 faece382
...@@ -29,7 +29,6 @@ from ..fluid.layers import utils ...@@ -29,7 +29,6 @@ from ..fluid.layers import utils
from ..fluid.dygraph import layers from ..fluid.dygraph import layers
from ..fluid.dygraph.parallel import prepare_context from ..fluid.dygraph.parallel import prepare_context
import paddle import paddle
from .fleet import fleet
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle import _C_ops from paddle import _C_ops
...@@ -1422,6 +1421,7 @@ def split(x, ...@@ -1422,6 +1421,7 @@ def split(x,
"graph mode, plese use ParallelEmbedding, ParallelRowLinear, " "graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
"ParallelColumnLinear instead.") "ParallelColumnLinear instead.")
else: else:
from .fleet import fleet
assert fleet._role_maker, ("To use paddle.distributed.split, " assert fleet._role_maker, ("To use paddle.distributed.split, "
"you must call fleet.init() firstly.") "you must call fleet.init() firstly.")
rank = fleet.worker_index() rank = fleet.worker_index()
......
...@@ -31,14 +31,19 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -31,14 +31,19 @@ class ParameterServerOptimizer(MetaOptimizerBase):
self.inner_opt = optimizer self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently # we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = [] self.meta_optimizers_white_list = []
self.pass_ctx = PassContext()
def _set_basic_info(self, loss, role_maker, user_defined_optimizer, def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy): user_defined_strategy):
super(ParameterServerOptimizer, self)._set_basic_info( super(ParameterServerOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy) loss, role_maker, user_defined_optimizer, user_defined_strategy)
def _set_origin_programs(self, losses):
self.origin_main_programs = []
for loss in losses:
self.origin_main_programs.append(loss.block.program)
def _init_ps_pass_context(self, loss, startup_program): def _init_ps_pass_context(self, loss, startup_program):
self.pass_ctx = PassContext()
attrs = {} attrs = {}
# trainer # trainer
attrs["env"] = get_dist_env() attrs["env"] = get_dist_env()
...@@ -46,9 +51,9 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -46,9 +51,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
attrs['loss'] = loss attrs['loss'] = loss
attrs['min_block_size'] = 81920 attrs['min_block_size'] = 81920
attrs['origin_main_program'] = loss.block.program attrs['origin_main_program'] = loss.block.program
attrs['origin_main_programs'] = [loss.block.program]
attrs['origin_startup_program'] = startup_program attrs['origin_startup_program'] = startup_program
attrs['origin_startup_programs'] = [startup_program]
attrs['origin_main_programs'] = self.origin_main_programs
attrs['cloned_main'] = attrs['origin_main_program'].clone() attrs['cloned_main'] = attrs['origin_main_program'].clone()
attrs['cloned_startup'] = attrs['origin_startup_program'].clone() attrs['cloned_startup'] = attrs['origin_startup_program'].clone()
...@@ -90,10 +95,11 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -90,10 +95,11 @@ class ParameterServerOptimizer(MetaOptimizerBase):
return False return False
def _can_apply(self): def _can_apply(self):
if self._attrs['role_maker']._is_collective or self._attrs[ if self.role_maker._is_collective:
'k_steps'] < 0:
return False return False
return True
k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
return True if k_steps >= 0 else False
def minimize_impl(self, def minimize_impl(self,
loss, loss,
...@@ -104,12 +110,37 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -104,12 +110,37 @@ class ParameterServerOptimizer(MetaOptimizerBase):
no_grad_set) no_grad_set)
if startup_program == None: if startup_program == None:
startup_program = paddle.static.default_startup_program() startup_program = paddle.static.default_startup_program()
print("program after inner optimizer minimize:",
str(loss.block.program))
self._set_origin_programs([loss])
self._init_ps_pass_context(loss, startup_program) self._init_ps_pass_context(loss, startup_program)
ps_builder = PsProgramBuilderFactory()._create_ps_program_builder( ps_builder = PsProgramBuilderFactory()._create_ps_program_builder(
self.pass_ctx) self.pass_ctx)
ps_builder._build_programs() ps_builder._build_programs()
return None, None return None, None
def minimize_losses_impl(self,
losses,
startup_program=None,
parameter_list=None,
no_grad_set=None):
if parameter_list is None:
parameter_list = [None] * len(losses)
for idx, loss in enumerate(losses):
startup_prog = startup_program[idx]
parameters = parameter_list[idx]
self.inner_opt.minimize(loss, startup_prog, parameters, no_grad_set)
self._set_origin_programs(losses)
for idx, loss in enumerate(losses):
print("ps_optimizer idx loss:", idx, loss)
startup_prog = startup_program[idx]
self._init_ps_pass_context(loss, startup_prog)
ps_builder = PsProgramBuilderFactory()._create_ps_program_builder(
self.pass_ctx)
ps_builder._build_programs()
startup_program[idx] = self.pass_ctx._attrs['cloned_startup']
return None, None
def _can_apply_geo(self, program): def _can_apply_geo(self, program):
def get_sys_free_mem(): def get_sys_free_mem():
plat = platform.system() plat = platform.system()
......
...@@ -74,6 +74,8 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用 ...@@ -74,6 +74,8 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用
def _apply_single_impl(self, main_program, startup_program, pass_ctx): def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs attrs = pass_ctx._attrs
print("pass loss program id:", id(attrs['loss'].block.program))
print("pass main program id:", id(main_program))
ps_mode = attrs['ps_mode'] ps_mode = attrs['ps_mode']
if ps_mode == DistributedMode.GEO: if ps_mode == DistributedMode.GEO:
send_ctx = get_geo_trainer_send_context(attrs) # geo 模式 send_ctx = get_geo_trainer_send_context(attrs) # geo 模式
...@@ -84,6 +86,8 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用 ...@@ -84,6 +86,8 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用
for merged_name, send in send_ctx.items(): for merged_name, send in send_ctx.items():
if send.is_sparse() and ps_mode != DistributedMode.GEO: if send.is_sparse() and ps_mode != DistributedMode.GEO:
continue continue
if send.program_id() != id(attrs['loss'].block.program):
continue
logger.info('merged_name, send: {}, {}'.format(merged_name, send)) logger.info('merged_name, send: {}, {}'.format(merged_name, send))
is_sparse = 1 if send.is_sparse() else 0 is_sparse = 1 if send.is_sparse() else 0
is_sparse = 2 if send.is_distributed() else is_sparse is_sparse = 2 if send.is_distributed() else is_sparse
...@@ -496,6 +500,7 @@ class DeleteOptimizesPass(PassBase): ...@@ -496,6 +500,7 @@ class DeleteOptimizesPass(PassBase):
persistable=True) persistable=True)
def _apply_single_impl(self, main_program, startup_program, pass_ctx): def _apply_single_impl(self, main_program, startup_program, pass_ctx):
print("delete_optimizer_pass")
attrs = pass_ctx._attrs attrs = pass_ctx._attrs
optimizer_ops = get_optimize_ops(main_program) optimizer_ops = get_optimize_ops(main_program)
lr_ops = get_lr_ops(main_program) lr_ops = get_lr_ops(main_program)
......
...@@ -40,12 +40,12 @@ def get_program_by_id(context, program_id): ...@@ -40,12 +40,12 @@ def get_program_by_id(context, program_id):
programs = context["origin_main_programs"] programs = context["origin_main_programs"]
for i, program in enumerate(programs): for i, program in enumerate(programs):
if id(program) == program_id: if id(program) == program_id:
return program, context["origin_startup_programs"][i] return program, context["origin_startup_programs"][i], i
return None, None return None, None, None
def parse_table_class(varname, program_id, context): def parse_table_class(varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id) main_program, startup_program, idx = get_program_by_id(context, program_id)
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op): if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue continue
...@@ -60,7 +60,7 @@ def parse_table_class(varname, program_id, context): ...@@ -60,7 +60,7 @@ def parse_table_class(varname, program_id, context):
def check_embedding_dim(accessor_proto, varname, program_id, context): def check_embedding_dim(accessor_proto, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id) main_program, startup_program, idx = get_program_by_id(context, program_id)
embedding_dim = 0 embedding_dim = 0
for var in main_program.list_vars(): for var in main_program.list_vars():
if var.name == varname: if var.name == varname:
...@@ -94,10 +94,9 @@ class Service: ...@@ -94,10 +94,9 @@ class Service:
class GpuService(Service): class GpuService(Service):
def __init__(self): def __init__(self):
super(GpuService).__init__(self) super(GpuService, self).__init__()
def _set(self, service_proto): def _set(self, service_proto):
super(GpuService)._set(service_proto)
service_proto.server_class = 'PsLocalServer' service_proto.server_class = 'PsLocalServer'
service_proto.client_class = 'PsLocalClient' service_proto.client_class = 'PsLocalClient'
...@@ -111,7 +110,8 @@ class Accessor: ...@@ -111,7 +110,8 @@ class Accessor:
# TableAccessorParameter accessor # TableAccessorParameter accessor
def _set(self, accessor_proto, varname, program_id, context): def _set(self, accessor_proto, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id) main_program, startup_program, idx = get_program_by_id(context,
program_id)
embedding_dim = 0 embedding_dim = 0
for var in main_program.list_vars(): for var in main_program.list_vars():
if var.name == varname: if var.name == varname:
...@@ -236,7 +236,8 @@ class CommonAccessor(Accessor): ...@@ -236,7 +236,8 @@ class CommonAccessor(Accessor):
self.opt_init_map = opt_init_map self.opt_init_map = opt_init_map
def parse_entry(self, varname, program_id, context): def parse_entry(self, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id) main_program, startup_program, idx = get_program_by_id(context,
program_id)
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op): if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue continue
...@@ -290,7 +291,7 @@ class CommonAccessor(Accessor): ...@@ -290,7 +291,7 @@ class CommonAccessor(Accessor):
print("parse_by_optimizer table_id:{} is_datanorm:{}".format( print("parse_by_optimizer table_id:{} is_datanorm:{}".format(
ctx.table_id(), ctx.is_datanorm_table())) ctx.table_id(), ctx.is_datanorm_table()))
main_program, startup_program = get_program_by_id(context, main_program, startup_program, idx = get_program_by_id(context,
ctx.program_id()) ctx.program_id())
pserver_id = get_role_id(context['role_maker']) pserver_id = get_role_id(context['role_maker'])
pserver_num = len(get_ps_endpoints(context['role_maker'])) pserver_num = len(get_ps_endpoints(context['role_maker']))
...@@ -359,10 +360,11 @@ class CommonAccessor(Accessor): ...@@ -359,10 +360,11 @@ class CommonAccessor(Accessor):
param = main_program.global_block().vars[oop.input( param = main_program.global_block().vars[oop.input(
formal_name)[0]] formal_name)[0]]
#TODO: for dense learning_rate, can be different from sparse lr #TODO: for dense learning_rate, can be different from sparse lr
if formal_name == "LearningRate" and param.name != "learning_rate_0": if formal_name == "LearningRate" and param.name != "learning_rate_" + str(
idx):
warnings.warn("will support decay soon") warnings.warn("will support decay soon")
param = main_program.global_block().vars[ param = main_program.global_block().vars[
"learning_rate_0"] "learning_rate_" + str(idx)]
initializer = self.get_initializer_attr(param.name, initializer = self.get_initializer_attr(param.name,
startup_program) startup_program)
...@@ -404,10 +406,11 @@ class CommonAccessor(Accessor): ...@@ -404,10 +406,11 @@ class CommonAccessor(Accessor):
else: else:
param = main_program.global_block().vars[oop.input( param = main_program.global_block().vars[oop.input(
formal_name)[0]] formal_name)[0]]
if formal_name == "LearningRate" and param.name != "learning_rate_0": if formal_name == "LearningRate" and param.name != "learning_rate_" + str(
idx):
warnings.warn("will support decay soon") warnings.warn("will support decay soon")
param = main_program.global_block().vars[ param = main_program.global_block().vars[
"learning_rate_0"] "learning_rate_" + str(idx)]
if shape is None: if shape is None:
if is_sparse: if is_sparse:
...@@ -707,6 +710,7 @@ class PsDescBuilder(object): ...@@ -707,6 +710,7 @@ class PsDescBuilder(object):
self.ps_mode = context['ps_mode'] self.ps_mode = context['ps_mode']
self.is_heter_ps_mode = context['is_heter_ps_mode'] self.is_heter_ps_mode = context['is_heter_ps_mode']
self.use_ps_gpu = context['use_ps_gpu'] self.use_ps_gpu = context['use_ps_gpu']
self.barrier_table_id = None
self.send_ctx = get_the_one_send_context( self.send_ctx = get_the_one_send_context(
self.context, self.context,
use_origin_program=True, use_origin_program=True,
...@@ -767,6 +771,8 @@ class PsDescBuilder(object): ...@@ -767,6 +771,8 @@ class PsDescBuilder(object):
table_proto = self.ps_desc.server_param.downpour_server_param.downpour_table_param.add( table_proto = self.ps_desc.server_param.downpour_server_param.downpour_table_param.add(
) )
table._set(table_proto) table._set(table_proto)
if type(table) == BarrierTable and self.barrier_table_id is None:
self.barrier_table_id = table.idx
self.service._set( self.service._set(
self.ps_desc.server_param.downpour_server_param.service_param) self.ps_desc.server_param.downpour_server_param.service_param)
return text_format.MessageToString(self.ps_desc) return text_format.MessageToString(self.ps_desc)
...@@ -820,9 +826,9 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -820,9 +826,9 @@ class TheOnePSRuntime(RuntimeBase):
self.context['tensor_table'] = {} self.context['tensor_table'] = {}
build_var_distributed(self.context) build_var_distributed(self.context)
endpoints = get_ps_endpoints(self.role_maker) self.endpoints = get_ps_endpoints(self.role_maker)
self.string_hosts = [] self.string_hosts = []
for idx, ep in enumerate(endpoints): for idx, ep in enumerate(self.endpoints):
host, port = ep.split(":") host, port = ep.split(":")
pshost = fluid.core.PSHost(host, int(port), idx) pshost = fluid.core.PSHost(host, int(port), idx)
self.string_hosts.append(pshost.serialize_to_string()) self.string_hosts.append(pshost.serialize_to_string())
...@@ -848,7 +854,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -848,7 +854,7 @@ class TheOnePSRuntime(RuntimeBase):
kwargs["trainer_id"] = self.role_maker._worker_index() kwargs["trainer_id"] = self.role_maker._worker_index()
return kwargs return kwargs
proto_txt = worker_desc + "\n" + server_desc proto_txt = worker_desc
debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
if debug: if debug:
print("worker: \n{}".format(proto_txt)) print("worker: \n{}".format(proto_txt))
...@@ -859,7 +865,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -859,7 +865,7 @@ class TheOnePSRuntime(RuntimeBase):
self.context, self.context,
split_dense_table=self.is_heter_ps_mode, split_dense_table=self.is_heter_ps_mode,
use_origin_program=self.is_heter_ps_mode, use_origin_program=self.is_heter_ps_mode,
ep_list=endpoints) ep_list=self.endpoints)
trainer_config = self.context['trainer'] trainer_config = self.context['trainer']
debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
...@@ -876,10 +882,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -876,10 +882,7 @@ class TheOnePSRuntime(RuntimeBase):
kwargs["trainer_id"] = self.role_maker._role_id() kwargs["trainer_id"] = self.role_maker._role_id()
kwargs["trainers"] = self.role_maker._worker_num() kwargs["trainers"] = self.role_maker._worker_num()
for table in server.servers[0].tables: #TODO kwargs["barrier_table_id"] = self.ps_desc_builder.barrier_table_id
if table.table_class == "BarrierTable":
kwargs["barrier_table_id"] = table.id
break
if self.context['ps_mode'] == DistributedMode.SYNC: if self.context['ps_mode'] == DistributedMode.SYNC:
sync_kwargs = sync_strategy_envs() sync_kwargs = sync_strategy_envs()
...@@ -1009,7 +1012,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1009,7 +1012,7 @@ class TheOnePSRuntime(RuntimeBase):
if origin_varname.endswith("@GRAD"): if origin_varname.endswith("@GRAD"):
return False return False
if origin_varname == "learning_rate_0": if origin_varname.startswith("learning_rate_"):
return False return False
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
...@@ -1113,7 +1116,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1113,7 +1116,7 @@ class TheOnePSRuntime(RuntimeBase):
"in fleet.save() function, executor must be as Executor type") "in fleet.save() function, executor must be as Executor type")
if main_program is None: if main_program is None:
main_program = self.context['origin_ps_main_program'] main_program = self.context['origin_main_program']
if isinstance(main_program, CompiledProgram): if isinstance(main_program, CompiledProgram):
raise TypeError( raise TypeError(
......
...@@ -88,7 +88,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式 ...@@ -88,7 +88,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式
self.attrs['origin_main_program'] = self.cloned_main self.attrs['origin_main_program'] = self.cloned_main
if self.launch_barrier and self.launch_barrier_flag: if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints) wait_server_ready(self.server_endpoints)
return return
...@@ -103,10 +103,13 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder): ...@@ -103,10 +103,13 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder):
format(self.ps_mode, "PsProgramBuilder")) format(self.ps_mode, "PsProgramBuilder"))
def _build_trainer_programs(self): def _build_trainer_programs(self):
print("build trainer program entry")
print("before ps program builder program:", self.cloned_main)
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass", add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
self.attrs) self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx) add_lr_decay_table_pass.apply([], [], self.pass_ctx)
print("before distributed op pass")
distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs) distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs)
distributed_ops_pass.apply([self.cloned_main], [None], self.pass_ctx) distributed_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
...@@ -126,9 +129,10 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder): ...@@ -126,9 +129,10 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder):
self.attrs['origin_main_program'] = self.cloned_main self.attrs['origin_main_program'] = self.cloned_main
self.attrs['origin_startup_program'] = self.cloned_startup self.attrs['origin_startup_program'] = self.cloned_startup
print("after ps program builder program:", self.cloned_main)
if self.launch_barrier and self.launch_barrier_flag: if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints) wait_server_ready(self.server_endpoints)
return return
...@@ -167,7 +171,7 @@ class GpuPsProgramBuilder(PsProgramBuilder): ...@@ -167,7 +171,7 @@ class GpuPsProgramBuilder(PsProgramBuilder):
self.attrs['origin_startup_program'] = self.cloned_startup self.attrs['origin_startup_program'] = self.cloned_startup
if self.launch_barrier and self.launch_barrier_flag: if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints) wait_server_ready(self.server_endpoints)
return return
...@@ -220,7 +224,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder): ...@@ -220,7 +224,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
[self.cloned_startup], self.pass_ctx) [self.cloned_startup], self.pass_ctx)
if self.launch_barrier and self.launch_barrier_flag: if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints) wait_server_ready(self.server_endpoints)
return return
......
...@@ -450,9 +450,8 @@ def get_the_one_send_context(context, ...@@ -450,9 +450,8 @@ def get_the_one_send_context(context,
idx = 0 idx = 0
for i, program in enumerate(origin_programs): for i, program in enumerate(origin_programs):
merged_dense_pairs = context['merged_dense_pairs'][i] merged_dense_pairs = context['merged_dense_pairs'][i]
idx += get_dense_send_context(program, send_ctx, idx, idx = get_dense_send_context(program, send_ctx, idx, merged_dense_pairs,
merged_dense_pairs, trainer_id, trainer_id, split_dense_table)
split_dense_table)
distibuted_varnames = get_sparse_tablenames(origin_programs, True) distibuted_varnames = get_sparse_tablenames(origin_programs, True)
print("public distibuted_varnames:", distibuted_varnames) print("public distibuted_varnames:", distibuted_varnames)
for i, program in enumerate(origin_programs): for i, program in enumerate(origin_programs):
......
...@@ -146,9 +146,13 @@ class TestPsTrainerPass(PsPassTestBase): ...@@ -146,9 +146,13 @@ class TestPsTrainerPass(PsPassTestBase):
self.config['ps_mode_config'] = "../ps/gpu_ps_config.yaml" self.config['ps_mode_config'] = "../ps/gpu_ps_config.yaml"
self.config['debug_new_minimize'] = '0' self.config['debug_new_minimize'] = '0'
self.config['log_dir'] = ps_log_root_dir + "gpubox_log_old_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch("gpu-ps") self.ps_launch("gpu-ps")
self.config['debug_new_minimize'] = '1' self.config['debug_new_minimize'] = '1'
self.config['log_dir'] = ps_log_root_dir + "gpubox_log_new_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch("gpu-ps") self.ps_launch("gpu-ps")
file1 = '/ps_log/gpubox_run_minimize_debug:_0_worker_main.prototxt' file1 = '/ps_log/gpubox_run_minimize_debug:_0_worker_main.prototxt'
......
...@@ -382,6 +382,7 @@ class DnnTrainer(object): ...@@ -382,6 +382,7 @@ class DnnTrainer(object):
ps_optimizer = ParameterServerOptimizer(inner_optimizer) ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer, ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
user_defined_strategy) user_defined_strategy)
ps_optimizer._set_origin_programs([loss])
ps_optimizer._init_ps_pass_context(loss, startup_program) ps_optimizer._init_ps_pass_context(loss, startup_program)
_main = ps_optimizer.pass_ctx._attrs['cloned_main'] _main = ps_optimizer.pass_ctx._attrs['cloned_main']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册