From 9dd223ccd9e06d37f02ac77592725eb5a67389d0 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Fri, 21 Aug 2020 09:36:35 +0000 Subject: [PATCH] fleet support dygraph in mnist/resnet/transformer --- dygraph/mnist/train.py | 20 ++++++++++++++------ dygraph/resnet/train.py | 20 ++++++++++++++------ dygraph/transformer/train.py | 17 +++++++++++------ 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/dygraph/mnist/train.py b/dygraph/mnist/train.py index f81df8f2..cbb5a16a 100644 --- a/dygraph/mnist/train.py +++ b/dygraph/mnist/train.py @@ -24,6 +24,8 @@ from paddle.fluid.optimizer import AdamOptimizer from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.dygraph.base import to_variable +from paddle.distributed import fleet +from paddle.distributed.fleet.base import role_maker def parse_args(): parser = argparse.ArgumentParser("Training for Mnist.") @@ -174,8 +176,11 @@ def train_mnist(args): epoch_num = args.epoch BATCH_SIZE = 64 - place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ - if args.use_data_parallel else fluid.CUDAPlace(0) + if args.use_data_parallel: + place_idx = int(os.environ['FLAGS_selected_gpus']) + place = fluid.CUDAPlace(place_idx) + else: + place = fluid.CUDAPlace(0) with fluid.dygraph.guard(place): if args.ce: print("ce mode") @@ -184,12 +189,15 @@ def train_mnist(args): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - if args.use_data_parallel: - strategy = fluid.dygraph.parallel.prepare_context() mnist = MNIST() adam = AdamOptimizer(learning_rate=0.001, parameter_list=mnist.parameters()) if args.use_data_parallel: - mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy) + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + dist_strategy = fleet.DistributedStrategy() + adam = fleet.distributed_optimizer(adam, dist_strategy) + # call after distributed_optimizer so as to apply dist_strategy + mnist = fleet.build_distributed_model(mnist) train_reader = paddle.batch( paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True) @@ -241,7 +249,7 @@ def train_mnist(args): save_parameters = (not args.use_data_parallel) or ( args.use_data_parallel and - fluid.dygraph.parallel.Env().local_rank == 0) + fleet.worker_index() == 0) if save_parameters: fluid.save_dygraph(mnist.state_dict(), "save_temp") diff --git a/dygraph/resnet/train.py b/dygraph/resnet/train.py index e92a39bd..342c64df 100644 --- a/dygraph/resnet/train.py +++ b/dygraph/resnet/train.py @@ -15,6 +15,7 @@ import numpy as np import argparse import ast +import os import paddle import paddle.fluid as fluid from paddle.fluid.layer_helper import LayerHelper @@ -23,6 +24,8 @@ from paddle.fluid.dygraph.base import to_variable from paddle.fluid import framework +from paddle.distributed import fleet +from paddle.distributed.fleet.base import role_maker import math import sys import time @@ -283,8 +286,11 @@ def eval(model, data): def train_resnet(): epoch = args.epoch - place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ - if args.use_data_parallel else fluid.CUDAPlace(0) + if args.use_data_parallel: + place_idx = int(os.environ['FLAGS_selected_gpus']) + place = fluid.CUDAPlace(place_idx) + else: + place = fluid.CUDAPlace(0) with fluid.dygraph.guard(place): if args.ce: print("ce mode") @@ -293,14 +299,16 @@ def train_resnet(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - if args.use_data_parallel: - strategy = fluid.dygraph.parallel.prepare_context() - resnet = ResNet() optimizer = optimizer_setting(parameter_list=resnet.parameters()) if args.use_data_parallel: - resnet = fluid.dygraph.parallel.DataParallel(resnet, strategy) + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + dist_strategy = fleet.DistributedStrategy() + optimizer = fleet.distributed_optimizer(optimizer, dist_strategy) + # call after distributed_optimizer so as to apply dist_strategy + resnet = fleet.build_distributed_model(resnet) train_reader = paddle.batch( paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size) diff --git a/dygraph/transformer/train.py b/dygraph/transformer/train.py index 75cbb277..7c5c2981 100644 --- a/dygraph/transformer/train.py +++ b/dygraph/transformer/train.py @@ -21,6 +21,8 @@ import time import numpy as np import paddle import paddle.fluid as fluid +from paddle.distributed import fleet +from paddle.distributed.fleet.base import role_maker from utils.configure import PDConfig from utils.check import check_gpu, check_version @@ -32,9 +34,9 @@ from model import Transformer, CrossEntropyCriterion, NoamDecay def do_train(args): if args.use_cuda: - trainer_count = fluid.dygraph.parallel.Env().nranks - place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id - ) if trainer_count > 1 else fluid.CUDAPlace(0) + trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) + place_idx = int(os.getenv('FLAGS_selected_gpus', 0)) + place = fluid.CUDAPlace(place_idx) else: trainer_count = 1 place = fluid.CPUPlace() @@ -130,9 +132,12 @@ def do_train(args): transformer.load_dict(model_dict) if trainer_count > 1: - strategy = fluid.dygraph.parallel.prepare_context() - transformer = fluid.dygraph.parallel.DataParallel( - transformer, strategy) + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + dist_strategy = fleet.DistributedStrategy() + optimizer = fleet.distributed_optimizer(optimizer, dist_strategy) + # call after distributed_optimizer so as to apply dist_strategy + transformer = fleet.build_distributed_model(transformer) # the best cross-entropy value with label smoothing loss_normalizer = -( -- GitLab