diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index f4c3e5a5800ee798b42873c3300fa45d13d0ef01..3be84c55126bff6f6b5984ef5831e4ce3b5c9199 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -44,6 +44,15 @@ from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost +def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping): + if trans_x: + x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[ + -2], x_dims_mapping[-1] + if trans_y: + y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[ + -2], y_dims_mapping[-1] + + def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): dist_op_desc = block.append_op(type='nop').desc dist_op_desc.copy_from(src_op.desc) @@ -90,6 +99,8 @@ def _update_dims_mapping_for_matmul(dist_op): y_dims_mapping.insert(1, -1) out_dims_mapping.insert(out_dims_mapping_len, 0) + trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) + new_x_dims_mapping_len = len(x_dims_mapping) new_y_dims_mapping_len = len(y_dims_mapping) new_out_dims_mapping_len = len(out_dims_mapping) @@ -117,6 +128,8 @@ def _update_dims_mapping_for_matmul(dist_op): broadcast_out_dims_mapping ]) if compatible_dims_mapping is None: + trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, + y_dims_mapping) return False for i in range(new_x_dims_mapping_len - 2): @@ -136,13 +149,6 @@ def _update_dims_mapping_for_matmul(dist_op): out_dims_mapping[i] = compatible_dims_mapping[i] changed = True - if trans_x: - x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[ - -2], x_dims_mapping[-1] - if trans_y: - y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[ - -2], y_dims_mapping[-1] - # The following which uses negative index can be work # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 dim_changed = compute_compatible_and_update_dim_mapping( @@ -160,12 +166,7 @@ def _update_dims_mapping_for_matmul(dist_op): if dim_changed: changed = True - if trans_x: - x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[ - -2], x_dims_mapping[-1] - if trans_y: - y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[ - -2], y_dims_mapping[-1] + trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor if x_dims_mapping_len == 1: @@ -188,6 +189,15 @@ def _is_auto_compatible_for_matmul(dist_op): x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] out_name = op_desc.output('Out')[0] + trans_x = None + trans_y = None + if op_desc.type() == "matmul_v2": + trans_x = op_desc.attr('trans_x') + trans_y = op_desc.attr('trans_y') + elif op_desc.type() == "matmul": + trans_x = op_desc.attr('transpose_X') + trans_y = op_desc.attr('transpose_Y') + # Deep copy these dims_mappings for keeping them unchanged. x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name)) y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name)) @@ -203,17 +213,7 @@ def _is_auto_compatible_for_matmul(dist_op): if y_dims_mapping_len == 1: y_dims_mapping.insert(1, -1) - # NOTE: Partition is not supported if matmul op has trans. - if op_desc.type() == "matmul_v2": - if op_desc.attr('trans_x') or op_desc.attr('trans_y'): - if x_dims_mapping[-2:] != [-1, -1 - ] or y_dims_mapping[-2:] != [-1, -1]: - return False - elif op_desc.type() == "matmul": - if op_desc.attr('transpose_X') or op_desc.attr('transpose_Y'): - if x_dims_mapping[-2:] != [-1, -1 - ] or y_dims_mapping[-2:] != [-1, -1]: - return False + trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) # Deal with dim > 2 and take care of broadcasting if out_dims_mapping_len > 2: @@ -304,9 +304,23 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ), "left operand(X) [{}] of dist matmul should not be parameter".format( X_var.name) + X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name) Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) process_mesh_shape = dist_attr.process_mesh.topology process_mesh_group = dist_attr.process_mesh.processes + + trans_x = None + trans_y = None + if backward_op.desc.type() == "matmul_v2_grad": + trans_x = backward_op.desc.attr('trans_x') + trans_y = backward_op.desc.attr('trans_y') + elif backward_op.desc.type() == "matmul_grad": + trans_x = backward_op.desc.attr('transpose_X') + trans_y = backward_op.desc.attr('transpose_Y') + + if trans_y: + trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping) + # assert len( # Y_var_dim_mapping # ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format( @@ -431,9 +445,17 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): if is_parameter_related(Y_var.name, main_block): out_grad_names = [kwargs['Y@GRAD'][0]] + if trans_x: + trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None) + gradient_synchronization(ctx, backward_op, act_grad_names, out_grad_names, rank_id) + if trans_x: + trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None) + if trans_y: + trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping) + def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): @@ -583,8 +605,13 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + x_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(x_name)) + y_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(y_name)) + trans_x = op_desc.attr('transpose_X') + trans_y = op_desc.attr('transpose_Y') + trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) if is_dim_shard(x_dims_mapping[-1]): return False if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate( @@ -660,10 +687,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): X_var = main_block.var(kwargs['X'][0]) Weight_var = main_block.var(kwargs['Y'][0]) Out_var = main_block.var(kwargs['Out'][0]) + trans_x = src_op.attr("transpose_X") + trans_y = src_op.attr("transpose_Y") # TODO infer logic comm presentation matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name)[-1] + if trans_y: + matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[-2] assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_col_dim_mapping) process_mesh_shape = op_dist_attr.process_mesh.topology @@ -723,10 +755,10 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): check_dtype(intermediate_var_0.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') attrs = { - 'transpose_X': False, - 'transpose_Y': False, + 'transpose_X': trans_x, + 'transpose_Y': trans_y, 'alpha': 1, - OP_ROLE_KEY: src_op('op_role') + OP_ROLE_KEY: src_op.attr('op_role') } inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} matmul_op = main_block.append_op(type='matmul', @@ -902,8 +934,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + x_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(x_name)) + y_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(y_name)) + trans_x = op_desc.attr('transpose_X') + trans_y = op_desc.attr('transpose_Y') + trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) if is_dim_replicate(x_dims_mapping[-1]): return False if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard( @@ -932,10 +969,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): if (not self.is_input_compatible(dist_op)) or \ (not self.is_output_compatible(dist_op)): return False - if not _is_auto_compatible_for_matmul(dist_op): return False - return True def update_dims_mapping(self, dist_op): @@ -983,10 +1018,15 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): X_var = main_block.var(kwargs['X'][0]) Weight_var = main_block.var(kwargs['Y'][0]) Out_var = main_block.var(kwargs['Out'][0]) + trans_x = src_op.attr('transpose_X') + trans_y = src_op.attr('transpose_Y') # TODO infer logic comm presentation matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name)[-2] + if trans_y: + matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[-1] assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_row_dim_mapping) process_mesh_shape = op_dist_attr.process_mesh.topology @@ -1002,8 +1042,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') attrs = { - 'transpose_X': False, - 'transpose_Y': False, + 'transpose_X': trans_x, + 'transpose_Y': trans_y, 'alpha': 1, OP_ROLE_KEY: src_op.attr('op_role') } @@ -1354,8 +1394,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + x_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(x_name)) + y_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(y_name)) + trans_x = op_desc.attr('trans_x') + trans_y = op_desc.attr('trans_y') + trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) if is_dim_shard(x_dims_mapping[-1]): return False if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate( @@ -1382,10 +1427,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): if (not self.is_input_compatible(dist_op)) or \ (not self.is_output_compatible(dist_op)): return False - if not _is_auto_compatible_for_matmul(dist_op): return False - return True def update_dims_mapping(self, dist_op): @@ -1433,10 +1476,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): X_var = main_block.var(kwargs['X'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0]) Out_var = main_block.var(kwargs['Out'][0]) + trans_x = src_op.attr('trans_x') + trans_y = src_op.attr('trans_y') # TODO infer logic comm presentation matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name)[-1] + if trans_y: + matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[-2] assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_col_dim_mapping) process_mesh_shape = op_dist_attr.process_mesh.topology @@ -1495,8 +1543,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): check_dtype(intermediate_var_0.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') attrs = { - 'trans_x': False, - 'trans_y': False, + 'trans_x': trans_x, + 'trans_y': trans_y, OP_ROLE_KEY: src_op.attr('op_role') } inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} @@ -1670,8 +1718,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + x_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(x_name)) + y_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(y_name)) + trans_x = op_desc.attr('trans_x') + trans_y = op_desc.attr('trans_y') + trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) if is_dim_replicate(x_dims_mapping[-1]): return False if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard( @@ -1700,10 +1753,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): if (not self.is_input_compatible(dist_op)) or \ (not self.is_output_compatible(dist_op)): return False - if not _is_auto_compatible_for_matmul(dist_op): return False - return True def update_dims_mapping(self, dist_op): @@ -1751,10 +1802,15 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): X_var = main_block.var(kwargs['X'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0]) Out_var = main_block.var(kwargs['Out'][0]) + trans_x = src_op.attr('trans_x') + trans_y = src_op.attr('trans_y') # TODO infer logic comm presentation matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( Weight_var.name)[-2] + if trans_y: + matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[-1] assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_row_dim_mapping) process_mesh_shape = op_dist_attr.process_mesh.topology @@ -1770,8 +1826,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') attrs = { - 'trans_x': False, - 'trans_y': False, + 'trans_x': trans_x, + 'trans_y': trans_y, OP_ROLE_KEY: src_op.attr('op_role') } inputs = {'X': X_var, 'Y': Weight_var} diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index b78540701eadce8872258377333cfb3829b0c69b..422b3db42c3323875896973592730acf82a7fd8a 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -71,4 +71,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip) py_test_modules(test_quantization MODULES test_quantization) + py_test_modules(test_dist_matmul MODULES test_dist_matmul) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf2b47660fe5cbdb44c280ab831099ead66e37a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py @@ -0,0 +1,374 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import paddle.distributed.auto_parallel as auto + +from paddle.fluid import program_guard +from paddle.fluid.backward import append_backward +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr + +paddle.enable_static() + +mesh = [[0, 1], [2, 3]] + + +def init_x_row(trans_x): + if trans_x: + x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32') + auto.shard_tensor(x, + dist_attr={ + "process_mesh": mesh, + "dims_mapping": [0, 1, -1] + }) + return x + else: + x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32') + auto.shard_tensor(x, + dist_attr={ + "process_mesh": mesh, + "dims_mapping": [0, -1, 1] + }) + return x + + +def init_x_col(trans_x): + if trans_x: + x = paddle.static.data(name='x', shape=[6, 8], dtype='float32') + auto.shard_tensor(x, + dist_attr={ + "process_mesh": mesh, + "dims_mapping": [-1, 0] + }) + return x + else: + x = paddle.static.data(name='x', shape=[8, 6], dtype='float32') + auto.shard_tensor(x, + dist_attr={ + "process_mesh": mesh, + "dims_mapping": [0, -1] + }) + return x + + +def init_y_row(trans_y): + if trans_y: + y = paddle.static.data(name='y', shape=[4, 6], dtype='float32') + auto.shard_tensor(y, + dist_attr={ + "process_mesh": mesh, + "dims_mapping": [-1, 1] + }) + return y + else: + y = paddle.static.data(name='y', shape=[6, 4], dtype='float32') + auto.shard_tensor(y, + dist_attr={ + "process_mesh": mesh, + "dims_mapping": [1, -1] + }) + return y + + +def init_y_col(trans_y): + if trans_y: + y = paddle.static.data(name='y', shape=[4, 6], dtype='float32') + auto.shard_tensor(y, + dist_attr={ + "process_mesh": mesh, + "dims_mapping": [1, -1] + }) + return y + else: + y = paddle.static.data(name='y', shape=[6, 4], dtype='float32') + auto.shard_tensor(y, + dist_attr={ + "process_mesh": mesh, + "dims_mapping": [-1, 1] + }) + return y + + +def matmul_dp2mp2(init_x, init_y, trans_x, trans_y): + main_program = paddle.fluid.Program() + start_program = paddle.fluid.Program() + with paddle.static.program_guard(main_program, start_program): + x = init_x(trans_x) + y = init_y(trans_y) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.fluid.layers.matmul(x, + y, + transpose_x=trans_x, + transpose_y=trans_y) + loss = paddle.mean(out) + return main_program, start_program, loss + + +def matmulv2_dp2mp2(init_x, init_y, trans_x, trans_y): + main_program = paddle.fluid.Program() + start_program = paddle.fluid.Program() + with paddle.static.program_guard(main_program, start_program): + x = init_x(trans_x) + y = init_y(trans_y) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.matmul(x, y, transpose_x=trans_x, transpose_y=trans_y) + loss = paddle.mean(out) + return main_program, start_program, loss + + +def parallelizer(program_func, *args, **kwargs): + from paddle.distributed.auto_parallel.completion import Completer + from paddle.distributed.auto_parallel.partitioner import Partitioner + from paddle.distributed.auto_parallel.dist_context import DistributedContext + + main_program, start_program, loss = program_func(*args, **kwargs) + + dist_context = DistributedContext() + completer = Completer(dist_context) + completer.complete_forward_annotation(main_program) + dist_context.block_state.parse_forward_blocks(main_program) + + with program_guard(main_program, start_program): + append_backward(loss, distop_context=dist_context.dist_op_context) + completer.complete_backward_annotation(main_program) + dist_context.block_state.parse_backward_blocks(main_program) + + partitioner = Partitioner(dist_context, 0) + dist_main_prog, _, _ = partitioner.partition(main_program, start_program, + []) + + return dist_main_prog, dist_context + + +class TestDistMatmul(unittest.TestCase): + + def check_col_program(self, main_program, dist_ctx): + # [0, -1] * [-1, 1] --> [0, 1] + ref_ops = [ + "c_identity", "matmul", "reduce_mean", "fill_constant", + "reduce_mean_grad", "matmul_grad" + ] + ops = [] + block = main_program.global_block() + for op in block.ops: + ops.append(op.type) + if op.type == "matmul": + out_name = op.output('Out')[0] + out_var = block.vars[out_name] + op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_idx == 0 + assert op_dist_attr.impl_type == "matmul" + out_dims_mapping = op_dist_attr.get_output_dims_mapping( + out_name) + assert out_dims_mapping == [0, 1] + tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program( + out_var) + assert tensor_dist_attr.dims_mapping == [0, 1] + if op.type == "matmul_grad": + op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_idx == 0 + assert op_dist_attr.impl_type == "matmul" + + assert ops == ref_ops + + def check_row_program(self, main_program, dist_ctx): + # [0, -1, 1] * [1, -1] --> [0, -1, -1] + ref_ops = [ + "matmul", "c_allreduce_sum", "reduce_mean", "fill_constant", + "reduce_mean_grad", "matmul_grad" + ] + ops = [] + block = main_program.global_block() + for op in block.ops: + ops.append(op.type) + if op.type == "matmul": + out_name = op.output('Out')[0] + out_var = block.vars[out_name] + op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_idx == 1 + assert op_dist_attr.impl_type == "matmul" + out_dims_mapping = op_dist_attr.get_output_dims_mapping( + out_name) + assert out_dims_mapping == [0, -1, -1] + tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program( + out_var) + assert tensor_dist_attr.dims_mapping == [0, -1, -1] + if op.type == "matmul_grad": + op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_idx == 1 + assert op_dist_attr.impl_type == "matmul" + assert ops == ref_ops + + +class TestDistMatmulCol(TestDistMatmul): + + def init(self, trans_x, trans_y): + dist_main_prog, dist_ctx = parallelizer(matmul_dp2mp2, init_x_col, + init_y_col, trans_x, trans_y) + return dist_main_prog, dist_ctx + + def test_matmul_col(self): + dist_main_prog, dist_ctx = self.init(False, False) + self.check_col_program(dist_main_prog, dist_ctx) + + def test_trans_x(self): + dist_main_prog, dist_ctx = self.init(True, False) + self.check_col_program(dist_main_prog, dist_ctx) + + def test_trans_y(self): + dist_main_prog, dist_ctx = self.init(False, True) + self.check_col_program(dist_main_prog, dist_ctx) + + def test_trans_x_trans_y(self): + dist_main_prog, dist_ctx = self.init(True, True) + self.check_col_program(dist_main_prog, dist_ctx) + + +class TestDistMatmulRow(TestDistMatmul): + + def init(self, trans_x, trans_y): + dist_main_prog, dist_ctx = parallelizer(matmul_dp2mp2, init_x_row, + init_y_row, trans_x, trans_y) + return dist_main_prog, dist_ctx + + def test_matmul_row(self): + dist_main_prog, dist_ctx = self.init(False, False) + self.check_row_program(dist_main_prog, dist_ctx) + + def test_trans_x(self): + dist_main_prog, dist_ctx = self.init(True, False) + self.check_row_program(dist_main_prog, dist_ctx) + + def test_trans_y(self): + dist_main_prog, dist_ctx = self.init(False, True) + self.check_row_program(dist_main_prog, dist_ctx) + + def test_trans_x_trans_y(self): + dist_main_prog, dist_ctx = self.init(True, True) + self.check_row_program(dist_main_prog, dist_ctx) + + +class TestDistMatmulV2(unittest.TestCase): + + def check_col_program(self, main_program, dist_ctx): + # [0, -1] * [-1, 1] --> [0, 1] + ref_ops = [ + "c_identity", "matmul_v2", "reduce_mean", "fill_constant", + "reduce_mean_grad", "matmul_v2_grad" + ] + ops = [] + block = main_program.global_block() + for op in block.ops: + ops.append(op.type) + if op.type == "matmul_v2": + out_name = op.output('Out')[0] + out_var = block.vars[out_name] + op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_idx == 0 + assert op_dist_attr.impl_type == "matmul_v2" + out_dims_mapping = op_dist_attr.get_output_dims_mapping( + out_name) + assert out_dims_mapping == [0, 1] + tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program( + out_var) + assert tensor_dist_attr.dims_mapping == [0, 1] + if op.type == "matmul_v2_grad": + op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_idx == 0 + assert op_dist_attr.impl_type == "matmul_v2" + + assert ops == ref_ops + + def check_row_program(self, main_program, dist_ctx): + # [0, -1, 1] * [1, -1] --> [0, -1, -1] + ref_ops = [ + "matmul_v2", "c_allreduce_sum", "reduce_mean", "fill_constant", + "reduce_mean_grad", "matmul_v2_grad" + ] + ops = [] + block = main_program.global_block() + for op in block.ops: + ops.append(op.type) + if op.type == "matmul_v2": + out_name = op.output('Out')[0] + out_var = block.vars[out_name] + op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_idx == 1 + assert op_dist_attr.impl_type == "matmul_v2" + out_dims_mapping = op_dist_attr.get_output_dims_mapping( + out_name) + assert out_dims_mapping == [0, -1, -1] + tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program( + out_var) + assert tensor_dist_attr.dims_mapping == [0, -1, -1] + if op.type == "matmul_v2_grad": + op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_idx == 1 + assert op_dist_attr.impl_type == "matmul_v2" + assert ops == ref_ops + + +class TestDistMatmulV2Col(TestDistMatmulV2): + + def init(self, trans_x, trans_y): + dist_main_prog, dist_ctx = parallelizer(matmulv2_dp2mp2, init_x_col, + init_y_col, trans_x, trans_y) + return dist_main_prog, dist_ctx + + def test_matmul_col(self): + dist_main_prog, dist_ctx = self.init(False, False) + self.check_col_program(dist_main_prog, dist_ctx) + + def test_trans_x(self): + dist_main_prog, dist_ctx = self.init(True, False) + self.check_col_program(dist_main_prog, dist_ctx) + + def test_trans_y(self): + dist_main_prog, dist_ctx = self.init(False, True) + self.check_col_program(dist_main_prog, dist_ctx) + + def test_trans_x_trans_y(self): + dist_main_prog, dist_ctx = self.init(True, True) + self.check_col_program(dist_main_prog, dist_ctx) + + +class TestDistMatmulV2Row(TestDistMatmulV2): + + def init(self, trans_x, trans_y): + dist_main_prog, dist_ctx = parallelizer(matmulv2_dp2mp2, init_x_row, + init_y_row, trans_x, trans_y) + return dist_main_prog, dist_ctx + + def test_matmul_row(self): + dist_main_prog, dist_ctx = self.init(False, False) + self.check_row_program(dist_main_prog, dist_ctx) + + def test_trans_x(self): + dist_main_prog, dist_ctx = self.init(True, False) + self.check_row_program(dist_main_prog, dist_ctx) + + def test_trans_y(self): + dist_main_prog, dist_ctx = self.init(False, True) + self.check_row_program(dist_main_prog, dist_ctx) + + def test_trans_x_trans_y(self): + dist_main_prog, dist_ctx = self.init(True, True) + self.check_row_program(dist_main_prog, dist_ctx) + + +if __name__ == "__main__": + unittest.main()