未验证 提交 22c67d14 编写于 作者: Z ziyoujiyi 提交者: GitHub

统一 ps 开发 - python (#39431)

* delete gloo connect retry

* the_one_ps dirs reconstruct

* .

* .

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* refactor ps optimize

* refactor ps optimize

* refactor ps optimize

* .

* .

* .

* .

* .

* .

* refactor theoneps

* the_one_ps

* add ps pass unittest

* add ps pass unittest

* ps unitest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* ps unittest ready

* ps unittest ready

* solve dist_pass init conflict

* solve import CommContext error

* unittest ok

* implement AllocateFrom

* solve setup.py.in conflict

* solve conflict

* solve conflict

* solve conflict

* .

* .

* cpu-async-ps minimize test ok & gpu minimize test ok
Co-authored-by: Nzkh2016 <zhangkaihuo@baidu.com>
上级 1c44d3e2
......@@ -31,7 +31,6 @@ class ParameterServerOptimizer(MetaOptimizerBase):
self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = []
self.attrs = {}
self.pass_ctx = PassContext()
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
......@@ -40,50 +39,48 @@ class ParameterServerOptimizer(MetaOptimizerBase):
loss, role_maker, user_defined_optimizer, user_defined_strategy)
def _init_ps_pass_context(self, loss, startup_program):
attrs = {}
# trainer
self.attrs["env"] = get_dist_env()
attrs["env"] = get_dist_env()
self.attrs['loss'] = loss
self.attrs['min_block_size'] = 81920
self.attrs['origin_main_program'] = loss.block.program
self.attrs['origin_startup_program'] = startup_program
attrs['loss'] = loss
attrs['min_block_size'] = 81920
attrs['origin_main_program'] = loss.block.program
attrs['origin_startup_program'] = startup_program
self.attrs['cloned_main'] = loss.block.program.clone()
self.attrs['cloned_startup'] = startup_program.clone()
attrs['cloned_main'] = attrs['origin_main_program'].clone()
attrs['cloned_startup'] = attrs['origin_startup_program'].clone()
self.attrs['user_defined_strategy'] = self.user_defined_strategy
self.attrs['trainer'] = TrainerRuntimeConfig(self.user_defined_strategy)
self.attrs['ps_mode'] = self.attrs['trainer'].mode
attrs['user_defined_strategy'] = self.user_defined_strategy
attrs['trainer'] = TrainerRuntimeConfig(self.user_defined_strategy)
attrs['ps_mode'] = attrs['trainer'].mode
self.attrs['role_maker'] = self.role_maker
self.attrs[
attrs['role_maker'] = self.role_maker
attrs[
'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode
self.attrs['is_worker'] = self.role_maker._is_worker()
self.attrs['is_server'] = self.role_maker._is_server()
self.attrs['is_heter_worker'] = self.role_maker._is_heter_worker()
attrs['is_worker'] = self.role_maker._is_worker()
attrs['is_server'] = self.role_maker._is_server()
attrs['is_heter_worker'] = self.role_maker._is_heter_worker()
self.attrs['use_ps_gpu'] = self.user_defined_strategy.a_sync_configs[
attrs['use_ps_gpu'] = self.user_defined_strategy.a_sync_configs[
"use_ps_gpu"]
self.attrs[
'lr_decay_steps'] = self.user_defined_strategy.a_sync_configs[
"lr_decay_steps"]
self.attrs['k_steps'] = self.user_defined_strategy.a_sync_configs[
"k_steps"]
self.attrs[
'launch_barrier'] = self.user_defined_strategy.a_sync_configs[
"launch_barrier"]
self.attrs['launch_barrier_flag'] = int(
attrs['lr_decay_steps'] = self.user_defined_strategy.a_sync_configs[
"lr_decay_steps"]
attrs['k_steps'] = self.user_defined_strategy.a_sync_configs["k_steps"]
attrs['launch_barrier'] = self.user_defined_strategy.a_sync_configs[
"launch_barrier"]
attrs['launch_barrier_flag'] = int(
os.getenv("FLAGS_LAUNCH_BARRIER", "1"))
build_var_distributed(self.attrs)
build_var_distributed(attrs)
# server
self.attrs['_main_server'] = fluid.Program()
self.attrs['_startup_server'] = fluid.Program()
self.attrs['tensor_table'] = {}
attrs['_main_server'] = fluid.Program()
attrs['_startup_server'] = fluid.Program()
attrs['tensor_table'] = {}
self.pass_ctx._attrs = self.attrs
self.pass_ctx._attrs = attrs
def _is_graph_out(self):
return False
......
......@@ -115,11 +115,11 @@ class AddLrDecayTablePass(PassBase):
LRScheduler), "must be LRScheduler"
ops = get_optimize_ops(attrs['origin_main_program'])
lr_decay_main_program, lr_decay_startup_program, lr_name = _get_lr_sheduler_program(
lr_decay_main_program, lr_decay_startup_program, lr_name = self._get_lr_sheduler_program(
attrs['origin_main_program'].lr_sheduler, attrs['lr_decay_steps'])
_add_tensor_table(attrs, "@LR_DECAY_COUNTER@", lr_name,
lr_decay_startup_program, lr_decay_main_program,
"GlobalStepTable")
self._add_tensor_table(attrs, "@LR_DECAY_COUNTER@", lr_name,
lr_decay_startup_program, lr_decay_main_program,
"GlobalStepTable")
return
......
......@@ -41,7 +41,7 @@ class PsProgramBuilder(object):
pass
def _build_trainer_programs(self):
pass
raise NotImplementedError
def _build_pserver_programs(self):
is_sgd_adam = False
......@@ -60,11 +60,13 @@ class PsProgramBuilder(object):
def _build_programs(self):
if self.attrs['is_worker']:
logger.info("start building trainer program")
self._build_trainer_programs()
fluid.framework.switch_startup_program(self.cloned_startup)
self.loss.block.program = self.cloned_main
elif self.attrs['is_server']:
logger.info("start building pserver program")
self._build_pserver_programs()
self.loss.block.program = self.attrs['_main_server']
fluid.framework.switch_startup_program(self.attrs[
......@@ -73,6 +75,7 @@ class PsProgramBuilder(object):
class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式
def __init__(self, pass_ctx):
logger.info("start building geo-ps program")
super(GeoPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode != DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}",
......@@ -92,6 +95,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式
class CpuSyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
logger.info("start building cpu-sync-ps program")
super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx)
if self.ps_mode == DistributedMode.GEO:
raise ValueError("ps mode: {} not matched {}",
......@@ -130,14 +134,17 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder):
class CpuAsyncPsProgramBuilder(CpuSyncPsProgramBuilder):
def __init__(self, pass_ctx):
logger.info("start building cpu-async-ps program")
super(CpuAsyncPsProgramBuilder, self).__init__(pass_ctx)
class GpuPsProgramBuilder(PsProgramBuilder): # 和 geo、sync、async 等模式无关
class GpuPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
logger.info("start building gpu-ps program")
super(GpuPsProgramBuilder, self).__init__(pass_ctx)
def _build_trainer_programs(self):
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
......@@ -152,7 +159,8 @@ class GpuPsProgramBuilder(PsProgramBuilder): # 和 geo、sync、async 等模式
ps_gpu_pass.apply([self.cloned_main], [None], self.pass_ctx)
ps_transpile_pass = new_pass("ps_transpile_pass", self.attrs)
ps_transpile_pass.apply([_main], [_startup], self.pass_ctx)
ps_transpile_pass.apply([self.cloned_main], [self.cloned_startup],
self.pass_ctx)
self.attrs['origin_main_program'] = self.cloned_main
self.attrs['origin_startup_program'] = self.cloned_startup
......@@ -165,6 +173,7 @@ class GpuPsProgramBuilder(PsProgramBuilder): # 和 geo、sync、async 等模式
class HeterAsyncPsProgramBuilder(PsProgramBuilder):
def __init__(self, pass_ctx):
logger.info("start building heter-async-ps program")
super(HeterAsyncPsProgramBuilder, self).__init__(pass_ctx)
if self.use_ps_gpu or self.ps_mode == DistributedMode.GEO or self.attrs[
'is_heter_ps_mode'] == False:
......
......@@ -27,6 +27,10 @@ from paddle.fluid.core import CommContext
import paddle.fluid.framework as framework
import paddle.distributed.fleet as fleet
#logging.basicConfig(
# format='%(levelname)s - %(asctime)s - %(pathname)s: %(lineno)s - %(message)s', level=logging.INFO)
#logger = logging.getLogger(__name__)
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
......@@ -43,6 +47,24 @@ SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"]
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
def logger_config(log_path, logging_name):
logger = logging.getLogger(logging_name)
logger.setLevel(level=logging.DEBUG)
handler = logging.FileHandler(log_path, mode='a', encoding='UTF-8')
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(levelname)s - %(asctime)s - %(pathname)s: %(lineno)s - %(message)s')
handler.setFormatter(formatter)
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
logger.addHandler(handler)
logger.addHandler(console)
return logger
logger = logger_config(log_path='/ps_log', logging_name='ps_log')
class DistributedMode:
SYNC = 0
ASYNC = 1
......
......@@ -36,20 +36,48 @@ class PsPassTestBase(unittest.TestCase):
def tearDown(self):
print('Ps tearDown...')
def ps_launch(self, config):
cmd = [
sys.executable,
"-u",
] + [
"-m", "launch", "--log_dir", config['log_dir'], "--worker_num",
config['worker_num'], "--server_num", config['server_num'],
"../ps/ps_dnn_trainer.py", "-m", config['ps_mode_config'],
"--run_minimize", config['run_minimize'], "--run_single_pass",
config['run_single_pass'], "--debug_new_pass",
config['debug_new_pass'], "--debug_new_minimize",
config['debug_new_minimize'], "--applied_pass_name",
config['applied_pass_name']
]
def ps_launch(self, config, ps_mode="cpu-ps"):
if ps_mode == "cpu-ps":
os.environ['WITH_DISTRIBUTE'] = 'ON'
cmd = [
sys.executable,
"-u",
] + [
"-m", "launch", "--log_dir", config['log_dir'], "--worker_num",
config['worker_num'], "--server_num", config['server_num'],
"../ps/ps_dnn_trainer.py", "-m", config['ps_mode_config'],
"--run_minimize", config['run_minimize'], "--run_single_pass",
config['run_single_pass'], "--debug_new_pass",
config['debug_new_pass'], "--debug_new_minimize",
config['debug_new_minimize'], "--applied_pass_name",
config['applied_pass_name']
]
elif ps_mode == "gpu-ps":
os.environ['FLAGS_LAUNCH_BARRIER'] = '0'
os.environ['PADDLE_PSERVER_NUMS'] = '1'
os.environ['PADDLE_TRAINERS_NUM'] = '1'
os.environ['POD_IP'] = '127.0.0.1'
os.environ['PADDLE_PSERVERS_IP_PORT_LIST'] = '127.0.0.1:29011'
os.environ['PADDLE_PORT'] = '29011'
os.environ['FLAGS_selected_gpus'] = '0,1,2,3,4,5,6,7'
# pserver
# os.environ['TRAINING_ROLE'] = 'PSERVER'
# trainer
os.environ['TRAINING_ROLE'] = 'TRAINER'
os.environ['PADDLE_TRAINER_ID'] = '0'
cmd = [
sys.executable, "-u", "../ps/ps_dnn_trainer.py", "-m",
config['ps_mode_config'], "--run_minimize",
config['run_minimize'], "--run_single_pass",
config['run_single_pass'], "--debug_new_pass",
config['debug_new_pass'], "--debug_new_minimize",
config['debug_new_minimize'], "--applied_pass_name",
config['applied_pass_name']
]
cmd = [shlex.quote(c) for c in cmd]
prepare_python_path_and_return_module(__file__)
exitcode = os.system(' '.join(cmd))
......@@ -21,6 +21,7 @@ import numpy as np
import paddle
from ps_pass_test_base import *
from paddle.distributed.ps.utils.public import logger
from paddle.fluid.tests.unittests.ps.ps_dnn_trainer import DnnTrainer
......@@ -38,30 +39,43 @@ class TestPsTrainerPass(PsPassTestBase):
self.config['applied_pass_name'] = ""
def setUp(self):
print('TestPsTrainerPass setUp...')
pass
def tearDown(self):
print('TestPsTrainerPass tearDown...')
pass
def check(self):
pass
def test_ps_optimizer_minimize(self):
def test_ps_optimizer_minimize_cpu(self):
self.init()
self.config['run_minimize'] = '1'
self.config['debug_new_minimize'] = '0'
self.config['log_dir'] = "/log_old_minimize"
self.config['log_dir'] = "/cpu_log_old_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.config['debug_new_minimize'] = '1'
self.config['log_dir'] = "/log_new_minimize"
self.config['log_dir'] = "/cpu_log_new_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.check()
def test_ps_optimizer_minimize_gpu(self):
self.init()
self.config['run_minimize'] = '1'
self.config['ps_mode_config'] = "../ps/gpu_ps_config.yaml"
self.config['debug_new_minimize'] = '0'
self.ps_launch(self.config, "gpu-ps")
self.config['debug_new_minimize'] = '1'
self.ps_launch(self.config, "gpu-ps")
self.check()
def test_append_send_ops_pass(self):
self.init()
self.config['run_single_pass'] = '1'
......@@ -70,12 +84,12 @@ class TestPsTrainerPass(PsPassTestBase):
self.config['debug_new_pass'] = '0'
self.config['log_dir'] = "/log_old_" + self.config['applied_pass_name']
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.ps_launch(self.config, "cpu-ps")
self.config['debug_new_pass'] = '1'
self.config['log_dir'] = "/log_new_" + self.config['applied_pass_name']
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config)
self.ps_launch(self.config, "cpu-ps")
self.check()
......@@ -84,4 +98,5 @@ class TestPsTrainerPass(PsPassTestBase):
if __name__ == '__main__':
remove_path_if_exists('/ps_log')
unittest.main()
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# refer to PaddleRec/models/rank/dnn/benchmark.yaml
hyper_parameters:
optimizer:
class: Adam
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# refer to PaddleRec/models/rank/dnn/config_gpubox.yaml
hyper_parameters:
optimizer:
class: Adam
learning_rate: 0.001
strategy: async
sparse_inputs_slots: 27
sparse_feature_number: 1024
sparse_feature_dim: 11
dense_input_dim: 13
fc_sizes: [512, 256, 128, 32]
distributed_embedding: 0
runner:
geo_step: 400
sync_mode: "gpubox"
thread_num: 16
use_gpu: 1
model_path: "../ps_dnn_model.py"
......@@ -33,10 +33,6 @@ from ps_dnn_model import StaticModel
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def is_distributed_env():
node_role = os.getenv("TRAINING_ROLE")
......@@ -140,7 +136,7 @@ class YamlHelper(object):
if header:
draws += h_format.format(header[0], header[1])
else:
draws += h_format.format("PaddleRec Benchmark Envs", "Value")
draws += h_format.format("Ps Benchmark Envs", "Value")
draws += line + "\n"
......@@ -163,7 +159,7 @@ def get_user_defined_strategy(config):
logger.warn(
"Not Find Distributed env, Change To local train mode. If you want train with fleet, please use [fleetrun] command."
)
return None
#return None
sync_mode = config.get("runner.sync_mode")
assert sync_mode in ["async", "sync", "geo", "heter", "gpubox"]
if sync_mode == "sync":
......@@ -318,7 +314,6 @@ class DnnTrainer(object):
logger.info("worker: {} started".format(fleet.worker_index()))
def run_minimize(self):
logger.info("entering run_minimize")
self.init_fleet_with_gloo()
self.model = get_model(self.config)
logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
......@@ -328,36 +323,32 @@ class DnnTrainer(object):
user_defined_strategy = get_user_defined_strategy(self.config)
learning_rate = self.config.get(
"hyper_parameters.optimizer.learning_rate")
sync_mode = self.config.get("runner.sync_mode")
inner_optimizer = paddle.optimizer.Adam(learning_rate, lazy_mode=True)
if self.config['debug_new_minimize'] == 1:
logger.info("entering run_minimize -- new")
from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer
ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
user_defined_strategy)
ps_optimizer.minimize_impl(loss)
else:
logger.info("entering run_minimize -- old")
fleet_obj = fleet.distributed_optimizer(
inner_optimizer, user_defined_strategy) ## Fleet 对象
fleet_obj.minimize(loss)
if fleet.is_server():
_main_file = '/' + 'run_minimize' + '_debug_minimize:_' + str(
_main_file = '/' + sync_mode + '_run_minimize' + '_debug:_' + str(
self.config['debug_new_minimize']) + '_server_main.prototxt'
debug_program(_main_file, loss.block.program, 0)
elif fleet.is_worker():
_main_file = '/' + 'run_minimize' + '_debug_minimize:_' + str(
_main_file = '/' + sync_mode + '_run_minimize' + '_debug:_' + str(
self.config['debug_new_minimize']) + '_worker_main.prototxt'
debug_program(_main_file, loss.block.program, 1)
'''
if fleet.is_server():
logger.info("Run Server Begin")
fleet.init_server()
fleet.run_server()
'''
def run_single_pass(self):
logger.info("entering run_single_pass")
self.init_fleet_with_gloo()
self.model = get_model(config)
input_data = self.model.create_feeds()
......@@ -365,22 +356,26 @@ class DnnTrainer(object):
loss = self.model._cost
user_defined_strategy = get_user_defined_strategy(config)
learning_rate = config.get("hyper_parameters.optimizer.learning_rate")
sync_mode = self.config.get("runner.sync_mode")
inner_optimizer = paddle.optimizer.Adam(learning_rate, lazy_mode=True)
startup_program = paddle.static.default_startup_program()
inner_optimizer.minimize(loss, startup_program)
if self.config['debug_new_pass'] == 1:
logger.info("entering run {} - new".format(
str(config["applied_pass_name"])))
from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer
ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
user_defined_strategy)
ps_optimizer._init_ps_pass_context(loss, startup_program)
_main = ps_optimizer.attrs['cloned_main']
_main = ps_optimizer.pass_ctx._attrs['cloned_main']
append_send_ops_pass = new_pass(config["applied_pass_name"],
ps_optimizer.attrs)
ps_optimizer.pass_ctx._attrs)
append_send_ops_pass.apply([_main], [None], ps_optimizer.pass_ctx)
else:
logger.info("entering run {} - old".format(
str(config["applied_pass_name"])))
from paddle.fluid.incubate.fleet.parameter_server.ir import public as public
dist_strategy = get_distributed_strategy(user_defined_strategy)
compiled_config = public.CompileTimeStrategy(
......@@ -393,13 +388,13 @@ class DnnTrainer(object):
_main = worker.append_send_ops_pass(_main, compiled_config)
if fleet.is_server():
_main_file = '/' + str(config[
"applied_pass_name"]) + '_debug_pass:_' + str(self.config[
_main_file = '/' + sync_mode + "_" + str(config[
"applied_pass_name"]) + '_debug:_' + str(self.config[
'debug_new_pass']) + '_server_main.prototxt'
debug_program(_main_file, _main, 0)
elif fleet.is_worker():
_main_file = '/' + str(config[
"applied_pass_name"]) + '_debug_pass:_' + str(self.config[
_main_file = '/' + sync_mode + "_" + str(config[
"applied_pass_name"]) + '_debug:_' + str(self.config[
'debug_new_pass']) + '_worker_main.prototxt'
debug_program(_main_file, _main, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册