未验证 提交 025053b4 编写于 作者: Z zhaoyingli 提交者: GitHub

Adapt auto search (#37490)

* adapt auto search

* adapt auto search

* fix matmulv2 compatible

* del debug
上级 5ff1ff5a
......@@ -273,6 +273,7 @@ message DistributedStrategy {
optional bool fuse_grad_merge = 34 [ default = false ];
optional bool semi_auto = 35 [ default = false ];
optional bool adam_d2sum = 36 [ default = true ];
optional bool auto_search = 37 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......
......@@ -715,6 +715,27 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
grad_op_dist_attr.process_mesh = forward_op_process_mesh
# var
for input_name in grad_op.input_arg_names:
input_var = vars[input_name]
ref_dims_mapping = None
if "@GRAD" in input_name:
forward_name = _get_forward_varname_from_grad_varname(
input_name)
ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
forward_name)
else:
if forward_op_dist_attr.get_input_dims_mapping(input_name):
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
input_name)
else:
ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
input_name)
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_var.name)
grad_op_dist_attr.set_input_dims_mapping(input_name,
ref_dims_mapping)
for output_name in grad_op.desc.output_names():
assert len(grad_op.desc.output(output_name)) in [0, 1]
if _is_grad_var_name(output_name):
......@@ -726,41 +747,25 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
]
input_name = "X"
assert input_name in forward_op.desc.input_names(
), "var [{}] in op [{}]'s output but coulf not find [{}] in its forward op".format(
), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format(
output_name, grad_op.type, input_name)
if len(grad_op.desc.output(output_name)) == 1:
assert len(forward_op.desc.input(input_name)) == 1
input_var = vars[forward_op.desc.input(input_name)[0]]
input_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var)
assert input_var_dist_attr is not None, "[{}] has not dist attribute".format(
input_var.name)
ref_dims_mapping = input_var_dist_attr.dims_mapping
# tensor dist attr
output_var = vars[grad_op.desc.output(output_name)[0]]
forward_name = _get_forward_varname_from_grad_varname(
output_var.name)
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
forward_name)
output_var_dist_attr = TensorDistributedAttribute()
output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = forward_op_process_mesh
dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr)
# op dist attr
grad_op_dist_attr.set_output_dims_mapping(output_var.name,
ref_dims_mapping)
for input_name in grad_op.input_arg_names:
input_var = vars[input_name]
input_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var)
assert input_var_dist_attr is not None, "[{}] has not dist attribute".format(
input_var.name)
ref_dims_mapping = input_var_dist_attr.dims_mapping
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_var.name)
grad_op_dist_attr.set_input_dims_mapping(input_name,
ref_dims_mapping)
dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_dist_attr)
......@@ -828,13 +833,7 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param)
grad_dist_attr = dist_context.get_tensor_dist_attr_for_program(
grad_var)
assert param_dist_attr is not None
assert grad_dist_attr is not None
assert param_dist_attr.dims_mapping == grad_dist_attr.dims_mapping
ref_process_mesh = dist_context.get_tensor_dist_attr_for_program(
param).process_mesh
assert ref_process_mesh is not None
......
......@@ -335,6 +335,17 @@ class DistributedContext:
dist_op.serial_op.type, dist_tensor.dist_attr)
return True
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
class DistributedOperatorContext:
"""
......@@ -352,6 +363,17 @@ class DistributedOperatorContext:
self.gradopidx2opidx = {}
self.already_init_sync_vars = set()
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_dst_main_program" or k == "_dst_startup_program" or k == "_cur_src_op":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
def set_dst_main_program(self, prog):
self._dst_main_program = prog
......
......@@ -219,6 +219,17 @@ class DistributedOperator:
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_op" or k == "_serial_inputs" or k == "_serial_outputs":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
class DistributedModule:
def __init__(self, serial_module, dist_attr=None):
......
......@@ -66,6 +66,17 @@ class DistributedTensor:
return False
return True
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_tensor":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.serial_tensor.desc.name(), self.serial_tensor.desc.id())
......
......@@ -111,37 +111,27 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
return best_compatible_impl, idx
def copy_distributed_attr_for_var(dist_context, dst_var, src_var):
"""
copy src var's dist_attr to dst var
"""
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
dist_context.set_tensor_dist_attr_for_program(dst_var, dist_attr)
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
var_shape = block.var(src_var.name).shape
var_topoloy = src_var_dist_attr.process_mesh.topology
var_dims_mapping = src_var_dist_attr.dims_mapping
complete_shape = []
for idx, shape in enumerate(var_shape):
if var_dims_mapping[idx] == -1:
complete_shape.append(shape)
else:
new_shape = shape * var_topoloy[var_dims_mapping[idx]]
complete_shape.append(new_shape)
exact_shape = []
input_topology = op_input_dist_attr.process_mesh.topology
input_dims_mapping = op_input_dist_attr.dims_mapping
for idx, shape in enumerate(complete_shape):
if input_dims_mapping[idx] == -1:
exact_shape.append(shape)
else:
new_shape = shape // input_topology[input_dims_mapping[idx]]
exact_shape.append(new_shape)
def copy_distributed_attr_for_dist_op(dist_context, dist_op, dst_block,
src_op_dist_attr):
"""
copy src op's dist_attr to dst dist op
"""
from ..dist_attribute import OperatorDistributedAttribute
# need check dist op attr and its inputs and outputs
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = src_op_dist_attr.process_mesh
op_dist_attr.impl_idx = src_op_dist_attr.impl_idx
for input_varname in dist_op.desc.input_arg_names():
input_var = dst_block.var(input_varname)
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var)
op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
for output_varname in dist_op.desc.output_arg_names():
output_var = dst_block.var(output_varname)
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
output_var)
op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
dist_context.set_op_dist_attr_for_program(dist_op, op_dist_attr)
op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_op)
return exact_shape
......@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import infer_shape
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
......@@ -172,6 +171,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'],
'c_embedding')
# infer new var shape with op dist attr
out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var)
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr,
out_var_dist_attr)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", 'tmp'])),
......@@ -180,9 +187,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_var_dist_attr)
check_variable_and_dtype(
Out_var, 'tensor',
......@@ -195,6 +202,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx})
if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape)
# use_model_parallel
c_allreduce_sum_op = main_block.append_op(
......@@ -206,12 +215,46 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'use_calc_stream': True,
'use_model_parallel': True,
})
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(ctx, c_embedding_op, main_block,
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)
# set dist op's dist_attr with serial op's dist_attr
# matmulv2
embedding_op_dist_attr = OperatorDistributedAttribute()
embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh
embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_embedding_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
embedding_op_dist_attr.set_input_dist_attr(input_varname,
input_dist_attr)
output_varname = c_embedding_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block,
embedding_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr)
# allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(input_varname,
tensor_dist_attr)
for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, "dist_attr is {}".format(
op_dist_attr)
allreduce_op_dist_attr.set_output_dist_attr(output_varname,
output_dist_attr)
ctx.set_op_dist_attr_for_program(c_allreduce_sum_op,
allreduce_op_dist_attr)
# param initialization sync
assert Weight_var.name not in dist_op_context.already_init_sync_vars
......
......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle
from paddle.distributed.utils import get_logger
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from .dist_context import DistributedContext
......@@ -22,7 +24,11 @@ from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .utils import make_data_unshard
from .utils import set_grad_var_shape
from .reshard import reshard
# from .auto_search import auto_search
_logger = get_logger(logging.INFO)
class AutoParallelizer:
......@@ -59,9 +65,19 @@ class AutoParallelizer:
assert startup_program is not None
main_program = loss.block.program
if self._dist_strategy.auto_search:
# auto search
_logger.info("Start search dist attr.")
# self._dist_context, _ = auto_search(main_program, startup_program,
# loss, self._optimizer)
# completed_main_program = main_program
raise NotImplementedError("Auto search has not implemented")
else:
# Annotation completion
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(main_program,
self._dist_context)
# Logical partition
rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
......@@ -74,13 +90,8 @@ class AutoParallelizer:
self._optimizer, dist_params_grads, partitioned_main_prog,
partitioned_startup_prog)
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if rank not in process_group._ranks:
continue
process_group.instantiate()
# set the grad var shape
set_grad_var_shape(partitioned_main_prog, self._dist_context)
# The last step: remove all distributed attributes to be compatiable
# with inference.
......@@ -91,6 +102,14 @@ class AutoParallelizer:
reshard(partitioned_main_prog, partitioned_startup_prog, rank,
self._dist_context)
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if rank not in process_group._ranks:
continue
process_group.instantiate()
# Copy distributed info to the default context
set_default_distributed_context(self._dist_context)
......
......@@ -981,3 +981,58 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape,
complete_shape))
split_indices_list = [sorted(x) for x in split_indices_list]
return split_indices_list
def set_grad_var_shape(program, dist_context):
from .operators.common import infer_shape
from paddle.distributed.fleet.meta_optimizers.common import OpRole
block = program.global_block()
vars = block.vars
for op in block.ops:
if op.type == "sum":
continue
if int(op.attr('op_role')) == int(OpRole.Backward):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr is not None
for var_name in op.output_arg_names:
assert "@GRAD" in var_name
forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale":
forward_var_name = op.input_arg_names[0]
need_set_shape_list = [
"reshape2_grad", "softmax_with_cross_entropy_grad",
"transpose2_grad", "softmax_grad", "cross_entropy_grad2",
"dropout_grad"
]
forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout"
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
assert int(forward_op.attr('op_role')) != int(
OpRole.Backward)
idx = need_set_shape_list.index(op.type)
forward_op_name = forward_list[idx]
if forward_op.type == forward_op_name and forward_var_name in forward_op.input_arg_names:
op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
break
forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var_name)
assert forward_input_dist_attr is not None, f"{forward_var_name}"
forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
forward_var)
assert forward_var_dist_attr is not None
grad_var = vars[var_name]
ref_shape = infer_shape(block, forward_var,
forward_var_dist_attr,
forward_input_dist_attr)
if list(grad_var.shape) != ref_shape:
grad_var.desc.set_shape(ref_shape)
......@@ -1631,6 +1631,29 @@ class DistributedStrategy(object):
else:
print("WARNING: semi-auto should have value of bool type")
@property
def auto_search(self):
"""
Indicating whether we are using auto-search parallel function
For details, please reference the following code example
Default Value: False
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.auto_search = True
"""
return self.strategy.auto_search
@auto_search.setter
def auto_search(self, flag):
if isinstance(flag, bool):
self.strategy.auto_search = flag
else:
print("WARNING: auto-search should have value of bool type")
@property
def cudnn_exhaustive_search(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册