Created by: danleifeng
PR types
New features
PR changes
APIs
Describe
This PR makes fleet support dygraph mode.
distributed_model
new api for fleet
1. add 2. add 6 api for fitting new 2.0 dygraph-mode optimizer :
set_lr
get_lr
state_dict
set_state_dict
step
clear_grad
3. add ditributed dygraph unittest for fleet code
Sample code:
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize fleet environment
fleet.init(is_collective=True)
# 3. create layer & optimizer
layer = LinearNet()
loss_fn = nn.MSELoss()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=layer.parameters())
# 4. get data_parallel model using fleet
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
# 5. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
train()