未验证 提交 196dbfc2 编写于 作者: Z ziyoujiyi 提交者: GitHub

ps optimize refactor (#38982)

* delete gloo connect retry

* the_one_ps dirs reconstruct

* .

* .

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* refactor ps optimize

* refactor ps optimize

* refactor ps optimize

* .

* .

* .

* .

* .

* .

* refactor theoneps

* the_one_ps

* add ps pass unittest

* add ps pass unittest

* ps unitest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* ps unittest ready

* ps unittest ready

* solve dist_pass init conflict

* solve import CommContext error

* unittest ok

* implement AllocateFrom

* solve setup.py.in conflict

* solve conflict

* solve conflict

* solve conflict

* .

* .
Co-authored-by: Nzkh2016 <zhangkaihuo@baidu.com>
上级 de0bad2a
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from paddle import fluid
import paddle.distributed.passes
from .meta_optimizer_base import MetaOptimizerBase
from paddle.fluid import core
import subprocess
import re
import os
import platform
from paddle.distributed.ps.utils.public import *
from paddle.distributed.passes import PassContext
from ..base.private_helper_function import wait_server_ready
from paddle.distributed.ps.utils.ps_factory import PsProgramBuilderFactory
class ParameterServerOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(ParameterServerOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = []
self.attrs = {}
self.pass_ctx = PassContext()
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy):
super(ParameterServerOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
def _init_ps_pass_context(self, loss, startup_program):
# trainer
self.attrs["env"] = get_dist_env()
self.attrs['loss'] = loss
self.attrs['min_block_size'] = 81920
self.attrs['origin_main_program'] = loss.block.program
self.attrs['origin_startup_program'] = startup_program
self.attrs['cloned_main'] = loss.block.program.clone()
self.attrs['cloned_startup'] = startup_program.clone()
self.attrs['user_defined_strategy'] = self.user_defined_strategy
self.attrs['trainer'] = TrainerRuntimeConfig(self.user_defined_strategy)
self.attrs['ps_mode'] = self.attrs['trainer'].mode
self.attrs['role_maker'] = self.role_maker
self.attrs[
'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode
self.attrs['is_worker'] = self.role_maker._is_worker()
self.attrs['is_server'] = self.role_maker._is_server()
self.attrs['is_heter_worker'] = self.role_maker._is_heter_worker()
self.attrs['use_ps_gpu'] = self.user_defined_strategy.a_sync_configs[
"use_ps_gpu"]
self.attrs[
'lr_decay_steps'] = self.user_defined_strategy.a_sync_configs[
"lr_decay_steps"]
self.attrs['k_steps'] = self.user_defined_strategy.a_sync_configs[
"k_steps"]
self.attrs[
'launch_barrier'] = self.user_defined_strategy.a_sync_configs[
"launch_barrier"]
self.attrs['launch_barrier_flag'] = int(
os.getenv("FLAGS_LAUNCH_BARRIER", "1"))
build_var_distributed(self.attrs)
# server
self.attrs['_main_server'] = fluid.Program()
self.attrs['_startup_server'] = fluid.Program()
self.attrs['tensor_table'] = {}
self.pass_ctx._attrs = self.attrs
def _is_graph_out(self):
return False
def _can_apply(self):
if self._attrs['role_maker']._is_collective or self._attrs[
'k_steps'] < 0:
return False
return True
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
self.inner_opt.minimize(loss, startup_program, parameter_list,
no_grad_set)
if startup_program == None:
startup_program = paddle.static.default_startup_program()
self._init_ps_pass_context(loss, startup_program)
ps_builder = PsProgramBuilderFactory()._create_ps_program_builder(
self.pass_ctx)
ps_builder._build_programs()
return None, None
def _can_apply_geo(self, program):
def get_sys_free_mem():
plat = platform.system()
if platform.system() == "Darwin":
vm = subprocess.Popen(
['vm_stat'], stdout=subprocess.PIPE).communicate()[0]
# Process vm_stat
vmLines = vm.split('\n')
sep = re.compile(r':[\s]+')
vmStats = {}
for row in range(1, len(vmLines) - 2):
rowText = vmLines[row].strip()
rowElements = sep.split(rowText)
vmStats[(rowElements[0]
)] = int(rowElements[1].strip(r'\.')) * 4096
return vmStats["Pages free"]
elif platform.system() == "Linux":
mems = {}
with open('/proc/meminfo', 'rb') as f:
for line in f:
fields = line.split()
mems[fields[0]] = int(fields[1]) * 1024
free = mems[b'MemFree:']
return free
else:
raise ValueError(
"%s platform is unsupported is parameter server optimizer" %
(platform.system()))
if not isinstance(self.inner_opt, fluid.optimizer.SGDOptimizer):
return False
free = get_sys_free_mem()
processed_var_names = set(["@EMPTY@"])
param_memory_size = 0
for varname in program.global_block().vars:
var = program.global_block().vars[varname]
if not var.persistable or var.desc.type(
) != core.VarDesc.VarType.LOD_TENSOR:
continue
set_var_lod_type(var)
param_memory_size += get_var_mem_size(var)
processed_var_names.add(varname)
upper_mem_use = param_memory_size * 5.0
program_tmp_vars = dict()
eval_batch_size = 1024
for op in program.global_block().ops:
for var_name in op.output_arg_names:
if var_name in processed_var_names:
continue
processed_var_names.add(var_name)
var = program.global_block().vars[var_name]
if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
continue
data_count = 1
neg_dim_count = 0
for x in var.shape:
if x < 0:
if neg_dim_count >= 1:
raise ValueError(
"Var %s has more than one negative dim." %
(var_name))
neg_dim_count += 1
data_count *= (-x)
else:
data_count *= x
program_tmp_vars[var_name] = (
data_count, neg_dim_count,
vars_metatools.dtype_to_size[var.dtype])
for varname in program_tmp_vars:
data_count, neg_dim_count, type_size = program_tmp_vars[varname]
if neg_dim_count == 1:
data_count *= eval_batch_size
var_memory = data_count * type_size
upper_mem_use += var_memory
if upper_mem_use < free:
return True
else:
return False
def _enable_strategy(self, dist_strategy, context):
if dist_strategy.a_sync_configs["k_steps"] >= 0:
return
dist_strategy.a_sync = True
is_geo = self._can_apply_geo(context["origin_main_program"])
dist_strategy.a_sync_configs["k_steps"] = 800 if is_geo else 0
def _disable_strategy(self, dist_strategy):
dist_strategy.a_sync = False
dist_strategy.a_sync_configs["k_steps"] = -1
......@@ -19,6 +19,10 @@ from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .auto_parallel_recompute import *
from .cpp_pass import *
import os
if os.getenv("WITH_DISTRIBUTE") == "ON":
from .ps_trainer_pass import *
from .ps_server_pass import *
__all__ = [
'new_pass',
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from ..ps.utils.public import *
from paddle.framework import core
from .pass_base import PassBase, register_pass
from paddle.optimizer.lr import LRScheduler
from paddle.optimizer.lr import ExponentialDecay, NoamDecay, PiecewiseDecay, NaturalExpDecay, InverseTimeDecay
from paddle.fluid.layers.learning_rate_scheduler import exponential_decay, noam_decay, piecewise_decay, natural_exp_decay, inverse_time_decay
@register_pass("add_lr_decay_table_pass")
class AddLrDecayTablePass(PassBase):
def __init__(self):
super(AddLrDecayTablePass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _add_tensor_table(self,
attrs,
feed_var_name,
fetch_var_name="",
startup_program=None,
main_program=None,
tensor_table_class=""):
tensor_table_dict = {}
tensor_table_dict[feed_var_name] = {}
tensor_table_dict[feed_var_name]["feed_var_name"] = feed_var_name
tensor_table_dict[feed_var_name]["fetch_var_name"] = fetch_var_name
tensor_table_dict[feed_var_name]["startup_program"] = startup_program
tensor_table_dict[feed_var_name]["main_program"] = main_program
tensor_table_dict[feed_var_name][
"tensor_table_class"] = tensor_table_class
attrs['tensor_table'] = tensor_table_dict
def _get_lr_sheduler_program(self, lr_sheduler, lr_decay_steps):
schedler_decay = [
'NoamDecay', 'NaturalExpDecay', 'InverseTimeDecay',
'ExponentialDecay'
]
decay_main_program = fluid.framework.Program()
decay_startup_program = fluid.framework.Program()
lr_name = ""
if isinstance(lr_sheduler, ExponentialDecay):
with fluid.program_guard(decay_main_program, decay_startup_program):
lr = exponential_decay(1.0, lr_decay_steps, lr_sheduler.gamma,
True)
lr_name = lr.name
logging.warn(
"ExponentialDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n"
"\t strategy = paddle.distributed.fleet.DistributedStrategy() \n "
"\t strategy.a_sync = True \n"
"\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n"
% lr_decay_steps)
elif isinstance(lr_sheduler, NoamDecay):
with fluid.program_guard(decay_main_program, decay_startup_program):
lr = noam_decay(lr_sheduler.d_model, lr_sheduler.warmup_steps,
1.0)
lr_name = lr.name
logging.warn("NoamDecay is set, warmup steps is [ %d ]" %
lr_sheduler.warmup_steps)
elif isinstance(lr_sheduler, NaturalExpDecay):
with fluid.program_guard(decay_main_program, decay_startup_program):
lr = natural_exp_decay(1.0, lr_decay_steps, lr_sheduler.gamma,
True)
lr_name = lr.name
logging.warn(
"NaturalExpDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n"
"\t strategy = paddle.distributed.fleet.DistributedStrategy() \n "
"\t strategy.a_sync = True \n"
"\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n"
% lr_decay_steps)
elif isinstance(lr_sheduler, InverseTimeDecay):
with fluid.program_guard(decay_main_program, decay_startup_program):
lr = inverse_time_decay(1.0, lr_decay_steps, lr_sheduler.gamma,
True)
lr_name = lr.name
logging.warn(
"InverseTimeDecay is set, staircase = True, global learning rate decay step is [ %d ], Change decay steps as follow: \n"
"\t strategy = paddle.distributed.fleet.DistributedStrategy() \n "
"\t strategy.a_sync = True \n"
"\t strategy.a_sync_configs= { 'lr_decay_steps' : YOUR_DECAY_STEP } \n"
% lr_decay_steps)
else:
raise ValueError(
"Not supported current LearningRate strategy, please use follow decay strategy: {}".
format(schedler_decay))
return decay_main_program, decay_startup_program, lr_name
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
if hasattr(attrs['origin_main_program'], 'lr_sheduler') == False:
return
assert isinstance(attrs['origin_main_program'].lr_sheduler,
LRScheduler), "must be LRScheduler"
ops = get_optimize_ops(attrs['origin_main_program'])
lr_decay_main_program, lr_decay_startup_program, lr_name = _get_lr_sheduler_program(
attrs['origin_main_program'].lr_sheduler, attrs['lr_decay_steps'])
_add_tensor_table(attrs, "@LR_DECAY_COUNTER@", lr_name,
lr_decay_startup_program, lr_decay_main_program,
"GlobalStepTable")
return
@register_pass("add_listen_and_serv_pass")
class AddListenAndServPass(PassBase):
def __init__(self):
super(AddListenAndServPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
opt = {
"grad_to_block_id": None,
"sparse_grad_to_param": None,
"lr_decay_block_id": None,
"dense_optimize_blocks": None,
"sparse_optimize_blocks": None,
# runtime attribute
"endpoint": get_ps_endpoint(attrs['role_maker']),
"pserver_id": get_role_id(attrs['role_maker']),
"Fanin": get_trainers(attrs['role_maker']),
"distributed_mode": attrs['ps_mode'],
"rpc_get_thread_num": -1,
"rpc_send_thread_num": -1,
"rpc_prefetch_thread_num": -1
}
main_program.global_block().append_op(
type="listen_and_serv", inputs={'X': []}, outputs={}, attrs=opt)
attrs['cloned_main'] = main_program
@register_pass("add_rpc_global_flags_pass")
class AddRpcGlobalFlagsPass(PassBase):
def __init__(self):
super(AddRpcGlobalFlagsPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
pass
@register_pass("add_optimizer_pass")
class AddOptimizerPass(PassBase):
def __init__(self):
super(AddOptimizerPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
pass
@register_pass("add_geo_optimizer_pass")
class AddGeoOptimizerPass(PassBase):
def __init__(self):
super(AddGeoOptimizerPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
pass
@register_pass("build_pserver_startup_program_pass")
class BuildPserverStartupProgramPass(PassBase):
def __init__(self):
super(BuildPserverStartupProgramPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
pass
@register_pass("delete_unused_in_startup_pass")
class DeleteUnusedInStartupPass(PassBase):
def __init__(self):
super(DeleteUnusedInStartupPass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
pass
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
......@@ -11,3 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .ps_program_builder import *
from .public import *
__all__ = [
'PsProgramBuilder', 'GeoPsProgramBuilder', 'CpuSyncPsProgramBuilder',
'CpuAsyncPsProgramBuilder', 'GpuPsProgramBuilder',
'HeterAsyncPsProgramBuilder', 'FlPsProgramBuilder'
]
class PsProgramBuilderFactory(object):
def __init__(self):
pass
def _create_ps_program_builder(self, pass_ctx):
attrs = pass_ctx._attrs
if attrs['ps_mode'] == DistributedMode.GEO:
return globals()['GeoPsProgramBuilder'](pass_ctx)
elif attrs['use_ps_gpu']:
return globals()['GpuPsProgramBuilder'](pass_ctx)
elif attrs['is_heter_ps_mode']:
return globals()['HeterAsyncPsProgramBuilder'](pass_ctx)
elif 'is_fl_ps_mode' in attrs and attrs[
'is_fl_ps_mode'] == DistributedMode.FL:
return globals()['FlPsProgramBuilder'](pass_ctx)
else:
return globals()['CpuSyncPsProgramBuilder'](pass_ctx)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .public import *
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready
from paddle.distributed.passes import new_pass, PassContext
class PsProgramBuilder(object):
def __init__(self, pass_ctx):
self.pass_ctx = pass_ctx
self.attrs = self.pass_ctx._attrs
self.loss = self.attrs['loss']
self.cloned_main = self.attrs['cloned_main']
self.cloned_startup = self.attrs['cloned_startup']
self.use_ps_gpu = self.attrs['use_ps_gpu']
self.use_heter_ps = self.attrs['is_heter_ps_mode']
self.is_worker = self.attrs['is_worker']
self.is_heter_worker = self.attrs['is_heter_worker']
self.ps_mode = self.attrs['ps_mode']
self.launch_barrier = self.attrs['launch_barrier']
self.launch_barrier_flag = self.attrs['launch_barrier_flag']
self.server_endpoints = self.attrs['role_maker']._get_pserver_endpoints(
)
def _optimize_programs(self):
pass
def _build_trainer_programs(self):
pass
def _build_pserver_programs(self):
is_sgd_adam = False
ops = get_optimize_ops(self.attrs['origin_main_program'])
if len(ops) == 0:
return
add_lr_decay_table_pass = new_pass('add_lr_decay_table_pass',
self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
for op in ops:
if op.type in ["sgd", "adam"]:
is_sgd_adam = True
break
if is_sgd_adam:
return
def _build_programs(self):
if self.attrs['is_worker']:
self._build_trainer_programs()
fluid.framework.switch_startup_program(self.cloned_startup)
self.loss.block.program = self.cloned_main
elif self.attrs['is_server']:
self._build_pserver_programs()
self.loss.block.program = self.attrs['_main_server']
fluid.framework.switch_startup_program(self.attrs[
'_startup_server'])
class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式
def __init__(self, pass_ctx):
super(GeoPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "GeoPsProgramBuilder"))
def _build_trainer_programs(self):
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
append_send_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
self.attrs['origin_main_program'] = self.cloned_main
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
return
class CpuSyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode == DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "CpuSyncPsProgramBuilder"))
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs)
distributed_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
delete_optimizer_pass = new_pass("delete_optimizer_pass", self.attrs)
delete_optimizer_pass.apply([self.cloned_main], [None], self.pass_ctx)
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
append_send_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
delete_extra_optimizer_pass = new_pass("delete_extra_optimizer_pass",
self.attrs)
delete_extra_optimizer_pass.apply([self.attrs['origin_main_program']],
[self.cloned_startup], self.pass_ctx)
fake_init_ops_pass = new_pass("fake_init_ops_pass", self.attrs)
fake_init_ops_pass.apply([None], [self.cloned_startup], self.pass_ctx)
self.attrs['origin_main_program'] = self.cloned_main
self.attrs['origin_startup_program'] = self.cloned_startup
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
return
class CpuAsyncPsProgramBuilder(CpuSyncPsProgramBuilder):
def __init__(self, pass_ctx):
super(CpuAsyncPsProgramBuilder, self).__init__(pass_ctx)
class GpuPsProgramBuilder(PsProgramBuilder): # 和 geo、sync、async 等模式无关
def __init__(self, pass_ctx):
super(GpuPsProgramBuilder, self).__init__(pass_ctx)
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs)
distributed_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
fake_init_ops_pass = new_pass("fake_init_ops_pass", self.attrs)
fake_init_ops_pass.apply([None], [self.cloned_startup], self.pass_ctx)
ps_gpu_pass = new_pass("ps_gpu_pass", self.attrs)
ps_gpu_pass.apply([self.cloned_main], [None], self.pass_ctx)
ps_transpile_pass = new_pass("ps_transpile_pass", self.attrs)
ps_transpile_pass.apply([_main], [_startup], self.pass_ctx)
self.attrs['origin_main_program'] = self.cloned_main
self.attrs['origin_startup_program'] = self.cloned_startup
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
return
class HeterAsyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
super(HeterAsyncPsProgramBuilder, self).__init__(pass_ctx)
if self.use_ps_gpu or self.ps_mode == DistributedMode.GEO or self.attrs[
'is_heter_ps_mode'] == False:
raise ValueError("ps mode: {} not matched {}",
format(ps_mode, "HeterAsyncPsProgramBuilder"))
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs)
distributed_ops_pass.apply([self.cloned_main], [], self.pass_ctx)
delete_optimizer_pass = new_pass("delete_optimizer_pass", self.attrs)
delete_optimizer_pass.apply([None], [_startup], self.pass_ctx)
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
append_send_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
delete_extra_optimizer_pass = new_pass("delete_extra_optimizer_pass",
self.attrs)
delete_extra_optimizer_pass.apply([self.attrs['origin_main_program']],
[self.cloned_startup], self.pass_ctx)
fake_init_ops_pass = new_pass("fake_init_ops_pass", self.attrs)
fake_init_ops_pass.apply([None], [self.cloned_startup], self.pass_ctx)
if self.is_heter_worker:
split_heter_worker_ops_pass = new_pass(
"split_heter_worker_ops_pass", self.attrs)
split_heter_worker_ops_pass.apply([self.cloned_main], [None],
self.pass_ctx)
else:
split_trainer_ops_pass = new_pass("split_trainer_ops_pass",
self.attrs)
split_trainer_ops_pass([self.cloned_main], [], self.pass_ctx)
set_heter_pipeline_opt_pass = new_pass('set_heter_pipeline_opt_pass',
self.attrs)
set_heter_pipeline_opt_pass.apply([self.cloned_main],
[self.cloned_startup], pass_ctx)
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
return
def _build_programs(self):
if self.attrs['is_worker'] or self.attrs['is_heter_worker']:
self._build_trainer_programs()
ps_set_heter_pipeline_opt_pass = new_pass(
"set_heter_pipeline_opt_pass", self.attrs)
ps_set_heter_pipeline_opt_pass.apply(
[self.loss.block.program], [startup_program], self.pass_ctx)
elif self.attrs['is_server']:
self._build_pserver_programs()
self.loss.block.program = self.attrs['_main_server']
fluid.framework.switch_startup_program(self.attrs[
'_startup_server'])
class FlPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
super(FlPsProgramBuilder, self).__init__(pass_ctx)
def _build_trainer_programs(self):
pass
def _build_pserver_programs(self):
pass
def _build_programs(self):
pass
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import sys
import pickle
import shlex
import shutil
import inspect
import unittest
import numpy as np
from collections import OrderedDict
from dist_pass_test_base import prepare_python_path_and_return_module, remove_path_if_exists
import paddle.distributed.fleet as fleet
class PsPassTestBase(unittest.TestCase):
def init(self):
raise NotImplementedError
def setUp(self):
print('Ps setUp...')
def tearDown(self):
print('Ps tearDown...')
def ps_launch(self, config):
cmd = [
sys.executable,
"-u",
] + [
"-m", "launch", "--log_dir", config['log_dir'], "--worker_num",
config['worker_num'], "--server_num", config['server_num'],
"../ps/ps_dnn_trainer.py", "-m", config['ps_mode_config'],
"--run_minimize", config['run_minimize'], "--run_single_pass",
config['run_single_pass'], "--debug_new_pass",
config['debug_new_pass'], "--debug_new_minimize",
config['debug_new_minimize'], "--applied_pass_name",
config['applied_pass_name']
]
cmd = [shlex.quote(c) for c in cmd]
prepare_python_path_and_return_module(__file__)
exitcode = os.system(' '.join(cmd))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle
from ps_pass_test_base import *
from paddle.fluid.tests.unittests.ps.ps_dnn_trainer import DnnTrainer
class TestPsServerPass(PsPassTestBase):
def init(self):
pass
def setUp(self):
print('TestPsServerPass setUp...')
def tearDown(self):
print('TestPsServerPass tearDown...')
def test_add_lr_decay_table_passs(self):
pass
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle
from ps_pass_test_base import *
from paddle.fluid.tests.unittests.ps.ps_dnn_trainer import DnnTrainer
class TestPsTrainerPass(PsPassTestBase):
def init(self):
self.config = {}
self.config['ps_mode_config'] = "../ps/cpu_async_ps_config.yaml"
self.config['worker_num'] = "1"
self.config['server_num'] = "1"
self.config['run_minimize'] = "0"
self.config['run_single_pass'] = "0"
self.config['debug_new_minimize'] = "0"
self.config['debug_new_pass'] = "0"
self.config['log_dir'] = ""
self.config['applied_pass_name'] = ""
def setUp(self):
print('TestPsTrainerPass setUp...')
def tearDown(self):
print('TestPsTrainerPass tearDown...')
def check(self):
pass
def test_ps_optimizer_minimize(self):
self.init()
self.config['run_minimize'] = '1'
self.config['debug_new_minimize'] = '0'
self.config['log_dir'] = "/log_old_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.config['debug_new_minimize'] = '1'
self.config['log_dir'] = "/log_new_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.check()
def test_append_send_ops_pass(self):
self.init()
self.config['run_single_pass'] = '1'
self.config['applied_pass_name'] = "append_send_ops_pass"
self.config['debug_new_pass'] = '0'
self.config['log_dir'] = "/log_old_" + self.config['applied_pass_name']
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.config['debug_new_pass'] = '1'
self.config['log_dir'] = "/log_new_" + self.config['applied_pass_name']
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.check()
def test_distributed_ops_pass(self):
pass
if __name__ == '__main__':
unittest.main()
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
set_tests_properties(test_the_one_ps PROPERTIES TIMEOUT 50)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.p
# Note: On Windows, import form subdirectories such as dirA()->dirB(), current directory
# will still be dirA(), But is should be dirB(). So it will ModulNotFoundError
# please refer to https://stackoverflow.com/questions/8953844/import-module-from-subfolder
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
hyper_parameters:
optimizer:
class: Adam
learning_rate: 0.0001
adam_lazy_mode: True
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 10
dense_input_dim: 13
fc_sizes: [400, 400, 400]
runner:
geo_step: 400
sync_mode: "async" # sync / async / geo / heter
thread_num: 16
use_gpu: 0
model_path: "../ps_dnn_model.py"
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.ps.utils.ps_program_builder import *
import paddle.distributed.fleet as fleet
import argparse
import time
import sys
import yaml, six, copy
import paddle
import os
import warnings
import logging
import ast
import numpy as np
import struct
sys.path.append("..")
from ps_dnn_model import StaticModel
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def is_distributed_env():
node_role = os.getenv("TRAINING_ROLE")
logger.info("-- Role: {} --".format(node_role))
if node_role is None:
return False
else:
return True
class YamlHelper(object):
def load_yaml(self, yaml_file, other_part=None):
part_list = ["runner", "hyper_parameters"]
if other_part:
part_list += other_part
running_config = self.get_all_inters_from_yaml(yaml_file, part_list)
running_config = self.workspace_adapter(running_config)
return running_config
def print_yaml(self, config):
print(self.pretty_print_envs(config))
def parse_yaml(self, config):
vs = [int(i) for i in yaml.__version__.split(".")]
if vs[0] < 5:
use_full_loader = False
elif vs[0] > 5:
use_full_loader = True
else:
if vs[1] >= 1:
use_full_loader = True
else:
use_full_loader = False
if os.path.isfile(config):
if six.PY2:
with open(config, 'r') as rb:
if use_full_loader:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
_config = yaml.load(rb.read())
return _config
else:
with open(config, 'r', encoding="utf-8") as rb:
if use_full_loader:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
_config = yaml.load(rb.read())
return _config
else:
raise ValueError("config {} can not be supported".format(config))
def get_all_inters_from_yaml(self, file, filters):
_envs = self.parse_yaml(file)
all_flattens = {}
def fatten_env_namespace(namespace_nests, local_envs):
for k, v in local_envs.items():
if isinstance(v, dict):
nests = copy.deepcopy(namespace_nests)
nests.append(k)
fatten_env_namespace(nests, v)
else:
global_k = ".".join(namespace_nests + [k])
all_flattens[global_k] = v
fatten_env_namespace([], _envs)
ret = {}
for k, v in all_flattens.items():
for f in filters:
if k.startswith(f):
ret[k] = v
return ret
def workspace_adapter(self, config):
workspace = config.get("workspace")
for k, v in config.items():
if isinstance(v, str) and "{workspace}" in v:
config[k] = v.replace("{workspace}", workspace)
return config
def pretty_print_envs(self, envs, header=None):
spacing = 2
max_k = 40
max_v = 45
for k, v in envs.items():
max_k = max(max_k, len(k))
h_format = " " + "|{{:>{}s}}{}{{:^{}s}}|\n".format(max_k, " " *
spacing, max_v)
l_format = " " + "|{{:>{}s}}{{}}{{:^{}s}}|\n".format(max_k, max_v)
length = max_k + max_v + spacing
border = " +" + "".join(["="] * length) + "+"
line = " +" + "".join(["-"] * length) + "+"
draws = ""
draws += border + "\n"
if header:
draws += h_format.format(header[0], header[1])
else:
draws += h_format.format("PaddleRec Benchmark Envs", "Value")
draws += line + "\n"
for k, v in sorted(envs.items()):
if isinstance(v, str) and len(v) >= max_v:
str_v = "... " + v[-41:]
else:
str_v = v
draws += l_format.format(k, " " * spacing, str(str_v))
draws += border
_str = "\n{}\n".format(draws)
return _str
def get_user_defined_strategy(config):
if not is_distributed_env():
logger.warn(
"Not Find Distributed env, Change To local train mode. If you want train with fleet, please use [fleetrun] command."
)
return None
sync_mode = config.get("runner.sync_mode")
assert sync_mode in ["async", "sync", "geo", "heter", "gpubox"]
if sync_mode == "sync":
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = False
elif sync_mode == "async":
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
elif sync_mode == "geo":
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
strategy.a_sync_configs = {"k_steps": config.get("runner.geo_step")}
elif sync_mode == "heter":
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
strategy.a_sync_configs = {"heter_worker_device_guard": "gpu"}
elif sync_mode == "gpubox":
print("sync_mode = {}".format(sync_mode))
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
strategy.a_sync_configs = {"use_ps_gpu": 1}
strategy.trainer_desc_configs = {
"dump_fields_path": config.get("runner.dump_fields_path", ""),
"dump_fields": config.get("runner.dump_fields", []),
"dump_param": config.get("runner.dump_param", []),
"stat_var_names": config.get("stat_var_names", [])
}
print("strategy:", strategy.trainer_desc_configs)
if config.get("runner.fs_client.uri") is not None:
strategy.fs_client_param = {
"uri": config.get("runner.fs_client.uri", ""),
"user": config.get("runner.fs_client.user", ""),
"passwd": config.get("runner.fs_client.passwd", ""),
"hadoop_bin": config.get("runner.fs_client.hadoop_bin", "hadoop")
}
print("strategy:", strategy.fs_client_param)
strategy.adam_d2sum = config.get("hyper_parameters.adam_d2sum", True)
table_config = {}
for x in config:
if x.startswith("table_parameters"):
table_name = x.split('.')[1]
if table_name not in table_config:
table_config[table_name] = {}
table_config[table_name][x] = config[x]
print("table_config:", table_config)
strategy.sparse_table_configs = table_config
print("strategy table config:", strategy.sparse_table_configs)
a_sync_configs = strategy.a_sync_configs
a_sync_configs["launch_barrier"] = False
strategy.a_sync_configs = a_sync_configs
print("launch_barrier: ", strategy.a_sync_configs["launch_barrier"])
return strategy
def get_distributed_strategy(user_defined_strategy):
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
k_steps = user_defined_strategy.a_sync_configs["k_steps"]
strategy = None
if not user_defined_strategy.a_sync and k_steps == 0:
strategy = StrategyFactory.create_sync_strategy()
if user_defined_strategy.a_sync and k_steps == 0:
strategy = StrategyFactory.create_async_strategy()
if user_defined_strategy.a_sync and k_steps > 0:
strategy = StrategyFactory.create_geo_strategy(k_steps)
if not strategy:
raise ValueError("k_steps must be invalid value, please check")
return strategy
def get_model(config):
abs_dir = config['config_abs_dir']
sys.path.append(abs_dir)
static_model = StaticModel(config)
return static_model
def parse_args():
parser = argparse.ArgumentParser("PsTest train script")
parser.add_argument(
'-m', '--config_yaml', type=str, required=True, help='config file path')
parser.add_argument(
'-bf16',
'--pure_bf16',
type=ast.literal_eval,
default=False,
help="whether use bf16")
parser.add_argument(
'--run_minimize', type=int, default=0, help="test single pass")
parser.add_argument(
'--run_single_pass', type=int, default=0, help="test single pass")
parser.add_argument(
'--debug_new_minimize', type=int, default=0, help="test single pass")
parser.add_argument(
'--debug_new_pass', type=int, default=0, help="test single pass")
parser.add_argument(
'--applied_pass_name', type=str, default="", help="test single pass")
args = parser.parse_args()
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
yaml_helper = YamlHelper()
config = yaml_helper.load_yaml(args.config_yaml)
config["yaml_path"] = args.config_yaml
config["config_abs_dir"] = args.abs_dir
config["pure_bf16"] = args.pure_bf16
config['run_minimize'] = args.run_minimize
config['run_single_pass'] = args.run_single_pass
config['debug_new_minimize'] = args.debug_new_minimize
config['debug_new_pass'] = args.debug_new_pass
config['applied_pass_name'] = args.applied_pass_name
yaml_helper.print_yaml(config)
return config
def bf16_to_fp32(val):
return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
class DnnTrainer(object):
def __init__(self, config):
self.metrics = {}
self.config = config
self.input_data = None
self.reader = None
self.exe = None
self.train_result_dict = {}
self.train_result_dict["speed"] = []
self.model = None
self.pure_bf16 = self.config['pure_bf16']
self.role_maker = role_maker.PaddleCloudRoleMaker()
def init_fleet_with_gloo(self, use_gloo=False):
if use_gloo:
os.environ["PADDLE_WITH_GLOO"] = "1"
fleet.init(self.role_maker)
else:
fleet.init()
if fleet.is_server():
logger.info("server: {} started".format(fleet.server_index()))
else:
logger.info("worker: {} started".format(fleet.worker_index()))
def run_minimize(self):
logger.info("entering run_minimize")
self.init_fleet_with_gloo()
self.model = get_model(self.config)
logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
self.input_data = self.model.create_feeds()
self.metrics = self.model.net(self.input_data)
loss = self.model._cost
user_defined_strategy = get_user_defined_strategy(self.config)
learning_rate = self.config.get(
"hyper_parameters.optimizer.learning_rate")
inner_optimizer = paddle.optimizer.Adam(learning_rate, lazy_mode=True)
if self.config['debug_new_minimize'] == 1:
from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer
ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
user_defined_strategy)
ps_optimizer.minimize_impl(loss)
else:
fleet_obj = fleet.distributed_optimizer(
inner_optimizer, user_defined_strategy) ## Fleet 对象
fleet_obj.minimize(loss)
if fleet.is_server():
_main_file = '/' + 'run_minimize' + '_debug_minimize:_' + str(
self.config['debug_new_minimize']) + '_server_main.prototxt'
debug_program(_main_file, loss.block.program, 0)
elif fleet.is_worker():
_main_file = '/' + 'run_minimize' + '_debug_minimize:_' + str(
self.config['debug_new_minimize']) + '_worker_main.prototxt'
debug_program(_main_file, loss.block.program, 1)
'''
if fleet.is_server():
logger.info("Run Server Begin")
fleet.init_server()
fleet.run_server()
'''
def run_single_pass(self):
logger.info("entering run_single_pass")
self.init_fleet_with_gloo()
self.model = get_model(config)
input_data = self.model.create_feeds()
metrics = self.model.net(input_data)
loss = self.model._cost
user_defined_strategy = get_user_defined_strategy(config)
learning_rate = config.get("hyper_parameters.optimizer.learning_rate")
inner_optimizer = paddle.optimizer.Adam(learning_rate, lazy_mode=True)
startup_program = paddle.static.default_startup_program()
inner_optimizer.minimize(loss, startup_program)
if self.config['debug_new_pass'] == 1:
from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer
ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
user_defined_strategy)
ps_optimizer._init_ps_pass_context(loss, startup_program)
_main = ps_optimizer.attrs['cloned_main']
append_send_ops_pass = new_pass(config["applied_pass_name"],
ps_optimizer.attrs)
append_send_ops_pass.apply([_main], [None], ps_optimizer.pass_ctx)
else:
from paddle.fluid.incubate.fleet.parameter_server.ir import public as public
dist_strategy = get_distributed_strategy(user_defined_strategy)
compiled_config = public.CompileTimeStrategy(
loss.block.program, startup_program, dist_strategy,
self.role_maker)
_main = compiled_config.origin_main_program.clone()
_startup = compiled_config.origin_startup_program.clone()
from paddle.fluid.incubate.fleet.parameter_server.ir import trainer_pass as worker
_main = worker.append_send_ops_pass(_main, compiled_config)
if fleet.is_server():
_main_file = '/' + str(config[
"applied_pass_name"]) + '_debug_pass:_' + str(self.config[
'debug_new_pass']) + '_server_main.prototxt'
debug_program(_main_file, _main, 0)
elif fleet.is_worker():
_main_file = '/' + str(config[
"applied_pass_name"]) + '_debug_pass:_' + str(self.config[
'debug_new_pass']) + '_worker_main.prototxt'
debug_program(_main_file, _main, 1)
if __name__ == "__main__":
paddle.enable_static()
config = parse_args()
os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
benchmark_main = DnnTrainer(config)
if config['run_single_pass'] == 1:
benchmark_main.run_single_pass()
elif config['run_minimize'] == 1:
benchmark_main.run_minimize()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
class TestTheOnePs(unittest.TestCase):
def setUp(self):
print('setUp...')
def tearDown(self):
print('tearDown...')
def test_main(self):
pass
if __name__ == '__main__':
unittest.main()
此差异已折叠。
......@@ -270,6 +270,8 @@ packages=['paddle',
'paddle.reader',
'paddle.distributed',
'paddle.distributed.metric',
'paddle.distributed.ps',
'paddle.distributed.ps.utils',
'paddle.incubate',
'paddle.incubate.optimizer',
'paddle.incubate.checkpoint',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册