From 31a437b19635fb676948dfa5b2d3d8f51ed4a3b4 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Fri, 14 Oct 2022 16:01:13 +0800 Subject: [PATCH] [AutoParallel] adapt for gpt-gen (#46771) * for gpt-gen * fix reshard * adapt assign and shape op * add dist_assign & unittest * add conditional block unittest * rename unittest --- .../distributed/auto_parallel/completion.py | 6 +- .../distributed/auto_parallel/engine.py | 8 +- .../auto_parallel/operators/__init__.py | 2 + .../auto_parallel/operators/dist_assign.py | 88 +++++++++++++++++ .../auto_parallel/operators/dist_shape.py | 73 ++++++++++++++ .../distributed/auto_parallel/reshard.py | 40 ++++---- .../paddle/distributed/auto_parallel/utils.py | 4 + .../unittests/auto_parallel/CMakeLists.txt | 4 + .../test_conditional_block_reshard.py | 96 +++++++++++++++++++ .../auto_parallel/test_dist_assign.py | 84 ++++++++++++++++ .../auto_parallel/test_dist_shape.py | 74 ++++++++++++++ 11 files changed, 458 insertions(+), 21 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/operators/dist_assign.py create mode 100644 python/paddle/distributed/auto_parallel/operators/dist_shape.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_conditional_block_reshard.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_dist_assign.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_dist_shape.py diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 5b9d4d427b..a4bee7a4ad 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -17,7 +17,7 @@ import time from paddle.fluid import core -from .utils import is_gradient_clip_op +from .utils import is_gradient_clip_op, __not_shape_var_type__ from .operators import find_compatible_distributed_operator_impls from .dist_context import _node_id from .dist_attribute import TensorDistributedAttribute @@ -491,14 +491,14 @@ class Completer: for tensor_node in node.inputs: if tensor_node.is_var() and tensor_node.var( ) is not None: - if tensor_node.var().type() == core.VarDesc.VarType.READER \ + if tensor_node.var().type() in __not_shape_var_type__ \ or len(tensor_node.var().shape()) != 1: flag = False break for tensor_node in node.outputs: if tensor_node.is_var() and tensor_node.var( ) is not None: - if tensor_node.var().type() == core.VarDesc.VarType.READER \ + if tensor_node.var().type() in __not_shape_var_type__ \ or len(tensor_node.var().shape()) != 1: flag = False break diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 60ee7d0ba3..7c550ab578 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -1139,8 +1139,10 @@ class Engine: self.to_mode(mode) if inputs or labels: self._skip_build = True + self._inputs_spec = inputs_spec + self._labels_spec = labels_spec self._inputs, self._labels = self._prepare_data_tensor( - inputs_spec, labels_spec, inputs, labels) + self._inputs_spec, self._labels_spec, inputs, labels) self._orig_main_prog = main_program if self._orig_main_prog is None: self._orig_main_prog = static.default_main_program() @@ -1152,9 +1154,11 @@ class Engine: else: self._switch_mode(self._mode) elif inputs_spec or labels_spec: + self._inputs_spec = inputs_spec + self._labels_spec = labels_spec self._outside_dataloader = True self._inputs, self._labels = self._prepare_data_tensor( - inputs_spec, labels_spec) + self._inputs_spec, self._labels_spec) self._orig_main_prog = main_program if self._orig_main_prog is None: self._orig_main_prog = static.default_main_program() diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 02b5138be2..4a0a05a4f1 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -33,3 +33,5 @@ from . import dist_slice from . import dist_fused_feedforward from . import dist_fused_attention from . import dist_reduce_sum_p +from . import dist_shape +from . import dist_assign diff --git a/python/paddle/distributed/auto_parallel/operators/dist_assign.py b/python/paddle/distributed/auto_parallel/operators/dist_assign.py new file mode 100644 index 0000000000..96923f461a --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_assign.py @@ -0,0 +1,88 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl +from .dist_default import DistributedDefaultImpl0 +from ..utils import compute_compatible_and_update_dim_mapping + + +class DistributedAssign(DistributedOperatorImplContainer): + + def __init__(self, op_type): + super(DistributedAssign, self).__init__(op_type) + + +register_distributed_operator_impl_container(DistributedAssign("assign")) + + +class DistributedAssignImpl(DistributedOperatorImpl): + + def __init__(self, name): + super(DistributedAssignImpl, self).__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + return True + + def is_output_compatible(self, dist_op): + return True + + def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + if x_dims_mapping != out_dims_mapping: + return False + + return True + + def update_dims_mapping(self, dist_op): + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + for i in range(len(x_dims_mapping)): + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, out_dims_mapping], [i, i]) + if dim_changed: + changed = True + + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + @staticmethod + def backward(ctx, *args, **kwargs): + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl("assign", DistributedAssignImpl("assign")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_shape.py b/python/paddle/distributed/auto_parallel/operators/dist_shape.py new file mode 100644 index 0000000000..313f296ab9 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_shape.py @@ -0,0 +1,73 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl +from .dist_default import DistributedDefaultImpl0 +from ..utils import is_dim_shard + + +class DistributedShape(DistributedOperatorImplContainer): + + def __init__(self, op_type): + super(DistributedShape, self).__init__(op_type) + + +register_distributed_operator_impl_container(DistributedShape("shape")) + + +class DistributedShapeImpl(DistributedOperatorImpl): + + def __init__(self, name): + super(DistributedShapeImpl, self).__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + return True + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + assert len(out_dims_mapping) == 1 + if is_dim_shard(out_dims_mapping[0]): + return False + + return True + + def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + + return True + + def update_dims_mapping(self, dist_op): + return False + + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + @staticmethod + def backward(ctx, *args, **kwargs): + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl("shape", DistributedShapeImpl("shape")) diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index d7f2949444..46057ad97c 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -34,6 +34,7 @@ _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] _g_gradient_clip_ops = [ "sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div" ] +_g_subblock_ops = ["while", "conditional_block"] def get_var_with_recursion(var_name, block, program): @@ -42,11 +43,11 @@ def get_var_with_recursion(var_name, block, program): if var_name in block.vars: var = block.vars[var_name] else: - parent_block = program.blocks[block.parent_idx] - if var_name in parent_block.vars: - var = parent_block.vars[var_name] - assert var is not None, \ - "{} is not found".format(var.name) + var = block._var_recursive(var_name) + # parent_block = program.blocks[block.parent_idx] + # if var_name in parent_block.vars: + # var = parent_block.vars[var_name] + assert var is not None, "{} is not found".format(var.name) return var @@ -1075,7 +1076,9 @@ class Resharder: new_Out = [] for var_name in while_op.output("Out"): for output_name in sub_block_op_outputs[::-1]: - if output_name.find(var_name) != -1: + if output_name.find(var_name) != -1 and ( + len(var_name) == len(output_name) + or "@RESHARD" in output_name): if output_name not in new_Out: new_Out.append(output_name) assert new_Out @@ -1104,13 +1107,15 @@ class Resharder: return False def is_condition_replicative(self, op): - assert op.type == "while" sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] - dist_op = self.dist_context.get_dist_op_for_program(op) - op_dist_attr = dist_op.dist_attr + + if op.type == "while": + input_cond = op.input("Condition") + elif op.type == "conditional_block": + input_cond = op.input("Cond") # the dims mapping of condition tensor should be replicative - for var_name in op.input("Condition"): + for var_name in input_cond: var = get_var_with_recursion(var_name, sub_block, self.auto_parallel_main_prog) dist_tensor = self.dist_context.get_dist_tensor_for_program(var) @@ -1660,9 +1665,9 @@ class Resharder: op.desc.set_input(proto.inputs[0].name, op.input("X") + while_op_X_append) - def _get_while_op_input_attrs(self, op, var_name): + def _get_subblock_input_attrs(self, op, var_name): # NOTE: Multi while loop is not supported - assert op.type == "while" + assert op.type in _g_subblock_ops sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] ops = sub_block.ops input_attrs = [] @@ -1713,8 +1718,8 @@ class Resharder: def get_op_input_attrs(self, op, var_name): op_input_attrs = [] - if op.type == "while": - op_input_attrs = self._get_while_op_input_attrs(op, var_name) + if op.type in _g_subblock_ops: + op_input_attrs = self._get_subblock_input_attrs(op, var_name) else: op_input_attrs = self._get_common_op_input_attrs(op, var_name) @@ -1818,7 +1823,7 @@ class Resharder: if dist_op is not None: op_input_dist_attrs = [ ] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)] - if op.type == "while": + if op.type in _g_subblock_ops: if not self.is_condition_replicative(op): raise ValueError( "Please check the condition due to the dims mapping is not replicative." @@ -1832,6 +1837,8 @@ class Resharder: if op.type == "while": # condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard input_var_names = op.input("X") + elif op.type == "conditional_block": + input_var_names = op.input("Input") else: input_var_names = op.input_arg_names # to avoid while op X order different @@ -1984,11 +1991,12 @@ class Resharder: idx = 0 # skip reader and ops whose process mesh is union skip_ops = [ - "create_py_reader", "create_double_buffer_reader", "read", "while", + "create_py_reader", "create_double_buffer_reader", "read", "write_to_array", "read_from_array" ] global _g_special_ops skip_ops += _g_special_ops + skip_ops += _g_subblock_ops while idx < len(block.ops): pre_op_count = len(block.ops) op = block.ops[idx] diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index db2ecc56da..88b5a08422 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -27,6 +27,10 @@ from paddle.distributed.auto_parallel.process_group import get_all_process_group from paddle.fluid.io import is_parameter, is_belong_to_optimizer from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute +__not_shape_var_type__ = [ + core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES +] + def get_logger(log_level, name="auto_parallel"): logger = logging.getLogger(name) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 766974090d..808546482b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -96,5 +96,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_interface MODULES test_interface) py_test_modules(test_strategy MODULES test_strategy) py_test_modules(test_pass_quantization MODULES test_pass_quantization) + py_test_modules(test_dist_shape MODULES test_dist_shape) + py_test_modules(test_dist_assign MODULES test_dist_assign) + py_test_modules(test_conditional_block_reshard MODULES + test_conditional_block_reshard) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_conditional_block_reshard.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_conditional_block_reshard.py new file mode 100644 index 0000000000..86371cbae6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_conditional_block_reshard.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.static import InputSpec +from paddle.distributed.fleet import auto + + +class MLPLayer(nn.Layer): + + def __init__(self, + hidden_size=64, + intermediate_size=4 * 64, + initializer_range=0.02): + super(MLPLayer, self).__init__() + self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5) + self.linear0 = nn.Linear( + hidden_size, + intermediate_size, + paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)), + bias_attr=None) + self.linear1 = nn.Linear( + intermediate_size, + hidden_size, + paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)), + bias_attr=None) + + def forward(self, input): + out = self.norm(input) + + auto.shard_tensor(self.linear0.weight, auto.ProcessMesh([0, 1], "x"), + [None, "x"]) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + + auto.shard_tensor(self.linear1.weight, auto.ProcessMesh([0, 1], "x"), + ["x", None]) + out = self.linear1(out) + + if paddle.mean(out) < 2: + out = self.norm(out) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + else: + out = self.norm(out) + out = self.linear0(out) + out = self.linear1(out) + + return out + + +def loss_fn(predict, label): + error_cost = paddle.nn.functional.square_error_cost(predict, label) + loss = paddle.mean(error_cost) + return loss + + +class TestSubblock(unittest.TestCase): + + def test_subblock(self): + + mlp = MLPLayer() + + strategy = auto.Strategy() + strategy.auto_mode = "semi" + + engine = auto.Engine(model=mlp, loss=loss_fn, strategy=strategy) + + input_sepc = InputSpec([4, 64], 'float32', 'input') + label_spec = InputSpec([4, 1], 'float32', 'label') + engine.prepare(inputs_spec=[input_sepc], + labels_spec=[label_spec], + mode="predict") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_assign.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_assign.py new file mode 100644 index 0000000000..b21dd606d8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_assign.py @@ -0,0 +1,84 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +from paddle.distributed.fleet import auto + +paddle.enable_static() + + +def make_program(): + main_program = paddle.fluid.Program() + start_program = paddle.fluid.Program() + with paddle.static.program_guard(main_program, start_program): + + x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32') + y = paddle.static.data(name='y', shape=[4, 4, 8], dtype='float32') + auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["d"]), + [None, "d", None]) + + z = paddle.add(x, y) + paddle.assign(x, output=z) + + return main_program, start_program + + +def parallelizer(program_func, rank): + from paddle.distributed.auto_parallel.completion import Completer + from paddle.distributed.auto_parallel.partitioner import Partitioner + from paddle.distributed.auto_parallel.dist_context import DistributedContext + + main_program, start_program = program_func() + + dist_context = DistributedContext() + completer = Completer(dist_context) + completer.complete_forward_annotation(main_program) + dist_context.block_state.parse_forward_blocks(main_program) + + partitioner = Partitioner(dist_context, rank) + dist_main_prog, _, _ = partitioner.partition(main_program, start_program, + []) + + return dist_main_prog, dist_context + + +class TestDistAssign(unittest.TestCase): + + def test_dist_assign(self): + + dist_main_prog, dist_context = parallelizer(make_program, 0) + ops = dist_main_prog.global_block().ops + for op in ops: + if op.type == "assign": + dist_op = dist_context.get_dist_op_for_program(op) + dist_op.dist_attr.impl_type == "assign" + dist_op.dist_attr.impl_idx == 0 + + x_name = op.input_arg_names[0] + out_name = op.output_arg_names[0] + out_var = dist_main_prog.global_block().vars[out_name] + dist_out = dist_context.get_dist_tensor_for_program(out_var) + + x_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( + x_name) + out_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( + out_name) + + assert x_dims_mapping == out_dims_mapping + assert out_dims_mapping == dist_out.dist_attr.dims_mapping + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_shape.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_shape.py new file mode 100644 index 0000000000..5e18b7d90c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_shape.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +from paddle.distributed.fleet import auto + +paddle.enable_static() + + +def make_program(): + main_program = paddle.fluid.Program() + start_program = paddle.fluid.Program() + with paddle.static.program_guard(main_program, start_program): + x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32') + x.stop_gradient = False + auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None, None]) + shape = paddle.shape(x) + return main_program, start_program + + +def parallelizer(program_func, rank): + from paddle.distributed.auto_parallel.completion import Completer + from paddle.distributed.auto_parallel.partitioner import Partitioner + from paddle.distributed.auto_parallel.dist_context import DistributedContext + + main_program, start_program = program_func() + + dist_context = DistributedContext() + completer = Completer(dist_context) + completer.complete_forward_annotation(main_program) + dist_context.block_state.parse_forward_blocks(main_program) + + partitioner = Partitioner(dist_context, rank) + dist_main_prog, _, _ = partitioner.partition(main_program, start_program, + []) + + return dist_main_prog, dist_context + + +class TestDistShape(unittest.TestCase): + + def test_dist_shape(self): + + dist_main_prog, dist_context = parallelizer(make_program, 0) + ops = dist_main_prog.global_block().ops + shape_op = ops[0] + dist_op = dist_context.get_dist_op_for_program(shape_op) + dist_op.dist_attr.impl_type == "shape" + dist_op.dist_attr.impl_idx == 0 + + in_name = shape_op.input_arg_names[0] + out_name = shape_op.output_arg_names[0] + in_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(in_name) + out_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(out_name) + + assert in_dims_mapping == [0, -1, -1] + assert out_dims_mapping == [-1] + + +if __name__ == "__main__": + unittest.main() -- GitLab