未验证 提交 3869a3b4 编写于 作者: J JZ-LIANG 提交者: GitHub

Cherry Pick Random Ctrl (#52778)

上级 6959eae5
...@@ -19,5 +19,6 @@ from .interface import shard_tensor ...@@ -19,5 +19,6 @@ from .interface import shard_tensor
from .interface import shard_op from .interface import shard_op
from .interface import recompute from .interface import recompute
from .interface import fetch from .interface import fetch
from .random import parallel_manual_seed
__all__ = [] __all__ = []
...@@ -51,7 +51,7 @@ from .dist_loader import ( ...@@ -51,7 +51,7 @@ from .dist_loader import (
from .process_group import new_process_group, get_all_process_groups from .process_group import new_process_group, get_all_process_groups
from .dist_context import DistributedContext, get_default_distributed_context from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy from .strategy import Strategy
from .interface import CollectionNames, get_collection from .interface import CollectionNames, get_collection, fetch
from .utils import to_list, get_dist_attr, get_lr, validate_opt from .utils import to_list, get_dist_attr, get_lr, validate_opt
from .utils import initialize_pg_in_full_mode, get_input_split_info from .utils import initialize_pg_in_full_mode, get_input_split_info
from .cost.estimate_cost import get_cost_from_engine from .cost.estimate_cost import get_cost_from_engine
...@@ -438,6 +438,8 @@ class Engine: ...@@ -438,6 +438,8 @@ class Engine:
), "user_fetches must be a list, but receive {}".format( ), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__ type(user_fetches).__name__
) )
else:
user_fetches = []
fetch_names = [] fetch_names = []
fetch_indices = [] fetch_indices = []
...@@ -464,10 +466,13 @@ class Engine: ...@@ -464,10 +466,13 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list) _process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict": if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"]) _process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches:
var_name = _to_name_str(usr_fetch)
fetch(var_name)
user_fetches_collection = [ user_fetches_collection = [
item[1] for item in get_collection(CollectionNames.FETCHES) item[1] for item in get_collection(CollectionNames.FETCHES)
] ]
var_list = (user_fetches_collection or []) + (user_fetches or []) var_list = user_fetches_collection or []
_process_fetch_group("fetches", var_list) _process_fetch_group("fetches", var_list)
return fetch_names, fetch_indices return fetch_names, fetch_indices
...@@ -522,10 +527,10 @@ class Engine: ...@@ -522,10 +527,10 @@ class Engine:
# logging user fetches # logging user fetches
collect_fetches = get_collection(CollectionNames.FETCHES) collect_fetches = get_collection(CollectionNames.FETCHES)
logs_fetch = {} logs_fetch = {}
for name, var in collect_fetches: for name, var_name in collect_fetches:
if var.name in fetch_names: if var_name in fetch_names:
idx = fetch_names.index(var.name) idx = fetch_names.index(var_name)
logs_fetch[name or var.name] = outs[idx] logs_fetch[name or var_name] = outs[idx]
logs["fetches"] = logs_fetch logs["fetches"] = logs_fetch
return logs return logs
......
...@@ -258,6 +258,16 @@ def add_to_collection(collection_name, value, name=None): ...@@ -258,6 +258,16 @@ def add_to_collection(collection_name, value, name=None):
def fetch(tensor, name=None, logging=False): def fetch(tensor, name=None, logging=False):
if isinstance(tensor, paddle.fluid.framework.Variable):
tensor = tensor.name
elif isinstance(tensor, str):
tensor = tensor
else:
raise TypeError(
"Only support fetch `Variable` or `str`[`Variable`'s name], but got `{}`".format(
type(tensor)
)
)
add_to_collection(CollectionNames.FETCHES, tensor, name) add_to_collection(CollectionNames.FETCHES, tensor, name)
if logging: if logging:
add_to_collection(CollectionNames.LOGGING, tensor, name) add_to_collection(CollectionNames.LOGGING, tensor, name)
...@@ -36,3 +36,4 @@ from . import dist_reduce_sum_p ...@@ -36,3 +36,4 @@ from . import dist_reduce_sum_p
from . import dist_shape from . import dist_shape
from . import dist_assign from . import dist_assign
from . import dist_scale from . import dist_scale
from . import dist_dropout
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
import paddle
from paddle.framework import core
from paddle.utils import unique_name
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from ..utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedDropout(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedDropout("dropout"))
# Dist Dropout with Random Control
# Dropout re-use the compatible and cost function of elementwise
class DistributedDropoutImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute:
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
# check validation of inputs / outputs
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert (
len(kwargs['X']) == 1
), "input X should be only one tensor but got {}".format(
kwargs['X']
)
assert 'Seed' in kwargs, "input [{}] is not given".format('Seed')
if (
src_op.has_attr("fix_seed")
and src_op.attr("fix_seed")
and src_op.has_attr("seed")
and src_op.attr("seed")
):
_logger.info(
"Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format(
src_op
)
)
elif rank_id not in op_dist_attr.process_mesh.process_ids:
pass
# NOTE Adopt for recompute
# If user already set seed, We should not modify it. But if the seed is added by recompute pass, it should be under control.
# TODO in future recompute pass should happen after parallel partitione. and remove this at that time.
elif len(kwargs['Seed']) > 0 or len(src_op.input("Seed")) > 0:
seed_var_name = kwargs['Seed'][0]
if seed_var_name.startswith('rc_seed'):
pre_op = main_block.ops[-1]
assert (
pre_op.type == "seed"
and len(pre_op.attr("rng_name")) == 0
), f"found exception op {str(pre_op)}"
# determinate rng
X_var = main_block._var_recursive(kwargs['X'][0])
X_dims_mapping = op_dist_attr.get_input_dims_mapping(
X_var.name
)
process_mesh = op_dist_attr.process_mesh
rng_name = determinate_rng(
rank_id, X_dims_mapping, process_mesh
)
# make recompute seed under control
pre_op._set_attr("rng_name", rng_name)
pre_op._set_attr("deterministic", True)
pre_op._set_attr("force_cpu", True)
else:
_logger.info(
"Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format(
src_op
)
)
else:
# determinate rng
X_var = main_block._var_recursive(kwargs['X'][0])
X_dims_mapping = op_dist_attr.get_input_dims_mapping(X_var.name)
process_mesh = op_dist_attr.process_mesh
rng_name = determinate_rng(
rank_id, X_dims_mapping, process_mesh
)
assert rng_name is not None and rng_name != ""
# insert seed op
seed_var = main_block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(["tensor_parallel_seed", 'tmp'])
),
dtype=paddle.int32,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
# set new seed_var's dist_attr
seed_var_dims_mapping = [-1]
seed_var_dist_attr = set_var_dist_attr(
ctx, seed_var, seed_var_dims_mapping, process_mesh
)
# adopt for recompute
# force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
seed_op = main_block.append_op(
type='seed',
outputs={'Out': seed_var},
attrs={
'deterministic': True,
'rng_name': rng_name,
'force_cpu': True,
},
)
seed_op._set_attr('op_namescope', 'auto_tensor_parallel_seed')
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, process_mesh, seed_var_dims_mapping, ctx
)
# modify dropout op
src_op.desc.set_input("Seed", [seed_var.name])
src_op._remove_attr("fix_seed")
src_op._remove_attr("seed")
op_dist_attr.set_input_dist_attr(
seed_var.name, seed_var_dist_attr
)
kwargs['Seed'] = [seed_var.name]
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"dropout", DistributedDropoutImpl0("random_control")
)
...@@ -17,47 +17,74 @@ from .common import DistributedOperatorImplContainer ...@@ -17,47 +17,74 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import gradient_synchronization from .common import gradient_synchronization
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related from .common import (
register_distributed_operator_impl,
set_comm_op_dist_attr_for_program,
naive_copy_op_dist_attr_for_program,
is_parameter_related,
)
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from ..dist_attribute import (
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable from paddle.fluid.framework import Program, Parameter, Variable
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import (
OpRole,
OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank, set_var_dist_attr from ..utils import (
_get_comm_group,
_get_idx_in_axis,
_get_corresponding_rank,
set_var_dist_attr,
)
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs from ..cost import (
build_comm_costs_from_descs,
build_comp_costs_from_descs,
build_dp_costs,
)
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost from ..cost import EmbeddingOpCost, EmbeddingGradOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost from paddle.distributed.auto_parallel.cost.comm_op_cost import (
AllreduceSumOpCost,
IdentityOpCost,
)
class DistributedEmbedding(DistributedOperatorImplContainer): class DistributedEmbedding(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedEmbedding, self).__init__(op_type) super(DistributedEmbedding, self).__init__(op_type)
register_distributed_operator_impl_container( register_distributed_operator_impl_container(
DistributedEmbedding("lookup_table_v2")) DistributedEmbedding("lookup_table_v2")
)
register_distributed_operator_impl_container( register_distributed_operator_impl_container(
DistributedEmbedding("c_embedding")) DistributedEmbedding("c_embedding")
)
register_distributed_operator_impl_container( register_distributed_operator_impl_container(
DistributedEmbedding("lookup_table")) DistributedEmbedding("lookup_table")
)
def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
assert len( assert (
Ids_var.shape len(Ids_var.shape) == 3
) == 3, "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format( ), "input Ids to lookup_table should have 3 dimensions but got [{}] with shape [{}]".format(
Ids_var.name, Ids_var.shape) Ids_var.name, Ids_var.shape
)
if not Ids_var.stop_gradient: if not Ids_var.stop_gradient:
raise NotImplementedError( raise NotImplementedError(
'Requiring the gradient of Ids of lookup_table(v1)dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).' 'Requiring the gradient of Ids of lookup_table(v1)dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).'
...@@ -65,59 +92,72 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): ...@@ -65,59 +92,72 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
target_shape = list(Ids_var.shape[:-1]) target_shape = list(Ids_var.shape[:-1])
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["dist_reshape", 'tmp'])), ".".join(["dist_reshape", 'tmp'])
),
dtype=Ids_var.dtype, dtype=Ids_var.dtype,
shape=target_shape, shape=target_shape,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True,
)
target_shape = [0] + list(Ids_var.shape[:-1]) target_shape = [0] + list(Ids_var.shape[:-1])
xshape_var = main_block.create_var( xshape_var = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["dist_Xshape", 'tmp'])), ".".join(["dist_Xshape", 'tmp'])
),
dtype=Ids_var.dtype, dtype=Ids_var.dtype,
shape=target_shape, shape=target_shape,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True,
)
# TODO use inplace reshape for memory saving # TODO use inplace reshape for memory saving
reshape_op = main_block.append_op(type='reshape2', reshape_op = main_block.append_op(
inputs={'X': [Ids_var]}, type='reshape2',
outputs={ inputs={'X': [Ids_var]},
'Out': [intermediate_var_0], outputs={'Out': [intermediate_var_0], 'XShape': [xshape_var]},
'XShape': [xshape_var] attrs={
}, "shape": [0, -1],
attrs={ },
"shape": [0, -1], )
})
# set dist attr # set dist attr
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
Ids_var_dist_attr = op_dist_attr.get_input_dist_attr(Ids_var.name) Ids_var_dist_attr = op_dist_attr.get_input_dist_attr(Ids_var.name)
assert Ids_var_dist_attr is not None assert Ids_var_dist_attr is not None
intermediate_var_0_dist_attr = set_var_dist_attr( intermediate_var_0_dist_attr = set_var_dist_attr(
ctx, intermediate_var_0, Ids_var_dist_attr.dims_mapping, ctx,
Ids_var_dist_attr.process_mesh) intermediate_var_0,
set_var_dist_attr(ctx, xshape_var, Ids_var_dist_attr.dims_mapping,
[-1] + list(Ids_var_dist_attr.dims_mapping), Ids_var_dist_attr.process_mesh,
Ids_var_dist_attr.process_mesh) )
set_var_dist_attr(
ctx,
xshape_var,
[-1] + list(Ids_var_dist_attr.dims_mapping),
Ids_var_dist_attr.process_mesh,
)
op_dist_attr.del_input_dist_attr(Ids_var.name) op_dist_attr.del_input_dist_attr(Ids_var.name)
op_dist_attr.set_input_dist_attr(intermediate_var_0.name, op_dist_attr.set_input_dist_attr(
intermediate_var_0_dist_attr) intermediate_var_0.name, intermediate_var_0_dist_attr
)
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh
new_op_dist_attr.impl_type = "default" new_op_dist_attr.impl_type = "default"
new_op_dist_attr.impl_idx = 0 new_op_dist_attr.impl_idx = 0
new_op_dist_attr.set_input_dims_mapping(Ids_var.name, new_op_dist_attr.set_input_dims_mapping(
Ids_var_dist_attr.dims_mapping) Ids_var.name, Ids_var_dist_attr.dims_mapping
new_op_dist_attr.set_output_dims_mapping(intermediate_var_0.name, )
Ids_var_dist_attr.dims_mapping) new_op_dist_attr.set_output_dims_mapping(
intermediate_var_0.name, Ids_var_dist_attr.dims_mapping
)
new_op_dist_attr.set_output_dims_mapping( new_op_dist_attr.set_output_dims_mapping(
xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping)) xshape_var.name, [-1] + list(Ids_var_dist_attr.dims_mapping)
)
ctx.set_op_dist_attr_for_program(reshape_op, new_op_dist_attr) ctx.set_op_dist_attr_for_program(reshape_op, new_op_dist_attr)
return intermediate_var_0 return intermediate_var_0
...@@ -125,7 +165,6 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): ...@@ -125,7 +165,6 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
# RowParallel # RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl): class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__(name) super(DistributedEmbeddingImpl, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -143,17 +182,19 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -143,17 +182,19 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.processes
# embedding need start_index # embedding need start_index
cost_mapping = build_comp_costs_from_descs(EmbeddingOpCost, ctx, cost_mapping = build_comp_costs_from_descs(
processes, desc_mapping, EmbeddingOpCost, ctx, processes, desc_mapping, cluster
cluster) )
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("W")[0])[0] serial_op.input("W")[0]
)[0]
attrs = {"use_calc_stream": True, "use_model_parallel": True} attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out") var_names = serial_op.output("Out")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op( c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
...@@ -162,11 +203,16 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -162,11 +203,16 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping, AllreduceSumOpCost,
cluster) ctx,
processes,
c_allreduce_sum_desc_mapping,
cluster,
)
res_cost = [cost_mapping, comm_op_cost_list] res_cost = [cost_mapping, comm_op_cost_list]
...@@ -180,7 +226,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -180,7 +226,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping( embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("W")[0])[0] backward_op.input("W")[0]
)[0]
parallel_axis = embedding_row_dim_mapping parallel_axis = embedding_row_dim_mapping
attrs = {"use_calc_stream": True, "use_model_parallel": True} attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = [backward_op.input("Out@GRAD")[0]] var_names = [backward_op.input("Out@GRAD")[0]]
...@@ -190,33 +237,38 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -190,33 +237,38 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
ctx, ctx,
var_names, var_names,
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis) parallel_axis=parallel_axis,
)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster) IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
res.append(comm_op_cost_list) res.append(comm_op_cost_list)
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, desc_mapping = build_comp_desc_from_dist_op(
dist_context=ctx) dist_op=dist_op, dist_context=ctx
cost_mapping = build_comp_costs_from_descs(EmbeddingGradOpCost, ctx, )
processes, desc_mapping, cost_mapping = build_comp_costs_from_descs(
cluster) EmbeddingGradOpCost, ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping) res.append(cost_mapping)
# need gradient allreduce # need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Ids")[0]) backward_op.input("Ids")[0]
)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
var_names = [backward_op.output('W@GRAD')[0]] var_names = [backward_op.output('W@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis, build_dp_costs(
cluster) res, dist_op, ctx, var_names, attrs, parallel_axis, cluster
)
return res return res
...@@ -228,7 +280,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -228,7 +280,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard( if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(
w_dims_mapping[-1]): w_dims_mapping[-1]
):
return False return False
# Other dimensions must be replicate except the batch dimension # Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]: for mapping in ids_dims_mapping[1:]:
...@@ -248,8 +301,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -248,8 +301,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \ if (not self.is_input_compatible(dist_op)) or (
(not self.is_output_compatible(dist_op)): not self.is_output_compatible(dist_op)
):
return False return False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
...@@ -261,7 +315,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -261,7 +315,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]: if ids_dims_mapping != out_dims_mapping[: len(ids_dims_mapping)]:
return False return False
return True return True
...@@ -279,12 +333,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -279,12 +333,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
for i in range(len(ids_dims_mapping)): for i in range(len(ids_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping( dim_changed = compute_compatible_and_update_dim_mapping(
[ids_dims_mapping, out_dims_mapping], [i, i]) [ids_dims_mapping, out_dims_mapping], [i, i]
)
if dim_changed: if dim_changed:
changed = True changed = True
dim_changed = compute_compatible_and_update_dim_mapping( dim_changed = compute_compatible_and_update_dim_mapping(
[w_dims_mapping, out_dims_mapping], [-1, -1]) [w_dims_mapping, out_dims_mapping], [-1, -1]
)
if dim_changed: if dim_changed:
changed = True changed = True
...@@ -302,26 +358,30 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -302,26 +358,30 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
src_op = dist_op_context.cur_src_op src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert (
str(src_op)) op_dist_attr is not None
), "forward op [{}] don't have dist attribute !".format(str(src_op))
# check validation of inputs / outputs # check validation of inputs / outputs
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
assert 'W' in kwargs, "input [{}] is not given".format('W') assert 'W' in kwargs, "input [{}] is not given".format('W')
assert 'Out' in kwargs, "output [{}] is not given".format('Out') assert 'Out' in kwargs, "output [{}] is not given".format('Out')
assert len( assert (
len(kwargs['Ids']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Ids'] kwargs['Ids']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( )
kwargs['Ids']) assert (
assert len( len(kwargs['W']) == 1
), "row_parallel_embedding input W take 1 variable but got {}".format(
kwargs['W'] kwargs['W']
) == 1, "row_parallel_embedding input W take 1 variable but got {}".format( )
kwargs['W']) assert (
assert len( len(kwargs['Out']) == 1
), "row_parallel_embedding output Out take 1 variable but got {}".format(
kwargs['Out'] kwargs['Out']
) == 1, "row_parallel_embedding output Out take 1 variable but got {}".format( )
kwargs['Out'])
Ids_var = main_block.var(kwargs['Ids'][0]) Ids_var = main_block.var(kwargs['Ids'][0])
Weight_var = main_block._var_recursive(kwargs['W'][0]) Weight_var = main_block._var_recursive(kwargs['W'][0])
...@@ -333,70 +393,85 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -333,70 +393,85 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# got dist attribute info # got dist attribute info
embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping( embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[0] Weight_var.name
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( )[0]
embedding_row_dim_mapping) assert (
embedding_row_dim_mapping >= 0
), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.processes
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh_group: if rank_id not in process_mesh_group:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id = _get_corresponding_rank(
rank_id) ctx, op_dist_attr.process_mesh, rank_id
)
# A generalized method to caculate embedding offset using cartisian product # A generalized method to caculate embedding offset using cartisian product
relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape, relative_idx = _get_idx_in_axis(
embedding_row_dim_mapping, rank_id) process_mesh_group,
process_mesh_shape,
embedding_row_dim_mapping,
rank_id,
)
per_part_size = Weight_var.shape[0] per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size relative_idx = relative_idx * per_part_size
# TODO caculate ring id # TODO caculate ring id
parallel_axis = embedding_row_dim_mapping parallel_axis = embedding_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(
parallel_axis, rank_id) process_mesh_group, process_mesh_shape, parallel_axis, rank_id
)
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
# append op # append op
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'], check_variable_and_dtype(
'c_embedding') Ids_var, 'input', ['int32', 'int64'], 'c_embedding'
)
# infer new var shape with op dist attr # infer new var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None assert out_var_dist_attr is not None
ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, ref_shape = infer_shape(
out_var_dist_attr) main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["c_embedding", 'tmp'])), ".".join(["c_embedding", 'tmp'])
),
dtype=Weight_var.dtype, dtype=Weight_var.dtype,
shape=Out_var.shape, shape=Out_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=Out_var.stop_gradient) stop_gradient=Out_var.stop_gradient,
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr # set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(
out_var_dist_attr) intermediate_var_0, out_var_dist_attr
)
check_variable_and_dtype( check_variable_and_dtype(
Out_var, 'tensor', Out_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum') 'c_allreduce_sum',
)
c_embedding_op = main_block.append_op( c_embedding_op = main_block.append_op(
type='c_embedding', type='c_embedding',
inputs={ inputs={'Ids': [Ids_var], 'W': [Weight_var]},
'Ids': [Ids_var],
'W': [Weight_var]
},
outputs={'Out': [intermediate_var_0]}, outputs={'Out': [intermediate_var_0]},
attrs={ attrs={
"start_index": relative_idx, "start_index": relative_idx,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
}) },
)
if intermediate_var_0.shape != ref_shape: if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape) intermediate_var_0.desc.set_shape(ref_shape)
...@@ -409,8 +484,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -409,8 +484,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role'),
}) },
)
if Out_var.shape != ref_shape: if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape) Out_var.desc.set_shape(ref_shape)
...@@ -423,15 +499,19 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -423,15 +499,19 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
for input_varname in c_embedding_op.desc.input_arg_names(): for input_varname in c_embedding_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format( assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
embedding_op_dist_attr.set_input_dist_attr(input_varname, )
input_dist_attr) embedding_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
output_varname = c_embedding_op.desc.output_arg_names()[0] output_varname = c_embedding_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
embedding_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) embedding_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr) ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr)
# allreduce # allreduce
...@@ -443,16 +523,20 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -443,16 +523,20 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
input_var = main_block.var(input_varname) input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(input_varname, allreduce_op_dist_attr.set_input_dist_attr(
tensor_dist_attr) input_varname, tensor_dist_attr
)
for output_varname in c_allreduce_sum_op.desc.output_arg_names(): for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format( assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr) op_dist_attr
allreduce_op_dist_attr.set_output_dist_attr(output_varname, )
output_dist_attr) allreduce_op_dist_attr.set_output_dist_attr(
ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, output_varname, output_dist_attr
allreduce_op_dist_attr) )
ctx.set_op_dist_attr_for_program(
c_allreduce_sum_op, allreduce_op_dist_attr
)
# param initialization sync # param initialization sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
...@@ -469,20 +553,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -469,20 +553,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
if size <= 1 or axis in dim_mapping: if size <= 1 or axis in dim_mapping:
pass pass
else: else:
group_ranks = _get_comm_group(process_mesh.processes, group_ranks = _get_comm_group(
process_mesh.topology, axis, process_mesh.processes,
rank_id) process_mesh.topology,
axis,
rank_id,
)
sync_group = new_process_group(group_ranks) sync_group = new_process_group(group_ranks)
startup_block.append_op(type='c_broadcast', startup_block.append_op(
inputs={'X': param}, type='c_broadcast',
outputs={'Out': param}, inputs={'X': param},
attrs={ outputs={'Out': param},
'ring_id': sync_group.id, attrs={
'root': 0, 'ring_id': sync_group.id,
'use_calc_stream': True, 'root': 0,
OP_ROLE_KEY: OpRole.Forward 'use_calc_stream': True,
}) OP_ROLE_KEY: OpRole.Forward,
},
)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -493,35 +582,43 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -493,35 +582,43 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
backward_op = dist_op_context.cur_src_op backward_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
dist_attr = ctx.get_op_dist_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert (
str(backward_op)) dist_attr is not None
), "backward op [{}] don't have dist attribute !".format(
str(backward_op)
)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.process_mesh.processes: if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id = _get_corresponding_rank(
rank_id) ctx, dist_attr.process_mesh, rank_id
)
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
assert 'W' in kwargs, "input [{}] is not given".format('W') assert 'W' in kwargs, "input [{}] is not given".format('W')
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out') assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out')
assert 'W@GRAD' in kwargs, "output [{}] is not given".format('W@GRAD') assert 'W@GRAD' in kwargs, "output [{}] is not given".format('W@GRAD')
assert len( assert (
len(kwargs['Ids']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Ids'] kwargs['Ids']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( )
kwargs['Ids']) assert (
assert len( len(kwargs['W']) == 1
), "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['W'] kwargs['W']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( )
kwargs['W']) assert (
assert len( len(kwargs['Out@GRAD']) == 1
kwargs['Out@GRAD'] ), "row_parallel_embedding input Ids take 1 variable but got {}".format(
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( kwargs['Out']
kwargs['Out']) )
assert len( assert (
len(kwargs['W@GRAD']) == 1
), "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['W@GRAD'] kwargs['W@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format( )
kwargs['W@GRAD'])
Ids_var = main_block.var(kwargs['Ids'][0]) Ids_var = main_block.var(kwargs['Ids'][0])
Weight_var = main_block.var(kwargs['W'][0]) Weight_var = main_block.var(kwargs['W'][0])
...@@ -529,39 +626,57 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -529,39 +626,57 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
Weight_grad = main_block.var(kwargs['W@GRAD'][0]) Weight_grad = main_block.var(kwargs['W@GRAD'][0])
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping( embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
Weight_var.name)[0] Weight_var.name
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( )[0]
embedding_row_dim_mapping) assert (
embedding_row_dim_mapping >= 0
), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping
)
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes process_mesh_group = dist_attr.process_mesh.processes
# A generalized method to caculate embedding offset using cartisian product # A generalized method to caculate embedding offset using cartisian product
relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape, relative_idx = _get_idx_in_axis(
embedding_row_dim_mapping, rank_id) process_mesh_group,
process_mesh_shape,
embedding_row_dim_mapping,
rank_id,
)
per_part_size = Weight_var.shape[0] per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size relative_idx = relative_idx * per_part_size
check_variable_and_dtype( check_variable_and_dtype(
Out_grad, 'tensor', Out_grad,
['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity',
)
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
["c_embedding", '@tmp_0@GRAD'])), ".".join(["c_embedding", '@tmp_0@GRAD'])
),
dtype=Out_grad.dtype, dtype=Out_grad.dtype,
shape=Out_grad.shape, shape=Out_grad.shape,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=Out_grad.stop_gradient) stop_gradient=Out_grad.stop_gradient,
)
# copy X_var's dist_attr to intermediate_var_0's dist_attr # copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name) out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name)
assert out_grad_dist_attr is not None assert out_grad_dist_attr is not None
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(
out_grad_dist_attr) intermediate_var_0, out_grad_dist_attr
)
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(
embedding_row_dim_mapping, rank_id) process_mesh_group,
process_mesh_shape,
embedding_row_dim_mapping,
rank_id,
)
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
c_identity_op = main_block.append_op( c_identity_op = main_block.append_op(
...@@ -573,41 +688,54 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -573,41 +688,54 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: OpRole.Backward, OP_ROLE_KEY: OpRole.Backward,
}) },
check_variable_and_dtype(intermediate_var_0, 'x', )
['float16', 'float32', 'float64'], 'linear') check_variable_and_dtype(
check_dtype(intermediate_var_0.dtype, 'dtype', intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
['float16', 'float32', 'float64'], 'linear') )
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
'linear',
)
set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh, set_comm_op_dist_attr_for_program(
out_grad_dist_attr, ctx) c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx
)
c_embedding_grad_op_desc = main_block.append_op(type='nop').desc c_embedding_grad_op_desc = main_block.append_op(type='nop').desc
c_embedding_grad_op_desc.set_type("c_embedding_grad") c_embedding_grad_op_desc.set_type("c_embedding_grad")
c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name]) c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
c_embedding_grad_op_desc.set_input('W', [Weight_var.name]) c_embedding_grad_op_desc.set_input('W', [Weight_var.name])
c_embedding_grad_op_desc.set_input('Out@GRAD', c_embedding_grad_op_desc.set_input(
[intermediate_var_0.name]) 'Out@GRAD', [intermediate_var_0.name]
)
c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name]) c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name])
c_embedding_grad_op_desc._set_attr('start_index', relative_idx) c_embedding_grad_op_desc._set_attr('start_index', relative_idx)
c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward) c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
c_embedding_grad_op = main_block.ops[-1] c_embedding_grad_op = main_block.ops[-1]
assert c_embedding_grad_op.type == "c_embedding_grad" assert c_embedding_grad_op.type == "c_embedding_grad"
naive_copy_op_dist_attr_for_program(c_embedding_grad_op, backward_op, naive_copy_op_dist_attr_for_program(
ctx) c_embedding_grad_op, backward_op, ctx
)
# data parallel gradient synchronization # data parallel gradient synchronization
act_grad_names = [Ids_var.name] act_grad_names = [Ids_var.name]
out_grad_names = [kwargs['W@GRAD'][0]] out_grad_names = [kwargs['W@GRAD'][0]]
gradient_synchronization(ctx, backward_op, act_grad_names, gradient_synchronization(
out_grad_names, rank_id) ctx, backward_op, act_grad_names, out_grad_names, rank_id
)
register_distributed_operator_impl("lookup_table_v2", register_distributed_operator_impl(
DistributedEmbeddingImpl("row_parallel")) "lookup_table_v2", DistributedEmbeddingImpl("row_parallel")
register_distributed_operator_impl("c_embedding", )
DistributedEmbeddingImpl("row_parallel")) register_distributed_operator_impl(
register_distributed_operator_impl("lookup_table", "c_embedding", DistributedEmbeddingImpl("row_parallel")
DistributedEmbeddingImpl("row_parallel")) )
register_distributed_operator_impl(
"lookup_table", DistributedEmbeddingImpl("row_parallel")
)
...@@ -25,6 +25,7 @@ from .reshard import Resharder ...@@ -25,6 +25,7 @@ from .reshard import Resharder
from .partitioner import Partitioner from .partitioner import Partitioner
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .process_group import get_world_process_group from .process_group import get_world_process_group
from .random import init_auto_parallel_rng
from ..utils.log_utils import get_logger from ..utils.log_utils import get_logger
...@@ -84,6 +85,9 @@ class Parallelizer: ...@@ -84,6 +85,9 @@ class Parallelizer:
) = partitioner.partition( ) = partitioner.partition(
serial_main_program, serial_startup_program, params_grads serial_main_program, serial_startup_program, params_grads
) )
init_auto_parallel_rng()
self._logger.debug( self._logger.debug(
"within parallel partitioner time: {}, mode {}".format( "within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode time.time() - time0, self._mode
......
...@@ -19,6 +19,8 @@ import paddle ...@@ -19,6 +19,8 @@ import paddle
# Use to store the previous and current process mesh # Use to store the previous and current process mesh
_g_previous_process_mesh = None _g_previous_process_mesh = None
_g_current_process_mesh = None _g_current_process_mesh = None
# {shape_process_ids : unique_id}
_g_unique_process_mesh_map = {}
def get_current_process_mesh(): def get_current_process_mesh():
...@@ -39,6 +41,30 @@ def reset_current_process_mesh(): ...@@ -39,6 +41,30 @@ def reset_current_process_mesh():
_g_current_process_mesh = _g_previous_process_mesh _g_current_process_mesh = _g_previous_process_mesh
def get_unique_id_for_process_mesh(shape, process_ids):
key = f"shape {shape}, process_ids {process_ids}"
global _g_unique_process_mesh_map
if key in _g_unique_process_mesh_map:
unique_id = _g_unique_process_mesh_map[key]
else:
unique_id = len(_g_unique_process_mesh_map) + 1
_g_unique_process_mesh_map[key] = unique_id
return unique_id
def retrive_unique_id_for_process_mesh(shape, process_ids):
key = f"shape {shape}, process_ids {process_ids}"
global _g_unique_process_mesh_map
assert key in _g_unique_process_mesh_map
return _g_unique_process_mesh_map[key]
def get_unique_process_mesh_map():
global _g_unique_process_mesh_map
return _g_unique_process_mesh_map
class ProcessMesh(object): class ProcessMesh(object):
""" """
The `Processmesh` object describes the topology of the used processes. The `Processmesh` object describes the topology of the used processes.
...@@ -113,6 +139,11 @@ class ProcessMesh(object): ...@@ -113,6 +139,11 @@ class ProcessMesh(object):
pg0 = get_process_group(0) pg0 = get_process_group(0)
pg0.add_ranks(self.processes) pg0.add_ranks(self.processes)
# Uniqe Mesh Id
self._unique_id = get_unique_id_for_process_mesh(
self._shape, self._process_ids
)
@property @property
def shape(self): def shape(self):
""" """
...@@ -148,6 +179,16 @@ class ProcessMesh(object): ...@@ -148,6 +179,16 @@ class ProcessMesh(object):
""" """
return self._mesh return self._mesh
@property
def unique_id(self):
"""
Get the unique id of ProcessMesh.
NOTE
Unique id only take process_ids and shape into account.
Different ProcessMesh with same process_ids and shape have same unique id.
"""
return self._unique_id
@property @property
def topology(self): def topology(self):
return self._shape return self._shape
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle
from ..utils.log_utils import get_logger
from .process_mesh import retrive_unique_id_for_process_mesh
from .utils import _get_idx_in_axis
_logger = get_logger(logging.INFO)
_rng_name_to_seed = {}
_inited_rng_name_to_seed = {}
_enable_random_control = False
_basic_seed = 42
# use Prime number as offset to avoid confict
_mesh_offset = 173
_dim_offsets = [11, 23, 37, 73]
def is_enable_auto_rand_ctrl():
global _enable_random_control
return _enable_random_control
def enable_auto_rand_ctrl():
global _enable_random_control
_enable_random_control = True
def parallel_manual_seed(seed):
"""Enable auto parallel random control.
Random control maintain the randomness when tensor is distributed across devices on a Mesh(any order).
* Independency: If tensor is **Sharded** on a Mesh dimension, Devices along that Mesh dimension should have Different randomness.
* Consistency: Meanwhile if the tensor is **Replicated** on another Mesh dimension, randomness of Devices along that Mesh dimension should be Consistent.
For instance: rank0 ~ rank7 consist a Mesh of shape of [2, 4]; A 2D tensor is distributed in that Mesh using dims_mapping [-1, 1].
Randomness for rank0-rank1-rank2-rank3 (rank4-rank5-rank6-rank7) should be Independent;
Randomness for rank0 and rank4 (rank1 and rank5, ...) should be Consistent.
This function should be called only once before auto parallel compiles the computation graph (e.g. auto_parallel.engine.prepare() or fit()).
This seed only affects how randomness-relative **operators** (dropout, fuse op with dropout inside, etc) are execute amonge mesh, and would NOT affect other processe like Parameter initialization.
Examples:
# seed relative to training step
auto_parallel_random_seed((step + 13) * 257)
...
engine.prepare()
"""
enable_auto_rand_ctrl()
global _basic_seed
_basic_seed = seed
def determinate_rng(rank, dims_mapping, process_mesh):
# TODO(JZ-LIANG) Support Mesh with any high rank
# use a string to unique integer hashing algorithm for seed computation.
# instead of using offsets to coodinate seed across devices.
if len(process_mesh.shape) > 4:
raise NotImplementedError(
"Auto Parallel Random Control for Mesh's rank > 4 is NOT supported! Got {}".format(
str(process_mesh)
)
)
global _basic_seed
seed_ = _basic_seed
# FIXME
# unique_id = process_mesh.unique_id
unique_id = retrive_unique_id_for_process_mesh(
process_mesh.shape, process_mesh.process_ids
)
sharding_expr = f'mesh:{unique_id}'
seed_ += _mesh_offset * (unique_id + 1)
for i in range(len(process_mesh.shape)):
if i not in dims_mapping:
relative_idx = -1
else:
relative_idx = _get_idx_in_axis(
process_mesh.process_ids,
process_mesh.shape,
i,
rank,
)
sharding_expr += f"_dim{i}:{relative_idx}"
seed_ += _dim_offsets[i] * (relative_idx + 1)
global _rng_name_to_seed
if sharding_expr in _rng_name_to_seed:
assert _rng_name_to_seed[sharding_expr] == seed_
else:
assert (
seed_ not in _rng_name_to_seed.values()
), "Seed Confilt! current seed: {}, current sharding expr: {}, generated seed: {}".format(
seed_, sharding_expr, _rng_name_to_seed
)
_rng_name_to_seed[sharding_expr] = seed_
return sharding_expr
def init_auto_parallel_rng():
if not is_enable_auto_rand_ctrl():
return
global _rng_name_to_seed
# NOTE init rng maybe call multiple times, avoid init same rng twice
global _inited_rng_name_to_seed
for rng_name, seed in _rng_name_to_seed.items():
if rng_name in _inited_rng_name_to_seed:
assert _inited_rng_name_to_seed[rng_name] == seed
else:
_logger.info(
f"Init Auto Parallel RNG: {rng_name}, with seed {seed}"
)
paddle.framework.random.set_random_seed_generator(rng_name, seed)
_inited_rng_name_to_seed[rng_name] = seed
...@@ -22,13 +22,20 @@ from paddle.fluid.framework import Variable, Operator ...@@ -22,13 +22,20 @@ from paddle.fluid.framework import Variable, Operator
from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name
from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_ from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import (
from paddle.distributed.auto_parallel.utils import get_loss_op, set_var_dist_attr, set_dist_op_desc_original_id OperatorDistributedAttribute,
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping )
from paddle.distributed.auto_parallel.utils import (
get_loss_op,
set_var_dist_attr,
set_dist_op_desc_original_id,
)
from paddle.distributed.auto_parallel.utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
)
class RecomputeState(ProgramStats): class RecomputeState(ProgramStats):
def __init__(self, block, ops): def __init__(self, block, ops):
super(RecomputeState, self).__init__(block=block, ops=ops) super(RecomputeState, self).__init__(block=block, ops=ops)
self._block = block self._block = block
...@@ -54,7 +61,7 @@ class RecomputeState(ProgramStats): ...@@ -54,7 +61,7 @@ class RecomputeState(ProgramStats):
self.var_op_deps[name]["var_as_output_ops"] = [i] self.var_op_deps[name]["var_as_output_ops"] = [i]
def get_recompute_segments(self, checkpoints): def get_recompute_segments(self, checkpoints):
""" get recompute segments from checkpoints """ """get recompute segments from checkpoints"""
segments = [] segments = []
start_idx = -1 start_idx = -1
pre_segment_end_idx = -1 pre_segment_end_idx = -1
...@@ -69,33 +76,43 @@ class RecomputeState(ProgramStats): ...@@ -69,33 +76,43 @@ class RecomputeState(ProgramStats):
segments.append([0, max(op_idx_list) + 1]) segments.append([0, max(op_idx_list) + 1])
else: else:
flag, min_idx, max_idx = self.is_subgraph( flag, min_idx, max_idx = self.is_subgraph(
[checkpoints[start_idx]], [checkpoints[start_idx + 1]]) [checkpoints[start_idx]], [checkpoints[start_idx + 1]]
)
if flag: if flag:
min_idx = self._update_segment_start( min_idx = self._update_segment_start(
min_idx, pre_segment_end_idx) min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1]) segments.append([min_idx, max_idx + 1])
else: else:
logging.info( logging.info(
"Could not recompute op range [{}] - [{}] ".format( "Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1)) min_idx, max_idx + 1
)
)
start_idx += 1 start_idx += 1
for i, (idx1, idx2) in enumerate(segments): for i, (idx1, idx2) in enumerate(segments):
logging.info("recompute segment[{}]".format(i)) logging.info("recompute segment[{}]".format(i))
logging.info("segment start op: [{}]: [{}] [{}]".format( logging.info(
self._ops[idx1].desc.type(), "segment start op: [{}]: [{}] [{}]".format(
self._ops[idx1].desc.input_arg_names(), self._ops[idx1].desc.type(),
self._ops[idx1].desc.output_arg_names())) self._ops[idx1].desc.input_arg_names(),
logging.info("segment end op: [{}]: [{}] [{}]".format( self._ops[idx1].desc.output_arg_names(),
self._ops[idx2 - 1].desc.type(), )
self._ops[idx2 - 1].desc.input_arg_names(), )
self._ops[idx2 - 1].desc.output_arg_names())) logging.info(
"segment end op: [{}]: [{}] [{}]".format(
self._ops[idx2 - 1].desc.type(),
self._ops[idx2 - 1].desc.input_arg_names(),
self._ops[idx2 - 1].desc.output_arg_names(),
)
)
return segments return segments
def modify_forward_desc_for_recompute(self, dist_context): def modify_forward_desc_for_recompute(self, dist_context):
""" """
If program's foward part has 'dropout' op, this function will insert If program's foward part has 'dropout' op, this function will insert
a seed op before it to guarantee that two dropout op have the same outputs. a seed op before it to guarantee that two dropout op have the same outputs.
""" """
op_types = [op.desc.type() for op in self._ops] op_types = [op.desc.type() for op in self._ops]
...@@ -116,45 +133,52 @@ class RecomputeState(ProgramStats): ...@@ -116,45 +133,52 @@ class RecomputeState(ProgramStats):
cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op) cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op)
# insert seed op to guarantee that two dropout op have the same outputs # insert seed op to guarantee that two dropout op have the same outputs
op_unique_name = unique_name.generate("seed") # NOTE Hack for adopt recompute for random control, for more info see dist_dropout.py
var_unique_name = unique_name.generate_with_ignorable_key(".".join( # new seed added by recompute should have a prefix to distinguish with seed added by user or other moudule.
[op_unique_name, 'tmp'])) op_unique_name = unique_name.generate("rc_seed")
var_unique_name = unique_name.generate_with_ignorable_key(
".".join([op_unique_name, 'tmp'])
)
seed_var = self._block.create_var( seed_var = self._block.create_var(
name=var_unique_name, name=var_unique_name,
dtype='int32', dtype='int32',
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=False) stop_gradient=False,
)
# set new seed_var's dist_attr # set new seed_var's dist_attr
ref_dims_mapping = [-1] ref_dims_mapping = [-1]
ref_process_mesh = cur_op_dist_attr.process_mesh ref_process_mesh = cur_op_dist_attr.process_mesh
seed_var_dist_attr = set_var_dist_attr(dist_context, seed_var, seed_var_dist_attr = set_var_dist_attr(
ref_dims_mapping, dist_context, seed_var, ref_dims_mapping, ref_process_mesh
ref_process_mesh) )
seed = 0 if cur_op.attr("fix_seed") is False else int( seed = (
cur_op.attr("seed")) 0
if cur_op.attr("fix_seed") is False
else int(cur_op.attr("seed"))
)
seed_op = self._block._insert_op_without_sync( seed_op = self._block._insert_op_without_sync(
index=cur_op.idx, index=cur_op.idx,
type="seed", type="seed",
inputs={}, inputs={},
outputs={"Out": seed_var}, outputs={"Out": seed_var},
attrs={ attrs={"seed": seed, "force_cpu": True},
"seed": seed, )
"force_cpu": True
})
# set new seed op's dist_attr # set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, ref_process_mesh, ref_dims_mapping, dist_context) seed_op, ref_process_mesh, ref_dims_mapping, dist_context
)
# modify dropout op's desc # modify dropout op's desc
self._ops.insert(op_idx, seed_op) self._ops.insert(op_idx, seed_op)
cur_op.desc.set_input("Seed", [var_unique_name]) cur_op.desc.set_input("Seed", [var_unique_name])
cur_op._remove_attr("fix_seed") cur_op._remove_attr("fix_seed")
cur_op._remove_attr("seed") cur_op._remove_attr("seed")
cur_op_dist_attr.set_input_dist_attr(seed_var.name, cur_op_dist_attr.set_input_dist_attr(
seed_var_dist_attr) seed_var.name, seed_var_dist_attr
)
op_idx += 2 op_idx += 2
self._block._sync_with_cpp() self._block._sync_with_cpp()
...@@ -168,7 +192,7 @@ def _find_op_index(block, cur_op): ...@@ -168,7 +192,7 @@ def _find_op_index(block, cur_op):
def _get_stop_gradients(program, no_grad_set): def _get_stop_gradients(program, no_grad_set):
""" get no grad var """ """get no grad var"""
if no_grad_set is None: if no_grad_set is None:
no_grad_set = set() no_grad_set = set()
else: else:
...@@ -185,8 +209,9 @@ def _get_stop_gradients(program, no_grad_set): ...@@ -185,8 +209,9 @@ def _get_stop_gradients(program, no_grad_set):
return no_grad_set_name return no_grad_set_name
def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars, def _add_needed_descs_to_block(
dist_context): descs, block, main_block, in_memory_vars, dist_context
):
""" """
Get the recomputed ops which will insert the backward part Get the recomputed ops which will insert the backward part
""" """
...@@ -217,7 +242,6 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars, ...@@ -217,7 +242,6 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars,
@register_pass("auto_parallel_recompute") @register_pass("auto_parallel_recompute")
class RecomputePass(PassBase): class RecomputePass(PassBase):
def __init__(self): def __init__(self):
super(RecomputePass, self).__init__() super(RecomputePass, self).__init__()
self.set_attr("checkpoints", None) self.set_attr("checkpoints", None)
...@@ -261,12 +285,15 @@ class RecomputePass(PassBase): ...@@ -261,12 +285,15 @@ class RecomputePass(PassBase):
vars_should_be_hold = [] vars_should_be_hold = []
for segment in segments: for segment in segments:
vars_should_be_hold.extend( vars_should_be_hold.extend(
rc_state.get_out_of_subgraph_vars(segment[0], segment[1])) rc_state.get_out_of_subgraph_vars(segment[0], segment[1])
)
cross_vars = set(vars_should_be_hold) - set(checkpoints) cross_vars = set(vars_should_be_hold) - set(checkpoints)
logging.info( logging.info(
"found [{}] vars which cross recompute segment: [{}]," "found [{}] vars which cross recompute segment: [{}],"
"better checkpoints might be set to reduce those vars".format( "better checkpoints might be set to reduce those vars".format(
len(cross_vars), cross_vars)) len(cross_vars), cross_vars
)
)
vars_should_be_hold.extend(rc_state.get_reserved_vars()) vars_should_be_hold.extend(rc_state.get_reserved_vars())
vars_should_be_hold.extend(rc_state.get_input_nodes()) vars_should_be_hold.extend(rc_state.get_input_nodes())
vars_should_be_hold = list(set(vars_should_be_hold)) vars_should_be_hold = list(set(vars_should_be_hold))
...@@ -277,14 +304,15 @@ class RecomputePass(PassBase): ...@@ -277,14 +304,15 @@ class RecomputePass(PassBase):
ckpt_ops_dict = {} ckpt_ops_dict = {}
buffer_block = main_block.program._create_block() buffer_block = main_block.program._create_block()
for i, segment in enumerate(segments[::-1]): for i, segment in enumerate(segments[::-1]):
fwd_ops = op_path[segment[0]:segment[1]] fwd_ops = op_path[segment[0] : segment[1]]
var_suffix = ".subprog_%d" % i var_suffix = ".subprog_%d" % i
for op in fwd_ops: for op in fwd_ops:
input_and_output_names = [] input_and_output_names = []
input_and_output_names.extend(op.desc.input_arg_names()) input_and_output_names.extend(op.desc.input_arg_names())
input_and_output_names.extend(op.desc.output_arg_names()) input_and_output_names.extend(op.desc.output_arg_names())
cur_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( cur_op_dist_attr = (
op) self._dist_context.get_op_dist_attr_for_program(op)
)
assert cur_op_dist_attr is not None assert cur_op_dist_attr is not None
for name in input_and_output_names: for name in input_and_output_names:
if main_block.var(name).persistable or name in checkpoints: if main_block.var(name).persistable or name in checkpoints:
...@@ -294,11 +322,13 @@ class RecomputePass(PassBase): ...@@ -294,11 +322,13 @@ class RecomputePass(PassBase):
if name not in var_name_dict: if name not in var_name_dict:
ref_process_mesh = cur_op_dist_attr.process_mesh ref_process_mesh = cur_op_dist_attr.process_mesh
if name in op.desc.input_arg_names(): if name in op.desc.input_arg_names():
ref_dims_mapping = cur_op_dist_attr.get_input_dims_mapping( ref_dims_mapping = (
name) cur_op_dist_attr.get_input_dims_mapping(name)
)
else: else:
ref_dims_mapping = cur_op_dist_attr.get_output_dims_mapping( ref_dims_mapping = (
name) cur_op_dist_attr.get_output_dims_mapping(name)
)
# record recomputed var's old_name and new_name (old_name.subprog_XXX) # record recomputed var's old_name and new_name (old_name.subprog_XXX)
# create new var with new name # create new var with new name
var_name_dict[name] = name + var_suffix var_name_dict[name] = name + var_suffix
...@@ -309,15 +339,23 @@ class RecomputePass(PassBase): ...@@ -309,15 +339,23 @@ class RecomputePass(PassBase):
dtype=ref_var.dtype, dtype=ref_var.dtype,
type=ref_var.type, type=ref_var.type,
persistable=ref_var.persistable, persistable=ref_var.persistable,
stop_gradient=ref_var.stop_gradient) stop_gradient=ref_var.stop_gradient,
)
# set new recomputed var's dist attr # set new recomputed var's dist attr
set_var_dist_attr(self._dist_context, rc_var, set_var_dist_attr(
ref_dims_mapping, ref_process_mesh) self._dist_context,
rc_var,
ref_dims_mapping,
ref_process_mesh,
)
# get recomputed segment's descs # get recomputed segment's descs
segment_descs = _add_needed_descs_to_block(fwd_ops, buffer_block, segment_descs = _add_needed_descs_to_block(
main_block, fwd_ops,
vars_in_memory, buffer_block,
self._dist_context) main_block,
vars_in_memory,
self._dist_context,
)
# rename recomputed ops' input and output var name # rename recomputed ops' input and output var name
for key in var_name_dict: for key in var_name_dict:
_rename_arg_(segment_descs, key, var_name_dict[key]) _rename_arg_(segment_descs, key, var_name_dict[key])
...@@ -345,7 +383,10 @@ class RecomputePass(PassBase): ...@@ -345,7 +383,10 @@ class RecomputePass(PassBase):
# rename grad op's var_name which is not in 'vars_in_memory' # rename grad op's var_name which is not in 'vars_in_memory'
for key in var_name_dict: for key in var_name_dict:
if key not in grad_op.input_arg_names + grad_op.output_arg_names: if (
key
not in grad_op.input_arg_names + grad_op.output_arg_names
):
continue continue
self.reset_op_dist_attr(grad_op, var_name_dict) self.reset_op_dist_attr(grad_op, var_name_dict)
_rename_arg_([grad_op.desc], key, var_name_dict[key]) _rename_arg_([grad_op.desc], key, var_name_dict[key])
...@@ -360,17 +401,20 @@ class RecomputePass(PassBase): ...@@ -360,17 +401,20 @@ class RecomputePass(PassBase):
idx -= 1 idx -= 1
segment_descs = ckpt_ops_dict[fwd_op_id][1] segment_descs = ckpt_ops_dict[fwd_op_id][1]
for _, op_desc in reversed(list(enumerate(segment_descs))): for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_op = main_block._insert_op_without_sync(idx, rc_op = main_block._insert_op_without_sync(
type='nop') idx, type='nop'
)
rc_desc = rc_op.desc rc_desc = rc_op.desc
rc_desc.copy_from(op_desc) rc_desc.copy_from(op_desc)
rc_desc.set_original_id(rc_desc.id()) rc_desc.set_original_id(rc_desc.id())
# set recomputed ops' dist attr # set recomputed ops' dist attr
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id( fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
op_desc.original_id()) op_desc.original_id()
)
assert fwd_op_dist_attr is not None assert fwd_op_dist_attr is not None
self.set_op_dist_attr(rc_op, fwd_op_dist_attr, self.set_op_dist_attr(
var_name_dict) rc_op, fwd_op_dist_attr, var_name_dict
)
ckpt_ops_dict[fwd_op_id][0] = False ckpt_ops_dict[fwd_op_id][0] = False
...@@ -382,13 +426,15 @@ class RecomputePass(PassBase): ...@@ -382,13 +426,15 @@ class RecomputePass(PassBase):
for input in op.desc.input_arg_names(): for input in op.desc.input_arg_names():
if input in var_name_dict.keys(): if input in var_name_dict.keys():
in_dist_attr = op_dist_attr.get_input_dist_attr(input) in_dist_attr = op_dist_attr.get_input_dist_attr(input)
op_dist_attr.set_input_dist_attr(var_name_dict[input], op_dist_attr.set_input_dist_attr(
in_dist_attr) var_name_dict[input], in_dist_attr
)
for output in op.desc.output_arg_names(): for output in op.desc.output_arg_names():
if output in var_name_dict.keys(): if output in var_name_dict.keys():
out_dist_attr = op_dist_attr.get_output_dist_attr(output) out_dist_attr = op_dist_attr.get_output_dist_attr(output)
op_dist_attr.set_output_dist_attr(var_name_dict[output], op_dist_attr.set_output_dist_attr(
out_dist_attr) var_name_dict[output], out_dist_attr
)
def set_op_dist_attr(self, op, old_dist_attr, var_name_dict): def set_op_dist_attr(self, op, old_dist_attr, var_name_dict):
new_dist_attr = OperatorDistributedAttribute() new_dist_attr = OperatorDistributedAttribute()
...@@ -399,16 +445,18 @@ class RecomputePass(PassBase): ...@@ -399,16 +445,18 @@ class RecomputePass(PassBase):
for input in old_dist_attr.inputs_dist_attrs.keys(): for input in old_dist_attr.inputs_dist_attrs.keys():
if input in var_name_dict.keys(): if input in var_name_dict.keys():
in_dist_attr = old_dist_attr.inputs_dist_attrs[input] in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
new_dist_attr.set_input_dist_attr(var_name_dict[input], new_dist_attr.set_input_dist_attr(
in_dist_attr) var_name_dict[input], in_dist_attr
)
else: else:
in_dist_attr = old_dist_attr.inputs_dist_attrs[input] in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
new_dist_attr.set_input_dist_attr(input, in_dist_attr) new_dist_attr.set_input_dist_attr(input, in_dist_attr)
for output in old_dist_attr.outputs_dist_attrs.keys(): for output in old_dist_attr.outputs_dist_attrs.keys():
if output in var_name_dict.keys(): if output in var_name_dict.keys():
out_dist_attr = old_dist_attr.outputs_dist_attrs[output] out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
new_dist_attr.set_output_dist_attr(var_name_dict[output], new_dist_attr.set_output_dist_attr(
out_dist_attr) var_name_dict[output], out_dist_attr
)
else: else:
out_dist_attr = old_dist_attr.outputs_dist_attrs[output] out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
new_dist_attr.set_output_dist_attr(output, out_dist_attr) new_dist_attr.set_output_dist_attr(output, out_dist_attr)
......
...@@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
${dist_ENVS}) ${dist_ENVS})
set_tests_properties(test_high_order_grad set_tests_properties(test_high_order_grad
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS})
set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS}) ${dist_ENVS})
set_tests_properties(test_iterable_dataset set_tests_properties(test_iterable_dataset
......
...@@ -21,14 +21,17 @@ from paddle.distributed.fleet import auto ...@@ -21,14 +21,17 @@ from paddle.distributed.fleet import auto
sys.path.append("..") sys.path.append("..")
import auto_parallel_gpt_model as modeling import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion from auto_parallel_gpt_model import (
GPTModel,
GPTForPretraining,
GPTPretrainingCriterion,
)
sequence_len = 512 sequence_len = 512
vocab_size = 1000 vocab_size = 1000
class FakeDataset(paddle.io.Dataset): class FakeDataset(paddle.io.Dataset):
def __init__(self, num_samples): def __init__(self, num_samples):
self.num_samples = num_samples self.num_samples = num_samples
self.sequence_len = sequence_len self.sequence_len = sequence_len
...@@ -40,8 +43,11 @@ class FakeDataset(paddle.io.Dataset): ...@@ -40,8 +43,11 @@ class FakeDataset(paddle.io.Dataset):
random.seed(2021) random.seed(2021)
tokens = np.random.randint(self.vocab_size, size=self.sequence_len) tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len) position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape( attention_mask = (
(1, self.sequence_len, self.sequence_len)).astype(np.float32) np.tril(np.ones(self.sequence_len))
.reshape((1, self.sequence_len, self.sequence_len))
.astype(np.float32)
)
labels = np.random.randint(self.vocab_size, size=self.sequence_len) labels = np.random.randint(self.vocab_size, size=self.sequence_len)
loss_mask = np.ones(self.sequence_len).astype(np.float32) loss_mask = np.ones(self.sequence_len).astype(np.float32)
return tokens, position_ids, attention_mask, labels, loss_mask return tokens, position_ids, attention_mask, labels, loss_mask
...@@ -51,30 +57,32 @@ class FakeDataset(paddle.io.Dataset): ...@@ -51,30 +57,32 @@ class FakeDataset(paddle.io.Dataset):
def create_data_holder(batch_size): def create_data_holder(batch_size):
tokens = paddle.static.InputSpec(name="tokens", tokens = paddle.static.InputSpec(
shape=[batch_size, sequence_len], name="tokens", shape=[batch_size, sequence_len], dtype='int64'
dtype='int64') )
position_ids = paddle.static.InputSpec(name="position_ids", position_ids = paddle.static.InputSpec(
shape=[batch_size, sequence_len], name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
dtype='int64') )
attention_mask = paddle.static.InputSpec( attention_mask = paddle.static.InputSpec(
name="attention_mask", name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len], shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32') dtype='float32',
labels = paddle.static.InputSpec(name="labels", )
shape=[batch_size, sequence_len], labels = paddle.static.InputSpec(
dtype='int64') name="labels", shape=[batch_size, sequence_len], dtype='int64'
loss_mask = paddle.static.InputSpec(name="loss_mask", )
shape=[batch_size, sequence_len], loss_mask = paddle.static.InputSpec(
dtype='float32') name="loss_mask", shape=[batch_size, sequence_len], dtype='float32'
)
return [tokens, position_ids, attention_mask], [labels, loss_mask] return [tokens, position_ids, attention_mask], [labels, loss_mask]
def generate_model(strategy): def generate_model(strategy, dropout_prob=0.0):
modeling.init_global() modeling.init_global()
ranks = list(range(paddle.distributed.get_world_size())) ranks = list(range(paddle.distributed.get_world_size()))
modeling._global_process_mesh = auto.ProcessMesh(mesh=ranks, modeling._global_process_mesh = auto.ProcessMesh(
dim_names=["x"]) mesh=ranks, dim_names=["x"]
)
if strategy == "serial": if strategy == "serial":
modeling._global_parallel_strategy = "serial" modeling._global_parallel_strategy = "serial"
elif strategy == "mp": elif strategy == "mp":
...@@ -84,24 +92,25 @@ def generate_model(strategy): ...@@ -84,24 +92,25 @@ def generate_model(strategy):
else: else:
raise ValueError("Only support serial, mp2 and dp2.") raise ValueError("Only support serial, mp2 and dp2.")
gpt = GPTModel(vocab_size=1000, gpt = GPTModel(
hidden_size=64, vocab_size=1000,
num_hidden_layers=2, hidden_size=64,
num_attention_heads=8, num_hidden_layers=2,
intermediate_size=256, num_attention_heads=8,
hidden_act="gelu", intermediate_size=256,
hidden_dropout_prob=0.0, hidden_act="gelu",
attention_probs_dropout_prob=0.0, hidden_dropout_prob=dropout_prob,
max_position_embeddings=1024, attention_probs_dropout_prob=dropout_prob,
type_vocab_size=1, max_position_embeddings=1024,
initializer_range=0.02, type_vocab_size=1,
pad_token_id=0, initializer_range=0.02,
eos_token_id=7, pad_token_id=0,
bos_token_id=0, eos_token_id=7,
eol_token_id=3) bos_token_id=0,
model = GPTForPretraining(gpt, eol_token_id=3,
vocab_size=1000, )
hidden_size=64, model = GPTForPretraining(
initializer_range=0.02) gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
)
criterion = GPTPretrainingCriterion() criterion = GPTPretrainingCriterion()
return model, criterion return model, criterion
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
paddle.enable_static()
from paddle import _legacy_C_ops
from paddle.distributed.fleet import auto
def dy_broadcast_helper(tensor):
_legacy_C_ops.c_broadcast(
tensor, tensor, 'root', 1, 'use_calc_stream', True, 'ring_id', 1000
)
_legacy_C_ops.c_sync_calc_stream(tensor, tensor)
def apply_pass(use_recompute=False, no_recompute_segments=[]):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
recompute.no_recompute_segments = no_recompute_segments
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestRandomControl(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
paddle.distributed.auto_parallel.parallel_manual_seed(100)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_recompute=False, no_recompute_segments=[]):
reset_prog()
strategy = apply_pass(use_recompute, no_recompute_segments)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp", dropout_prob=0.1)
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def compare_mask_between_ranks(
self, rank, mask_np_list, comapre_idx, equal
):
for np_mask in [mask_np_list[i] for i in comapre_idx]:
mask_tensor_local = paddle.to_tensor(np_mask.astype("float32"))
if rank == 0:
mask_tensor_remote = paddle.ones_like(mask_tensor_local)
dy_broadcast_helper(mask_tensor_remote)
if equal:
assert np.array_equal(
mask_tensor_remote.numpy(), mask_tensor_local.numpy()
)
else:
assert not np.array_equal(
mask_tensor_remote.numpy(),
mask_tensor_local.numpy(),
)
else:
dy_broadcast_helper(mask_tensor_local)
def test_random_ctrl_vanilla(self):
# mp2 recompute training
rc_engine = self.get_engine(False)
train_dataloader = rc_engine.dataloader(
self.dataset,
batch_size=self.batch_size,
mode="train",
sample_split=3,
)
rc_engine.prepare(mode="train")
mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)]
mask_var_list = [
rc_engine.main_program.global_block().var(varname)
for varname in mask_name_list
]
for data in train_dataloader:
outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train")
mask_np_list = [outs['fetches'][varname] for varname in mask_name_list]
paddle.disable_static()
rank = paddle.distributed.get_rank()
# check globl mask consistent across ranks
global_index = [0, 2, 3, 5, 6]
self.compare_mask_between_ranks(
rank, mask_np_list, global_index, equal=True
)
local_index = [1, 4]
# check loacl mask different across ranks
self.compare_mask_between_ranks(
rank, mask_np_list, local_index, equal=False
)
paddle.enable_static()
# check program
ops = rc_engine.main_program.global_block().ops
rng_names = []
seed_var_names = []
for op in ops:
if op.type == "seed":
rng_names.append(op.attr('rng_name'))
if op.type == "dropout":
seed_var_names.append(op.input("Seed")[0])
rank = paddle.distributed.get_rank()
self.assertEqual(
rng_names,
[
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
],
)
self.assertEqual(
seed_var_names,
[
'tensor_parallel_seed.tmp_0',
'tensor_parallel_seed.tmp_1',
'tensor_parallel_seed.tmp_2',
'tensor_parallel_seed.tmp_3',
'tensor_parallel_seed.tmp_4',
'tensor_parallel_seed.tmp_5',
'tensor_parallel_seed.tmp_6',
],
)
def test_random_ctrl_with_recompute(self):
# mp2 recompute training
rc_engine = self.get_engine(True)
train_dataloader = rc_engine.dataloader(
self.dataset,
batch_size=self.batch_size,
mode="train",
sample_split=3,
)
rc_engine.prepare(mode="train")
mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)]
recompute_mask_name_list = [
'dropout_0.tmp_1.subprog_1',
'dropout_1.tmp_1.subprog_1',
'dropout_2.tmp_1.subprog_1',
'dropout_3.tmp_1.subprog_1',
'dropout_4.tmp_1.subprog_0',
'dropout_5.tmp_1.subprog_0',
'dropout_6.tmp_1.subprog_0',
]
mask_var_list = [
rc_engine.main_program.global_block().var(varname)
for varname in mask_name_list + recompute_mask_name_list
]
for data in train_dataloader:
outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train")
mask_np_list = [
outs['fetches'][varname]
for varname in mask_name_list + recompute_mask_name_list
]
# check recompute is mask the same within local device
for i in range(7):
mask_fw = mask_np_list[i].astype("float32")
mask_rc = mask_np_list[i + 7].astype("float32")
assert np.array_equal(
mask_fw,
mask_rc,
)
paddle.disable_static()
# check globl mask consistent across ranks
rank = paddle.distributed.get_rank()
global_index = [0, 2, 3, 5, 6]
self.compare_mask_between_ranks(
rank, mask_np_list, global_index, equal=True
)
local_index = [1, 4]
# check loacl mask different across ranks
self.compare_mask_between_ranks(
rank, mask_np_list, local_index, equal=False
)
paddle.enable_static()
# check program
rank = paddle.distributed.get_rank()
ops = rc_engine.main_program.global_block().ops
rng_names = []
seed_var_names = []
for op in ops:
if op.type == "seed":
rng_names.append(op.attr('rng_name'))
if op.type == "dropout":
seed_var_names.append(op.input("Seed")[0])
self.assertEqual(
rng_names,
[
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
],
)
self.assertEqual(
seed_var_names,
[
'rc_seed_0.tmp_0',
'rc_seed_1.tmp_0',
'rc_seed_2.tmp_0',
'rc_seed_3.tmp_0',
'rc_seed_4.tmp_0',
'rc_seed_5.tmp_0',
'rc_seed_6.tmp_0',
'rc_seed_4.tmp_0',
'rc_seed_5.tmp_0',
'rc_seed_6.tmp_0',
'rc_seed_0.tmp_0',
'rc_seed_1.tmp_0',
'rc_seed_2.tmp_0',
'rc_seed_3.tmp_0',
],
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestRandomCtrlPass(unittest.TestCase):
def test_mp2_with_recompute(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "random_control_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册