未验证 提交 6f3c9643 编写于 作者: J JZ-LIANG 提交者: GitHub

Eb118 BF16 Adoption (#52827)

* pr1

* pr2

* pr3

* fixed unitest

* adopt for scale
上级 8cbc75ca
......@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
AMP = "amp"
set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "dtype", "float16")
set_field_default_config(AMP, "level", "o1")
set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
......@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_bf16_guard", False)
set_field_default_config(AMP, "use_optimizer_fp16", False)
#########################################
......
......@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype(
Out_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'c_allreduce_sum',
)
......@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype(
Out_grad,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
......@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
},
)
check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
......
......@@ -20,7 +20,11 @@ from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import gradient_synchronization
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from .common import (
set_comm_op_dist_attr_for_program,
naive_copy_op_dist_attr_for_program,
is_parameter_related,
)
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
......@@ -33,24 +37,39 @@ from paddle.fluid import core, unique_name
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from paddle.distributed.fleet.meta_optimizers.common import (
OpRole,
OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs
from ..cost import (
build_comp_desc_from_dist_op,
build_comm_desc_from_dist_op,
build_dp_costs,
)
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs
from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost
from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
AllreduceSumOpCost,
IdentityOpCost,
)
def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping):
if trans_x:
x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[
-2], x_dims_mapping[-1]
x_dims_mapping[-1], x_dims_mapping[-2] = (
x_dims_mapping[-2],
x_dims_mapping[-1],
)
if trans_y:
y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[
-2], y_dims_mapping[-1]
y_dims_mapping[-1], y_dims_mapping[-2] = (
y_dims_mapping[-2],
y_dims_mapping[-1],
)
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
......@@ -123,13 +142,17 @@ def _update_dims_mapping_for_matmul(dist_op):
for i in range(new_out_dims_mapping_len - 2):
broadcast_out_dims_mapping.append(out_dims_mapping[i])
compatible_dims_mapping = compute_compatible_dims_mapping([
broadcast_x_dims_mapping, broadcast_y_dims_mapping,
broadcast_out_dims_mapping
])
compatible_dims_mapping = compute_compatible_dims_mapping(
[
broadcast_x_dims_mapping,
broadcast_y_dims_mapping,
broadcast_out_dims_mapping,
]
)
if compatible_dims_mapping is None:
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping,
y_dims_mapping)
trans_x_y_dims_mapping(
trans_x, trans_y, x_dims_mapping, y_dims_mapping
)
return False
for i in range(new_x_dims_mapping_len - 2):
......@@ -152,17 +175,20 @@ def _update_dims_mapping_for_matmul(dist_op):
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, y_dims_mapping], [-1, -2])
[x_dims_mapping, y_dims_mapping], [-1, -2]
)
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [-2, -2])
[x_dims_mapping, out_dims_mapping], [-2, -2]
)
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[y_dims_mapping, out_dims_mapping], [-1, -1])
[y_dims_mapping, out_dims_mapping], [-1, -1]
)
if dim_changed:
changed = True
......@@ -202,7 +228,8 @@ def _is_auto_compatible_for_matmul(dist_op):
x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name))
y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name))
out_dims_mapping = copy.deepcopy(
op_dist_attr.get_output_dims_mapping(out_name))
op_dist_attr.get_output_dims_mapping(out_name)
)
x_dims_mapping_len = len(x_dims_mapping)
y_dims_mapping_len = len(y_dims_mapping)
out_dims_mapping_len = len(out_dims_mapping)
......@@ -234,22 +261,23 @@ def _is_auto_compatible_for_matmul(dist_op):
for i in range(out_dims_mapping_len - 2):
broadcast_out_dims_mapping.append(out_dims_mapping[i])
is_same = ((broadcast_x_dims_mapping == broadcast_y_dims_mapping)
and (broadcast_x_dims_mapping == broadcast_out_dims_mapping))
is_same = (broadcast_x_dims_mapping == broadcast_y_dims_mapping) and (
broadcast_x_dims_mapping == broadcast_out_dims_mapping
)
if not is_same:
return False
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
is_same = (x_dims_mapping[-1] == y_dims_mapping[-2])
is_same = x_dims_mapping[-1] == y_dims_mapping[-2]
if not is_same:
return False
is_same = (x_dims_mapping[-2] == out_dims_mapping[-2])
is_same = x_dims_mapping[-2] == out_dims_mapping[-2]
if not is_same:
return False
is_same = (y_dims_mapping[-1] == out_dims_mapping[-1])
is_same = y_dims_mapping[-1] == out_dims_mapping[-1]
if not is_same:
return False
......@@ -265,8 +293,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
backward_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
assert (
dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(backward_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.process_mesh.processes:
......@@ -277,22 +306,26 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD')
assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD')
assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD')
assert len(
assert (
len(kwargs['Y']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Y']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Y'])
assert len(
)
assert (
len(kwargs['X']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['X']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['X'])
assert len(
kwargs['Out@GRAD']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Out'])
assert len(
)
assert (
len(kwargs['Out@GRAD']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Out']
)
assert (
len(kwargs['Y@GRAD']) == 1
), "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['Y@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['Y@GRAD'])
)
X_var = main_block.var(kwargs['X'][0])
Y_var = main_block._var_recursive(kwargs['Y'][0])
......@@ -302,7 +335,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert not is_parameter_related(
X_var.name, main_block
), "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name)
X_var.name
)
X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name)
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
......@@ -339,28 +373,34 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
parallel_axis = Y_var_dim_mapping[0]
check_variable_and_dtype(
Out_grad, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')
Out_grad,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])) + "@GRAD",
name=unique_name.generate_with_ignorable_key(
".".join(["c_identity", 'tmp'])
)
+ "@GRAD",
dtype=Out_grad.dtype,
shape=Out_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_grad.stop_gradient)
stop_gradient=Out_grad.stop_gradient,
)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
assert out_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_grad_dist_attr)
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, out_grad_dist_attr
)
group_ranks = _get_comm_group(process_mesh_group,
process_mesh_shape, parallel_axis,
rank_id)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks)
c_identity_op = main_block.append_op(
type='c_identity',
......@@ -371,20 +411,29 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'],
'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
set_comm_op_dist_attr_for_program(c_identity_op,
dist_attr.process_mesh,
out_grad_dist_attr, ctx)
},
)
check_variable_and_dtype(
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
set_comm_op_dist_attr_for_program(
c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
)
new_kwargs = copy.deepcopy(kwargs)
new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output(
ctx, main_block, backward_op, **new_kwargs)
ctx, main_block, backward_op, **new_kwargs
)
else:
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
......@@ -397,28 +446,36 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert len(kwargs['X@GRAD']) == 1
X_grad = main_block.var(kwargs['X@GRAD'][0])
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])) + "@GRAD",
name=unique_name.generate_with_ignorable_key(
".".join(["c_identity", 'tmp'])
)
+ "@GRAD",
dtype=X_grad.dtype,
shape=X_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_grad.stop_gradient)
stop_gradient=X_grad.stop_gradient,
)
X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name)
assert X_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
X_grad_dist_attr)
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, X_grad_dist_attr
)
new_kwargs['X@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output(
ctx, main_block, backward_op, **new_kwargs)
ctx, main_block, backward_op, **new_kwargs
)
# NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
if has_x_grad:
group_ranks = _get_comm_group(process_mesh_group,
process_mesh_shape, parallel_axis,
rank_id)
group_ranks = _get_comm_group(
process_mesh_group,
process_mesh_shape,
parallel_axis,
rank_id,
)
group = new_process_group(group_ranks)
c_allreduce_sum_op = main_block.append_op(
type='c_allreduce_sum',
......@@ -428,15 +485,20 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward
})
set_comm_op_dist_attr_for_program(c_allreduce_sum_op,
OP_ROLE_KEY: OpRole.Backward,
},
)
set_comm_op_dist_attr_for_program(
c_allreduce_sum_op,
dist_attr.process_mesh,
X_grad_dist_attr, ctx)
X_grad_dist_attr,
ctx,
)
else:
# replicate
matmul_op_desc = copy_op_with_new_input_output(ctx, main_block,
backward_op, **kwargs)
matmul_op_desc = copy_op_with_new_input_output(
ctx, main_block, backward_op, **kwargs
)
# data parallel gradient synchronization
act_grad_names = [X_var.name]
......@@ -448,8 +510,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
if trans_x:
trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
gradient_synchronization(ctx, backward_op, act_grad_names, out_grad_names,
rank_id)
gradient_synchronization(
ctx, backward_op, act_grad_names, out_grad_names, rank_id
)
if trans_x:
trans_x_y_dims_mapping(True, False, X_var_dims_mapping, None)
......@@ -472,23 +535,25 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis, rank_id)
group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, axis, rank_id
)
sync_group = new_process_group(group_ranks)
startup_block.append_op(type='c_broadcast',
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
OP_ROLE_KEY: OpRole.Forward,
},
)
class DistributedMatmul(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedMatmul, self).__init__(op_type)
......@@ -498,7 +563,6 @@ register_distributed_operator_impl_container(DistributedMatmul("matmul"))
# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl0, self).__init__(name)
self._forward_implemented = True
......@@ -521,7 +585,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
backward_op.input("Y")[0]
)
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
......@@ -531,13 +596,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert len(backward_op.output("X@GRAD")) == 1
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx,
processes, desc_mapping,
cluster)
cost_mapping = build_comp_costs_from_descs(
MatmulGradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
# calc comm op cost
......@@ -550,40 +616,52 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes,
c_allreduce_sum_desc_mapping, cluster)
AllreduceSumOpCost,
ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res.append(comm_op_cost_list)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
build_dp_costs(
res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes,
desc_mapping, cluster)
cost_mapping = build_comp_costs_from_descs(
MatmulOpCost, ctx, processes, desc_mapping, cluster
)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-1]
serial_op.input("Y")[0]
)[-1]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.input("X")
c_identity_desc_mapping = build_comm_desc_from_dist_op(
......@@ -592,10 +670,12 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res_cost = [comm_op_cost_list, cost_mapping]
return res_cost
......@@ -606,16 +686,19 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(x_name))
op_dist_attr.get_input_dims_mapping(x_name)
)
y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name))
op_dist_attr.get_input_dims_mapping(y_name)
)
trans_x = op_desc.attr('transpose_X')
trans_y = op_desc.attr('transpose_Y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
y_dims_mapping[-1]):
y_dims_mapping[-1]
):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
......@@ -635,8 +718,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
if (not self.is_input_compatible(dist_op)) or (
not self.is_output_compatible(dist_op)
):
return False
if not _is_auto_compatible_for_matmul(dist_op):
return False
......@@ -661,28 +745,33 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
assert (
op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
input_name
)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
output_name
)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
output_name
)
X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block.var(kwargs['Y'][0])
......@@ -692,18 +781,24 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1]
Weight_var.name
)[-1]
if trans_y:
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping)
Weight_var.name
)[-2]
assert (
matmul_col_dim_mapping >= 0
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks)
# infer new var shape with op dist attr
......@@ -711,31 +806,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert x_tensor_dist_attr is not None
identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
assert identity_var_dist_attr is not None
ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr,
identity_var_dist_attr)
ref_shape_x = infer_shape(
main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
)
# infer out var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr,
out_var_dist_attr)
ref_shape_out = infer_shape(
main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])),
name=unique_name.generate_with_ignorable_key(
".".join(["c_identity", 'tmp'])
),
dtype=X_var.dtype,
shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
stop_gradient=X_var.stop_gradient,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
identity_var_dist_attr)
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, identity_var_dist_attr
)
check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')
X_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
c_identity_op = main_block.append_op(
type='c_identity',
......@@ -745,26 +848,34 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
OP_ROLE_KEY: src_op.attr('op_role'),
},
)
if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'], 'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
check_variable_and_dtype(
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
'transpose_X': trans_x,
'transpose_Y': trans_y,
'alpha': 1,
OP_ROLE_KEY: src_op.attr('op_role')
OP_ROLE_KEY: src_op.attr('op_role'),
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
matmul_op = main_block.append_op(type='matmul',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)
matmul_op = main_block.append_op(
type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
)
if Out_var.shape != ref_shape_out:
Out_var.desc.set_shape(ref_shape_out)
......@@ -778,13 +889,16 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
input_varname = c_identity_op.desc.input_arg_names()[0]
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
identity_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
op_dist_attr
)
identity_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
# output
output_varname = c_identity_op.desc.output_arg_names()[0]
identity_op_dist_attr.set_output_dist_attr(output_varname,
input_dist_attr)
identity_op_dist_attr.set_output_dist_attr(
output_varname, input_dist_attr
)
# set op dist attr
ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)
......@@ -797,31 +911,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
for input_varname in matmul_op.desc.input_arg_names():
if input_varname in src_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(
input_varname)
input_varname
)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmul_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
op_dist_attr
)
matmul_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
else:
input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
input_var)
matmul_op_dist_attr.set_input_dist_attr(input_varname,
tensor_dist_attr)
input_var
)
matmul_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr
)
# output
output_varname = matmul_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmul_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
op_dist_attr
)
matmul_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
# set op dist attr
ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
# init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
_init_param_sync(
Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod
def backward(ctx, *args, **kwargs):
......@@ -830,7 +952,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl1, self).__init__(name)
self._forward_implemented = True
......@@ -853,7 +974,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
backward_op.input("Y")[0]
)
assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0]
......@@ -866,50 +988,60 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res.append(comm_op_cost_list)
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx,
processes, desc_mapping,
cluster)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
cost_mapping = build_comp_costs_from_descs(
MatmulGradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
build_dp_costs(
res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes,
desc_mapping, cluster)
cost_mapping = build_comp_costs_from_descs(
MatmulOpCost, ctx, processes, desc_mapping, cluster
)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-2]
serial_op.input("Y")[0]
)[-2]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
......@@ -919,11 +1051,16 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)
AllreduceSumOpCost,
ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res_cost = [cost_mapping, comm_op_cost_list]
......@@ -935,16 +1072,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(x_name))
op_dist_attr.get_input_dims_mapping(x_name)
)
y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name))
op_dist_attr.get_input_dims_mapping(y_name)
)
trans_x = op_desc.attr('transpose_X')
trans_y = op_desc.attr('transpose_Y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
y_dims_mapping[-1]):
y_dims_mapping[-1]
):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
......@@ -966,8 +1106,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
if (not self.is_input_compatible(dist_op)) or (
not self.is_output_compatible(dist_op)
):
return False
if not _is_auto_compatible_for_matmul(dist_op):
return False
......@@ -992,28 +1133,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
assert (
op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
input_name
)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
output_name
)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
output_name
)
X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block.var(kwargs['Y'][0])
......@@ -1023,29 +1169,40 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2]
Weight_var.name
)[-2]
if trans_y:
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping)
Weight_var.name
)[-1]
assert (
matmul_row_dim_mapping >= 0
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks)
check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'],
'linear')
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
'linear')
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
)
check_dtype(
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
'transpose_X': trans_x,
'transpose_Y': trans_y,
'alpha': 1,
OP_ROLE_KEY: src_op.attr('op_role')
OP_ROLE_KEY: src_op.attr('op_role'),
}
inputs = {'X': X_var, 'Y': Weight_var}
......@@ -1054,27 +1211,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr,
out_var_dist_attr)
ref_shape = infer_shape(
main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_allreduce_sum", 'tmp'])),
name=unique_name.generate_with_ignorable_key(
".".join(["c_allreduce_sum", 'tmp'])
),
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
lod_level=Out_var.lod_level,
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
need_check_feed=Out_var.desc.need_check_feed(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_var_dist_attr)
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, out_var_dist_attr
)
matmul_op = main_block.append_op(type='matmul',
matmul_op = main_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)
attrs=attrs,
)
if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape)
......@@ -1086,8 +1249,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
OP_ROLE_KEY: src_op.attr('op_role'),
},
)
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)
......@@ -1100,15 +1264,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
for input_varname in matmul_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmul_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
op_dist_attr
)
matmul_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
output_varname = matmul_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmul_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
op_dist_attr
)
matmul_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
# allreduce
......@@ -1120,21 +1288,26 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(input_varname,
tensor_dist_attr)
allreduce_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr
)
for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
allreduce_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
ctx.set_op_dist_attr_for_program(c_allreduce_sum_op,
allreduce_op_dist_attr)
op_dist_attr
)
allreduce_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(
c_allreduce_sum_op, allreduce_op_dist_attr
)
# init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
_init_param_sync(
Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod
def backward(ctx, *args, **kwargs):
......@@ -1143,7 +1316,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# ReplicateParallel
class DistributedMatmulImpl2(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl2, self).__init__(name)
......@@ -1164,38 +1336,45 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
vars = main_block.vars
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx,
processes, desc_mapping,
cluster)
cost_mapping = build_comp_costs_from_descs(
MatmulGradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
build_dp_costs(
res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes,
desc_mapping, cluster)
cost_mapping = build_comp_costs_from_descs(
MatmulOpCost, ctx, processes, desc_mapping, cluster
)
res_cost = [cost_mapping]
return res_cost
......@@ -1211,13 +1390,15 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
x_dims_mapping[-2]):
x_dims_mapping[-2]
):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
y_dims_mapping[-2]):
y_dims_mapping[-2]
):
return False
return True
......@@ -1231,14 +1412,16 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
out_dims_mapping[-2]):
out_dims_mapping[-2]
):
return False
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
if (not self.is_input_compatible(dist_op)) or (
not self.is_output_compatible(dist_op)
):
return False
if not _is_auto_compatible_for_matmul(dist_op):
......@@ -1262,16 +1445,18 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
register_distributed_operator_impl("matmul",
DistributedMatmulImpl0("column_parallel"))
register_distributed_operator_impl("matmul",
DistributedMatmulImpl1("row_parallel"))
register_distributed_operator_impl("matmul",
DistributedMatmulImpl2("replicate_parallel"))
register_distributed_operator_impl(
"matmul", DistributedMatmulImpl0("column_parallel")
)
register_distributed_operator_impl(
"matmul", DistributedMatmulImpl1("row_parallel")
)
register_distributed_operator_impl(
"matmul", DistributedMatmulImpl2("replicate_parallel")
)
class DistributedMatmulV2(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedMatmulV2, self).__init__(op_type)
......@@ -1281,7 +1466,6 @@ register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
# ColumnParallel
class DistributedMatmulV2Impl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl0, self).__init__(name)
self._forward_implemented = True
......@@ -1304,7 +1488,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
backward_op.input("Y")[0]
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
# col parallel: matmul + allreduce
......@@ -1318,12 +1503,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
assert len(backward_op.output("X@GRAD")) == 1
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx,
processes, desc_mapping,
cluster)
cost_mapping = build_comp_costs_from_descs(
MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
# calc comm op cost
......@@ -1336,45 +1522,55 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes,
c_allreduce_sum_desc_mapping, cluster)
AllreduceSumOpCost,
ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res.append(comm_op_cost_list)
# need gradient allreduce
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
build_dp_costs(
res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
# TODO: trans shape if trans_x or trans_y is True
comp_desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
comp_desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
comp_cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx,
processes,
comp_desc_mapping,
cluster)
comp_cost_mapping = build_comp_costs_from_descs(
MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster
)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-1]
serial_op.input("Y")[0]
)[-1]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.input("X")
......@@ -1384,9 +1580,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res_cost = [comm_op_cost_list, comp_cost_mapping]
return res_cost
......@@ -1397,16 +1595,19 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(x_name))
op_dist_attr.get_input_dims_mapping(x_name)
)
y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name))
op_dist_attr.get_input_dims_mapping(y_name)
)
trans_x = op_desc.attr('trans_x')
trans_y = op_desc.attr('trans_y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
y_dims_mapping[-1]):
y_dims_mapping[-1]
):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
......@@ -1426,8 +1627,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
if (not self.is_input_compatible(dist_op)) or (
not self.is_output_compatible(dist_op)
):
return False
if not _is_auto_compatible_for_matmul(dist_op):
return False
......@@ -1452,28 +1654,33 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
assert (
op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
input_name
)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
output_name
)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
output_name
)
X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0])
......@@ -1483,18 +1690,24 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1]
Weight_var.name
)[-1]
if trans_y:
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping)
Weight_var.name
)[-2]
assert (
matmul_col_dim_mapping >= 0
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks)
# infer new var shape with op dist attr
......@@ -1502,31 +1715,39 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
assert x_tensor_dist_attr is not None
identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
assert identity_var_dist_attr is not None
ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr,
identity_var_dist_attr)
ref_shape_x = infer_shape(
main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
)
# infer out var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr,
out_var_dist_attr)
ref_shape_out = infer_shape(
main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])),
name=unique_name.generate_with_ignorable_key(
".".join(["c_identity", 'tmp'])
),
dtype=X_var.dtype,
shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
stop_gradient=X_var.stop_gradient,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
identity_var_dist_attr)
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, identity_var_dist_attr
)
check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')
X_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
c_identity_op = main_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
......@@ -1536,24 +1757,35 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role'),
})
},
)
if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'], 'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
check_variable_and_dtype(
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
'trans_x': trans_x,
'trans_y': trans_y,
OP_ROLE_KEY: src_op.attr('op_role')
OP_ROLE_KEY: src_op.attr('op_role'),
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
matmul_v2_op = main_block.append_op(type='matmul_v2',
matmul_v2_op = main_block.append_op(
type='matmul_v2',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)
attrs=attrs,
)
if Out_var.shape != ref_shape_out:
Out_var.desc.set_shape(ref_shape_out)
......@@ -1567,13 +1799,16 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
input_varname = c_identity_op.desc.input_arg_names()[0]
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
identity_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
op_dist_attr
)
identity_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
# output
output_varname = c_identity_op.desc.output_arg_names()[0]
identity_op_dist_attr.set_output_dist_attr(output_varname,
input_dist_attr)
identity_op_dist_attr.set_output_dist_attr(
output_varname, input_dist_attr
)
ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)
# matmulv2
......@@ -1584,29 +1819,37 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
for input_varname in matmul_v2_op.desc.input_arg_names():
if input_varname in src_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(
input_varname)
input_varname
)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
op_dist_attr
)
matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr)
input_varname, input_dist_attr
)
else:
input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
input_var)
input_var
)
matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr)
input_varname, tensor_dist_attr
)
for output_varname in matmul_v2_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmulv2_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
op_dist_attr
)
matmulv2_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
# init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
_init_param_sync(
Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod
def backward(ctx, *args, **kwargs):
......@@ -1615,7 +1858,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# RowParallel
class DistributedMatmulV2Impl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl1, self).__init__(name)
self._forward_implemented = True
......@@ -1638,7 +1880,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
backward_op.input("Y")[0]
)
assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0]
......@@ -1653,50 +1896,59 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res.append(comm_op_cost_list)
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx,
processes, desc_mapping,
cluster)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
cost_mapping = build_comp_costs_from_descs(
MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
# need gradient allreduce
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
build_dp_costs(
res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx,
processes, desc_mapping,
cluster)
cost_mapping = build_comp_costs_from_descs(
MatmulV2OpCost, ctx, processes, desc_mapping, cluster
)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-2]
serial_op.input("Y")[0]
)[-2]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
......@@ -1706,11 +1958,16 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)
AllreduceSumOpCost,
ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res_cost = [cost_mapping, comm_op_cost_list]
return res_cost
......@@ -1721,16 +1978,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(x_name))
op_dist_attr.get_input_dims_mapping(x_name)
)
y_dims_mapping = copy.deepcopy(
op_dist_attr.get_input_dims_mapping(y_name))
op_dist_attr.get_input_dims_mapping(y_name)
)
trans_x = op_desc.attr('trans_x')
trans_y = op_desc.attr('trans_y')
trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping)
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
y_dims_mapping[-1]):
y_dims_mapping[-1]
):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
......@@ -1752,8 +2012,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
if (not self.is_input_compatible(dist_op)) or (
not self.is_output_compatible(dist_op)
):
return False
if not _is_auto_compatible_for_matmul(dist_op):
return False
......@@ -1778,28 +2039,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
assert (
op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
input_name
)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
output_name
)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
output_name
)
X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0])
......@@ -1809,28 +2075,39 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2]
Weight_var.name
)[-2]
if trans_y:
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping)
Weight_var.name
)[-1]
assert (
matmul_row_dim_mapping >= 0
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks)
check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'],
'linear')
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
'linear')
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
)
check_dtype(
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
'trans_x': trans_x,
'trans_y': trans_y,
OP_ROLE_KEY: src_op.attr('op_role')
OP_ROLE_KEY: src_op.attr('op_role'),
}
inputs = {'X': X_var, 'Y': Weight_var}
......@@ -1839,27 +2116,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr,
out_var_dist_attr)
ref_shape = infer_shape(
main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_allreduce_sum", 'tmp'])),
name=unique_name.generate_with_ignorable_key(
".".join(["c_allreduce_sum", 'tmp'])
),
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
lod_level=Out_var.lod_level,
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
need_check_feed=Out_var.desc.need_check_feed(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_var_dist_attr)
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, out_var_dist_attr
)
matmul_v2_op = main_block.append_op(type='matmul_v2',
matmul_v2_op = main_block.append_op(
type='matmul_v2',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)
attrs=attrs,
)
if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape)
......@@ -1871,8 +2154,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
OP_ROLE_KEY: src_op.attr('op_role'),
},
)
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)
......@@ -1885,15 +2169,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
for input_varname in matmul_v2_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmulv2_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
op_dist_attr
)
matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
output_varname = matmul_v2_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmulv2_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
op_dist_attr
)
matmulv2_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
# allreduce
......@@ -1905,21 +2193,26 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(input_varname,
tensor_dist_attr)
allreduce_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr
)
for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
allreduce_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
ctx.set_op_dist_attr_for_program(c_allreduce_sum_op,
allreduce_op_dist_attr)
op_dist_attr
)
allreduce_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(
c_allreduce_sum_op, allreduce_op_dist_attr
)
# init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
_init_param_sync(
Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod
def backward(ctx, *args, **kwargs):
......@@ -1928,7 +2221,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# ReplicateParallel
class DistributedMatmulV2Impl2(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl2, self).__init__(name)
......@@ -1950,38 +2242,44 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
process_mesh = dist_attr.process_mesh
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx,
processes, desc_mapping,
cluster)
cost_mapping = build_comp_costs_from_descs(
MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
build_dp_costs(
res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx,
processes, desc_mapping,
cluster)
cost_mapping = build_comp_costs_from_descs(
MatmulV2OpCost, ctx, processes, desc_mapping, cluster
)
res_cost = [cost_mapping]
......@@ -1998,13 +2296,15 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
x_dims_mapping[-2]):
x_dims_mapping[-2]
):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
y_dims_mapping[-2]):
y_dims_mapping[-2]
):
return False
return True
......@@ -2019,14 +2319,16 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
out_dims_mapping[-2]):
out_dims_mapping[-2]
):
return False
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
if (not self.is_input_compatible(dist_op)) or (
not self.is_output_compatible(dist_op)
):
return False
if not _is_auto_compatible_for_matmul(dist_op):
......@@ -2050,16 +2352,18 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
register_distributed_operator_impl("matmul_v2",
DistributedMatmulV2Impl0("column_parallel"))
register_distributed_operator_impl("matmul_v2",
DistributedMatmulV2Impl1("row_parallel"))
register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl2("replicate_parallel"))
"matmul_v2", DistributedMatmulV2Impl0("column_parallel")
)
register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl1("row_parallel")
)
register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl2("replicate_parallel")
)
class DistributedMul(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedMul, self).__init__(op_type)
......@@ -2069,7 +2373,6 @@ register_distributed_operator_impl_container(DistributedMul("mul"))
# ColumnParallel
class DistributedMulImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMulImpl0, self).__init__(name)
self._forward_implemented = True
......@@ -2092,7 +2395,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
backward_op.input("Y")[0]
)
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
......@@ -2102,13 +2406,14 @@ class DistributedMulImpl0(DistributedOperatorImpl):
assert len(backward_op.output("X@GRAD")) == 1
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx,
processes, desc_mapping,
cluster)
cost_mapping = build_comp_costs_from_descs(
MulGradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
# calc comm op cost
......@@ -2121,40 +2426,52 @@ class DistributedMulImpl0(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes,
c_allreduce_sum_desc_mapping, cluster)
AllreduceSumOpCost,
ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res.append(comm_op_cost_list)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
build_dp_costs(
res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes,
desc_mapping, cluster)
cost_mapping = build_comp_costs_from_descs(
MulOpCost, ctx, processes, desc_mapping, cluster
)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-1]
serial_op.input("Y")[0]
)[-1]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.input("X")
c_identity_desc_mapping = build_comm_desc_from_dist_op(
......@@ -2163,10 +2480,12 @@ class DistributedMulImpl0(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res_cost = [comm_op_cost_list, cost_mapping]
return res_cost
......@@ -2181,7 +2500,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(
y_dims_mapping[-1]):
y_dims_mapping[-1]
):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
......@@ -2201,8 +2521,9 @@ class DistributedMulImpl0(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
if (not self.is_input_compatible(dist_op)) or (
not self.is_output_compatible(dist_op)
):
return False
if not _is_auto_compatible_for_matmul(dist_op):
......@@ -2229,28 +2550,33 @@ class DistributedMulImpl0(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
assert (
op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
input_name
)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
output_name
)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
output_name
)
X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0])
......@@ -2258,15 +2584,20 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-1]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping)
Weight_var.name
)[-1]
assert (
matmul_col_dim_mapping >= 0
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks)
# infer new var shape with op dist attr
......@@ -2274,31 +2605,39 @@ class DistributedMulImpl0(DistributedOperatorImpl):
assert x_tensor_dist_attr is not None
identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name)
assert identity_var_dist_attr is not None
ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr,
identity_var_dist_attr)
ref_shape_x = infer_shape(
main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr
)
# infer out var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr,
out_var_dist_attr)
ref_shape_out = infer_shape(
main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])),
name=unique_name.generate_with_ignorable_key(
".".join(["c_identity", 'tmp'])
),
dtype=X_var.dtype,
shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
stop_gradient=X_var.stop_gradient,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
identity_var_dist_attr)
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, identity_var_dist_attr
)
check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')
X_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
c_identity_op = main_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
......@@ -2307,20 +2646,29 @@ class DistributedMulImpl0(DistributedOperatorImpl):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
OP_ROLE_KEY: src_op.attr('op_role'),
},
)
if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'], 'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
check_variable_and_dtype(
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs = {
"x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
"y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
OP_ROLE_KEY: src_op.attr('op_role')
OP_ROLE_KEY: src_op.attr('op_role'),
}
inputs = {'X': intermediate_var_0, 'Y': Weight_var}
......@@ -2334,16 +2682,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
inputs_original_shape[var_name] = var.shape
input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
input_ref_shape = infer_shape(main_block, var,
input_tensor_dist_attr,
input_var_dist_attr)
input_ref_shape = infer_shape(
main_block, var, input_tensor_dist_attr, input_var_dist_attr
)
inputs_ref_shape[var_name] = input_ref_shape
var.desc.set_shape(input_ref_shape)
mul_op = main_block.append_op(type='mul',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)
mul_op = main_block.append_op(
type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs
)
if Out_var.shape != ref_shape_out:
Out_var.desc.set_shape(ref_shape_out)
......@@ -2362,13 +2709,16 @@ class DistributedMulImpl0(DistributedOperatorImpl):
input_varname = c_identity_op.desc.input_arg_names()[0]
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
identity_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
op_dist_attr
)
identity_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
# output
output_varname = c_identity_op.desc.output_arg_names()[0]
identity_op_dist_attr.set_output_dist_attr(output_varname,
input_dist_attr)
identity_op_dist_attr.set_output_dist_attr(
output_varname, input_dist_attr
)
ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)
# matmulv2
......@@ -2379,29 +2729,37 @@ class DistributedMulImpl0(DistributedOperatorImpl):
for input_varname in mul_op.desc.input_arg_names():
if input_varname in src_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(
input_varname)
input_varname
)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
op_dist_attr
)
matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr)
input_varname, input_dist_attr
)
else:
input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
input_var)
input_var
)
matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr)
input_varname, tensor_dist_attr
)
for output_varname in mul_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmulv2_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
op_dist_attr
)
matmulv2_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)
# init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
_init_param_sync(
Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod
def backward(ctx, *args, **kwargs):
......@@ -2410,7 +2768,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# RowParallel
class DistributedMulImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMulImpl1, self).__init__(name)
self._forward_implemented = True
......@@ -2434,7 +2791,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
backward_op.input("Y")[0]
)
assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0]
......@@ -2447,49 +2805,59 @@ class DistributedMulImpl1(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res.append(comm_op_cost_list)
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx,
processes, desc_mapping,
cluster)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
cost_mapping = build_comp_costs_from_descs(
MulGradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
build_dp_costs(
res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes,
desc_mapping, cluster)
cost_mapping = build_comp_costs_from_descs(
MulOpCost, ctx, processes, desc_mapping, cluster
)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-2]
serial_op.input("Y")[0]
)[-2]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
......@@ -2499,12 +2867,17 @@ class DistributedMulImpl1(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
# print("dist_matmul.py dist_op: ", dist_op)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)
AllreduceSumOpCost,
ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res_cost = [cost_mapping, comm_op_cost_list]
......@@ -2520,7 +2893,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(
y_dims_mapping[-1]):
y_dims_mapping[-1]
):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
......@@ -2542,8 +2916,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
if (not self.is_input_compatible(dist_op)) or (
not self.is_output_compatible(dist_op)
):
return False
if not _is_auto_compatible_for_matmul(dist_op):
......@@ -2570,28 +2945,33 @@ class DistributedMulImpl1(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
assert (
op_dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
input_name
)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
output_name
)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
output_name
)
X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0])
......@@ -2599,26 +2979,36 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[-2]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping)
Weight_var.name
)[-2]
assert (
matmul_row_dim_mapping >= 0
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks)
check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'],
'linear')
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
'linear')
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
)
check_dtype(
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs = {
"x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
"y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
OP_ROLE_KEY: src_op.attr('op_role')
OP_ROLE_KEY: src_op.attr('op_role'),
}
inputs = {'X': X_var, 'Y': Weight_var}
......@@ -2627,22 +3017,26 @@ class DistributedMulImpl1(DistributedOperatorImpl):
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr,
out_var_dist_attr)
ref_shape = infer_shape(
main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_allreduce_sum", 'tmp'])),
name=unique_name.generate_with_ignorable_key(
".".join(["c_allreduce_sum", 'tmp'])
),
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
lod_level=Out_var.lod_level,
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
need_check_feed=Out_var.desc.need_check_feed(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_var_dist_attr)
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, out_var_dist_attr
)
inputs_ref_shape = {}
inputs_original_shape = {}
......@@ -2651,16 +3045,18 @@ class DistributedMulImpl1(DistributedOperatorImpl):
inputs_original_shape[var_name] = var.shape
input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
input_ref_shape = infer_shape(main_block, var,
input_tensor_dist_attr,
input_var_dist_attr)
input_ref_shape = infer_shape(
main_block, var, input_tensor_dist_attr, input_var_dist_attr
)
inputs_ref_shape[var_name] = input_ref_shape
var.desc.set_shape(input_ref_shape)
mul_op = main_block.append_op(type='mul',
mul_op = main_block.append_op(
type='mul',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)
attrs=attrs,
)
if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape)
......@@ -2678,8 +3074,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
OP_ROLE_KEY: src_op.attr('op_role'),
},
)
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)
......@@ -2693,15 +3090,19 @@ class DistributedMulImpl1(DistributedOperatorImpl):
for input_varname in mul_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmulv2_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
op_dist_attr
)
matmulv2_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
output_varname = mul_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
matmulv2_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
op_dist_attr
)
matmulv2_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)
# allreduce
......@@ -2713,21 +3114,26 @@ class DistributedMulImpl1(DistributedOperatorImpl):
input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(input_varname,
tensor_dist_attr)
allreduce_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr
)
for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
allreduce_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
ctx.set_op_dist_attr_for_program(c_allreduce_sum_op,
allreduce_op_dist_attr)
op_dist_attr
)
allreduce_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(
c_allreduce_sum_op, allreduce_op_dist_attr
)
# init param sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
_init_param_sync(
Weight_var, dist_op_context, startup_block, ctx, rank_id
)
@staticmethod
def backward(ctx, *args, **kwargs):
......@@ -2736,7 +3142,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# ReplicateParallel
class DistributedMulImpl2(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMulImpl2, self).__init__(name)
......@@ -2757,38 +3162,45 @@ class DistributedMulImpl2(DistributedOperatorImpl):
vars = main_block.vars
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx,
processes, desc_mapping,
cluster)
cost_mapping = build_comp_costs_from_descs(
MulGradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
and is_parameter_related(backward_op.input("Y")[0], main_block)
):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
build_dp_costs(
res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes,
desc_mapping, cluster)
cost_mapping = build_comp_costs_from_descs(
MulOpCost, ctx, processes, desc_mapping, cluster
)
res_cost = [cost_mapping]
return res_cost
......@@ -2804,12 +3216,14 @@ class DistributedMulImpl2(DistributedOperatorImpl):
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping, -2) and is_dim_shard(
x_dims_mapping[-2]):
x_dims_mapping[-2]
):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping, -2) and is_dim_shard(
y_dims_mapping[-2]):
y_dims_mapping[-2]
):
return False
return True
......@@ -2824,14 +3238,16 @@ class DistributedMulImpl2(DistributedOperatorImpl):
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping, -2) and is_dim_shard(
out_dims_mapping[-2]):
out_dims_mapping[-2]
):
return False
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
if (not self.is_input_compatible(dist_op)) or (
not self.is_output_compatible(dist_op)
):
return False
if not _is_auto_compatible_for_matmul(dist_op):
......@@ -2855,8 +3271,10 @@ class DistributedMulImpl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
register_distributed_operator_impl("mul",
DistributedMulImpl0("column_parallel"))
register_distributed_operator_impl(
"mul", DistributedMulImpl0("column_parallel")
)
register_distributed_operator_impl("mul", DistributedMulImpl1("row_parallel"))
register_distributed_operator_impl("mul",
DistributedMulImpl2("replicate_parallel"))
register_distributed_operator_impl(
"mul", DistributedMulImpl2("replicate_parallel")
)
......@@ -254,17 +254,26 @@ class Parallelizer:
self._dist_context.serial_feed_vars["inputs"]
+ self._dist_context.serial_feed_vars["labels"]
)
if config["use_pure_fp16"]:
self._logger.info(
"Applying AMP-{}-{} ...".format(
config["dtype"], config['level']
),
)
if config['level'] == "o1":
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_amp_pass.get_loss()
elif config['level'] in ['o2', 'o3']:
config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_fp16_pass.get_loss()
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
raise ValueError("AMP level should be one of o1, o2, o3")
# apply recompute pass
# recompute is then train-only optimization
......
......@@ -18,25 +18,48 @@ from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.distributed.auto_parallel.utils import get_loss_op, set_var_dist_attr
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group
from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _keep_fp32_output, find_op_index
from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op
from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.utils import (
get_loss_op,
set_var_dist_attr,
)
from paddle.distributed.auto_parallel.utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
)
from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
AutoMixedPrecisionLists,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_keep_fp32_input,
_keep_fp32_output,
find_op_index,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_valid_types,
find_true_post_op,
find_true_prev_op,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_is_in_black_varnames,
_dtype_to_str,
_rename_arg,
)
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
)
from ..auto_parallel.utils import is_forward_op, is_backward_op, is_loss_op
world_process_group = get_world_process_group()
class AMPState(object):
def __init__(self, block):
self._block = block
self._op_fp16_dict = {
} # op_id --> True/False. 'True' means that the current op is in fp16 mode.
self._op_fp16_dict = (
{}
) # op_id --> True/False. 'True' means that the current op is in fp16 mode.
self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name}
self.is_train = False
......@@ -55,7 +78,8 @@ class AMPState(object):
elif int(op.attr('op_role')) == int(OpRole.Backward):
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[
op.desc.original_id()]
op.desc.original_id()
]
if self._is_fp16_op(fwd_op_id) == True:
self._op_fp16_dict[op.desc.original_id()] = True
elif self._is_fp16_op(fwd_op_id) == False:
......@@ -78,7 +102,8 @@ class AMPState(object):
if op.type == 'create_py_reader' or op.type == 'read':
continue
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists):
op, amp_lists
):
self._op_fp16_dict[op.desc.original_id()] = False
continue
if op.type in amp_lists.black_list:
......@@ -98,17 +123,24 @@ class AMPState(object):
continue
elif in_var.op is op:
prev_op = find_true_prev_op(
ops, op, in_var_name)
ops, op, in_var_name
)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if self._is_fp16_op(prev_op.desc.original_id()) == False or \
prev_op.type in amp_lists.black_list:
if (
self._is_fp16_op(prev_op.desc.original_id())
== False
or prev_op.type in amp_lists.black_list
):
is_black_op = True
elif self._is_fp16_op(prev_op.desc.original_id()) == True or \
prev_op.type in amp_lists.white_list:
elif (
self._is_fp16_op(prev_op.desc.original_id())
== True
or prev_op.type in amp_lists.white_list
):
is_white_op = True
if is_black_op:
self._op_fp16_dict[op.desc.original_id()] = False
......@@ -131,19 +163,28 @@ class AMPState(object):
break
if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_cast_op_forward(
op, idx, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, dist_context)
op,
idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
dist_context,
)
elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_cast_op_forward(
op, idx, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dist_context)
op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
dist_context,
)
else:
pass
idx += num_cast_ops + 1
self._block._sync_with_cpp()
def _insert_cast_op_forward(self, op, idx, src_dtype, dst_dtype,
dist_context):
def _insert_cast_op_forward(
self, op, idx, src_dtype, dst_dtype, dist_context
):
"""
only for forward cast
modified from paddle.fluid.contrib.mixed_precision
......@@ -152,38 +193,45 @@ class AMPState(object):
var_name_dict = {}
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name):
op, in_name
):
continue
for in_var_name in op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(
dst_dtype)
cast_name = (
in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
)
out_var = self._block.vars.get(cast_name)
var_name_dict[in_var.name] = cast_name
consume_op_attr = dist_context.get_op_dist_attr_for_program(
op)
op
)
assert consume_op_attr is not None
if out_var is None or out_var.dtype != dst_dtype:
# NOTE we make the cast op and var's dist attr as the op that consume the
# cast var instead of the op which generates the var
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var.name)
in_var.name
)
assert in_var_dist_attr is not None
ref_mesh = in_var_dist_attr.process_mesh
ref_mapping = in_var_dist_attr.dims_mapping
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr)
cast_name, in_var_dist_attr
)
out_var = self._block.create_var(
name=cast_name,
dtype=dst_dtype,
persistable=False,
stop_gradient=in_var.stop_gradient)
set_var_dist_attr(dist_context, out_var, ref_mapping,
ref_mesh)
stop_gradient=in_var.stop_gradient,
)
set_var_dist_attr(
dist_context, out_var, ref_mapping, ref_mesh
)
cast_op = self._block._insert_op_without_sync(
idx,
......@@ -193,22 +241,29 @@ class AMPState(object):
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype,
})
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context)
cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1
else:
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var.name)
in_var.name
)
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr)
cast_name, in_var_dist_attr
)
_rename_arg(op, in_var.name, cast_name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dst_dtype)
self._var_name_dict[op.desc.original_id()] = var_name_dict
if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16:
if (
src_dtype == core.VarDesc.VarType.FP32
and dst_dtype == core.VarDesc.VarType.FP16
):
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
continue
......@@ -238,8 +293,9 @@ class AMPState(object):
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# which will affect the dist_op to insert allreduce_sum op.
op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
if is_backward_op(grad_op) and (is_forward_op(ops[idx - 1])
or is_loss_op(ops[idx - 1])):
if is_backward_op(grad_op) and (
is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1])
):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
......@@ -248,14 +304,22 @@ class AMPState(object):
if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(grad_op_orig_id) == False: # fp32
num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, dist_context,
appended_grad_times)
grad_op,
idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
dist_context,
appended_grad_times,
)
elif self._is_fp16_op(grad_op_orig_id) == True: # fp16
num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dist_context,
appended_grad_times)
grad_op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
dist_context,
appended_grad_times,
)
elif grad_op.type == "sum":
in_var_name = grad_op.desc.input_arg_names()[0]
src_dtype = self._block.var(in_var_name).dtype
......@@ -270,15 +334,24 @@ class AMPState(object):
else:
raise ValueError(
"'{}' op is not supported in the complete amp pass.".format(
grad_op.type))
grad_op.type
)
)
idx += num_cast_ops + 1
self._block._sync_with_cpp()
_update_backward_cast_ops(params_grads, dist_context)
def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype,
dist_context, appended_grad_times):
""" only for backward cast """
def _insert_cast_op_backward(
self,
grad_op,
idx,
src_dtype,
dst_dtype,
dist_context,
appended_grad_times,
):
"""only for backward cast"""
def _keep_fp32_input(op, in_name):
op_type = op.type
......@@ -299,7 +372,8 @@ class AMPState(object):
for in_name in grad_op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
grad_op, in_name):
grad_op, in_name
):
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
assert in_var.dtype == core.VarDesc.VarType.FP32
......@@ -309,24 +383,34 @@ class AMPState(object):
in_var = self._block._find_var_recursive(in_var_name)
if in_var.dtype == src_dtype:
consume_op_attr = dist_context.get_op_dist_attr_for_program(
grad_op)
grad_op
)
if in_var_name in self._var_name_dict[fwd_op_id]:
# NOTE: if in_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr.
cast_name = self._var_name_dict[fwd_op_id][in_var_name]
grad_op.desc._rename_input(in_var_name, cast_name)
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name)
in_var_name
)
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr)
cast_name, in_var_dist_attr
)
else:
assert in_var.dtype == dst_dtype, "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type, in_name, dst_dtype, in_var.dtype,
str(grad_op))
assert (
in_var.dtype == dst_dtype
), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type,
in_name,
dst_dtype,
in_var.dtype,
str(grad_op),
)
for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
grad_op, out_name):
grad_op, out_name
):
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
assert out_var.dtype == core.VarDesc.VarType.FP32
......@@ -334,7 +418,7 @@ class AMPState(object):
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
out_var_name_prefix = out_var_name[:out_var_name.find("@")]
out_var_name_prefix = out_var_name[: out_var_name.find("@")]
fwd_var = self._block._find_var_recursive(out_var_name_prefix)
# NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
if out_var.dtype != fwd_var.dtype:
......@@ -345,34 +429,45 @@ class AMPState(object):
# NOTE: if out_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr, then we insert cast op to
# convert the cast_var to original dtype
consume_op_attr = dist_context.get_op_dist_attr_for_program(
grad_op)
consume_op_attr = (
dist_context.get_op_dist_attr_for_program(grad_op)
)
fwd_cast_name = self._var_name_dict[fwd_op_id][
out_var_name_prefix]
out_var_name_prefix
]
suffix = ""
if "@RENAME" in out_var_name:
suffix = out_var_name[out_var_name.find("@RENAME"):]
suffix = out_var_name[
out_var_name.find("@RENAME") :
]
cast_name = fwd_cast_name + "@GRAD" + suffix
cast_var = self._block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dst_dtype:
grad_op.desc._rename_output(out_var_name, cast_name)
out_var_dist_attr = consume_op_attr.get_output_dist_attr(
out_var_name)
out_var_dist_attr = (
consume_op_attr.get_output_dist_attr(
out_var_name
)
)
ref_mesh = out_var_dist_attr.process_mesh
ref_mapping = out_var_dist_attr.dims_mapping
consume_op_attr.set_output_dist_attr(
cast_name, out_var_dist_attr)
cast_name, out_var_dist_attr
)
assert ref_mapping is not None
cast_var = self._block.create_var(
name=cast_name,
shape=out_var.shape,
dtype=dst_dtype,
persistable=False,
stop_gradient=out_var.stop_gradient)
set_var_dist_attr(dist_context, cast_var,
ref_mapping, ref_mesh)
stop_gradient=out_var.stop_gradient,
)
set_var_dist_attr(
dist_context, cast_var, ref_mapping, ref_mesh
)
dist_op_context.grad_var_to_var[
appended_grad_times][cast_name] = fwd_cast_name
appended_grad_times
][cast_name] = fwd_cast_name
cast_op = self._block._insert_op(
idx + 1,
......@@ -382,13 +477,15 @@ class AMPState(object):
attrs={
"in_dtype": cast_var.dtype,
"out_dtype": out_var.dtype,
"op_role": OpRole.Backward
})
"op_role": OpRole.Backward,
},
)
cast_op._remove_attr("op_role_var")
cast_op._remove_attr("op_namescope")
cast_op._remove_attr("with_quant_attr")
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context)
cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1
else:
assert out_var.dtype == dst_dtype
......@@ -409,15 +506,18 @@ def _update_backward_cast_ops(params_grads, dist_context):
for p, g in params_grads:
op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
if int(op.attr('op_role')) == int(
OpRole.Backward) and op.has_attr('op_role_var'):
if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
'op_role_var'
):
op._remove_attr("op_role_var")
post_ops = find_true_post_op(main_block.ops, op, g.name)
if post_ops:
raise ValueError("The cast op {0}'s output should not be"
raise ValueError(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0]))
"is used by {1}".format(op, post_ops[0])
)
if op == main_block.ops[-1]:
continue
......@@ -425,23 +525,29 @@ def _update_backward_cast_ops(params_grads, dist_context):
# add new op in the python and cpp at the same time
new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = paddle.fluid.framework.Operator(block=main_block,
new_op = paddle.fluid.framework.Operator(
block=main_block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None)
attrs=None,
)
main_block.ops.append(new_op)
# dist attr
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
main_block.var(op.output_arg_names[0]))
main_block.var(op.output_arg_names[0])
)
assert param_dist_attr is not None
assert output_dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, param_dist_attr.process_mesh,
param_dist_attr.dims_mapping, dist_context)
new_op,
param_dist_attr.process_mesh,
param_dist_attr.dims_mapping,
dist_context,
)
output_dist_attr.process_mesh = param_dist_attr.process_mesh
output_dist_attr.dims_mapping = param_dist_attr.dims_mapping
......@@ -462,26 +568,34 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
grads = [g for _, g in params_grads]
check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'check_finite_and_unscale')
check_variable_and_dtype(
e,
"x",
['float16', 'float32', 'float64'],
'check_finite_and_unscale',
)
found_inf = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
['find_infinite_scale', 'tmp'])),
name=unique_name.generate_with_ignorable_key(
".".join(['find_infinite_scale', 'tmp'])
),
shape=[1],
dtype='bool',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
stop_gradient=False,
)
set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale',
new_op = main_block.append_op(
type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
attrs=attrs)
attrs=attrs,
)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = world_process_group.ranks
......@@ -491,17 +605,18 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
for g in grads:
g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(g.name,
g_dist_attr.dims_mapping)
new_op_dist_attr.set_output_dims_mapping(g.name,
g_dist_attr.dims_mapping)
new_op_dist_attr.set_input_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
new_op_dist_attr.set_output_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
return grads, found_inf
@register_pass("auto_parallel_amp")
class AMPPass(PassBase):
def __init__(self):
super(AMPPass, self).__init__()
self.set_attr("loss", None)
......@@ -517,6 +632,7 @@ class AMPPass(PassBase):
self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("input_data", [])
self.set_attr("params_grads", [])
self.set_attr("dtype", "") # fp16/bf16
self._loss = None
self._loss_scaling = None
self._num_good_steps = None
......@@ -524,6 +640,8 @@ class AMPPass(PassBase):
self._loss = None
def _check_self(self):
if self.get_attr("dtype") not in ["float16", "bfloat16"]:
return False
if self.get_attr("init_loss_scaling") < 0:
return False
if self.get_attr("incr_every_n_steps") < 0:
......@@ -548,11 +666,13 @@ class AMPPass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
self.amp_dtype = self.get_attr("dtype")
amp_lists = AutoMixedPrecisionLists(
set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")),
set(self.get_attr("custom_black_varnames")))
set(self.get_attr("custom_black_varnames")),
)
with paddle.static.program_guard(main_program, startup_program):
amp_state = AMPState(main_program.global_block())
......@@ -566,10 +686,13 @@ class AMPPass(PassBase):
self._init_amp_var()
self._scale_loss()
if self.get_attr("use_dynamic_loss_scaling"
) or self.get_attr("init_loss_scaling") != 1.0:
if (
self.get_attr("use_dynamic_loss_scaling")
or self.get_attr("init_loss_scaling") != 1.0
):
grads, found_inf = _check_and_update_gradient(
params_grads, self._loss_scaling, self.dist_context)
params_grads, self._loss_scaling, self.dist_context
)
if self.get_attr("use_dynamic_loss_scaling"):
self._update_loss_scaling(grads, found_inf)
......@@ -580,9 +703,14 @@ class AMPPass(PassBase):
shape=[1],
value=self.get_attr("init_loss_scaling"),
dtype='float32',
persistable=True)
set_var_dist_attr(self.dist_context, self._loss_scaling, [-1],
world_process_group.ranks)
persistable=True,
)
set_var_dist_attr(
self.dist_context,
self._loss_scaling,
[-1],
world_process_group.ranks,
)
if self.get_attr("use_dynamic_loss_scaling"):
self._num_good_steps = paddle.static.create_global_var(
......@@ -590,18 +718,28 @@ class AMPPass(PassBase):
shape=[1],
value=0,
dtype='int32',
persistable=True)
set_var_dist_attr(self.dist_context, self._num_good_steps, [-1],
world_process_group.ranks)
persistable=True,
)
set_var_dist_attr(
self.dist_context,
self._num_good_steps,
[-1],
world_process_group.ranks,
)
self._num_bad_steps = paddle.static.create_global_var(
name=unique_name.generate("num_bad_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
set_var_dist_attr(self.dist_context, self._num_bad_steps, [-1],
world_process_group.ranks)
persistable=True,
)
set_var_dist_attr(
self.dist_context,
self._num_bad_steps,
[-1],
world_process_group.ranks,
)
def _scale_loss(self):
......@@ -613,7 +751,8 @@ class AMPPass(PassBase):
assert loss is not None
loss_op = loss.op
loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
loss_op)
loss_op
)
if loss.dtype != core.VarDesc.VarType.FP32:
# cast loss here will change the effective loss tensor for the computation graph
......@@ -626,10 +765,12 @@ class AMPPass(PassBase):
tmp_name = unique_name.generate(loss.name + ".cast_fp32")
cast_loss = main_block.create_var(name=tmp_name, dtype=dtype)
loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
loss)
loss
)
ref_mesh = loss_op_dist_attr.process_mesh
self.dist_context.set_tensor_dist_attr_for_program(
cast_loss, loss_dist_attr)
cast_loss, loss_dist_attr
)
loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
cast_op = main_block._insert_op(
......@@ -641,16 +782,21 @@ class AMPPass(PassBase):
"in_dtype": loss.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
})
},
)
loss_op._set_attr(OP_ROLE_KEY,
core.op_proto_and_checker_maker.OpRole.Forward)
loss_op._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context)
cast_op, ref_mesh, [-1], self.dist_context
)
loss = loss.astype('float32')
if self.get_attr("use_dynamic_loss_scaling"
) or self.get_attr("init_loss_scaling") != 1.0:
if self.amp_dtype == "float16" and (
self.get_attr("use_dynamic_loss_scaling")
or self.get_attr("init_loss_scaling") != 1.0
):
loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
......@@ -660,63 +806,76 @@ class AMPPass(PassBase):
name=unique_name.generate("scaled_loss"),
shape=loss.shape,
dtype=loss.dtype,
persistable=loss.persistable)
set_var_dist_attr(self.dist_context, self._scaled_loss, [-1],
ref_mesh)
persistable=loss.persistable,
)
set_var_dist_attr(
self.dist_context, self._scaled_loss, [-1], ref_mesh
)
elementwise_mul_op = main_block._insert_op(
loss_op_idx + 1,
type='elementwise_mul',
inputs={
'X': [loss],
'Y': [self._loss_scaling]
},
inputs={'X': [loss], 'Y': [self._loss_scaling]},
outputs={'Out': [self._scaled_loss]},
attrs={
'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
})
loss_op._set_attr(OP_ROLE_KEY,
core.op_proto_and_checker_maker.OpRole.Forward)
},
)
loss_op._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_op, ref_mesh, [-1], self.dist_context)
elementwise_mul_op, ref_mesh, [-1], self.dist_context
)
# backward
first_backward_op = main_block.ops[loss_op_idx + 2]
assert first_backward_op.type == "fill_constant" and int(
first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
assert (
first_backward_op.type == "fill_constant"
and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
)
self._scaled_loss_grad = main_block.create_var(
name=unique_name.generate("scaled_loss") + "@GRAD",
shape=loss.shape,
dtype=loss.dtype,
persistable=loss.persistable)
set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1],
ref_mesh)
persistable=loss.persistable,
)
set_var_dist_attr(
self.dist_context, self._scaled_loss_grad, [-1], ref_mesh
)
pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(pre_grad_name,
self._scaled_loss_grad.name)
first_backward_op._rename_output(
pre_grad_name, self._scaled_loss_grad.name
)
# FIXME(JZ-LIANG) a trick to insert backward op
main_block._sync_with_cpp()
elementwise_mul_grad_op_desc = main_block.desc._insert_op(
loss_op_idx + 3)
loss_op_idx + 3
)
elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
elementwise_mul_grad_op_desc.set_input(
'Out@GRAD', [self._scaled_loss_grad.name])
'Out@GRAD', [self._scaled_loss_grad.name]
)
elementwise_mul_grad_op_desc.set_input('X', [loss.name])
elementwise_mul_grad_op_desc.set_input('Y',
[self._loss_scaling.name])
elementwise_mul_grad_op_desc.set_input(
'Y', [self._loss_scaling.name]
)
elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
elementwise_mul_grad_op_desc._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward)
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward
)
elementwise_mul_grad_op_desc._set_attr('axis', -1)
elementwise_mul_grad_op = paddle.fluid.framework.Operator(
main_block, elementwise_mul_grad_op_desc)
main_block, elementwise_mul_grad_op_desc
)
main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op)
main_block._sync_with_cpp()
elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context)
elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context
)
else:
self._scaled_loss = loss
......@@ -728,31 +887,39 @@ class AMPPass(PassBase):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
check_variable_and_dtype(self._loss_scaling, "prev_loss_scaling",
['float32', 'float64'], "update_loss_scaling")
check_variable_and_dtype(
self._loss_scaling,
"prev_loss_scaling",
['float32', 'float64'],
"update_loss_scaling",
)
check_type(grads, 'x', (tuple, list), 'update_loss_scaling')
for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'update_loss_scaling')
check_variable_and_dtype(
e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling'
)
if e.dtype == core.VarDesc.VarType.FP16:
assert self._loss_scaling.dtype == core.VarDesc.VarType.FP32, \
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
assert (
self._loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else:
assert self._loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
assert (
self._loss_scaling.dtype == e.dtype
), "The dtype of prev_loss_scaling should be equal to the dtype of x."
inputs = {
'X': grads,
'FoundInfinite': found_inf,
'PrevLossScaling': self._loss_scaling,
'InGoodSteps': self._num_good_steps,
'InBadSteps': self._num_bad_steps
'InBadSteps': self._num_bad_steps,
}
outputs = {
'Out': grads,
'LossScaling': self._loss_scaling,
'OutGoodSteps': self._num_good_steps,
'OutBadSteps': self._num_bad_steps
'OutBadSteps': self._num_bad_steps,
}
attrs = {
......@@ -761,13 +928,15 @@ class AMPPass(PassBase):
'incr_ratio': self.get_attr("incr_ratio"),
'decr_ratio': self.get_attr("decr_ratio"),
'stop_update': self.get_attr("stop_update"),
'op_role': OpRole.Optimize
'op_role': OpRole.Optimize,
}
new_op = main_block.append_op(type='update_loss_scaling',
new_op = main_block.append_op(
type='update_loss_scaling',
inputs=inputs,
outputs=outputs,
attrs=attrs)
attrs=attrs,
)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = world_process_group.ranks
......@@ -777,10 +946,22 @@ class AMPPass(PassBase):
for g in grads:
g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(g.name,
g_dist_attr.dims_mapping)
new_op_dist_attr.set_output_dims_mapping(g.name,
g_dist_attr.dims_mapping)
new_op_dist_attr.set_input_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
new_op_dist_attr.set_output_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
main_block._sync_with_cpp()
def get_loss(self):
# the amp might change the effective loss variable for network and
# therefore would affect the subsequent passes that rely on the loss.
# return the effective loss after amp pass.
if self._loss:
return self._loss
else:
return self.get_attr("loss")
......@@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
from paddle.fluid.contrib.mixed_precision.fp16_lists import (
AutoMixedPrecisionLists,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_keep_layer_norm_scale_bias_to_fp32,
_need_keep_fp32,
_valid_types,
_dtype_to_str,
)
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
......@@ -55,6 +54,23 @@ __amp_skip_ops__ = [
'while',
'cast',
]
__target_dtype__ = None
def _dtype_to_str(dtype):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.FP16:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return 'fp16'
elif dtype == core.VarDesc.VarType.BF16:
return 'bf16'
else:
return 'fp32'
def set_op_dtype_to_fp16(op):
......@@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op):
op.has_attr('in_dtype')
and op.attr('in_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
op._set_attr('in_dtype', __target_dtype__)
if (
op.has_attr('out_dtype')
and op.attr('out_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
op._set_attr('out_dtype', __target_dtype__)
if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
op._set_attr('dtype', __target_dtype__)
if __target_dtype__ == core.VarDesc.VarType.BF16:
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
if op.has_attr('mkldnn_data_type'):
op._set_attr('mkldnn_data_type', 'bfloat16')
# adapot for backward op
......@@ -156,6 +178,7 @@ class FP16State(object):
list
) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
self.is_train = False
self.out_var_op_deps = {}
def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None)
......@@ -169,6 +192,14 @@ class FP16State(object):
# assume all backward block are behind forward blocks
for block in self.program.blocks:
for op in block.ops:
for name in op.output_arg_names:
if name not in self.out_var_op_deps:
self.out_var_op_deps[name] = [op.desc.original_id()]
else:
self.out_var_op_deps[name].extend(
[op.desc.original_id()]
)
self._mark_op(op)
# set forward tensor dtype
......@@ -192,6 +223,18 @@ class FP16State(object):
if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.original_id()] = False
return
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
if op.type == "assign":
out_name = op.output_arg_names[0]
if len(self.out_var_op_deps[out_name]) > 1:
if not self._op_fp16_dict[
self.out_var_op_deps[out_name][0]
]:
self._op_fp16_dict[op.desc.original_id()] = False
else:
self._op_fp16_dict[op.desc.original_id()] = True
return
if _need_keep_fp32(
op, self.amp_list.unsupported_list, self.use_fp16_guard
):
......@@ -228,7 +271,7 @@ class FP16State(object):
return
if var.dtype == core.VarDesc.VarType.FP32:
var.desc.set_dtype(core.VarDesc.VarType.FP16)
var.desc.set_dtype(__target_dtype__)
def resolute_tensor_dtype(self, block):
......@@ -260,7 +303,7 @@ class FP16State(object):
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP16:
if out_var.dtype == __target_dtype__:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
elif is_backward_op(op):
if self._is_fp16_op(op.desc.original_id()) == True:
......@@ -276,7 +319,7 @@ class FP16State(object):
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP16:
if out_var.dtype == __target_dtype__:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
def cast_block(self, block):
......@@ -295,7 +338,7 @@ class FP16State(object):
op,
idx,
block,
core.VarDesc.VarType.FP16,
__target_dtype__,
core.VarDesc.VarType.FP32,
self.dist_context,
)
......@@ -305,7 +348,7 @@ class FP16State(object):
idx,
block,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
__target_dtype__,
self.dist_context,
)
elif is_backward_op(op):
......@@ -315,7 +358,7 @@ class FP16State(object):
op,
idx,
block,
core.VarDesc.VarType.FP16,
__target_dtype__,
core.VarDesc.VarType.FP32,
self.dist_context,
)
......@@ -325,7 +368,7 @@ class FP16State(object):
idx,
block,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
__target_dtype__,
self.dist_context,
)
elif op.type == "sum":
......@@ -399,6 +442,9 @@ class FP16State(object):
dist_context, cast_var, ref_mapping, ref_mesh
)
op_namescope = "/"
if op.has_attr('op_namescope'):
op_namescope = op.attr('op_namescope')
cast_op = block._insert_op_without_sync(
idx,
type="cast",
......@@ -410,6 +456,9 @@ class FP16State(object):
OP_ROLE_KEY: OpRole.Forward,
},
)
cast_op._set_attr(
'op_namescope', op_namescope
) # for recompute
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context
)
......@@ -455,22 +504,36 @@ class FP16State(object):
) in self.forward_input_cast_ops[forward_op_id]:
# rename input
# some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
if op.type != "scale" and slot_name in op.input_names:
assert src_name in op.input(
slot_name
), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op))
), "var: {} not in op's {}. {}".format(
src_name, slot_name, str(op)
)
src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
assert src_var_dist_attr is not None
op._rename_input(src_name, cast_name)
grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr)
# NOTE Special for scale op, scale op's grad op is scale,
# so slot name map rule could not apply to grad scale op
# cast_name: mean_0.tmp_0.cast_bf16, src_name: mean_0.tmp_0, dst_dtype: paddle.bfloat16, src_dtype: paddle.float32, slot_name: X.
if op.type == "scale":
grad_slot_name = "X"
# create cast grad
else:
grad_slot_name = slot_name + "@GRAD"
assert grad_slot_name in op.output_names
if grad_slot_name in op.output_names:
# some forward input maybe stop_gradient=True, e.g. input_mask
if len(op.output(grad_slot_name)) == 0:
var = block.var(src_name)
assert var.stop_gradient is True
continue
assert len(op.output(grad_slot_name)) == 1
assert (
len(op.output(grad_slot_name)) == 1
), "[{}], Current Op: {}".format(grad_slot_name, str(op))
grad_name = op.output(grad_slot_name)[0]
grad = block.var(grad_name)
grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name)
......@@ -492,7 +555,9 @@ class FP16State(object):
cast_grad, grad_dist_attr
)
op._rename_output(grad_name, cast_grad.name)
grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr)
grad_op_attr.set_output_dist_attr(
cast_grad.name, grad_dist_attr
)
# add cast
cast_op = block._insert_op_without_sync(
......@@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
def _split_grads(params_grads):
grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
fp16_grads = [g for g in grads if g.dtype == __target_dtype__]
assert len(fp32_grads) + len(fp16_grads) == len(
grads
), "Data types of all grads must be either fp16 or fp32."
......@@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
# TODO to support CUDAPinned/NPU/XPU Places
if direction == "D2H":
dst_place_type = 0
elif direction == "D2H":
dst_place_type = 1
else:
raise NotImplementedError(
"direction [{}] is not supported yet.".format(direction)
f"direction [{direction}] is not supported yet."
)
attrs = {'dst_place_type': dst_place_type}
new_op = block._insert_op_without_sync(
index=idx,
type='memcpy',
type='memcpy_d2h',
inputs={'X': [src_var]},
outputs={'Out': [output_var]},
attrs=attrs,
......@@ -678,17 +741,17 @@ def cast_startup_program():
for op in startup_program.global_block().ops:
if is_initialization_op(op):
output_name = op.output_arg_names[0]
if (
param_to_dtype.get(output_name, None)
== core.VarDesc.VarType.FP16
):
if param_to_dtype.get(output_name, None) == __target_dtype__:
assert op.has_attr(
'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format(
str(op)
)
out_var = startup_program.global_block().var(output_name)
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(__target_dtype__)
if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
op._set_attr('dtype', __target_dtype__)
@register_pass("auto_parallel_fp16")
......@@ -701,14 +764,44 @@ class FP16Pass(AMPPass):
# in distributed scenario, all ranks should have the same modification.
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
self.target_dtype = self.get_attr("dtype")
params_grads = self.get_attr("params_grads")
self.use_optimizer_fp16 = self.get_attr("use_optimizer_fp16", None)
if self.use_optimizer_fp16 is None:
self.use_optimizer_fp16 = self.get_attr("level", None) == "o3"
# swith enviroment for fp16 / bf16.
if self.target_dtype == "float16":
__target_dtype = core.VarDesc.VarType.FP16
elif self.target_dtype == "bfloat16":
__target_dtype = core.VarDesc.VarType.BF16
else:
raise NotImplementedError(
"target dtype [{}] is for amp o2 not supported yet.".format(
self.target_dtype
)
)
global __target_dtype__
__target_dtype__ = __target_dtype
amp_list = AutoMixedPrecisionLists(
set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")),
None,
)
dtype=self.target_dtype,
)
amp_list.unsupported_list -= {
"conditional_block_grad",
"conditional_block",
"conditional_block_infer",
"select_input",
"while",
"while_grad",
"cast",
"tensor_array_to_tensor",
"lod_array_length",
"write_to_array",
}
# NOTE don't not change input data dtype, since it is controled by dataloader
# and which is out of control of FP16 Pass
input_data_var_names = [var.name for var in self.get_attr("input_data")]
......@@ -726,6 +819,7 @@ class FP16Pass(AMPPass):
cast_startup_program()
if is_train:
if self.target_dtype == "fp16":
with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var()
......@@ -801,10 +895,12 @@ class FP16Pass(AMPPass):
# modify optimizer
base_opt = self.get_attr("base_opt")
base_opt._multi_precision = True
if self.get_attr("use_optimizer_fp16"):
if self.use_optimizer_fp16:
base_opt._multi_precision = False
if self.target_dtype == "fp16":
if isinstance(
base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)
base_opt,
(paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW),
):
with main_program._optimized_guard([]):
# found_inf = paddle.tensor.creation._memcpy(
......
......@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS})
set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_amp_o2_pass MODULES test_amp_o2_pass ENVS ${dist_ENVS})
set_tests_properties(test_amp_o2_pass PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS})
set_tests_properties(test_iterable_dataset
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.framework import core
paddle.enable_static()
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def apply_pass(use_amp=False, amp_dtype="bfloat16"):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.dtype = amp_dtype
amp.level = "o2"
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingStage2WithNewEXE(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, amp_dtype="bfloat16"):
reset_prog()
strategy = apply_pass(use_amp, amp_dtype)
# clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
clip = None
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_bf16(self, program):
num_bf16 = 0
num_fp16 = 0
num_fp32 = 0
for p in program.all_parameters():
if p.dtype == core.VarDesc.VarType.FP32:
num_fp32 += 1
if p.dtype == core.VarDesc.VarType.FP16:
num_fp16 += 1
if p.dtype == core.VarDesc.VarType.BF16:
num_bf16 += 1
self.assertEqual(num_bf16, 25)
self.assertEqual(num_fp16, 0)
self.assertEqual(num_fp32, 11)
def test_param_grad_fuse_overlap(self):
# std
mp_engine = self.get_engine(use_amp=False)
mp_history = mp_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss0 = mp_history.history['loss'][0]
# bf16
mp_bf16_engine = self.get_engine(use_amp=True)
if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000:
return
mp_bf16_history = mp_bf16_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss1 = mp_bf16_history.history['loss'][0]
np.testing.assert_allclose(loss0, loss1, atol=1e-3, rtol=1e-2)
self.check_bf16(mp_bf16_engine.main_program)
if __name__ == "__main__":
unittest.main()
......@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"]
amp.level = level
amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level)
return strategy
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestAMPO2(unittest.TestCase):
def test_bf16(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "amp_o2_pass.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
......@@ -13,13 +13,13 @@
# limitations under the License.
import os
# import yaml
import unittest
from paddle.distributed.fleet import auto
class TestStrategy(unittest.TestCase):
def test_default_config(self):
strategy = auto.Strategy()
......@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp
self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.dtype, "float16")
self.assertAlmostEqual(amp.level, "o1")
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
......@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_black_list, [])
self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_pure_fp16, False)
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, False)
sharding = strategy.sharding
......@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
amp.custom_white_list = ["x"]
amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"]
amp.use_pure_fp16 = True
amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True)
......@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_white_list, ["x"])
self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_pure_fp16, True)
self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册