未验证 提交 621d8a71 编写于 作者: B Baibaifan 提交者: GitHub

mode_c_embeding_bugs (#33801)

上级 95007981
......@@ -1219,6 +1219,65 @@ def _parallel_embedding(x,
return out
def _parallel_embedding_npu(x,
per_part_embeddings,
origin_size,
param_attr,
inner_rank,
num_partitions,
name,
group=None):
"""
NPU Parallel Embedding
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
origin_num_embeddings = origin_size[0]
embedding = paddle.nn.Embedding(
per_part_embeddings,
origin_size[1],
padding_idx=per_part_embeddings - 1,
sparse=False,
weight_attr=param_attr,
name=name)
origin_input_shape = x.shape
if len(origin_input_shape) == 2:
x = paddle.unsqueeze(x, axis=-1)
else:
assert origin_input_shape[-1] == 1, (
"The last dimension size of x must be 1.")
x_shard = paddle.shard_index(x, origin_num_embeddings, num_partitions,
inner_rank, per_part_embeddings - 1)
if len(origin_input_shape) == 2:
x_shard = paddle.squeeze(x_shard, axis=-1)
emb_out = embedding(x_shard)
startup_block = paddle.static.default_startup_program().global_block()
main_block = paddle.static.default_main_program().global_block()
startup_block.vars[embedding.weight.name].is_distributed = True
main_block.vars[embedding.weight.name].is_distributed = True
out = main_block.create_var(
shape=emb_out.shape,
dtype=emb_out.dtype,
type=emb_out.type,
lod_level=emb_out.lod_level,
persistable=False,
is_data=False,
need_check_feed=emb_out.desc.need_check_feed())
main_block.append_op(
type='c_allreduce_sum',
inputs={'X': emb_out},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
'use_model_parallel': True
})
return out
def split(x,
size,
operation,
......@@ -1332,6 +1391,18 @@ def split(x,
"but received vocabulary={} num_partitions={}".format(size[0], num_partitions)
per_part_size = size[0] // num_partitions
if core.is_compiled_with_npu():
emb_out = _parallel_embedding_npu(
x,
per_part_size,
size,
weight_attr,
inner_rank,
num_partitions,
name,
group=None)
return emb_out
else:
emb_out = _parallel_embedding(
x,
per_part_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册