未验证 提交 7c13645d 编写于 作者: 沉潜的鱼儿's avatar 沉潜的鱼儿 提交者: GitHub

dist matmul op compatible (#37949)

* dist matmul op compatible

* modify common dist op

* modify common

* add a space
上级 2360406d
...@@ -296,6 +296,83 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -296,6 +296,83 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return False return False
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_replicate(out_dims_mapping[-1]):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
input_dims_mapping = []
ordered_input_shard_dims_mapping = []
for dim in (x_dims_mapping + y_dims_mapping):
input_dims_mapping.append(dim)
for item in input_dims_mapping:
if item not in ordered_input_shard_dims_mapping and item != -1:
ordered_input_shard_dims_mapping.append(item)
for mapping in out_dims_mapping:
if mapping not in input_dims_mapping:
return False
if is_dim_shard(x_dims_mapping[0]):
order_index = 0
for idx, item in enumerate(out_dims_mapping):
if item != -1:
if item != ordered_input_shard_dims_mapping[order_index]:
return False
else:
order_index += 1
if order_index != len(ordered_input_shard_dims_mapping):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_shard(x_dims_mapping[0]):
for mapping in y_dims_mapping[1:]:
if is_dim_shard(mapping) and mapping == x_dims_mapping[0]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op) dim_changed = _update_dims_mapping_for_matmul(dist_op)
...@@ -510,6 +587,95 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -510,6 +587,95 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return False return False
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if op_desc.attr('transpose_X') or op_desc.attr('transpose_Y'):
return False
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# for gpt2, x dims > y dims, this is a temporary solution
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
x_shard_dim_count = 0
x_shard_dims = []
y_shard_dim_count = 0
y_shard_dims = []
for dim in x_dims_mapping:
if is_dim_shard(dim):
x_shard_dim_count += 1
x_shard_dims.append(dim)
for dim in y_dims_mapping:
if is_dim_shard(dim):
y_shard_dim_count += 1
y_shard_dims.append(dim)
if not x_shard_dims and not y_shard_dims:
return False
if x_shard_dims[-1] != y_shard_dims[0]:
return False
if x_shard_dim_count == y_shard_dim_count:
for dim in out_dims_mapping:
if is_dim_shard(dim):
return False
if x_shard_dims != y_shard_dims:
return False
else:
if x_shard_dim_count < y_shard_dim_count:
return False
output_shard_dims = []
for dim in out_dims_mapping:
if is_dim_shard(dim):
output_shard_dims.append(dim)
if not output_shard_dims or output_shard_dims[0] != x_shard_dims[0]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op) dim_changed = _update_dims_mapping_for_matmul(dist_op)
...@@ -710,6 +876,59 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -710,6 +876,59 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping)
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op) dim_changed = _update_dims_mapping_for_matmul(dist_op)
...@@ -777,6 +996,86 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -777,6 +996,86 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return False return False
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if op_desc.attr('trans_x') or op_desc.attr('trans_y'):
return False
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_replicate(out_dims_mapping[-1]):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
input_dims_mapping = []
ordered_input_shard_dims_mapping = []
for dim in (x_dims_mapping + y_dims_mapping):
input_dims_mapping.append(dim)
for item in input_dims_mapping:
if item not in ordered_input_shard_dims_mapping and item != -1:
ordered_input_shard_dims_mapping.append(item)
for mapping in out_dims_mapping:
if mapping not in input_dims_mapping:
return False
if is_dim_shard(x_dims_mapping[0]):
order_index = 0
for idx, item in enumerate(out_dims_mapping):
if item != -1:
if item != ordered_input_shard_dims_mapping[order_index]:
return False
else:
order_index += 1
if order_index != len(ordered_input_shard_dims_mapping):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_shard(x_dims_mapping[0]):
for mapping in y_dims_mapping[1:]:
if is_dim_shard(mapping) and mapping == x_dims_mapping[0]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op) dim_changed = _update_dims_mapping_for_matmul(dist_op)
...@@ -985,6 +1284,94 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -985,6 +1284,94 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return False return False
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if op_desc.attr('trans_x') or op_desc.attr('trans_y'):
return False
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
x_shard_dim_count = 0
x_shard_dims = []
y_shard_dim_count = 0
y_shard_dims = []
for dim in x_dims_mapping:
if is_dim_shard(dim):
x_shard_dim_count += 1
x_shard_dims.append(dim)
for dim in y_dims_mapping:
if is_dim_shard(dim):
y_shard_dim_count += 1
y_shard_dims.append(dim)
if not x_shard_dims and not y_shard_dims:
return False
if x_shard_dims[-1] != y_shard_dims[0]:
return False
if x_shard_dim_count == y_shard_dim_count:
for dim in out_dims_mapping:
if is_dim_shard(dim):
return False
if x_shard_dims != y_shard_dims:
return False
else:
if x_shard_dim_count < y_shard_dim_count:
return False
output_shard_dims = []
for dim in out_dims_mapping:
if is_dim_shard(dim):
output_shard_dims.append(dim)
if not output_shard_dims or output_shard_dims[0] != x_shard_dims[0]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op) dim_changed = _update_dims_mapping_for_matmul(dist_op)
...@@ -1183,6 +1570,61 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -1183,6 +1570,61 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping)
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op) dim_changed = _update_dims_mapping_for_matmul(dist_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册