未验证 提交 1f829f6e 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Support process group in dp with fleet api (#41119)

* support process group in dp with fleet api

* update

* fix uts

* update
上级 7554f428
......@@ -217,6 +217,7 @@ def init_parallel_env():
"required to create a process group.")
master_addr = os.getenv("MASTER_ADDR", None)
master_port = os.getenv("MASTER_PORT", None)
endpoints = None
if not master_addr or not master_port:
endpoints = os.getenv("PADDLE_MASTER", None)
if endpoints is None:
......
......@@ -398,16 +398,6 @@ def sync_params_buffers(model,
'axis': 0})
@imperative_base.no_grad
@framework.dygraph_only
def sync_eager_params(model, comm_group=None, src_rank=0):
for _, param in model._obtain_parameters_buffers().items():
if not isinstance(param, core.eager.Tensor):
raise TypeError("The data type of '%s' must be '%s'" %
(param.name, core.eager.Tensor))
comm_group.broadcast(param, src_rank).synchronize()
class DataParallel(layers.Layer):
"""
Run the dygraph module with data parallelism.
......@@ -575,7 +565,7 @@ class DataParallel(layers.Layer):
comm_buffer_size=25,
last_comm_buffer_size=1,
find_unused_parameters=False,
process_group=None):
group=None):
super(DataParallel,
self).__init__(layers.full_name() + "_data_parallel")
......@@ -585,7 +575,7 @@ class DataParallel(layers.Layer):
self._layers = layers
self.find_unused_parameters = find_unused_parameters
self.grad_need_sync = True
self.process_group = process_group
self.group = group
self.var_dtype = core.eager.Tensor if in_dygraph_mode(
) else core.VarBase
......@@ -604,20 +594,18 @@ class DataParallel(layers.Layer):
"ParallelContext must be initialized before. You should use init_parallel_env() before" \
"constructing the DataParallel."
if self.process_group is None and in_dygraph_mode():
raise RuntimeError(
"Process group should be built for DataParallel in eager mode."
)
if in_dygraph_mode():
self.group = paddle.distributed.collective._get_default_group(
) if self.group is None else self.group
assert isinstance(self.group, paddle.distributed.collective.Group), \
"ProcessGroup must be an instance of Group in DataParallel."
# sync buffer and params
# TODO(liuyuhui) Currently not support xpu. xpu is
# still broadcasting parameters when calling layer
if not paddle.is_compiled_with_xpu():
if in_dygraph_mode():
sync_eager_params(
self._layers, comm_group=self.process_group)
elif _in_legacy_dygraph():
sync_params_buffers(self._layers)
sync_params_buffers(self._layers)
self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
# NOTE(shenliang03): We can set environment variables to control
......@@ -678,7 +666,7 @@ class DataParallel(layers.Layer):
self._reducer = core.EagerReducer(
trainable_parameters,
list(reversed(self.group_indices)), is_sparse_gradient,
self.process_group,
self.group.process_group,
[self.last_comm_buffer_size, self.comm_buffer_size],
self.find_unused_parameters)
elif _in_legacy_dygraph():
......
......@@ -39,9 +39,7 @@ if (WITH_GPU OR WITH_XPU OR WITH_ASCEND OR WITH_ASCEND_CL)
endif()
list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow_in_eager_mode)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_no_sync)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_no_sync_in_eager_mode)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_no_sync_gradient_check)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel)
......@@ -279,9 +277,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_transformer)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sync_batch_norm)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow_in_eager_mode)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_no_sync)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_no_sync_in_eager_mode)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_no_sync_gradient_check)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel)
......@@ -1128,12 +1124,11 @@ set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_split_program PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 300)
set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_control_flow_in_eager_mode PROPERTIES TIMEOUT 150)
set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 150)
set_tests_properties(test_parallel_dygraph_no_sync_in_eager_mode PROPERTIES TIMEOUT 150)
set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_se_resnext PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 350)
set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 350)
set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 300)
set_tests_properties(test_parallel_dygraph_no_sync_gradient_check PROPERTIES TIMEOUT 30)
set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200)
......@@ -1155,8 +1150,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_sparse_embedding_over_height PROPERTIES TIMEOUT 150)
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 150)
endif()
endif()
......
......@@ -57,4 +57,6 @@ class TestDygraphFleetAPI(unittest.TestCase):
if __name__ == "__main__":
with _test_eager_guard():
pass
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import os
import copy
import numpy as np
import random
import socket
import paddle
import paddle.nn as nn
from paddle.fluid.dygraph.nn import Linear
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
import paddle.distributed as dist
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.optimizer import SGD
from paddle.fluid.initializer import NumpyArrayInitializer
from test_parallel_dygraph_dataparallel import get_dist_port_from_flags
def init_process_group(strategy=None):
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
envs = copy.copy(os.environ.copy())
port = get_dist_port_from_flags()
store = paddle.fluid.core.TCPStore("127.0.0.1", port, is_master, nranks)
if 'PADDLE_DISTRI_BACKEND' in envs.keys() and envs[
'PADDLE_DISTRI_BACKEND'] == 'gloo':
group = core.ProcessGroupGloo(store, rank, nranks)
else:
group = core.ProcessGroupNCCL(store, rank, nranks)
return group
class LinearModel(nn.Layer):
def __init__(self, attr_list):
super(LinearModel, self).__init__()
self._linear1 = paddle.nn.Linear(
50, 30, weight_attr=attr_list[0], bias_attr=False)
self._linear2 = paddle.nn.Linear(
30, 10, weight_attr=attr_list[1], bias_attr=False)
self._linear3 = paddle.nn.Linear(
10, 10, weight_attr=attr_list[2], bias_attr=False)
def forward(self, x):
output = self._linear1(x)
output = self._linear2(output)
output = self._linear3(output)
return output
class TestDistTraning(unittest.TestCase):
def test_multiple_gpus(self):
process_group = init_process_group()
self.generate_reducer("float32", process_group)
if paddle.get_device() != "cpu":
self.generate_reducer("float16", process_group)
def generate_reducer(self, dtype, process_group):
local_rank = ParallelEnv().local_rank
np.random.seed(2022 + local_rank)
paddle.set_default_dtype(dtype)
w_1 = paddle.ParamAttr(initializer=NumpyArrayInitializer(
np.random.rand(50, 30).astype(dtype)))
w_2 = paddle.ParamAttr(initializer=NumpyArrayInitializer(
np.random.rand(30, 10).astype(dtype)))
w_3 = paddle.ParamAttr(initializer=NumpyArrayInitializer(
np.random.rand(10, 10).astype(dtype)))
attr_list = [w_1, w_2, w_3]
inp = np.random.rand(10, 50).astype(dtype)
# original reducer
params_a = self.model_train(attr_list, inp)
# refactored reducer in eager mode
with _test_eager_guard():
params_b = self.model_train(
attr_list, inp, process_group=process_group)
for i in range(len(params_a)):
np.testing.assert_allclose(params_a[i].numpy(), params_b[i].numpy())
def model_train(self, attr_list, inp, process_group=None):
model = LinearModel(attr_list)
model = paddle.DataParallel(model, process_group=process_group)
optimizer = SGD(learning_rate=0.0003, parameters=model.parameters())
x = paddle.to_tensor(inp)
x.stop_gradient = False
for step in range(10):
y = model(x)
loss = y.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return model.parameters()
class TestCatchErrors1(unittest.TestCase):
def test_multiple_gpus(self):
linear = paddle.nn.Linear(2, 4)
with _test_eager_guard():
self.assertRaises(RuntimeError, paddle.DataParallel, linear)
class TestCatchErrors2(unittest.TestCase):
def test_multiple_gpus(self):
with _test_eager_guard():
linear = paddle.nn.Linear(2, 4)
self.assertRaises(RuntimeError, paddle.DataParallel, linear)
if __name__ == '__main__':
dist.init_parallel_env()
unittest.main()
......@@ -36,19 +36,6 @@ in_dim = 10
out_dim = 20
def init_process_group(strategy=None):
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
current_env = copy.copy(os.environ.copy())
port = 6175
if 'PADDLE_DIST_UT_PORT' in current_env.keys():
port = int(current_env['PADDLE_DIST_UT_PORT'])
store = paddle.fluid.core.TCPStore("127.0.0.1", port, is_master, nranks)
group = core.ProcessGroupNCCL(store, rank, nranks)
return group
class SimpleNet(fluid.Layer):
def __init__(self, train_id):
super(SimpleNet, self).__init__()
......@@ -83,12 +70,9 @@ class SimpleNet(fluid.Layer):
class TestDistTraning(unittest.TestCase):
def test_multiple_gpus(self):
dist.init_parallel_env()
self.trainer_id = dist.get_rank()
process_group = init_process_group()
self.pg = process_group
with _test_eager_guard():
self.pg = dist.init_parallel_env()
model_a = SimpleNet(self.trainer_id)
model_b = SimpleNet(self.trainer_id)
......@@ -97,13 +81,9 @@ class TestDistTraning(unittest.TestCase):
model_b.set_state_dict(state_dict)
model_a = paddle.DataParallel(
model_a,
find_unused_parameters=True,
process_group=process_group)
model_a, find_unused_parameters=True, group=self.pg)
model_b = paddle.DataParallel(
model_b,
find_unused_parameters=True,
process_group=process_group)
model_b, find_unused_parameters=True, group=self.pg)
ones_input = paddle.ones(shape=(batch, in_dim))
ones_input.stop_gradient = True
......@@ -150,7 +130,7 @@ class TestDistTraning(unittest.TestCase):
print(*args)
def broadcast_param(self, param, root):
self.pg.broadcast(param, root)
self.pg.process_group.broadcast(param, root)
return param
def check_gradient(self, params):
......
......@@ -69,18 +69,6 @@ class TestNoSync(TestParallelDyGraphRunnerBase):
loss = out.sum() / len(batch)
return loss
def run_trainer(self, args):
if args.eager_mode:
self.run_trainer_in_eager_mode(args)
else:
self.run_trainer_func(args)
def run_trainer_with_spawn(self, args):
if args.eager_mode:
return self.run_trainer_with_spawn_in_eager_mode(args)
else:
return self.run_trainer_with_spawn_func(args)
def run_trainer_func(self, args):
if fluid.core.is_compiled_with_cuda():
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
......@@ -103,41 +91,36 @@ class TestNoSync(TestParallelDyGraphRunnerBase):
model = paddle.DataParallel(
model, find_unused_parameters=args.find_unused_parameters)
print_to_err(type(self).__name__, "model built in dygraph")
return self.model_train(args, model, opt, train_reader)
def run_trainer_in_eager_mode(self, args):
if fluid.core.is_compiled_with_cuda():
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
else:
assert ("Only support CUDAPlace for now.")
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
np.random.seed(seed)
random.seed(seed)
with _test_eager_guard():
model, train_reader, opt = self.get_model()
if args.update_method == "nccl2":
dist.init_parallel_env()
print_to_err(
type(self).__name__,
"begin to prepare context in dygraph with nccl2")
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore(
"127.0.0.1", args.dist_port, is_master, nranks)
group = core.ProcessGroupNCCL(store, rank, nranks)
model = paddle.DataParallel(
model,
process_group=group,
find_unused_parameters=args.find_unused_parameters)
print_to_err(type(self).__name__, "model built in dygraph")
return self.model_train(args, model, opt, train_reader)
out_losses = self.model_train(args, model, opt, train_reader)
print_to_out(out_losses)
return out_losses
def run_trainer_with_spawn_func(self, args):
# 1. enable dygraph
paddle.disable_static()
# 2. init seed
seed = 90
paddle.static.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
np.random.seed(seed)
random.seed(seed)
# get trainer id
args.trainer_id = paddle.distributed.get_rank()
# 3. init parallel env
if args.update_method in ["nccl2", "gloo"]:
paddle.distributed.init_parallel_env()
# 4. train model
model, train_reader, opt = self.get_model()
if args.update_method in ["nccl2", "gloo"]:
model = paddle.DataParallel(
model, find_unused_parameters=args.find_unused_parameters)
out_losses = self.model_train(args, model, opt, train_reader)
print_to_out(out_losses)
return out_losses
def model_train(self, args, model, opt, train_reader):
out_losses = []
......@@ -157,12 +140,8 @@ class TestNoSync(TestParallelDyGraphRunnerBase):
loss = self.run_one_loop(model, opt, data)
loss.backward()
opt.minimize(loss)
print_to_err(
type(self).__name__,
"loss at step %d: %f" % (step_id, loss.numpy()))
out_losses.append(loss.numpy())
model.clear_gradients()
print_to_out(out_losses)
return out_losses
......
......@@ -21,7 +21,7 @@ import paddle
# used by model.run_trainer in test_dist_base
from test_dist_base import RUN_STEP
from test_parallel_dygraph_dataparallel import get_dist_port_from_flags
from paddle.fluid.framework import _test_eager_guard
# NOTE: compatible TestParallelDyGraphRunnerBase args
......@@ -29,8 +29,6 @@ class SpawnAssistTestArgs(object):
update_method = "local"
trainer_id = 0
find_unused_parameters = False
eager_mode = False
dist_port = get_dist_port_from_flags()
class TestDistSpawnRunner(unittest.TestCase):
......@@ -55,14 +53,17 @@ class TestDistSpawnRunner(unittest.TestCase):
result_list.append(res_queue.get())
return result_list
def _args_config(self, args):
return
def check_dist_result_with_spawn(self, test_class, delta=1e-3):
with _test_eager_guard():
self.check_dist_result_with_spawn_func(
test_class=test_class, delta=delta)
self.check_dist_result_with_spawn_func(
test_class=test_class, delta=delta)
def check_dist_result_with_spawn_func(self, test_class, delta=1e-3):
# 0. prepare model and args
model = test_class()
args = SpawnAssistTestArgs()
self._args_config(args)
# 1. calc signal card loss
losses = self._run(model, args)
......
......@@ -36,7 +36,6 @@ import paddle.fluid.dygraph as dygraph
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import DataParallel, ParallelEnv
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
......@@ -543,12 +542,6 @@ class TestParallelDyGraphRunnerBase(object):
return batch
def run_trainer(self, args):
if args.eager_mode:
self.run_trainer_in_eager_mode(args)
else:
self.run_trainer_func(args)
def run_trainer_func(self, args):
seed = 90
if args.update_method == 'gloo':
place = fluid.CPUPlace()
......@@ -580,6 +573,7 @@ class TestParallelDyGraphRunnerBase(object):
strategy.local_rank = args.trainer_id
strategy.trainer_endpoints = args.endpoints.split(",")
strategy.current_endpoint = args.current_endpoint
paddle.distributed.init_parallel_env()
print_to_err(
type(self).__name__,
"begin to prepare context in dygraph with nccl2")
......@@ -621,82 +615,7 @@ class TestParallelDyGraphRunnerBase(object):
model.clear_gradients()
print_to_out(out_losses)
def run_trainer_in_eager_mode(self, args):
seed = 90
if args.update_method == 'gloo':
place = fluid.CPUPlace()
elif fluid.core.is_compiled_with_cuda():
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
elif fluid.core.is_compiled_with_xpu():
device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
place = fluid.XPUPlace(device_id)
elif fluid.core.is_compiled_with_npu():
device_id = int(os.getenv("FLAGS_selected_npus", "0"))
place = fluid.NPUPlace(device_id)
else:
assert ("Only support CUDAPlace or XPUPlace or CPU(Gloo) for now.")
with _test_eager_guard():
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
np.random.seed(seed)
import random
random.seed(seed)
model, train_reader, opt = self.get_model()
#if args.update_method == "nccl2":
if args.update_method in ["nccl2", "gloo"]:
paddle.distributed.init_parallel_env()
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore(
"127.0.0.1", args.dist_port, is_master, nranks)
if args.update_method == "nccl2":
group = core.ProcessGroupNCCL(store, rank, nranks)
elif args.update_method == "gloo":
group = core.ProcessGroupGloo(store, rank, nranks)
print_to_err(
type(self).__name__,
"begin to prepare context in dygraph with nccl2")
model = dygraph.parallel.DataParallel(
model,
process_group=group,
find_unused_parameters=args.find_unused_parameters)
print_to_err(type(self).__name__, "model built in dygraph")
out_losses = []
print_to_err(
type(self).__name__, "begin to run dygraph training")
for step_id, data in enumerate(train_reader()):
data = self._get_data(data, args)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
if step_id % 10 == 0:
print_to_err(
type(self).__name__,
"loss at step %d: %f" % (step_id, loss.numpy()))
out_losses.append(loss.numpy())
loss.backward()
opt.minimize(loss)
if not args.accumulate_gradient:
model.clear_gradients()
print_to_out(out_losses)
def run_trainer_with_spawn(self, args):
if args.eager_mode:
return self.run_trainer_with_spawn_in_eager_mode(args)
else:
return self.run_trainer_with_spawn_func(args)
def run_trainer_with_spawn_func(self, args):
# 1. enable dygraph
paddle.disable_static()
......@@ -733,64 +652,7 @@ class TestParallelDyGraphRunnerBase(object):
model.clear_gradients()
return out_losses
def run_trainer_with_spawn_in_eager_mode(self, args):
# 1. enable dygraph
paddle.disable_static()
# 2. init seed
seed = 90
paddle.static.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
np.random.seed(seed)
random.seed(seed)
# get trainer id
args.trainer_id = paddle.distributed.get_rank()
# 3. init parallel env
if args.update_method in ["nccl2", "gloo"]:
paddle.distributed.init_parallel_env()
# 4. build process group
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", args.dist_port,
is_master, nranks)
if args.update_method == "nccl2":
group = core.ProcessGroupNCCL(store, rank, nranks)
elif args.update_method == "gloo":
group = core.ProcessGroupGloo(store, rank, nranks)
# 5. train model
with _test_eager_guard():
model, train_reader, opt = self.get_model()
if args.update_method in ["nccl2", "gloo"]:
model = paddle.DataParallel(
model,
process_group=group,
find_unused_parameters=args.find_unused_parameters)
out_losses = []
for step_id, data in enumerate(train_reader()):
data = self._get_data(data, args)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
out_losses.append(loss.numpy())
loss.backward()
opt.minimize(loss)
model.clear_gradients()
return out_losses
def run_use_fleet_api_trainer(self, args):
if args.eager_mode:
self.run_use_fleet_api_trainer_in_eager_mode(args)
else:
self.run_use_fleet_api_trainer_func(args)
def run_use_fleet_api_trainer_func(self, args):
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
# 1. enable dygraph
......@@ -835,52 +697,6 @@ class TestParallelDyGraphRunnerBase(object):
opt.clear_grad()
print_to_out(out_losses)
def run_use_fleet_api_trainer_in_eager_mode(self, args):
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
# 1. enable dygraph
paddle.disable_static()
# 2. init seed
seed = 90
paddle.static.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
np.random.seed(seed)
random.seed(seed)
# get trainer id
args.trainer_id = paddle.distributed.get_rank()
# set strategy
strategy = fleet.DistributedStrategy()
if args.find_unused_parameters:
strategy.find_unused_parameters = True
# 3. init parallel env
if args.update_method == "nccl2" or "bkcl" or "hccl":
fleet.init(is_collective=True, strategy=strategy)
# 4. train model
with _test_eager_guard():
model, train_reader, opt = self.get_model()
if args.update_method == "nccl2" or "bkcl" or "hccl":
opt = fleet.distributed_optimizer(opt)
model = fleet.distributed_model(model)
out_losses = []
for step_id, data in enumerate(train_reader()):
data = self._get_data(data, args)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
out_losses.append(loss.numpy())
loss.backward()
opt.step()
if not args.accumulate_gradient:
opt.clear_grad()
print_to_out(out_losses)
def runtime_main(test_class):
parser = argparse.ArgumentParser(description='Run dist test.')
......@@ -911,8 +727,6 @@ def runtime_main(test_class):
parser.add_argument(
'--current_endpoint', type=str, required=False, default="")
parser.add_argument('--sync_mode', action='store_true')
parser.add_argument('--eager_mode', action='store_true')
parser.add_argument('--dist_port', type=int, required=False, default=6175)
parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--use_cpu', action='store_true')
parser.add_argument('--use_xpu', action='store_true')
......@@ -1005,8 +819,6 @@ class TestDistBase(unittest.TestCase):
self._port_set = set()
self._python_interp = sys.executable
self._sync_mode = True
self._dist_port = 6175
self._eager_mode = False
self._hogwild_mode = False
self._enforce_place = None
self._use_reduce = False
......@@ -1168,10 +980,6 @@ class TestDistBase(unittest.TestCase):
if len(devices) > 1 and self._use_dgc:
cmd += " --use_dgc"
if self._eager_mode:
cmd += " --eager_mode"
cmd += " --dist_port {}".format(self._dist_port)
if self._accumulate_gradient:
cmd += " --accumulate_gradient"
......@@ -1245,11 +1053,6 @@ class TestDistBase(unittest.TestCase):
if self._sync_mode:
tr0_cmd += " --sync_mode"
tr1_cmd += " --sync_mode"
if self._eager_mode:
tr0_cmd += " --eager_mode"
tr1_cmd += " --eager_mode"
tr0_cmd += " --dist_port {}".format(self._dist_port)
tr1_cmd += " --dist_port {}".format(self._dist_port)
if self._hogwild_mode:
tr0_cmd += " --hogwild"
tr1_cmd += " --hogwild"
......@@ -1356,10 +1159,6 @@ class TestDistBase(unittest.TestCase):
assert self._use_dgc == False, "gloo not support use dgc"
if self._eager_mode:
tr_cmd += " --eager_mode"
tr_cmd += " --dist_port {}".format(self._dist_port)
if self._accumulate_gradient:
tr_cmd += " --accumulate_gradient"
......@@ -1437,10 +1236,6 @@ class TestDistBase(unittest.TestCase):
if self._use_dgc:
tr_cmd += " --use_dgc"
if self._eager_mode:
tr_cmd += " --eager_mode"
tr_cmd += " --dist_port {}".format(self._dist_port)
if self._accumulate_gradient:
tr_cmd += " --accumulate_gradient"
......@@ -1665,7 +1460,34 @@ class TestDistBase(unittest.TestCase):
check_error_log=False,
need_envs={},
log_name=""):
if self._dygraph and (self._gloo_mode or self._nccl2_mode):
with _test_eager_guard():
self.check_with_place_func(
model_file=model_file,
delta=delta,
check_error_log=check_error_log,
need_envs=need_envs,
log_name=log_name)
self.check_with_place_func(
model_file=model_file,
delta=delta,
check_error_log=check_error_log,
need_envs=need_envs,
log_name=log_name)
else:
self.check_with_place_func(
model_file=model_file,
delta=delta,
check_error_log=check_error_log,
need_envs=need_envs,
log_name=log_name)
def check_with_place_func(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={},
log_name=""):
required_envs = self._get_required_envs(check_error_log, need_envs)
if self._gloo_mode:
......
......@@ -26,7 +26,7 @@ import paddle.fluid.dygraph as dygraph
from paddle.fluid.dygraph.nn import Linear
import paddle.fluid.core as core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, in_dygraph_mode
class TestDataParallelGroup(unittest.TestCase):
......@@ -34,7 +34,10 @@ class TestDataParallelGroup(unittest.TestCase):
return paddle.rand(shape=shape, dtype=dtype)
def assign_group_by_size(self, *args):
return core.assign_group_by_size(*args)
if in_dygraph_mode():
return core.eager_assign_group_by_size(*args)
elif _in_legacy_dygraph():
return core.assign_group_by_size(*args)
def test_construct_group0(self):
# one dtype & one limit capability
......@@ -160,14 +163,19 @@ class TestDataParallelGroup(unittest.TestCase):
[300], [1, 0, 2, 3])
self.assertEqual([[1, 0], [3], [2]], res)
class TestDataParallelGroupEager(TestDataParallelGroup):
def create_varbase(self, dtype, shape):
def test_construct_group_in_legacy_mode(self):
with _test_eager_guard():
return paddle.rand(shape=shape, dtype=dtype)
def assign_group_by_size(self, *args):
return core.eager_assign_group_by_size(*args)
pass
self.test_construct_group0()
self.test_construct_group1()
self.test_construct_group2()
self.test_construct_group3()
self.test_construct_group4()
self.test_construct_group5()
self.test_construct_group6()
self.test_construct_group7()
self.test_construct_group8()
self.test_construct_group9()
if __name__ == '__main__':
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
flag_name = os.path.splitext(__file__)[0]
class TestDygraphControlFlowSameEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._eager_mode = True
self._dygraph = True
self._find_unused_parameters = True
def test_net(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_control_flow_same.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestDygraphControlFlowSameAccGradEager(TestDygraphControlFlowSameEager):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._eager_mode = True
self._dygraph = True
self._accumulate_gradient = True
self._find_unused_parameters = True
class TestDygraphControlFlowDiffEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._eager_mode = True
self._dygraph = True
self._find_unused_parameters = True
def test_net(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_control_flow_different.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestFleetDygraphControlFlowDiffAccGradEager(
TestDygraphControlFlowDiffEager):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._eager_mode = True
self._dygraph = True
self._accumulate_gradient = True
self._find_unused_parameters = True
if __name__ == "__main__":
unittest.main()
......@@ -208,11 +208,6 @@ class TestDataParallelWithPyLayer(TestMultipleGpus):
self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py')
class TestDataParallelInEagerMode(TestMultipleGpus):
def test_multiple_gpus_dynamic(self):
self.run_mnist_2gpu('parallel_dygraph_dataparallel_in_eager_mode.py')
class TestGradientCheckInEagerMode(TestMultipleGpus):
def test_multiple_gpus_dynamic(self):
self.run_mnist_2gpu('parallel_dygraph_gradient_check_in_eager_mode.py')
......
......@@ -136,7 +136,7 @@ class TestDataParallelGradientCheck(TestMultipleGpus):
class TestDataParallelGradientCheckInEagerMode(TestMultipleGpus):
def test_multiple_gpus_dynamic(self):
self.run_mnist_2gpu('parallel_dygraph_dataparallel_in_eager_mode.py')
self.run_mnist_2gpu('parallel_dygraph_gradient_check_in_eager_mode.py')
if __name__ == "__main__":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_no_sync import TestNoSync
from parallel_dygraph_no_sync_unused_params import TestNoSyncUnusedParam
from parallel_dygraph_no_sync_control_flow import TestNoSyncControlFlow
flag_name = os.path.splitext(__file__)[0]
class TestParallelDygraphNoSync(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
self._find_unused_parameters = False
def test_no_sync(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_no_sync.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphNoSyncUnusedParam(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
self._find_unused_parameters = True
def test_no_sync_ununsed_param(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_no_sync_unused_params.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphNoSyncControlFlow(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
self._find_unused_parameters = True
def test_no_sync_control_flow(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_no_sync_control_flow.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphNoSyncSpawn(TestDistSpawnRunner):
def test_no_sync_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(test_class=TestNoSync, delta=1e-5)
class TestParallelDygraphNoSyncUnusedParamSpawn(TestDistSpawnRunner):
def _args_config(self, args):
args.find_unused_parameters = True
args.eager_mode = True
def test_no_sync_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestNoSyncUnusedParam, delta=1e-5)
class TestParallelDygraphNoSyncControlFlowSpawn(TestDistSpawnRunner):
def _args_config(self, args):
args.find_unused_parameters = True
args.eager_mode = True
def test_no_sync_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestNoSyncControlFlow, delta=1e-5)
if __name__ == "__main__":
unittest.main()
......@@ -64,47 +64,5 @@ class TestParallelDygraphSparseEmdeddingSpawn(TestDistSpawnRunner):
test_class=TestSparseEmbedding, delta=1e-5)
class TestParallelDygraphSparseEmdeddingEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._eager_mode = True
self._dygraph = True
def test_sparse_embedding(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_sparse_embedding.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingFP64Eager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_sparse_embedding_fp64(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_sparse_embedding_fp64.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingSpawnEager(TestDistSpawnRunner):
def _args_config(self, args):
args.eager_mode = True
def test_sparse_embedding_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSparseEmbedding, delta=1e-5)
if __name__ == "__main__":
unittest.main()
......@@ -48,32 +48,5 @@ class TestParallelDygraphSparseEmdeddingOverHeightSpawn(TestDistSpawnRunner):
test_class=TestSparseEmbeddingOverHeight, delta=1e-5)
class TestParallelDygraphSparseEmdeddingOverHeightEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_sparse_embedding(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_sparse_embedding_over_height.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingOverHeightSpawnEager(
TestDistSpawnRunner):
def _args_config(self, args):
args.eager_mode = True
def test_sparse_embedding_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSparseEmbeddingOverHeight, delta=1e-5)
if __name__ == "__main__":
unittest.main()
......@@ -36,21 +36,5 @@ class TestParallelDygraphMnist(TestDistBase):
log_name=flag_name)
class TestParallelDygraphMnistEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_mnist(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_sync_batch_norm.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
......@@ -41,13 +41,6 @@ class TestParallelDygraphTransformer(TestDistBase):
log_name=flag_name)
class TestParallelDygraphTransformerSpawn(TestDistSpawnRunner):
def test_transformer_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestTransformer, delta=1e-5)
class TestParallelDygraphTransformerAccGrad(TestDistBase):
def _setup_config(self):
self._sync_mode = False
......@@ -65,21 +58,5 @@ class TestParallelDygraphTransformerAccGrad(TestDistBase):
log_name=flag_name)
class TestParallelDygraphTransformerEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_transformer(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
......@@ -86,71 +86,5 @@ class TestParallelDygraphSharedUnusedVariables(TestDistBase):
log_name=flag_name)
class TestParallelDygraphUnusedVarEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_net(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_unused_variables.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestDygraphUnusedVarEager(TestParallelDygraphUnusedVar):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
class TestSparseEmbeddingUnusedVarsSpawnEager(TestDistSpawnRunner):
def _args_config(self, args):
args.eager_mode = True
def test_mnist_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSparseEmbeddingUnusedVars, delta=1e-5)
class TestParallelDygraphNoVarEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_net(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_none_var.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSharedUnusedVariablesEager(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._nccl2_mode = True
self._dygraph = True
def test_mnist(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_shared_unused_var.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册