未验证 提交 85bb1a85 编写于 作者: G Guoxia Wang 提交者: GitHub

support auto parallel data shard (#36055)

上级 817f9ef0
...@@ -20,6 +20,7 @@ from .context import get_default_distributed_context ...@@ -20,6 +20,7 @@ from .context import get_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner from .partitioner import Partitioner
from .process import get_all_process_groups from .process import get_all_process_groups
from .utils import make_data_unshard
from .reshard import reshard from .reshard import reshard
...@@ -95,6 +96,8 @@ class AutoParallelizer: ...@@ -95,6 +96,8 @@ class AutoParallelizer:
self._remove_distributed_attrs(partitioned_main_prog) self._remove_distributed_attrs(partitioned_main_prog)
complete_backward_annotation(partitioned_main_prog, self._dist_context) complete_backward_annotation(partitioned_main_prog, self._dist_context)
make_data_unshard(partitioned_main_prog, partitioned_startup_prog)
reshard(partitioned_main_prog, partitioned_startup_prog, rank, reshard(partitioned_main_prog, partitioned_startup_prog, rank,
self._dist_context) self._dist_context)
......
...@@ -277,3 +277,40 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): ...@@ -277,3 +277,40 @@ def _linear_idx2coordinate(mesh_shape, linear_idx):
# row major order # row major order
return coordinate return coordinate
def _get_unshard_dist_shape(var, dist_attr):
var_shape = var.shape
mapping = dist_attr.get_dims_mapping()
mesh = dist_attr.get_process_mesh().topology
assert len(var_shape) == len(
mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
var_shape, mapping)
new_shape = []
for idx in range(len(var_shape)):
if var_shape[idx] == -1 or mapping[idx] == -1:
new_shape.append(var_shape[idx])
else:
new_shape.append(var_shape[idx] * mesh[mapping[idx]])
return new_shape
def make_data_unshard(dist_main_prog, dist_startup_prog):
from .context import get_default_distributed_context
dist_context = get_default_distributed_context()
for var in dist_main_prog.list_vars():
if var.is_data:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
var)
inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr)
var.desc.set_shape(inverse_shape)
dim_mapping = tensor_dist_attr.get_dims_mapping()
dim_mapping = [-1] * len(dim_mapping)
tensor_dist_attr.set_dims_mapping(dim_mapping)
dist_context.set_tensor_distributed_attr_for_program(
var, tensor_dist_attr)
var._set_attr('dim_mapping' + core.kAutoParallelSuffix(),
dim_mapping)
...@@ -1423,6 +1423,7 @@ class Fleet(object): ...@@ -1423,6 +1423,7 @@ class Fleet(object):
auto_parallelizer = AutoParallelizer(self) auto_parallelizer = AutoParallelizer(self)
optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize( optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
return optimize_ops, params_grads, dist_startup_prog, dist_main_prog return optimize_ops, params_grads, dist_startup_prog, dist_main_prog
# compile time # compile time
......
...@@ -36,6 +36,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) ...@@ -36,6 +36,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper)
list(APPEND DIST_TEST_OPS test_parallel_class_center_sample) list(APPEND DIST_TEST_OPS test_parallel_class_center_sample)
list(APPEND DIST_TEST_OPS test_parallel_margin_cross_entropy) list(APPEND DIST_TEST_OPS test_parallel_margin_cross_entropy)
list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests. #remove distribute unittests.
list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) list(APPEND MIXED_DIST_TEST_OPS test_dgc_op)
...@@ -233,6 +234,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -233,6 +234,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_data_unshard)
elseif(WITH_GPU) elseif(WITH_GPU)
if (${CUDNN_VERSION} VERSION_LESS 7100) if (${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
...@@ -1001,6 +1003,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) ...@@ -1001,6 +1003,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_class_center_sample PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_class_center_sample PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_margin_cross_entropy PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_margin_cross_entropy PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120)
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import copy
import numpy as np
import random
import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.distributed.auto_parallel as auto
import paddle.nn.functional as F
from paddle.distributed import fleet
paddle.enable_static()
paddle.distributed.init_parallel_env()
class TestDataUnshard(unittest.TestCase):
def test_dp2pp1mp1(self):
def create_model(train_program, start_program):
with paddle.static.program_guard(train_program, start_program):
ROOT_MESH = auto.ProcessMesh([0, 1])
MESH_0 = auto.ProcessMesh([0, 1], ROOT_MESH)
input = paddle.static.data(name='input', shape=[2, 8])
label = paddle.static.data(name='label', shape=[2, 8])
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(
mean=0.0, std=0.02))
linear0 = nn.Linear(8, 8, weight_attr)
linear1 = nn.Linear(8, 8, weight_attr)
auto.shard_tensor(input, MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(label, MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(linear0.weight, MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(linear1.weight, MESH_0, dim_mapping=[-1, -1])
linear0_out = linear0(input)
gelu_out = F.gelu(linear0_out)
linear1_out = linear1(gelu_out)
error_cost = paddle.nn.functional.square_error_cost(linear1_out,
label)
loss = paddle.mean(error_cost)
return train_program, start_program, loss, input, label
train_program = paddle.static.Program()
start_program = paddle.static.Program()
# serial program
train_program, start_program, loss, input, label = create_model(
train_program, start_program)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
worker_index = paddle.distributed.get_rank()
paddle.seed(worker_index + 2021)
random.seed(worker_index + 2021)
np.random.seed(worker_index + 2021)
place = paddle.set_device("gpu")
exe = paddle.static.Executor(place)
exe.run(distributed_startup_program)
input_data = np.array(range(2 * 8)).reshape([2, 8]).astype("float32")
label_data = np.random.randint(0, 10, [2, 8]).astype("float32")
fetchs = [loss.name, 'input@RESHARD_0']
loss_np, shard_data_np = exe.run(
distributed_main_program,
feed={"input": input_data,
"label": label_data},
fetch_list=fetchs)
desired = input_data[worker_index].reshape(shard_data_np.shape)
np.testing.assert_allclose(shard_data_np, desired)
def dp1pp1mp2(self):
def create_model(train_program, start_program):
with paddle.static.program_guard(train_program, start_program):
ROOT_MESH = auto.ProcessMesh([0, 1])
MESH_0 = auto.ProcessMesh([0, 1], ROOT_MESH)
input = paddle.static.data(name='input', shape=[8, 8])
label = paddle.static.data(name='label', shape=[8, 8])
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(
mean=0.0, std=0.02))
linear0 = nn.Linear(8, 8, weight_attr)
linear1 = nn.Linear(8, 8, weight_attr)
auto.shard_tensor(input, MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(label, MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(linear0.weight, MESH_0, dim_mapping=[-1, 0])
auto.shard_tensor(linear1.weight, MESH_0, dim_mapping=[0, -1])
linear0_out = linear0(input)
gelu_out = F.gelu(linear0_out)
linear1_out = linear1(gelu_out)
error_cost = paddle.nn.functional.square_error_cost(linear1_out,
label)
loss = paddle.mean(error_cost)
return train_program, start_program, loss, input, label
train_program = paddle.static.Program()
start_program = paddle.static.Program()
# serial program
train_program, start_program, loss, input, label = create_model(
train_program, start_program)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
worker_index = paddle.distributed.get_rank()
paddle.seed(worker_index + 2021)
random.seed(worker_index + 2021)
np.random.seed(worker_index + 2021)
place = paddle.set_device("gpu")
exe = paddle.static.Executor(place)
exe.run(distributed_startup_program)
input_data = np.array(range(8 * 8)).reshape([8, 8]).astype("float32")
label_data = np.random.randint(0, 10, [8, 8]).astype("float32")
fetchs = [loss.name, 'input']
loss_np, shard_data_np = exe.run(
distributed_main_program,
feed={"input": input_data,
"label": label_data},
fetch_list=fetchs)
desired = input_data.reshape(shard_data_np.shape)
np.testing.assert_allclose(shard_data_np, desired)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestAutoParallelDataUnshard(TestMultipleGpus):
def test_auto_parallel_data_unshard(self):
self.run_mnist_2gpu('auto_parallel_data_unshard.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册