未验证 提交 4a9895b1 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] dist_matmul trans_x or trans_y (#45678)

* dist_matmul trans

* update unittest

* update cmakelist
上级 a12c806f
...@@ -44,6 +44,15 @@ from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost ...@@ -44,6 +44,15 @@ from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost
def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping):
if trans_x:
x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[
-2], x_dims_mapping[-1]
if trans_y:
y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[
-2], y_dims_mapping[-1]
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
dist_op_desc = block.append_op(type='nop').desc dist_op_desc = block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
...@@ -90,6 +99,8 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -90,6 +99,8 @@ def _update_dims_mapping_for_matmul(dist_op):
y_dims_mapping.insert(1, -1) y_dims_mapping.insert(1, -1)
out_dims_mapping.insert(out_dims_mapping_len, 0) out_dims_mapping.insert(out_dims_mapping_len, 0)
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
new_x_dims_mapping_len = len(x_dims_mapping) new_x_dims_mapping_len = len(x_dims_mapping)
new_y_dims_mapping_len = len(y_dims_mapping) new_y_dims_mapping_len = len(y_dims_mapping)
new_out_dims_mapping_len = len(out_dims_mapping) new_out_dims_mapping_len = len(out_dims_mapping)
...@@ -117,6 +128,8 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -117,6 +128,8 @@ def _update_dims_mapping_for_matmul(dist_op):
broadcast_out_dims_mapping broadcast_out_dims_mapping
]) ])
if compatible_dims_mapping is None: if compatible_dims_mapping is None:
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping,
y_dims_mapping)
return False return False
for i in range(new_x_dims_mapping_len - 2): for i in range(new_x_dims_mapping_len - 2):
...@@ -136,13 +149,6 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -136,13 +149,6 @@ def _update_dims_mapping_for_matmul(dist_op):
out_dims_mapping[i] = compatible_dims_mapping[i] out_dims_mapping[i] = compatible_dims_mapping[i]
changed = True changed = True
if trans_x:
x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[
-2], x_dims_mapping[-1]
if trans_y:
y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[
-2], y_dims_mapping[-1]
# The following which uses negative index can be work # The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed = compute_compatible_and_update_dim_mapping( dim_changed = compute_compatible_and_update_dim_mapping(
...@@ -160,12 +166,7 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -160,12 +166,7 @@ def _update_dims_mapping_for_matmul(dist_op):
if dim_changed: if dim_changed:
changed = True changed = True
if trans_x: trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[
-2], x_dims_mapping[-1]
if trans_y:
y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[
-2], y_dims_mapping[-1]
# Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
if x_dims_mapping_len == 1: if x_dims_mapping_len == 1:
...@@ -188,6 +189,15 @@ def _is_auto_compatible_for_matmul(dist_op): ...@@ -188,6 +189,15 @@ def _is_auto_compatible_for_matmul(dist_op):
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
trans_x = None
trans_y = None
if op_desc.type() == "matmul_v2":
trans_x = op_desc.attr('trans_x')
trans_y = op_desc.attr('trans_y')
elif op_desc.type() == "matmul":
trans_x = op_desc.attr('transpose_X')
trans_y = op_desc.attr('transpose_Y')
# Deep copy these dims_mappings for keeping them unchanged. # Deep copy these dims_mappings for keeping them unchanged.
x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name)) x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name))
y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name)) y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name))
...@@ -203,17 +213,7 @@ def _is_auto_compatible_for_matmul(dist_op): ...@@ -203,17 +213,7 @@ def _is_auto_compatible_for_matmul(dist_op):
if y_dims_mapping_len == 1: if y_dims_mapping_len == 1:
y_dims_mapping.insert(1, -1) y_dims_mapping.insert(1, -1)
# NOTE: Partition is not supported if matmul op has trans. trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if op_desc.type() == "matmul_v2":
if op_desc.attr('trans_x') or op_desc.attr('trans_y'):
if x_dims_mapping[-2:] != [-1, -1
] or y_dims_mapping[-2:] != [-1, -1]:
return False
elif op_desc.type() == "matmul":
if op_desc.attr('transpose_X') or op_desc.attr('transpose_Y'):
if x_dims_mapping[-2:] != [-1, -1
] or y_dims_mapping[-2:] != [-1, -1]:
return False
# Deal with dim > 2 and take care of broadcasting # Deal with dim > 2 and take care of broadcasting
if out_dims_mapping_len > 2: if out_dims_mapping_len > 2:
...@@ -304,9 +304,23 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -304,9 +304,23 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
), "left operand(X) [{}] of dist matmul should not be parameter".format( ), "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name) X_var.name)
X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name)
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes process_mesh_group = dist_attr.process_mesh.processes
trans_x = None
trans_y = None
if backward_op.desc.type() == "matmul_v2_grad":
trans_x = backward_op.desc.attr('trans_x')
trans_y = backward_op.desc.attr('trans_y')
elif backward_op.desc.type() == "matmul_grad":
trans_x = backward_op.desc.attr('transpose_X')
trans_y = backward_op.desc.attr('transpose_Y')
if trans_y:
trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)
# assert len( # assert len(
# Y_var_dim_mapping # Y_var_dim_mapping
# ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format( # ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
...@@ -431,9 +445,17 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -431,9 +445,17 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
if is_parameter_related(Y_var.name, main_block): if is_parameter_related(Y_var.name, main_block):
out_grad_names = [kwargs['Y@GRAD'][0]] out_grad_names = [kwargs['Y@GRAD'][0]]
if trans_x:
trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
gradient_synchronization(ctx, backward_op, act_grad_names, out_grad_names, gradient_synchronization(ctx, backward_op, act_grad_names, out_grad_names,
rank_id) rank_id)
if trans_x:
trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
if trans_y:
trans_x_y_dims_mapping(False, True, None, Y_var_dim_mapping)
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
...@@ -583,8 +605,13 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -583,8 +605,13 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = copy.deepcopy(
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) op_dist_attr.get_input_dims_mapping(x_name))
y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name))
trans_x = op_desc.attr('transpose_X')
trans_y = op_desc.attr('transpose_Y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_shard(x_dims_mapping[-1]): if is_dim_shard(x_dims_mapping[-1]):
return False return False
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate( if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
...@@ -660,10 +687,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -660,10 +687,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block.var(kwargs['Y'][0]) Weight_var = main_block.var(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block.var(kwargs['Out'][0])
trans_x = src_op.attr("transpose_X")
trans_y = src_op.attr("transpose_Y")
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1] Weight_var.name)[-1]
if trans_y:
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping) matmul_col_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
...@@ -723,10 +755,10 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -723,10 +755,10 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
check_dtype(intermediate_var_0.dtype, 'dtype', check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear') ['float16', 'float32', 'float64'], 'linear')
attrs = { attrs = {
'transpose_X': False, 'transpose_X': trans_x,
'transpose_Y': False, 'transpose_Y': trans_y,
'alpha': 1, 'alpha': 1,
OP_ROLE_KEY: src_op('op_role') OP_ROLE_KEY: src_op.attr('op_role')
} }
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
matmul_op = main_block.append_op(type='matmul', matmul_op = main_block.append_op(type='matmul',
...@@ -902,8 +934,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -902,8 +934,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = copy.deepcopy(
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) op_dist_attr.get_input_dims_mapping(x_name))
y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name))
trans_x = op_desc.attr('transpose_X')
trans_y = op_desc.attr('transpose_Y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_replicate(x_dims_mapping[-1]): if is_dim_replicate(x_dims_mapping[-1]):
return False return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard( if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
...@@ -932,10 +969,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -932,10 +969,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)): (not self.is_output_compatible(dist_op)):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
return False return False
return True return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
...@@ -983,10 +1018,15 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -983,10 +1018,15 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block.var(kwargs['Y'][0]) Weight_var = main_block.var(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block.var(kwargs['Out'][0])
trans_x = src_op.attr('transpose_X')
trans_y = src_op.attr('transpose_Y')
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2] Weight_var.name)[-2]
if trans_y:
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping) matmul_row_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
...@@ -1002,8 +1042,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1002,8 +1042,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
'linear') 'linear')
attrs = { attrs = {
'transpose_X': False, 'transpose_X': trans_x,
'transpose_Y': False, 'transpose_Y': trans_y,
'alpha': 1, 'alpha': 1,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role')
} }
...@@ -1354,8 +1394,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1354,8 +1394,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = copy.deepcopy(
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) op_dist_attr.get_input_dims_mapping(x_name))
y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name))
trans_x = op_desc.attr('trans_x')
trans_y = op_desc.attr('trans_y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_shard(x_dims_mapping[-1]): if is_dim_shard(x_dims_mapping[-1]):
return False return False
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate( if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
...@@ -1382,10 +1427,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1382,10 +1427,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)): (not self.is_output_compatible(dist_op)):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
return False return False
return True return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
...@@ -1433,10 +1476,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1433,10 +1476,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block.var(kwargs['Out'][0])
trans_x = src_op.attr('trans_x')
trans_y = src_op.attr('trans_y')
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1] Weight_var.name)[-1]
if trans_y:
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping) matmul_col_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
...@@ -1495,8 +1543,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1495,8 +1543,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
check_dtype(intermediate_var_0.dtype, 'dtype', check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear') ['float16', 'float32', 'float64'], 'linear')
attrs = { attrs = {
'trans_x': False, 'trans_x': trans_x,
'trans_y': False, 'trans_y': trans_y,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role')
} }
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
...@@ -1670,8 +1718,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1670,8 +1718,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = copy.deepcopy(
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) op_dist_attr.get_input_dims_mapping(x_name))
y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name))
trans_x = op_desc.attr('trans_x')
trans_y = op_desc.attr('trans_y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_replicate(x_dims_mapping[-1]): if is_dim_replicate(x_dims_mapping[-1]):
return False return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard( if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
...@@ -1700,10 +1753,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1700,10 +1753,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)): (not self.is_output_compatible(dist_op)):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
return False return False
return True return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
...@@ -1751,10 +1802,15 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1751,10 +1802,15 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block.var(kwargs['Out'][0])
trans_x = src_op.attr('trans_x')
trans_y = src_op.attr('trans_y')
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2] Weight_var.name)[-2]
if trans_y:
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping) matmul_row_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
...@@ -1770,8 +1826,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1770,8 +1826,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
'linear') 'linear')
attrs = { attrs = {
'trans_x': False, 'trans_x': trans_x,
'trans_y': False, 'trans_y': trans_y,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role')
} }
inputs = {'X': X_var, 'Y': Weight_var} inputs = {'X': X_var, 'Y': Weight_var}
......
...@@ -71,4 +71,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -71,4 +71,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip) py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip)
py_test_modules(test_quantization MODULES test_quantization) py_test_modules(test_quantization MODULES test_quantization)
py_test_modules(test_dist_matmul MODULES test_dist_matmul)
endif() endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
mesh = [[0, 1], [2, 3]]
def init_x_row(trans_x):
if trans_x:
x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, 1, -1]
})
return x
else:
x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1, 1]
})
return x
def init_x_col(trans_x):
if trans_x:
x = paddle.static.data(name='x', shape=[6, 8], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 0]
})
return x
else:
x = paddle.static.data(name='x', shape=[8, 6], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1]
})
return x
def init_y_row(trans_y):
if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
return y
else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
return y
def init_y_col(trans_y):
if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
return y
else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
return y
def matmul_dp2mp2(init_x, init_y, trans_x, trans_y):
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = init_x(trans_x)
y = init_y(trans_y)
x.stop_gradient = False
y.stop_gradient = False
out = paddle.fluid.layers.matmul(x,
y,
transpose_x=trans_x,
transpose_y=trans_y)
loss = paddle.mean(out)
return main_program, start_program, loss
def matmulv2_dp2mp2(init_x, init_y, trans_x, trans_y):
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = init_x(trans_x)
y = init_y(trans_y)
x.stop_gradient = False
y.stop_gradient = False
out = paddle.matmul(x, y, transpose_x=trans_x, transpose_y=trans_y)
loss = paddle.mean(out)
return main_program, start_program, loss
def parallelizer(program_func, *args, **kwargs):
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.dist_context import DistributedContext
main_program, start_program, loss = program_func(*args, **kwargs)
dist_context = DistributedContext()
completer = Completer(dist_context)
completer.complete_forward_annotation(main_program)
dist_context.block_state.parse_forward_blocks(main_program)
with program_guard(main_program, start_program):
append_backward(loss, distop_context=dist_context.dist_op_context)
completer.complete_backward_annotation(main_program)
dist_context.block_state.parse_backward_blocks(main_program)
partitioner = Partitioner(dist_context, 0)
dist_main_prog, _, _ = partitioner.partition(main_program, start_program,
[])
return dist_main_prog, dist_context
class TestDistMatmul(unittest.TestCase):
def check_col_program(self, main_program, dist_ctx):
# [0, -1] * [-1, 1] --> [0, 1]
ref_ops = [
"c_identity", "matmul", "reduce_mean", "fill_constant",
"reduce_mean_grad", "matmul_grad"
]
ops = []
block = main_program.global_block()
for op in block.ops:
ops.append(op.type)
if op.type == "matmul":
out_name = op.output('Out')[0]
out_var = block.vars[out_name]
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_idx == 0
assert op_dist_attr.impl_type == "matmul"
out_dims_mapping = op_dist_attr.get_output_dims_mapping(
out_name)
assert out_dims_mapping == [0, 1]
tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program(
out_var)
assert tensor_dist_attr.dims_mapping == [0, 1]
if op.type == "matmul_grad":
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_idx == 0
assert op_dist_attr.impl_type == "matmul"
assert ops == ref_ops
def check_row_program(self, main_program, dist_ctx):
# [0, -1, 1] * [1, -1] --> [0, -1, -1]
ref_ops = [
"matmul", "c_allreduce_sum", "reduce_mean", "fill_constant",
"reduce_mean_grad", "matmul_grad"
]
ops = []
block = main_program.global_block()
for op in block.ops:
ops.append(op.type)
if op.type == "matmul":
out_name = op.output('Out')[0]
out_var = block.vars[out_name]
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_idx == 1
assert op_dist_attr.impl_type == "matmul"
out_dims_mapping = op_dist_attr.get_output_dims_mapping(
out_name)
assert out_dims_mapping == [0, -1, -1]
tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program(
out_var)
assert tensor_dist_attr.dims_mapping == [0, -1, -1]
if op.type == "matmul_grad":
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_idx == 1
assert op_dist_attr.impl_type == "matmul"
assert ops == ref_ops
class TestDistMatmulCol(TestDistMatmul):
def init(self, trans_x, trans_y):
dist_main_prog, dist_ctx = parallelizer(matmul_dp2mp2, init_x_col,
init_y_col, trans_x, trans_y)
return dist_main_prog, dist_ctx
def test_matmul_col(self):
dist_main_prog, dist_ctx = self.init(False, False)
self.check_col_program(dist_main_prog, dist_ctx)
def test_trans_x(self):
dist_main_prog, dist_ctx = self.init(True, False)
self.check_col_program(dist_main_prog, dist_ctx)
def test_trans_y(self):
dist_main_prog, dist_ctx = self.init(False, True)
self.check_col_program(dist_main_prog, dist_ctx)
def test_trans_x_trans_y(self):
dist_main_prog, dist_ctx = self.init(True, True)
self.check_col_program(dist_main_prog, dist_ctx)
class TestDistMatmulRow(TestDistMatmul):
def init(self, trans_x, trans_y):
dist_main_prog, dist_ctx = parallelizer(matmul_dp2mp2, init_x_row,
init_y_row, trans_x, trans_y)
return dist_main_prog, dist_ctx
def test_matmul_row(self):
dist_main_prog, dist_ctx = self.init(False, False)
self.check_row_program(dist_main_prog, dist_ctx)
def test_trans_x(self):
dist_main_prog, dist_ctx = self.init(True, False)
self.check_row_program(dist_main_prog, dist_ctx)
def test_trans_y(self):
dist_main_prog, dist_ctx = self.init(False, True)
self.check_row_program(dist_main_prog, dist_ctx)
def test_trans_x_trans_y(self):
dist_main_prog, dist_ctx = self.init(True, True)
self.check_row_program(dist_main_prog, dist_ctx)
class TestDistMatmulV2(unittest.TestCase):
def check_col_program(self, main_program, dist_ctx):
# [0, -1] * [-1, 1] --> [0, 1]
ref_ops = [
"c_identity", "matmul_v2", "reduce_mean", "fill_constant",
"reduce_mean_grad", "matmul_v2_grad"
]
ops = []
block = main_program.global_block()
for op in block.ops:
ops.append(op.type)
if op.type == "matmul_v2":
out_name = op.output('Out')[0]
out_var = block.vars[out_name]
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_idx == 0
assert op_dist_attr.impl_type == "matmul_v2"
out_dims_mapping = op_dist_attr.get_output_dims_mapping(
out_name)
assert out_dims_mapping == [0, 1]
tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program(
out_var)
assert tensor_dist_attr.dims_mapping == [0, 1]
if op.type == "matmul_v2_grad":
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_idx == 0
assert op_dist_attr.impl_type == "matmul_v2"
assert ops == ref_ops
def check_row_program(self, main_program, dist_ctx):
# [0, -1, 1] * [1, -1] --> [0, -1, -1]
ref_ops = [
"matmul_v2", "c_allreduce_sum", "reduce_mean", "fill_constant",
"reduce_mean_grad", "matmul_v2_grad"
]
ops = []
block = main_program.global_block()
for op in block.ops:
ops.append(op.type)
if op.type == "matmul_v2":
out_name = op.output('Out')[0]
out_var = block.vars[out_name]
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_idx == 1
assert op_dist_attr.impl_type == "matmul_v2"
out_dims_mapping = op_dist_attr.get_output_dims_mapping(
out_name)
assert out_dims_mapping == [0, -1, -1]
tensor_dist_attr = dist_ctx.get_tensor_dist_attr_for_program(
out_var)
assert tensor_dist_attr.dims_mapping == [0, -1, -1]
if op.type == "matmul_v2_grad":
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_idx == 1
assert op_dist_attr.impl_type == "matmul_v2"
assert ops == ref_ops
class TestDistMatmulV2Col(TestDistMatmulV2):
def init(self, trans_x, trans_y):
dist_main_prog, dist_ctx = parallelizer(matmulv2_dp2mp2, init_x_col,
init_y_col, trans_x, trans_y)
return dist_main_prog, dist_ctx
def test_matmul_col(self):
dist_main_prog, dist_ctx = self.init(False, False)
self.check_col_program(dist_main_prog, dist_ctx)
def test_trans_x(self):
dist_main_prog, dist_ctx = self.init(True, False)
self.check_col_program(dist_main_prog, dist_ctx)
def test_trans_y(self):
dist_main_prog, dist_ctx = self.init(False, True)
self.check_col_program(dist_main_prog, dist_ctx)
def test_trans_x_trans_y(self):
dist_main_prog, dist_ctx = self.init(True, True)
self.check_col_program(dist_main_prog, dist_ctx)
class TestDistMatmulV2Row(TestDistMatmulV2):
def init(self, trans_x, trans_y):
dist_main_prog, dist_ctx = parallelizer(matmulv2_dp2mp2, init_x_row,
init_y_row, trans_x, trans_y)
return dist_main_prog, dist_ctx
def test_matmul_row(self):
dist_main_prog, dist_ctx = self.init(False, False)
self.check_row_program(dist_main_prog, dist_ctx)
def test_trans_x(self):
dist_main_prog, dist_ctx = self.init(True, False)
self.check_row_program(dist_main_prog, dist_ctx)
def test_trans_y(self):
dist_main_prog, dist_ctx = self.init(False, True)
self.check_row_program(dist_main_prog, dist_ctx)
def test_trans_x_trans_y(self):
dist_main_prog, dist_ctx = self.init(True, True)
self.check_row_program(dist_main_prog, dist_ctx)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册