未验证 提交 1fa1d114 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Dist embedding Supports Ernie BigTable with 3d input (#45395)

* bugfix (#45332)

* dist embedding support lookup table v1

* add unitest

* update unitest cmake
上级 11e62d68
...@@ -31,7 +31,7 @@ from paddle.fluid.framework import Program, Parameter, Variable ...@@ -31,7 +31,7 @@ from paddle.fluid.framework import Program, Parameter, Variable
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank, set_var_dist_attr
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost from ..cost import EmbeddingOpCost, EmbeddingGradOpCost
...@@ -48,6 +48,79 @@ register_distributed_operator_impl_container( ...@@ -48,6 +48,79 @@ register_distributed_operator_impl_container(
DistributedEmbedding("lookup_table_v2")) DistributedEmbedding("lookup_table_v2"))
register_distributed_operator_impl_container( register_distributed_operator_impl_container(
DistributedEmbedding("c_embedding")) DistributedEmbedding("c_embedding"))
register_distributed_operator_impl_container(
DistributedEmbedding("lookup_table"))
def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
assert len(
Ids_var.shape
) == 3, "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format(
Ids_var.name, Ids_var.shape)
if not Ids_var.stop_gradient:
raise NotImplementedError(
'Requiring the gradient of Ids of lookup_table(v1)dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).'
)
target_shape = list(Ids_var.shape[:-1])
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["dist_reshape", 'tmp'])),
dtype=Ids_var.dtype,
shape=target_shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
target_shape = [0] + list(Ids_var.shape[:-1])
xshape_var = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["dist_Xshape", 'tmp'])),
dtype=Ids_var.dtype,
shape=target_shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
# TODO use inplace reshape for memory saving
reshape_op = main_block.append_op(type='reshape2',
inputs={'X': [Ids_var]},
outputs={
'Out': [intermediate_var_0],
'XShape': [xshape_var]
},
attrs={
"shape": [0, -1],
})
# set dist attr
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
Ids_var_dist_attr = op_dist_attr.get_input_dist_attr(Ids_var.name)
assert Ids_var_dist_attr is not None
intermediate_var_0_dist_attr = set_var_dist_attr(
ctx, intermediate_var_0, Ids_var_dist_attr.dims_mapping,
Ids_var_dist_attr.process_mesh)
set_var_dist_attr(ctx, xshape_var,
[-1] + list(Ids_var_dist_attr.dims_mapping),
Ids_var_dist_attr.process_mesh)
op_dist_attr.del_input_dist_attr(Ids_var.name)
op_dist_attr.set_input_dist_attr(intermediate_var_0.name,
intermediate_var_0_dist_attr)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh
new_op_dist_attr.impl_type = "default"
new_op_dist_attr.impl_idx = 0
new_op_dist_attr.set_input_dims_mapping(Ids_var.name,
Ids_var_dist_attr.dims_mapping)
new_op_dist_attr.set_output_dims_mapping(intermediate_var_0.name,
Ids_var_dist_attr.dims_mapping)
new_op_dist_attr.set_output_dims_mapping(
xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping))
ctx.set_op_dist_attr_for_program(reshape_op, new_op_dist_attr)
return intermediate_var_0
# RowParallel # RowParallel
...@@ -254,6 +327,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -254,6 +327,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
Weight_var = main_block._var_recursive(kwargs['W'][0]) Weight_var = main_block._var_recursive(kwargs['W'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block.var(kwargs['Out'][0])
# support lookup_table_v1
if src_op.type == 'lookup_table':
Ids_var = adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var)
# got dist attribute info # got dist attribute info
embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping( embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[0] Weight_var.name)[0]
...@@ -532,3 +609,5 @@ register_distributed_operator_impl("lookup_table_v2", ...@@ -532,3 +609,5 @@ register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel")) DistributedEmbeddingImpl("row_parallel"))
register_distributed_operator_impl("c_embedding", register_distributed_operator_impl("c_embedding",
DistributedEmbeddingImpl("row_parallel")) DistributedEmbeddingImpl("row_parallel"))
register_distributed_operator_impl("lookup_table",
DistributedEmbeddingImpl("row_parallel"))
...@@ -55,6 +55,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -55,6 +55,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
${dist_ENVS}) ${dist_ENVS})
py_test_modules(test_dist_reshape MODULES test_dist_reshape ENVS ${dist_ENVS}) py_test_modules(test_dist_reshape MODULES test_dist_reshape ENVS ${dist_ENVS})
py_test_modules(test_dist_pnorm MODULES test_dist_pnorm ENVS ${dist_ENVS}) py_test_modules(test_dist_pnorm MODULES test_dist_pnorm ENVS ${dist_ENVS})
py_test_modules(test_dist_embedding MODULES test_dist_embedding ENVS
${dist_ENVS})
py_test_modules(test_dist_slice MODULES test_dist_slice ENVS ${dist_ENVS}) py_test_modules(test_dist_slice MODULES test_dist_slice ENVS ${dist_ENVS})
py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS}) py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS})
py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS}) py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS})
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from test_dist_pnorm import parallelizer
paddle.enable_static()
def make_program_lookup_table_v1_mp_dp():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
block = main_program.global_block()
with paddle.static.program_guard(main_program, start_program):
src_ids = paddle.static.data(name='src_ids',
shape=[12, 512, 1],
dtype='int64')
src_ids.stop_gradient = True
emb_out = paddle.fluid.layers.embedding(
input=src_ids,
size=[64, 128],
param_attr=paddle.fluid.ParamAttr(name="emb_weight"),
dtype="float32",
is_sparse=False)
loss = paddle.fluid.layers.reduce_mean(emb_out)
auto.shard_tensor(src_ids,
dist_attr={
"process_mesh": auto.ProcessMesh([[0, 1], [2,
3]]),
"dims_mapping": [0, -1, -1]
})
emb_weight = block.vars["emb_weight"]
auto.shard_tensor(emb_weight,
dist_attr={
"process_mesh": auto.ProcessMesh([[0, 1], [2,
3]]),
"dims_mapping": [1, -1]
})
return main_program, start_program, loss
class TestDistPNorm(unittest.TestCase):
def test_lookup_table_v1_mp_dp(self):
for rank in range(4):
dist_main_prog, dist_context = parallelizer(
make_program_lookup_table_v1_mp_dp, rank)
ops = dist_main_prog.global_block().ops
op_types = []
for op in ops:
op_types.append(op.type)
assert op_types == [
'reshape2', 'c_embedding', 'c_allreduce_sum', 'reduce_mean',
'fill_constant', 'reduce_mean_grad', 'c_identity',
'c_embedding_grad', 'c_allreduce_sum', 'scale'
]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册