From 621d8a71203a2290e273164c965c0314c93e31d8 Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Mon, 28 Jun 2021 19:35:37 +0800 Subject: [PATCH] mode_c_embeding_bugs (#33801) --- python/paddle/distributed/collective.py | 91 ++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 3f0d97075c8..cdad59cabf1 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -1219,6 +1219,65 @@ def _parallel_embedding(x, return out +def _parallel_embedding_npu(x, + per_part_embeddings, + origin_size, + param_attr, + inner_rank, + num_partitions, + name, + group=None): + """ + NPU Parallel Embedding + """ + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + + origin_num_embeddings = origin_size[0] + embedding = paddle.nn.Embedding( + per_part_embeddings, + origin_size[1], + padding_idx=per_part_embeddings - 1, + sparse=False, + weight_attr=param_attr, + name=name) + + origin_input_shape = x.shape + if len(origin_input_shape) == 2: + x = paddle.unsqueeze(x, axis=-1) + else: + assert origin_input_shape[-1] == 1, ( + "The last dimension size of x must be 1.") + x_shard = paddle.shard_index(x, origin_num_embeddings, num_partitions, + inner_rank, per_part_embeddings - 1) + if len(origin_input_shape) == 2: + x_shard = paddle.squeeze(x_shard, axis=-1) + emb_out = embedding(x_shard) + startup_block = paddle.static.default_startup_program().global_block() + main_block = paddle.static.default_main_program().global_block() + startup_block.vars[embedding.weight.name].is_distributed = True + main_block.vars[embedding.weight.name].is_distributed = True + out = main_block.create_var( + shape=emb_out.shape, + dtype=emb_out.dtype, + type=emb_out.type, + lod_level=emb_out.lod_level, + persistable=False, + is_data=False, + need_check_feed=emb_out.desc.need_check_feed()) + main_block.append_op( + type='c_allreduce_sum', + inputs={'X': emb_out}, + outputs={'Out': out}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + return out + + def split(x, size, operation, @@ -1332,16 +1391,28 @@ def split(x, "but received vocabulary={} num_partitions={}".format(size[0], num_partitions) per_part_size = size[0] // num_partitions - emb_out = _parallel_embedding( - x, - per_part_size, - size, - weight_attr, - inner_rank, - num_partitions, - name, - group=None) - return emb_out + if core.is_compiled_with_npu(): + emb_out = _parallel_embedding_npu( + x, + per_part_size, + size, + weight_attr, + inner_rank, + num_partitions, + name, + group=None) + return emb_out + else: + emb_out = _parallel_embedding( + x, + per_part_size, + size, + weight_attr, + inner_rank, + num_partitions, + name, + group=None) + return emb_out else: should_split = False if axis == 0: -- GitLab