未验证 提交 c4fdb057 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Paralle] partitioner refactor (#37853)

上级 b463dff4
......@@ -404,7 +404,7 @@ class DistributedOperatorContext:
def get_cur_src_op(self):
return self._cur_src_op
def prepare_forward_context(self, src_op):
def prepare_context(self, src_op):
self.set_cur_src_op(src_op)
......@@ -413,6 +413,7 @@ class DistributedOperatorContext:
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
assert varname in self._varname_mapping
varnames.append(self._varname_mapping[varname])
kinputs[input_name] = varnames
......@@ -421,29 +422,8 @@ class DistributedOperatorContext:
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
assert varname in self._varname_mapping
varnames.append(self._varname_mapping[varname])
koutputs[output_name] = varnames
return kinputs, koutputs
def prepare_backward_context(self, backward_op):
self.set_cur_src_op(backward_op)
# build input varname mapping
kinputs = {}
for input_name in backward_op.desc.input_names():
varnames = []
for varname in backward_op.desc.input(input_name):
varnames.append(varname)
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in backward_op.desc.output_names():
varnames = []
for varname in backward_op.desc.output(output_name):
varnames.append(varname)
koutputs[output_name] = varnames
return kinputs, koutputs
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License
from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_registries = {}
......@@ -138,3 +140,46 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
exact_shape.append(new_shape)
return exact_shape
def set_comm_op_dist_attr_for_program(new_op, process_mesh, tensor_dist_attr,
ctx):
assert process_mesh is not None
assert tensor_dist_attr is not None
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = process_mesh
for input_varname in new_op.desc.input_arg_names():
new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
for output_varname in new_op.desc.output_arg_names():
new_op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh
for input_name in ref_op.input_names:
assert input_name in new_op.input_names
assert len(ref_op.input(input_name)) == 1
assert len(new_op.input(input_name)) == 1
ref_tensor_dist_attr = ref_dist_attr.get_input_dist_attr(
ref_op.input(input_name)[0])
new_op_dist_attr.set_input_dist_attr(
new_op.input(input_name)[0], ref_tensor_dist_attr)
for output_name in ref_op.output_names:
assert output_name in new_op.output_names
assert len(ref_op.output(output_name)) == 1
assert len(new_op.output(output_name)) == 1
ref_tensor_dist_attr = ref_dist_attr.get_output_dist_attr(
ref_op.output(output_name)[0])
new_op_dist_attr.set_output_dist_attr(
new_op.output(output_name)[0], ref_tensor_dist_attr)
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
......@@ -66,7 +66,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op()
varname_mapping = dist_op_context.get_varname_mapping()
rank_id = dist_op_context.get_rank_id()
# check validation of inputs / outputs
......@@ -153,6 +152,31 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
str(backward_op))
rank_id = dist_op_context.get_rank_id()
# check validation of inputs / outputs
for input_name in backward_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
assert len(kwargs[input_name]) == len(
backward_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in backward_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
assert len(kwargs[output_name]) == len(
backward_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(backward_op.desc)
for input_name in backward_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in backward_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited,
# we need insert gradient allreduce for the gradient of parameter in its output
......
......@@ -16,14 +16,14 @@ from .common import infer_shape
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute
from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
......@@ -329,9 +329,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh,
rank_id)
# check if need gradient allreduce
need_gradient_allreduce = False
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
assert 'W' in kwargs, "input [{}] is not given".format('W')
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out')
......@@ -355,6 +352,84 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs['W@GRAD'])
Ids_var = main_block.var(kwargs['Ids'][0])
Weight_var = main_block.var(kwargs['W'][0])
Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Weight_grad = main_block.var(kwargs['W@GRAD'][0])
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
Weight_var.name)[0]
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes
# A generalized method to caculate embedding offset using cartisian product
relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape,
embedding_row_dim_mapping, rank_id)
per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size
check_variable_and_dtype(
Out_grad, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", '@tmp_0@GRAD'])),
dtype=Out_grad.dtype,
shape=Out_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_grad.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
assert out_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_grad_dist_attr)
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
embedding_row_dim_mapping, rank_id)
group = new_process_group(group_ranks)
c_identity_op = main_block.append_op(
type='c_identity',
inputs={'X': [Out_grad]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'], 'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh,
out_grad_dist_attr, ctx)
main_block._sync_with_cpp()
c_embedding_grad_op_desc = main_block.desc.append_op()
c_embedding_grad_op_desc.set_type("c_embedding_grad")
c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
c_embedding_grad_op_desc.set_input('W', [Weight_var.name])
c_embedding_grad_op_desc.set_input('Out@GRAD',
[intermediate_var_0.name])
c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name])
c_embedding_grad_op_desc._set_attr('start_index', relative_idx)
c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
main_block._sync_with_cpp()
c_embedding_grad_op = main_block.ops[-1]
assert c_embedding_grad_op.type == "c_embedding_grad"
naive_copy_op_dist_attr_for_program(c_embedding_grad_op, backward_op,
ctx)
# check if need gradient allreduce
need_gradient_allreduce = False
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name)
mesh_shape = process_mesh.topology
......
......@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from .common import infer_shape
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
......@@ -33,6 +35,20 @@ from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
def copy_op_with_new_input_output(block, src_op, **kwargs):
dist_op_desc = block.desc.append_op()
dist_op_desc.copy_from(src_op.desc)
for input_name in src_op.desc.input_names():
assert input_name in kwargs
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
assert input_name in kwargs
dist_op_desc.set_output(output_name, kwargs[output_name])
block._sync_with_cpp()
return dist_op_desc
def _update_dims_mapping_for_matmul(dist_op):
changed = False
op_desc = dist_op.serial_op.desc
......@@ -141,15 +157,11 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
# check if need gradient allreduce
need_gradient_allreduce = False
assert 'Y' in kwargs, "input [{}] is not given".format('Y')
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
assert len(
kwargs['Y']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
......@@ -166,15 +178,138 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
kwargs['Y@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['Y@GRAD'])
assert len(
kwargs['X@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['X@GRAD'])
X_var = main_block.var(kwargs['X'][0])
Y_var = main_block.var(kwargs['Y'][0])
Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Y_grad = main_block.var(kwargs['Y@GRAD'][0])
assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name)
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes
assert len(
Y_var_dim_mapping
) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
Y_var.name, Y_var_dim_mapping)
Y_var_partitioned = False
for dim in Y_var_dim_mapping:
if dim >= 0 and process_mesh_shape[dim] > 0:
Y_var_partitioned = True
break
if Y_var.is_parameter and Y_var_partitioned:
if Y_var_dim_mapping[0] >= 0:
# row parallel: c_identity + matmul
assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0]
check_variable_and_dtype(
Out_grad, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])) + "@GRAD",
dtype=Out_grad.dtype,
shape=Out_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_grad.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
assert out_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_grad_dist_attr)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id)
group = new_process_group(group_ranks)
c_identity_op = main_block.append_op(
type='c_identity',
inputs={'X': [Out_grad]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'],
'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
set_comm_op_dist_attr_for_program(
c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx)
new_kwargs = copy.deepcopy(kwargs)
new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output(
main_block, backward_op, **new_kwargs)
else:
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
new_kwargs = copy.deepcopy(kwargs)
# NOTE (JZ-LIANG) should allow left operand be empty for matmul grad
has_x_grad = len(kwargs['X@GRAD']) > 0
if has_x_grad:
assert len(kwargs['X@GRAD']) == 1
X_grad = main_block.var(kwargs['X@GRAD'][0])
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])) + "@GRAD",
dtype=X_grad.dtype,
shape=X_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_grad.stop_gradient)
X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
assert X_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
X_grad_dist_attr)
new_kwargs['X@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output(
main_block, backward_op, **new_kwargs)
# NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
if has_x_grad:
group_ranks = _get_comm_group(process_mesh_group,
process_mesh_shape, parallel_axis,
rank_id)
group = new_process_group(group_ranks)
c_allreduce_sum_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0.name]},
outputs={'Out': kwargs['X@GRAD']},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward
})
set_comm_op_dist_attr_for_program(c_allreduce_sum_op,
dist_attr.process_mesh,
X_grad_dist_attr, ctx)
else:
# replicate
matmul_op_desc = copy_op_with_new_input_output(main_block, backward_op,
**kwargs)
main_block._sync_with_cpp()
# check if need gradient allreduce
need_gradient_allreduce = False
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name)
mesh_shape = process_mesh.topology
......@@ -187,7 +322,6 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)
Y_var = main_block.var(kwargs['Y'][0])
if need_gradient_allreduce and Y_var.is_parameter:
Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
allreduce_op = main_block.append_op(
......
......@@ -43,7 +43,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
super(DistributedReshapeImpl0, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = True
self._backward_implemented = False
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
......@@ -200,7 +200,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
super(DistributedReshapeImpl1, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = True
self._backward_implemented = False
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
......
......@@ -39,7 +39,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
super(DistributedSoftmaxImpl, self).__init__()
self._name = name
self._forward_implemented = False
self._backward_implemented = True
self._backward_implemented = False
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
......
......@@ -39,7 +39,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
super(DistributedTranspose2Impl, self).__init__()
self._name = name
self._forward_implemented = False
self._backward_implemented = True
self._backward_implemented = False
def is_input_compatible(self, dist_op):
return True
......
......@@ -22,15 +22,15 @@ import subprocess
import logging
import pickle
import time
import paddle
from paddle.distributed.utils import get_logger
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from paddle.fluid import program_guard
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation
from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .process_group import get_process_group
......@@ -79,6 +79,7 @@ class AutoParallelizer:
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False
self._pass_context = None
def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix()
......@@ -90,28 +91,112 @@ class AutoParallelizer:
if suffix in attr_name:
op._remove_attr(attr_name)
def _apply_serial_forward_pass(self, main_program, startup_program):
# apply amp forward pass
if self._dist_strategy.amp:
auto_parallel_amp_pass = new_pass("auto_parallel_amp_pass",
self._dist_strategy.amp_configs)
auto_parallel_amp_pass.apply_forward(main_program, startup_program,
self._pass_context)
# apply recompute forward pass
if self._dist_strategy.recompute:
auto_parallel_recompute_pass = new_pass(
"auto_parallel_recompute_pass",
self._dist_strategy.recompute_configs)
auto_parallel_recompute_pass.apply_forward(
main_program, startup_program, self._pass_context)
def _generate_backward(self, main_program, startup_program, loss,
parameter_list, no_grad_set, callbacks):
# apply recompute backward pass
if self._dist_strategy.recompute:
assert auto_parallel_recompute_pass
auto_parallel_recompute_pass.apply_forward(
main_program, startup_program, parameter_list, no_grad_set,
self._pass_context)
else:
from paddle.fluid.backward import append_backward
with program_guard(main_program, startup_program):
params_grads = append_backward(
loss,
parameter_list,
no_grad_set,
callbacks,
distop_context=self._dist_context.dist_op_context)
complete_backward_annotation(
main_program, dist_context=self._dist_context)
# apply amp forward pass
if self._dist_strategy.amp:
assert auto_parallel_amp_pass
auto_parallel_amp_pass.apply_backward(main_program, startup_program,
self._pass_context)
return params_grads
def _apply_optimize(self, main_program, startup_program, params_grads):
if self._dist_strategy.sharding:
auto_parallel_sharding_pass = new_pass(
"auto_parallel_sharding_pass", self._dist_strategy)
params_grads = auto_parallel_sharding_pass.apply(
main_program, startup_program, params_grads, self._pass_context)
if self._dist_strategy.gradient_merge:
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass",
self._dist_strategy.gradient_merge_configs)
auto_parallel_gradient_merge_pass.apply(
main_program, startup_program, params_grads, self._pass_context)
else:
with program_guard(main_program, startup_program):
optimizer = copy.deepcopy(self._optimizer)
optimize_ops = optimizer.apply_gradients(params_grads)
# update completion
complete_update_annotation(
main_program, dist_context=self._dist_context)
return optimize_ops
def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None
serial_main_program = self._main_program.clone()
serial_startup_program = self._startup_program.clone()
serial_loss = serial_main_program.global_block().var(self._loss.name)
# generating serial
if dist_context is None:
# Annotation completion
self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(self._main_program,
completed_main_program = complete_annotation(serial_main_program,
self._dist_context)
else:
completed_main_program = self._main_program
completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context)
# Logical partition
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
dist_main_prog, dist_startup_prog = partitioner.transpile_forward(
completed_main_program, self._startup_program)
dist_params_grads = partitioner.apply_backward(
self._loss, completed_main_program, self._startup_program,
dist_main_prog, dist_startup_prog)
dist_optimize_ops = partitioner.apply_optimize(
copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog,
dist_startup_prog)
# serial forward pass
self._apply_serial_forward_pass(completed_main_program,
serial_startup_program)
# serial backward pass
params_grads = self._generate_backward(
completed_main_program, serial_startup_program, serial_loss,
self._parameter_list, self._no_grad_set, self._callbacks)
# Logical partition
rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
completed_main_program, serial_startup_program, params_grads)
# TODO refactor the placement of optimizer
# generate optimize program
dist_optimize_ops = self._apply_optimize(
dist_main_prog, dist_startup_prog, dist_params_grads)
set_grad_var_shape(dist_main_prog, self._dist_context)
......@@ -133,13 +218,15 @@ class AutoParallelizer:
loss,
startup_program,
parameter_list=None,
no_grad_set=None):
no_grad_set=None,
callbacks=None):
assert startup_program is not None
self._loss = loss
self._startup_program = startup_program
self._main_program = loss.block.program
self._parameter_list = parameter_list
self._no_grad_set = no_grad_set
self._callbacks = callbacks
if self._enable_auto_mapping and self._need_rank_mapping:
# Do the mapping pass before parallelization
......@@ -156,6 +243,7 @@ class AutoParallelizer:
self._optimizer, self._cluster)
planner = Planner(
serial_program_info,
self,
algorithm_config={"name": "mcmc",
"max_search_times": 5})
dist_context, _ = planner.search()
......@@ -262,6 +350,7 @@ class AutoParallelizer:
cluster=self._cluster)
planner = Planner(
serial_program_info,
self,
algorithm_config={
"name": "mcmc",
"max_search_times": 5
......@@ -303,3 +392,14 @@ class AutoParallelizer:
self._remove_distributed_attrs(dist_main_prog)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_main_program" or k == "_startup_program" or k == "_dist_context" or k == "_fleet" or k == "_loss":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
此差异已折叠。
......@@ -386,15 +386,20 @@ class SearchAlgorithm:
class MCMC(SearchAlgorithm):
def __init__(self, serial_program_info, max_search_times=5):
def __init__(self, serial_program_info, parallelizer, max_search_times=5):
super(MCMC, self).__init__("mcmc")
self._serial_program_info = serial_program_info
self._max_search_times = max_search_times
self._parallelizer = parallelizer
@property
def serial_program_info(self):
return self._serial_program_info
@property
def parallelizer(self):
return self._parallelizer
@property
def max_search_times(self):
return self._max_search_times
......@@ -483,7 +488,7 @@ class MCMC(SearchAlgorithm):
cost = None
# get all distributed programs
all_dist_main_program = get_all_distributed_main_program(
self.serial_program_info, dist_context)
self.serial_program_info, dist_context, self.parallelizer)
pipeline_config = [
process_mesh.processes for process_mesh in pipeline_process_meshes
] if pipeline_process_meshes is not None else None
......@@ -829,8 +834,10 @@ class MCMC(SearchAlgorithm):
class Planner:
def __init__(self, serial_program_info, algorithm_config=None):
def __init__(self, serial_program_info, parallelizer,
algorithm_config=None):
self._serial_program_info = serial_program_info
self._parallelizer = parallelizer
self._algorithm_config = algorithm_config
self._algorithm_searcher = self.create_algorithm_searcher(
algorithm_config)
......@@ -847,6 +854,10 @@ class Planner:
def algorithm_searcher(self):
return self._algorithm_searcher
@property
def parallelizer(self):
return self._parallelizer
def create_algorithm_searcher(self, algorithm_config):
name = algorithm_config.get("name", None)
assert name is not None, "Invalid algorithm config."
......@@ -856,9 +867,9 @@ class Planner:
# NOTE: Only GPU clusters are supported now.
max_search_times = algorithm_config.get("max_search_times", None)
algorithm_searcher = MCMC(
self.serial_program_info,
self.serial_program_info, self.parallelizer,
max_search_times) if max_search_times is not None else MCMC(
self.serial_program_info)
self.serial_program_info, self.parallelizer)
else:
raise NotImplementedError(
"Other search algorithms have not been supported now.")
......
......@@ -993,7 +993,9 @@ def set_grad_var_shape(program, dist_context):
block = program.global_block()
vars = block.vars
for op in block.ops:
if op.type == "sum":
if op.type in [
"sum", "check_finite_and_unscale", "update_loss_scaling"
]:
continue
if int(op.attr('op_role')) == int(OpRole.Backward):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
......@@ -1004,15 +1006,24 @@ def set_grad_var_shape(program, dist_context):
forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale":
forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad":
forward_var_name = None
for output_name in op.output_names:
if var_name in op.output(output_name):
assert "@GRAD" in output_name
input_name = output_name[:output_name.find("@GRAD")]
assert len(op.input(input_name)) == 1
forward_var_name = op.input(input_name)[0]
assert forward_var_name is not None
need_set_shape_list = [
"reshape2_grad", "softmax_with_cross_entropy_grad",
"transpose2_grad", "softmax_grad", "cross_entropy_grad2",
"dropout_grad", "unsqueeze2_grad"
"dropout_grad"
]
forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout", "unsqueeze2"
"softmax", "cross_entropy2", "dropout"
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
......@@ -1041,6 +1052,23 @@ def set_grad_var_shape(program, dist_context):
grad_var.desc.set_shape(ref_shape)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
def is_forward_op(op):
ref_role1 = int(core.op_proto_and_checker_maker.OpRole.Forward)
ref_role2 = int(core.op_proto_and_checker_maker.OpRole.Loss)
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == ref_role1 or
op_role == ref_role2)
def is_backward_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
def update_op_dims_mapping_by_default_dist_impl(dist_op):
changed = False
op_dist_attr = dist_op.dist_attr
......@@ -1177,57 +1205,25 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op):
return changed
def get_all_distributed_main_program(serial_program_info, dist_context):
def get_all_distributed_main_program(serial_program_info, dist_context,
parallelizer):
"Get all distributed main programs by dist_context."
from .dist_context import DistributedOperatorContext
from .dist_context import DistributedOperatorContext, DistributedContext
cluster = serial_program_info.cluster
copied_parallelizer = copy.deepcopy(parallelizer)
all_dist_main_program = []
ranks = paddle.distributed.get_world_size() if cluster is None else len(
cluster.get_all_devices("GPU"))
for rank_id in range(ranks):
used_dist_context = copy.deepcopy(dist_context)
used_dist_context._dist_op_context = DistributedOperatorContext()
dist_main_program, dist_startup_program = get_specified_distributed_main_program(
serial_program_info, used_dist_context, rank_id)
_, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program(
rank_id, used_dist_context)
all_dist_main_program.append(dist_main_program)
return all_dist_main_program
def get_specified_distributed_main_program(serial_program_info, dist_context,
rank_id):
"Get distributed main program by the given dist_context and rank_id."
from .partitioner import Partitioner
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from .process_group import _g_process_group_map, ProcessGroup
dist_strategy = paddle.distributed.fleet.DistributedStrategy()
train_program = serial_program_info.train_program
startup_program = serial_program_info.startup_program
loss = serial_program_info.loss
optimizer = serial_program_info.optimizer
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
dist_main_program, dist_startup_program = partitioner.transpile_forward(
train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, train_program, startup_program, dist_main_program,
dist_startup_program)
opt_ops = partitioner.apply_optimize(
copy.deepcopy(optimizer), dist_params_grads, dist_main_program,
dist_startup_program)
set_grad_var_shape(dist_main_program, dist_context)
make_data_unshard(dist_main_program, dist_startup_program, dist_context)
reshard(dist_main_program, dist_startup_program, rank_id, dist_context)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
return dist_main_program, dist_startup_program
class SerialProgramInfo:
def __init__(self,
train_program,
......@@ -1286,7 +1282,6 @@ def get_standalone_cost_data(distributed_programs):
shape = list(map(lambda x: int(x.strip()), shape))
dtype_factor = 1
total_static_input_size += reduce(lambda x, y: x * y, shape)
# print(arg_name_lower)
if op.type == "c_embedding":
arg_name_lower = "w" if arg_name_lower == "weight" else "ids"
for arg_name in op.input_names:
......@@ -1301,7 +1296,8 @@ def get_standalone_cost_data(distributed_programs):
actual_runtime = total_actual_input_size / total_static_input_size * runtime
return actual_runtime
cost_model = paddle.cost_model.CostModel()
import paddle.cost_model as cm
cost_model = cm.CostModel()
cost_model.static_cost_data()
DEFAULT_MULTIPLE = 2
OP_NAME_MAPPING = {
......
......@@ -26,7 +26,7 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core
......@@ -148,22 +148,33 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
dist_strategy = fleet.DistributedStrategy()
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# auto completion
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
......
......@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
......@@ -469,21 +470,31 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
dist_strategy = fleet.DistributedStrategy()
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# logical partition
dist_train_program, dist_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, dist_train_program,
dist_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
dist_train_program, dist_startup_prog)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
partitioner = Partitioner(dist_context, rank_id)
dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
dist_train_program, dist_startup_prog, dist_params_grads)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context)
return dist_train_program, dist_startup_prog
......
......@@ -54,9 +54,9 @@ def get_programs(annotated_func):
rank_id = 3
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog = partitioner.transpile_forward(
complete_train_program, start_program)
partitioner = Partitioner(dist_context, rank_id)
test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, _ = partitioner.partition(
complete_train_program, start_program, [])
return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context
......
......@@ -35,6 +35,7 @@ from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_pr
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.process_group import new_process_group
......@@ -790,9 +791,9 @@ class GPTPretrainingCriterion(nn.Layer):
return loss
def gpt_pretrain_forward(train_program, start_program):
def gpt_pretrain_forward(train_program, startup_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
startup_program), utils.unique_name.guard():
batch_size = 16
sequence_len = 512
input_ids = static.data(
......@@ -848,7 +849,19 @@ def gpt_pretrain_forward(train_program, start_program):
loss = criterion(preds, labels, loss_mask)
return train_program, start_program, loss
return train_program, startup_program, loss
class FakeStrategy(object):
def __init__(self):
self.amp = False
self.recompute = False
class FakeFleet(object):
def __init__(self):
self.user_defined_optimizer = None
self._user_defined_strategy = FakeStrategy()
class TestGPTPartitioner(unittest.TestCase):
......@@ -861,38 +874,41 @@ class TestGPTPartitioner(unittest.TestCase):
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
startup_program = static.Program()
parallelizer = AutoParallelizer(FakeFleet())
dist_context = parallelizer._dist_context
dist_context.process_mesh = _global_process_mesh
train_program, start_program, loss = gpt_pretrain_forward(train_program,
start_program)
train_program, startup_program, loss = gpt_pretrain_forward(
train_program, startup_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# serial forward pass
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
# serial backward pass
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
rank_id = 3
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, start_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, start_program,
auto_parallel_main_prog, auto_parallel_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
with open("./test_auto_parallel_partitioner_serial_main_new.txt",
"w") as fw:
fw.write(str(train_program))
with open("./test_auto_parallel_partitioner_serial_startup_new.txt",
"w") as fw:
fw.write(str(start_program))
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
fw.write(str(startup_program))
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
set_default_distributed_context(dist_context)
with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
......@@ -927,7 +943,7 @@ class TestGPTPartitioner(unittest.TestCase):
complete_train_program, weights, 0, 1))
all_params = sorted(
[param.name for param in start_program.all_parameters()])
[param.name for param in startup_program.all_parameters()])
allreduce_grads = [
'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2',
'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2',
......
......@@ -24,6 +24,7 @@ import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process_group import _g_process_group_map
......@@ -145,22 +146,34 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
# auto completion
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
......
......@@ -24,6 +24,7 @@ import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
......@@ -109,22 +110,34 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
# auto completion
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
......
......@@ -24,6 +24,7 @@ import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
......@@ -125,22 +126,32 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
# auto completion
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
return auto_parallel_main_prog, auto_parallel_startup_prog
......@@ -253,14 +264,15 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 0
dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
partitioner = Partitioner(dist_context, rank_id)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
reshard(auto_parallel_main_prog, startup_program, rank_id, dist_context)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
dist_context)
# the x should not be slice
self.assertTrue(check_allgather(auto_parallel_main_prog))
self.assertTrue(check_allgather(partitioned_main_prog))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册