未验证 提交 9a3e1bce 编写于 作者: K kuizhiqing 提交者: GitHub

[LAUNCH] add distributed launch check tools (#44495)

* add launch test

* launch test for cpu

* bs 1
上级 067107ad
......@@ -101,6 +101,7 @@ class Context(object):
return False
def set_env_in_args(self):
# this logic may not propre to replace args with env, but ...
for k, v in env_args_mapping.items():
if k in self.envs:
setattr(self.args, v, self.envs[k])
setattr(self.args, v, type(getattr(self.args, v))(self.envs[k]))
......@@ -97,10 +97,14 @@ class CollectiveController(Controller):
"PADDLE_TRAINERS_NUM": "{}".format(global_size),
"PADDLE_RANK_IN_NODE": str(i),
}
if len(selected_dev_list) > 0:
if self.pod.replicas == 1:
e.update({selected_dev_key: ",".join(selected_dev_list)})
else:
e.update({selected_dev_key: selected_dev_list[i]})
else:
e.update({'PADDLE_DISTRI_BACKEND': 'gloo'})
self.add_container(envs=e, log_tag=i)
return True
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import six
import os
__all__ = []
......@@ -60,4 +61,15 @@ def rewrite_host_ip(ctx):
ctx.node.ip = ctx.args.host
enabled_plugins = [collective_compatible, rewrite_host_ip, process_args]
def test_mode(ctx):
if ctx.args.training_script == 'test':
ctx.logger.info('Paddle Distributed Test begin...')
if int(ctx.args.nnodes) < 2:
ctx.args.nnodes = 2
ctx.args.training_script = '{}/test.py'.format(
os.path.dirname(__file__))
enabled_plugins = [
test_mode, collective_compatible, rewrite_host_ip, process_args
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.vision.models import ResNet
from paddle.vision.models.resnet import BottleneckBlock
from paddle.io import Dataset, BatchSampler, DataLoader
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
epoch = 3
batch_num = 1
batch_size = 1
class_dim = 102
# define a random dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([3, 224, 224]).astype('float32')
label = np.random.randint(0, class_dim - 1, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
def optimizer_setting(parameter_list=None):
optimizer = paddle.optimizer.Momentum(
learning_rate=base_lr,
momentum=momentum_rate,
weight_decay=paddle.regularizer.L2Decay(l2_decay),
parameters=parameter_list)
return optimizer
def train_resnet():
fleet.init(is_collective=True)
resnet = ResNet(BottleneckBlock, 18, num_classes=class_dim)
optimizer = optimizer_setting(parameter_list=resnet.parameters())
optimizer = fleet.distributed_optimizer(optimizer)
resnet = fleet.distributed_model(resnet)
dataset = RandomDataset(batch_num * batch_size)
train_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=2)
print("Distributed training start...")
for eop in range(epoch):
resnet.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
out = resnet(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
avg_loss = paddle.mean(x=loss)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
avg_loss.backward()
optimizer.step()
resnet.clear_gradients()
print("[Epoch %d, batch %d] loss: %.5f, acc1: %.5f, acc5: %.5f" %
(eop, batch_id, avg_loss, acc_top1, acc_top5))
print("Distributed training completed")
if __name__ == '__main__':
import os
nnodes = os.getenv('PADDLE_NNODES')
cn = os.getenv('PADDLE_LOCAL_SIZE')
print(f"Prepare distributed training with {nnodes} nodes {cn} cards")
train_resnet()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册