未验证 提交 c4fdb057 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Paralle] partitioner refactor (#37853)

上级 b463dff4
...@@ -404,7 +404,7 @@ class DistributedOperatorContext: ...@@ -404,7 +404,7 @@ class DistributedOperatorContext:
def get_cur_src_op(self): def get_cur_src_op(self):
return self._cur_src_op return self._cur_src_op
def prepare_forward_context(self, src_op): def prepare_context(self, src_op):
self.set_cur_src_op(src_op) self.set_cur_src_op(src_op)
...@@ -413,6 +413,7 @@ class DistributedOperatorContext: ...@@ -413,6 +413,7 @@ class DistributedOperatorContext:
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
varnames = [] varnames = []
for varname in src_op.desc.input(input_name): for varname in src_op.desc.input(input_name):
assert varname in self._varname_mapping
varnames.append(self._varname_mapping[varname]) varnames.append(self._varname_mapping[varname])
kinputs[input_name] = varnames kinputs[input_name] = varnames
...@@ -421,29 +422,8 @@ class DistributedOperatorContext: ...@@ -421,29 +422,8 @@ class DistributedOperatorContext:
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
varnames = [] varnames = []
for varname in src_op.desc.output(output_name): for varname in src_op.desc.output(output_name):
assert varname in self._varname_mapping
varnames.append(self._varname_mapping[varname]) varnames.append(self._varname_mapping[varname])
koutputs[output_name] = varnames koutputs[output_name] = varnames
return kinputs, koutputs return kinputs, koutputs
def prepare_backward_context(self, backward_op):
self.set_cur_src_op(backward_op)
# build input varname mapping
kinputs = {}
for input_name in backward_op.desc.input_names():
varnames = []
for varname in backward_op.desc.input(input_name):
varnames.append(varname)
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in backward_op.desc.output_names():
varnames = []
for varname in backward_op.desc.output(output_name):
varnames.append(varname)
koutputs[output_name] = varnames
return kinputs, koutputs
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_registries = {} _g_distributed_operator_impl_registries = {}
...@@ -138,3 +140,46 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): ...@@ -138,3 +140,46 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
exact_shape.append(new_shape) exact_shape.append(new_shape)
return exact_shape return exact_shape
def set_comm_op_dist_attr_for_program(new_op, process_mesh, tensor_dist_attr,
ctx):
assert process_mesh is not None
assert tensor_dist_attr is not None
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = process_mesh
for input_varname in new_op.desc.input_arg_names():
new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
for output_varname in new_op.desc.output_arg_names():
new_op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh
for input_name in ref_op.input_names:
assert input_name in new_op.input_names
assert len(ref_op.input(input_name)) == 1
assert len(new_op.input(input_name)) == 1
ref_tensor_dist_attr = ref_dist_attr.get_input_dist_attr(
ref_op.input(input_name)[0])
new_op_dist_attr.set_input_dist_attr(
new_op.input(input_name)[0], ref_tensor_dist_attr)
for output_name in ref_op.output_names:
assert output_name in new_op.output_names
assert len(ref_op.output(output_name)) == 1
assert len(new_op.output(output_name)) == 1
ref_tensor_dist_attr = ref_dist_attr.get_output_dist_attr(
ref_op.output(output_name)[0])
new_op_dist_attr.set_output_dist_attr(
new_op.output(output_name)[0], ref_tensor_dist_attr)
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
...@@ -66,7 +66,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -66,7 +66,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block() startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
varname_mapping = dist_op_context.get_varname_mapping()
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.get_rank_id()
# check validation of inputs / outputs # check validation of inputs / outputs
...@@ -153,6 +152,31 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -153,6 +152,31 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
str(backward_op)) str(backward_op))
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.get_rank_id()
# check validation of inputs / outputs
for input_name in backward_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
assert len(kwargs[input_name]) == len(
backward_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in backward_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
assert len(kwargs[output_name]) == len(
backward_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(backward_op.desc)
for input_name in backward_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in backward_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# check if need gradient allreduce # check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited, # if there is a non-gradient & non-parameter input and its batch dimension is splited,
# we need insert gradient allreduce for the gradient of parameter in its output # we need insert gradient allreduce for the gradient of parameter in its output
......
...@@ -16,14 +16,14 @@ from .common import infer_shape ...@@ -16,14 +16,14 @@ from .common import infer_shape
from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
...@@ -329,9 +329,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -329,9 +329,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh,
rank_id) rank_id)
# check if need gradient allreduce
need_gradient_allreduce = False
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
assert 'W' in kwargs, "input [{}] is not given".format('W') assert 'W' in kwargs, "input [{}] is not given".format('W')
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out') assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out')
...@@ -355,6 +352,84 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -355,6 +352,84 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs['W@GRAD']) kwargs['W@GRAD'])
Ids_var = main_block.var(kwargs['Ids'][0]) Ids_var = main_block.var(kwargs['Ids'][0])
Weight_var = main_block.var(kwargs['W'][0])
Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Weight_grad = main_block.var(kwargs['W@GRAD'][0])
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
Weight_var.name)[0]
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes
# A generalized method to caculate embedding offset using cartisian product
relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape,
embedding_row_dim_mapping, rank_id)
per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size
check_variable_and_dtype(
Out_grad, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", '@tmp_0@GRAD'])),
dtype=Out_grad.dtype,
shape=Out_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_grad.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
assert out_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_grad_dist_attr)
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
embedding_row_dim_mapping, rank_id)
group = new_process_group(group_ranks)
c_identity_op = main_block.append_op(
type='c_identity',
inputs={'X': [Out_grad]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'], 'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh,
out_grad_dist_attr, ctx)
main_block._sync_with_cpp()
c_embedding_grad_op_desc = main_block.desc.append_op()
c_embedding_grad_op_desc.set_type("c_embedding_grad")
c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
c_embedding_grad_op_desc.set_input('W', [Weight_var.name])
c_embedding_grad_op_desc.set_input('Out@GRAD',
[intermediate_var_0.name])
c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name])
c_embedding_grad_op_desc._set_attr('start_index', relative_idx)
c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
main_block._sync_with_cpp()
c_embedding_grad_op = main_block.ops[-1]
assert c_embedding_grad_op.type == "c_embedding_grad"
naive_copy_op_dist_attr_for_program(c_embedding_grad_op, backward_op,
ctx)
# check if need gradient allreduce
need_gradient_allreduce = False
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name) var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
......
...@@ -12,11 +12,13 @@ ...@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
import copy
from .common import infer_shape from .common import infer_shape
from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -33,6 +35,20 @@ from ..process_group import new_process_group ...@@ -33,6 +35,20 @@ from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank
def copy_op_with_new_input_output(block, src_op, **kwargs):
dist_op_desc = block.desc.append_op()
dist_op_desc.copy_from(src_op.desc)
for input_name in src_op.desc.input_names():
assert input_name in kwargs
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
assert input_name in kwargs
dist_op_desc.set_output(output_name, kwargs[output_name])
block._sync_with_cpp()
return dist_op_desc
def _update_dims_mapping_for_matmul(dist_op): def _update_dims_mapping_for_matmul(dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
...@@ -141,15 +157,11 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -141,15 +157,11 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
if rank_id not in dist_attr.process_mesh.processes: if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id) rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
# check if need gradient allreduce
need_gradient_allreduce = False
assert 'Y' in kwargs, "input [{}] is not given".format('Y') assert 'Y' in kwargs, "input [{}] is not given".format('Y')
assert 'X' in kwargs, "input [{}] is not given".format('X') assert 'X' in kwargs, "input [{}] is not given".format('X')
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD') assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD') assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD') assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
assert len( assert len(
kwargs['Y'] kwargs['Y']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
...@@ -166,15 +178,138 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -166,15 +178,138 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
kwargs['Y@GRAD'] kwargs['Y@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format( ) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['Y@GRAD']) kwargs['Y@GRAD'])
assert len(
kwargs['X@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['X@GRAD'])
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Y_var = main_block.var(kwargs['Y'][0])
Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Y_grad = main_block.var(kwargs['Y@GRAD'][0])
assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format( assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name) X_var.name)
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes
assert len(
Y_var_dim_mapping
) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
Y_var.name, Y_var_dim_mapping)
Y_var_partitioned = False
for dim in Y_var_dim_mapping:
if dim >= 0 and process_mesh_shape[dim] > 0:
Y_var_partitioned = True
break
if Y_var.is_parameter and Y_var_partitioned:
if Y_var_dim_mapping[0] >= 0:
# row parallel: c_identity + matmul
assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0]
check_variable_and_dtype(
Out_grad, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])) + "@GRAD",
dtype=Out_grad.dtype,
shape=Out_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_grad.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
assert out_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_grad_dist_attr)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id)
group = new_process_group(group_ranks)
c_identity_op = main_block.append_op(
type='c_identity',
inputs={'X': [Out_grad]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'],
'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
set_comm_op_dist_attr_for_program(
c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx)
new_kwargs = copy.deepcopy(kwargs)
new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output(
main_block, backward_op, **new_kwargs)
else:
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
new_kwargs = copy.deepcopy(kwargs)
# NOTE (JZ-LIANG) should allow left operand be empty for matmul grad
has_x_grad = len(kwargs['X@GRAD']) > 0
if has_x_grad:
assert len(kwargs['X@GRAD']) == 1
X_grad = main_block.var(kwargs['X@GRAD'][0])
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])) + "@GRAD",
dtype=X_grad.dtype,
shape=X_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_grad.stop_gradient)
X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
assert X_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
X_grad_dist_attr)
new_kwargs['X@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output(
main_block, backward_op, **new_kwargs)
# NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
if has_x_grad:
group_ranks = _get_comm_group(process_mesh_group,
process_mesh_shape, parallel_axis,
rank_id)
group = new_process_group(group_ranks)
c_allreduce_sum_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0.name]},
outputs={'Out': kwargs['X@GRAD']},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward
})
set_comm_op_dist_attr_for_program(c_allreduce_sum_op,
dist_attr.process_mesh,
X_grad_dist_attr, ctx)
else:
# replicate
matmul_op_desc = copy_op_with_new_input_output(main_block, backward_op,
**kwargs)
main_block._sync_with_cpp()
# check if need gradient allreduce
need_gradient_allreduce = False
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name) var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
...@@ -187,7 +322,6 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -187,7 +322,6 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_degree = len(group_ranks) dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks) dp_group = new_process_group(group_ranks)
Y_var = main_block.var(kwargs['Y'][0])
if need_gradient_allreduce and Y_var.is_parameter: if need_gradient_allreduce and Y_var.is_parameter:
Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0]) Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
allreduce_op = main_block.append_op( allreduce_op = main_block.append_op(
......
...@@ -43,7 +43,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -43,7 +43,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
super(DistributedReshapeImpl0, self).__init__() super(DistributedReshapeImpl0, self).__init__()
self._name = name self._name = name
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = False
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
...@@ -200,7 +200,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -200,7 +200,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
super(DistributedReshapeImpl1, self).__init__() super(DistributedReshapeImpl1, self).__init__()
self._name = name self._name = name
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = False
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
......
...@@ -39,7 +39,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -39,7 +39,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
super(DistributedSoftmaxImpl, self).__init__() super(DistributedSoftmaxImpl, self).__init__()
self._name = name self._name = name
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = True self._backward_implemented = False
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
......
...@@ -39,7 +39,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -39,7 +39,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
super(DistributedTranspose2Impl, self).__init__() super(DistributedTranspose2Impl, self).__init__()
self._name = name self._name = name
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = True self._backward_implemented = False
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
return True return True
......
...@@ -22,15 +22,15 @@ import subprocess ...@@ -22,15 +22,15 @@ import subprocess
import logging import logging
import pickle import pickle
import time import time
import paddle import paddle
from paddle.distributed.utils import get_logger from paddle.distributed.utils import get_logger
from paddle.distributed.fleet import cloud_utils from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import program_guard
from .dist_context import DistributedContext from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_all_process_groups from .process_group import get_all_process_groups
from .process_group import get_process_group from .process_group import get_process_group
...@@ -79,6 +79,7 @@ class AutoParallelizer: ...@@ -79,6 +79,7 @@ class AutoParallelizer:
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING") self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \ self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False self._need_rank_mapping.lower() == 'true' else False
self._pass_context = None
def _remove_distributed_attrs(self, main_program): def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix() suffix = core.kAutoParallelSuffix()
...@@ -90,28 +91,112 @@ class AutoParallelizer: ...@@ -90,28 +91,112 @@ class AutoParallelizer:
if suffix in attr_name: if suffix in attr_name:
op._remove_attr(attr_name) op._remove_attr(attr_name)
def _apply_serial_forward_pass(self, main_program, startup_program):
# apply amp forward pass
if self._dist_strategy.amp:
auto_parallel_amp_pass = new_pass("auto_parallel_amp_pass",
self._dist_strategy.amp_configs)
auto_parallel_amp_pass.apply_forward(main_program, startup_program,
self._pass_context)
# apply recompute forward pass
if self._dist_strategy.recompute:
auto_parallel_recompute_pass = new_pass(
"auto_parallel_recompute_pass",
self._dist_strategy.recompute_configs)
auto_parallel_recompute_pass.apply_forward(
main_program, startup_program, self._pass_context)
def _generate_backward(self, main_program, startup_program, loss,
parameter_list, no_grad_set, callbacks):
# apply recompute backward pass
if self._dist_strategy.recompute:
assert auto_parallel_recompute_pass
auto_parallel_recompute_pass.apply_forward(
main_program, startup_program, parameter_list, no_grad_set,
self._pass_context)
else:
from paddle.fluid.backward import append_backward
with program_guard(main_program, startup_program):
params_grads = append_backward(
loss,
parameter_list,
no_grad_set,
callbacks,
distop_context=self._dist_context.dist_op_context)
complete_backward_annotation(
main_program, dist_context=self._dist_context)
# apply amp forward pass
if self._dist_strategy.amp:
assert auto_parallel_amp_pass
auto_parallel_amp_pass.apply_backward(main_program, startup_program,
self._pass_context)
return params_grads
def _apply_optimize(self, main_program, startup_program, params_grads):
if self._dist_strategy.sharding:
auto_parallel_sharding_pass = new_pass(
"auto_parallel_sharding_pass", self._dist_strategy)
params_grads = auto_parallel_sharding_pass.apply(
main_program, startup_program, params_grads, self._pass_context)
if self._dist_strategy.gradient_merge:
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass",
self._dist_strategy.gradient_merge_configs)
auto_parallel_gradient_merge_pass.apply(
main_program, startup_program, params_grads, self._pass_context)
else:
with program_guard(main_program, startup_program):
optimizer = copy.deepcopy(self._optimizer)
optimize_ops = optimizer.apply_gradients(params_grads)
# update completion
complete_update_annotation(
main_program, dist_context=self._dist_context)
return optimize_ops
def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None completed_main_program = None
serial_main_program = self._main_program.clone()
serial_startup_program = self._startup_program.clone()
serial_loss = serial_main_program.global_block().var(self._loss.name)
# generating serial
if dist_context is None: if dist_context is None:
# Annotation completion # Annotation completion
self._dist_context = DistributedContext() self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.") _logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(self._main_program, completed_main_program = complete_annotation(serial_main_program,
self._dist_context) self._dist_context)
else: else:
completed_main_program = self._main_program completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context) self._dist_context = copy.deepcopy(dist_context)
# serial forward pass
self._apply_serial_forward_pass(completed_main_program,
serial_startup_program)
# serial backward pass
params_grads = self._generate_backward(
completed_main_program, serial_startup_program, serial_loss,
self._parameter_list, self._no_grad_set, self._callbacks)
# Logical partition # Logical partition
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank) rank = paddle.distributed.get_rank()
dist_main_prog, dist_startup_prog = partitioner.transpile_forward( partitioner = Partitioner(self._dist_context, rank)
completed_main_program, self._startup_program) dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
dist_params_grads = partitioner.apply_backward( completed_main_program, serial_startup_program, params_grads)
self._loss, completed_main_program, self._startup_program,
dist_main_prog, dist_startup_prog) # TODO refactor the placement of optimizer
dist_optimize_ops = partitioner.apply_optimize( # generate optimize program
copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog, dist_optimize_ops = self._apply_optimize(
dist_startup_prog) dist_main_prog, dist_startup_prog, dist_params_grads)
set_grad_var_shape(dist_main_prog, self._dist_context) set_grad_var_shape(dist_main_prog, self._dist_context)
...@@ -133,13 +218,15 @@ class AutoParallelizer: ...@@ -133,13 +218,15 @@ class AutoParallelizer:
loss, loss,
startup_program, startup_program,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None,
callbacks=None):
assert startup_program is not None assert startup_program is not None
self._loss = loss self._loss = loss
self._startup_program = startup_program self._startup_program = startup_program
self._main_program = loss.block.program self._main_program = loss.block.program
self._parameter_list = parameter_list self._parameter_list = parameter_list
self._no_grad_set = no_grad_set self._no_grad_set = no_grad_set
self._callbacks = callbacks
if self._enable_auto_mapping and self._need_rank_mapping: if self._enable_auto_mapping and self._need_rank_mapping:
# Do the mapping pass before parallelization # Do the mapping pass before parallelization
...@@ -156,6 +243,7 @@ class AutoParallelizer: ...@@ -156,6 +243,7 @@ class AutoParallelizer:
self._optimizer, self._cluster) self._optimizer, self._cluster)
planner = Planner( planner = Planner(
serial_program_info, serial_program_info,
self,
algorithm_config={"name": "mcmc", algorithm_config={"name": "mcmc",
"max_search_times": 5}) "max_search_times": 5})
dist_context, _ = planner.search() dist_context, _ = planner.search()
...@@ -262,6 +350,7 @@ class AutoParallelizer: ...@@ -262,6 +350,7 @@ class AutoParallelizer:
cluster=self._cluster) cluster=self._cluster)
planner = Planner( planner = Planner(
serial_program_info, serial_program_info,
self,
algorithm_config={ algorithm_config={
"name": "mcmc", "name": "mcmc",
"max_search_times": 5 "max_search_times": 5
...@@ -303,3 +392,14 @@ class AutoParallelizer: ...@@ -303,3 +392,14 @@ class AutoParallelizer:
self._remove_distributed_attrs(dist_main_prog) self._remove_distributed_attrs(dist_main_prog)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_main_program" or k == "_startup_program" or k == "_dist_context" or k == "_fleet" or k == "_loss":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
此差异已折叠。
...@@ -386,15 +386,20 @@ class SearchAlgorithm: ...@@ -386,15 +386,20 @@ class SearchAlgorithm:
class MCMC(SearchAlgorithm): class MCMC(SearchAlgorithm):
def __init__(self, serial_program_info, max_search_times=5): def __init__(self, serial_program_info, parallelizer, max_search_times=5):
super(MCMC, self).__init__("mcmc") super(MCMC, self).__init__("mcmc")
self._serial_program_info = serial_program_info self._serial_program_info = serial_program_info
self._max_search_times = max_search_times self._max_search_times = max_search_times
self._parallelizer = parallelizer
@property @property
def serial_program_info(self): def serial_program_info(self):
return self._serial_program_info return self._serial_program_info
@property
def parallelizer(self):
return self._parallelizer
@property @property
def max_search_times(self): def max_search_times(self):
return self._max_search_times return self._max_search_times
...@@ -483,7 +488,7 @@ class MCMC(SearchAlgorithm): ...@@ -483,7 +488,7 @@ class MCMC(SearchAlgorithm):
cost = None cost = None
# get all distributed programs # get all distributed programs
all_dist_main_program = get_all_distributed_main_program( all_dist_main_program = get_all_distributed_main_program(
self.serial_program_info, dist_context) self.serial_program_info, dist_context, self.parallelizer)
pipeline_config = [ pipeline_config = [
process_mesh.processes for process_mesh in pipeline_process_meshes process_mesh.processes for process_mesh in pipeline_process_meshes
] if pipeline_process_meshes is not None else None ] if pipeline_process_meshes is not None else None
...@@ -829,8 +834,10 @@ class MCMC(SearchAlgorithm): ...@@ -829,8 +834,10 @@ class MCMC(SearchAlgorithm):
class Planner: class Planner:
def __init__(self, serial_program_info, algorithm_config=None): def __init__(self, serial_program_info, parallelizer,
algorithm_config=None):
self._serial_program_info = serial_program_info self._serial_program_info = serial_program_info
self._parallelizer = parallelizer
self._algorithm_config = algorithm_config self._algorithm_config = algorithm_config
self._algorithm_searcher = self.create_algorithm_searcher( self._algorithm_searcher = self.create_algorithm_searcher(
algorithm_config) algorithm_config)
...@@ -847,6 +854,10 @@ class Planner: ...@@ -847,6 +854,10 @@ class Planner:
def algorithm_searcher(self): def algorithm_searcher(self):
return self._algorithm_searcher return self._algorithm_searcher
@property
def parallelizer(self):
return self._parallelizer
def create_algorithm_searcher(self, algorithm_config): def create_algorithm_searcher(self, algorithm_config):
name = algorithm_config.get("name", None) name = algorithm_config.get("name", None)
assert name is not None, "Invalid algorithm config." assert name is not None, "Invalid algorithm config."
...@@ -856,9 +867,9 @@ class Planner: ...@@ -856,9 +867,9 @@ class Planner:
# NOTE: Only GPU clusters are supported now. # NOTE: Only GPU clusters are supported now.
max_search_times = algorithm_config.get("max_search_times", None) max_search_times = algorithm_config.get("max_search_times", None)
algorithm_searcher = MCMC( algorithm_searcher = MCMC(
self.serial_program_info, self.serial_program_info, self.parallelizer,
max_search_times) if max_search_times is not None else MCMC( max_search_times) if max_search_times is not None else MCMC(
self.serial_program_info) self.serial_program_info, self.parallelizer)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Other search algorithms have not been supported now.") "Other search algorithms have not been supported now.")
......
...@@ -993,7 +993,9 @@ def set_grad_var_shape(program, dist_context): ...@@ -993,7 +993,9 @@ def set_grad_var_shape(program, dist_context):
block = program.global_block() block = program.global_block()
vars = block.vars vars = block.vars
for op in block.ops: for op in block.ops:
if op.type == "sum": if op.type in [
"sum", "check_finite_and_unscale", "update_loss_scaling"
]:
continue continue
if int(op.attr('op_role')) == int(OpRole.Backward): if int(op.attr('op_role')) == int(OpRole.Backward):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
...@@ -1004,15 +1006,24 @@ def set_grad_var_shape(program, dist_context): ...@@ -1004,15 +1006,24 @@ def set_grad_var_shape(program, dist_context):
forward_var_name = var_name[:var_name.find("@GRAD")] forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale": if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale":
forward_var_name = op.input_arg_names[0] forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad":
forward_var_name = None
for output_name in op.output_names:
if var_name in op.output(output_name):
assert "@GRAD" in output_name
input_name = output_name[:output_name.find("@GRAD")]
assert len(op.input(input_name)) == 1
forward_var_name = op.input(input_name)[0]
assert forward_var_name is not None
need_set_shape_list = [ need_set_shape_list = [
"reshape2_grad", "softmax_with_cross_entropy_grad", "reshape2_grad", "softmax_with_cross_entropy_grad",
"transpose2_grad", "softmax_grad", "cross_entropy_grad2", "transpose2_grad", "softmax_grad", "cross_entropy_grad2",
"dropout_grad", "unsqueeze2_grad" "dropout_grad"
] ]
forward_list = [ forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2", "reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout", "unsqueeze2" "softmax", "cross_entropy2", "dropout"
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
...@@ -1041,6 +1052,23 @@ def set_grad_var_shape(program, dist_context): ...@@ -1041,6 +1052,23 @@ def set_grad_var_shape(program, dist_context):
grad_var.desc.set_shape(ref_shape) grad_var.desc.set_shape(ref_shape)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
def is_forward_op(op):
ref_role1 = int(core.op_proto_and_checker_maker.OpRole.Forward)
ref_role2 = int(core.op_proto_and_checker_maker.OpRole.Loss)
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == ref_role1 or
op_role == ref_role2)
def is_backward_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
def update_op_dims_mapping_by_default_dist_impl(dist_op): def update_op_dims_mapping_by_default_dist_impl(dist_op):
changed = False changed = False
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -1177,57 +1205,25 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): ...@@ -1177,57 +1205,25 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op):
return changed return changed
def get_all_distributed_main_program(serial_program_info, dist_context): def get_all_distributed_main_program(serial_program_info, dist_context,
parallelizer):
"Get all distributed main programs by dist_context." "Get all distributed main programs by dist_context."
from .dist_context import DistributedOperatorContext from .dist_context import DistributedOperatorContext, DistributedContext
cluster = serial_program_info.cluster cluster = serial_program_info.cluster
copied_parallelizer = copy.deepcopy(parallelizer)
all_dist_main_program = [] all_dist_main_program = []
ranks = paddle.distributed.get_world_size() if cluster is None else len( ranks = paddle.distributed.get_world_size() if cluster is None else len(
cluster.get_all_devices("GPU")) cluster.get_all_devices("GPU"))
for rank_id in range(ranks): for rank_id in range(ranks):
used_dist_context = copy.deepcopy(dist_context) used_dist_context = copy.deepcopy(dist_context)
used_dist_context._dist_op_context = DistributedOperatorContext() used_dist_context._dist_op_context = DistributedOperatorContext()
dist_main_program, dist_startup_program = get_specified_distributed_main_program( _, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program(
serial_program_info, used_dist_context, rank_id) rank_id, used_dist_context)
all_dist_main_program.append(dist_main_program) all_dist_main_program.append(dist_main_program)
return all_dist_main_program return all_dist_main_program
def get_specified_distributed_main_program(serial_program_info, dist_context,
rank_id):
"Get distributed main program by the given dist_context and rank_id."
from .partitioner import Partitioner
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from .process_group import _g_process_group_map, ProcessGroup
dist_strategy = paddle.distributed.fleet.DistributedStrategy()
train_program = serial_program_info.train_program
startup_program = serial_program_info.startup_program
loss = serial_program_info.loss
optimizer = serial_program_info.optimizer
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
dist_main_program, dist_startup_program = partitioner.transpile_forward(
train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, train_program, startup_program, dist_main_program,
dist_startup_program)
opt_ops = partitioner.apply_optimize(
copy.deepcopy(optimizer), dist_params_grads, dist_main_program,
dist_startup_program)
set_grad_var_shape(dist_main_program, dist_context)
make_data_unshard(dist_main_program, dist_startup_program, dist_context)
reshard(dist_main_program, dist_startup_program, rank_id, dist_context)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
return dist_main_program, dist_startup_program
class SerialProgramInfo: class SerialProgramInfo:
def __init__(self, def __init__(self,
train_program, train_program,
...@@ -1286,7 +1282,6 @@ def get_standalone_cost_data(distributed_programs): ...@@ -1286,7 +1282,6 @@ def get_standalone_cost_data(distributed_programs):
shape = list(map(lambda x: int(x.strip()), shape)) shape = list(map(lambda x: int(x.strip()), shape))
dtype_factor = 1 dtype_factor = 1
total_static_input_size += reduce(lambda x, y: x * y, shape) total_static_input_size += reduce(lambda x, y: x * y, shape)
# print(arg_name_lower)
if op.type == "c_embedding": if op.type == "c_embedding":
arg_name_lower = "w" if arg_name_lower == "weight" else "ids" arg_name_lower = "w" if arg_name_lower == "weight" else "ids"
for arg_name in op.input_names: for arg_name in op.input_names:
...@@ -1301,7 +1296,8 @@ def get_standalone_cost_data(distributed_programs): ...@@ -1301,7 +1296,8 @@ def get_standalone_cost_data(distributed_programs):
actual_runtime = total_actual_input_size / total_static_input_size * runtime actual_runtime = total_actual_input_size / total_static_input_size * runtime
return actual_runtime return actual_runtime
cost_model = paddle.cost_model.CostModel() import paddle.cost_model as cm
cost_model = cm.CostModel()
cost_model.static_cost_data() cost_model.static_cost_data()
DEFAULT_MULTIPLE = 2 DEFAULT_MULTIPLE = 2
OP_NAME_MAPPING = { OP_NAME_MAPPING = {
......
...@@ -26,7 +26,7 @@ import paddle.distributed.auto_parallel as auto ...@@ -26,7 +26,7 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.cost_model import estimate_cost from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -148,22 +148,33 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -148,22 +148,33 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program, loss, train_program, startup_program = mlp_forward(train_program,
startup_program) startup_program)
dist_strategy = fleet.DistributedStrategy() fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# auto completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition # logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( partitioner = Partitioner(dist_context, rank_id)
complete_train_program, startup_program) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
dist_params_grads = partitioner.apply_backward( complete_train_program, startup_program, params_grads)
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog) partitioned_optimize_ops = parallelizer._apply_optimize(
optimizer = paddle.fluid.optimizer.AdamOptimizer() auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
return auto_parallel_main_prog, auto_parallel_startup_prog return auto_parallel_main_prog, auto_parallel_startup_prog
......
...@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer ...@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
...@@ -469,21 +470,31 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -469,21 +470,31 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program, loss, train_program, startup_program = mlp_forward(train_program,
startup_program) startup_program)
dist_strategy = fleet.DistributedStrategy() fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# auto completion # auto completion
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
partitioner = Partitioner(dist_strategy, dist_context, rank_id) parallelizer._apply_serial_forward_pass(complete_train_program,
# logical partition startup_program)
dist_train_program, dist_startup_prog = partitioner.transpile_forward( params_grads = parallelizer._generate_backward(
complete_train_program, startup_program) complete_train_program,
dist_params_grads = partitioner.apply_backward( startup_program,
loss, complete_train_program, startup_program, dist_train_program, loss,
dist_startup_prog) parameter_list=None,
optimizer = paddle.fluid.optimizer.AdamOptimizer() no_grad_set=None,
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, callbacks=None)
dist_train_program, dist_startup_prog)
partitioner = Partitioner(dist_context, rank_id)
dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads)
partitioned_optimize_ops = parallelizer._apply_optimize(
dist_train_program, dist_startup_prog, dist_params_grads)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context) reshard(dist_train_program, dist_startup_prog, rank_id, dist_context)
return dist_train_program, dist_startup_prog return dist_train_program, dist_startup_prog
......
...@@ -54,9 +54,9 @@ def get_programs(annotated_func): ...@@ -54,9 +54,9 @@ def get_programs(annotated_func):
rank_id = 3 rank_id = 3
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id) partitioner = Partitioner(dist_context, rank_id)
test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog = partitioner.transpile_forward( test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, _ = partitioner.partition(
complete_train_program, start_program) complete_train_program, start_program, [])
return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context
......
...@@ -35,6 +35,7 @@ from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_pr ...@@ -35,6 +35,7 @@ from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_pr
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.utils import _get_comm_group from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
...@@ -790,9 +791,9 @@ class GPTPretrainingCriterion(nn.Layer): ...@@ -790,9 +791,9 @@ class GPTPretrainingCriterion(nn.Layer):
return loss return loss
def gpt_pretrain_forward(train_program, start_program): def gpt_pretrain_forward(train_program, startup_program):
with static.program_guard(train_program, with static.program_guard(train_program,
start_program), utils.unique_name.guard(): startup_program), utils.unique_name.guard():
batch_size = 16 batch_size = 16
sequence_len = 512 sequence_len = 512
input_ids = static.data( input_ids = static.data(
...@@ -848,7 +849,19 @@ def gpt_pretrain_forward(train_program, start_program): ...@@ -848,7 +849,19 @@ def gpt_pretrain_forward(train_program, start_program):
loss = criterion(preds, labels, loss_mask) loss = criterion(preds, labels, loss_mask)
return train_program, start_program, loss return train_program, startup_program, loss
class FakeStrategy(object):
def __init__(self):
self.amp = False
self.recompute = False
class FakeFleet(object):
def __init__(self):
self.user_defined_optimizer = None
self._user_defined_strategy = FakeStrategy()
class TestGPTPartitioner(unittest.TestCase): class TestGPTPartitioner(unittest.TestCase):
...@@ -861,38 +874,41 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -861,38 +874,41 @@ class TestGPTPartitioner(unittest.TestCase):
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() startup_program = static.Program()
dist_context = DistributedContext() parallelizer = AutoParallelizer(FakeFleet())
dist_context = parallelizer._dist_context
dist_context.process_mesh = _global_process_mesh dist_context.process_mesh = _global_process_mesh
train_program, start_program, loss = gpt_pretrain_forward(train_program, train_program, startup_program, loss = gpt_pretrain_forward(
start_program) train_program, startup_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# serial forward pass
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
# serial backward pass
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
rank_id = 3 rank_id = 3
dist_strategy = fleet.DistributedStrategy() partitioner = Partitioner(dist_context, rank_id)
partitioner = Partitioner(dist_strategy, dist_context, rank_id) auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition(
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( complete_train_program, startup_program, params_grads)
complete_train_program, start_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, start_program,
auto_parallel_main_prog, auto_parallel_startup_prog)
with open("./test_auto_parallel_partitioner_serial_main_new.txt", with open("./test_auto_parallel_partitioner_serial_main_new.txt",
"w") as fw: "w") as fw:
fw.write(str(train_program)) fw.write(str(train_program))
with open("./test_auto_parallel_partitioner_serial_startup_new.txt", with open("./test_auto_parallel_partitioner_serial_startup_new.txt",
"w") as fw: "w") as fw:
fw.write(str(start_program)) fw.write(str(startup_program))
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
set_default_distributed_context(dist_context) set_default_distributed_context(dist_context)
with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw: with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
...@@ -927,7 +943,7 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -927,7 +943,7 @@ class TestGPTPartitioner(unittest.TestCase):
complete_train_program, weights, 0, 1)) complete_train_program, weights, 0, 1))
all_params = sorted( all_params = sorted(
[param.name for param in start_program.all_parameters()]) [param.name for param in startup_program.all_parameters()])
allreduce_grads = [ allreduce_grads = [
'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2',
'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2',
......
...@@ -24,6 +24,7 @@ import paddle.utils as utils ...@@ -24,6 +24,7 @@ import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process_group import _g_process_group_map from paddle.distributed.auto_parallel.process_group import _g_process_group_map
...@@ -145,22 +146,34 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -145,22 +146,34 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program, loss, train_program, startup_program = mlp_forward(train_program,
startup_program) startup_program)
# auto completion fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
dist_strategy = fleet.DistributedStrategy() parallelizer._apply_serial_forward_pass(complete_train_program,
partitioner = Partitioner(dist_strategy, dist_context, rank_id) startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition # logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( partitioner = Partitioner(dist_context, rank_id)
complete_train_program, startup_program) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
dist_params_grads = partitioner.apply_backward( complete_train_program, startup_program, params_grads)
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog) partitioned_optimize_ops = parallelizer._apply_optimize(
optimizer = paddle.fluid.optimizer.AdamOptimizer() auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
return auto_parallel_main_prog, auto_parallel_startup_prog return auto_parallel_main_prog, auto_parallel_startup_prog
......
...@@ -24,6 +24,7 @@ import paddle.utils as utils ...@@ -24,6 +24,7 @@ import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
...@@ -109,22 +110,34 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -109,22 +110,34 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program, loss, train_program, startup_program = mlp_forward(train_program,
startup_program) startup_program)
# auto completion fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
dist_strategy = fleet.DistributedStrategy() parallelizer._apply_serial_forward_pass(complete_train_program,
partitioner = Partitioner(dist_strategy, dist_context, rank_id) startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition # logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( partitioner = Partitioner(dist_context, rank_id)
complete_train_program, startup_program) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
dist_params_grads = partitioner.apply_backward( complete_train_program, startup_program, params_grads)
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog) partitioned_optimize_ops = parallelizer._apply_optimize(
optimizer = paddle.fluid.optimizer.AdamOptimizer() auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
return auto_parallel_main_prog, auto_parallel_startup_prog return auto_parallel_main_prog, auto_parallel_startup_prog
......
...@@ -24,6 +24,7 @@ import paddle.utils as utils ...@@ -24,6 +24,7 @@ import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
...@@ -125,22 +126,32 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -125,22 +126,32 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program, loss, train_program, startup_program = mlp_forward(train_program,
startup_program) startup_program)
# auto completion fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# logical partition # logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( partitioner = Partitioner(dist_context, rank_id)
complete_train_program, startup_program) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
dist_params_grads = partitioner.apply_backward( complete_train_program, startup_program, params_grads)
loss, complete_train_program, startup_program, auto_parallel_main_prog,
auto_parallel_startup_prog) partitioned_optimize_ops = parallelizer._apply_optimize(
optimizer = paddle.fluid.optimizer.AdamOptimizer() auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
return auto_parallel_main_prog, auto_parallel_startup_prog return auto_parallel_main_prog, auto_parallel_startup_prog
...@@ -253,14 +264,15 @@ class TestMLPReshard(unittest.TestCase): ...@@ -253,14 +264,15 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 0 rank_id = 0
dist_context = DistributedContext() dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id) partitioner = Partitioner(dist_context, rank_id)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program) complete_train_program, startup_program, [])
reshard(auto_parallel_main_prog, startup_program, rank_id, dist_context) reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
dist_context)
# the x should not be slice # the x should not be slice
self.assertTrue(check_allgather(auto_parallel_main_prog)) self.assertTrue(check_allgather(partitioned_main_prog))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册