未验证 提交 3869a3b4 编写于 作者: J JZ-LIANG 提交者: GitHub

Cherry Pick Random Ctrl (#52778)

上级 6959eae5
......@@ -19,5 +19,6 @@ from .interface import shard_tensor
from .interface import shard_op
from .interface import recompute
from .interface import fetch
from .random import parallel_manual_seed
__all__ = []
......@@ -51,7 +51,7 @@ from .dist_loader import (
from .process_group import new_process_group, get_all_process_groups
from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy
from .interface import CollectionNames, get_collection
from .interface import CollectionNames, get_collection, fetch
from .utils import to_list, get_dist_attr, get_lr, validate_opt
from .utils import initialize_pg_in_full_mode, get_input_split_info
from .cost.estimate_cost import get_cost_from_engine
......@@ -438,6 +438,8 @@ class Engine:
), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__
)
else:
user_fetches = []
fetch_names = []
fetch_indices = []
......@@ -464,10 +466,13 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches:
var_name = _to_name_str(usr_fetch)
fetch(var_name)
user_fetches_collection = [
item[1] for item in get_collection(CollectionNames.FETCHES)
]
var_list = (user_fetches_collection or []) + (user_fetches or [])
var_list = user_fetches_collection or []
_process_fetch_group("fetches", var_list)
return fetch_names, fetch_indices
......@@ -522,10 +527,10 @@ class Engine:
# logging user fetches
collect_fetches = get_collection(CollectionNames.FETCHES)
logs_fetch = {}
for name, var in collect_fetches:
if var.name in fetch_names:
idx = fetch_names.index(var.name)
logs_fetch[name or var.name] = outs[idx]
for name, var_name in collect_fetches:
if var_name in fetch_names:
idx = fetch_names.index(var_name)
logs_fetch[name or var_name] = outs[idx]
logs["fetches"] = logs_fetch
return logs
......
......@@ -258,6 +258,16 @@ def add_to_collection(collection_name, value, name=None):
def fetch(tensor, name=None, logging=False):
if isinstance(tensor, paddle.fluid.framework.Variable):
tensor = tensor.name
elif isinstance(tensor, str):
tensor = tensor
else:
raise TypeError(
"Only support fetch `Variable` or `str`[`Variable`'s name], but got `{}`".format(
type(tensor)
)
)
add_to_collection(CollectionNames.FETCHES, tensor, name)
if logging:
add_to_collection(CollectionNames.LOGGING, tensor, name)
......@@ -36,3 +36,4 @@ from . import dist_reduce_sum_p
from . import dist_shape
from . import dist_assign
from . import dist_scale
from . import dist_dropout
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
import paddle
from paddle.framework import core
from paddle.utils import unique_name
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from ..utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedDropout(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedDropout("dropout"))
# Dist Dropout with Random Control
# Dropout re-use the compatible and cost function of elementwise
class DistributedDropoutImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute:
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
# check validation of inputs / outputs
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert (
len(kwargs['X']) == 1
), "input X should be only one tensor but got {}".format(
kwargs['X']
)
assert 'Seed' in kwargs, "input [{}] is not given".format('Seed')
if (
src_op.has_attr("fix_seed")
and src_op.attr("fix_seed")
and src_op.has_attr("seed")
and src_op.attr("seed")
):
_logger.info(
"Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format(
src_op
)
)
elif rank_id not in op_dist_attr.process_mesh.process_ids:
pass
# NOTE Adopt for recompute
# If user already set seed, We should not modify it. But if the seed is added by recompute pass, it should be under control.
# TODO in future recompute pass should happen after parallel partitione. and remove this at that time.
elif len(kwargs['Seed']) > 0 or len(src_op.input("Seed")) > 0:
seed_var_name = kwargs['Seed'][0]
if seed_var_name.startswith('rc_seed'):
pre_op = main_block.ops[-1]
assert (
pre_op.type == "seed"
and len(pre_op.attr("rng_name")) == 0
), f"found exception op {str(pre_op)}"
# determinate rng
X_var = main_block._var_recursive(kwargs['X'][0])
X_dims_mapping = op_dist_attr.get_input_dims_mapping(
X_var.name
)
process_mesh = op_dist_attr.process_mesh
rng_name = determinate_rng(
rank_id, X_dims_mapping, process_mesh
)
# make recompute seed under control
pre_op._set_attr("rng_name", rng_name)
pre_op._set_attr("deterministic", True)
pre_op._set_attr("force_cpu", True)
else:
_logger.info(
"Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format(
src_op
)
)
else:
# determinate rng
X_var = main_block._var_recursive(kwargs['X'][0])
X_dims_mapping = op_dist_attr.get_input_dims_mapping(X_var.name)
process_mesh = op_dist_attr.process_mesh
rng_name = determinate_rng(
rank_id, X_dims_mapping, process_mesh
)
assert rng_name is not None and rng_name != ""
# insert seed op
seed_var = main_block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(["tensor_parallel_seed", 'tmp'])
),
dtype=paddle.int32,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
# set new seed_var's dist_attr
seed_var_dims_mapping = [-1]
seed_var_dist_attr = set_var_dist_attr(
ctx, seed_var, seed_var_dims_mapping, process_mesh
)
# adopt for recompute
# force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
seed_op = main_block.append_op(
type='seed',
outputs={'Out': seed_var},
attrs={
'deterministic': True,
'rng_name': rng_name,
'force_cpu': True,
},
)
seed_op._set_attr('op_namescope', 'auto_tensor_parallel_seed')
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, process_mesh, seed_var_dims_mapping, ctx
)
# modify dropout op
src_op.desc.set_input("Seed", [seed_var.name])
src_op._remove_attr("fix_seed")
src_op._remove_attr("seed")
op_dist_attr.set_input_dist_attr(
seed_var.name, seed_var_dist_attr
)
kwargs['Seed'] = [seed_var.name]
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"dropout", DistributedDropoutImpl0("random_control")
)
......@@ -17,47 +17,74 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import gradient_synchronization
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from .common import (
register_distributed_operator_impl,
set_comm_op_dist_attr_for_program,
naive_copy_op_dist_attr_for_program,
is_parameter_related,
)
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from ..dist_attribute import (
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from paddle.fluid import core, unique_name
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from paddle.distributed.fleet.meta_optimizers.common import (
OpRole,
OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank, set_var_dist_attr
from ..utils import (
_get_comm_group,
_get_idx_in_axis,
_get_corresponding_rank,
set_var_dist_attr,
)
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs
from ..cost import (
build_comm_costs_from_descs,
build_comp_costs_from_descs,
build_dp_costs,
)
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
AllreduceSumOpCost,
IdentityOpCost,
)
class DistributedEmbedding(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedEmbedding, self).__init__(op_type)
register_distributed_operator_impl_container(
DistributedEmbedding("lookup_table_v2"))
DistributedEmbedding("lookup_table_v2")
)
register_distributed_operator_impl_container(
DistributedEmbedding("c_embedding"))
DistributedEmbedding("c_embedding")
)
register_distributed_operator_impl_container(
DistributedEmbedding("lookup_table"))
DistributedEmbedding("lookup_table")
)
def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
assert len(
Ids_var.shape
) == 3, "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format(
Ids_var.name, Ids_var.shape)
assert (
len(Ids_var.shape) == 3
), "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format(
Ids_var.name, Ids_var.shape
)
if not Ids_var.stop_gradient:
raise NotImplementedError(
'Requiring the gradient of Ids of lookup_table(v1)dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).'
......@@ -65,59 +92,72 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
target_shape = list(Ids_var.shape[:-1])
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["dist_reshape", 'tmp'])),
name=unique_name.generate_with_ignorable_key(
".".join(["dist_reshape", 'tmp'])
),
dtype=Ids_var.dtype,
shape=target_shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
stop_gradient=True,
)
target_shape = [0] + list(Ids_var.shape[:-1])
xshape_var = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["dist_Xshape", 'tmp'])),
name=unique_name.generate_with_ignorable_key(
".".join(["dist_Xshape", 'tmp'])
),
dtype=Ids_var.dtype,
shape=target_shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
stop_gradient=True,
)
# TODO use inplace reshape for memory saving
reshape_op = main_block.append_op(type='reshape2',
reshape_op = main_block.append_op(
type='reshape2',
inputs={'X': [Ids_var]},
outputs={
'Out': [intermediate_var_0],
'XShape': [xshape_var]
},
outputs={'Out': [intermediate_var_0], 'XShape': [xshape_var]},
attrs={
"shape": [0, -1],
})
},
)
# set dist attr
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
Ids_var_dist_attr = op_dist_attr.get_input_dist_attr(Ids_var.name)
assert Ids_var_dist_attr is not None
intermediate_var_0_dist_attr = set_var_dist_attr(
ctx, intermediate_var_0, Ids_var_dist_attr.dims_mapping,
Ids_var_dist_attr.process_mesh)
set_var_dist_attr(ctx, xshape_var,
ctx,
intermediate_var_0,
Ids_var_dist_attr.dims_mapping,
Ids_var_dist_attr.process_mesh,
)
set_var_dist_attr(
ctx,
xshape_var,
[-1] + list(Ids_var_dist_attr.dims_mapping),
Ids_var_dist_attr.process_mesh)
Ids_var_dist_attr.process_mesh,
)
op_dist_attr.del_input_dist_attr(Ids_var.name)
op_dist_attr.set_input_dist_attr(intermediate_var_0.name,
intermediate_var_0_dist_attr)
op_dist_attr.set_input_dist_attr(
intermediate_var_0.name, intermediate_var_0_dist_attr
)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh
new_op_dist_attr.impl_type = "default"
new_op_dist_attr.impl_idx = 0
new_op_dist_attr.set_input_dims_mapping(Ids_var.name,
Ids_var_dist_attr.dims_mapping)
new_op_dist_attr.set_output_dims_mapping(intermediate_var_0.name,
Ids_var_dist_attr.dims_mapping)
new_op_dist_attr.set_input_dims_mapping(
Ids_var.name, Ids_var_dist_attr.dims_mapping
)
new_op_dist_attr.set_output_dims_mapping(
xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping))
intermediate_var_0.name, Ids_var_dist_attr.dims_mapping
)
new_op_dist_attr.set_output_dims_mapping(
xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping)
)
ctx.set_op_dist_attr_for_program(reshape_op, new_op_dist_attr)
return intermediate_var_0
......@@ -125,7 +165,6 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
# RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__(name)
self._forward_implemented = True
......@@ -143,17 +182,19 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
# embedding need start_index
cost_mapping = build_comp_costs_from_descs(EmbeddingOpCost, ctx,
processes, desc_mapping,
cluster)
cost_mapping = build_comp_costs_from_descs(
EmbeddingOpCost, ctx, processes, desc_mapping, cluster
)
serial_op = dist_op.serial_op
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("W")[0])[0]
serial_op.input("W")[0]
)[0]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
......@@ -162,11 +203,16 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)
AllreduceSumOpCost,
ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res_cost = [cost_mapping, comm_op_cost_list]
......@@ -180,7 +226,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
dist_attr = dist_op.dist_attr
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("W")[0])[0]
backward_op.input("W")[0]
)[0]
parallel_axis = embedding_row_dim_mapping
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = [backward_op.input("Out@GRAD")[0]]
......@@ -190,33 +237,38 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
parallel_axis=parallel_axis,
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res.append(comm_op_cost_list)
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(EmbeddingGradOpCost, ctx,
processes, desc_mapping,
cluster)
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
cost_mapping = build_comp_costs_from_descs(
EmbeddingGradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Ids")[0])
backward_op.input("Ids")[0]
)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('W@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
build_dp_costs(
res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res
......@@ -228,7 +280,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(
w_dims_mapping[-1]):
w_dims_mapping[-1]
):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]:
......@@ -248,8 +301,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
if (not self.is_input_compatible(dist_op)) or (
not self.is_output_compatible(dist_op)
):
return False
op_desc = dist_op.serial_op.desc
......@@ -261,7 +315,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]:
if ids_dims_mapping != out_dims_mapping[: len(ids_dims_mapping)]:
return False
return True
......@@ -279,12 +333,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
for i in range(len(ids_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[ids_dims_mapping, out_dims_mapping], [i, i])
[ids_dims_mapping, out_dims_mapping], [i, i]
)
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[w_dims_mapping, out_dims_mapping], [-1, -1])
[w_dims_mapping, out_dims_mapping], [-1, -1]
)
if dim_changed:
changed = True
......@@ -302,26 +358,30 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
assert (
op_dist_attr is not None
), "forward op [{}] don't have dist attribute !".format(str(src_op))
# check validation of inputs / outputs
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
assert 'W' in kwargs, "input [{}] is not given".format('W')
assert 'Out' in kwargs, "output [{}] is not given".format('Out')
assert len(
assert (
len(kwargs['Ids']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Ids']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Ids'])
assert len(
)
assert (
len(kwargs['W']) == 1
), "row_parallel_embedding input W take 1 variable but got {}".format(
kwargs['W']
) == 1, "row_parallel_embedding input W take 1 variable but got {}".format(
kwargs['W'])
assert len(
)
assert (
len(kwargs['Out']) == 1
), "row_parallel_embedding output Out take 1 variable but got {}".format(
kwargs['Out']
) == 1, "row_parallel_embedding output Out take 1 variable but got {}".format(
kwargs['Out'])
)
Ids_var = main_block.var(kwargs['Ids'][0])
Weight_var = main_block._var_recursive(kwargs['W'][0])
......@@ -333,70 +393,85 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# got dist attribute info
embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[0]
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping)
Weight_var.name
)[0]
assert (
embedding_row_dim_mapping >= 0
), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh_group:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
# A generalized method to caculate embedding offset using cartisian product
relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape,
embedding_row_dim_mapping, rank_id)
relative_idx = _get_idx_in_axis(
process_mesh_group,
process_mesh_shape,
embedding_row_dim_mapping,
rank_id,
)
per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size
# TODO caculate ring id
parallel_axis = embedding_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks)
# append op
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'],
'c_embedding')
check_variable_and_dtype(
Ids_var, 'input', ['int32', 'int64'], 'c_embedding'
)
# infer new var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr,
out_var_dist_attr)
ref_shape = infer_shape(
main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", 'tmp'])),
name=unique_name.generate_with_ignorable_key(
".".join(["c_embedding", 'tmp'])
),
dtype=Weight_var.dtype,
shape=Out_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)
stop_gradient=Out_var.stop_gradient,
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_var_dist_attr)
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, out_var_dist_attr
)
check_variable_and_dtype(
Out_var, 'tensor',
Out_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum')
'c_allreduce_sum',
)
c_embedding_op = main_block.append_op(
type='c_embedding',
inputs={
'Ids': [Ids_var],
'W': [Weight_var]
},
inputs={'Ids': [Ids_var], 'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={
"start_index": relative_idx,
OP_ROLE_KEY: src_op.attr('op_role')
})
OP_ROLE_KEY: src_op.attr('op_role'),
},
)
if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape)
......@@ -409,8 +484,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
OP_ROLE_KEY: src_op.attr('op_role'),
},
)
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)
......@@ -423,15 +499,19 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
for input_varname in c_embedding_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
embedding_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
op_dist_attr
)
embedding_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
output_varname = c_embedding_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
embedding_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
op_dist_attr
)
embedding_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr)
# allreduce
......@@ -443,16 +523,20 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(input_varname,
tensor_dist_attr)
allreduce_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr
)
for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
allreduce_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
ctx.set_op_dist_attr_for_program(c_allreduce_sum_op,
allreduce_op_dist_attr)
op_dist_attr
)
allreduce_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(
c_allreduce_sum_op, allreduce_op_dist_attr
)
# param initialization sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
......@@ -469,20 +553,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis,
rank_id)
group_ranks = _get_comm_group(
process_mesh.processes,
process_mesh.topology,
axis,
rank_id,
)
sync_group = new_process_group(group_ranks)
startup_block.append_op(type='c_broadcast',
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
OP_ROLE_KEY: OpRole.Forward,
},
)
@staticmethod
def backward(ctx, *args, **kwargs):
......@@ -493,35 +582,43 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
backward_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
assert (
dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(
str(backward_op)
)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh,
rank_id)
rank_id = _get_corresponding_rank(
ctx, dist_attr.process_mesh, rank_id
)
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
assert 'W' in kwargs, "input [{}] is not given".format('W')
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out')
assert 'W@GRAD' in kwargs, "output [{}] is not given".format('W@GRAD')
assert len(
assert (
len(kwargs['Ids']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Ids']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Ids'])
assert len(
)
assert (
len(kwargs['W']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['W']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['W'])
assert len(
kwargs['Out@GRAD']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Out'])
assert len(
)
assert (
len(kwargs['Out@GRAD']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Out']
)
assert (
len(kwargs['W@GRAD']) == 1
), "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['W@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['W@GRAD'])
)
Ids_var = main_block.var(kwargs['Ids'][0])
Weight_var = main_block.var(kwargs['W'][0])
......@@ -529,39 +626,57 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
Weight_grad = main_block.var(kwargs['W@GRAD'][0])
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
Weight_var.name)[0]
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping)
Weight_var.name
)[0]
assert (
embedding_row_dim_mapping >= 0
), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping
)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes
# A generalized method to caculate embedding offset using cartisian product
relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape,
embedding_row_dim_mapping, rank_id)
relative_idx = _get_idx_in_axis(
process_mesh_group,
process_mesh_shape,
embedding_row_dim_mapping,
rank_id,
)
per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size
check_variable_and_dtype(
Out_grad, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity')
Out_grad,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity',
)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", '@tmp_0@GRAD'])),
name=unique_name.generate_with_ignorable_key(
".".join(["c_embedding", '@tmp_0@GRAD'])
),
dtype=Out_grad.dtype,
shape=Out_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_grad.stop_gradient)
stop_gradient=Out_grad.stop_gradient,
)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
assert out_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_grad_dist_attr)
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, out_grad_dist_attr
)
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
embedding_row_dim_mapping, rank_id)
group_ranks = _get_comm_group(
process_mesh_group,
process_mesh_shape,
embedding_row_dim_mapping,
rank_id,
)
group = new_process_group(group_ranks)
c_identity_op = main_block.append_op(
......@@ -573,41 +688,54 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'], 'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
},
)
check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
'linear',
)
set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh,
out_grad_dist_attr, ctx)
set_comm_op_dist_attr_for_program(
c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
)
c_embedding_grad_op_desc = main_block.append_op(type='nop').desc
c_embedding_grad_op_desc.set_type("c_embedding_grad")
c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
c_embedding_grad_op_desc.set_input('W', [Weight_var.name])
c_embedding_grad_op_desc.set_input('Out@GRAD',
[intermediate_var_0.name])
c_embedding_grad_op_desc.set_input(
'Out@GRAD', [intermediate_var_0.name]
)
c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name])
c_embedding_grad_op_desc._set_attr('start_index', relative_idx)
c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
c_embedding_grad_op = main_block.ops[-1]
assert c_embedding_grad_op.type == "c_embedding_grad"
naive_copy_op_dist_attr_for_program(c_embedding_grad_op, backward_op,
ctx)
naive_copy_op_dist_attr_for_program(
c_embedding_grad_op, backward_op, ctx
)
# data parallel gradient synchronization
act_grad_names = [Ids_var.name]
out_grad_names = [kwargs['W@GRAD'][0]]
gradient_synchronization(ctx, backward_op, act_grad_names,
out_grad_names, rank_id)
gradient_synchronization(
ctx, backward_op, act_grad_names, out_grad_names, rank_id
)
register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel"))
register_distributed_operator_impl("c_embedding",
DistributedEmbeddingImpl("row_parallel"))
register_distributed_operator_impl("lookup_table",
DistributedEmbeddingImpl("row_parallel"))
register_distributed_operator_impl(
"lookup_table_v2", DistributedEmbeddingImpl("row_parallel")
)
register_distributed_operator_impl(
"c_embedding", DistributedEmbeddingImpl("row_parallel")
)
register_distributed_operator_impl(
"lookup_table", DistributedEmbeddingImpl("row_parallel")
)
......@@ -25,6 +25,7 @@ from .reshard import Resharder
from .partitioner import Partitioner
from .utils import set_grad_var_shape
from .process_group import get_world_process_group
from .random import init_auto_parallel_rng
from ..utils.log_utils import get_logger
......@@ -84,6 +85,9 @@ class Parallelizer:
) = partitioner.partition(
serial_main_program, serial_startup_program, params_grads
)
init_auto_parallel_rng()
self._logger.debug(
"within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode
......
......@@ -19,6 +19,8 @@ import paddle
# Use to store the previous and current process mesh
_g_previous_process_mesh = None
_g_current_process_mesh = None
# {shape_process_ids : unique_id}
_g_unique_process_mesh_map = {}
def get_current_process_mesh():
......@@ -39,6 +41,30 @@ def reset_current_process_mesh():
_g_current_process_mesh = _g_previous_process_mesh
def get_unique_id_for_process_mesh(shape, process_ids):
key = f"shape {shape}, process_ids {process_ids}"
global _g_unique_process_mesh_map
if key in _g_unique_process_mesh_map:
unique_id = _g_unique_process_mesh_map[key]
else:
unique_id = len(_g_unique_process_mesh_map) + 1
_g_unique_process_mesh_map[key] = unique_id
return unique_id
def retrive_unique_id_for_process_mesh(shape, process_ids):
key = f"shape {shape}, process_ids {process_ids}"
global _g_unique_process_mesh_map
assert key in _g_unique_process_mesh_map
return _g_unique_process_mesh_map[key]
def get_unique_process_mesh_map():
global _g_unique_process_mesh_map
return _g_unique_process_mesh_map
class ProcessMesh(object):
"""
The `Processmesh` object describes the topology of the used processes.
......@@ -113,6 +139,11 @@ class ProcessMesh(object):
pg0 = get_process_group(0)
pg0.add_ranks(self.processes)
# Uniqe Mesh Id
self._unique_id = get_unique_id_for_process_mesh(
self._shape, self._process_ids
)
@property
def shape(self):
"""
......@@ -148,6 +179,16 @@ class ProcessMesh(object):
"""
return self._mesh
@property
def unique_id(self):
"""
Get the unique id of ProcessMesh.
NOTE
Unique id only take process_ids and shape into account.
Different ProcessMesh with same process_ids and shape have same unique id.
"""
return self._unique_id
@property
def topology(self):
return self._shape
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle
from ..utils.log_utils import get_logger
from .process_mesh import retrive_unique_id_for_process_mesh
from .utils import _get_idx_in_axis
_logger = get_logger(logging.INFO)
_rng_name_to_seed = {}
_inited_rng_name_to_seed = {}
_enable_random_control = False
_basic_seed = 42
# use Prime number as offset to avoid confict
_mesh_offset = 173
_dim_offsets = [11, 23, 37, 73]
def is_enable_auto_rand_ctrl():
global _enable_random_control
return _enable_random_control
def enable_auto_rand_ctrl():
global _enable_random_control
_enable_random_control = True
def parallel_manual_seed(seed):
"""Enable auto parallel random control.
Random control maintain the randomness when tensor is distributed across devices on a Mesh(any order).
* Independency: If tensor is **Sharded** on a Mesh dimension, Devices along that Mesh dimension should have Different randomness.
* Consistency: Meanwhile if the tensor is **Replicated** on another Mesh dimension, randomness of Devices along that Mesh dimension should be Consistent.
For instance: rank0 ~ rank7 consist a Mesh of shape of [2, 4]; A 2D tensor is distributed in that Mesh using dims_mapping [-1, 1].
Randomness for rank0-rank1-rank2-rank3 (rank4-rank5-rank6-rank7) should be Independent;
Randomness for rank0 and rank4 (rank1 and rank5, ...) should be Consistent.
This function should be called only once before auto parallel compiles the computation graph (e.g. auto_parallel.engine.prepare() or fit()).
This seed only affects how randomness-relative **operators** (dropout, fuse op with dropout inside, etc) are execute amonge mesh, and would NOT affect other processe like Parameter initialization.
Examples:
# seed relative to training step
auto_parallel_random_seed((step + 13) * 257)
...
engine.prepare()
"""
enable_auto_rand_ctrl()
global _basic_seed
_basic_seed = seed
def determinate_rng(rank, dims_mapping, process_mesh):
# TODO(JZ-LIANG) Support Mesh with any high rank
# use a string to unique integer hashing algorithm for seed computation.
# instead of using offsets to coodinate seed across devices.
if len(process_mesh.shape) > 4:
raise NotImplementedError(
"Auto Parallel Random Control for Mesh's rank > 4 is NOT supported! Got {}".format(
str(process_mesh)
)
)
global _basic_seed
seed_ = _basic_seed
# FIXME
# unique_id = process_mesh.unique_id
unique_id = retrive_unique_id_for_process_mesh(
process_mesh.shape, process_mesh.process_ids
)
sharding_expr = f'mesh:{unique_id}'
seed_ += _mesh_offset * (unique_id + 1)
for i in range(len(process_mesh.shape)):
if i not in dims_mapping:
relative_idx = -1
else:
relative_idx = _get_idx_in_axis(
process_mesh.process_ids,
process_mesh.shape,
i,
rank,
)
sharding_expr += f"_dim{i}:{relative_idx}"
seed_ += _dim_offsets[i] * (relative_idx + 1)
global _rng_name_to_seed
if sharding_expr in _rng_name_to_seed:
assert _rng_name_to_seed[sharding_expr] == seed_
else:
assert (
seed_ not in _rng_name_to_seed.values()
), "Seed Confilt! current seed: {}, current sharding expr: {}, generated seed: {}".format(
seed_, sharding_expr, _rng_name_to_seed
)
_rng_name_to_seed[sharding_expr] = seed_
return sharding_expr
def init_auto_parallel_rng():
if not is_enable_auto_rand_ctrl():
return
global _rng_name_to_seed
# NOTE init rng maybe call multiple times, avoid init same rng twice
global _inited_rng_name_to_seed
for rng_name, seed in _rng_name_to_seed.items():
if rng_name in _inited_rng_name_to_seed:
assert _inited_rng_name_to_seed[rng_name] == seed
else:
_logger.info(
f"Init Auto Parallel RNG: {rng_name}, with seed {seed}"
)
paddle.framework.random.set_random_seed_generator(rng_name, seed)
_inited_rng_name_to_seed[rng_name] = seed
......@@ -22,13 +22,20 @@ from paddle.fluid.framework import Variable, Operator
from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name
from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.utils import get_loss_op, set_var_dist_attr, set_dist_op_desc_original_id
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
)
from paddle.distributed.auto_parallel.utils import (
get_loss_op,
set_var_dist_attr,
set_dist_op_desc_original_id,
)
from paddle.distributed.auto_parallel.utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
)
class RecomputeState(ProgramStats):
def __init__(self, block, ops):
super(RecomputeState, self).__init__(block=block, ops=ops)
self._block = block
......@@ -54,7 +61,7 @@ class RecomputeState(ProgramStats):
self.var_op_deps[name]["var_as_output_ops"] = [i]
def get_recompute_segments(self, checkpoints):
""" get recompute segments from checkpoints """
"""get recompute segments from checkpoints"""
segments = []
start_idx = -1
pre_segment_end_idx = -1
......@@ -69,27 +76,37 @@ class RecomputeState(ProgramStats):
segments.append([0, max(op_idx_list) + 1])
else:
flag, min_idx, max_idx = self.is_subgraph(
[checkpoints[start_idx]], [checkpoints[start_idx + 1]])
[checkpoints[start_idx]], [checkpoints[start_idx + 1]]
)
if flag:
min_idx = self._update_segment_start(
min_idx, pre_segment_end_idx)
min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1])
else:
logging.info(
"Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1))
min_idx, max_idx + 1
)
)
start_idx += 1
for i, (idx1, idx2) in enumerate(segments):
logging.info("recompute segment[{}]".format(i))
logging.info("segment start op: [{}]: [{}] [{}]".format(
logging.info(
"segment start op: [{}]: [{}] [{}]".format(
self._ops[idx1].desc.type(),
self._ops[idx1].desc.input_arg_names(),
self._ops[idx1].desc.output_arg_names()))
logging.info("segment end op: [{}]: [{}] [{}]".format(
self._ops[idx1].desc.output_arg_names(),
)
)
logging.info(
"segment end op: [{}]: [{}] [{}]".format(
self._ops[idx2 - 1].desc.type(),
self._ops[idx2 - 1].desc.input_arg_names(),
self._ops[idx2 - 1].desc.output_arg_names()))
self._ops[idx2 - 1].desc.output_arg_names(),
)
)
return segments
......@@ -116,45 +133,52 @@ class RecomputeState(ProgramStats):
cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op)
# insert seed op to guarantee that two dropout op have the same outputs
op_unique_name = unique_name.generate("seed")
var_unique_name = unique_name.generate_with_ignorable_key(".".join(
[op_unique_name, 'tmp']))
# NOTE Hack for adopt recompute for random control, for more info see dist_dropout.py
# new seed added by recompute should have a prefix to distinguish with seed added by user or other moudule.
op_unique_name = unique_name.generate("rc_seed")
var_unique_name = unique_name.generate_with_ignorable_key(
".".join([op_unique_name, 'tmp'])
)
seed_var = self._block.create_var(
name=var_unique_name,
dtype='int32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
stop_gradient=False,
)
# set new seed_var's dist_attr
ref_dims_mapping = [-1]
ref_process_mesh = cur_op_dist_attr.process_mesh
seed_var_dist_attr = set_var_dist_attr(dist_context, seed_var,
ref_dims_mapping,
ref_process_mesh)
seed = 0 if cur_op.attr("fix_seed") is False else int(
cur_op.attr("seed"))
seed_var_dist_attr = set_var_dist_attr(
dist_context, seed_var, ref_dims_mapping, ref_process_mesh
)
seed = (
0
if cur_op.attr("fix_seed") is False
else int(cur_op.attr("seed"))
)
seed_op = self._block._insert_op_without_sync(
index=cur_op.idx,
type="seed",
inputs={},
outputs={"Out": seed_var},
attrs={
"seed": seed,
"force_cpu": True
})
attrs={"seed": seed, "force_cpu": True},
)
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, ref_process_mesh, ref_dims_mapping, dist_context)
seed_op, ref_process_mesh, ref_dims_mapping, dist_context
)
# modify dropout op's desc
self._ops.insert(op_idx, seed_op)
cur_op.desc.set_input("Seed", [var_unique_name])
cur_op._remove_attr("fix_seed")
cur_op._remove_attr("seed")
cur_op_dist_attr.set_input_dist_attr(seed_var.name,
seed_var_dist_attr)
cur_op_dist_attr.set_input_dist_attr(
seed_var.name, seed_var_dist_attr
)
op_idx += 2
self._block._sync_with_cpp()
......@@ -168,7 +192,7 @@ def _find_op_index(block, cur_op):
def _get_stop_gradients(program, no_grad_set):
""" get no grad var """
"""get no grad var"""
if no_grad_set is None:
no_grad_set = set()
else:
......@@ -185,8 +209,9 @@ def _get_stop_gradients(program, no_grad_set):
return no_grad_set_name
def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars,
dist_context):
def _add_needed_descs_to_block(
descs, block, main_block, in_memory_vars, dist_context
):
"""
Get the recomputed ops which will insert the backward part
"""
......@@ -217,7 +242,6 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars,
@register_pass("auto_parallel_recompute")
class RecomputePass(PassBase):
def __init__(self):
super(RecomputePass, self).__init__()
self.set_attr("checkpoints", None)
......@@ -261,12 +285,15 @@ class RecomputePass(PassBase):
vars_should_be_hold = []
for segment in segments:
vars_should_be_hold.extend(
rc_state.get_out_of_subgraph_vars(segment[0], segment[1]))
rc_state.get_out_of_subgraph_vars(segment[0], segment[1])
)
cross_vars = set(vars_should_be_hold) - set(checkpoints)
logging.info(
"found [{}] vars which cross recompute segment: [{}],"
"better checkpoints might be set to reduce those vars".format(
len(cross_vars), cross_vars))
len(cross_vars), cross_vars
)
)
vars_should_be_hold.extend(rc_state.get_reserved_vars())
vars_should_be_hold.extend(rc_state.get_input_nodes())
vars_should_be_hold = list(set(vars_should_be_hold))
......@@ -277,14 +304,15 @@ class RecomputePass(PassBase):
ckpt_ops_dict = {}
buffer_block = main_block.program._create_block()
for i, segment in enumerate(segments[::-1]):
fwd_ops = op_path[segment[0]:segment[1]]
fwd_ops = op_path[segment[0] : segment[1]]
var_suffix = ".subprog_%d" % i
for op in fwd_ops:
input_and_output_names = []
input_and_output_names.extend(op.desc.input_arg_names())
input_and_output_names.extend(op.desc.output_arg_names())
cur_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
op)
cur_op_dist_attr = (
self._dist_context.get_op_dist_attr_for_program(op)
)
assert cur_op_dist_attr is not None
for name in input_and_output_names:
if main_block.var(name).persistable or name in checkpoints:
......@@ -294,11 +322,13 @@ class RecomputePass(PassBase):
if name not in var_name_dict:
ref_process_mesh = cur_op_dist_attr.process_mesh
if name in op.desc.input_arg_names():
ref_dims_mapping = cur_op_dist_attr.get_input_dims_mapping(
name)
ref_dims_mapping = (
cur_op_dist_attr.get_input_dims_mapping(name)
)
else:
ref_dims_mapping = cur_op_dist_attr.get_output_dims_mapping(
name)
ref_dims_mapping = (
cur_op_dist_attr.get_output_dims_mapping(name)
)
# record recomputed var's old_name and new_name (old_name.subprog_XXX)
# create new var with new name
var_name_dict[name] = name + var_suffix
......@@ -309,15 +339,23 @@ class RecomputePass(PassBase):
dtype=ref_var.dtype,
type=ref_var.type,
persistable=ref_var.persistable,
stop_gradient=ref_var.stop_gradient)
stop_gradient=ref_var.stop_gradient,
)
# set new recomputed var's dist attr
set_var_dist_attr(self._dist_context, rc_var,
ref_dims_mapping, ref_process_mesh)
set_var_dist_attr(
self._dist_context,
rc_var,
ref_dims_mapping,
ref_process_mesh,
)
# get recomputed segment's descs
segment_descs = _add_needed_descs_to_block(fwd_ops, buffer_block,
segment_descs = _add_needed_descs_to_block(
fwd_ops,
buffer_block,
main_block,
vars_in_memory,
self._dist_context)
self._dist_context,
)
# rename recomputed ops' input and output var name
for key in var_name_dict:
_rename_arg_(segment_descs, key, var_name_dict[key])
......@@ -345,7 +383,10 @@ class RecomputePass(PassBase):
# rename grad op's var_name which is not in 'vars_in_memory'
for key in var_name_dict:
if key not in grad_op.input_arg_names + grad_op.output_arg_names:
if (
key
not in grad_op.input_arg_names + grad_op.output_arg_names
):
continue
self.reset_op_dist_attr(grad_op, var_name_dict)
_rename_arg_([grad_op.desc], key, var_name_dict[key])
......@@ -360,17 +401,20 @@ class RecomputePass(PassBase):
idx -= 1
segment_descs = ckpt_ops_dict[fwd_op_id][1]
for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_op = main_block._insert_op_without_sync(idx,
type='nop')
rc_op = main_block._insert_op_without_sync(
idx, type='nop'
)
rc_desc = rc_op.desc
rc_desc.copy_from(op_desc)
rc_desc.set_original_id(rc_desc.id())
# set recomputed ops' dist attr
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
op_desc.original_id())
op_desc.original_id()
)
assert fwd_op_dist_attr is not None
self.set_op_dist_attr(rc_op, fwd_op_dist_attr,
var_name_dict)
self.set_op_dist_attr(
rc_op, fwd_op_dist_attr, var_name_dict
)
ckpt_ops_dict[fwd_op_id][0] = False
......@@ -382,13 +426,15 @@ class RecomputePass(PassBase):
for input in op.desc.input_arg_names():
if input in var_name_dict.keys():
in_dist_attr = op_dist_attr.get_input_dist_attr(input)
op_dist_attr.set_input_dist_attr(var_name_dict[input],
in_dist_attr)
op_dist_attr.set_input_dist_attr(
var_name_dict[input], in_dist_attr
)
for output in op.desc.output_arg_names():
if output in var_name_dict.keys():
out_dist_attr = op_dist_attr.get_output_dist_attr(output)
op_dist_attr.set_output_dist_attr(var_name_dict[output],
out_dist_attr)
op_dist_attr.set_output_dist_attr(
var_name_dict[output], out_dist_attr
)
def set_op_dist_attr(self, op, old_dist_attr, var_name_dict):
new_dist_attr = OperatorDistributedAttribute()
......@@ -399,16 +445,18 @@ class RecomputePass(PassBase):
for input in old_dist_attr.inputs_dist_attrs.keys():
if input in var_name_dict.keys():
in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
new_dist_attr.set_input_dist_attr(var_name_dict[input],
in_dist_attr)
new_dist_attr.set_input_dist_attr(
var_name_dict[input], in_dist_attr
)
else:
in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
new_dist_attr.set_input_dist_attr(input, in_dist_attr)
for output in old_dist_attr.outputs_dist_attrs.keys():
if output in var_name_dict.keys():
out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
new_dist_attr.set_output_dist_attr(var_name_dict[output],
out_dist_attr)
new_dist_attr.set_output_dist_attr(
var_name_dict[output], out_dist_attr
)
else:
out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
new_dist_attr.set_output_dist_attr(output, out_dist_attr)
......
......@@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
${dist_ENVS})
set_tests_properties(test_high_order_grad
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS})
set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS})
set_tests_properties(test_iterable_dataset
......
......@@ -21,14 +21,17 @@ from paddle.distributed.fleet import auto
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
from auto_parallel_gpt_model import (
GPTModel,
GPTForPretraining,
GPTPretrainingCriterion,
)
sequence_len = 512
vocab_size = 1000
class FakeDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
self.sequence_len = sequence_len
......@@ -40,8 +43,11 @@ class FakeDataset(paddle.io.Dataset):
random.seed(2021)
tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape(
(1, self.sequence_len, self.sequence_len)).astype(np.float32)
attention_mask = (
np.tril(np.ones(self.sequence_len))
.reshape((1, self.sequence_len, self.sequence_len))
.astype(np.float32)
)
labels = np.random.randint(self.vocab_size, size=self.sequence_len)
loss_mask = np.ones(self.sequence_len).astype(np.float32)
return tokens, position_ids, attention_mask, labels, loss_mask
......@@ -51,30 +57,32 @@ class FakeDataset(paddle.io.Dataset):
def create_data_holder(batch_size):
tokens = paddle.static.InputSpec(name="tokens",
shape=[batch_size, sequence_len],
dtype='int64')
position_ids = paddle.static.InputSpec(name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
tokens = paddle.static.InputSpec(
name="tokens", shape=[batch_size, sequence_len], dtype='int64'
)
position_ids = paddle.static.InputSpec(
name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
)
attention_mask = paddle.static.InputSpec(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32')
labels = paddle.static.InputSpec(name="labels",
shape=[batch_size, sequence_len],
dtype='int64')
loss_mask = paddle.static.InputSpec(name="loss_mask",
shape=[batch_size, sequence_len],
dtype='float32')
dtype='float32',
)
labels = paddle.static.InputSpec(
name="labels", shape=[batch_size, sequence_len], dtype='int64'
)
loss_mask = paddle.static.InputSpec(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float32'
)
return [tokens, position_ids, attention_mask], [labels, loss_mask]
def generate_model(strategy):
def generate_model(strategy, dropout_prob=0.0):
modeling.init_global()
ranks = list(range(paddle.distributed.get_world_size()))
modeling._global_process_mesh = auto.ProcessMesh(mesh=ranks,
dim_names=["x"])
modeling._global_process_mesh = auto.ProcessMesh(
mesh=ranks, dim_names=["x"]
)
if strategy == "serial":
modeling._global_parallel_strategy = "serial"
elif strategy == "mp":
......@@ -84,24 +92,25 @@ def generate_model(strategy):
else:
raise ValueError("Only support serial, mp2 and dp2.")
gpt = GPTModel(vocab_size=1000,
gpt = GPTModel(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=dropout_prob,
attention_probs_dropout_prob=dropout_prob,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3)
model = GPTForPretraining(gpt,
vocab_size=1000,
hidden_size=64,
initializer_range=0.02)
eol_token_id=3,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
)
criterion = GPTPretrainingCriterion()
return model, criterion
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
paddle.enable_static()
from paddle import _legacy_C_ops
from paddle.distributed.fleet import auto
def dy_broadcast_helper(tensor):
_legacy_C_ops.c_broadcast(
tensor, tensor, 'root', 1, 'use_calc_stream', True, 'ring_id', 1000
)
_legacy_C_ops.c_sync_calc_stream(tensor, tensor)
def apply_pass(use_recompute=False, no_recompute_segments=[]):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
recompute.no_recompute_segments = no_recompute_segments
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestRandomControl(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
paddle.distributed.auto_parallel.parallel_manual_seed(100)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_recompute=False, no_recompute_segments=[]):
reset_prog()
strategy = apply_pass(use_recompute, no_recompute_segments)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp", dropout_prob=0.1)
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def compare_mask_between_ranks(
self, rank, mask_np_list, comapre_idx, equal
):
for np_mask in [mask_np_list[i] for i in comapre_idx]:
mask_tensor_local = paddle.to_tensor(np_mask.astype("float32"))
if rank == 0:
mask_tensor_remote = paddle.ones_like(mask_tensor_local)
dy_broadcast_helper(mask_tensor_remote)
if equal:
assert np.array_equal(
mask_tensor_remote.numpy(), mask_tensor_local.numpy()
)
else:
assert not np.array_equal(
mask_tensor_remote.numpy(),
mask_tensor_local.numpy(),
)
else:
dy_broadcast_helper(mask_tensor_local)
def test_random_ctrl_vanilla(self):
# mp2 recompute training
rc_engine = self.get_engine(False)
train_dataloader = rc_engine.dataloader(
self.dataset,
batch_size=self.batch_size,
mode="train",
sample_split=3,
)
rc_engine.prepare(mode="train")
mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)]
mask_var_list = [
rc_engine.main_program.global_block().var(varname)
for varname in mask_name_list
]
for data in train_dataloader:
outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train")
mask_np_list = [outs['fetches'][varname] for varname in mask_name_list]
paddle.disable_static()
rank = paddle.distributed.get_rank()
# check globl mask consistent across ranks
global_index = [0, 2, 3, 5, 6]
self.compare_mask_between_ranks(
rank, mask_np_list, global_index, equal=True
)
local_index = [1, 4]
# check loacl mask different across ranks
self.compare_mask_between_ranks(
rank, mask_np_list, local_index, equal=False
)
paddle.enable_static()
# check program
ops = rc_engine.main_program.global_block().ops
rng_names = []
seed_var_names = []
for op in ops:
if op.type == "seed":
rng_names.append(op.attr('rng_name'))
if op.type == "dropout":
seed_var_names.append(op.input("Seed")[0])
rank = paddle.distributed.get_rank()
self.assertEqual(
rng_names,
[
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
],
)
self.assertEqual(
seed_var_names,
[
'tensor_parallel_seed.tmp_0',
'tensor_parallel_seed.tmp_1',
'tensor_parallel_seed.tmp_2',
'tensor_parallel_seed.tmp_3',
'tensor_parallel_seed.tmp_4',
'tensor_parallel_seed.tmp_5',
'tensor_parallel_seed.tmp_6',
],
)
def test_random_ctrl_with_recompute(self):
# mp2 recompute training
rc_engine = self.get_engine(True)
train_dataloader = rc_engine.dataloader(
self.dataset,
batch_size=self.batch_size,
mode="train",
sample_split=3,
)
rc_engine.prepare(mode="train")
mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)]
recompute_mask_name_list = [
'dropout_0.tmp_1.subprog_1',
'dropout_1.tmp_1.subprog_1',
'dropout_2.tmp_1.subprog_1',
'dropout_3.tmp_1.subprog_1',
'dropout_4.tmp_1.subprog_0',
'dropout_5.tmp_1.subprog_0',
'dropout_6.tmp_1.subprog_0',
]
mask_var_list = [
rc_engine.main_program.global_block().var(varname)
for varname in mask_name_list + recompute_mask_name_list
]
for data in train_dataloader:
outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train")
mask_np_list = [
outs['fetches'][varname]
for varname in mask_name_list + recompute_mask_name_list
]
# check recompute is mask the same within local device
for i in range(7):
mask_fw = mask_np_list[i].astype("float32")
mask_rc = mask_np_list[i + 7].astype("float32")
assert np.array_equal(
mask_fw,
mask_rc,
)
paddle.disable_static()
# check globl mask consistent across ranks
rank = paddle.distributed.get_rank()
global_index = [0, 2, 3, 5, 6]
self.compare_mask_between_ranks(
rank, mask_np_list, global_index, equal=True
)
local_index = [1, 4]
# check loacl mask different across ranks
self.compare_mask_between_ranks(
rank, mask_np_list, local_index, equal=False
)
paddle.enable_static()
# check program
rank = paddle.distributed.get_rank()
ops = rc_engine.main_program.global_block().ops
rng_names = []
seed_var_names = []
for op in ops:
if op.type == "seed":
rng_names.append(op.attr('rng_name'))
if op.type == "dropout":
seed_var_names.append(op.input("Seed")[0])
self.assertEqual(
rng_names,
[
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
],
)
self.assertEqual(
seed_var_names,
[
'rc_seed_0.tmp_0',
'rc_seed_1.tmp_0',
'rc_seed_2.tmp_0',
'rc_seed_3.tmp_0',
'rc_seed_4.tmp_0',
'rc_seed_5.tmp_0',
'rc_seed_6.tmp_0',
'rc_seed_4.tmp_0',
'rc_seed_5.tmp_0',
'rc_seed_6.tmp_0',
'rc_seed_0.tmp_0',
'rc_seed_1.tmp_0',
'rc_seed_2.tmp_0',
'rc_seed_3.tmp_0',
],
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestRandomCtrlPass(unittest.TestCase):
def test_mp2_with_recompute(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "random_control_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册