未验证 提交 89bced5e 编写于 作者: 沉潜的鱼儿's avatar 沉潜的鱼儿 提交者: GitHub

Dist op compatible (#37994)

* dist matmul op compatible

* dist op unittest

* modify dist matmul

* modify dist reshape

* modify dist reshape

* add a space

* add a space

* delete dist matmul op

* modify reshape

* add dist op unittest

* modify dist op unittest
上级 698fca80
...@@ -57,6 +57,9 @@ class DistributedOperatorImpl: ...@@ -57,6 +57,9 @@ class DistributedOperatorImpl:
return self.is_input_compatible(dist_op) and \ return self.is_input_compatible(dist_op) and \
self.is_output_compatible(dist_op) self.is_output_compatible(dist_op)
def is_auto_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
......
...@@ -80,6 +80,32 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -80,6 +80,32 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return False return False
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
for mapping in out_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
if w_dims_mapping[-1] != out_dims_mapping[-1]:
return False
if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
......
...@@ -74,6 +74,36 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -74,6 +74,36 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[-1]):
return False
for idx, item in enumerate(out_dims_mapping[:-2]):
if x_dims_mapping[idx] != item:
return False
if out_dims_mapping[-2] != x_dims_mapping[-1]:
return False
if x_shape_dims_mapping[0] != -1:
return False
if x_shape_dims_mapping[1:] != x_dims_mapping[:]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
...@@ -201,6 +231,43 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -201,6 +231,43 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
if len(x_dims_mapping) == len(out_dims_mapping) + 2:
if out_dims_mapping[0] != x_dims_mapping[0]:
return False
if x_dims_mapping[-1] != -1 or x_dims_mapping[-2] != -1:
return False
elif len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
for idx, item in enumerate(x_dims_mapping[:-2]):
if out_dims_mapping[idx] != item:
return False
if x_dims_mapping[-2] != out_dims_mapping[-1]:
return False
if x_shape_dims_mapping[0] != -1:
return False
if x_shape_dims_mapping[1:] != x_dims_mapping[:]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
......
...@@ -71,6 +71,25 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -71,6 +71,25 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis')
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(x_dims_mapping) - 1:
return False
if is_dim_shard(x_dims_mapping[axis]):
return False
if x_dims_mapping != out_dims_mapping:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
......
...@@ -47,6 +47,35 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -47,6 +47,35 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
perm = op_desc.attr('axis')
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
new_dims_mapping = [-1 for i in range(len(x_dims_mapping))]
for i in range(len(x_dims_mapping)):
new_dims_mapping[i] = x_dims_mapping[perm[i]]
if len(x_dims_mapping) != len(out_dims_mapping):
return False
if new_dims_mapping != out_dims_mapping:
return False
if x_shape_dims_mapping[0] != -1:
return False
if x_shape_dims_mapping[1:] != x_dims_mapping[:]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册