未验证 提交 e35ad3ee 编写于 作者: D danleifeng 提交者: GitHub

【paddle.fleet】support running python train.py for fleet tasks (#26249)

* support running python train.py for fleet-task; test=develop
上级 9cb57f94
......@@ -13,7 +13,9 @@
# limitations under the License.
from __future__ import print_function
import warnings
import paddle
from paddle.fluid import compiler
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
from .strategy_compiler import StrategyCompiler
from .distributed_strategy import DistributedStrategy
......@@ -35,7 +37,24 @@ def _inited_runtime_handler_(func):
return __impl__
def _is_non_distributed_check_(func):
def __impl__(*args, **kwargs):
cls = args[0]
if cls._role_maker is not None and cls._role_maker._is_non_distributed(
) is True:
warnings.warn(
"%s() function doesn't work when use non_distributed fleet." %
(func.__name__))
return
return func(*args, **kwargs)
return __impl__
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
class Fleet(object):
......@@ -367,6 +386,7 @@ class Fleet(object):
"""
self._role_maker.barrier_worker()
@is_non_distributed_check
@inited_runtime_handler
def init_worker(self):
"""
......@@ -391,6 +411,7 @@ class Fleet(object):
"""
self._runtime_handle._init_worker()
@is_non_distributed_check
@inited_runtime_handler
def init_server(self, *args, **kwargs):
"""
......@@ -416,6 +437,7 @@ class Fleet(object):
"""
self._runtime_handle._init_server(*args, **kwargs)
@is_non_distributed_check
@inited_runtime_handler
def run_server(self):
"""
......@@ -440,6 +462,7 @@ class Fleet(object):
"""
self._runtime_handle._run_server()
@is_non_distributed_check
@inited_runtime_handler
def stop_worker(self):
"""
......@@ -593,8 +616,8 @@ class Fleet(object):
tuple: tuple (optimize_ops, params_grads), A list of operators appended
by minimize and a list of (param, grad) variable pairs, param is
``Parameter``, grad is the gradient value corresponding to the parameter.
The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
``fetch_list`` before run, see details in ``Executor``.
Examples:
......@@ -672,6 +695,20 @@ class Fleet(object):
optimize_ops = []
params_grads = []
if self._role_maker._is_non_distributed() and not self._is_collective:
if self._runtime_handle is None:
self._runtime_handle = RuntimeFactory()._create_runtime(context)
compiled_program = compiler.CompiledProgram(
self.origin_main_program).with_data_parallel(
loss_name=loss.name, share_vars_from=None)
loss.block.program._graph = compiled_program
return self.user_defined_optimizer.minimize(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
if meta_optimizer:
optimize_ops, params_grads = meta_optimizer.minimize(
loss,
......
......@@ -232,6 +232,8 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._node_type_comm = None
self._all_comm = None
self._non_distributed = False
if not self._is_collective:
self._hdfs_name = kwargs.get("hdfs_name", "")
self._hdfs_ugi = kwargs.get("hdfs_ugi", "")
......@@ -373,6 +375,15 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self.generate_role()
return self._server_endpoints
def _is_non_distributed(self):
"""
Return True if indispensable environment for fleetrun is not found
(use python-run to launch fleet-code directly)
"""
if not self._role_is_generated:
self.generate_role()
return self._non_distributed
def _heter_worker_num(self):
"""
get heter worker nums
......@@ -409,13 +420,22 @@ class PaddleCloudRoleMaker(RoleMakerBase):
try:
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST",
"").split(",")
assert self._server_endpoints != ""
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST")
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
"").split(",")
assert self._server_endpoints != ""
if self._server_endpoints is None:
# back to non_distributed execution.
self._server_endpoints = ""
self._trainers_num = 1
self._role = Role.WORKER
self._current_id = 0
self._node_num = 1
self._heter_trainers_num = 0
self._heter_trainer_endpoints = None
self._non_distributed = True
return
self._server_endpoints = self._server_endpoints.split(",")
trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"])
training_role = os.environ["TRAINING_ROLE"]
......@@ -488,7 +508,11 @@ class PaddleCloudRoleMaker(RoleMakerBase):
assert (self._training_role == "TRAINER")
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
self._cur_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
assert self._worker_endpoints is not None, "can't find PADDLE_TRAINER_ENDPOINTS"
if self._worker_endpoints is None:
# back to non_distributed execution.
self._worker_endpoints = "127.0.0.1:6170"
self._cur_endpoint = self._worker_endpoints
self._non_distributed = True
self._worker_endpoints = self._worker_endpoints.split(",")
self._trainers_num = len(self._worker_endpoints)
self._node_num = len(
......
......@@ -18,6 +18,7 @@ import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import os
import paddle.fluid as fluid
import numpy as np
class TestFleetBase(unittest.TestCase):
......@@ -125,5 +126,84 @@ class TestFleetBase(unittest.TestCase):
self.assertRaises(Exception, fleet.init_worker)
class TestFleetBaseSingleRunCollective(unittest.TestCase):
def setUp(self):
os.environ.pop("PADDLE_TRAINER_ENDPOINTS")
def gen_data(self):
return {
"x": np.random.random(size=(128, 32)).astype('float32'),
"y": np.random.randint(
2, size=(128, 1)).astype('int64')
}
def test_single_run_collective_minimize(self):
input_x = paddle.static.data(name="x", shape=[-1, 32], dtype='float32')
input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64')
fc_1 = fluid.layers.fc(input=input_x, size=64, act='tanh')
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
avg_cost = paddle.mean(x=cost)
fleet.init(is_collective=True)
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(avg_cost)
place = fluid.CUDAPlace(0) if paddle.fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(paddle.static.default_startup_program())
for i in range(10):
cost_val = exe.run(feed=self.gen_data(), fetch_list=[avg_cost.name])
print("cost of step[{}] = {}".format(i, cost_val))
class TestFleetBaseSingleRunPS(unittest.TestCase):
def setUp(self):
os.environ.pop("PADDLE_PSERVERS_IP_PORT_LIST")
def gen_data(self):
return {
"x": np.random.random(size=(128, 32)).astype('float32'),
"y": np.random.randint(
2, size=(128, 1)).astype('int64')
}
def test_single_run_ps_minimize(self):
input_x = paddle.static.data(name="x", shape=[-1, 32], dtype='float32')
input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64')
fc_1 = fluid.layers.fc(input=input_x, size=64, act='tanh')
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
avg_cost = paddle.mean(x=cost)
fleet.init()
strategy = paddle.distributed.fleet.DistributedStrategy()
optimizer = fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
if fleet.is_server():
fleet.init_server()
fleet.run_server()
elif fleet.is_worker():
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(paddle.static.default_startup_program())
step = 100
for i in range(step):
cost_val = exe.run(program=fluid.default_main_program(),
feed=self.gen_data(),
fetch_list=[avg_cost.name])
print("worker_index: %d, step%d cost = %f" %
(fleet.worker_index(), i, cost_val[0]))
fleet.save_persistables(exe, "fleet_single_model/")
print("save fleet models done.")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册