未验证 提交 f2d68d3e 编写于 作者: 1 123malin 提交者: GitHub

【paddle.fleet】parameter_server_optimizer support auto_strategy (#26838)

* test=develop, add ps auto
上级 4c70e31a
......@@ -13,6 +13,10 @@
from paddle import fluid
from .meta_optimizer_base import MetaOptimizerBase
from paddle.fluid import core
import subprocess
import re
import platform
class ParameterServerOptimizer(MetaOptimizerBase):
......@@ -28,6 +32,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
def _can_apply(self):
if self.role_maker._is_collective:
return False
if self.user_defined_strategy.auto == True:
return True
k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
return True if k_steps >= 0 else False
......@@ -127,6 +134,105 @@ class ParameterServerOptimizer(MetaOptimizerBase):
return _main, _startup
def _try_auto_apply_geo(self, program, compiled_config):
def get_sys_free_mem():
plat = platform.system()
if platform.system() == "Darwin":
vm = subprocess.Popen(
['vm_stat'], stdout=subprocess.PIPE).communicate()[0]
# Process vm_stat
vmLines = vm.split('\n')
sep = re.compile(':[\s]+')
vmStats = {}
for row in range(1, len(vmLines) - 2):
rowText = vmLines[row].strip()
rowElements = sep.split(rowText)
vmStats[(rowElements[0]
)] = int(rowElements[1].strip('\.')) * 4096
return vmStats["Pages free"]
elif platform.system() == "Linux":
mems = {}
with open('/proc/meminfo', 'rb') as f:
for line in f:
fields = line.split()
mems[fields[0]] = int(fields[1]) * 1024
free = mems[b'MemFree:']
return free
else:
raise ValueError(
"%s platform is unsupported is parameter server optimizer" %
(platform.system()))
if self.user_defined_strategy.auto == False:
return
a_sync_configs = self.user_defined_strategy.a_sync_configs
if a_sync_configs["k_steps"] >= 0:
return
self.user_defined_strategy.a_sync = True
if not isinstance(self.inner_opt, fluid.optimizer.SGDOptimizer):
# auto async
a_sync_configs["k_steps"] = 0
self.user_defined_strategy.a_sync_configs = a_sync_configs
return
from paddle.fluid.incubate.fleet.parameter_server.ir.vars_metatools import dtype_to_size
free = get_sys_free_mem()
param_grad_pairs = compiled_config.origin_sparse_pairs + compiled_config.origin_dense_pairs
processed_var_names = set(["@EMPTY@"])
param_memory_size = 0
for param_grad_pair in param_grad_pairs:
param, grad = param_grad_pair
param_memory_size += param.m_size
processed_var_names.add(param.name)
upper_mem_use = param_memory_size * 5.0
program_tmp_vars = dict()
batch_size = 1024
for op in program.global_block().ops:
for var_name in op.output_arg_names:
if var_name in processed_var_names:
continue
processed_var_names.add(var_name)
var = program.global_block().vars[var_name]
if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
continue
data_count = 1
neg_dim_count = 0
for x in var.shape:
if x < 0:
if neg_dim_count >= 1:
raise ValueError(
"Var %s has more than one negative dim." %
(var_name))
neg_dim_count += 1
data_count *= (-x)
else:
data_count *= x
program_tmp_vars[var_name] = (data_count, neg_dim_count,
dtype_to_size[var.dtype])
for varname in program_tmp_vars:
data_count, neg_dim_count, type_size = program_tmp_vars[varname]
if neg_dim_count == 1:
data_count *= batch_size
var_memory = data_count * type_size
upper_mem_use += var_memory
if upper_mem_use < free:
# auto geo
a_sync_configs["k_steps"] = 800
else:
# auto async
a_sync_configs["k_steps"] = 0
self.user_defined_strategy.a_sync_configs = a_sync_configs
def minimize_impl(self,
loss,
startup_program=None,
......@@ -134,7 +240,6 @@ class ParameterServerOptimizer(MetaOptimizerBase):
no_grad_set=None):
self.inner_opt.minimize(loss, startup_program, parameter_list,
no_grad_set)
strategy = self._get_distributed_strategy()
_origin_main_program = loss.block.program
_origin_startup_program = startup_program
......@@ -142,7 +247,12 @@ class ParameterServerOptimizer(MetaOptimizerBase):
compiled_config = public.CompileTimeStrategy(_origin_main_program,
_origin_startup_program,
strategy, self.role_maker)
None, self.role_maker)
self._try_auto_apply_geo(_origin_main_program, compiled_config)
strategy = self._get_distributed_strategy()
compiled_config.strategy = strategy
if self.role_maker.is_worker() or self.role_maker._is_heter_worker():
main_program, startup_program = self._build_trainer_programs(
......
......@@ -12,9 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from functools import reduce
from paddle.fluid.framework import Variable
from paddle.fluid import core
dtype_to_size = {
core.VarDesc.VarType.FP16: 2,
core.VarDesc.VarType.FP32: 4,
core.VarDesc.VarType.FP64: 8,
core.VarDesc.VarType.INT16: 2,
core.VarDesc.VarType.INT32: 4,
core.VarDesc.VarType.INT64: 8,
core.VarDesc.VarType.BOOL: 1,
core.VarDesc.VarType.UINT8: 1,
}
class VarBlock:
def __init__(self, varname, offset, size):
......@@ -51,11 +64,14 @@ class VarStruct(object):
self.type = type
self.lod_level = lod_level
self.persistable = persistable
self.m_size = 1
self.m_size = reduce(lambda x, y: x * y, shape)
self.m_size *= dtype_to_size[dtype]
def __str__(self):
return "N: {}, S: {}, D: {}, T: {}, LL: {}, P: {}".format(
return "N: {}, S: {}, D: {}, T: {}, LL: {}, P: {}, M: {}".format(
self.name, self.shape, self.dtype, self.type, self.lod_level,
self.persistable)
self.persistable, self.m_size)
class VarDistributed(object):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import os
import paddle.distributed.fleet.base.role_maker as role_maker
import time
class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_PSERVER_NUMS"] = "2"
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
"127.0.0.1:36001,127.0.0.2:36001"
def test_a_sync_optimizer1(self):
os.environ["TRAINING_ROLE"] = "TRAINER"
import paddle.distributed.fleet as fleet
main_program = paddle.fluid.Program()
startup_program = paddle.fluid.Program()
paddle.fluid.framework.switch_main_program(main_program)
paddle.fluid.framework.switch_startup_program(startup_program)
fleet.init(role_maker.PaddleCloudRoleMaker())
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.auto = True
optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 0)
def test_a_sync_optimizer2(self):
os.environ["TRAINING_ROLE"] = "TRAINER"
import paddle.distributed.fleet as fleet
main_program = paddle.fluid.Program()
startup_program = paddle.fluid.Program()
paddle.fluid.framework.switch_main_program(main_program)
paddle.fluid.framework.switch_startup_program(startup_program)
fleet.init(role_maker.PaddleCloudRoleMaker())
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.auto = True
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 800)
def test_a_sync_optimizer3(self):
os.environ["TRAINING_ROLE"] = "TRAINER"
import paddle.distributed.fleet as fleet
main_program = paddle.fluid.Program()
startup_program = paddle.fluid.Program()
paddle.fluid.framework.switch_main_program(main_program)
paddle.fluid.framework.switch_startup_program(startup_program)
fleet.init(role_maker.PaddleCloudRoleMaker())
input_x = paddle.fluid.layers.data(
name="x",
shape=[-1, 1],
dtype="int64",
lod_level=1,
append_batch_size=False)
x_embedding = paddle.fluid.layers.embedding(
is_distributed=False,
input=input_x,
size=[1000000000, 100000],
param_attr=paddle.fluid.ParamAttr(
name="embedding",
initializer=paddle.fluid.initializer.Constant(value=0.01)),
is_sparse=True)
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=x_embedding, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.auto = True
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 0)
if __name__ == "__main__":
unittest.main()
......@@ -76,9 +76,10 @@ class FleetDistRunnerBase(object):
return role
def build_strategy(self, args):
self.strategy = paddle.distributed.fleet.DistributedStrategy()
self.strategy.a_sync = False
if args.mode == "async":
if args.mode == "sync":
self.strategy = paddle.distributed.fleet.DistributedStrategy()
self.strategy.a_sync = False
elif args.mode == "async":
self.strategy = paddle.distributed.fleet.DistributedStrategy()
self.strategy.a_sync = True
elif args.mode == "geo":
......@@ -87,6 +88,10 @@ class FleetDistRunnerBase(object):
self.strategy.a_sync_configs = {
"k_steps": args.geo_sgd_need_push_nums
}
elif args.mode == "auto":
self.strategy = paddle.distributed.fleet.DistributedStrategy()
self.strategy.auto = True
self.dump_param = os.getenv("dump_param", "").split(",")
self.dump_fields = os.getenv("dump_fields", "").split(",")
self.dump_fields_path = os.getenv("dump_fields_path", "")
......@@ -232,14 +237,17 @@ class TestFleetBase(unittest.TestCase):
tr0_pipe = open(tempfile.gettempdir() + "/tr0_err.log", "wb+")
tr1_pipe = open(tempfile.gettempdir() + "/tr1_err.log", "wb+")
tr0_out = open(tempfile.gettempdir() + "/tr0_stdout.log", "wb+")
tr1_out = open(tempfile.gettempdir() + "/tr1_stdout.log", "wb+")
tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stdout=tr0_out,
stderr=tr0_pipe,
env=required_envs)
tr1_proc = subprocess.Popen(
tr1_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stdout=tr1_out,
stderr=tr1_pipe,
env=required_envs)
......
......@@ -52,6 +52,38 @@ class TestDistMnistSync2x2(TestFleetBase):
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistMnistAuto2x2(TestFleetBase):
def _setup_config(self):
self._mode = "auto"
self._reader = "pyreader"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "2"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistMnistAsync2x2(TestFleetBase):
def _setup_config(self):
self._mode = "async"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册