未验证 提交 df944772 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] adapt for clip (#49249)

* [AutoParallel] adapt for clip

* fix unittest

* enable_static

* fix dist_fill_constant_batch_size_like

* fix process_mesh.shape

* update cond of modifying shape_list
上级 836a662c
...@@ -39,6 +39,7 @@ from .common import ( ...@@ -39,6 +39,7 @@ from .common import (
) )
__op_not_need_param_init__ = ["while", "cond"] __op_not_need_param_init__ = ["while", "cond"]
__op_has_shape_attr__ = ["fill_constant_batch_size_like", "fill_constant"]
def prim_operator_data_parallel_functor(ctx, src_op): def prim_operator_data_parallel_functor(ctx, src_op):
...@@ -476,6 +477,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -476,6 +477,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name]) dist_op_desc.set_output(output_name, kwargs[output_name])
if (
src_op.has_attr('shape')
and src_op.attr('shape')
and src_op.type in __op_has_shape_attr__
):
shape_list = src_op.attr('shape')
Out_var = main_block._var_recursive(kwargs['Out'][0])
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.process_mesh.shape
assert len(shape_list) == len(dim_mapping)
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[idx] = (
shape_list[idx] // process_mesh_shape[axis]
)
dist_op_desc._set_attr('shape', shape_list)
# data parallel synchronization for primtive operators # data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled from paddle.incubate.autograd import prim_enabled
......
...@@ -129,24 +129,6 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): ...@@ -129,24 +129,6 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
DistributedDefaultImpl0.forward(ctx, *args, **kwargs) DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
dist_op_context = ctx.dist_op_context
src_op = dist_op_context.cur_src_op
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
main_block = dist_op_context.work_block
op = main_block.ops[-1]
assert op.type == "fill_constant_batch_size_like"
# modify shape attr according to how output are partitioned
out_name = op.output('Out')[0]
dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = op_dist_attr.process_mesh.shape
shape_list = op.attr("shape")
# modify target shape
for idx, axis in enumerate(dims_mapping):
if axis >= 0:
shape_list[idx] = shape_list[idx] // process_mesh_shape[axis]
op._set_attr("shape", shape_list)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
......
...@@ -47,8 +47,23 @@ class DistributedPNorm(DistributedOperatorImplContainer): ...@@ -47,8 +47,23 @@ class DistributedPNorm(DistributedOperatorImplContainer):
register_distributed_operator_impl_container(DistributedPNorm("p_norm")) register_distributed_operator_impl_container(DistributedPNorm("p_norm"))
# Row Parallel # Data Parallel
class DistributedPNormImpl(DistributedOperatorImpl): class DistributedPNormImpl0(DistributedOperatorImpl):
"""
TODO: p_norm scene
1. axis == None, isinstance(p, (int, float)), asvector = True
1.1 x_dims_mapping == [0, -1, -1]
allgather input if it is splited by dp group
1.2 x_dims_mapping == [-1, 0, -1]
allgather, split and concat input if it is splited by mp group
2. isinstance(axis, int), asvector = False
1.1 axis == 0 and x_dims_mapping == [0, -1, -1]
allgather input if it's input[0] is splited by dp group.
1.2 axis == 1 and x_dims_mapping == [-1, 0, -1]
allgather, split and concat input if it's input[1] is splited by mp group
"""
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -57,6 +72,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -57,6 +72,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
axis = op_desc.attr('axis')
asvector = op_desc.attr('asvector')
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
if is_dim_replicate(x_dims_mapping[0]): if is_dim_replicate(x_dims_mapping[0]):
...@@ -65,6 +82,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -65,6 +82,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
for mapping in x_dims_mapping[1:]: for mapping in x_dims_mapping[1:]:
if is_dim_shard(mapping): if is_dim_shard(mapping):
return False return False
if not (axis == -1 and asvector) and not (axis == 0 and not asvector):
return False
return True return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
...@@ -90,6 +109,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -90,6 +109,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
axis = op_desc.attr('axis')
keepdim = op_desc.attr('keepdim')
batch_dim_mappings = [] batch_dim_mappings = []
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
...@@ -115,6 +136,14 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -115,6 +136,14 @@ class DistributedPNormImpl(DistributedOperatorImpl):
): ):
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
changed = True changed = True
if axis == 0 and not keepdim:
for arg_name in op_desc.output_arg_names():
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if len(dims_mapping) >= 1 and dims_mapping[0] != -1:
dims_mapping[0] = -1
changed = True
else:
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if ( if (
...@@ -350,5 +379,5 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -350,5 +379,5 @@ class DistributedPNormImpl(DistributedOperatorImpl):
register_distributed_operator_impl( register_distributed_operator_impl(
"p_norm", DistributedPNormImpl("row_parallel") "p_norm", DistributedPNormImpl0("data_parallel")
) )
...@@ -1261,6 +1261,8 @@ def set_grad_var_shape(program, dist_context): ...@@ -1261,6 +1261,8 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle_grad", "fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad", "flatten_contiguous_range_grad",
"relu_grad", "relu_grad",
"exp_grad",
"sigmoid_grad",
] ]
forward_list = [ forward_list = [
"reshape2", "reshape2",
...@@ -1279,6 +1281,8 @@ def set_grad_var_shape(program, dist_context): ...@@ -1279,6 +1281,8 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle", "fused_softmax_mask_upper_triangle",
"flatten_contiguous_range", "flatten_contiguous_range",
"relu", "relu",
"exp",
"sigmoid",
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
......
...@@ -22,7 +22,7 @@ from paddle.fluid.backward import append_backward ...@@ -22,7 +22,7 @@ from paddle.fluid.backward import append_backward
paddle.enable_static() paddle.enable_static()
def make_program_dp2(): def make_program_dp2_axis_None():
main_program = paddle.fluid.Program() main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program() start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program): with paddle.static.program_guard(main_program, start_program):
...@@ -35,6 +35,32 @@ def make_program_dp2(): ...@@ -35,6 +35,32 @@ def make_program_dp2():
return main_program, start_program, tmp_0 return main_program, start_program, tmp_0
def make_program_dp2_axis_0():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(
x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
)
tmp_0 = paddle.norm(x, p=2, axis=0)
return main_program, start_program, tmp_0
def make_program_dp2_axis_1():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(
x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
)
tmp_0 = paddle.norm(x, p=2, axis=1)
return main_program, start_program, tmp_0
def make_program_serial(): def make_program_serial():
main_program = paddle.fluid.Program() main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program() start_program = paddle.fluid.Program()
...@@ -76,15 +102,24 @@ def parallelizer(program_func, rank): ...@@ -76,15 +102,24 @@ def parallelizer(program_func, rank):
class TestDistPNorm(unittest.TestCase): class TestDistPNorm(unittest.TestCase):
def test_dist_pnorm_dp2(self): def prepare(self, func):
self.dist_main_prog, self.dist_context = parallelizer(func, 0)
self.ops = self.dist_main_prog.global_block().ops
def test_dist_pnorm(self):
pass
for rank in range(2): class TestDistPNormDP(TestDistPNorm):
dist_main_prog, dist_context = parallelizer(make_program_dp2, rank) def test_dist_pnorm(self):
ops = dist_main_prog.global_block().ops self.prepare(make_program_dp2_axis_None)
self.check_program()
def check_program(self):
op_types = [] op_types = []
for op in ops: for op in self.ops:
op_types.append(op.type) op_types.append(op.type)
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
if op.type == "p_norm": if op.type == "p_norm":
assert op_dist_attr.impl_type == "p_norm" assert op_dist_attr.impl_type == "p_norm"
if op.type in ["p_norm", "p_norm_grad"]: if op.type in ["p_norm", "p_norm_grad"]:
...@@ -112,11 +147,26 @@ class TestDistPNorm(unittest.TestCase): ...@@ -112,11 +147,26 @@ class TestDistPNorm(unittest.TestCase):
"slice", "slice",
] ]
def test_dist_pnorm_serial(self):
dist_main_prog, dist_context = parallelizer(make_program_serial, 0) class TestDistPNormDP1(TestDistPNormDP):
ops = dist_main_prog.global_block().ops def test_dist_pnorm(self):
for op in ops: self.prepare(make_program_dp2_axis_0)
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) self.check_program()
class TestDistPNormSerial(TestDistPNorm):
def test_dist_pnorm(self):
self.prepare(make_program_serial)
for op in self.ops:
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "default"
class TestDistPNormDPAxis1(TestDistPNorm):
def test_dist_pnorm(self):
self.prepare(make_program_dp2_axis_1)
for op in self.ops:
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "default" assert op_dist_attr.impl_type == "default"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册