未验证 提交 91a0acdb 编写于 作者: W WangXi 提交者: GitHub

static support mp_layers (#33700)

上级 58e465aa
......@@ -92,8 +92,6 @@ class Group():
return True
def get_group_rank(self, rank):
if self.id == 0:
return rank
if self.is_member() and rank in self.ranks:
return self.ranks.index(rank)
else:
......@@ -126,7 +124,8 @@ def _get_group_map():
global _group_map
if not _group_map:
genv = _get_global_env()
_group_map[0] = Group(genv.rank, genv.world_size, 0)
_group_map[0] = Group(genv.rank, genv.world_size,
list(range(genv.world_size)))
return _group_map
......@@ -1014,6 +1013,27 @@ def _c_softmax_with_cross_entropy(logits,
else:
return loss, softmax
attrs = {
'ring_id': ring_id,
'rank': rank,
'nranks': nranks,
}
helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='c_softmax_with_cross_entropy',
inputs={'Logits': logits,
'Label': label},
outputs={'Softmax': softmax,
'Loss': loss},
attrs=attrs)
if return_softmax:
return loss, softmax
return loss
def _linear(x, weight, bias=None, name=None):
"""
......
......@@ -253,6 +253,40 @@ class Fleet(object):
warnings.warn(
"The dygraph hybrid parallel environment has been initialized."
)
elif self._is_collective:
use_sharding = self._user_defined_strategy.sharding
# global group
global_rank = self.worker_index()
global_world_size = self.worker_num()
# NOTE(wangxi): see sharding_optimizer
global_ring_id = 3 if use_sharding else 0
global_ranks = list(range(global_world_size))
if tp._HYBRID_PARALLEL_GROUP is None: tp._CommunicateGroup()
cg = tp._HYBRID_PARALLEL_GROUP
self._hcg = cg
cg.set_comm_group('global', global_rank, global_world_size,
global_ring_id, global_ranks)
# hybrid group
if use_sharding is False: return
sharding_configs = self._user_defined_strategy.sharding_configs
mp_degree = int(sharding_configs['mp_degree'])
if mp_degree > 1:
assert global_world_size % mp_degree == 0
# NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
mp_ring_id = 0
mp_rank = global_rank % mp_degree
mp_group_id = global_rank // mp_degree
mp_group_ranks = [
idx for idx in global_ranks
if idx // mp_degree == mp_group_id
]
cg.set_comm_group('model', mp_rank, mp_degree, mp_ring_id,
mp_group_ranks)
def _init_hybrid_parallel_env(self):
"""initialize the hybrid environment
......
......@@ -262,3 +262,31 @@ class HybridCommunicateGroup(object):
def get_rank_from_stage(self, stage_id, **kwargs):
return self._topo.get_rank_from_stage(
self.global_rank, pipe=stage_id, **kwargs)
class _CommunicateGroup(object):
""" tmp for static """
def __init__(self):
global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self
self.groups = dict()
def set_comm_group(self, group_name, group_rank, group_size, ring_id,
group_ranks):
group = paddle.distributed.collective.Group(group_rank, group_size,
ring_id, group_ranks)
self.groups[group_name] = group
def get_group(self, group_name):
assert group_name in self.groups
return self.groups[group_name]
def get_model_parallel_group(self):
return self.get_group('model')
def get_model_parallel_world_size(self):
return self.get_group('model').nranks
def get_model_parallel_rank(self):
return self.get_group('model').rank
......@@ -56,7 +56,7 @@ class VocabParallelEmbedding(Layer):
self._weight_attr = weight_attr
self._name = name
if self.is_mp:
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
attr=self._weight_attr,
......@@ -121,7 +121,7 @@ class ColumnParallelLinear(Layer):
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
if self.is_mp:
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
......@@ -198,7 +198,7 @@ class RowParallelLinear(Layer):
self.input_size_per_partition = in_features // self.world_size
if self.is_mp:
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
......
......@@ -70,6 +70,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach()
......@@ -525,6 +526,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS})
py_test_modules(test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS ${dist_ENVS})
py_test_modules(test_fleet_distributed_strategy MODULES test_fleet_distributed_strategy)
py_test_modules(test_fleet_static_mp_layers MODULES test_fleet_static_mp_layers)
#py_test_modules(test_fleet_auto MODULES test_fleet_auto ENVS ${dist_ENVS})
if(NOT WIN32)
py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS})
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
from paddle import framework
import os
paddle.enable_static()
class ColumnLinearNet(fluid.dygraph.Layer):
def __init__(self, input_size, output_size):
super(ColumnLinearNet, self).__init__()
self.parallel_linear = fleet.meta_parallel.ColumnParallelLinear(
in_features=input_size,
out_features=output_size,
weight_attr=None,
has_bias=True,
gather_output=True,
name="test_column_linear")
def forward(self, x):
output = self.parallel_linear(x)
return output
class RowLinearNet(fluid.dygraph.Layer):
def __init__(self, input_size, output_size):
super(RowLinearNet, self).__init__()
self.parallel_linear = fleet.meta_parallel.RowParallelLinear(
in_features=input_size,
out_features=output_size,
has_bias=True,
input_is_parallel=False,
name="test_row_linear")
def forward(self, x):
output = self.parallel_linear(x)
return output
class EmbeddingNet(fluid.dygraph.Layer):
def __init__(self, vocab_size, hidden_size):
super(EmbeddingNet, self).__init__()
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(vocab_size,
hidden_size)
def forward(self, x):
output = self.embedding(x)
return output
class TestDistTraning(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "2"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002,127.0.0.1:36003,127.0.0.1:36004"
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
strategy.sharding = True
strategy.sharding_configs = {
"mp_degree": self.model_parallel_size,
"sharding_degree": 2,
}
fleet.init(is_collective=True, strategy=strategy)
def get_program(self):
return paddle.static.Program(), paddle.static.Program()
def test_column_parallel_layer(self):
main_program, startup_program = self.get_program()
with paddle.static.program_guard(main_program, startup_program):
input_size, output_size = 28, 64
model_a = ColumnLinearNet(input_size, output_size)
x = paddle.static.data(name='x', shape=[None, input_size])
y = model_a(x)
#print(main_program)
ops = main_program.global_block().ops
ops = [op.type for op in ops]
self.assertEqual(
ops, ['c_identity', 'matmul', 'elementwise_add', 'c_concat'])
weight = model_a.parallel_linear.weight
bias = model_a.parallel_linear.bias
self.assertEqual(weight.shape, (input_size, output_size //
self.model_parallel_size))
self.assertEqual(bias.shape,
(output_size // self.model_parallel_size, ))
def test_row_parallel_layer(self):
main_program, startup_program = self.get_program()
with paddle.static.program_guard(main_program, startup_program):
input_size, output_size = 28, 64
model_a = RowLinearNet(input_size, output_size)
x = paddle.static.data(name='x', shape=[None, input_size])
y = model_a(x)
#print(main_program)
ops = main_program.global_block().ops
ops = [op.type for op in ops]
self.assertEqual(
ops,
['c_split', 'matmul', 'c_allreduce_sum', 'elementwise_add'])
weight = model_a.parallel_linear.weight
bias = model_a.parallel_linear.bias
self.assertEqual(weight.shape, (
input_size // self.model_parallel_size, output_size))
self.assertEqual(bias.shape, (output_size, ))
def test_parallel_embedding(self):
main_program, startup_program = self.get_program()
with paddle.static.program_guard(main_program, startup_program):
vocab_size, hidden_size = 1000, 512
seq_len = 128
# model_a
model_a = EmbeddingNet(vocab_size, hidden_size)
x = paddle.static.data(
name='x', shape=[None, seq_len], dtype='int64')
y = model_a(x)
#print(main_program)
ops = main_program.global_block().ops
ops = [op.type for op in ops]
self.assertEqual(ops, ['c_embedding', 'c_allreduce_sum'])
weight = model_a.embedding.weight
self.assertEqual(weight.shape, (
vocab_size // self.model_parallel_size, hidden_size))
def test_parallel_cross_entropy(self):
main_program, startup_program = self.get_program()
with paddle.static.program_guard(main_program, startup_program):
batch_size = 8
seq_length = 16
class_size = 1000
class_size_per_card = class_size // self.model_parallel_size
# model_a
model_a = fleet.meta_parallel.ParallelCrossEntropy()
x = paddle.static.data(
name='x', shape=[batch_size, seq_length, class_size_per_card])
label = paddle.static.data(
name='label', shape=[batch_size, seq_length], dtype='int64')
loss_a = model_a(x, label)
#print(main_program)
ops = main_program.global_block().ops
ops = [op.type for op in ops]
self.assertEqual(ops,
['unsqueeze2', 'c_softmax_with_cross_entropy'])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册