From 1fa1d1149e70d740e00c8ba4754b8aee91da4caf Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 31 Aug 2022 16:36:38 +0800 Subject: [PATCH] [Auto Parallel] Dist embedding Supports Ernie BigTable with 3d input (#45395) * bugfix (#45332) * dist embedding support lookup table v1 * add unitest * update unitest cmake --- .../auto_parallel/operators/dist_embedding.py | 81 +++++++++++++++++- .../unittests/auto_parallel/CMakeLists.txt | 2 + .../auto_parallel/test_dist_embedding.py | 83 +++++++++++++++++++ 3 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_dist_embedding.py diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index cf7779a02a..856d9c36bb 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -31,7 +31,7 @@ from paddle.fluid.framework import Program, Parameter, Variable from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..process_group import new_process_group -from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank +from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank, set_var_dist_attr from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs from ..cost import EmbeddingOpCost, EmbeddingGradOpCost @@ -48,6 +48,79 @@ register_distributed_operator_impl_container( DistributedEmbedding("lookup_table_v2")) register_distributed_operator_impl_container( DistributedEmbedding("c_embedding")) +register_distributed_operator_impl_container( + DistributedEmbedding("lookup_table")) + + +def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): + + assert len( + Ids_var.shape + ) == 3, "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format( + Ids_var.name, Ids_var.shape) + if not Ids_var.stop_gradient: + raise NotImplementedError( + 'Requiring the gradient of Ids of lookup_table(v1)dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).' + ) + + target_shape = list(Ids_var.shape[:-1]) + intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["dist_reshape", 'tmp'])), + dtype=Ids_var.dtype, + shape=target_shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=True) + + target_shape = [0] + list(Ids_var.shape[:-1]) + xshape_var = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["dist_Xshape", 'tmp'])), + dtype=Ids_var.dtype, + shape=target_shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=True) + + # TODO use inplace reshape for memory saving + reshape_op = main_block.append_op(type='reshape2', + inputs={'X': [Ids_var]}, + outputs={ + 'Out': [intermediate_var_0], + 'XShape': [xshape_var] + }, + attrs={ + "shape": [0, -1], + }) + + # set dist attr + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + Ids_var_dist_attr = op_dist_attr.get_input_dist_attr(Ids_var.name) + assert Ids_var_dist_attr is not None + intermediate_var_0_dist_attr = set_var_dist_attr( + ctx, intermediate_var_0, Ids_var_dist_attr.dims_mapping, + Ids_var_dist_attr.process_mesh) + set_var_dist_attr(ctx, xshape_var, + [-1] + list(Ids_var_dist_attr.dims_mapping), + Ids_var_dist_attr.process_mesh) + op_dist_attr.del_input_dist_attr(Ids_var.name) + op_dist_attr.set_input_dist_attr(intermediate_var_0.name, + intermediate_var_0_dist_attr) + + new_op_dist_attr = OperatorDistributedAttribute() + new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh + new_op_dist_attr.impl_type = "default" + new_op_dist_attr.impl_idx = 0 + new_op_dist_attr.set_input_dims_mapping(Ids_var.name, + Ids_var_dist_attr.dims_mapping) + new_op_dist_attr.set_output_dims_mapping(intermediate_var_0.name, + Ids_var_dist_attr.dims_mapping) + new_op_dist_attr.set_output_dims_mapping( + xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping)) + ctx.set_op_dist_attr_for_program(reshape_op, new_op_dist_attr) + + return intermediate_var_0 # RowParallel @@ -254,6 +327,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): Weight_var = main_block._var_recursive(kwargs['W'][0]) Out_var = main_block.var(kwargs['Out'][0]) + # support lookup_table_v1 + if src_op.type == 'lookup_table': + Ids_var = adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var) + # got dist attribute info embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name)[0] @@ -532,3 +609,5 @@ register_distributed_operator_impl("lookup_table_v2", DistributedEmbeddingImpl("row_parallel")) register_distributed_operator_impl("c_embedding", DistributedEmbeddingImpl("row_parallel")) +register_distributed_operator_impl("lookup_table", + DistributedEmbeddingImpl("row_parallel")) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index beb1c722dd..b78540701e 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -55,6 +55,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ${dist_ENVS}) py_test_modules(test_dist_reshape MODULES test_dist_reshape ENVS ${dist_ENVS}) py_test_modules(test_dist_pnorm MODULES test_dist_pnorm ENVS ${dist_ENVS}) + py_test_modules(test_dist_embedding MODULES test_dist_embedding ENVS + ${dist_ENVS}) py_test_modules(test_dist_slice MODULES test_dist_slice ENVS ${dist_ENVS}) py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS}) py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_embedding.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_embedding.py new file mode 100644 index 0000000000..0b81b5bd48 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_embedding.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import paddle.distributed.auto_parallel as auto + +from paddle.fluid import program_guard +from paddle.fluid.backward import append_backward +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr +from test_dist_pnorm import parallelizer + +paddle.enable_static() + + +def make_program_lookup_table_v1_mp_dp(): + main_program = paddle.fluid.Program() + start_program = paddle.fluid.Program() + block = main_program.global_block() + with paddle.static.program_guard(main_program, start_program): + + src_ids = paddle.static.data(name='src_ids', + shape=[12, 512, 1], + dtype='int64') + src_ids.stop_gradient = True + emb_out = paddle.fluid.layers.embedding( + input=src_ids, + size=[64, 128], + param_attr=paddle.fluid.ParamAttr(name="emb_weight"), + dtype="float32", + is_sparse=False) + loss = paddle.fluid.layers.reduce_mean(emb_out) + + auto.shard_tensor(src_ids, + dist_attr={ + "process_mesh": auto.ProcessMesh([[0, 1], [2, + 3]]), + "dims_mapping": [0, -1, -1] + }) + emb_weight = block.vars["emb_weight"] + auto.shard_tensor(emb_weight, + dist_attr={ + "process_mesh": auto.ProcessMesh([[0, 1], [2, + 3]]), + "dims_mapping": [1, -1] + }) + + return main_program, start_program, loss + + +class TestDistPNorm(unittest.TestCase): + + def test_lookup_table_v1_mp_dp(self): + + for rank in range(4): + dist_main_prog, dist_context = parallelizer( + make_program_lookup_table_v1_mp_dp, rank) + ops = dist_main_prog.global_block().ops + + op_types = [] + for op in ops: + op_types.append(op.type) + + assert op_types == [ + 'reshape2', 'c_embedding', 'c_allreduce_sum', 'reduce_mean', + 'fill_constant', 'reduce_mean_grad', 'c_identity', + 'c_embedding_grad', 'c_allreduce_sum', 'scale' + ] + + +if __name__ == "__main__": + unittest.main() -- GitLab