未验证 提交 a01bc6b3 编写于 作者: D danleifeng 提交者: GitHub

【paddle.fleet】fleet support non_distributed training in dygraph mode (#27714)

* fleet support non_distributed training in dygraph mode; test=develop
上级 e496640b
......@@ -187,6 +187,8 @@ class Fleet(object):
self.strategy_compiler = StrategyCompiler()
if paddle.fluid.framework.in_dygraph_mode():
if self.worker_num() == 1:
return
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn(
"The dygraph parallel environment has been initialized.")
......
......@@ -18,6 +18,7 @@ import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import os
import paddle.fluid as fluid
import paddle.nn as nn
import numpy as np
......@@ -170,6 +171,44 @@ class TestFleetDygraph(unittest.TestCase):
final_strategy = fleet._final_strategy()
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
class TestFleetDygraphSingle(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_TRAINERS_NUM"] = "1"
os.environ["PADDLE_TRAINER_ID"] = "0"
def test_dygraph_single(self):
paddle.disable_static()
fleet.init(is_collective=True)
layer = LinearNet()
loss_fn = nn.MSELoss()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
for step in range(2):
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss.backward()
adam.step()
adam.clear_grad()
class TestFleetBaseSingleRunCollective(unittest.TestCase):
def setUp(self):
os.environ.pop("PADDLE_TRAINER_ENDPOINTS")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册