未验证 提交 6b8efc45 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Support Primitive operators with Data Parallel (#42709)

* auto parallel support primitive op with data parallel

* add primitive change

* 5 loss 3D cylinder acc aligned

* add unitest
上级 a7778930
...@@ -1250,3 +1250,70 @@ class Completer: ...@@ -1250,3 +1250,70 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr) op, op_dist_attr)
continue continue
def complete_prim_annotation(self, serial_main_program=None):
"""
fill default data parallel annotation for program with primitive operators.
Arguments:
serial_main_program: partial annotated serial_main_program.
Returns:
serial_main_program: completed annotated serial_main_program.
"""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
import time
start_time = time.time()
self._dist_context._is_initialized = True
start_time = time.time()
self._dist_context._init_dist_attr_for_program()
start_time = time.time()
self._init_global_mesh_for_program()
# Do the validation check and amend some completion
start_time = time.time()
self._dist_context.amend_dist_attr_for_program()
self._dist_context.validate_dist_attr_for_program()
def _init_global_mesh_for_program(self):
# Copy the dist tensors and dist ops annotated by users from the default context
# global mesh
from paddle.distributed.auto_parallel.process_group import get_world_process_group
world_ranks = get_world_process_group().ranks
for block in self._dist_context._serial_main_program.blocks:
for tensor in block.vars.values():
# Copy the distributed tensors in the default context
dist_tensor = self._dist_context.get_dist_tensor_for_program(
tensor)
assert dist_tensor is not None
dist_tensor.dist_attr.process_mesh = world_ranks
for op in block.ops:
# Copy the distributed operators in the default context
dist_op = self._dist_context.get_dist_op_for_program(op)
assert dist_op is not None
dist_op.dist_attr.process_mesh = world_ranks
# Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
if op_dist_impls is not None:
backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr)
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
dist_op.dist_attr.impl_type = "default"
else:
dist_op.dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
dist_op.dist_attr.impl_idx = op_dist_impl.idx
break
else:
dist_op.dist_attr = backup_op_dist_attr
...@@ -32,3 +32,4 @@ from . import dist_pnorm ...@@ -32,3 +32,4 @@ from . import dist_pnorm
from . import dist_slice from . import dist_slice
from . import dist_fused_feedforward from . import dist_fused_feedforward
from . import dist_fused_attention from . import dist_fused_attention
from . import dist_reduce_p
...@@ -24,9 +24,10 @@ BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} ...@@ -24,9 +24,10 @@ BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
def is_elementwise_op(op_type): def is_elementwise_op(op_type):
for eltwise_op in _g_elementwise_ops: if op_type in _g_elementwise_ops:
if eltwise_op in op_type: return True
return True if "elementwise" in op_type:
return True
return False return False
......
...@@ -18,7 +18,7 @@ from .common import register_distributed_operator_impl_container ...@@ -18,7 +18,7 @@ from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl, is_parameter_related from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index, is_prim_op
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
...@@ -35,6 +35,55 @@ from ..utils import _get_comm_group, _get_corresponding_rank ...@@ -35,6 +35,55 @@ from ..utils import _get_comm_group, _get_corresponding_rank
__op_not_need_param_init__ = ["while", "cond"] __op_not_need_param_init__ = ["while", "cond"]
def prim_operator_data_parallel_functor(ctx, src_op):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
var_name = src_op.output_arg_names[0]
if var_name in ctx.grads_params:
assert var_name not in ctx.synced_gradient, "in primtive mode, grad is already {} synced".format(
var_name)
ctx.synced_gradient.add(var_name)
sync_group = new_process_group(ctx.data_parallel_group)
allreduce_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [var_name]},
outputs={'Out': [var_name]},
attrs={
'ring_id': sync_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
param = ctx.grads_params[var_name]
startup_block = dist_op_context.startup_block
new_op = startup_block.append_op(
type='c_broadcast',
inputs={'X': [param]},
outputs={'Out': [param]},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
grad_var = main_block.var(var_name)
dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var).dims_mapping
dist_attr = ctx.get_op_dist_attr_for_program(src_op)
process_mesh = dist_attr.process_mesh
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
ctx.set_op_dist_attr_for_program(allreduce_op, op_attr)
return
class DistributedDefault(DistributedOperatorImplContainer): class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedDefault, self).__init__(op_type) super(DistributedDefault, self).__init__(op_type)
...@@ -292,7 +341,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -292,7 +341,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
@staticmethod @staticmethod
def forward(ctx, *args, **kwargs): def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block startup_block = dist_op_context.startup_block
...@@ -315,7 +363,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -315,7 +363,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
output_name) output_name)
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.desc.append_op() dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
...@@ -323,7 +371,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -323,7 +371,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name]) dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp() # data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled
if prim_enabled():
assert is_prim_op(src_op)
prim_operator_data_parallel_functor(ctx, src_op)
return
# param initialization sync # param initialization sync
if src_op.type in __op_not_need_param_init__: if src_op.type in __op_not_need_param_init__:
...@@ -373,8 +426,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -373,8 +426,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
op_attr.set_input_dims_mapping(param.name, dims_mapping) op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_dist_attr_for_program(new_op, op_attr) ctx.set_op_dist_attr_for_program(new_op, op_attr)
startup_block._sync_with_cpp()
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedReducePrimtive(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedReducePrimtive, self).__init__(op_type)
register_distributed_operator_impl_container(
DistributedReducePrimtive("reduce_p"))
# Batch Dimension Reduce Primitive
class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReducePrimtiveImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
return len(op_desc.input_arg_names()) == 1
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
outputs = op_desc.output_arg_names()
if len(outputs) != 1:
return False
output_name = outputs[0]
output_var = dist_op.serial_op.block.var(output_name)
if output_var.shape != (1, ):
return False
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
return self.is_input_compatible(dist_op) and self.is_output_compatible(
dist_op)
def update_dims_mapping(self, dist_op):
changed = False
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
# batch dimension synchronization
var_name = src_op.output_arg_names[0]
sync_group = new_process_group(ctx.data_parallel_group)
allreduce_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [var_name]},
outputs={'Out': [var_name]},
attrs={
'ring_id': sync_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
# dist attr
var = main_block.var(var_name)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
new_op_attr = OperatorDistributedAttribute()
new_op_attr.process_mesh = op_dist_attr.process_mesh
new_op_attr.set_output_dims_mapping(var.name,
tensor_dist_attr.dims_mapping)
new_op_attr.set_input_dims_mapping(var.name,
tensor_dist_attr.dims_mapping)
ctx.set_op_dist_attr_for_program(allreduce_op, new_op_attr)
@staticmethod
def backward(ctx, *args, **kwargs):
raise RuntimeError(
"primitive operator does NOT have backward function, op type: {}".
format(str(op.type)))
register_distributed_operator_impl(
"reduce_p", DistributedReducePrimtiveImpl0("batch_dimension_reduce_p"))
...@@ -263,6 +263,11 @@ class Partitioner(object): ...@@ -263,6 +263,11 @@ class Partitioner(object):
dist_op_backward_impl.backward( dist_op_backward_impl.backward(
self._dist_context, **kinputs, **koutputs, self._dist_context, **kinputs, **koutputs,
**{"grad_var_to_var": grad_var_to_var}) **{"grad_var_to_var": grad_var_to_var})
elif int(op.attr('op_role')) == 2:
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_impl = get_distributed_operator_impl_container(
"default").get_impl(0)
dist_op_impl.backward(self._dist_context, **kinputs, **koutputs)
else: else:
raise NotImplementedError( raise NotImplementedError(
"partitioner only support forward op and backward op, but got {}". "partitioner only support forward op and backward op, but got {}".
......
...@@ -1101,6 +1101,10 @@ def is_loss_op(op): ...@@ -1101,6 +1101,10 @@ def is_loss_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss)) int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))
def is_prim_op(op):
return op.type.endswith("_p")
def get_loss_op(block): def get_loss_op(block):
loss_ops = [] loss_ops = []
for op in block.ops: for op in block.ops:
...@@ -1118,6 +1122,9 @@ def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs): ...@@ -1118,6 +1122,9 @@ def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs):
tensor_dist_attr.dims_mapping = dims_mapping tensor_dist_attr.dims_mapping = dims_mapping
# TODO get global mesh group # TODO get global mesh group
tensor_dist_attr.process_mesh = process_mesh tensor_dist_attr.process_mesh = process_mesh
if "mark_annotated" in kwargs and kwargs["mark_annotated"]:
tensor_dist_attr.mark_annotated("dims_mapping")
tensor_dist_attr.mark_annotated("process_mesh")
dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr)
return tensor_dist_attr return tensor_dist_attr
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid import program_guard
from paddle.incubate.autograd import prim2orig, enable_prim, prim_enabled
from paddle.fluid.layer_helper import LayerHelper
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import set_var_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context, set_default_distributed_context
paddle.enable_static()
enable_prim()
nranks = 2
rank = 0
class TestPrimDistOp(unittest.TestCase):
def setUp(self):
self.main_program = paddle.static.Program()
self.startup_program = paddle.static.Program()
self.layer_help = LayerHelper('TestPrimDistOp')
with paddle.static.program_guard(self.main_program,
self.startup_program):
self.init_prog()
def init_prog(self):
# block = self.main_program.global_block()
# block = self.main_program.global_block()
self.w = self.layer_help.create_parameter(
dtype="float", shape=[20], attr=None)
self.w_grad = paddle.static.data(
name='w_grad', shape=[20], dtype='float')
self.tmp1 = paddle.static.data(name='tmp1', shape=[20], dtype='float')
self.tmp2 = paddle.static.data(name='tmp2', shape=[20], dtype='float')
self.batch_reduced = paddle.static.data(
name='batch_reduced', shape=[1], dtype='float')
self.attrs = {}
default_dist_context = get_default_distributed_context()
_global_process_mesh = auto.ProcessMesh(list(range(nranks)))
tensor_dist_attr = set_var_dist_attr(
default_dist_context,
self.tmp1, [-1],
_global_process_mesh,
mark_annotated=True)
tensor_dist_attr = set_var_dist_attr(
default_dist_context,
self.tmp1, [-1],
_global_process_mesh,
mark_annotated=True)
op = self.layer_help.append_op(
type="add_p",
inputs={'X': self.tmp1,
'Y': self.w},
outputs={'Z': self.w_grad},
attrs=self.attrs)
op = self.layer_help.append_op(
type="reduce_p",
inputs={'X': self.tmp2},
outputs={'Y': self.batch_reduced},
attrs={"axis": [0]})
def test_loss_and_grad_allreduce(self):
dist_context = DistributedContext(self.main_program,
self.startup_program)
completer = Completer(dist_context)
completer.complete_prim_annotation(self.main_program)
dist_context.block_state.parse_forward_blocks(self.main_program)
dist_context.block_state.parse_backward_blocks(self.main_program)
dist_context.grads_params = dict()
dist_context.grads_params[self.w_grad.name] = self.w.name
dist_context.synced_gradient = set()
dist_context.data_parallel_group = list(range(nranks))
partitioner = Partitioner(dist_context, rank)
dist_main_prog, dist_startup_prog, _ = partitioner.partition(
self.main_program, self.startup_program, [(self.w, self.w_grad)])
ops = dist_main_prog.global_block().ops
self.assertTrue(ops[1].type == "c_allreduce_sum")
self.assertTrue(ops[3].type == "c_allreduce_sum")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册