未验证 提交 42c1297e 编写于 作者: W WangXi 提交者: GitHub

[HybridParallel] update collective split to use c_embedding and mp_allreduce (#33411)

上级 9cda9ec2
......@@ -894,8 +894,25 @@ def _mp_allreduce(tensor,
"use_model_parallel", use_model_parallel)
else:
raise ValueError("Unknown parameter: {}.".format(op))
else:
raise NotImplementedError("No support _mp_allreduce in dygraph mode.")
op_type = 'c_allreduce_sum'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
op_type)
helper.append_op(
type=op_type,
inputs={'X': tensor},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
'use_model_parallel': use_model_parallel,
})
return out
def _c_lookup_table(table, index, start_index=0, name=None):
......@@ -915,6 +932,19 @@ def _c_lookup_table(table, index, start_index=0, name=None):
if in_dygraph_mode():
return core.ops.c_embedding(table, index, "start_index", start_index)
op_type = 'c_embedding'
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='table')
check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='c_embedding',
inputs={'Ids': index,
'W': table},
outputs={'Out': tmp},
attrs={"start_index": start_index})
return tmp
class _Linear(layers.Layer):
"""
......@@ -1136,47 +1166,34 @@ def _parallel_embedding(x,
return
ring_id = 0 if group is None else group.id
origin_num_embeddings = origin_size[0]
embedding = paddle.nn.Embedding(
per_part_embeddings,
origin_size[1],
padding_idx=per_part_embeddings - 1,
sparse=False,
weight_attr=param_attr,
name=name)
origin_input_shape = x.shape
if len(origin_input_shape) == 2:
x = paddle.unsqueeze(x, axis=-1)
else:
assert origin_input_shape[-1] == 1, (
"The last dimension size of x must be 1.")
x_shard = paddle.shard_index(x, origin_num_embeddings, num_partitions,
inner_rank, per_part_embeddings - 1)
if len(origin_input_shape) == 2:
x_shard = paddle.squeeze(x_shard, axis=-1)
emb_out = embedding(x_shard)
helper = LayerHelper("_parallel_embedding", **locals())
per_part_size = per_part_embeddings
rank = inner_rank
vocab_start_index = rank * per_part_size
dtype = helper.get_default_dtype()
size = [per_part_size, origin_size[1]]
weight = helper.create_parameter(
attr=param_attr, shape=size, dtype=dtype, is_bias=False)
if num_partitions == 1:
return paddle.nn.functional.embedding(
x, weight=weight, padding_idx=None, sparse=False, name=name)
startup_block = paddle.static.default_startup_program().global_block()
main_block = paddle.static.default_main_program().global_block()
startup_block.vars[embedding.weight.name].is_distributed = True
main_block.vars[embedding.weight.name].is_distributed = True
out = main_block.create_var(
shape=emb_out.shape,
dtype=emb_out.dtype,
type=emb_out.type,
lod_level=emb_out.lod_level,
persistable=False,
is_data=False,
need_check_feed=emb_out.desc.need_check_feed())
main_block.append_op(
type='c_allreduce_sum',
inputs={'X': emb_out},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
'use_model_parallel': True
})
startup_block.vars[weight.name].is_distributed = True
main_block.vars[weight.name].is_distributed = True
output_parallel = paddle.distributed.collective._c_lookup_table(
weight, x, start_index=vocab_start_index, name=name)
out = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=group,
use_calc_stream=True,
use_model_parallel=True)
return out
......@@ -1288,11 +1305,11 @@ def split(x,
if operation == "embedding":
assert axis == 0, ("We only support to split the weight of embedding "
"along the first axis now.")
per_part_size = (size[0] + num_partitions - 1) // num_partitions
last_part_size = size[0] - per_part_size * (num_partitions - 1)
if inner_rank == num_partitions - 1: per_part_size = last_part_size
per_part_size += 1 # make the last row as the padding index
assert size[0] % num_partitions == 0, \
"The length of the vocabulary must be divisible by num_partitions " \
"but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
per_part_size = size[0] // num_partitions
emb_out = _parallel_embedding(
x,
per_part_size,
......
......@@ -48,23 +48,27 @@ class TestParallelEmbeddingAPI(TestCollectiveAPIRunnerBase):
with fluid.program_guard(main_prog, startup_program):
fleet.init(is_collective=True)
np.random.seed(2020)
np_array = np.random.rand(10, 8)
# (num_embeddings, embedding_dim) = (12, 8)
size = (12, 8)
np_array = np.random.rand(size[0], size[1])
paddle.seed(2020)
data_in = paddle.randint(0, 8, shape=(10, 4))
data_in = paddle.randint(0, size[0], shape=(10, 4))
data = paddle.static.data(
name='tindata', shape=[10, 1000], dtype="float32")
per_part_size = size[0] // 2
if rank == 0:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[0:5, :]), )
np_array[0:per_part_size, :]), )
else:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[5:10, :]), )
np_array[per_part_size:size[0], :]), )
emb_out = paddle.distributed.split(
data_in, (8, 8),
data_in,
size,
operation="embedding",
num_partitions=2,
weight_attr=param_attr)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import paddle.distributed.fleet as fleet
from paddle.fluid.incubate.fleet.base import role_maker
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
paddle.enable_static()
class TestParallelEmbeddingAPINoneDivisible(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
fleet.init(is_collective=True)
np.random.seed(2020)
np_array = np.random.rand(9, 8)
paddle.seed(2020)
data_in = paddle.randint(0, 7, shape=(10, 4))
data = paddle.static.data(
name='tindata', shape=[10, 1000], dtype="float32")
if rank == 0:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[0:5, :]), )
else:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[5:9, :]), )
emb_out = paddle.distributed.split(
data_in, (7, 8),
operation="embedding",
num_partitions=2,
weight_attr=param_attr)
return [data_in, emb_out]
if __name__ == "__main__":
runtime_main(TestParallelEmbeddingAPINoneDivisible, "parallel_embedding")
......@@ -257,11 +257,10 @@ class TestDistBase(unittest.TestCase):
elif col_type == "parallel_embedding":
result_data = tr0_out[0]
np.random.seed(2020)
need_result = np.random.rand(10, 8)
need_result = np.random.rand(12, 8)
for i in range(result_data.shape[0]):
for j in range(result_data.shape[1]):
data = result_data[i][j]
if data >= 4: data += 1
assert np.allclose(
tr0_out[1][i][j], need_result[data], atol=1e-08)
elif col_type == "row_parallel_linear":
......
......@@ -16,20 +16,24 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_api_base import TestDistBase
from paddle.distributed import fleet
paddle.enable_static()
class TestParallelEmbeddingNoneDivisibleAPI(TestDistBase):
def _setup_config(self):
pass
class TestCollectiveSplitAssert(unittest.TestCase):
def network(self):
fleet.init()
data = paddle.static.data(
name='tindata', shape=[10, 1000], dtype="float32")
emb_out = paddle.distributed.split(
data, (7, 8), operation="embedding", num_partitions=2)
def test_parallel_embedding_none_divisible(self):
self.check_with_place("parallel_embedding_api_none_divisible.py",
"parallel_embedding", "nccl")
def test_assert(self):
with self.assertRaises(AssertionError):
self.network()
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册