未验证 提交 6f3c9643 编写于 作者: J JZ-LIANG 提交者: GitHub

Eb118 BF16 Adoption (#52827)

* pr1

* pr2

* pr3

* fixed unitest

* adopt for scale
上级 8cbc75ca
...@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False) ...@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
######################################### #########################################
AMP = "amp" AMP = "amp"
set_field_default_config(AMP, "enable", False) set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "dtype", "float16")
set_field_default_config(AMP, "level", "o1")
set_field_default_config(AMP, "init_loss_scaling", 32768.0) set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000) set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2) set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
...@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True) ...@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", []) set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", []) set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", []) set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False) set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_fp16_guard", True) set_field_default_config(AMP, "use_bf16_guard", False)
set_field_default_config(AMP, "use_optimizer_fp16", False) set_field_default_config(AMP, "use_optimizer_fp16", False)
######################################### #########################################
......
...@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
Out_var, Out_var,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'c_allreduce_sum', 'c_allreduce_sum',
) )
...@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
Out_grad, Out_grad,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity', '_c_identity',
) )
...@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
}, },
) )
check_variable_and_dtype( check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
check_dtype( check_dtype(
intermediate_var_0.dtype, intermediate_var_0.dtype,
'dtype', 'dtype',
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'linear', 'linear',
) )
......
...@@ -20,7 +20,11 @@ from .common import DistributedOperatorImpl ...@@ -20,7 +20,11 @@ from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import gradient_synchronization from .common import gradient_synchronization
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related from .common import (
set_comm_op_dist_attr_for_program,
naive_copy_op_dist_attr_for_program,
is_parameter_related,
)
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -33,24 +37,39 @@ from paddle.fluid import core, unique_name ...@@ -33,24 +37,39 @@ from paddle.fluid import core, unique_name
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import (
OpRole,
OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs from ..cost import (
build_comp_desc_from_dist_op,
build_comm_desc_from_dist_op,
build_dp_costs,
)
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs
from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost
from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost from paddle.distributed.auto_parallel.cost.comm_op_cost import (
AllreduceSumOpCost,
IdentityOpCost,
)
def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping): def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping):
if trans_x: if trans_x:
x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[ x_dims_mapping[-1], x_dims_mapping[-2] = (
-2], x_dims_mapping[-1] x_dims_mapping[-2],
x_dims_mapping[-1],
)
if trans_y: if trans_y:
y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[ y_dims_mapping[-1], y_dims_mapping[-2] = (
-2], y_dims_mapping[-1] y_dims_mapping[-2],
y_dims_mapping[-1],
)
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
...@@ -123,13 +142,17 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -123,13 +142,17 @@ def _update_dims_mapping_for_matmul(dist_op):
for i in range(new_out_dims_mapping_len - 2): for i in range(new_out_dims_mapping_len - 2):
broadcast_out_dims_mapping.append(out_dims_mapping[i]) broadcast_out_dims_mapping.append(out_dims_mapping[i])
compatible_dims_mapping = compute_compatible_dims_mapping([ compatible_dims_mapping = compute_compatible_dims_mapping(
broadcast_x_dims_mapping, broadcast_y_dims_mapping, [
broadcast_out_dims_mapping broadcast_x_dims_mapping,
]) broadcast_y_dims_mapping,
broadcast_out_dims_mapping,
]
)
if compatible_dims_mapping is None: if compatible_dims_mapping is None:
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, trans_x_y_dims_mapping(
y_dims_mapping) trans_x, trans_y, x_dims_mapping, y_dims_mapping
)
return False return False
for i in range(new_x_dims_mapping_len - 2): for i in range(new_x_dims_mapping_len - 2):
...@@ -152,17 +175,20 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -152,17 +175,20 @@ def _update_dims_mapping_for_matmul(dist_op):
# The following which uses negative index can be work # The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed = compute_compatible_and_update_dim_mapping( dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, y_dims_mapping], [-1, -2]) [x_dims_mapping, y_dims_mapping], [-1, -2]
)
if dim_changed: if dim_changed:
changed = True changed = True
dim_changed = compute_compatible_and_update_dim_mapping( dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [-2, -2]) [x_dims_mapping, out_dims_mapping], [-2, -2]
)
if dim_changed: if dim_changed:
changed = True changed = True
dim_changed = compute_compatible_and_update_dim_mapping( dim_changed = compute_compatible_and_update_dim_mapping(
[y_dims_mapping, out_dims_mapping], [-1, -1]) [y_dims_mapping, out_dims_mapping], [-1, -1]
)
if dim_changed: if dim_changed:
changed = True changed = True
...@@ -202,7 +228,8 @@ def _is_auto_compatible_for_matmul(dist_op): ...@@ -202,7 +228,8 @@ def _is_auto_compatible_for_matmul(dist_op):
x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name)) x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name))
y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name)) y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name))
out_dims_mapping = copy.deepcopy( out_dims_mapping = copy.deepcopy(
op_dist_attr.get_output_dims_mapping(out_name)) op_dist_attr.get_output_dims_mapping(out_name)
)
x_dims_mapping_len = len(x_dims_mapping) x_dims_mapping_len = len(x_dims_mapping)
y_dims_mapping_len = len(y_dims_mapping) y_dims_mapping_len = len(y_dims_mapping)
out_dims_mapping_len = len(out_dims_mapping) out_dims_mapping_len = len(out_dims_mapping)
...@@ -234,22 +261,23 @@ def _is_auto_compatible_for_matmul(dist_op): ...@@ -234,22 +261,23 @@ def _is_auto_compatible_for_matmul(dist_op):
for i in range(out_dims_mapping_len - 2): for i in range(out_dims_mapping_len - 2):
broadcast_out_dims_mapping.append(out_dims_mapping[i]) broadcast_out_dims_mapping.append(out_dims_mapping[i])
is_same = ((broadcast_x_dims_mapping == broadcast_y_dims_mapping) is_same = (broadcast_x_dims_mapping == broadcast_y_dims_mapping) and (
and (broadcast_x_dims_mapping == broadcast_out_dims_mapping)) broadcast_x_dims_mapping == broadcast_out_dims_mapping
)
if not is_same: if not is_same:
return False return False
# The following which uses negative index can be work # The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
is_same = (x_dims_mapping[-1] == y_dims_mapping[-2]) is_same = x_dims_mapping[-1] == y_dims_mapping[-2]
if not is_same: if not is_same:
return False return False
is_same = (x_dims_mapping[-2] == out_dims_mapping[-2]) is_same = x_dims_mapping[-2] == out_dims_mapping[-2]
if not is_same: if not is_same:
return False return False
is_same = (y_dims_mapping[-1] == out_dims_mapping[-1]) is_same = y_dims_mapping[-1] == out_dims_mapping[-1]
if not is_same: if not is_same:
return False return False
...@@ -265,8 +293,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -265,8 +293,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
backward_op = dist_op_context.cur_src_op backward_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
dist_attr = ctx.get_op_dist_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert (
str(backward_op)) dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(backward_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.process_mesh.processes: if rank_id not in dist_attr.process_mesh.processes:
...@@ -277,22 +306,26 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -277,22 +306,26 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD') assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD') assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD') assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
assert len( assert (
len(kwargs['Y']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Y'] kwargs['Y']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( )
kwargs['Y']) assert (
assert len( len(kwargs['X']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['X'] kwargs['X']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( )
kwargs['X']) assert (
assert len( len(kwargs['Out@GRAD']) == 1
kwargs['Out@GRAD'] ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( kwargs['Out']
kwargs['Out']) )
assert len( assert (
len(kwargs['Y@GRAD']) == 1
), "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['Y@GRAD'] kwargs['Y@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format( )
kwargs['Y@GRAD'])
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Y_var = main_block._var_recursive(kwargs['Y'][0]) Y_var = main_block._var_recursive(kwargs['Y'][0])
...@@ -302,7 +335,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -302,7 +335,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert not is_parameter_related( assert not is_parameter_related(
X_var.name, main_block X_var.name, main_block
), "left operand(X) [{}] of dist matmul should not be parameter".format( ), "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name) X_var.name
)
X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name) X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name)
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
...@@ -339,28 +373,34 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -339,28 +373,34 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
parallel_axis = Y_var_dim_mapping[0] parallel_axis = Y_var_dim_mapping[0]
check_variable_and_dtype( check_variable_and_dtype(
Out_grad, 'tensor', Out_grad,
['float16', 'float32', 'float64', 'int32', 'int64'], 'tensor',
'_c_identity') ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["c_identity", 'tmp'])) + "@GRAD", ".".join(["c_identity", 'tmp'])
)
+ "@GRAD",
dtype=Out_grad.dtype, dtype=Out_grad.dtype,
shape=Out_grad.shape, shape=Out_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=Out_grad.stop_gradient) stop_gradient=Out_grad.stop_gradient,
)
# copy X_var's dist_attr to intermediate_var_0's dist_attr # copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name) out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
assert out_grad_dist_attr is not None assert out_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(
out_grad_dist_attr) intermediate_var_0, out_grad_dist_attr
)
group_ranks = _get_comm_group(process_mesh_group, group_ranks = _get_comm_group(
process_mesh_shape, parallel_axis, process_mesh_group, process_mesh_shape, parallel_axis, rank_id
rank_id) )
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
c_identity_op = main_block.append_op( c_identity_op = main_block.append_op(
type='c_identity', type='c_identity',
...@@ -371,20 +411,29 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -371,20 +411,29 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward, OP_ROLE_KEY: OpRole.Backward,
}) },
check_variable_and_dtype(intermediate_var_0, 'x', )
['float16', 'float32', 'float64'], check_variable_and_dtype(
'linear') intermediate_var_0,
check_dtype(intermediate_var_0.dtype, 'dtype', 'x',
['float16', 'float32', 'float64'], 'linear') ['float16', 'float32', 'float64', 'uint16'],
set_comm_op_dist_attr_for_program(c_identity_op, 'linear',
dist_attr.process_mesh, )
out_grad_dist_attr, ctx) check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
set_comm_op_dist_attr_for_program(
c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
)
new_kwargs = copy.deepcopy(kwargs) new_kwargs = copy.deepcopy(kwargs)
new_kwargs['Out@GRAD'] = [intermediate_var_0.name] new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output( matmul_op_desc = copy_op_with_new_input_output(
ctx, main_block, backward_op, **new_kwargs) ctx, main_block, backward_op, **new_kwargs
)
else: else:
# col parallel: matmul + allreduce # col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0 assert Y_var_dim_mapping[0] < 0
...@@ -397,28 +446,36 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -397,28 +446,36 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert len(kwargs['X@GRAD']) == 1 assert len(kwargs['X@GRAD']) == 1
X_grad = main_block.var(kwargs['X@GRAD'][0]) X_grad = main_block.var(kwargs['X@GRAD'][0])
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["c_identity", 'tmp'])) + "@GRAD", ".".join(["c_identity", 'tmp'])
)
+ "@GRAD",
dtype=X_grad.dtype, dtype=X_grad.dtype,
shape=X_grad.shape, shape=X_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=X_grad.stop_gradient) stop_gradient=X_grad.stop_gradient,
)
X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name) X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
assert X_grad_dist_attr is not None assert X_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(
X_grad_dist_attr) intermediate_var_0, X_grad_dist_attr
)
new_kwargs['X@GRAD'] = [intermediate_var_0.name] new_kwargs['X@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output( matmul_op_desc = copy_op_with_new_input_output(
ctx, main_block, backward_op, **new_kwargs) ctx, main_block, backward_op, **new_kwargs
)
# NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
if has_x_grad: if has_x_grad:
group_ranks = _get_comm_group(process_mesh_group, group_ranks = _get_comm_group(
process_mesh_shape, parallel_axis, process_mesh_group,
rank_id) process_mesh_shape,
parallel_axis,
rank_id,
)
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
c_allreduce_sum_op = main_block.append_op( c_allreduce_sum_op = main_block.append_op(
type='c_allreduce_sum', type='c_allreduce_sum',
...@@ -428,15 +485,20 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -428,15 +485,20 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward,
}) },
set_comm_op_dist_attr_for_program(c_allreduce_sum_op, )
set_comm_op_dist_attr_for_program(
c_allreduce_sum_op,
dist_attr.process_mesh, dist_attr.process_mesh,
X_grad_dist_attr, ctx) X_grad_dist_attr,
ctx,
)
else: else:
# replicate # replicate
matmul_op_desc = copy_op_with_new_input_output(ctx, main_block, matmul_op_desc = copy_op_with_new_input_output(
backward_op, **kwargs) ctx, main_block, backward_op, **kwargs
)
# data parallel gradient synchronization # data parallel gradient synchronization
act_grad_names = [X_var.name] act_grad_names = [X_var.name]
...@@ -448,8 +510,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -448,8 +510,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
if trans_x: if trans_x:
trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None) trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
gradient_synchronization(ctx, backward_op, act_grad_names, out_grad_names, gradient_synchronization(
rank_id) ctx, backward_op, act_grad_names, out_grad_names, rank_id
)
if trans_x: if trans_x:
trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None) trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
...@@ -472,23 +535,25 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): ...@@ -472,23 +535,25 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
if size <= 1 or axis in dim_mapping: if size <= 1 or axis in dim_mapping:
pass pass
else: else:
group_ranks = _get_comm_group(process_mesh.processes, group_ranks = _get_comm_group(
process_mesh.topology, axis, rank_id) process_mesh.processes, process_mesh.topology, axis, rank_id
)
sync_group = new_process_group(group_ranks) sync_group = new_process_group(group_ranks)
startup_block.append_op(type='c_broadcast', startup_block.append_op(
type='c_broadcast',
inputs={'X': param}, inputs={'X': param},
outputs={'Out': param}, outputs={'Out': param},
attrs={ attrs={
'ring_id': sync_group.id, 'ring_id': sync_group.id,
'root': 0, 'root': 0,
'use_calc_stream': True, 'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward,
}) },
)
class DistributedMatmul(DistributedOperatorImplContainer): class DistributedMatmul(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedMatmul, self).__init__(op_type) super(DistributedMatmul, self).__init__(op_type)
...@@ -498,7 +563,6 @@ register_distributed_operator_impl_container(DistributedMatmul("matmul")) ...@@ -498,7 +563,6 @@ register_distributed_operator_impl_container(DistributedMatmul("matmul"))
# ColumnParallel # ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl): class DistributedMatmulImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulImpl0, self).__init__(name) super(DistributedMatmulImpl0, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -521,7 +585,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -521,7 +585,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0]) backward_op.input("Y")[0]
)
# col parallel: matmul + allreduce # col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0 assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1] parallel_axis = Y_var_dim_mapping[1]
...@@ -531,13 +596,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -531,13 +596,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert len(backward_op.output("X@GRAD")) == 1 assert len(backward_op.output("X@GRAD")) == 1
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx, cost_mapping = build_comp_costs_from_descs(
processes, desc_mapping, MatmulGradOpCost, ctx, processes, desc_mapping, cluster
cluster) )
res.append(cost_mapping) res.append(cost_mapping)
# calc comm op cost # calc comm op cost
...@@ -550,40 +616,52 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -550,40 +616,52 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, AllreduceSumOpCost,
c_allreduce_sum_desc_mapping, cluster) ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res.append(comm_op_cost_list) res.append(comm_op_cost_list)
# need gradient allreduce # need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]) backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[ if (
batch_size_axis] > 1 and is_parameter_related( batch_size_axis > -1
backward_op.input("Y")[0], main_block): and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]] var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, build_dp_costs(
cluster) res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res return res
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes, cost_mapping = build_comp_costs_from_descs(
desc_mapping, cluster) MatmulOpCost, ctx, processes, desc_mapping, cluster
)
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-1] serial_op.input("Y")[0]
)[-1]
attrs = {"use_calc_stream": True, "use_model_parallel": True} attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.input("X") var_names = serial_op.input("X")
c_identity_desc_mapping = build_comm_desc_from_dist_op( c_identity_desc_mapping = build_comm_desc_from_dist_op(
...@@ -592,10 +670,12 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -592,10 +670,12 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res_cost = [comm_op_cost_list, cost_mapping] res_cost = [comm_op_cost_list, cost_mapping]
return res_cost return res_cost
...@@ -606,16 +686,19 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -606,16 +686,19 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = copy.deepcopy( x_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(x_name)) op_dist_attr.get_input_dims_mapping(x_name)
)
y_dims_mapping = copy.deepcopy( y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name)) op_dist_attr.get_input_dims_mapping(y_name)
)
trans_x = op_desc.attr('transpose_X') trans_x = op_desc.attr('transpose_X')
trans_y = op_desc.attr('transpose_Y') trans_y = op_desc.attr('transpose_Y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_shard(x_dims_mapping[-1]): if is_dim_shard(x_dims_mapping[-1]):
return False return False
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate( if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
y_dims_mapping[-1]): y_dims_mapping[-1]
):
return False return False
for mapping in x_dims_mapping[1:-1]: for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping): if is_dim_shard(mapping):
...@@ -635,8 +718,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -635,8 +718,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or (
(not self.is_output_compatible(dist_op)): not self.is_output_compatible(dist_op)
):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
return False return False
...@@ -661,28 +745,33 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -661,28 +745,33 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert (
str(src_op)) op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id = _get_corresponding_rank(
rank_id) ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format( assert input_name in kwargs, "input [{}] is not given".format(
input_name) input_name
)
assert len(kwargs[input_name]) == len( assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name) src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name) ), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format( assert output_name in kwargs, "input [{}] is not given".format(
output_name) output_name
)
assert len(kwargs[output_name]) == len( assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name) src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format( ), "number of tensor for input [{}] is not match".format(
output_name) output_name
)
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block.var(kwargs['Y'][0]) Weight_var = main_block.var(kwargs['Y'][0])
...@@ -692,18 +781,24 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -692,18 +781,24 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1] Weight_var.name
)[-1]
if trans_y: if trans_y:
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2] Weight_var.name
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( )[-2]
matmul_col_dim_mapping) assert (
matmul_col_dim_mapping >= 0
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(
parallel_axis, rank_id) process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
# infer new var shape with op dist attr # infer new var shape with op dist attr
...@@ -711,31 +806,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -711,31 +806,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert x_tensor_dist_attr is not None assert x_tensor_dist_attr is not None
identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
assert identity_var_dist_attr is not None assert identity_var_dist_attr is not None
ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr, ref_shape_x = infer_shape(
identity_var_dist_attr) main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
)
# infer out var shape with op dist attr # infer out var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None assert out_var_dist_attr is not None
ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr, ref_shape_out = infer_shape(
out_var_dist_attr) main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["c_identity", 'tmp'])), ".".join(["c_identity", 'tmp'])
),
dtype=X_var.dtype, dtype=X_var.dtype,
shape=X_var.shape, shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=X_var.stop_gradient) stop_gradient=X_var.stop_gradient,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr # set intermediate_var_0's dist_attr with X_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(
identity_var_dist_attr) intermediate_var_0, identity_var_dist_attr
)
check_variable_and_dtype( check_variable_and_dtype(
X_var, 'tensor', X_var,
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
c_identity_op = main_block.append_op( c_identity_op = main_block.append_op(
type='c_identity', type='c_identity',
...@@ -745,26 +848,34 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -745,26 +848,34 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
}) },
)
if intermediate_var_0.shape != ref_shape_x: if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x) intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(intermediate_var_0, 'x', check_variable_and_dtype(
['float16', 'float32', 'float64'], 'linear') intermediate_var_0,
check_dtype(intermediate_var_0.dtype, 'dtype', 'x',
['float16', 'float32', 'float64'], 'linear') ['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = { attrs = {
'transpose_X': trans_x, 'transpose_X': trans_x,
'transpose_Y': trans_y, 'transpose_Y': trans_y,
'alpha': 1, 'alpha': 1,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
} }
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
matmul_op = main_block.append_op(type='matmul', matmul_op = main_block.append_op(
inputs=inputs, type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
outputs={'Out': Out_var}, )
attrs=attrs)
if Out_var.shape != ref_shape_out: if Out_var.shape != ref_shape_out:
Out_var.desc.set_shape(ref_shape_out) Out_var.desc.set_shape(ref_shape_out)
...@@ -778,13 +889,16 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -778,13 +889,16 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
input_varname = c_identity_op.desc.input_arg_names()[0] input_varname = c_identity_op.desc.input_arg_names()[0]
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format( assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
identity_op_dist_attr.set_input_dist_attr(input_varname, )
input_dist_attr) identity_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
# output # output
output_varname = c_identity_op.desc.output_arg_names()[0] output_varname = c_identity_op.desc.output_arg_names()[0]
identity_op_dist_attr.set_output_dist_attr(output_varname, identity_op_dist_attr.set_output_dist_attr(
input_dist_attr) output_varname, input_dist_attr
)
# set op dist attr # set op dist attr
ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)
...@@ -797,31 +911,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -797,31 +911,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
for input_varname in matmul_op.desc.input_arg_names(): for input_varname in matmul_op.desc.input_arg_names():
if input_varname in src_op.desc.input_arg_names(): if input_varname in src_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr( input_dist_attr = op_dist_attr.get_input_dist_attr(
input_varname) input_varname
)
assert input_dist_attr is not None, "dist_attr is {}".format( assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
matmul_op_dist_attr.set_input_dist_attr(input_varname, )
input_dist_attr) matmul_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
else: else:
input_var = main_block.var(input_varname) input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
input_var) input_var
matmul_op_dist_attr.set_input_dist_attr(input_varname, )
tensor_dist_attr) matmul_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr
)
# output # output
output_varname = matmul_op.desc.output_arg_names()[0] output_varname = matmul_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
matmul_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) matmul_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
# set op dist attr # set op dist attr
ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr) ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx, _init_param_sync(
rank_id) Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -830,7 +952,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -830,7 +952,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# RowParallel # RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl): class DistributedMatmulImpl1(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulImpl1, self).__init__(name) super(DistributedMatmulImpl1, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -853,7 +974,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -853,7 +974,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0]) backward_op.input("Y")[0]
)
assert Y_var_dim_mapping[1] < 0 assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0] parallel_axis = Y_var_dim_mapping[0]
...@@ -866,50 +988,60 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -866,50 +988,60 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res.append(comm_op_cost_list) res.append(comm_op_cost_list)
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx, )
processes, desc_mapping, cost_mapping = build_comp_costs_from_descs(
cluster) MatmulGradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping) res.append(cost_mapping)
# need gradient allreduce # need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]) backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[ if (
batch_size_axis] > 1 and is_parameter_related( batch_size_axis > -1
backward_op.input("Y")[0], main_block): and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]] var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, build_dp_costs(
cluster) res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res return res
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes, cost_mapping = build_comp_costs_from_descs(
desc_mapping, cluster) MatmulOpCost, ctx, processes, desc_mapping, cluster
)
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-2] serial_op.input("Y")[0]
)[-2]
attrs = {"use_calc_stream": True, "use_model_parallel": True} attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out") var_names = serial_op.output("Out")
...@@ -919,11 +1051,16 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -919,11 +1051,16 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, AllreduceSumOpCost,
cluster) ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res_cost = [cost_mapping, comm_op_cost_list] res_cost = [cost_mapping, comm_op_cost_list]
...@@ -935,16 +1072,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -935,16 +1072,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = copy.deepcopy( x_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(x_name)) op_dist_attr.get_input_dims_mapping(x_name)
)
y_dims_mapping = copy.deepcopy( y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name)) op_dist_attr.get_input_dims_mapping(y_name)
)
trans_x = op_desc.attr('transpose_X') trans_x = op_desc.attr('transpose_X')
trans_y = op_desc.attr('transpose_Y') trans_y = op_desc.attr('transpose_Y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_replicate(x_dims_mapping[-1]): if is_dim_replicate(x_dims_mapping[-1]):
return False return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard( if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
y_dims_mapping[-1]): y_dims_mapping[-1]
):
return False return False
# Other dimensions must be replicate except the batch dimension # Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]: for mapping in x_dims_mapping[1:-1]:
...@@ -966,8 +1106,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -966,8 +1106,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or (
(not self.is_output_compatible(dist_op)): not self.is_output_compatible(dist_op)
):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
return False return False
...@@ -992,28 +1133,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -992,28 +1133,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert (
str(src_op)) op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id = _get_corresponding_rank(
rank_id) ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format( assert input_name in kwargs, "input [{}] is not given".format(
input_name) input_name
)
assert len(kwargs[input_name]) == len( assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name) src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name) ), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format( assert output_name in kwargs, "input [{}] is not given".format(
output_name) output_name
)
assert len(kwargs[output_name]) == len( assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name) src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format( ), "number of tensor for input [{}] is not match".format(
output_name) output_name
)
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block.var(kwargs['Y'][0]) Weight_var = main_block.var(kwargs['Y'][0])
...@@ -1023,29 +1169,40 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1023,29 +1169,40 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2] Weight_var.name
)[-2]
if trans_y: if trans_y:
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1] Weight_var.name
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( )[-1]
matmul_row_dim_mapping) assert (
matmul_row_dim_mapping >= 0
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(
parallel_axis, rank_id) process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'], check_variable_and_dtype(
'linear') X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], )
'linear') check_dtype(
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = { attrs = {
'transpose_X': trans_x, 'transpose_X': trans_x,
'transpose_Y': trans_y, 'transpose_Y': trans_y,
'alpha': 1, 'alpha': 1,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
} }
inputs = {'X': X_var, 'Y': Weight_var} inputs = {'X': X_var, 'Y': Weight_var}
...@@ -1054,27 +1211,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1054,27 +1211,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
assert out_tensor_dist_attr is not None assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None assert out_var_dist_attr is not None
ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, ref_shape = infer_shape(
out_var_dist_attr) main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["c_allreduce_sum", 'tmp'])), ".".join(["c_allreduce_sum", 'tmp'])
),
shape=Out_var.shape, shape=Out_var.shape,
dtype=Out_var.dtype, dtype=Out_var.dtype,
type=Out_var.type, type=Out_var.type,
lod_level=Out_var.lod_level, lod_level=Out_var.lod_level,
persistable=False, persistable=False,
is_data=False, is_data=False,
need_check_feed=Out_var.desc.need_check_feed()) need_check_feed=Out_var.desc.need_check_feed(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr # set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(
out_var_dist_attr) intermediate_var_0, out_var_dist_attr
)
matmul_op = main_block.append_op(type='matmul', matmul_op = main_block.append_op(
type='matmul',
inputs=inputs, inputs=inputs,
outputs={'Out': intermediate_var_0}, outputs={'Out': intermediate_var_0},
attrs=attrs) attrs=attrs,
)
if intermediate_var_0.shape != ref_shape: if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape) intermediate_var_0.desc.set_shape(ref_shape)
...@@ -1086,8 +1249,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1086,8 +1249,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
}) },
)
if Out_var.shape != ref_shape: if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape) Out_var.desc.set_shape(ref_shape)
...@@ -1100,15 +1264,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1100,15 +1264,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
for input_varname in matmul_op.desc.input_arg_names(): for input_varname in matmul_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format( assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
matmul_op_dist_attr.set_input_dist_attr(input_varname, )
input_dist_attr) matmul_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
output_varname = matmul_op.desc.output_arg_names()[0] output_varname = matmul_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
matmul_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) matmul_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr) ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
# allreduce # allreduce
...@@ -1120,21 +1288,26 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1120,21 +1288,26 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
input_var = main_block.var(input_varname) input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(input_varname, allreduce_op_dist_attr.set_input_dist_attr(
tensor_dist_attr) input_varname, tensor_dist_attr
)
for output_varname in c_allreduce_sum_op.desc.output_arg_names(): for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
allreduce_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) allreduce_op_dist_attr.set_output_dist_attr(
ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, output_varname, output_dist_attr
allreduce_op_dist_attr) )
ctx.set_op_dist_attr_for_program(
c_allreduce_sum_op, allreduce_op_dist_attr
)
# init param sync # init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx, _init_param_sync(
rank_id) Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -1143,7 +1316,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1143,7 +1316,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# ReplicateParallel # ReplicateParallel
class DistributedMatmulImpl2(DistributedOperatorImpl): class DistributedMatmulImpl2(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulImpl2, self).__init__(name) super(DistributedMatmulImpl2, self).__init__(name)
...@@ -1164,38 +1336,45 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -1164,38 +1336,45 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
vars = main_block.vars vars = main_block.vars
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx, cost_mapping = build_comp_costs_from_descs(
processes, desc_mapping, MatmulGradOpCost, ctx, processes, desc_mapping, cluster
cluster) )
res.append(cost_mapping) res.append(cost_mapping)
# need gradient allreduce # need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]) backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[ if (
batch_size_axis] > 1 and is_parameter_related( batch_size_axis > -1
backward_op.input("Y")[0], main_block): and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]] var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, build_dp_costs(
cluster) res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res return res
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes, cost_mapping = build_comp_costs_from_descs(
desc_mapping, cluster) MatmulOpCost, ctx, processes, desc_mapping, cluster
)
res_cost = [cost_mapping] res_cost = [cost_mapping]
return res_cost return res_cost
...@@ -1211,13 +1390,15 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -1211,13 +1390,15 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
if is_dim_shard(x_dims_mapping[-1]): if is_dim_shard(x_dims_mapping[-1]):
return False return False
if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard( if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
x_dims_mapping[-2]): x_dims_mapping[-2]
):
return False return False
if is_dim_shard(y_dims_mapping[-1]): if is_dim_shard(y_dims_mapping[-1]):
return False return False
if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard( if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
y_dims_mapping[-2]): y_dims_mapping[-2]
):
return False return False
return True return True
...@@ -1231,14 +1412,16 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -1231,14 +1412,16 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
if is_dim_shard(out_dims_mapping[-1]): if is_dim_shard(out_dims_mapping[-1]):
return False return False
if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard( if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
out_dims_mapping[-2]): out_dims_mapping[-2]
):
return False return False
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or (
(not self.is_output_compatible(dist_op)): not self.is_output_compatible(dist_op)
):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
...@@ -1262,16 +1445,18 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -1262,16 +1445,18 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs) _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
register_distributed_operator_impl("matmul", register_distributed_operator_impl(
DistributedMatmulImpl0("column_parallel")) "matmul", DistributedMatmulImpl0("column_parallel")
register_distributed_operator_impl("matmul", )
DistributedMatmulImpl1("row_parallel")) register_distributed_operator_impl(
register_distributed_operator_impl("matmul", "matmul", DistributedMatmulImpl1("row_parallel")
DistributedMatmulImpl2("replicate_parallel")) )
register_distributed_operator_impl(
"matmul", DistributedMatmulImpl2("replicate_parallel")
)
class DistributedMatmulV2(DistributedOperatorImplContainer): class DistributedMatmulV2(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedMatmulV2, self).__init__(op_type) super(DistributedMatmulV2, self).__init__(op_type)
...@@ -1281,7 +1466,6 @@ register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2")) ...@@ -1281,7 +1466,6 @@ register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
# ColumnParallel # ColumnParallel
class DistributedMatmulV2Impl0(DistributedOperatorImpl): class DistributedMatmulV2Impl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulV2Impl0, self).__init__(name) super(DistributedMatmulV2Impl0, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -1304,7 +1488,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1304,7 +1488,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0]) backward_op.input("Y")[0]
)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.processes
# col parallel: matmul + allreduce # col parallel: matmul + allreduce
...@@ -1318,12 +1503,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1318,12 +1503,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
assert len(backward_op.output("X@GRAD")) == 1 assert len(backward_op.output("X@GRAD")) == 1
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx, cost_mapping = build_comp_costs_from_descs(
processes, desc_mapping, MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
cluster) )
res.append(cost_mapping) res.append(cost_mapping)
# calc comm op cost # calc comm op cost
...@@ -1336,45 +1522,55 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1336,45 +1522,55 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, AllreduceSumOpCost,
c_allreduce_sum_desc_mapping, cluster) ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res.append(comm_op_cost_list) res.append(comm_op_cost_list)
# need gradient allreduce # need gradient allreduce
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]) backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[ if (
batch_size_axis] > 1 and is_parameter_related( batch_size_axis > -1
backward_op.input("Y")[0], main_block): and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]] var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, build_dp_costs(
cluster) res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res return res
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost # calc comp op cost
# TODO: trans shape if trans_x or trans_y is True # TODO: trans shape if trans_x or trans_y is True
comp_desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, comp_desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.processes
comp_cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx, comp_cost_mapping = build_comp_costs_from_descs(
processes, MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster
comp_desc_mapping, )
cluster)
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-1] serial_op.input("Y")[0]
)[-1]
attrs = {"use_calc_stream": True, "use_model_parallel": True} attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.input("X") var_names = serial_op.input("X")
...@@ -1384,9 +1580,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1384,9 +1580,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res_cost = [comm_op_cost_list, comp_cost_mapping] res_cost = [comm_op_cost_list, comp_cost_mapping]
return res_cost return res_cost
...@@ -1397,16 +1595,19 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1397,16 +1595,19 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = copy.deepcopy( x_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(x_name)) op_dist_attr.get_input_dims_mapping(x_name)
)
y_dims_mapping = copy.deepcopy( y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name)) op_dist_attr.get_input_dims_mapping(y_name)
)
trans_x = op_desc.attr('trans_x') trans_x = op_desc.attr('trans_x')
trans_y = op_desc.attr('trans_y') trans_y = op_desc.attr('trans_y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_shard(x_dims_mapping[-1]): if is_dim_shard(x_dims_mapping[-1]):
return False return False
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate( if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
y_dims_mapping[-1]): y_dims_mapping[-1]
):
return False return False
for mapping in x_dims_mapping[1:-1]: for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping): if is_dim_shard(mapping):
...@@ -1426,8 +1627,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1426,8 +1627,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or (
(not self.is_output_compatible(dist_op)): not self.is_output_compatible(dist_op)
):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
return False return False
...@@ -1452,28 +1654,33 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1452,28 +1654,33 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert (
str(src_op)) op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id = _get_corresponding_rank(
rank_id) ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format( assert input_name in kwargs, "input [{}] is not given".format(
input_name) input_name
)
assert len(kwargs[input_name]) == len( assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name) src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name) ), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format( assert output_name in kwargs, "input [{}] is not given".format(
output_name) output_name
)
assert len(kwargs[output_name]) == len( assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name) src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format( ), "number of tensor for input [{}] is not match".format(
output_name) output_name
)
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
...@@ -1483,18 +1690,24 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1483,18 +1690,24 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1] Weight_var.name
)[-1]
if trans_y: if trans_y:
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2] Weight_var.name
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( )[-2]
matmul_col_dim_mapping) assert (
matmul_col_dim_mapping >= 0
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(
parallel_axis, rank_id) process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
# infer new var shape with op dist attr # infer new var shape with op dist attr
...@@ -1502,31 +1715,39 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1502,31 +1715,39 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
assert x_tensor_dist_attr is not None assert x_tensor_dist_attr is not None
identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
assert identity_var_dist_attr is not None assert identity_var_dist_attr is not None
ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr, ref_shape_x = infer_shape(
identity_var_dist_attr) main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
)
# infer out var shape with op dist attr # infer out var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None assert out_var_dist_attr is not None
ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr, ref_shape_out = infer_shape(
out_var_dist_attr) main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["c_identity", 'tmp'])), ".".join(["c_identity", 'tmp'])
),
dtype=X_var.dtype, dtype=X_var.dtype,
shape=X_var.shape, shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=X_var.stop_gradient) stop_gradient=X_var.stop_gradient,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr # set intermediate_var_0's dist_attr with X_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(
identity_var_dist_attr) intermediate_var_0, identity_var_dist_attr
)
check_variable_and_dtype( check_variable_and_dtype(
X_var, 'tensor', X_var,
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
c_identity_op = main_block.append_op( c_identity_op = main_block.append_op(
type='c_identity', type='c_identity',
inputs={'X': [X_var]}, inputs={'X': [X_var]},
...@@ -1536,24 +1757,35 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1536,24 +1757,35 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role'), OP_ROLE_KEY: src_op.attr('op_role'),
}) },
)
if intermediate_var_0.shape != ref_shape_x: if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x) intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(intermediate_var_0, 'x', check_variable_and_dtype(
['float16', 'float32', 'float64'], 'linear') intermediate_var_0,
check_dtype(intermediate_var_0.dtype, 'dtype', 'x',
['float16', 'float32', 'float64'], 'linear') ['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = { attrs = {
'trans_x': trans_x, 'trans_x': trans_x,
'trans_y': trans_y, 'trans_y': trans_y,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
} }
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
matmul_v2_op = main_block.append_op(type='matmul_v2', matmul_v2_op = main_block.append_op(
type='matmul_v2',
inputs=inputs, inputs=inputs,
outputs={'Out': Out_var}, outputs={'Out': Out_var},
attrs=attrs) attrs=attrs,
)
if Out_var.shape != ref_shape_out: if Out_var.shape != ref_shape_out:
Out_var.desc.set_shape(ref_shape_out) Out_var.desc.set_shape(ref_shape_out)
...@@ -1567,13 +1799,16 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1567,13 +1799,16 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
input_varname = c_identity_op.desc.input_arg_names()[0] input_varname = c_identity_op.desc.input_arg_names()[0]
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format( assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
identity_op_dist_attr.set_input_dist_attr(input_varname, )
input_dist_attr) identity_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
# output # output
output_varname = c_identity_op.desc.output_arg_names()[0] output_varname = c_identity_op.desc.output_arg_names()[0]
identity_op_dist_attr.set_output_dist_attr(output_varname, identity_op_dist_attr.set_output_dist_attr(
input_dist_attr) output_varname, input_dist_attr
)
ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)
# matmulv2 # matmulv2
...@@ -1584,29 +1819,37 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1584,29 +1819,37 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
for input_varname in matmul_v2_op.desc.input_arg_names(): for input_varname in matmul_v2_op.desc.input_arg_names():
if input_varname in src_op.desc.input_arg_names(): if input_varname in src_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr( input_dist_attr = op_dist_attr.get_input_dist_attr(
input_varname) input_varname
)
assert input_dist_attr is not None, "dist_attr is {}".format( assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
)
matmulv2_op_dist_attr.set_input_dist_attr( matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr) input_varname, input_dist_attr
)
else: else:
input_var = main_block.var(input_varname) input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
input_var) input_var
)
matmulv2_op_dist_attr.set_input_dist_attr( matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr) input_varname, tensor_dist_attr
)
for output_varname in matmul_v2_op.desc.output_arg_names(): for output_varname in matmul_v2_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
matmulv2_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) matmulv2_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr) ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx, _init_param_sync(
rank_id) Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -1615,7 +1858,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1615,7 +1858,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# RowParallel # RowParallel
class DistributedMatmulV2Impl1(DistributedOperatorImpl): class DistributedMatmulV2Impl1(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulV2Impl1, self).__init__(name) super(DistributedMatmulV2Impl1, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -1638,7 +1880,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1638,7 +1880,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0]) backward_op.input("Y")[0]
)
assert Y_var_dim_mapping[1] < 0 assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0] parallel_axis = Y_var_dim_mapping[0]
...@@ -1653,50 +1896,59 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1653,50 +1896,59 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res.append(comm_op_cost_list) res.append(comm_op_cost_list)
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx, )
processes, desc_mapping, cost_mapping = build_comp_costs_from_descs(
cluster) MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping) res.append(cost_mapping)
# need gradient allreduce # need gradient allreduce
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]) backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[ if (
batch_size_axis] > 1 and is_parameter_related( batch_size_axis > -1
backward_op.input("Y")[0], main_block): and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]] var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, build_dp_costs(
cluster) res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res return res
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx, cost_mapping = build_comp_costs_from_descs(
processes, desc_mapping, MatmulV2OpCost, ctx, processes, desc_mapping, cluster
cluster) )
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-2] serial_op.input("Y")[0]
)[-2]
attrs = {"use_calc_stream": True, "use_model_parallel": True} attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out") var_names = serial_op.output("Out")
...@@ -1706,11 +1958,16 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1706,11 +1958,16 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, AllreduceSumOpCost,
cluster) ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res_cost = [cost_mapping, comm_op_cost_list] res_cost = [cost_mapping, comm_op_cost_list]
return res_cost return res_cost
...@@ -1721,16 +1978,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1721,16 +1978,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = copy.deepcopy( x_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(x_name)) op_dist_attr.get_input_dims_mapping(x_name)
)
y_dims_mapping = copy.deepcopy( y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name)) op_dist_attr.get_input_dims_mapping(y_name)
)
trans_x = op_desc.attr('trans_x') trans_x = op_desc.attr('trans_x')
trans_y = op_desc.attr('trans_y') trans_y = op_desc.attr('trans_y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping) trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_replicate(x_dims_mapping[-1]): if is_dim_replicate(x_dims_mapping[-1]):
return False return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard( if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
y_dims_mapping[-1]): y_dims_mapping[-1]
):
return False return False
# Other dimensions must be replicate except the batch dimension # Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]: for mapping in x_dims_mapping[1:-1]:
...@@ -1752,8 +2012,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1752,8 +2012,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or (
(not self.is_output_compatible(dist_op)): not self.is_output_compatible(dist_op)
):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
return False return False
...@@ -1778,28 +2039,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1778,28 +2039,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert (
str(src_op)) op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id = _get_corresponding_rank(
rank_id) ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format( assert input_name in kwargs, "input [{}] is not given".format(
input_name) input_name
)
assert len(kwargs[input_name]) == len( assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name) src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name) ), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format( assert output_name in kwargs, "input [{}] is not given".format(
output_name) output_name
)
assert len(kwargs[output_name]) == len( assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name) src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format( ), "number of tensor for input [{}] is not match".format(
output_name) output_name
)
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
...@@ -1809,28 +2075,39 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1809,28 +2075,39 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2] Weight_var.name
)[-2]
if trans_y: if trans_y:
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1] Weight_var.name
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( )[-1]
matmul_row_dim_mapping) assert (
matmul_row_dim_mapping >= 0
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(
parallel_axis, rank_id) process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'], check_variable_and_dtype(
'linear') X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], )
'linear') check_dtype(
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = { attrs = {
'trans_x': trans_x, 'trans_x': trans_x,
'trans_y': trans_y, 'trans_y': trans_y,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
} }
inputs = {'X': X_var, 'Y': Weight_var} inputs = {'X': X_var, 'Y': Weight_var}
...@@ -1839,27 +2116,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1839,27 +2116,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
assert out_tensor_dist_attr is not None assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None assert out_var_dist_attr is not None
ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, ref_shape = infer_shape(
out_var_dist_attr) main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["c_allreduce_sum", 'tmp'])), ".".join(["c_allreduce_sum", 'tmp'])
),
shape=Out_var.shape, shape=Out_var.shape,
dtype=Out_var.dtype, dtype=Out_var.dtype,
type=Out_var.type, type=Out_var.type,
lod_level=Out_var.lod_level, lod_level=Out_var.lod_level,
persistable=False, persistable=False,
is_data=False, is_data=False,
need_check_feed=Out_var.desc.need_check_feed()) need_check_feed=Out_var.desc.need_check_feed(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr # set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(
out_var_dist_attr) intermediate_var_0, out_var_dist_attr
)
matmul_v2_op = main_block.append_op(type='matmul_v2', matmul_v2_op = main_block.append_op(
type='matmul_v2',
inputs=inputs, inputs=inputs,
outputs={'Out': intermediate_var_0}, outputs={'Out': intermediate_var_0},
attrs=attrs) attrs=attrs,
)
if intermediate_var_0.shape != ref_shape: if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape) intermediate_var_0.desc.set_shape(ref_shape)
...@@ -1871,8 +2154,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1871,8 +2154,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
}) },
)
if Out_var.shape != ref_shape: if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape) Out_var.desc.set_shape(ref_shape)
...@@ -1885,15 +2169,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1885,15 +2169,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
for input_varname in matmul_v2_op.desc.input_arg_names(): for input_varname in matmul_v2_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format( assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
matmulv2_op_dist_attr.set_input_dist_attr(input_varname, )
input_dist_attr) matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
output_varname = matmul_v2_op.desc.output_arg_names()[0] output_varname = matmul_v2_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
matmulv2_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) matmulv2_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr) ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
# allreduce # allreduce
...@@ -1905,21 +2193,26 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1905,21 +2193,26 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
input_var = main_block.var(input_varname) input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(input_varname, allreduce_op_dist_attr.set_input_dist_attr(
tensor_dist_attr) input_varname, tensor_dist_attr
)
for output_varname in c_allreduce_sum_op.desc.output_arg_names(): for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
allreduce_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) allreduce_op_dist_attr.set_output_dist_attr(
ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, output_varname, output_dist_attr
allreduce_op_dist_attr) )
ctx.set_op_dist_attr_for_program(
c_allreduce_sum_op, allreduce_op_dist_attr
)
# init param sync # init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx, _init_param_sync(
rank_id) Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -1928,7 +2221,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1928,7 +2221,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# ReplicateParallel # ReplicateParallel
class DistributedMatmulV2Impl2(DistributedOperatorImpl): class DistributedMatmulV2Impl2(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulV2Impl2, self).__init__(name) super(DistributedMatmulV2Impl2, self).__init__(name)
...@@ -1950,38 +2242,44 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -1950,38 +2242,44 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = process_mesh.processes processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx, cost_mapping = build_comp_costs_from_descs(
processes, desc_mapping, MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
cluster) )
res.append(cost_mapping) res.append(cost_mapping)
# need gradient allreduce # need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]) backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[ if (
batch_size_axis] > 1 and is_parameter_related( batch_size_axis > -1
backward_op.input("Y")[0], main_block): and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]] var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, build_dp_costs(
cluster) res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res return res
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx, cost_mapping = build_comp_costs_from_descs(
processes, desc_mapping, MatmulV2OpCost, ctx, processes, desc_mapping, cluster
cluster) )
res_cost = [cost_mapping] res_cost = [cost_mapping]
...@@ -1998,13 +2296,15 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -1998,13 +2296,15 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
if is_dim_shard(x_dims_mapping[-1]): if is_dim_shard(x_dims_mapping[-1]):
return False return False
if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard( if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
x_dims_mapping[-2]): x_dims_mapping[-2]
):
return False return False
if is_dim_shard(y_dims_mapping[-1]): if is_dim_shard(y_dims_mapping[-1]):
return False return False
if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard( if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
y_dims_mapping[-2]): y_dims_mapping[-2]
):
return False return False
return True return True
...@@ -2019,14 +2319,16 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -2019,14 +2319,16 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
if is_dim_shard(out_dims_mapping[-1]): if is_dim_shard(out_dims_mapping[-1]):
return False return False
if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard( if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
out_dims_mapping[-2]): out_dims_mapping[-2]
):
return False return False
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or (
(not self.is_output_compatible(dist_op)): not self.is_output_compatible(dist_op)
):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
...@@ -2050,16 +2352,18 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -2050,16 +2352,18 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs) _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
register_distributed_operator_impl("matmul_v2",
DistributedMatmulV2Impl0("column_parallel"))
register_distributed_operator_impl("matmul_v2",
DistributedMatmulV2Impl1("row_parallel"))
register_distributed_operator_impl( register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")) "matmul_v2", DistributedMatmulV2Impl0("column_parallel")
)
register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl1("row_parallel")
)
register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")
)
class DistributedMul(DistributedOperatorImplContainer): class DistributedMul(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedMul, self).__init__(op_type) super(DistributedMul, self).__init__(op_type)
...@@ -2069,7 +2373,6 @@ register_distributed_operator_impl_container(DistributedMul("mul")) ...@@ -2069,7 +2373,6 @@ register_distributed_operator_impl_container(DistributedMul("mul"))
# ColumnParallel # ColumnParallel
class DistributedMulImpl0(DistributedOperatorImpl): class DistributedMulImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMulImpl0, self).__init__(name) super(DistributedMulImpl0, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -2092,7 +2395,8 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2092,7 +2395,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0]) backward_op.input("Y")[0]
)
# col parallel: matmul + allreduce # col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0 assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1] parallel_axis = Y_var_dim_mapping[1]
...@@ -2102,13 +2406,14 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2102,13 +2406,14 @@ class DistributedMulImpl0(DistributedOperatorImpl):
assert len(backward_op.output("X@GRAD")) == 1 assert len(backward_op.output("X@GRAD")) == 1
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx, cost_mapping = build_comp_costs_from_descs(
processes, desc_mapping, MulGradOpCost, ctx, processes, desc_mapping, cluster
cluster) )
res.append(cost_mapping) res.append(cost_mapping)
# calc comm op cost # calc comm op cost
...@@ -2121,40 +2426,52 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2121,40 +2426,52 @@ class DistributedMulImpl0(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, AllreduceSumOpCost,
c_allreduce_sum_desc_mapping, cluster) ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res.append(comm_op_cost_list) res.append(comm_op_cost_list)
# need gradient allreduce # need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]) backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[ if (
batch_size_axis] > 1 and is_parameter_related( batch_size_axis > -1
backward_op.input("Y")[0], main_block): and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]] var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, build_dp_costs(
cluster) res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res return res
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes, cost_mapping = build_comp_costs_from_descs(
desc_mapping, cluster) MulOpCost, ctx, processes, desc_mapping, cluster
)
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-1] serial_op.input("Y")[0]
)[-1]
attrs = {"use_calc_stream": True, "use_model_parallel": True} attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.input("X") var_names = serial_op.input("X")
c_identity_desc_mapping = build_comm_desc_from_dist_op( c_identity_desc_mapping = build_comm_desc_from_dist_op(
...@@ -2163,10 +2480,12 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2163,10 +2480,12 @@ class DistributedMulImpl0(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res_cost = [comm_op_cost_list, cost_mapping] res_cost = [comm_op_cost_list, cost_mapping]
return res_cost return res_cost
...@@ -2181,7 +2500,8 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2181,7 +2500,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
if is_dim_shard(x_dims_mapping[-1]): if is_dim_shard(x_dims_mapping[-1]):
return False return False
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate( if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
y_dims_mapping[-1]): y_dims_mapping[-1]
):
return False return False
for mapping in x_dims_mapping[1:-1]: for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping): if is_dim_shard(mapping):
...@@ -2201,8 +2521,9 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2201,8 +2521,9 @@ class DistributedMulImpl0(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or (
(not self.is_output_compatible(dist_op)): not self.is_output_compatible(dist_op)
):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
...@@ -2229,28 +2550,33 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2229,28 +2550,33 @@ class DistributedMulImpl0(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert (
str(src_op)) op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id = _get_corresponding_rank(
rank_id) ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format( assert input_name in kwargs, "input [{}] is not given".format(
input_name) input_name
)
assert len(kwargs[input_name]) == len( assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name) src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name) ), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format( assert output_name in kwargs, "input [{}] is not given".format(
output_name) output_name
)
assert len(kwargs[output_name]) == len( assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name) src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format( ), "number of tensor for input [{}] is not match".format(
output_name) output_name
)
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
...@@ -2258,15 +2584,20 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2258,15 +2584,20 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1] Weight_var.name
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( )[-1]
matmul_col_dim_mapping) assert (
matmul_col_dim_mapping >= 0
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(
parallel_axis, rank_id) process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
# infer new var shape with op dist attr # infer new var shape with op dist attr
...@@ -2274,31 +2605,39 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2274,31 +2605,39 @@ class DistributedMulImpl0(DistributedOperatorImpl):
assert x_tensor_dist_attr is not None assert x_tensor_dist_attr is not None
identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
assert identity_var_dist_attr is not None assert identity_var_dist_attr is not None
ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr, ref_shape_x = infer_shape(
identity_var_dist_attr) main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
)
# infer out var shape with op dist attr # infer out var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None assert out_var_dist_attr is not None
ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr, ref_shape_out = infer_shape(
out_var_dist_attr) main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["c_identity", 'tmp'])), ".".join(["c_identity", 'tmp'])
),
dtype=X_var.dtype, dtype=X_var.dtype,
shape=X_var.shape, shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=X_var.stop_gradient) stop_gradient=X_var.stop_gradient,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr # set intermediate_var_0's dist_attr with X_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(
identity_var_dist_attr) intermediate_var_0, identity_var_dist_attr
)
check_variable_and_dtype( check_variable_and_dtype(
X_var, 'tensor', X_var,
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
c_identity_op = main_block.append_op( c_identity_op = main_block.append_op(
type='c_identity', type='c_identity',
inputs={'X': [X_var]}, inputs={'X': [X_var]},
...@@ -2307,20 +2646,29 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2307,20 +2646,29 @@ class DistributedMulImpl0(DistributedOperatorImpl):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
}) },
)
if intermediate_var_0.shape != ref_shape_x: if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x) intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(intermediate_var_0, 'x', check_variable_and_dtype(
['float16', 'float32', 'float64'], 'linear') intermediate_var_0,
check_dtype(intermediate_var_0.dtype, 'dtype', 'x',
['float16', 'float32', 'float64'], 'linear') ['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
# attrs = {'trans_x': False, 'trans_y': False} # attrs = {'trans_x': False, 'trans_y': False}
attrs = { attrs = {
"x_num_col_dims": src_op.desc.attr("x_num_col_dims"), "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
"y_num_col_dims": src_op.desc.attr("y_num_col_dims"), "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
} }
inputs = {'X': intermediate_var_0, 'Y': Weight_var} inputs = {'X': intermediate_var_0, 'Y': Weight_var}
...@@ -2334,16 +2682,15 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2334,16 +2682,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
inputs_original_shape[var_name] = var.shape inputs_original_shape[var_name] = var.shape
input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name) input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
input_ref_shape = infer_shape(main_block, var, input_ref_shape = infer_shape(
input_tensor_dist_attr, main_block, var, input_tensor_dist_attr, input_var_dist_attr
input_var_dist_attr) )
inputs_ref_shape[var_name] = input_ref_shape inputs_ref_shape[var_name] = input_ref_shape
var.desc.set_shape(input_ref_shape) var.desc.set_shape(input_ref_shape)
mul_op = main_block.append_op(type='mul', mul_op = main_block.append_op(
inputs=inputs, type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
outputs={'Out': Out_var}, )
attrs=attrs)
if Out_var.shape != ref_shape_out: if Out_var.shape != ref_shape_out:
Out_var.desc.set_shape(ref_shape_out) Out_var.desc.set_shape(ref_shape_out)
...@@ -2362,13 +2709,16 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2362,13 +2709,16 @@ class DistributedMulImpl0(DistributedOperatorImpl):
input_varname = c_identity_op.desc.input_arg_names()[0] input_varname = c_identity_op.desc.input_arg_names()[0]
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format( assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
identity_op_dist_attr.set_input_dist_attr(input_varname, )
input_dist_attr) identity_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
# output # output
output_varname = c_identity_op.desc.output_arg_names()[0] output_varname = c_identity_op.desc.output_arg_names()[0]
identity_op_dist_attr.set_output_dist_attr(output_varname, identity_op_dist_attr.set_output_dist_attr(
input_dist_attr) output_varname, input_dist_attr
)
ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)
# matmulv2 # matmulv2
...@@ -2379,29 +2729,37 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2379,29 +2729,37 @@ class DistributedMulImpl0(DistributedOperatorImpl):
for input_varname in mul_op.desc.input_arg_names(): for input_varname in mul_op.desc.input_arg_names():
if input_varname in src_op.desc.input_arg_names(): if input_varname in src_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr( input_dist_attr = op_dist_attr.get_input_dist_attr(
input_varname) input_varname
)
assert input_dist_attr is not None, "dist_attr is {}".format( assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
)
matmulv2_op_dist_attr.set_input_dist_attr( matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr) input_varname, input_dist_attr
)
else: else:
input_var = main_block.var(input_varname) input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
input_var) input_var
)
matmulv2_op_dist_attr.set_input_dist_attr( matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr) input_varname, tensor_dist_attr
)
for output_varname in mul_op.desc.output_arg_names(): for output_varname in mul_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
matmulv2_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) matmulv2_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr) ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx, _init_param_sync(
rank_id) Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -2410,7 +2768,6 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2410,7 +2768,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# RowParallel # RowParallel
class DistributedMulImpl1(DistributedOperatorImpl): class DistributedMulImpl1(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMulImpl1, self).__init__(name) super(DistributedMulImpl1, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -2434,7 +2791,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2434,7 +2791,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0]) backward_op.input("Y")[0]
)
assert Y_var_dim_mapping[1] < 0 assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0] parallel_axis = Y_var_dim_mapping[0]
...@@ -2447,49 +2805,59 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2447,49 +2805,59 @@ class DistributedMulImpl1(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
processes = process_mesh.processes processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res.append(comm_op_cost_list) res.append(comm_op_cost_list)
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx, )
processes, desc_mapping, cost_mapping = build_comp_costs_from_descs(
cluster) MulGradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping) res.append(cost_mapping)
# need gradient allreduce # need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]) backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[ if (
batch_size_axis] > 1 and is_parameter_related( batch_size_axis > -1
backward_op.input("Y")[0], main_block): and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]] var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, build_dp_costs(
cluster) res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res return res
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes, cost_mapping = build_comp_costs_from_descs(
desc_mapping, cluster) MulOpCost, ctx, processes, desc_mapping, cluster
)
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-2] serial_op.input("Y")[0]
)[-2]
attrs = {"use_calc_stream": True, "use_model_parallel": True} attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out") var_names = serial_op.output("Out")
...@@ -2499,12 +2867,17 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2499,12 +2867,17 @@ class DistributedMulImpl1(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
# print("dist_matmul.py dist_op: ", dist_op) # print("dist_matmul.py dist_op: ", dist_op)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, AllreduceSumOpCost,
cluster) ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res_cost = [cost_mapping, comm_op_cost_list] res_cost = [cost_mapping, comm_op_cost_list]
...@@ -2520,7 +2893,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2520,7 +2893,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
if is_dim_replicate(x_dims_mapping[-1]): if is_dim_replicate(x_dims_mapping[-1]):
return False return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard( if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
y_dims_mapping[-1]): y_dims_mapping[-1]
):
return False return False
# Other dimensions must be replicate except the batch dimension # Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]: for mapping in x_dims_mapping[1:-1]:
...@@ -2542,8 +2916,9 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2542,8 +2916,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or (
(not self.is_output_compatible(dist_op)): not self.is_output_compatible(dist_op)
):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
...@@ -2570,28 +2945,33 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2570,28 +2945,33 @@ class DistributedMulImpl1(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert (
str(src_op)) op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id = _get_corresponding_rank(
rank_id) ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format( assert input_name in kwargs, "input [{}] is not given".format(
input_name) input_name
)
assert len(kwargs[input_name]) == len( assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name) src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name) ), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format( assert output_name in kwargs, "input [{}] is not given".format(
output_name) output_name
)
assert len(kwargs[output_name]) == len( assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name) src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format( ), "number of tensor for input [{}] is not match".format(
output_name) output_name
)
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
...@@ -2599,26 +2979,36 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2599,26 +2979,36 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2] Weight_var.name
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( )[-2]
matmul_row_dim_mapping) assert (
matmul_row_dim_mapping >= 0
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(
parallel_axis, rank_id) process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'], check_variable_and_dtype(
'linear') X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], )
'linear') check_dtype(
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
# attrs = {'trans_x': False, 'trans_y': False} # attrs = {'trans_x': False, 'trans_y': False}
attrs = { attrs = {
"x_num_col_dims": src_op.desc.attr("x_num_col_dims"), "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
"y_num_col_dims": src_op.desc.attr("y_num_col_dims"), "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
} }
inputs = {'X': X_var, 'Y': Weight_var} inputs = {'X': X_var, 'Y': Weight_var}
...@@ -2627,22 +3017,26 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2627,22 +3017,26 @@ class DistributedMulImpl1(DistributedOperatorImpl):
assert out_tensor_dist_attr is not None assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None assert out_var_dist_attr is not None
ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, ref_shape = infer_shape(
out_var_dist_attr) main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["c_allreduce_sum", 'tmp'])), ".".join(["c_allreduce_sum", 'tmp'])
),
shape=Out_var.shape, shape=Out_var.shape,
dtype=Out_var.dtype, dtype=Out_var.dtype,
type=Out_var.type, type=Out_var.type,
lod_level=Out_var.lod_level, lod_level=Out_var.lod_level,
persistable=False, persistable=False,
is_data=False, is_data=False,
need_check_feed=Out_var.desc.need_check_feed()) need_check_feed=Out_var.desc.need_check_feed(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr # set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(
out_var_dist_attr) intermediate_var_0, out_var_dist_attr
)
inputs_ref_shape = {} inputs_ref_shape = {}
inputs_original_shape = {} inputs_original_shape = {}
...@@ -2651,16 +3045,18 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2651,16 +3045,18 @@ class DistributedMulImpl1(DistributedOperatorImpl):
inputs_original_shape[var_name] = var.shape inputs_original_shape[var_name] = var.shape
input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name) input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
input_ref_shape = infer_shape(main_block, var, input_ref_shape = infer_shape(
input_tensor_dist_attr, main_block, var, input_tensor_dist_attr, input_var_dist_attr
input_var_dist_attr) )
inputs_ref_shape[var_name] = input_ref_shape inputs_ref_shape[var_name] = input_ref_shape
var.desc.set_shape(input_ref_shape) var.desc.set_shape(input_ref_shape)
mul_op = main_block.append_op(type='mul', mul_op = main_block.append_op(
type='mul',
inputs=inputs, inputs=inputs,
outputs={'Out': intermediate_var_0}, outputs={'Out': intermediate_var_0},
attrs=attrs) attrs=attrs,
)
if intermediate_var_0.shape != ref_shape: if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape) intermediate_var_0.desc.set_shape(ref_shape)
...@@ -2678,8 +3074,9 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2678,8 +3074,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
}) },
)
if Out_var.shape != ref_shape: if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape) Out_var.desc.set_shape(ref_shape)
...@@ -2693,15 +3090,19 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2693,15 +3090,19 @@ class DistributedMulImpl1(DistributedOperatorImpl):
for input_varname in mul_op.desc.input_arg_names(): for input_varname in mul_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format( assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
matmulv2_op_dist_attr.set_input_dist_attr(input_varname, )
input_dist_attr) matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
output_varname = mul_op.desc.output_arg_names()[0] output_varname = mul_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
matmulv2_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) matmulv2_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr) ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)
# allreduce # allreduce
...@@ -2713,21 +3114,26 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2713,21 +3114,26 @@ class DistributedMulImpl1(DistributedOperatorImpl):
input_var = main_block.var(input_varname) input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(input_varname, allreduce_op_dist_attr.set_input_dist_attr(
tensor_dist_attr) input_varname, tensor_dist_attr
)
for output_varname in c_allreduce_sum_op.desc.output_arg_names(): for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
allreduce_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) allreduce_op_dist_attr.set_output_dist_attr(
ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, output_varname, output_dist_attr
allreduce_op_dist_attr) )
ctx.set_op_dist_attr_for_program(
c_allreduce_sum_op, allreduce_op_dist_attr
)
# init param sync # init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx, _init_param_sync(
rank_id) Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -2736,7 +3142,6 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2736,7 +3142,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# ReplicateParallel # ReplicateParallel
class DistributedMulImpl2(DistributedOperatorImpl): class DistributedMulImpl2(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMulImpl2, self).__init__(name) super(DistributedMulImpl2, self).__init__(name)
...@@ -2757,38 +3162,45 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -2757,38 +3162,45 @@ class DistributedMulImpl2(DistributedOperatorImpl):
vars = main_block.vars vars = main_block.vars
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx, cost_mapping = build_comp_costs_from_descs(
processes, desc_mapping, MulGradOpCost, ctx, processes, desc_mapping, cluster
cluster) )
res.append(cost_mapping) res.append(cost_mapping)
# need gradient allreduce # need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]) backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[ if (
batch_size_axis] > 1 and is_parameter_related( batch_size_axis > -1
backward_op.input("Y")[0], main_block): and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]] var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, build_dp_costs(
cluster) res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res return res
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes, cost_mapping = build_comp_costs_from_descs(
desc_mapping, cluster) MulOpCost, ctx, processes, desc_mapping, cluster
)
res_cost = [cost_mapping] res_cost = [cost_mapping]
return res_cost return res_cost
...@@ -2804,12 +3216,14 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -2804,12 +3216,14 @@ class DistributedMulImpl2(DistributedOperatorImpl):
if is_dim_shard(x_dims_mapping[-1]): if is_dim_shard(x_dims_mapping[-1]):
return False return False
if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard( if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
x_dims_mapping[-2]): x_dims_mapping[-2]
):
return False return False
if is_dim_shard(y_dims_mapping[-1]): if is_dim_shard(y_dims_mapping[-1]):
return False return False
if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard( if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
y_dims_mapping[-2]): y_dims_mapping[-2]
):
return False return False
return True return True
...@@ -2824,14 +3238,16 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -2824,14 +3238,16 @@ class DistributedMulImpl2(DistributedOperatorImpl):
if is_dim_shard(out_dims_mapping[-1]): if is_dim_shard(out_dims_mapping[-1]):
return False return False
if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard( if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
out_dims_mapping[-2]): out_dims_mapping[-2]
):
return False return False
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or (
(not self.is_output_compatible(dist_op)): not self.is_output_compatible(dist_op)
):
return False return False
if not _is_auto_compatible_for_matmul(dist_op): if not _is_auto_compatible_for_matmul(dist_op):
...@@ -2855,8 +3271,10 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -2855,8 +3271,10 @@ class DistributedMulImpl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs) _right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
register_distributed_operator_impl("mul", register_distributed_operator_impl(
DistributedMulImpl0("column_parallel")) "mul", DistributedMulImpl0("column_parallel")
)
register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel")) register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel"))
register_distributed_operator_impl("mul", register_distributed_operator_impl(
DistributedMulImpl2("replicate_parallel")) "mul", DistributedMulImpl2("replicate_parallel")
)
...@@ -254,17 +254,26 @@ class Parallelizer: ...@@ -254,17 +254,26 @@ class Parallelizer:
self._dist_context.serial_feed_vars["inputs"] self._dist_context.serial_feed_vars["inputs"]
+ self._dist_context.serial_feed_vars["labels"] + self._dist_context.serial_feed_vars["labels"]
) )
if config["use_pure_fp16"]: self._logger.info(
"Applying AMP-{}-{} ...".format(
config["dtype"], config['level']
),
)
if config['level'] == "o1":
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_amp_pass.get_loss()
elif config['level'] in ['o2', 'o3']:
config["base_opt"] = optimizer config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply( auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
loss = auto_parallel_fp16_pass.get_loss()
else: else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) raise ValueError("AMP level should be one of o1, o2, o3")
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
# apply recompute pass # apply recompute pass
# recompute is then train-only optimization # recompute is then train-only optimization
......
...@@ -18,25 +18,48 @@ from paddle.fluid import unique_name ...@@ -18,25 +18,48 @@ from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.distributed.auto_parallel.utils import get_loss_op, set_var_dist_attr from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping get_loss_op,
from paddle.distributed.auto_parallel.process_group import get_world_process_group set_var_dist_attr,
from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists )
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _keep_fp32_output, find_op_index from paddle.distributed.auto_parallel.utils import (
from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg )
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
AutoMixedPrecisionLists,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_keep_fp32_input,
_keep_fp32_output,
find_op_index,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_valid_types,
find_true_post_op,
find_true_prev_op,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_is_in_black_varnames,
_dtype_to_str,
_rename_arg,
)
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
)
from ..auto_parallel.utils import is_forward_op, is_backward_op, is_loss_op from ..auto_parallel.utils import is_forward_op, is_backward_op, is_loss_op
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
class AMPState(object): class AMPState(object):
def __init__(self, block): def __init__(self, block):
self._block = block self._block = block
self._op_fp16_dict = { self._op_fp16_dict = (
} # op_id --> True/False. 'True' means that the current op is in fp16 mode. {}
) # op_id --> True/False. 'True' means that the current op is in fp16 mode.
self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name} self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name}
self.is_train = False self.is_train = False
...@@ -55,7 +78,8 @@ class AMPState(object): ...@@ -55,7 +78,8 @@ class AMPState(object):
elif int(op.attr('op_role')) == int(OpRole.Backward): elif int(op.attr('op_role')) == int(OpRole.Backward):
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[ fwd_op_id = dist_op_context.grad_op_id_to_op_id[
op.desc.original_id()] op.desc.original_id()
]
if self._is_fp16_op(fwd_op_id) == True: if self._is_fp16_op(fwd_op_id) == True:
self._op_fp16_dict[op.desc.original_id()] = True self._op_fp16_dict[op.desc.original_id()] = True
elif self._is_fp16_op(fwd_op_id) == False: elif self._is_fp16_op(fwd_op_id) == False:
...@@ -78,7 +102,8 @@ class AMPState(object): ...@@ -78,7 +102,8 @@ class AMPState(object):
if op.type == 'create_py_reader' or op.type == 'read': if op.type == 'create_py_reader' or op.type == 'read':
continue continue
if amp_lists.black_varnames is not None and _is_in_black_varnames( if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists): op, amp_lists
):
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
continue continue
if op.type in amp_lists.black_list: if op.type in amp_lists.black_list:
...@@ -98,17 +123,24 @@ class AMPState(object): ...@@ -98,17 +123,24 @@ class AMPState(object):
continue continue
elif in_var.op is op: elif in_var.op is op:
prev_op = find_true_prev_op( prev_op = find_true_prev_op(
ops, op, in_var_name) ops, op, in_var_name
)
if prev_op is None: if prev_op is None:
continue continue
else: else:
prev_op = in_var.op prev_op = in_var.op
# if it's one of inputs # if it's one of inputs
if self._is_fp16_op(prev_op.desc.original_id()) == False or \ if (
prev_op.type in amp_lists.black_list: self._is_fp16_op(prev_op.desc.original_id())
== False
or prev_op.type in amp_lists.black_list
):
is_black_op = True is_black_op = True
elif self._is_fp16_op(prev_op.desc.original_id()) == True or \ elif (
prev_op.type in amp_lists.white_list: self._is_fp16_op(prev_op.desc.original_id())
== True
or prev_op.type in amp_lists.white_list
):
is_white_op = True is_white_op = True
if is_black_op: if is_black_op:
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
...@@ -131,19 +163,28 @@ class AMPState(object): ...@@ -131,19 +163,28 @@ class AMPState(object):
break break
if self._is_fp16_op(op.desc.original_id()) == False: if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_cast_op_forward( num_cast_ops = self._insert_cast_op_forward(
op, idx, core.VarDesc.VarType.FP16, op,
core.VarDesc.VarType.FP32, dist_context) idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
dist_context,
)
elif self._is_fp16_op(op.desc.original_id()) == True: elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_cast_op_forward( num_cast_ops = self._insert_cast_op_forward(
op, idx, core.VarDesc.VarType.FP32, op,
core.VarDesc.VarType.FP16, dist_context) idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
dist_context,
)
else: else:
pass pass
idx += num_cast_ops + 1 idx += num_cast_ops + 1
self._block._sync_with_cpp() self._block._sync_with_cpp()
def _insert_cast_op_forward(self, op, idx, src_dtype, dst_dtype, def _insert_cast_op_forward(
dist_context): self, op, idx, src_dtype, dst_dtype, dist_context
):
""" """
only for forward cast only for forward cast
modified from paddle.fluid.contrib.mixed_precision modified from paddle.fluid.contrib.mixed_precision
...@@ -152,38 +193,45 @@ class AMPState(object): ...@@ -152,38 +193,45 @@ class AMPState(object):
var_name_dict = {} var_name_dict = {}
for in_name in op.input_names: for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name): op, in_name
):
continue continue
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name) in_var = self._block._find_var_recursive(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dst_dtype: if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
continue continue
if in_var.dtype == src_dtype: if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str( cast_name = (
dst_dtype) in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
)
out_var = self._block.vars.get(cast_name) out_var = self._block.vars.get(cast_name)
var_name_dict[in_var.name] = cast_name var_name_dict[in_var.name] = cast_name
consume_op_attr = dist_context.get_op_dist_attr_for_program( consume_op_attr = dist_context.get_op_dist_attr_for_program(
op) op
)
assert consume_op_attr is not None assert consume_op_attr is not None
if out_var is None or out_var.dtype != dst_dtype: if out_var is None or out_var.dtype != dst_dtype:
# NOTE we make the cast op and var's dist attr as the op that consume the # NOTE we make the cast op and var's dist attr as the op that consume the
# cast var instead of the op which generates the var # cast var instead of the op which generates the var
in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var.name) in_var.name
)
assert in_var_dist_attr is not None assert in_var_dist_attr is not None
ref_mesh = in_var_dist_attr.process_mesh ref_mesh = in_var_dist_attr.process_mesh
ref_mapping = in_var_dist_attr.dims_mapping ref_mapping = in_var_dist_attr.dims_mapping
consume_op_attr.set_input_dist_attr( consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr) cast_name, in_var_dist_attr
)
out_var = self._block.create_var( out_var = self._block.create_var(
name=cast_name, name=cast_name,
dtype=dst_dtype, dtype=dst_dtype,
persistable=False, persistable=False,
stop_gradient=in_var.stop_gradient) stop_gradient=in_var.stop_gradient,
set_var_dist_attr(dist_context, out_var, ref_mapping, )
ref_mesh) set_var_dist_attr(
dist_context, out_var, ref_mapping, ref_mesh
)
cast_op = self._block._insert_op_without_sync( cast_op = self._block._insert_op_without_sync(
idx, idx,
...@@ -193,22 +241,29 @@ class AMPState(object): ...@@ -193,22 +241,29 @@ class AMPState(object):
attrs={ attrs={
"in_dtype": in_var.dtype, "in_dtype": in_var.dtype,
"out_dtype": out_var.dtype, "out_dtype": out_var.dtype,
}) },
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context) cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1 num_cast_ops += 1
else: else:
in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var.name) in_var.name
)
consume_op_attr.set_input_dist_attr( consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr) cast_name, in_var_dist_attr
)
_rename_arg(op, in_var.name, cast_name) _rename_arg(op, in_var.name, cast_name)
else: else:
if op.has_attr('in_dtype'): if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dst_dtype) op._set_attr('in_dtype', dst_dtype)
self._var_name_dict[op.desc.original_id()] = var_name_dict self._var_name_dict[op.desc.original_id()] = var_name_dict
if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16: if (
src_dtype == core.VarDesc.VarType.FP32
and dst_dtype == core.VarDesc.VarType.FP16
):
for out_name in op.output_names: for out_name in op.output_names:
if _keep_fp32_output(op, out_name): if _keep_fp32_output(op, out_name):
continue continue
...@@ -238,8 +293,9 @@ class AMPState(object): ...@@ -238,8 +293,9 @@ class AMPState(object):
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted, # NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# which will affect the dist_op to insert allreduce_sum op. # which will affect the dist_op to insert allreduce_sum op.
op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op) op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
if is_backward_op(grad_op) and (is_forward_op(ops[idx - 1]) if is_backward_op(grad_op) and (
or is_loss_op(ops[idx - 1])): is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1])
):
if not op_dist_attr.is_recompute: if not op_dist_attr.is_recompute:
appended_grad_times += 1 appended_grad_times += 1
...@@ -248,14 +304,22 @@ class AMPState(object): ...@@ -248,14 +304,22 @@ class AMPState(object):
if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id: if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(grad_op_orig_id) == False: # fp32 if self._is_fp16_op(grad_op_orig_id) == False: # fp32
num_cast_ops = self._insert_cast_op_backward( num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP16, grad_op,
core.VarDesc.VarType.FP32, dist_context, idx,
appended_grad_times) core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
dist_context,
appended_grad_times,
)
elif self._is_fp16_op(grad_op_orig_id) == True: # fp16 elif self._is_fp16_op(grad_op_orig_id) == True: # fp16
num_cast_ops = self._insert_cast_op_backward( num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP32, grad_op,
core.VarDesc.VarType.FP16, dist_context, idx,
appended_grad_times) core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
dist_context,
appended_grad_times,
)
elif grad_op.type == "sum": elif grad_op.type == "sum":
in_var_name = grad_op.desc.input_arg_names()[0] in_var_name = grad_op.desc.input_arg_names()[0]
src_dtype = self._block.var(in_var_name).dtype src_dtype = self._block.var(in_var_name).dtype
...@@ -270,15 +334,24 @@ class AMPState(object): ...@@ -270,15 +334,24 @@ class AMPState(object):
else: else:
raise ValueError( raise ValueError(
"'{}' op is not supported in the complete amp pass.".format( "'{}' op is not supported in the complete amp pass.".format(
grad_op.type)) grad_op.type
)
)
idx += num_cast_ops + 1 idx += num_cast_ops + 1
self._block._sync_with_cpp() self._block._sync_with_cpp()
_update_backward_cast_ops(params_grads, dist_context) _update_backward_cast_ops(params_grads, dist_context)
def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype, def _insert_cast_op_backward(
dist_context, appended_grad_times): self,
""" only for backward cast """ grad_op,
idx,
src_dtype,
dst_dtype,
dist_context,
appended_grad_times,
):
"""only for backward cast"""
def _keep_fp32_input(op, in_name): def _keep_fp32_input(op, in_name):
op_type = op.type op_type = op.type
...@@ -299,7 +372,8 @@ class AMPState(object): ...@@ -299,7 +372,8 @@ class AMPState(object):
for in_name in grad_op.input_names: for in_name in grad_op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
grad_op, in_name): grad_op, in_name
):
for in_var_name in grad_op.input(in_name): for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name) in_var = self._block._find_var_recursive(in_var_name)
assert in_var.dtype == core.VarDesc.VarType.FP32 assert in_var.dtype == core.VarDesc.VarType.FP32
...@@ -309,24 +383,34 @@ class AMPState(object): ...@@ -309,24 +383,34 @@ class AMPState(object):
in_var = self._block._find_var_recursive(in_var_name) in_var = self._block._find_var_recursive(in_var_name)
if in_var.dtype == src_dtype: if in_var.dtype == src_dtype:
consume_op_attr = dist_context.get_op_dist_attr_for_program( consume_op_attr = dist_context.get_op_dist_attr_for_program(
grad_op) grad_op
)
if in_var_name in self._var_name_dict[fwd_op_id]: if in_var_name in self._var_name_dict[fwd_op_id]:
# NOTE: if in_var of consume grad_op has been casted before, # NOTE: if in_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr. # it should be renamed and reset dist_attr.
cast_name = self._var_name_dict[fwd_op_id][in_var_name] cast_name = self._var_name_dict[fwd_op_id][in_var_name]
grad_op.desc._rename_input(in_var_name, cast_name) grad_op.desc._rename_input(in_var_name, cast_name)
in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name) in_var_name
)
consume_op_attr.set_input_dist_attr( consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr) cast_name, in_var_dist_attr
)
else: else:
assert in_var.dtype == dst_dtype, "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format( assert (
grad_op.type, in_name, dst_dtype, in_var.dtype, in_var.dtype == dst_dtype
str(grad_op)) ), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type,
in_name,
dst_dtype,
in_var.dtype,
str(grad_op),
)
for out_name in grad_op.output_names: for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
grad_op, out_name): grad_op, out_name
):
for out_var_name in grad_op.output(out_name): for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name) out_var = self._block._find_var_recursive(out_var_name)
assert out_var.dtype == core.VarDesc.VarType.FP32 assert out_var.dtype == core.VarDesc.VarType.FP32
...@@ -334,7 +418,7 @@ class AMPState(object): ...@@ -334,7 +418,7 @@ class AMPState(object):
for out_var_name in grad_op.output(out_name): for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name) out_var = self._block._find_var_recursive(out_var_name)
out_var_name_prefix = out_var_name[:out_var_name.find("@")] out_var_name_prefix = out_var_name[: out_var_name.find("@")]
fwd_var = self._block._find_var_recursive(out_var_name_prefix) fwd_var = self._block._find_var_recursive(out_var_name_prefix)
# NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
if out_var.dtype != fwd_var.dtype: if out_var.dtype != fwd_var.dtype:
...@@ -345,34 +429,45 @@ class AMPState(object): ...@@ -345,34 +429,45 @@ class AMPState(object):
# NOTE: if out_var of consume grad_op has been casted before, # NOTE: if out_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr, then we insert cast op to # it should be renamed and reset dist_attr, then we insert cast op to
# convert the cast_var to original dtype # convert the cast_var to original dtype
consume_op_attr = dist_context.get_op_dist_attr_for_program( consume_op_attr = (
grad_op) dist_context.get_op_dist_attr_for_program(grad_op)
)
fwd_cast_name = self._var_name_dict[fwd_op_id][ fwd_cast_name = self._var_name_dict[fwd_op_id][
out_var_name_prefix] out_var_name_prefix
]
suffix = "" suffix = ""
if "@RENAME" in out_var_name: if "@RENAME" in out_var_name:
suffix = out_var_name[out_var_name.find("@RENAME"):] suffix = out_var_name[
out_var_name.find("@RENAME") :
]
cast_name = fwd_cast_name + "@GRAD" + suffix cast_name = fwd_cast_name + "@GRAD" + suffix
cast_var = self._block.vars.get(cast_name) cast_var = self._block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dst_dtype: if cast_var is None or cast_var.dtype != dst_dtype:
grad_op.desc._rename_output(out_var_name, cast_name) grad_op.desc._rename_output(out_var_name, cast_name)
out_var_dist_attr = consume_op_attr.get_output_dist_attr( out_var_dist_attr = (
out_var_name) consume_op_attr.get_output_dist_attr(
out_var_name
)
)
ref_mesh = out_var_dist_attr.process_mesh ref_mesh = out_var_dist_attr.process_mesh
ref_mapping = out_var_dist_attr.dims_mapping ref_mapping = out_var_dist_attr.dims_mapping
consume_op_attr.set_output_dist_attr( consume_op_attr.set_output_dist_attr(
cast_name, out_var_dist_attr) cast_name, out_var_dist_attr
)
assert ref_mapping is not None assert ref_mapping is not None
cast_var = self._block.create_var( cast_var = self._block.create_var(
name=cast_name, name=cast_name,
shape=out_var.shape, shape=out_var.shape,
dtype=dst_dtype, dtype=dst_dtype,
persistable=False, persistable=False,
stop_gradient=out_var.stop_gradient) stop_gradient=out_var.stop_gradient,
set_var_dist_attr(dist_context, cast_var, )
ref_mapping, ref_mesh) set_var_dist_attr(
dist_context, cast_var, ref_mapping, ref_mesh
)
dist_op_context.grad_var_to_var[ dist_op_context.grad_var_to_var[
appended_grad_times][cast_name] = fwd_cast_name appended_grad_times
][cast_name] = fwd_cast_name
cast_op = self._block._insert_op( cast_op = self._block._insert_op(
idx + 1, idx + 1,
...@@ -382,13 +477,15 @@ class AMPState(object): ...@@ -382,13 +477,15 @@ class AMPState(object):
attrs={ attrs={
"in_dtype": cast_var.dtype, "in_dtype": cast_var.dtype,
"out_dtype": out_var.dtype, "out_dtype": out_var.dtype,
"op_role": OpRole.Backward "op_role": OpRole.Backward,
}) },
)
cast_op._remove_attr("op_role_var") cast_op._remove_attr("op_role_var")
cast_op._remove_attr("op_namescope") cast_op._remove_attr("op_namescope")
cast_op._remove_attr("with_quant_attr") cast_op._remove_attr("with_quant_attr")
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context) cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1 num_cast_ops += 1
else: else:
assert out_var.dtype == dst_dtype assert out_var.dtype == dst_dtype
...@@ -409,15 +506,18 @@ def _update_backward_cast_ops(params_grads, dist_context): ...@@ -409,15 +506,18 @@ def _update_backward_cast_ops(params_grads, dist_context):
for p, g in params_grads: for p, g in params_grads:
op = g.op op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast': if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
if int(op.attr('op_role')) == int( if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
OpRole.Backward) and op.has_attr('op_role_var'): 'op_role_var'
):
op._remove_attr("op_role_var") op._remove_attr("op_role_var")
post_ops = find_true_post_op(main_block.ops, op, g.name) post_ops = find_true_post_op(main_block.ops, op, g.name)
if post_ops: if post_ops:
raise ValueError("The cast op {0}'s output should not be" raise ValueError(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it" "used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0])) "is used by {1}".format(op, post_ops[0])
)
if op == main_block.ops[-1]: if op == main_block.ops[-1]:
continue continue
...@@ -425,23 +525,29 @@ def _update_backward_cast_ops(params_grads, dist_context): ...@@ -425,23 +525,29 @@ def _update_backward_cast_ops(params_grads, dist_context):
# add new op in the python and cpp at the same time # add new op in the python and cpp at the same time
new_op_desc = main_block.desc.append_op() new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
new_op = paddle.fluid.framework.Operator(block=main_block, new_op = paddle.fluid.framework.Operator(
block=main_block,
desc=new_op_desc, desc=new_op_desc,
type=None, type=None,
inputs=None, inputs=None,
outputs=None, outputs=None,
attrs=None) attrs=None,
)
main_block.ops.append(new_op) main_block.ops.append(new_op)
# dist attr # dist attr
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p) param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
output_dist_attr = dist_context.get_tensor_dist_attr_for_program( output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
main_block.var(op.output_arg_names[0])) main_block.var(op.output_arg_names[0])
)
assert param_dist_attr is not None assert param_dist_attr is not None
assert output_dist_attr is not None assert output_dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, param_dist_attr.process_mesh, new_op,
param_dist_attr.dims_mapping, dist_context) param_dist_attr.process_mesh,
param_dist_attr.dims_mapping,
dist_context,
)
output_dist_attr.process_mesh = param_dist_attr.process_mesh output_dist_attr.process_mesh = param_dist_attr.process_mesh
output_dist_attr.dims_mapping = param_dist_attr.dims_mapping output_dist_attr.dims_mapping = param_dist_attr.dims_mapping
...@@ -462,26 +568,34 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): ...@@ -462,26 +568,34 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale') check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
for e in grads: for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], check_variable_and_dtype(
'check_finite_and_unscale') e,
"x",
['float16', 'float32', 'float64'],
'check_finite_and_unscale',
)
found_inf = main_block.create_var( found_inf = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
['find_infinite_scale', 'tmp'])), ".".join(['find_infinite_scale', 'tmp'])
),
shape=[1], shape=[1],
dtype='bool', dtype='bool',
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=False) stop_gradient=False,
)
set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks) set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
inputs = {'X': grads, 'Scale': loss_scaling} inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf} outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Optimize} attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale', new_op = main_block.append_op(
type='check_finite_and_unscale',
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
attrs=attrs) attrs=attrs,
)
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = world_process_group.ranks new_op_dist_attr.process_mesh = world_process_group.ranks
...@@ -491,17 +605,18 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): ...@@ -491,17 +605,18 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
for g in grads: for g in grads:
g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(g.name, new_op_dist_attr.set_input_dims_mapping(
g_dist_attr.dims_mapping) g.name, g_dist_attr.dims_mapping
new_op_dist_attr.set_output_dims_mapping(g.name, )
g_dist_attr.dims_mapping) new_op_dist_attr.set_output_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
return grads, found_inf return grads, found_inf
@register_pass("auto_parallel_amp") @register_pass("auto_parallel_amp")
class AMPPass(PassBase): class AMPPass(PassBase):
def __init__(self): def __init__(self):
super(AMPPass, self).__init__() super(AMPPass, self).__init__()
self.set_attr("loss", None) self.set_attr("loss", None)
...@@ -517,6 +632,7 @@ class AMPPass(PassBase): ...@@ -517,6 +632,7 @@ class AMPPass(PassBase):
self.set_attr("use_dynamic_loss_scaling", False) self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("input_data", []) self.set_attr("input_data", [])
self.set_attr("params_grads", []) self.set_attr("params_grads", [])
self.set_attr("dtype", "") # fp16/bf16
self._loss = None self._loss = None
self._loss_scaling = None self._loss_scaling = None
self._num_good_steps = None self._num_good_steps = None
...@@ -524,6 +640,8 @@ class AMPPass(PassBase): ...@@ -524,6 +640,8 @@ class AMPPass(PassBase):
self._loss = None self._loss = None
def _check_self(self): def _check_self(self):
if self.get_attr("dtype") not in ["float16", "bfloat16"]:
return False
if self.get_attr("init_loss_scaling") < 0: if self.get_attr("init_loss_scaling") < 0:
return False return False
if self.get_attr("incr_every_n_steps") < 0: if self.get_attr("incr_every_n_steps") < 0:
...@@ -548,11 +666,13 @@ class AMPPass(PassBase): ...@@ -548,11 +666,13 @@ class AMPPass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context") self.dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads") params_grads = self.get_attr("params_grads")
self.amp_dtype = self.get_attr("dtype")
amp_lists = AutoMixedPrecisionLists( amp_lists = AutoMixedPrecisionLists(
set(self.get_attr("custom_white_list")), set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")), set(self.get_attr("custom_black_list")),
set(self.get_attr("custom_black_varnames"))) set(self.get_attr("custom_black_varnames")),
)
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
amp_state = AMPState(main_program.global_block()) amp_state = AMPState(main_program.global_block())
...@@ -566,10 +686,13 @@ class AMPPass(PassBase): ...@@ -566,10 +686,13 @@ class AMPPass(PassBase):
self._init_amp_var() self._init_amp_var()
self._scale_loss() self._scale_loss()
if self.get_attr("use_dynamic_loss_scaling" if (
) or self.get_attr("init_loss_scaling") != 1.0: self.get_attr("use_dynamic_loss_scaling")
or self.get_attr("init_loss_scaling") != 1.0
):
grads, found_inf = _check_and_update_gradient( grads, found_inf = _check_and_update_gradient(
params_grads, self._loss_scaling, self.dist_context) params_grads, self._loss_scaling, self.dist_context
)
if self.get_attr("use_dynamic_loss_scaling"): if self.get_attr("use_dynamic_loss_scaling"):
self._update_loss_scaling(grads, found_inf) self._update_loss_scaling(grads, found_inf)
...@@ -580,9 +703,14 @@ class AMPPass(PassBase): ...@@ -580,9 +703,14 @@ class AMPPass(PassBase):
shape=[1], shape=[1],
value=self.get_attr("init_loss_scaling"), value=self.get_attr("init_loss_scaling"),
dtype='float32', dtype='float32',
persistable=True) persistable=True,
set_var_dist_attr(self.dist_context, self._loss_scaling, [-1], )
world_process_group.ranks) set_var_dist_attr(
self.dist_context,
self._loss_scaling,
[-1],
world_process_group.ranks,
)
if self.get_attr("use_dynamic_loss_scaling"): if self.get_attr("use_dynamic_loss_scaling"):
self._num_good_steps = paddle.static.create_global_var( self._num_good_steps = paddle.static.create_global_var(
...@@ -590,18 +718,28 @@ class AMPPass(PassBase): ...@@ -590,18 +718,28 @@ class AMPPass(PassBase):
shape=[1], shape=[1],
value=0, value=0,
dtype='int32', dtype='int32',
persistable=True) persistable=True,
set_var_dist_attr(self.dist_context, self._num_good_steps, [-1], )
world_process_group.ranks) set_var_dist_attr(
self.dist_context,
self._num_good_steps,
[-1],
world_process_group.ranks,
)
self._num_bad_steps = paddle.static.create_global_var( self._num_bad_steps = paddle.static.create_global_var(
name=unique_name.generate("num_bad_steps"), name=unique_name.generate("num_bad_steps"),
shape=[1], shape=[1],
value=0, value=0,
dtype='int32', dtype='int32',
persistable=True) persistable=True,
set_var_dist_attr(self.dist_context, self._num_bad_steps, [-1], )
world_process_group.ranks) set_var_dist_attr(
self.dist_context,
self._num_bad_steps,
[-1],
world_process_group.ranks,
)
def _scale_loss(self): def _scale_loss(self):
...@@ -613,7 +751,8 @@ class AMPPass(PassBase): ...@@ -613,7 +751,8 @@ class AMPPass(PassBase):
assert loss is not None assert loss is not None
loss_op = loss.op loss_op = loss.op
loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
loss_op) loss_op
)
if loss.dtype != core.VarDesc.VarType.FP32: if loss.dtype != core.VarDesc.VarType.FP32:
# cast loss here will change the effective loss tensor for the computation graph # cast loss here will change the effective loss tensor for the computation graph
...@@ -626,10 +765,12 @@ class AMPPass(PassBase): ...@@ -626,10 +765,12 @@ class AMPPass(PassBase):
tmp_name = unique_name.generate(loss.name + ".cast_fp32") tmp_name = unique_name.generate(loss.name + ".cast_fp32")
cast_loss = main_block.create_var(name=tmp_name, dtype=dtype) cast_loss = main_block.create_var(name=tmp_name, dtype=dtype)
loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program( loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
loss) loss
)
ref_mesh = loss_op_dist_attr.process_mesh ref_mesh = loss_op_dist_attr.process_mesh
self.dist_context.set_tensor_dist_attr_for_program( self.dist_context.set_tensor_dist_attr_for_program(
cast_loss, loss_dist_attr) cast_loss, loss_dist_attr
)
loss_op_idx = find_op_index(main_block.desc, loss_op.desc) loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
cast_op = main_block._insert_op( cast_op = main_block._insert_op(
...@@ -641,16 +782,21 @@ class AMPPass(PassBase): ...@@ -641,16 +782,21 @@ class AMPPass(PassBase):
"in_dtype": loss.dtype, "in_dtype": loss.dtype,
"out_dtype": core.VarDesc.VarType.FP32, "out_dtype": core.VarDesc.VarType.FP32,
'op_role': loss_op.all_attrs()[OP_ROLE_KEY], 'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
}) },
)
loss_op._set_attr(OP_ROLE_KEY, loss_op._set_attr(
core.op_proto_and_checker_maker.OpRole.Forward) OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context) cast_op, ref_mesh, [-1], self.dist_context
)
loss = loss.astype('float32') loss = loss.astype('float32')
if self.get_attr("use_dynamic_loss_scaling" if self.amp_dtype == "float16" and (
) or self.get_attr("init_loss_scaling") != 1.0: self.get_attr("use_dynamic_loss_scaling")
or self.get_attr("init_loss_scaling") != 1.0
):
loss_op_idx = find_op_index(main_block.desc, loss_op.desc) loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
...@@ -660,63 +806,76 @@ class AMPPass(PassBase): ...@@ -660,63 +806,76 @@ class AMPPass(PassBase):
name=unique_name.generate("scaled_loss"), name=unique_name.generate("scaled_loss"),
shape=loss.shape, shape=loss.shape,
dtype=loss.dtype, dtype=loss.dtype,
persistable=loss.persistable) persistable=loss.persistable,
set_var_dist_attr(self.dist_context, self._scaled_loss, [-1], )
ref_mesh) set_var_dist_attr(
self.dist_context, self._scaled_loss, [-1], ref_mesh
)
elementwise_mul_op = main_block._insert_op( elementwise_mul_op = main_block._insert_op(
loss_op_idx + 1, loss_op_idx + 1,
type='elementwise_mul', type='elementwise_mul',
inputs={ inputs={'X': [loss], 'Y': [self._loss_scaling]},
'X': [loss],
'Y': [self._loss_scaling]
},
outputs={'Out': [self._scaled_loss]}, outputs={'Out': [self._scaled_loss]},
attrs={ attrs={
'op_role': loss_op.all_attrs()[OP_ROLE_KEY], 'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
}) },
loss_op._set_attr(OP_ROLE_KEY, )
core.op_proto_and_checker_maker.OpRole.Forward) loss_op._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_op, ref_mesh, [-1], self.dist_context) elementwise_mul_op, ref_mesh, [-1], self.dist_context
)
# backward # backward
first_backward_op = main_block.ops[loss_op_idx + 2] first_backward_op = main_block.ops[loss_op_idx + 2]
assert first_backward_op.type == "fill_constant" and int( assert (
first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 first_backward_op.type == "fill_constant"
and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
)
self._scaled_loss_grad = main_block.create_var( self._scaled_loss_grad = main_block.create_var(
name=unique_name.generate("scaled_loss") + "@GRAD", name=unique_name.generate("scaled_loss") + "@GRAD",
shape=loss.shape, shape=loss.shape,
dtype=loss.dtype, dtype=loss.dtype,
persistable=loss.persistable) persistable=loss.persistable,
set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1], )
ref_mesh) set_var_dist_attr(
self.dist_context, self._scaled_loss_grad, [-1], ref_mesh
)
pre_grad_name = first_backward_op.output_arg_names[0] pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(pre_grad_name, first_backward_op._rename_output(
self._scaled_loss_grad.name) pre_grad_name, self._scaled_loss_grad.name
)
# FIXME(JZ-LIANG) a trick to insert backward op # FIXME(JZ-LIANG) a trick to insert backward op
main_block._sync_with_cpp() main_block._sync_with_cpp()
elementwise_mul_grad_op_desc = main_block.desc._insert_op( elementwise_mul_grad_op_desc = main_block.desc._insert_op(
loss_op_idx + 3) loss_op_idx + 3
)
elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad") elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
elementwise_mul_grad_op_desc.set_input( elementwise_mul_grad_op_desc.set_input(
'Out@GRAD', [self._scaled_loss_grad.name]) 'Out@GRAD', [self._scaled_loss_grad.name]
)
elementwise_mul_grad_op_desc.set_input('X', [loss.name]) elementwise_mul_grad_op_desc.set_input('X', [loss.name])
elementwise_mul_grad_op_desc.set_input('Y', elementwise_mul_grad_op_desc.set_input(
[self._loss_scaling.name]) 'Y', [self._loss_scaling.name]
)
elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name]) elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
elementwise_mul_grad_op_desc.set_output('Y@GRAD', []) elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
elementwise_mul_grad_op_desc._set_attr( elementwise_mul_grad_op_desc._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward) OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward
)
elementwise_mul_grad_op_desc._set_attr('axis', -1) elementwise_mul_grad_op_desc._set_attr('axis', -1)
elementwise_mul_grad_op = paddle.fluid.framework.Operator( elementwise_mul_grad_op = paddle.fluid.framework.Operator(
main_block, elementwise_mul_grad_op_desc) main_block, elementwise_mul_grad_op_desc
)
main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op) main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op)
main_block._sync_with_cpp() main_block._sync_with_cpp()
elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3] elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
assert elementwise_mul_grad_op.type == "elementwise_mul_grad" assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context) elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context
)
else: else:
self._scaled_loss = loss self._scaled_loss = loss
...@@ -728,31 +887,39 @@ class AMPPass(PassBase): ...@@ -728,31 +887,39 @@ class AMPPass(PassBase):
main_block = paddle.static.default_main_program().global_block() main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp() main_block._sync_with_cpp()
check_variable_and_dtype(self._loss_scaling, "prev_loss_scaling", check_variable_and_dtype(
['float32', 'float64'], "update_loss_scaling") self._loss_scaling,
"prev_loss_scaling",
['float32', 'float64'],
"update_loss_scaling",
)
check_type(grads, 'x', (tuple, list), 'update_loss_scaling') check_type(grads, 'x', (tuple, list), 'update_loss_scaling')
for e in grads: for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], check_variable_and_dtype(
'update_loss_scaling') e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
)
if e.dtype == core.VarDesc.VarType.FP16: if e.dtype == core.VarDesc.VarType.FP16:
assert self._loss_scaling.dtype == core.VarDesc.VarType.FP32, \ assert (
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." self._loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else: else:
assert self._loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." assert (
self._loss_scaling.dtype == e.dtype
), "The dtype of prev_loss_scaling should be equal to the dtype of x."
inputs = { inputs = {
'X': grads, 'X': grads,
'FoundInfinite': found_inf, 'FoundInfinite': found_inf,
'PrevLossScaling': self._loss_scaling, 'PrevLossScaling': self._loss_scaling,
'InGoodSteps': self._num_good_steps, 'InGoodSteps': self._num_good_steps,
'InBadSteps': self._num_bad_steps 'InBadSteps': self._num_bad_steps,
} }
outputs = { outputs = {
'Out': grads, 'Out': grads,
'LossScaling': self._loss_scaling, 'LossScaling': self._loss_scaling,
'OutGoodSteps': self._num_good_steps, 'OutGoodSteps': self._num_good_steps,
'OutBadSteps': self._num_bad_steps 'OutBadSteps': self._num_bad_steps,
} }
attrs = { attrs = {
...@@ -761,13 +928,15 @@ class AMPPass(PassBase): ...@@ -761,13 +928,15 @@ class AMPPass(PassBase):
'incr_ratio': self.get_attr("incr_ratio"), 'incr_ratio': self.get_attr("incr_ratio"),
'decr_ratio': self.get_attr("decr_ratio"), 'decr_ratio': self.get_attr("decr_ratio"),
'stop_update': self.get_attr("stop_update"), 'stop_update': self.get_attr("stop_update"),
'op_role': OpRole.Optimize 'op_role': OpRole.Optimize,
} }
new_op = main_block.append_op(type='update_loss_scaling', new_op = main_block.append_op(
type='update_loss_scaling',
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
attrs=attrs) attrs=attrs,
)
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = world_process_group.ranks new_op_dist_attr.process_mesh = world_process_group.ranks
...@@ -777,10 +946,22 @@ class AMPPass(PassBase): ...@@ -777,10 +946,22 @@ class AMPPass(PassBase):
for g in grads: for g in grads:
g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g) g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(g.name, new_op_dist_attr.set_input_dims_mapping(
g_dist_attr.dims_mapping) g.name, g_dist_attr.dims_mapping
new_op_dist_attr.set_output_dims_mapping(g.name, )
g_dist_attr.dims_mapping) new_op_dist_attr.set_output_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
main_block._sync_with_cpp() main_block._sync_with_cpp()
def get_loss(self):
# the amp might change the effective loss variable for network and
# therefore would affect the subsequent passes that rely on the loss.
# return the effective loss after amp pass.
if self._loss:
return self._loss
else:
return self.get_attr("loss")
...@@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import ( ...@@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.fluid.contrib.mixed_precision.fp16_utils import ( from paddle.fluid.contrib.mixed_precision.fp16_lists import (
AutoMixedPrecisionLists, AutoMixedPrecisionLists,
) )
from paddle.fluid.contrib.mixed_precision.fp16_utils import ( from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_keep_layer_norm_scale_bias_to_fp32, _keep_layer_norm_scale_bias_to_fp32,
_need_keep_fp32, _need_keep_fp32,
_valid_types, _valid_types,
_dtype_to_str,
) )
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute, OperatorDistributedAttribute,
...@@ -55,6 +54,23 @@ __amp_skip_ops__ = [ ...@@ -55,6 +54,23 @@ __amp_skip_ops__ = [
'while', 'while',
'cast', 'cast',
] ]
__target_dtype__ = None
def _dtype_to_str(dtype):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.FP16:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return 'fp16'
elif dtype == core.VarDesc.VarType.BF16:
return 'bf16'
else:
return 'fp32'
def set_op_dtype_to_fp16(op): def set_op_dtype_to_fp16(op):
...@@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op): ...@@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op):
op.has_attr('in_dtype') op.has_attr('in_dtype')
and op.attr('in_dtype') == core.VarDesc.VarType.FP32 and op.attr('in_dtype') == core.VarDesc.VarType.FP32
): ):
op._set_attr('in_dtype', core.VarDesc.VarType.FP16) op._set_attr('in_dtype', __target_dtype__)
if ( if (
op.has_attr('out_dtype') op.has_attr('out_dtype')
and op.attr('out_dtype') == core.VarDesc.VarType.FP32 and op.attr('out_dtype') == core.VarDesc.VarType.FP32
): ):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16) op._set_attr('out_dtype', __target_dtype__)
if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32: if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16) op._set_attr('dtype', __target_dtype__)
if __target_dtype__ == core.VarDesc.VarType.BF16:
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
if op.has_attr('mkldnn_data_type'):
op._set_attr('mkldnn_data_type', 'bfloat16')
# adapot for backward op # adapot for backward op
...@@ -156,6 +178,7 @@ class FP16State(object): ...@@ -156,6 +178,7 @@ class FP16State(object):
list list
) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]} ) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
self.is_train = False self.is_train = False
self.out_var_op_deps = {}
def _is_fp16_op(self, op_id): def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None) return self._op_fp16_dict.get(op_id, None)
...@@ -169,6 +192,14 @@ class FP16State(object): ...@@ -169,6 +192,14 @@ class FP16State(object):
# assume all backward block are behind forward blocks # assume all backward block are behind forward blocks
for block in self.program.blocks: for block in self.program.blocks:
for op in block.ops: for op in block.ops:
for name in op.output_arg_names:
if name not in self.out_var_op_deps:
self.out_var_op_deps[name] = [op.desc.original_id()]
else:
self.out_var_op_deps[name].extend(
[op.desc.original_id()]
)
self._mark_op(op) self._mark_op(op)
# set forward tensor dtype # set forward tensor dtype
...@@ -192,6 +223,18 @@ class FP16State(object): ...@@ -192,6 +223,18 @@ class FP16State(object):
if op.type == "assign" and "array_" in op.input_arg_names[0]: if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
return return
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
if op.type == "assign":
out_name = op.output_arg_names[0]
if len(self.out_var_op_deps[out_name]) > 1:
if not self._op_fp16_dict[
self.out_var_op_deps[out_name][0]
]:
self._op_fp16_dict[op.desc.original_id()] = False
else:
self._op_fp16_dict[op.desc.original_id()] = True
return
if _need_keep_fp32( if _need_keep_fp32(
op, self.amp_list.unsupported_list, self.use_fp16_guard op, self.amp_list.unsupported_list, self.use_fp16_guard
): ):
...@@ -228,7 +271,7 @@ class FP16State(object): ...@@ -228,7 +271,7 @@ class FP16State(object):
return return
if var.dtype == core.VarDesc.VarType.FP32: if var.dtype == core.VarDesc.VarType.FP32:
var.desc.set_dtype(core.VarDesc.VarType.FP16) var.desc.set_dtype(__target_dtype__)
def resolute_tensor_dtype(self, block): def resolute_tensor_dtype(self, block):
...@@ -260,7 +303,7 @@ class FP16State(object): ...@@ -260,7 +303,7 @@ class FP16State(object):
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == __target_dtype__:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
elif is_backward_op(op): elif is_backward_op(op):
if self._is_fp16_op(op.desc.original_id()) == True: if self._is_fp16_op(op.desc.original_id()) == True:
...@@ -276,7 +319,7 @@ class FP16State(object): ...@@ -276,7 +319,7 @@ class FP16State(object):
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == __target_dtype__:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
def cast_block(self, block): def cast_block(self, block):
...@@ -295,7 +338,7 @@ class FP16State(object): ...@@ -295,7 +338,7 @@ class FP16State(object):
op, op,
idx, idx,
block, block,
core.VarDesc.VarType.FP16, __target_dtype__,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
self.dist_context, self.dist_context,
) )
...@@ -305,7 +348,7 @@ class FP16State(object): ...@@ -305,7 +348,7 @@ class FP16State(object):
idx, idx,
block, block,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, __target_dtype__,
self.dist_context, self.dist_context,
) )
elif is_backward_op(op): elif is_backward_op(op):
...@@ -315,7 +358,7 @@ class FP16State(object): ...@@ -315,7 +358,7 @@ class FP16State(object):
op, op,
idx, idx,
block, block,
core.VarDesc.VarType.FP16, __target_dtype__,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
self.dist_context, self.dist_context,
) )
...@@ -325,7 +368,7 @@ class FP16State(object): ...@@ -325,7 +368,7 @@ class FP16State(object):
idx, idx,
block, block,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, __target_dtype__,
self.dist_context, self.dist_context,
) )
elif op.type == "sum": elif op.type == "sum":
...@@ -399,6 +442,9 @@ class FP16State(object): ...@@ -399,6 +442,9 @@ class FP16State(object):
dist_context, cast_var, ref_mapping, ref_mesh dist_context, cast_var, ref_mapping, ref_mesh
) )
op_namescope = "/"
if op.has_attr('op_namescope'):
op_namescope = op.attr('op_namescope')
cast_op = block._insert_op_without_sync( cast_op = block._insert_op_without_sync(
idx, idx,
type="cast", type="cast",
...@@ -410,6 +456,9 @@ class FP16State(object): ...@@ -410,6 +456,9 @@ class FP16State(object):
OP_ROLE_KEY: OpRole.Forward, OP_ROLE_KEY: OpRole.Forward,
}, },
) )
cast_op._set_attr(
'op_namescope', op_namescope
) # for recompute
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context cast_op, ref_mesh, ref_mapping, dist_context
) )
...@@ -455,22 +504,36 @@ class FP16State(object): ...@@ -455,22 +504,36 @@ class FP16State(object):
) in self.forward_input_cast_ops[forward_op_id]: ) in self.forward_input_cast_ops[forward_op_id]:
# rename input # rename input
# some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
if op.type != "scale" and slot_name in op.input_names:
assert src_name in op.input( assert src_name in op.input(
slot_name slot_name
), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op)) ), "var: {} not in op's {}. {}".format(
src_name, slot_name, str(op)
)
src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name) src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
assert src_var_dist_attr is not None assert src_var_dist_attr is not None
op._rename_input(src_name, cast_name) op._rename_input(src_name, cast_name)
grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr) grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr)
# NOTE Special for scale op, scale op's grad op is scale,
# so slot name map rule could not apply to grad scale op
# cast_name: mean_0.tmp_0.cast_bf16, src_name: mean_0.tmp_0, dst_dtype: paddle.bfloat16, src_dtype: paddle.float32, slot_name: X.
if op.type == "scale":
grad_slot_name = "X"
# create cast grad # create cast grad
else:
grad_slot_name = slot_name + "@GRAD" grad_slot_name = slot_name + "@GRAD"
assert grad_slot_name in op.output_names
if grad_slot_name in op.output_names:
# some forward input maybe stop_gradient=True, e.g. input_mask
if len(op.output(grad_slot_name)) == 0: if len(op.output(grad_slot_name)) == 0:
var = block.var(src_name)
assert var.stop_gradient is True
continue continue
assert len(op.output(grad_slot_name)) == 1 assert (
len(op.output(grad_slot_name)) == 1
), "[{}], Current Op: {}".format(grad_slot_name, str(op))
grad_name = op.output(grad_slot_name)[0] grad_name = op.output(grad_slot_name)[0]
grad = block.var(grad_name) grad = block.var(grad_name)
grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name) grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name)
...@@ -492,7 +555,9 @@ class FP16State(object): ...@@ -492,7 +555,9 @@ class FP16State(object):
cast_grad, grad_dist_attr cast_grad, grad_dist_attr
) )
op._rename_output(grad_name, cast_grad.name) op._rename_output(grad_name, cast_grad.name)
grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr) grad_op_attr.set_output_dist_attr(
cast_grad.name, grad_dist_attr
)
# add cast # add cast
cast_op = block._insert_op_without_sync( cast_op = block._insert_op_without_sync(
...@@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): ...@@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
def _split_grads(params_grads): def _split_grads(params_grads):
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] fp16_grads = [g for g in grads if g.dtype == __target_dtype__]
assert len(fp32_grads) + len(fp16_grads) == len( assert len(fp32_grads) + len(fp16_grads) == len(
grads grads
), "Data types of all grads must be either fp16 or fp32." ), "Data types of all grads must be either fp16 or fp32."
...@@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): ...@@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
# TODO to support CUDAPinned/NPU/XPU Places # TODO to support CUDAPinned/NPU/XPU Places
if direction == "D2H": if direction == "D2H":
dst_place_type = 0 dst_place_type = 0
elif direction == "D2H":
dst_place_type = 1
else: else:
raise NotImplementedError( raise NotImplementedError(
"direction [{}] is not supported yet.".format(direction) f"direction [{direction}] is not supported yet."
) )
attrs = {'dst_place_type': dst_place_type} attrs = {'dst_place_type': dst_place_type}
new_op = block._insert_op_without_sync( new_op = block._insert_op_without_sync(
index=idx, index=idx,
type='memcpy', type='memcpy_d2h',
inputs={'X': [src_var]}, inputs={'X': [src_var]},
outputs={'Out': [output_var]}, outputs={'Out': [output_var]},
attrs=attrs, attrs=attrs,
...@@ -678,17 +741,17 @@ def cast_startup_program(): ...@@ -678,17 +741,17 @@ def cast_startup_program():
for op in startup_program.global_block().ops: for op in startup_program.global_block().ops:
if is_initialization_op(op): if is_initialization_op(op):
output_name = op.output_arg_names[0] output_name = op.output_arg_names[0]
if ( if param_to_dtype.get(output_name, None) == __target_dtype__:
param_to_dtype.get(output_name, None)
== core.VarDesc.VarType.FP16
):
assert op.has_attr( assert op.has_attr(
'dtype' 'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format( ), "initialization op is supported to has dtype attribute but got {}.".format(
str(op) str(op)
) )
out_var = startup_program.global_block().var(output_name)
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(__target_dtype__)
if op.attr('dtype') == core.VarDesc.VarType.FP32: if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16) op._set_attr('dtype', __target_dtype__)
@register_pass("auto_parallel_fp16") @register_pass("auto_parallel_fp16")
...@@ -701,14 +764,44 @@ class FP16Pass(AMPPass): ...@@ -701,14 +764,44 @@ class FP16Pass(AMPPass):
# in distributed scenario, all ranks should have the same modification. # in distributed scenario, all ranks should have the same modification.
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context") self.dist_context = self.get_attr("dist_context")
self.target_dtype = self.get_attr("dtype")
params_grads = self.get_attr("params_grads") params_grads = self.get_attr("params_grads")
self.use_optimizer_fp16 = self.get_attr("use_optimizer_fp16", None)
if self.use_optimizer_fp16 is None:
self.use_optimizer_fp16 = self.get_attr("level", None) == "o3"
# swith enviroment for fp16 / bf16.
if self.target_dtype == "float16":
__target_dtype = core.VarDesc.VarType.FP16
elif self.target_dtype == "bfloat16":
__target_dtype = core.VarDesc.VarType.BF16
else:
raise NotImplementedError(
"target dtype [{}] is for amp o2 not supported yet.".format(
self.target_dtype
)
)
global __target_dtype__
__target_dtype__ = __target_dtype
amp_list = AutoMixedPrecisionLists( amp_list = AutoMixedPrecisionLists(
set(self.get_attr("custom_white_list")), set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")), set(self.get_attr("custom_black_list")),
None, dtype=self.target_dtype,
) )
amp_list.unsupported_list -= {
"conditional_block_grad",
"conditional_block",
"conditional_block_infer",
"select_input",
"while",
"while_grad",
"cast",
"tensor_array_to_tensor",
"lod_array_length",
"write_to_array",
}
# NOTE don't not change input data dtype, since it is controled by dataloader # NOTE don't not change input data dtype, since it is controled by dataloader
# and which is out of control of FP16 Pass # and which is out of control of FP16 Pass
input_data_var_names = [var.name for var in self.get_attr("input_data")] input_data_var_names = [var.name for var in self.get_attr("input_data")]
...@@ -726,6 +819,7 @@ class FP16Pass(AMPPass): ...@@ -726,6 +819,7 @@ class FP16Pass(AMPPass):
cast_startup_program() cast_startup_program()
if is_train: if is_train:
if self.target_dtype == "fp16":
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference # TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var() self._init_amp_var()
...@@ -801,10 +895,12 @@ class FP16Pass(AMPPass): ...@@ -801,10 +895,12 @@ class FP16Pass(AMPPass):
# modify optimizer # modify optimizer
base_opt = self.get_attr("base_opt") base_opt = self.get_attr("base_opt")
base_opt._multi_precision = True base_opt._multi_precision = True
if self.get_attr("use_optimizer_fp16"): if self.use_optimizer_fp16:
base_opt._multi_precision = False base_opt._multi_precision = False
if self.target_dtype == "fp16":
if isinstance( if isinstance(
base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW) base_opt,
(paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW),
): ):
with main_program._optimized_guard([]): with main_program._optimized_guard([]):
# found_inf = paddle.tensor.creation._memcpy( # found_inf = paddle.tensor.creation._memcpy(
......
...@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS}) py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS})
set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50) TIMEOUT 50)
py_test_modules(test_amp_o2_pass MODULES test_amp_o2_pass ENVS ${dist_ENVS})
set_tests_properties(test_amp_o2_pass PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS}) ${dist_ENVS})
set_tests_properties(test_iterable_dataset set_tests_properties(test_iterable_dataset
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.framework import core
paddle.enable_static()
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def apply_pass(use_amp=False, amp_dtype="bfloat16"):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.dtype = amp_dtype
amp.level = "o2"
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingStage2WithNewEXE(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, amp_dtype="bfloat16"):
reset_prog()
strategy = apply_pass(use_amp, amp_dtype)
# clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
clip = None
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_bf16(self, program):
num_bf16 = 0
num_fp16 = 0
num_fp32 = 0
for p in program.all_parameters():
if p.dtype == core.VarDesc.VarType.FP32:
num_fp32 += 1
if p.dtype == core.VarDesc.VarType.FP16:
num_fp16 += 1
if p.dtype == core.VarDesc.VarType.BF16:
num_bf16 += 1
self.assertEqual(num_bf16, 25)
self.assertEqual(num_fp16, 0)
self.assertEqual(num_fp32, 11)
def test_param_grad_fuse_overlap(self):
# std
mp_engine = self.get_engine(use_amp=False)
mp_history = mp_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss0 = mp_history.history['loss'][0]
# bf16
mp_bf16_engine = self.get_engine(use_amp=True)
if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000:
return
mp_bf16_history = mp_bf16_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss1 = mp_bf16_history.history['loss'][0]
np.testing.assert_allclose(loss0, loss1, atol=1e-3, rtol=1e-2)
self.check_bf16(mp_bf16_engine.main_program)
if __name__ == "__main__":
unittest.main()
...@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None): ...@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
] ]
amp.init_loss_scaling = 32768 amp.init_loss_scaling = 32768
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"] amp.level = level
amp.use_optimizer_fp16 = level == "o3" amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level) print("amp level: ", level)
return strategy return strategy
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestAMPO2(unittest.TestCase):
def test_bf16(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "amp_o2_pass.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
# limitations under the License. # limitations under the License.
import os import os
# import yaml # import yaml
import unittest import unittest
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
class TestStrategy(unittest.TestCase): class TestStrategy(unittest.TestCase):
def test_default_config(self): def test_default_config(self):
strategy = auto.Strategy() strategy = auto.Strategy()
...@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase): ...@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp amp = strategy.amp
self.assertEqual(amp.enable, False) self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.dtype, "float16")
self.assertAlmostEqual(amp.level, "o1")
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0) self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000) self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2) self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
...@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase): ...@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_black_list, []) self.assertEqual(amp.custom_black_list, [])
self.assertEqual(amp.custom_white_list, []) self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, []) self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_pure_fp16, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_optimizer_fp16, False) self.assertEqual(amp.use_optimizer_fp16, False)
sharding = strategy.sharding sharding = strategy.sharding
...@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase): ...@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
amp.custom_white_list = ["x"] amp.custom_white_list = ["x"]
amp.custom_black_list = ["y"] amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"] amp.custom_black_varnames = ["z"]
amp.use_pure_fp16 = True
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True) self.assertEqual(amp.enable, True)
...@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase): ...@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_white_list, ["x"]) self.assertEqual(amp.custom_white_list, ["x"])
self.assertEqual(amp.custom_black_list, ["y"]) self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"]) self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_pure_fp16, True)
self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True) self.assertEqual(amp.use_optimizer_fp16, True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册