未验证 提交 df944772 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] adapt for clip (#49249)

* [AutoParallel] adapt for clip

* fix unittest

* enable_static

* fix dist_fill_constant_batch_size_like

* fix process_mesh.shape

* update cond of modifying shape_list
上级 836a662c
......@@ -39,6 +39,7 @@ from .common import (
)
__op_not_need_param_init__ = ["while", "cond"]
__op_has_shape_attr__ = ["fill_constant_batch_size_like", "fill_constant"]
def prim_operator_data_parallel_functor(ctx, src_op):
......@@ -476,6 +477,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
if (
src_op.has_attr('shape')
and src_op.attr('shape')
and src_op.type in __op_has_shape_attr__
):
shape_list = src_op.attr('shape')
Out_var = main_block._var_recursive(kwargs['Out'][0])
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.process_mesh.shape
assert len(shape_list) == len(dim_mapping)
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[idx] = (
shape_list[idx] // process_mesh_shape[axis]
)
dist_op_desc._set_attr('shape', shape_list)
# data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled
......
......@@ -129,24 +129,6 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
dist_op_context = ctx.dist_op_context
src_op = dist_op_context.cur_src_op
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
main_block = dist_op_context.work_block
op = main_block.ops[-1]
assert op.type == "fill_constant_batch_size_like"
# modify shape attr according to how output are partitioned
out_name = op.output('Out')[0]
dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = op_dist_attr.process_mesh.shape
shape_list = op.attr("shape")
# modify target shape
for idx, axis in enumerate(dims_mapping):
if axis >= 0:
shape_list[idx] = shape_list[idx] // process_mesh_shape[axis]
op._set_attr("shape", shape_list)
@staticmethod
def backward(ctx, *args, **kwargs):
......
......@@ -47,8 +47,23 @@ class DistributedPNorm(DistributedOperatorImplContainer):
register_distributed_operator_impl_container(DistributedPNorm("p_norm"))
# Row Parallel
class DistributedPNormImpl(DistributedOperatorImpl):
# Data Parallel
class DistributedPNormImpl0(DistributedOperatorImpl):
"""
TODO: p_norm scene
1. axis == None, isinstance(p, (int, float)), asvector = True
1.1 x_dims_mapping == [0, -1, -1]
allgather input if it is splited by dp group
1.2 x_dims_mapping == [-1, 0, -1]
allgather, split and concat input if it is splited by mp group
2. isinstance(axis, int), asvector = False
1.1 axis == 0 and x_dims_mapping == [0, -1, -1]
allgather input if it's input[0] is splited by dp group.
1.2 axis == 1 and x_dims_mapping == [-1, 0, -1]
allgather, split and concat input if it's input[1] is splited by mp group
"""
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
......@@ -57,6 +72,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
axis = op_desc.attr('axis')
asvector = op_desc.attr('asvector')
x_name = op_desc.input('X')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
if is_dim_replicate(x_dims_mapping[0]):
......@@ -65,6 +82,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
for mapping in x_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
if not (axis == -1 and asvector) and not (axis == 0 and not asvector):
return False
return True
def is_output_compatible(self, dist_op):
......@@ -90,6 +109,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
axis = op_desc.attr('axis')
keepdim = op_desc.attr('keepdim')
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
......@@ -115,14 +136,22 @@ class DistributedPNormImpl(DistributedOperatorImpl):
):
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if (
len(dims_mapping) >= 1
and compatible_dim_mapping != dims_mapping[0]
):
dims_mapping[0] = compatible_dim_mapping
changed = True
if axis == 0 and not keepdim:
for arg_name in op_desc.output_arg_names():
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if len(dims_mapping) >= 1 and dims_mapping[0] != -1:
dims_mapping[0] = -1
changed = True
else:
for arg_name in op_desc.output_arg_names():
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if (
len(dims_mapping) >= 1
and compatible_dim_mapping != dims_mapping[0]
):
dims_mapping[0] = compatible_dim_mapping
changed = True
return changed
......@@ -350,5 +379,5 @@ class DistributedPNormImpl(DistributedOperatorImpl):
register_distributed_operator_impl(
"p_norm", DistributedPNormImpl("row_parallel")
"p_norm", DistributedPNormImpl0("data_parallel")
)
......@@ -1261,6 +1261,8 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad",
"relu_grad",
"exp_grad",
"sigmoid_grad",
]
forward_list = [
"reshape2",
......@@ -1279,6 +1281,8 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle",
"flatten_contiguous_range",
"relu",
"exp",
"sigmoid",
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
......
......@@ -22,7 +22,7 @@ from paddle.fluid.backward import append_backward
paddle.enable_static()
def make_program_dp2():
def make_program_dp2_axis_None():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
......@@ -35,6 +35,32 @@ def make_program_dp2():
return main_program, start_program, tmp_0
def make_program_dp2_axis_0():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(
x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
)
tmp_0 = paddle.norm(x, p=2, axis=0)
return main_program, start_program, tmp_0
def make_program_dp2_axis_1():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(
x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
)
tmp_0 = paddle.norm(x, p=2, axis=1)
return main_program, start_program, tmp_0
def make_program_serial():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
......@@ -76,47 +102,71 @@ def parallelizer(program_func, rank):
class TestDistPNorm(unittest.TestCase):
def test_dist_pnorm_dp2(self):
for rank in range(2):
dist_main_prog, dist_context = parallelizer(make_program_dp2, rank)
ops = dist_main_prog.global_block().ops
op_types = []
for op in ops:
op_types.append(op.type)
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if op.type == "p_norm":
assert op_dist_attr.impl_type == "p_norm"
if op.type in ["p_norm", "p_norm_grad"]:
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert set(input_attr.dims_mapping) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert set(output_attr.dims_mapping) == set([-1])
if op.type == 'c_allgather':
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert input_attr.dims_mapping[0] == 0
assert set(input_attr.dims_mapping[1:]) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert set(output_attr.dims_mapping) == set([-1])
if op.type == 'slice':
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert set(input_attr.dims_mapping) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert output_attr.dims_mapping[0] == 0
assert set(output_attr.dims_mapping[1:]) == set([-1])
assert op_types == [
"c_allgather",
"p_norm",
"fill_constant",
"p_norm_grad",
"slice",
]
def test_dist_pnorm_serial(self):
dist_main_prog, dist_context = parallelizer(make_program_serial, 0)
ops = dist_main_prog.global_block().ops
for op in ops:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
def prepare(self, func):
self.dist_main_prog, self.dist_context = parallelizer(func, 0)
self.ops = self.dist_main_prog.global_block().ops
def test_dist_pnorm(self):
pass
class TestDistPNormDP(TestDistPNorm):
def test_dist_pnorm(self):
self.prepare(make_program_dp2_axis_None)
self.check_program()
def check_program(self):
op_types = []
for op in self.ops:
op_types.append(op.type)
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
if op.type == "p_norm":
assert op_dist_attr.impl_type == "p_norm"
if op.type in ["p_norm", "p_norm_grad"]:
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert set(input_attr.dims_mapping) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert set(output_attr.dims_mapping) == set([-1])
if op.type == 'c_allgather':
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert input_attr.dims_mapping[0] == 0
assert set(input_attr.dims_mapping[1:]) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert set(output_attr.dims_mapping) == set([-1])
if op.type == 'slice':
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert set(input_attr.dims_mapping) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert output_attr.dims_mapping[0] == 0
assert set(output_attr.dims_mapping[1:]) == set([-1])
assert op_types == [
"c_allgather",
"p_norm",
"fill_constant",
"p_norm_grad",
"slice",
]
class TestDistPNormDP1(TestDistPNormDP):
def test_dist_pnorm(self):
self.prepare(make_program_dp2_axis_0)
self.check_program()
class TestDistPNormSerial(TestDistPNorm):
def test_dist_pnorm(self):
self.prepare(make_program_serial)
for op in self.ops:
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "default"
class TestDistPNormDPAxis1(TestDistPNorm):
def test_dist_pnorm(self):
self.prepare(make_program_dp2_axis_1)
for op in self.ops:
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "default"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册