diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 9f84df2d8963432b76abd8a5c03efae03bc3560c..db6f909f8ca7da66366656b33c02fa4f647ad5bb 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -27,3 +27,4 @@ from . import dist_eltwise from . import dist_check_finite_and_unscale from . import dist_update_loss_scaling from . import dist_split +from . import dist_fill_constant_batch_size_like diff --git a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py old mode 100755 new mode 100644 diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 94eb0d2d469f0595fdc8cb31821d6cded9ad064a..32f8e2acef5e103c870b6861ade4dc334c7329b5 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -155,7 +155,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): kwargs['Out']) Ids_var = main_block.var(kwargs['Ids'][0]) - Weight_var = main_block.var(kwargs['W'][0]) + Weight_var = main_block._var_recursive(kwargs['W'][0]) Out_var = main_block.var(kwargs['Out'][0]) # got dist attribute info @@ -277,7 +277,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): # param initialization sync if Weight_var.is_parameter and not op_dist_attr.is_recompute: - assert Weight_var.name not in dist_op_context.already_init_sync_vars + if Weight_var.name in dist_op_context.already_init_sync_vars: + return dist_op_context.already_init_sync_vars.add(Weight_var.name) param = startup_block.var(Weight_var.name) param_dist_attr = ctx.get_tensor_dist_attr_for_program(param) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9d9eda02e1bfbc855c0a7ad0e943e44f362fce --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl +from ..utils import is_dim_shard +from ..utils import is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping +from ..utils import set_dist_op_desc_original_id +from paddle.fluid import core, unique_name +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from .dist_default import DistributedDefaultImpl0 + + +class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer): + def __init__(self, op_type): + super(DistributedFillConstantBatchSizeLike, self).__init__(op_type) + + +register_distributed_operator_impl_container( + DistributedFillConstantBatchSizeLike("fill_constant_batch_size_like")) + + +class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedFillConstantBatchSizeLikeImpl0, self).__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + + return True + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + shape_list = op_desc.attr("shape") + + if len(shape_list) != len(out_dims_mapping): + return False + + return True + + def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + in_name = op_desc.input('Input')[0] + in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) + + # the dim_mapping of batch dimension should be the same + return out_dims_mapping[0] == in_dims_mapping[0] + + def update_dims_mapping(self, dist_op): + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + # only the batch size dimemsion of input and output are relative. + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, out_dims_mapping], [0, 0]) + if dim_changed: + changed = True + + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + """ + kwargs: inputname_mapping & outputname_mapping + """ + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + dist_op_context = ctx.dist_op_context + src_op = dist_op_context.cur_src_op + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + main_block = dist_op_context.work_block + op = main_block.ops[-1] + assert op.type == "fill_constant_batch_size_like" + + # modify shape attr according to how output are partitioned + out_name = op.output('Out')[0] + dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + process_mesh_shape = op_dist_attr.process_mesh.topology + shape_list = op.attr("shape") + # modify target shape + for idx, axis in enumerate(dims_mapping): + if axis >= 0: + shape_list[idx] = shape_list[idx] // process_mesh_shape[axis] + + op._set_attr("shape", shape_list) + main_block._sync_with_cpp() + + @staticmethod + def backward(ctx, *args, **kwargs): + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl( + "fill_constant_batch_size_like", + DistributedFillConstantBatchSizeLikeImpl0("fill_by_shape")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 9eb24a65e608c22573342f32dfd0dc96a601e3ac..058ae1d0a9fd5c25ec83ea15ed9c2e479322957c 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -433,8 +433,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): - assert Weight_var.name not in dist_op_context.already_init_sync_vars, "{} is in {}.".format( - Weight_var.name, dist_op_context.already_init_sync_vars) + if Weight_var.name in dist_op_context.already_init_sync_vars: + return assert startup_block.has_var(Weight_var.name) dist_op_context.already_init_sync_vars.add(Weight_var.name) param = startup_block.var(Weight_var.name) @@ -819,6 +819,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): out_var_dist_attr) intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_allreduce_sum", 'tmp'])), shape=Out_var.shape, dtype=Out_var.dtype, type=Out_var.type, @@ -1323,6 +1325,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): out_var_dist_attr) intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_allreduce_sum", 'tmp'])), shape=Out_var.shape, dtype=Out_var.dtype, type=Out_var.type, diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 2f88407c093a534d1d67133aece636127ff29626..ed5ec85d84f224a38358dd531f3df490e3c160f1 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -285,6 +285,9 @@ def _get_dist_shape(var, dist_attr): var_shape = var.shape mapping = dist_attr.dims_mapping mesh = dist_attr.process_mesh.topology + if mapping == []: + return var_shape + assert len(var_shape) == len( mapping ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_auto_parallel_while_op.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_auto_parallel_while_op.py index 1cd8f8f3e7083d61bd4a30ca114d0ac2a099ba47..07e6a2c4346da42fd9ff0aafef12bf453cc6f463 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_auto_parallel_while_op.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_auto_parallel_while_op.py @@ -174,6 +174,7 @@ def get_program(): dtype='float32') label = static.data( name="label", shape=[batch_size, sequence_len, 1], dtype='float32') + data_holder = [input, label] # dataloader dataloader = paddle.io.DataLoader.from_generator( @@ -194,6 +195,17 @@ def get_program(): "dims_mapping": [-1, -1, -1] }) + # fill constant bsz like + tmp = paddle.fluid.layers.fill_constant_batch_size_like( + input=input, shape=[-1, 16, 0, 48], dtype='float32', value=0) + auto.shard_tensor( + tmp, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, 0, -1, -1] + }) + + # model mlp_start = MLPLayer( hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -395,6 +407,9 @@ def completion(train_program, start_program, dist_context): op_dist_attr.impl_idx = 0 else: op_dist_attr.impl_idx = 1 + elif op.type == "fill_constant_batch_size_like": + op_dist_attr.impl_type = "fill_constant_batch_size_like" + op_dist_attr.impl_idx = 0 else: op_dist_attr.impl_type = "default" op_dist_attr.impl_idx = 0 @@ -428,6 +443,12 @@ class TestMLP(unittest.TestCase): dist_main_prog, dist_startup_prog = partition( train_program, start_program, dist_context) global_block_ops = dist_main_prog.blocks[0].ops + + fill_op = None + for op in global_block_ops: + if op.type == "fill_constant_batch_size_like": + fill_op = op + global_block_ops = [op.type for op in global_block_ops] sub_block_ops = dist_main_prog.blocks[1].ops sub_block_ops = [op.type for op in sub_block_ops] @@ -435,6 +456,13 @@ class TestMLP(unittest.TestCase): self.assertTrue("c_allreduce_sum" in global_block_ops) self.assertTrue("c_allreduce_sum" in sub_block_ops) + # test fill_constant_batch_size_like + + self.assertTrue(fill_op is not None) + ref_shape = [-1, 8, 0, 48] + shape = fill_op.attr("shape") + self.assertTrue(ref_shape == shape) + if __name__ == "__main__": unittest.main()