未验证 提交 c4fdb057 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Paralle] partitioner refactor (#37853)

上级 b463dff4
......@@ -404,7 +404,7 @@ class DistributedOperatorContext:
def get_cur_src_op(self):
return self._cur_src_op
def prepare_forward_context(self, src_op):
def prepare_context(self, src_op):
self.set_cur_src_op(src_op)
......@@ -413,6 +413,7 @@ class DistributedOperatorContext:
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
assert varname in self._varname_mapping
varnames.append(self._varname_mapping[varname])
kinputs[input_name] = varnames
......@@ -421,29 +422,8 @@ class DistributedOperatorContext:
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
assert varname in self._varname_mapping
varnames.append(self._varname_mapping[varname])
koutputs[output_name] = varnames
return kinputs, koutputs
def prepare_backward_context(self, backward_op):
self.set_cur_src_op(backward_op)
# build input varname mapping
kinputs = {}
for input_name in backward_op.desc.input_names():
varnames = []
for varname in backward_op.desc.input(input_name):
varnames.append(varname)
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in backward_op.desc.output_names():
varnames = []
for varname in backward_op.desc.output(output_name):
varnames.append(varname)
koutputs[output_name] = varnames
return kinputs, koutputs
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License
from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_registries = {}
......@@ -138,3 +140,46 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
exact_shape.append(new_shape)
return exact_shape
def set_comm_op_dist_attr_for_program(new_op, process_mesh, tensor_dist_attr,
ctx):
assert process_mesh is not None
assert tensor_dist_attr is not None
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = process_mesh
for input_varname in new_op.desc.input_arg_names():
new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
for output_varname in new_op.desc.output_arg_names():
new_op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh
for input_name in ref_op.input_names:
assert input_name in new_op.input_names
assert len(ref_op.input(input_name)) == 1
assert len(new_op.input(input_name)) == 1
ref_tensor_dist_attr = ref_dist_attr.get_input_dist_attr(
ref_op.input(input_name)[0])
new_op_dist_attr.set_input_dist_attr(
new_op.input(input_name)[0], ref_tensor_dist_attr)
for output_name in ref_op.output_names:
assert output_name in new_op.output_names
assert len(ref_op.output(output_name)) == 1
assert len(new_op.output(output_name)) == 1
ref_tensor_dist_attr = ref_dist_attr.get_output_dist_attr(
ref_op.output(output_name)[0])
new_op_dist_attr.set_output_dist_attr(
new_op.output(output_name)[0], ref_tensor_dist_attr)
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
......@@ -66,7 +66,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op()
varname_mapping = dist_op_context.get_varname_mapping()
rank_id = dist_op_context.get_rank_id()
# check validation of inputs / outputs
......@@ -153,6 +152,31 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
str(backward_op))
rank_id = dist_op_context.get_rank_id()
# check validation of inputs / outputs
for input_name in backward_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
assert len(kwargs[input_name]) == len(
backward_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in backward_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
assert len(kwargs[output_name]) == len(
backward_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(backward_op.desc)
for input_name in backward_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in backward_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited,
# we need insert gradient allreduce for the gradient of parameter in its output
......
......@@ -16,14 +16,14 @@ from .common import infer_shape
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute
from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
......@@ -329,9 +329,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh,
rank_id)
# check if need gradient allreduce
need_gradient_allreduce = False
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
assert 'W' in kwargs, "input [{}] is not given".format('W')
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out')
......@@ -355,6 +352,84 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs['W@GRAD'])
Ids_var = main_block.var(kwargs['Ids'][0])
Weight_var = main_block.var(kwargs['W'][0])
Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Weight_grad = main_block.var(kwargs['W@GRAD'][0])
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
Weight_var.name)[0]
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes
# A generalized method to caculate embedding offset using cartisian product
relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape,
embedding_row_dim_mapping, rank_id)
per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size
check_variable_and_dtype(
Out_grad, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", '@tmp_0@GRAD'])),
dtype=Out_grad.dtype,
shape=Out_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_grad.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
assert out_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_grad_dist_attr)
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
embedding_row_dim_mapping, rank_id)
group = new_process_group(group_ranks)
c_identity_op = main_block.append_op(
type='c_identity',
inputs={'X': [Out_grad]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'], 'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh,
out_grad_dist_attr, ctx)
main_block._sync_with_cpp()
c_embedding_grad_op_desc = main_block.desc.append_op()
c_embedding_grad_op_desc.set_type("c_embedding_grad")
c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
c_embedding_grad_op_desc.set_input('W', [Weight_var.name])
c_embedding_grad_op_desc.set_input('Out@GRAD',
[intermediate_var_0.name])
c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name])
c_embedding_grad_op_desc._set_attr('start_index', relative_idx)
c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
main_block._sync_with_cpp()
c_embedding_grad_op = main_block.ops[-1]
assert c_embedding_grad_op.type == "c_embedding_grad"
naive_copy_op_dist_attr_for_program(c_embedding_grad_op, backward_op,
ctx)
# check if need gradient allreduce
need_gradient_allreduce = False
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name)
mesh_shape = process_mesh.topology
......
......@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from .common import infer_shape
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
......@@ -33,6 +35,20 @@ from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
def copy_op_with_new_input_output(block, src_op, **kwargs):
dist_op_desc = block.desc.append_op()
dist_op_desc.copy_from(src_op.desc)
for input_name in src_op.desc.input_names():
assert input_name in kwargs
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
assert input_name in kwargs
dist_op_desc.set_output(output_name, kwargs[output_name])
block._sync_with_cpp()
return dist_op_desc
def _update_dims_mapping_for_matmul(dist_op):
changed = False
op_desc = dist_op.serial_op.desc
......@@ -141,15 +157,11 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
# check if need gradient allreduce
need_gradient_allreduce = False
assert 'Y' in kwargs, "input [{}] is not given".format('Y')
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
assert len(
kwargs['Y']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
......@@ -166,15 +178,138 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
kwargs['Y@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['Y@GRAD'])
assert len(
kwargs['X@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['X@GRAD'])
X_var = main_block.var(kwargs['X'][0])
Y_var = main_block.var(kwargs['Y'][0])
Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Y_grad = main_block.var(kwargs['Y@GRAD'][0])
assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name)
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes
assert len(
Y_var_dim_mapping
) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
Y_var.name, Y_var_dim_mapping)
Y_var_partitioned = False
for dim in Y_var_dim_mapping:
if dim >= 0 and process_mesh_shape[dim] > 0:
Y_var_partitioned = True
break
if Y_var.is_parameter and Y_var_partitioned:
if Y_var_dim_mapping[0] >= 0:
# row parallel: c_identity + matmul
assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0]
check_variable_and_dtype(
Out_grad, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])) + "@GRAD",
dtype=Out_grad.dtype,
shape=Out_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_grad.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
assert out_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_grad_dist_attr)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id)
group = new_process_group(group_ranks)
c_identity_op = main_block.append_op(
type='c_identity',
inputs={'X': [Out_grad]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'],
'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
set_comm_op_dist_attr_for_program(
c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx)
new_kwargs = copy.deepcopy(kwargs)
new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output(
main_block, backward_op, **new_kwargs)
else:
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
new_kwargs = copy.deepcopy(kwargs)
# NOTE (JZ-LIANG) should allow left operand be empty for matmul grad
has_x_grad = len(kwargs['X@GRAD']) > 0
if has_x_grad:
assert len(kwargs['X@GRAD']) == 1
X_grad = main_block.var(kwargs['X@GRAD'][0])
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])) + "@GRAD",
dtype=X_grad.dtype,
shape=X_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_grad.stop_gradient)
X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
assert X_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
X_grad_dist_attr)
new_kwargs['X@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output(
main_block, backward_op, **new_kwargs)
# NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
if has_x_grad:
group_ranks = _get_comm_group(process_mesh_group,
process_mesh_shape, parallel_axis,
rank_id)
group = new_process_group(group_ranks)
c_allreduce_sum_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0.name]},
outputs={'Out': kwargs['X@GRAD']},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward
})
set_comm_op_dist_attr_for_program(c_allreduce_sum_op,
dist_attr.process_mesh,
X_grad_dist_attr, ctx)
else:
# replicate
matmul_op_desc = copy_op_with_new_input_output(main_block, backward_op,
**kwargs)
main_block._sync_with_cpp()
# check if need gradient allreduce
need_gradient_allreduce = False
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name)
mesh_shape = process_mesh.topology
......@@ -187,7 +322,6 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)
Y_var = main_block.var(kwargs['Y'][0])
if need_gradient_allreduce and Y_var.is_parameter:
Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
allreduce_op = main_block.append_op(
......
......@@ -43,7 +43,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
super(DistributedReshapeImpl0, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = True
self._backward_implemented = False
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
......@@ -200,7 +200,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
super(DistributedReshapeImpl1, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = True
self._backward_implemented = False
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
......
......@@ -39,7 +39,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
super(DistributedSoftmaxImpl, self).__init__()
self._name = name
self._forward_implemented = False
self._backward_implemented = True
self._backward_implemented = False
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
......
......@@ -39,7 +39,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
super(DistributedTranspose2Impl, self).__init__()
self._name = name
self._forward_implemented = False
self._backward_implemented = True
self._backward_implemented = False
def is_input_compatible(self, dist_op):
return True
......
......@@ -22,15 +22,15 @@ import subprocess
import logging
import pickle
import time
import paddle
from paddle.distributed.utils import get_logger
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from paddle.fluid import program_guard
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation
from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .process_group import get_process_group
......@@ -79,6 +79,7 @@ class AutoParallelizer:
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False
self._pass_context = None
def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix()
......@@ -90,28 +91,112 @@ class AutoParallelizer:
if suffix in attr_name:
op._remove_attr(attr_name)
def _apply_serial_forward_pass(self, main_program, startup_program):
# apply amp forward pass
if self._dist_strategy.amp:
auto_parallel_amp_pass = new_pass("auto_parallel_amp_pass",
self._dist_strategy.amp_configs)
auto_parallel_amp_pass.apply_forward(main_program, startup_program,
self._pass_context)
# apply recompute forward pass
if self._dist_strategy.recompute:
auto_parallel_recompute_pass = new_pass(
"auto_parallel_recompute_pass",
self._dist_strategy.recompute_configs)
auto_parallel_recompute_pass.apply_forward(
main_program, startup_program, self._pass_context)
def _generate_backward(self, main_program, startup_program, loss,
parameter_list, no_grad_set, callbacks):
# apply recompute backward pass
if self._dist_strategy.recompute:
assert auto_parallel_recompute_pass
auto_parallel_recompute_pass.apply_forward(
main_program, startup_program, parameter_list, no_grad_set,
self._pass_context)
else:
from paddle.fluid.backward import append_backward
with program_guard(main_program, startup_program):
params_grads = append_backward(
loss,
parameter_list,
no_grad_set,
callbacks,
distop_context=self._dist_context.dist_op_context)
complete_backward_annotation(
main_program, dist_context=self._dist_context)
# apply amp forward pass
if self._dist_strategy.amp:
assert auto_parallel_amp_pass
auto_parallel_amp_pass.apply_backward(main_program, startup_program,
self._pass_context)
return params_grads
def _apply_optimize(self, main_program, startup_program, params_grads):
if self._dist_strategy.sharding:
auto_parallel_sharding_pass = new_pass(
"auto_parallel_sharding_pass", self._dist_strategy)
params_grads = auto_parallel_sharding_pass.apply(
main_program, startup_program, params_grads, self._pass_context)
if self._dist_strategy.gradient_merge:
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass",
self._dist_strategy.gradient_merge_configs)
auto_parallel_gradient_merge_pass.apply(
main_program, startup_program, params_grads, self._pass_context)
else:
with program_guard(main_program, startup_program):
optimizer = copy.deepcopy(self._optimizer)
optimize_ops = optimizer.apply_gradients(params_grads)
# update completion
complete_update_annotation(
main_program, dist_context=self._dist_context)
return optimize_ops
def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None
serial_main_program = self._main_program.clone()
serial_startup_program = self._startup_program.clone()
serial_loss = serial_main_program.global_block().var(self._loss.name)
# generating serial
if dist_context is None:
# Annotation completion
self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(self._main_program,
completed_main_program = complete_annotation(serial_main_program,
self._dist_context)
else:
completed_main_program = self._main_program
completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context)
# serial forward pass
self._apply_serial_forward_pass(completed_main_program,
serial_startup_program)
# serial backward pass
params_grads = self._generate_backward(
completed_main_program, serial_startup_program, serial_loss,
self._parameter_list, self._no_grad_set, self._callbacks)
# Logical partition
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
dist_main_prog, dist_startup_prog = partitioner.transpile_forward(
completed_main_program, self._startup_program)
dist_params_grads = partitioner.apply_backward(
self._loss, completed_main_program, self._startup_program,
dist_main_prog, dist_startup_prog)
dist_optimize_ops = partitioner.apply_optimize(
copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog,
dist_startup_prog)
rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
completed_main_program, serial_startup_program, params_grads)
# TODO refactor the placement of optimizer
# generate optimize program
dist_optimize_ops = self._apply_optimize(
dist_main_prog, dist_startup_prog, dist_params_grads)
set_grad_var_shape(dist_main_prog, self._dist_context)
......@@ -133,13 +218,15 @@ class AutoParallelizer:
loss,
startup_program,
parameter_list=None,
no_grad_set=None):
no_grad_set=None,
callbacks=None):
assert startup_program is not None
self._loss = loss
self._startup_program = startup_program
self._main_program = loss.block.program
self._parameter_list = parameter_list
self._no_grad_set = no_grad_set
self._callbacks = callbacks
if self._enable_auto_mapping and self._need_rank_mapping:
# Do the mapping pass before parallelization
......@@ -156,6 +243,7 @@ class AutoParallelizer:
self._optimizer, self._cluster)
planner = Planner(
serial_program_info,
self,
algorithm_config={"name": "mcmc",
"max_search_times": 5})
dist_context, _ = planner.search()
......@@ -262,6 +350,7 @@ class AutoParallelizer:
cluster=self._cluster)
planner = Planner(
serial_program_info,
self,
algorithm_config={
"name": "mcmc",
"max_search_times": 5
......@@ -303,3 +392,14 @@ class AutoParallelizer:
self._remove_distributed_attrs(dist_main_prog)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_main_program" or k == "_startup_program" or k == "_dist_context" or k == "_fleet" or k == "_loss":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
......@@ -20,18 +20,11 @@ from paddle.fluid import core
from paddle.fluid import framework as framework
from paddle.fluid import core, unique_name
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm
from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.completion import complete_backward_annotation, complete_update_annotation
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
......@@ -48,279 +41,85 @@ class Partitioner(object):
2. partition var: if a var is sharded, modify the shape of var according to its shard annotation
Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user.
Example:
....
import paddle.distributed.auto_parallel as auto
from paddle.fluid.distributed_attribute import get_default_distributed_context
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
# create serial program with forward only
with static.program_guard(serial_main_program, serial_start_program):
model = create_model(config)
tokens = static.data(name="tokens", shape=[batch_size, sequence_len], dtype='int64')
labels = static.data(name="labels", shape=[batch_size, sequence_len], dtype='int64')
loss_mask = static.data(name="loss_mask", shape=[batch_size, sequence_len], dtype='int64')
preds = model(tokens)
loss = criterion(preds, labels, loss_mask)
# auto completion
auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7])
annotated_main_program = auto.complete_annotation(serial_main_program)
dist_context = get_default_distributed_context()
# distributed strategy & rank info
rank_id = paddle.distributed.get_rank()
dist_strategy = fleet.DistributedStrategy()
# create partitioner
Partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# create dist program with forward only
# for distributed inference, using partitioned_main_prog from here
partitioned_main_prog, partitioned_startup_prog = Partitioner.transpile_forward(complete_train_program, start_program)
# create dist program with forward/backward/update
# for distributed training, using partitioned_main_prog from here
dist_params_grads = Partitioner.apply_backward(loss, complete_train_program, start_program, partitioned_main_prog, partitioned_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog)
"""
def __init__(self, dist_strategy, dist_context, rank_id=0):
def __init__(self, dist_context, rank_id=0):
"""
Args:
dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy.
dist_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
rank_id (int): global rank id to which the partitioned distributed program belong.
"""
if not isinstance(dist_strategy, DistributedStrategy):
raise TypeError(
"dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here"
% type(dist_strategy))
if not isinstance(dist_context, DistributedContext):
raise TypeError(
"dist_context be paddle.fluid.DistributedContext, got %s here" %
type(dist_context))
self._dist_strategy = dist_strategy
self._dist_context = dist_context
self._rank_id = rank_id
self._serial2dist_varname_mapping = {}
self._dist_varname_suffix = ""
# TODO if there is some dist op that is not compatible
# with auto_backward in forward, the following flag
# should be set to False
self._compatible_with_auto_backward = True
def transpile_forward(self, serial_main_program, serial_startup_program):
"""
take serial forward programs with shard annotation, create a new distributed forward programs based on the serial ones.
instead of modify the input programs inplace, this function will preserve the inputs and create new program for output.
beside replace the serial op with its dist op, if user has defined other strategy in fleet.distributed_strategy, and if
those strategy need to transpile (modify) the forward network program, those forward program modification should also be done within this
function in auto parallel scenario, in order to facilitate distributed inference/evaluation which need to DECOUPLE strategy specific forward transpilation with fleet.distributed_optimizer.minimize().
by now the fleet.distributed_strategy that need transpile forward program are following:
1. (optimizer) sharding
Args:
main_program (paddle.fluid.framework.program): serial main program with forward network only
startup_program (paddle.fluid.framework.program): serial startup program with forward network only
return:
main_program (paddle.fluid.framework.program): distributed main program with forward network only
startup_program (paddle.fluid.framework.program): distributed startup program with forward network only
"""
dist_main_program, dist_startup_program = self.transpile_forward_impl(
serial_main_program, serial_startup_program)
return dist_main_program, dist_startup_program
def apply_backward(self,
serial_loss,
serial_main_program,
serial_startup_program,
dist_main_program,
dist_startup_program,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
A complete training neural network is made up of forward and backward propagation.
This function is to generate the dist backward program for the distributed forward program.
By now, the current automatical backward mechanism in paddle framework might NOT handle the backward generation for
some dist ops correctly, some so we now have two ways to genenate the backward program:
1. dist_forward_program --> auto_backward --> dist_backward_program (if auto_backward could handle all dist op)
2. serial_forward_program --> auto_backward --> serial_backward_program --> dist_op_backward_transpile --> dist_backward_program (if auto_backward could not handle all dist op)
the backprogram is append the input dist program inplaced.
Args:
serial_loss (Variable) the loss in serial program that to be minimized
serial_main_program (paddle.fluid.framework.program): serial main program with forward network only
serial_startup_program (paddle.fluid.framework.program): serial startup program with forward network only
dist_main_program (paddle.fluid.framework.program): dist main program with forward network only
dist_startup_program (paddle.fluid.framework.program): dist startup program with forward network only
parameter_list (Iterable, optional): Iterable of ``Variable`` or ``Variable.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters
will be updated.
no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need
to be updated. The default value is None.
callbacks (list, optional): list of callable objects to run when appending backward
operator for one parameter. The default value is None.
return:
params_grads (list) list of tuple that contain param and its grad variable
"""
params_grads = self.apply_backward_impl(
serial_loss, serial_main_program, serial_startup_program,
dist_main_program, dist_startup_program)
return params_grads
def apply_optimize(self, user_define_optimizer, params_grads,
dist_main_program, dist_startup_program):
"""
append update related ops to the program: clip, weight decay, ops
filter optimize op if sharding is enable
naive gradient synchronization before update
Args:
user_define_optimizer (paddle.fluid.optimizer):
params_grads (list) list of tuple that contain param and its grad variable
dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network
dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward network
"""
optimize_ops = self.apply_optimize_impl(user_define_optimizer,
params_grads, dist_main_program,
dist_startup_program)
return optimize_ops
def transpile_forward_impl(self, main_program, startup_program):
if not isinstance(main_program, (Program)):
raise TypeError(
"dist_strategy be paddle.fluid.framework.program, got %s here" %
type(main_program))
def partition(self, serial_main_program, serial_startup_program,
params_grads):
if not isinstance(startup_program, (Program)):
if not isinstance(serial_main_program, (Program)):
raise TypeError(
"dist_context be paddle.fluid.framework.program, got %s here" %
type(startup_program))
"main_program be paddle.fluid.framework.program, got %s here" %
type(serial_main_program))
# check if shard annotated serial program valid
if not self._is_valid_annotated_program(main_program):
if not self._is_valid_annotated_program(serial_main_program):
raise RuntimeError(
"Not all vars or ops are annotated in main program !")
# dist op & partition vars
new_main_prog, new_startup_program = self._dist_var_op_forward_transpile(
main_program, startup_program)
# Sharding
if self._dist_strategy.sharding:
new_main_prog, new_startup_program = self._sharding_forward_transpile(
new_main_prog, new_startup_program)
return new_main_prog, new_startup_program
def apply_backward_impl(self,
serial_loss,
serial_main_program,
serial_startup_program,
dist_main_program,
dist_startup_program,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
"""
params_grads = self._dist_var_op_backward_transpile(
serial_loss, serial_main_program, serial_startup_program,
dist_main_program, dist_startup_program)
# Sharding
if self._dist_strategy.sharding:
self._sharding_backward_transpile(new_main_prog,
new_startup_program)
return params_grads
def apply_optimize_impl(self, user_define_optimizer, params_grads,
dist_main_program, dist_startup_program):
"""
append update related ops to the program: clip, weight decay, ops
filter optimize op if sharding is enable
naive gradient synchronization before update
Args:
user_define_optimizer (paddle.fluid.optimizer):
params_grads (list) list of tuple that contain param and its grad variable
dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network
dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward network
"""
# init distop helper
dist_op_context = self._dist_context.dist_op_context
dist_op_context.set_varname_mapping(self._serial2dist_varname_mapping)
dist_op_context.set_rank_id(self._rank_id)
if self._dist_strategy.sharding:
params_grads = sharding_optimize_transpile(
params_grads, dist_main_program, dist_startup_program)
# partition startup program
if serial_startup_program == None:
partitioned_startup_prog = None
else:
partitioned_startup_prog = self.partition_startup_program(
serial_main_program, serial_startup_program)
dist_op_context.set_dst_startup_program(partitioned_startup_prog)
optimize_ops = self._optimize_transpile(user_define_optimizer,
params_grads, dist_main_program,
dist_startup_program)
# partition main program
partitioned_main_prog, partitioned_params_grads = self.partition_main_program(
serial_main_program, params_grads)
return optimize_ops
return partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads
def _dist_var_op_forward_transpile(self,
serial_main_program,
serial_startup_program=None):
"""
1. partition variables
2. replace local op with corresponding dist op
"""
def partition_startup_program(self, serial_main_program,
serial_startup_program):
partitioned_main_prog = fluid.Program()
partitioned_global_block = partitioned_main_prog.global_block()
serial_main_block = serial_main_program.global_block()
serial_ops = serial_main_program.global_block().ops
if not isinstance(serial_startup_program, (Program)):
raise TypeError(
"dist_context be paddle.fluid.framework.program, got %s here" %
type(serial_startup_program))
# transpile startup program
if serial_startup_program == None:
partitioned_startup_prog = None
else:
partitioned_startup_prog = fluid.Program()
# create parameter
partitioned_startup_global_block = partitioned_startup_prog.global_block(
)
ref_block = serial_main_program.global_block()
target_block = partitioned_startup_prog.global_block()
param2shape = {}
temp_varname_map = {}
# tensors
for var in serial_startup_program.list_vars():
if isinstance(var, Parameter):
# TODO if var not belong to this rank, should be filtered
serial_main_var = serial_main_block.var(var.name)
serial_main_var = ref_block.var(var.name)
dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
serial_main_var)
target_shape = _get_dist_shape(serial_main_var, dist_attr)
new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name
_partition_parameter(self._dist_context, serial_main_var,
partitioned_startup_global_block,
new_name, target_shape)
target_block, new_name, target_shape)
param2shape[new_name] = target_shape
# copy initializer
# ops
for op in serial_startup_program.global_block().ops:
# TODO if var not belong to this rank, should be filtered
output_vars = op.desc.output_arg_names()
......@@ -331,20 +130,19 @@ class Partitioner(object):
assert temp_varname_map[output_vars[
0]] in param2shape, "try to initialize [{}] which is not a Parameter".format(
output_vars[0])
new_op_desc = partitioned_startup_global_block.desc.append_op()
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op_desc._rename_output(output_vars[0],
temp_varname_map[output_vars[0]])
new_op_desc._set_attr(
"shape", param2shape[temp_varname_map[output_vars[0]]])
partitioned_startup_global_block._sync_with_cpp()
new_op_desc._set_attr("shape",
param2shape[temp_varname_map[output_vars[0]]])
target_block._sync_with_cpp()
# set distribute atrribute
new_op = partitioned_startup_global_block.ops[-1]
new_op = target_block.ops[-1]
assert new_op.type == new_op_desc.type()
assert new_op.desc == new_op_desc
output_var = partitioned_startup_global_block.var(output_vars[
0])
output_var = target_block.var(output_vars[0])
output_var_attr = self._dist_context.get_tensor_dist_attr_for_program(
output_var)
op_attr = OperatorDistributedAttribute()
......@@ -355,24 +153,40 @@ class Partitioner(object):
output_var_attr.dims_mapping)
self._dist_context.set_op_dist_attr_for_program(new_op, op_attr)
# TODO move helper init to a comm place
return partitioned_startup_prog
def partition_main_program(self, serial_main_program, params_and_grads):
"""
1. partition variables
2. replace local op with corresponding dist op
"""
dist_op_context = self._dist_context.dist_op_context
partitioned_main_prog = fluid.Program()
dist_op_context.set_dst_main_program(partitioned_main_prog)
dist_op_context.set_dst_startup_program(partitioned_startup_prog)
dist_op_context.set_varname_mapping(self._serial2dist_varname_mapping)
dist_op_context.set_rank_id(self._rank_id)
target_block = partitioned_main_prog.global_block()
ref_block = serial_main_program.global_block()
serial_ops = serial_main_program.global_block().ops
# transpile main program
# init mapping
first_backward_op_idx = -1
forward_op_id2forward_op = {}
for idx in range(len(serial_ops)):
if is_forward_op(serial_ops[idx]):
forward_op_id2forward_op[serial_ops[idx].desc.id(
)] = serial_ops[idx]
# partiiton
for op in serial_ops:
# partititon input variables
for serial_input_varname in op.desc.input_arg_names():
if serial_input_varname not in self._serial2dist_varname_mapping:
new_varname = serial_input_varname + self._dist_varname_suffix
if serial_main_block.has_var(serial_input_varname):
_partition_var(self._dist_context, serial_main_block,
partitioned_global_block,
serial_input_varname, new_varname)
if ref_block.has_var(serial_input_varname):
_partition_var(self._dist_context, ref_block,
target_block, serial_input_varname,
new_varname)
else:
assert serial_input_varname in __varname_not_in_block__
......@@ -383,145 +197,47 @@ class Partitioner(object):
for serial_output_varname in op.desc.output_arg_names():
if serial_output_varname not in self._serial2dist_varname_mapping:
new_varname = serial_output_varname + self._dist_varname_suffix
_partition_var(self._dist_context, serial_main_block,
partitioned_global_block,
_partition_var(self._dist_context, ref_block, target_block,
serial_output_varname, new_varname)
self._serial2dist_varname_mapping[
serial_output_varname] = new_varname
# partition op
kinputs, koutputs = dist_op_context.prepare_forward_context(op)
dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if _is_dist_op_forward_implement(self._dist_context, op):
dist_ops = get_distributed_operator_impl_container(op.type)
dist_op_impl = dist_ops.get_impl(dist_attr.impl_idx)
dist_op_impl.forward(self._dist_context, **kinputs, **koutputs)
else:
# replicate op
dist_ops = get_distributed_operator_impl_container("default")
dist_op_impl = dist_ops.get_impl(0)
dist_op_impl.forward(self._dist_context, **kinputs, **koutputs)
return partitioned_main_prog, partitioned_startup_prog
def _dist_var_op_backward_transpile(self,
serial_loss,
serial_main_program,
serial_startup_program,
dist_main_program,
dist_startup_program,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
so far, the auto_backward case only guarantee the correcotness of backward ops for curtain Dist ops:
1. NV-Megatron-like parallel embedding
2. NV-Megatron-like row parallel linear
3. NV-Megatron-like col parallel linear
"""
if self._compatible_with_auto_backward:
assert isinstance(
serial_loss, Variable), "The target loss should be an Variable."
dist_loss = self._serial_varname2dist_var(serial_loss.name,
dist_main_program)
assert len(dist_loss.shape) == 1 and dist_loss.shape[0] == 1, \
"The dist loss.shape should be (1L,), but the current dist loss.shape is {}. " \
"Maybe that you should call fluid.layers.mean to process the current loss.".format(
dist_loss.shape)
# update parameter list
if parameter_list:
parameter_list = [
self._serial_varname2dist_var(param.name, dist_main_program)
for param in parameter_list
]
# update parameter no_grad_set
if no_grad_set:
no_grad_set = [
self._serial_varname2dist_var(param.name, dist_main_program)
for param in no_grad_set
]
if is_forward_op(op):
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_forward_impl = _get_dist_op_forward_implement(
op, self._dist_context)
dist_op_forward_impl.forward(self._dist_context, **kinputs,
**koutputs)
dist_op_context = self._dist_context.dist_op_context
params_and_grads = _auto_backward(
dist_loss,
dist_startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set,
callbacks=callbacks,
distop_context=dist_op_context)
# backward completion
complete_backward_annotation(
dist_main_program, dist_context=self._dist_context)
# transpiler backward for dist op
# get backward ops
ops = dist_main_program.global_block().ops
first_backward_op_idx = -1
forward_op_id2forward_op = {}
for idx in range(len(ops)):
if is_forward_op(ops[idx]):
forward_op_id2forward_op[ops[idx].desc.id()] = ops[idx]
if int(ops[idx].attr('op_role')) == int(OpRole.Backward):
first_backward_op_idx = idx
break
assert first_backward_op_idx >= 0, "not found backward ops in program"
assert len(forward_op_id2forward_op
) > 0, "not found forward ops in program"
backward_ops = ops[first_backward_op_idx:]
for backward_op in backward_ops:
# if the backward op has a corresponding forward op
if backward_op.desc.id() in dist_op_context.gradopidx2opidx:
forward_op_id = dist_op_context.gradopidx2opidx[
backward_op.desc.id()]
forward_op = forward_op_id2forward_op[forward_op_id]
# TODO backward attr should has _impl_idx
forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op)
# TODO use the backward op itself to find the dist op
dist_ops = get_distributed_operator_impl_container(
forward_op.type)
kinputs, koutputs = dist_op_context.prepare_backward_context(
backward_op)
# TODO use backward op itself to determine impl idx
if _is_dist_op_backward_implement(self._dist_context,
forward_op):
dist_op_impl = dist_ops.get_impl(
forward_op_dist_attr.impl_idx)
dist_op_impl.backward(self._dist_context, **kinputs,
elif is_backward_op(op):
print(str(op))
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_backward_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op)
dist_op_backward_impl.backward(self._dist_context, **kinputs,
**koutputs)
else:
# replicate op
dist_ops = get_distributed_operator_impl_container(
"default")
dist_op_impl = dist_ops.get_impl(0)
dist_op_impl.backward(self._dist_context, **kinputs,
**koutputs)
return params_and_grads
# replace dist grad ops
raise NotImplementedError(
"partitioner only support forward op and backward op, but got {}".
format(str(op)))
partitioned_params_and_grads = []
for p, g in params_and_grads:
assert p.name in self._serial2dist_varname_mapping
dist_p_name = self._serial2dist_varname_mapping[p.name]
assert target_block.has_var(dist_p_name)
dist_p = target_block.var(dist_p_name)
if g is None:
dist_g = None
else:
raise RuntimeError("transpile NOT implemented !")
def _optimize_transpile(self, user_define_optimizer, params_grads,
main_program, startup_program):
assert g.name in self._serial2dist_varname_mapping
dist_g_name = self._serial2dist_varname_mapping[g.name]
assert target_block.has_var(dist_g_name)
dist_g = target_block.var(dist_g_name)
partitioned_params_and_grads.append((dist_p, dist_g))
with program_guard(main_program, startup_program):
optimize_ops = user_define_optimizer.apply_gradients(params_grads)
# update completion
complete_update_annotation(
main_program, dist_context=self._dist_context)
return optimize_ops
return partitioned_main_prog, partitioned_params_and_grads
def _is_valid_annotated_program(self, program):
......@@ -543,154 +259,6 @@ class Partitioner(object):
return all_ops_annotated and all_vars_annotated
def _serial_varname2dist_var(self, serial_varname, dist_program):
assert serial_varname in self._serial2dist_varname_mapping, "The serial var [{}] is not found in var name mapping".format(
serial_varname)
dist_varname = self._serial2dist_varname_mapping[serial_varname]
assert dist_program.global_block().has_var(
dist_varname
), "The dist var [{}] is not found in dist program".format(dist_varname)
dist_var = dist_program.global_block().var(dist_varname)
return dist_var
def _is_var_distributed(self, var):
dist_attr = self._dist_context.get_tensor_dist_attr_for_program(var)
assert dist_attr is not None, "dist_attr of var [{}] is None".format(
var.name)
return _is_distributed(dist_attr)
def _sharding_forward_transpile(self, main_prog, startup_program):
"""
this transpile conduct the modification in forward program need by sharding strategy
which majorly include:
1. partition the parameter
2. insert broadcast op
3. insert sync op
NOTE the transpile modification is inplace on the input program
"""
raise NotImplementedError(
"Sharding is NOT support in AutoParallel yet!")
def _sharding_backward_transpile(self, main_prog, startup_program):
"""
this transpile conduct the modification in backward program need by sharding strategy
which majorly include:
1. partition the gradient
2. insert broadcast op
3. insert sync op
NOTE the transpile modification is inplace on the input program
"""
raise NotImplementedError(
"Sharding is NOT support in AutoParallel yet!")
def _sharding_optimize_transpile(self, params_grads, dist_main_program,
dist_startup_program):
"""
shard params_grads
append the broadcast to sync parameters
"""
raise RuntimeError("sharding transpile is NOT implemented !")
def _get_no_grad_set_name(no_grad_set):
no_grad_set_name = set()
if no_grad_set is not None:
if isinstance(no_grad_set, (set, list, tuple)):
for i, no_grad_var in enumerate(no_grad_set):
if isinstance(no_grad_var, framework.Variable):
no_grad_set_name.add(no_grad_var.name)
elif isinstance(no_grad_var, six.string_types):
no_grad_set_name.add(no_grad_var)
else:
raise TypeError(
"The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s."
% (type(no_grad_var)))
else:
raise TypeError(
"The type of no_grad_set should be set or list or tuple, but received {}".
format(type(no_grad_set)))
return no_grad_set_name
def _get_no_grad_set(loss, no_grad_set=None):
no_grad_set = _get_no_grad_set_name(no_grad_set)
parameters = loss.block.program.global_block().all_parameters()
param_no_trainable = set(
[param.name for param in parameters if param.trainable is False])
# If the parameter is no trainable, it should not have a gradient.
no_grad_set.update(param_no_trainable)
return no_grad_set
def _is_dist_op_forward_implement(dist_context, op):
dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_ops = get_distributed_operator_impl_container(op.type)
return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
dist_attr.impl_idx)._forward_implemented
def _is_dist_op_backward_implement(dist_context, op):
dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_ops = get_distributed_operator_impl_container(op.type)
return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
dist_attr.impl_idx)._backward_implemented
def _auto_backward(loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None,
distop_context=None):
"""
modification is inplaced
"""
act_no_grad_set = _get_no_grad_set(loss, no_grad_set)
assert isinstance(loss, Variable), "The target loss should be an Variable."
if callbacks is None:
callbacks = [error_clip_callback]
else:
assert (isinstance(callbacks, list))
assert len(loss.shape) == 1 and loss.shape[0] == 1, \
"The loss.shape should be (1L,), but the current loss.shape is {}. " \
"Maybe that you should call fluid.layers.mean to process the current loss.".format(
loss.shape)
program = loss.block.program
with program_guard(program, startup_program):
params_grads = append_backward(
loss,
parameter_list,
act_no_grad_set,
callbacks,
distop_context=distop_context)
return params_grads
def _is_distributed(dist_attr):
mapping = dist_attr.dims_mapping
mesh = dist_attr.process_mesh.topology
for idx in range(len(mapping)):
if mapping[idx] >= 0 and mesh[mapping[idx]] > 1:
return True
return False
def _get_dist_shape(var, dist_attr):
......@@ -795,52 +363,33 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
dst_varname, target_shape)
def _insert_src_op(src_op, dst_block, varname_mapping):
new_op_desc = dst_block.desc.append_op()
new_op_desc.copy_from(src_op.desc)
for local_varname in src_op.desc.input_arg_names():
new_op_desc._rename_input(local_varname, varname_mapping[local_varname])
for local_varname in src_op.desc.output_arg_names():
new_op_desc._rename_output(local_varname,
varname_mapping[local_varname])
dst_block._sync_with_cpp()
def _insert_dist_op(src_op, dst_block, varname_mapping, dist_context, rank_id):
# build input varname mapping
input_mapping = {}
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
varnames.append(varname_mapping[varname])
input_mapping[input_name] = varnames
# build output varname mapping
output_mapping = {}
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
varnames.append(varname_mapping[varname])
output_mapping[output_name] = varnames
# append dist op
dist_attr = dist_context.get_op_dist_attr_for_program(src_op)
dist_ops = get_distributed_operator_impl_container(src_op.type)
append_op_handle = dist_ops.get_impl(dist_attr.impl_idx).forward(src_op)
append_op_handle(
dst_block,
src_op,
dist_attr,
input_mapping,
output_mapping,
rank_id=rank_id)
def is_forward_op(op):
role1 = int(core.op_proto_and_checker_maker.OpRole.Forward) | int(
core.op_proto_and_checker_maker.OpRole.Loss)
role2 = int(core.op_proto_and_checker_maker.OpRole.Forward)
op_role = int(op.attr('op_role'))
return op_role == role2 or op_role == role1
def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op_id2forward_op):
dist_op_context = dist_context.dist_op_context
if backward_op.desc.id() in dist_op_context.gradopidx2opidx:
forward_op_id = dist_op_context.gradopidx2opidx[backward_op.desc.id()]
forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
dist_ops = get_distributed_operator_impl_container(forward_op.type)
# TODO backward should have its own impl_idx
if dist_ops and forward_op_dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
forward_op_dist_attr.impl_idx)._backward_implemented:
return dist_ops.get_impl(forward_op_dist_attr.impl_idx)
dist_ops = get_distributed_operator_impl_container("default")
return dist_ops.get_impl(0)
def _get_dist_op_forward_implement(forward_op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
dist_ops = get_distributed_operator_impl_container(forward_op.type)
if dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl(
dist_attr.impl_idx)._forward_implemented:
return dist_ops.get_impl(dist_attr.impl_idx)
else:
dist_ops = get_distributed_operator_impl_container("default")
return dist_ops.get_impl(0)
......@@ -386,15 +386,20 @@ class SearchAlgorithm:
class MCMC(SearchAlgorithm):
def __init__(self, serial_program_info, max_search_times=5):
def __init__(self, serial_program_info, parallelizer, max_search_times=5):
super(MCMC, self).__init__("mcmc")
self._serial_program_info = serial_program_info
self._max_search_times = max_search_times
self._parallelizer = parallelizer
@property
def serial_program_info(self):
return self._serial_program_info
@property
def parallelizer(self):
return self._parallelizer
@property
def max_search_times(self):
return self._max_search_times
......@@ -483,7 +488,7 @@ class MCMC(SearchAlgorithm):
cost = None
# get all distributed programs
all_dist_main_program = get_all_distributed_main_program(
self.serial_program_info, dist_context)
self.serial_program_info, dist_context, self.parallelizer)
pipeline_config = [
process_mesh.processes for process_mesh in pipeline_process_meshes
] if pipeline_process_meshes is not None else None
......@@ -829,8 +834,10 @@ class MCMC(SearchAlgorithm):
class Planner:
def __init__(self, serial_program_info, algorithm_config=None):
def __init__(self, serial_program_info, parallelizer,
algorithm_config=None):
self._serial_program_info = serial_program_info
self._parallelizer = parallelizer
self._algorithm_config = algorithm_config
self._algorithm_searcher = self.create_algorithm_searcher(
algorithm_config)
......@@ -847,6 +854,10 @@ class Planner:
def algorithm_searcher(self):
return self._algorithm_searcher
@property
def parallelizer(self):
return self._parallelizer
def create_algorithm_searcher(self, algorithm_config):
name = algorithm_config.get("name", None)
assert name is not None, "Invalid algorithm config."
......@@ -856,9 +867,9 @@ class Planner:
# NOTE: Only GPU clusters are supported now.
max_search_times = algorithm_config.get("max_search_times", None)
algorithm_searcher = MCMC(
self.serial_program_info,
self.serial_program_info, self.parallelizer,
max_search_times) if max_search_times is not None else MCMC(
self.serial_program_info)
self.serial_program_info, self.parallelizer)
else:
raise NotImplementedError(
"Other search algorithms have not been supported now.")
......
......@@ -993,7 +993,9 @@ def set_grad_var_shape(program, dist_context):
block = program.global_block()
vars = block.vars
for op in block.ops:
if op.type == "sum":
if op.type in [
"sum", "check_finite_and_unscale", "update_loss_scaling"
]:
continue
if int(op.attr('op_role')) == int(OpRole.Backward):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
......@@ -1004,15 +1006,24 @@ def set_grad_var_shape(program, dist_context):
forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale":
forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad":
forward_var_name = None
for output_name in op.output_names:
if var_name in op.output(output_name):
assert "@GRAD" in output_name
input_name = output_name[:output_name.find("@GRAD")]
assert len(op.input(input_name)) == 1
forward_var_name = op.input(input_name)[0]
assert forward_var_name is not None
need_set_shape_list = [
"reshape2_grad", "softmax_with_cross_entropy_grad",
"transpose2_grad", "softmax_grad", "cross_entropy_grad2",
"dropout_grad", "unsqueeze2_grad"
"dropout_grad"
]
forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout", "unsqueeze2"
"softmax", "cross_entropy2", "dropout"
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
......@@ -1041,6 +1052,23 @@ def set_grad_var_shape(program, dist_context):
grad_var.desc.set_shape(ref_shape)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
def is_forward_op(op):
ref_role1 = int(core.op_proto_and_checker_maker.OpRole.Forward)
ref_role2 = int(core.op_proto_and_checker_maker.OpRole.Loss)
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == ref_role1 or
op_role == ref_role2)
def is_backward_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
def update_op_dims_mapping_by_default_dist_impl(dist_op):
changed = False
op_dist_attr = dist_op.dist_attr
......@@ -1177,57 +1205,25 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op):
return changed
def get_all_distributed_main_program(serial_program_info, dist_context):
def get_all_distributed_main_program(serial_program_info, dist_context,
parallelizer):
"Get all distributed main programs by dist_context."
from .dist_context import DistributedOperatorContext
from .dist_context import DistributedOperatorContext, DistributedContext
cluster = serial_program_info.cluster
copied_parallelizer = copy.deepcopy(parallelizer)
all_dist_main_program = []
ranks = paddle.distributed.get_world_size() if cluster is None else len(
cluster.get_all_devices("GPU"))
for rank_id in range(ranks):
used_dist_context = copy.deepcopy(dist_context)
used_dist_context._dist_op_context = DistributedOperatorContext()
dist_main_program, dist_startup_program = get_specified_distributed_main_program(
serial_program_info, used_dist_context, rank_id)
_, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program(
rank_id, used_dist_context)
all_dist_main_program.append(dist_main_program)
return all_dist_main_program
def get_specified_distributed_main_program(serial_program_info, dist_context,
rank_id):
"Get distributed main program by the given dist_context and rank_id."
from .partitioner import Partitioner
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from .process_group import _g_process_group_map, ProcessGroup
dist_strategy = paddle.distributed.fleet.DistributedStrategy()
train_program = serial_program_info.train_program
startup_program = serial_program_info.startup_program
loss = serial_program_info.loss
optimizer = serial_program_info.optimizer
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
dist_main_program, dist_startup_program = partitioner.transpile_forward(
train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, train_program, startup_program, dist_main_program,
dist_startup_program)
opt_ops = partitioner.apply_optimize(
copy.deepcopy(optimizer), dist_params_grads, dist_main_program,
dist_startup_program)
set_grad_var_shape(dist_main_program, dist_context)
make_data_unshard(dist_main_program, dist_startup_program, dist_context)
reshard(dist_main_program, dist_startup_program, rank_id, dist_context)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
return dist_main_program, dist_startup_program
class SerialProgramInfo:
def __init__(self,
train_program,
......@@ -1286,7 +1282,6 @@ def get_standalone_cost_data(distributed_programs):
shape = list(map(lambda x: int(x.strip()), shape))
dtype_factor = 1
total_static_input_size += reduce(lambda x, y: x * y, shape)
# print(arg_name_lower)
if op.type == "c_embedding":
arg_name_lower = "w" if arg_name_lower == "weight" else "ids"
for arg_name in op.input_names:
......@@ -1301,7 +1296,8 @@ def get_standalone_cost_data(distributed_programs):
actual_runtime = total_actual_input_size / total_static_input_size * runtime
return actual_runtime
cost_model = paddle.cost_model.CostModel()
import paddle.cost_model as cm
cost_model = cm.CostModel()
cost_model.static_cost_data()
DEFAULT_MULTIPLE = 2
OP_NAME_MAPPING = {
......
......@@ -26,7 +26,7 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core
......@@ -148,22 +148,33 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
dist_strategy = fleet.DistributedStrategy()
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# auto completion
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
......
......@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
......@@ -469,21 +470,31 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
dist_strategy = fleet.DistributedStrategy()
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# logical partition
dist_train_program, dist_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, dist_train_program,
dist_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
dist_train_program, dist_startup_prog)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
partitioner = Partitioner(dist_context, rank_id)
dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
dist_train_program, dist_startup_prog, dist_params_grads)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context)
return dist_train_program, dist_startup_prog
......
......@@ -54,9 +54,9 @@ def get_programs(annotated_func):
rank_id = 3
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog = partitioner.transpile_forward(
complete_train_program, start_program)
partitioner = Partitioner(dist_context, rank_id)
test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, _ = partitioner.partition(
complete_train_program, start_program, [])
return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context
......
......@@ -35,6 +35,7 @@ from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_pr
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.process_group import new_process_group
......@@ -790,9 +791,9 @@ class GPTPretrainingCriterion(nn.Layer):
return loss
def gpt_pretrain_forward(train_program, start_program):
def gpt_pretrain_forward(train_program, startup_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
startup_program), utils.unique_name.guard():
batch_size = 16
sequence_len = 512
input_ids = static.data(
......@@ -848,7 +849,19 @@ def gpt_pretrain_forward(train_program, start_program):
loss = criterion(preds, labels, loss_mask)
return train_program, start_program, loss
return train_program, startup_program, loss
class FakeStrategy(object):
def __init__(self):
self.amp = False
self.recompute = False
class FakeFleet(object):
def __init__(self):
self.user_defined_optimizer = None
self._user_defined_strategy = FakeStrategy()
class TestGPTPartitioner(unittest.TestCase):
......@@ -861,38 +874,41 @@ class TestGPTPartitioner(unittest.TestCase):
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
startup_program = static.Program()
parallelizer = AutoParallelizer(FakeFleet())
dist_context = parallelizer._dist_context
dist_context.process_mesh = _global_process_mesh
train_program, start_program, loss = gpt_pretrain_forward(train_program,
start_program)
train_program, startup_program, loss = gpt_pretrain_forward(
train_program, startup_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# serial forward pass
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
# serial backward pass
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
rank_id = 3
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, start_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, start_program,
auto_parallel_main_prog, auto_parallel_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
with open("./test_auto_parallel_partitioner_serial_main_new.txt",
"w") as fw:
fw.write(str(train_program))
with open("./test_auto_parallel_partitioner_serial_startup_new.txt",
"w") as fw:
fw.write(str(start_program))
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
fw.write(str(startup_program))
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
set_default_distributed_context(dist_context)
with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
......@@ -927,7 +943,7 @@ class TestGPTPartitioner(unittest.TestCase):
complete_train_program, weights, 0, 1))
all_params = sorted(
[param.name for param in start_program.all_parameters()])
[param.name for param in startup_program.all_parameters()])
allreduce_grads = [
'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2',
'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2',
......
......@@ -24,6 +24,7 @@ import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process_group import _g_process_group_map
......@@ -145,22 +146,34 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
# auto completion
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
......
......@@ -24,6 +24,7 @@ import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
......@@ -109,22 +110,34 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
# auto completion
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
......
......@@ -24,6 +24,7 @@ import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
......@@ -125,22 +126,32 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
# auto completion
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
......@@ -253,14 +264,15 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 0
dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
partitioner = Partitioner(dist_context, rank_id)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
reshard(auto_parallel_main_prog, startup_program, rank_id, dist_context)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
dist_context)
# the x should not be slice
self.assertTrue(check_allgather(auto_parallel_main_prog))
self.assertTrue(check_allgather(partitioned_main_prog))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册