未验证 提交 c5cc4278 编写于 作者: Y Yulong Ao 提交者: GitHub

[Cherry-pick][Auto Parallel] Improve the APIs (#46164)

* [AutoParallel] adapt gradient merge pass (#45915)

* adapt gradient merge

* fix op_role

* fix strategy

* [Auto Parallel] Gradient Fuse Allreduce (#45643)

* bugfix (#45332)

* dist embedding support lookup table v1

* add unitest

* customize wait_comm

* group gradients

* bugfix

* update program

* [Auto Parallel] Improve the APIs (#45776)

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Add the serialization process for dist attrs

* [Auto Parallel] Remove unnecessary comments

* [Auto Parallel] Fix some bugs

* [Auto Parallel] Fix the code style

* [Auto Parallel] Remove unnecessary impls

* [Auto Parallel] Fix the importing error

* [Auto Parallel] Fix the copy from bugs of op dist attr

* [Auto Parallel] Replace the use of constexpr if

* [Auto Parallel] Redesign the shard_tensor, shard_op and ProcessMesh

* [Auto Parallel] Change API of the completion unittest

* [Auto Parallel] Fix the bug when set_attr an int

* [Auto Parallel] Add the unittest for the serialization

* [Auto Parallel] Add some unit tests

* [Auto Paralle] Unify the strategy

* [Auto Parallel] Improve the engine api

* [Auto Parallel] Reset the changes made to the framework

* [Auto Parallel] Change the engine unittest

* [Auto Parallel] Update API of the completion and partitioner

* [Auto Parallel] Update unit tests using engine api

* update shard annotation

* [Auto Parallel] Remove the modifications of other modules

* [Auto Parallel] Add docs for APIs

* add new strategy

* [Auto Parallel] Replace the logger

* [Auto Parallel] Restore the test_program.py

* [Auto Parallel] Change the import rules

* [Auto Parallel] Add the examples for Engine

* [Auto Parallel] Do some minor changes

* [Auto Parallel] Remove yaml dependency

* [Auto Parallel] Fix the unittests

* add valid after train

* bug fix
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <48191911+Caozhou1995@users.noreply.github.com>

* [Auto Parallel] Bugfix allreduce fuse for MP (#46086)

* bugfix

* bugfix

* typos fixed

* update strategy (#46138)
Co-authored-by: Nzhaoyingli <86812880+zhaoyinglia@users.noreply.github.com>
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <48191911+Caozhou1995@users.noreply.github.com>
上级 860f6077
......@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .strategy import Strategy
from .process_mesh import ProcessMesh
from .reshard import Resharder # noqa: F401
from .cost_model import estimate_cost
from .engine import Engine
from .interface import shard_tensor
from .interface import shard_op
from .interface import recompute
from .interface import fetch
__all__ = []
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from collections import defaultdict
# _g_default_config[category][field] = default_value
_g_default_config = defaultdict(dict)
def get_category_default_config(category):
return _g_default_config[category]
def set_category_default_config(category, default_value):
_g_default_config[category] = default_value
def get_field_default_config(category, field):
return _g_default_config[category][field]
def set_field_default_config(category, field, default_value):
_g_default_config[category][field] = default_value
NOT_FOUND = "not_found"
#########################################
# base configuration
#########################################
BASE = "base"
set_field_default_config(BASE, "auto_mode", "semi")
set_field_default_config(BASE, "gradient_scale", True)
set_field_default_config(BASE, "use_cache", True)
set_field_default_config(BASE, "return_numpy", True)
set_field_default_config(BASE, "all_ranks", False)
set_field_default_config(BASE, "split_data", False)
set_field_default_config(BASE, "seed", None)
set_field_default_config(BASE, "reinit", False) # Only for debug
#########################################
# recompute configuration
#########################################
RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None)
set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
# AMP configuration
#########################################
AMP = "amp"
set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
set_field_default_config(AMP, "incr_ratio", 2.0)
set_field_default_config(AMP, "decr_ratio", 0.8)
set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_optimizer_fp16", False)
#########################################
# sharding configuration
#########################################
SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0)
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])
#########################################
# gradient merge configuration
#########################################
GRADIENT_MERGE = "gradient_merge"
set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
set_field_default_config(GRADIENT_MERGE, "avg", True)
#########################################
# quantization configuration
#########################################
QAT = "qat"
set_field_default_config(QAT, "enable", False)
set_field_default_config(QAT, "channel_wise_abs_max", True)
set_field_default_config(QAT, "weight_bits", 8)
set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None)
# #########################################
# auto tuning configuration
# #########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
set_field_default_config(TUNING, "dataset", None)
set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)
......@@ -16,7 +16,7 @@ import paddle
import warnings
import logging
import numpy as np
from ..utils import get_logger
from .utils import get_logger
class Converter(object):
......
......@@ -173,6 +173,17 @@ class TensorDistributedAttribute:
def clear_annotated(self):
self._is_annotated.clear()
def __eq__(self, other):
if not isinstance(other, TensorDistributedAttribute):
return False
if self.process_mesh != other.process_mesh:
return False
if self.dims_mapping != other.dims_mapping:
return False
if self._is_annotated != other._is_annotated:
return False
return True
def __str__(self):
str = "\n\ttensor_dist_attr = {"
if self.is_annotated("process_mesh"):
......@@ -486,6 +497,27 @@ class OperatorDistributedAttribute:
else:
return False
def __eq__(self, other):
if not isinstance(other, OperatorDistributedAttribute):
return False
if self.process_mesh != other.process_mesh:
return False
if self.op_type != other.op_type:
return False
if self.impl_type != other.impl_type:
return False
if self.impl_idx != other.impl_idx:
return False
if self._is_annotated != other._is_annotated:
return False
if self._is_recompute != other._is_recompute:
return False
if self.inputs_dist_attrs != other.inputs_dist_attrs:
return False
if self.outputs_dist_attrs != other.outputs_dist_attrs:
return False
return True
def __str__(self):
str = "\n\top_dist_attr = {"
if self.is_annotated("process_mesh"):
......
......@@ -126,9 +126,6 @@ class DistributedContext:
# A flag indicates whether the used parallelism is data parallel
self._data_parallel = False
# flag whether using `to_static`
self._dygraph_mode = False
@property
def serial_main_program(self):
return self._serial_main_program
......
......@@ -23,6 +23,7 @@ from .dist_attribute import append_op_input_suffix
from .dist_attribute import append_op_output_suffix
from .dist_attribute import get_tensor_dist_attr_field_keys
from .dist_attribute import get_op_dist_attr_field_keys
from .utils import convert_to_shard_spec, verify_shard_spec
class DistributedOperator:
......@@ -248,23 +249,106 @@ class DistributedOperator:
return result
class DistributedModule:
class DistributedOperatorHelper:
def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module
self._dist_attr = dist_attr
def __init__(self, serial_op, process_mesh, in_dims_mappings,
out_dims_mappings):
self._serial_op = serial_op
self._process_mesh = process_mesh
self._in_dims_mappings = in_dims_mappings
self._out_dims_mappings = out_dims_mappings
def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context
tensor_to_dims_mapping = {}
index = 0
if self._in_dims_mappings:
assert len(args) + len(kwargs) == len(self._in_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._in_dims_mappings), len(args) + len(kwargs))
for arg in args:
if isinstance(arg, Variable) and self._in_dims_mappings:
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
index += 1
for arg in kwargs.values() and self._in_dims_mappings:
if isinstance(arg, Variable):
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
index += 1
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
op_size = len(cur_block.ops)
output = self._serial_module(*args, **kwargs)
output = self._serial_op(*args, **kwargs)
new_op_size = len(cur_block.ops)
if isinstance(output, tuple) or isinstance(output, list):
new_output = list(output)
elif isinstance(output, Variable):
new_output = [output]
else:
raise ValueError("Unrecognized outpout.")
if self._out_dims_mappings:
assert len(new_output) == len(self._out_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._out_dims_mappings), len(new_output))
for i, item in enumerate(new_output):
if isinstance(item, Variable) and self._out_dims_mappings:
tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i]
from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
dist_op = DistributedOperator(op)
for name in dist_op.serial_op.input_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_input(name)
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
name)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
for name in dist_op.serial_op.output_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_output(name)
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
name)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
dist_op.dist_attr.process_mesh = self._process_mesh
if self._process_mesh is not None:
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)
return output
......@@ -27,7 +27,7 @@ from paddle.fluid.framework import static_only
from .utils import get_dist_attr
from .converter import Converter
from .process_group import _g_process_group_map
from ..utils import get_logger
from .utils import get_logger
def check_filename(re_exp, filename):
......@@ -59,6 +59,14 @@ class DistributedSaver:
def save(self, path, serial_program, dist_main_program, dist_context):
def _save_state(program, path, mode="param"):
state = {
k: np.array(v)
for k, v in program.state_dict(mode).items()
}
with open(path, "wb") as f:
pickle.dump(state, f)
dirname, filename = _process_path(path)
rank_id = paddle.distributed.get_rank()
......@@ -76,16 +84,6 @@ class DistributedSaver:
with open(dist_model_path, "wb") as f:
f.write(dist_main_program.desc.serialize_to_string())
# save distributed params
dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams"
dist_param_path = os.path.join(dirname, dist_param_filename)
dist_param = {
k: np.array(v)
for k, v in dist_main_program.state_dict().items()
}
with open(dist_param_path, "wb") as f:
pickle.dump(dist_param, f)
# save distributed attribute
dist_attr_filename = filename + "_dist" + str(rank_id) + ".pdattr"
dist_attr_path = os.path.join(dirname, dist_attr_filename)
......@@ -93,65 +91,69 @@ class DistributedSaver:
with open(dist_attr_path, "wb") as f:
pickle.dump(dist_attrs, f)
# save distributed params
dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams"
dist_param_path = os.path.join(dirname, dist_param_filename)
_save_state(dist_main_program, dist_param_path)
# save distributed opt states
dist_opt_filename = filename + "_dist" + str(rank_id) + ".pdopt"
dist_opt_path = os.path.join(dirname, dist_opt_filename)
_save_state(dist_main_program, dist_opt_path, "opt")
# TODO:save cluster.json
def load(self,
path,
program,
dist_context,
strict=True,
load_optimizer=True):
def load(self, path, load_optimizer=True):
# TODO: if `program` is None, load `path.pdmodel`.
def _load_file(filename, dirname, suffix="pdparams"):
file_list = []
for file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix),
file):
file_list.append(os.path.join(dirname, file))
file_list.sort()
return file_list
def _load_state(filename, dirname, suffix="pdparams"):
file_list = _load_file(filename, dirname, suffix)
state_dict = {}
for file in file_list:
with open(file, 'rb') as f:
state_dict_info = pickle.load(f, encoding='latin1')
for name, value in state_dict_info.items():
if name in state_dict:
state_dict[name].append(np.array(value))
else:
state_dict[name] = [np.array(value)]
self._logger.info("Load param file: {}".format(file_list))
return state_dict
filename = os.path.basename(path)
if filename == "":
raise ValueError(
"path should be of 'dirname/filename' format, but received filename is empty string"
)
dirname = os.path.dirname(path)
# load path.pdparam
param_file_list = []
for param_file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).pdparams'.format(filename),
param_file):
param_file_list.append(os.path.join(dirname, param_file))
param_file_list.sort()
self._logger.info(
"Load distributed attribute file: {}".format(param_file_list))
param_dict = {}
for param_file in param_file_list:
with open(param_file, 'rb') as f:
state_dict_info = pickle.load(f, encoding='latin1')
for name, value in state_dict_info.items():
if name in param_dict:
param_dict[name].append(np.array(value))
else:
param_dict[name] = [np.array(value)]
# load path.pdparam and path.pdopt
param_state_dict = _load_state(filename, dirname)
opt_state_dict = _load_state(filename, dirname,
"pdopt") if load_optimizer else {}
state_dict = dict(param_state_dict, **opt_state_dict)
# load path.pdattr
dist_attr_file_list = []
for dist_attr_file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).pdattr'.format(filename),
dist_attr_file):
dist_attr_file_list.append(os.path.join(dirname,
dist_attr_file))
dist_attr_file_list.sort()
dist_attr_file_list = _load_file(filename, dirname, "pdattr")
self._logger.info(
"Load distributed attribute file: {}".format(dist_attr_file_list))
pre_dist_attr = {}
dist_attr = {}
for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f:
dist_attr = pickle.load(f, encoding='latin1')
for name, attr in dist_attr.items():
if name not in pre_dist_attr:
pre_dist_attr[name] = attr
# get current dist_attr
cur_dist_attr = get_dist_attr(program, dist_context)
# param convert
converter = Converter(param_dict, pre_dist_attr, cur_dist_attr)
param_dict = converter.convert(strict=strict)
program.set_state_dict(param_dict)
dist_attr_info = pickle.load(f, encoding='latin1')
for name, attr in dist_attr_info.items():
if name not in dist_attr:
dist_attr[name] = attr
return state_dict, dist_attr
def save_inference_model(self, path, feed_vars, fetch_vars, exe, **kwargs):
......
......@@ -19,13 +19,13 @@ import paddle
from paddle.nn import Layer
from paddle.jit import to_static, not_to_static
from paddle.distributed.utils import get_logger
from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import program_guard
from paddle.fluid.executor import global_scope
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from .utils import to_list
from .utils import get_logger
from .converter import Converter
......
......@@ -12,101 +12,199 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import copy
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import Variable
from paddle.fluid.framework import _non_static_mode
from paddle.fluid import core
from .process_mesh import ProcessMesh
from .process_mesh import get_current_process_mesh
from .process_mesh import set_current_process_mesh
from .process_mesh import reset_current_process_mesh
from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor
from .dist_op import DistributedModule
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_op import DistributedOperatorHelper
from .utils import verify_shard_spec, convert_to_dims_mapping
def _static_mode_check():
if _non_static_mode():
raise RuntimeError("Auto-parallel only supports static mode for now, "
"please use paddle.enable_static() first.")
def shard_tensor(x, dist_attr=None):
def shard_tensor(x, process_mesh=None, shard_spec=None):
"""
Add distributed attributes for a tensors.
Shard a tensor on a process mesh according to the shard specification.
Args:
x (Tensor): the tensor to be sharded.
dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow:
"process_mesh": a nested list an to describe the mesh topology of logical processes.
"dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension
`i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`,
where -1 means that tensor dimension is not split.
Both process_mesh and dims_mapping are optional and users can specify as need.
process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh
topology of the used logical processes where the tensor is sharded. If it is None,
the found current process mesh will be used. And an error will be raised if the
current process mesh cannot be found. Default: None.
shard_spec (list, optional): a list to describe the sharding mapping between `x` and `process_mesh`,
which means the dimension `i` of `x` is split across the dimension `shard_spec[i]` of `process_mesh`,
where `None` means that tensor dimension is not split. For example, given a tensor wih
the shape [6, 12] and a process mesh with the shape [2, 3] and the dimension names ["x", "y"]:
If `shard_spec=["x", "y"]`, each shard of the tensor will have a shape [3, 4];
If `shard_spec=["y", "x"]`, each shard of the tensor will have a shape [2, 6];
If `shard_spec=["x", None]`, each shard of the tensor will have a shape [3, 12];
If `shard_spec=[None, "x"]`, each shard of the tensor will have a shape [6, 4];
If `shard_spec=["y", None]`, each shard of the tensor will have a shape [2, 12];
If `shard_spec=[None, "y"]`, each shard of the tensor will have a shape [6, 4];
If `shard_spec=[None, None]`, each shard of the tensor will have a shape [6, 12];
If the `shard_spec` is None, the tensor will be replicated across all the processes of `process_mesh`.
In the above example, the `shard_spec=None` is same as 'shard_spec=[None, None]'. Defaults: None.
Returns:
Tensor: the tensor `x` annotated with distributed attributes.
Tensor: the tensor `x` annotated with sharding information.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
import paddle.distributed.auto_parallel as auto
mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
x = paddle.ones([4, 6])
dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]],
"dims_mapping": [0, -1]})
shard_spec = ["x", "y"]
auto.shard_tensor(x, mesh, shard_spec)
"""
_static_mode_check()
assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be None, dict or TensorDistributedAttribute."
dist_tensor = DistributedTensor(x, dist_attr)
dist_tensor.dist_attr.mark_annotated_as(dist_attr)
if process_mesh is not None:
assert isinstance(process_mesh, ProcessMesh), \
"Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh)
else:
process_mesh = get_current_process_mesh()
assert process_mesh is not None, \
"Specify the process mesh argument or use ProcessMesh context manager first."
assert isinstance(shard_spec, list), \
"Argument shard_spec {} is not an instance of list".format(shard_spec)
dist_tensor = DistributedTensor(x)
serial_tensor = dist_tensor.serial_tensor
dist_tensor.dist_attr.process_mesh = process_mesh
if serial_tensor.type == core.VarDesc.VarType.READER \
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
if shard_spec is not None:
assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
serial_tensor.name, shard_spec, tensor_shape, process_mesh)
dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping(
shard_spec, process_mesh)
if process_mesh is not None:
dist_tensor.dist_attr.mark_annotated("process_mesh")
if shard_spec is not None:
dist_tensor.dist_attr.mark_annotated("dims_mapping")
default_dist_ctx = get_default_distributed_context()
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x)
return x
def shard_op(op_fn, dist_attr=None):
def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
"""
Call a functioin and add distributed attributes for ops added by the function.
Shard an operation on a process mesh according to its input and output shard specification.
Args:
op_fn (callable): a callable operator or module to be sharded.
dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into
two categories. The first category decsribes the distributed attributes shared by all inputs and
outputs, and only `process_mesh` can be specified now. The second category describes distributed
attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are
optional and users can specify them as need. Note that `process_mesh` for operators must be the
same as these process_meshes for inputs and outputs.
op (Callable): a callable operator or module to be sharded.
process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh
topology of the used logical processes where the op is sharded. All of its inputs and
outputs are sharded by this process mesh. If it is None, the found current process mesh
will be used. And an error will be raised if the current process mesh cannot be found.
Default: None.
in_shard_specs (list of list, optional): a list of list to describe the sharding specifications
for the inputs. Each item of `in_shard_specs` is a `shard_spec` between the correspoinding input
and `process_mesh`. If one item is None, the cooresponding input is replicated across all processes
If it is None, all inputs are replicated accross all processes. Note that the lenght of the
`in_shard_specs` should be equal to the actual number of inputs when calling this operation.
Default: None.
out_shard_specs (list of list, optional): a list of list to describe the sharding specifications
for the outputs. Each item of `out_shard_specs` is a `shard_spec` between the correspoinding output
and `process_mesh`. If one item is None, the cooresponding output is replicated across all processes
If it is None, all outputs are replicated accross all processes. Note that the lenght of the
`in_shard_specs` should be equal to the actual number of inputs when calling this operation.
Default: None. Default: None.
Returns:
list: the outputs of the function `op_fn`, which are annotated with distributed attributes.
Outputs of `op`, each of which is annotated with sharding information.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
import paddle.distributed.auto_parallel as auto
x = paddle.ones([4, 6])
y = paddle.zeros([4, 6])
dist_add = dist.shard_op(paddle.add,
dist_attr={
"process_mesh": [[2, 3, 1], [0, 4, 5]],
x: {"dims_mapping": [-1, 0]},
y: {"dims_mapping": [0, -1]}
})
mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
dist_add = auto.shard_op(paddle.add,
in_shard_specs=[["x", "y"], ["y", None]],
out_shard_specs=[[None, "x"]])
dist_add(x, y)
"""
_static_mode_check()
assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
dist_module = DistributedModule(op_fn, dist_attr)
return dist_module
if process_mesh is not None:
assert isinstance(process_mesh, ProcessMesh), \
"Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh)
else:
process_mesh = get_current_process_mesh()
assert process_mesh is not None, \
"Specify the process mesh argument or use ProcessMesh context manager first."
in_dims_mappings = []
if in_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in in_shard_specs), \
"in_shard_spec {} is not a list of list or None".format(in_shard_specs)
for shard_spec in in_shard_specs:
if shard_spec is not None:
in_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh))
else:
in_dims_mappings.append(None)
out_dims_mappings = []
if out_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in out_shard_specs), \
"out_shard_spec {} is not a list of list or None".format(out_shard_specs)
for shard_spec in out_shard_specs:
if shard_spec is not None:
out_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh))
else:
out_dims_mappings.append(None)
op = DistributedOperatorHelper(op, process_mesh, in_dims_mappings,
out_dims_mappings)
return op
def recompute(op):
class RecomputeOperator:
def __init__(self, op):
self._op = op
def __call__(self, *args, **kwargs):
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
op_size = len(cur_block.ops)
output = self._op(*args, **kwargs)
new_op_size = len(cur_block.ops)
for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
op._set_attr("is_recompute@auto_parallel", True)
return output
return RecomputeOperator(op)
_g_fetched_tensors = {}
def fetch(tensor, name=None):
if name is None:
_g_fetched_tensors[tensor.name] = tensor
else:
_g_fetched_tensors[name] = tensor
def _get_fetches():
return _g_fetched_tensors
......@@ -42,6 +42,7 @@ from .utils import make_data_unshard
from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo
from .utils import get_logger
from .reshard import Resharder
from .cluster import Cluster
from .mapper import mapping
......@@ -84,7 +85,7 @@ class AutoParallelizer:
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False
self._pass_context = None
# self._pass_context = None
def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix()
......@@ -143,10 +144,11 @@ class AutoParallelizer:
def _apply_optimize(self, main_program, startup_program, params_grads):
optimizer = copy.deepcopy(self._optimizer)
with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy(
self._optimizer).apply_gradients(params_grads)
optimize_ops = optimizer.apply_gradients(params_grads)
self._dist_context._lr_optimizer = optimizer
# update completion
self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program)
......@@ -165,6 +167,15 @@ class AutoParallelizer:
config)
auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["rank_id"] = rank
auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
auto_parallel_clip_pass.apply([main_program], [startup_program],
self._pass_context)
if self._dist_strategy.gradient_merge:
config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
......
......@@ -22,7 +22,6 @@ from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode, unique_name
from paddle.distributed.passes import new_pass
from paddle.distributed.utils import get_logger
from .reshard import Resharder
from .partitioner import Partitioner
......@@ -31,6 +30,7 @@ from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
from .utils import get_logger
from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context
......@@ -160,8 +160,8 @@ class Parallelizer:
# apply quantization pass
# The pass can be applied when mode must be 'train'
if self._mode == 'train' and self._strategy.qat:
config = copy.deepcopy(self._strategy.qat_configs)
if self._mode == 'train' and self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
auto_parallel_quantization_pass = new_pass(
......@@ -176,8 +176,8 @@ class Parallelizer:
# apply amp pass
# FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future
if self._mode == 'train' and self._strategy.amp:
config = copy.deepcopy(self._strategy.amp_configs)
if self._mode == 'train' and self._strategy.amp.enable:
config = copy.deepcopy(self._strategy.amp.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["loss"] = loss
......@@ -195,8 +195,8 @@ class Parallelizer:
# apply recompute pass
# recompute is then train-only optimization
if self._mode == "train" and self._strategy.recompute:
config = copy.deepcopy(self._strategy.recompute_configs)
if self._mode == "train" and self._strategy.recompute.enable:
config = copy.deepcopy(self._strategy.recompute.to_dict())
config["dist_context"] = self._dist_context
config["no_grad_set"] = None
config["loss"] = loss
......@@ -217,12 +217,12 @@ class Parallelizer:
config = {}
config["dist_context"] = self._dist_context
config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding
config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding:
config = copy.deepcopy(self._strategy.sharding_configs)
if self._strategy.sharding.enable:
config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["global_rank"] = rank
......@@ -230,11 +230,11 @@ class Parallelizer:
config)
auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
# GradClip is train-only optimization
if self._mode == "train":
config = copy.deepcopy(self._strategy.sharding_configs)
config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["rank_id"] = rank
......@@ -244,8 +244,8 @@ class Parallelizer:
self._pass_context)
# gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge:
config = copy.deepcopy(self._strategy.gradient_merge_configs)
if self._mode == "train" and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
auto_parallel_gradient_merge_pass = new_pass(
......
......@@ -12,86 +12,90 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import numpy as np
import copy
import paddle
# Use to store the previous and current process mesh
_g_previous_process_mesh = None
_g_current_process_mesh = None
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def get_current_process_mesh():
global _g_current_process_mesh
return _g_current_process_mesh
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
def set_current_process_mesh(process_mesh):
global _g_previous_process_mesh
global _g_current_process_mesh
_g_previous_process_mesh = _g_current_process_mesh
_g_current_process_mesh = process_mesh
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args:
mesh (list): an N-dimensional array (nested list) describes the toplogy
of logical processes. The shape of the N-dimensional array
represents the topology of logical processes and every
element of the N-dimensional array represents a logical process.
def reset_current_process_mesh():
global _g_previous_process_mesh
global _g_current_process_mesh
_g_current_process_mesh = _g_previous_process_mesh
Returns:
None
class ProcessMesh(object):
"""
The `Processmesh` object describes the topology of the used processes.
Raises:
ValueError: If `mesh` is not an instance of list.
Args:
mesh (list|numpy.array): an n-dimensional array describes the toplogy
of the processes.
dim_names (list, optional): the i-th element of this list gives the name of the
i-th dimension of the mesh.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
assert mesh.topology == [2, 3]
assert mesh.processes == [2, 4, 5, 0, 1, 3]
mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
assert mesh.shape == [2, 3]
assert mesh.processe_ids == [2, 4, 5, 0, 1, 3]
"""
def __init__(self, mesh):
if mesh is None or not isinstance(mesh, list):
raise ValueError('mesh must be an instance of list.')
processes = _flatten_nested_list(mesh)
assert all(isinstance(p, int) for p in processes), \
("All elements of mesh must be integer")
assert min(processes) >= 0, ('All elements of mesh must be >= 0.')
unique_processes = set(processes)
assert len(unique_processes) == len(processes), (
'All elements of mesh must be unique.')
self._topology = _get_nested_list_shape(mesh)
self._processes = processes
def __init__(self, mesh=None, dim_names=None, shape=None, process_ids=None):
# Use shape and process_ids just for compatibility
# Users should not use these directly
if mesh is None:
assert shape is not None
assert process_ids is not None
mesh = np.array(process_ids).reshape(shape)
if not isinstance(mesh, list) and \
not isinstance(mesh, np.ndarray):
raise ValueError(
'The mesh must be an instance of list or np.ndarray.')
if isinstance(mesh, list):
mesh = np.array(mesh)
self._mesh = mesh
self._shape = list(self._mesh.shape)
self._process_ids = self._mesh.flatten().tolist()
assert all(isinstance(p, int) for p in self._process_ids), \
("All elements of the mesh must be integer")
assert min(
self._process_ids) >= 0, ('All elements of the mesh must be >= 0.')
unique_process_ids = set(self._process_ids)
assert len(unique_process_ids) == len(
self._process_ids), ('All elements of the mesh must be unique.')
if dim_names is not None:
assert len(dim_names) == len(self._shape), \
("The length of dims_names must be same as the shape of the mesh.")
self._dim_names = copy.deepcopy(dim_names)
else:
self._dim_names = ["d" + str(i) for i in range(len(self._shape))]
unique_dim_names = set(self._dim_names)
assert len(unique_dim_names) == len(self._dim_names), (
'All dim_names {} must be unique.'.format(dim_names))
# Store all process meshes
from .dist_context import get_default_distributed_context
......@@ -103,31 +107,117 @@ class ProcessMesh(object):
pg0.add_ranks(self.processes)
@property
def topology(self):
r"""
Get the topology of logical processes belonging to this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
def shape(self):
"""
Get the shape of this ProcessMesh.
"""
return self._topology
return self._shape
@property
def processes(self):
r"""
Get a list of all processes belonging to this ProcessMesh.
def process_ids(self):
"""
Get the process ids belonging to this ProcessMesh.
"""
return self._processes
return self._process_ids
@property
def dim_names(self):
"""
Get the dimension names of this ProcessMesh.
"""
return self._dim_names
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
Get the number of dimension of this ProcessMesh.
"""
return len(self._shape)
@property
def mesh(self):
"""
Get the underlying mesh of ProcessMesh.
"""
return self._mesh
@property
def topology(self):
return self._shape
@property
def processes(self):
return self._process_ids
def __getitem__(self, index):
if isinstance(index, tuple):
new_dim_names = []
for i, item in enumerate(index):
if isinstance(item, slice):
new_dim_names.append(self._dim_names[i])
new_mesh = self._mesh[index]
if new_mesh.shape:
return ProcessMesh(new_mesh, new_dim_names)
else:
# Wrap a scalar into a list but without dim_names
return ProcessMesh([new_mesh])
elif isinstance(index, slice):
new_mesh = self._mesh[index]
new_dim_names = self._dim_names
return ProcessMesh(new_mesh, new_dim_names)
else:
new_mesh = self._mesh[index]
new_dim_names = self._dim_names[1:]
return ProcessMesh(new_mesh, new_dim_names)
def __enter__(self):
set_current_process_mesh(self)
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
self._old_var_names = list(cur_block.vars.keys())
self._old_op_size = len(cur_block.ops)
def __exit__(self, exc_type, exc_value, exc_traceback):
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
new_var_names = list(cur_block.vars.keys())
new_op_size = len(cur_block.ops)
from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context()
for name in new_var_names:
if name not in self._old_var_names:
tensor = cur_block.vars[name]
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(
tensor)
if dist_tensor is None:
dist_tensor = DistributedTensor(cur_block.vars[name],
{"process_mesh": self})
dist_tensor.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
else:
if dist_tensor.dist_attr.process_mesh is None:
dist_tensor.dist_attr.process_mesh = self
dist_tensor.dist_attr.mark_annotated("process_mesh")
for idx in range(self._old_op_size, new_op_size):
op = cur_block.ops[idx]
dist_op = default_dist_ctx.get_dist_op_for_program(op)
if dist_op is None:
dist_op = DistributedOperator(op, {"process_mesh": self})
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)
else:
if dist_op.dist_attr.process_mesh is None:
dist_op.dist_attr.process_mesh = self
dist_op.dist_attr.mark_annotated("process_mesh")
reset_current_process_mesh()
def __eq__(self, other):
if not isinstance(other, ProcessMesh):
return False
if self.topology != other.topology or self.processes != other.processes:
if self.shape != other.shape or self.process_ids != other.process_ids:
return False
return True
......@@ -135,6 +225,6 @@ class ProcessMesh(object):
return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.processes)
str = "shape {}, process_ids {}, dim_nams {}".format(
self.shape, self.process_ids, self.dim_names)
return str
......@@ -81,54 +81,57 @@ class ProcessMesh(core.ProcessMesh):
return self._mesh
# def compute_compatible_process_meshes(process_meshes):
# """Compute the compatible process mesh given a list of process meshes."""
# if not process_meshes:
# return None
# def _compute_compatible_two_process_meshes(pm1, pm2):
# if pm1 is None:
# return True, pm2
# if pm2 is None:
# return True, pm1
# if pm1 == pm2:
# return True, pm1
# if pm1.device_mesh != pm2.device_mesh:
# return False, None
# if pm1.process_ids == pm2.process_ids:
# if len(pm1.shape) >= len(pm2.shape):
# return True, pm1
# else:
# return True, pm2
# process_set1 = set(pm1.process_ids)
# process_set2 = set(pm2.process_ids)
# if process_set1.issubset(process_set2):
# return True, pm2
# if process_set2.issubset(process_set1):
# return True, pm1
# return False, None
# compatible_result = None
# for process_mesh in process_meshes:
# compatible, compatible_result = _compute_compatible_two_process_meshes(
# compatible_result, process_mesh)
# if not compatible:
# return None
# return ProcessMesh(compatible_result.mesh, compatible_result.dim_names)
# def merge_process_meshes(process_meshes):
# """Merge a list of process meshes."""
# merged_process_mesh = None
# merged_process_ids = set()
# device_type = ""
# for process_mesh in process_meshes:
# if process_mesh is not None:
# process_ids = set(process_mesh.process_ids)
# if not device_type:
# device_type = process_mesh.device_type
# assert device_type != process_mesh.device_type, \
# "All process meshes must have the same device_type."
# merged_process_ids.union(process_ids)
# if len(merged_process_ids) != 0:
# merged_process_mesh = ProcessMesh(list(merged_process_ids))
# return merged_process_mesh
def compute_compatible_process_mesh(process_meshes):
"""Compute the compatible process mesh given a list of process meshes."""
if not process_meshes:
return None
def _compute_compatible_of_two_process_meshes(pm1, pm2):
if pm1 is None:
return True, pm2
if pm2 is None:
return True, pm1
if pm1 == pm2:
return True, pm1
if pm1.process_ids == pm2.process_ids:
if len(pm1.shape) >= len(pm2.shape):
return True, pm1
else:
return True, pm2
process_set1 = set(pm1.process_ids)
process_set2 = set(pm2.process_ids)
if process_set1.issubset(process_set2):
return True, pm2
if process_set2.issubset(process_set1):
return True, pm1
return False, None
compatible_result = None
for process_mesh in process_meshes:
compatible, compatible_result = _compute_compatible_of_two_process_meshes(
compatible_result, process_mesh)
if not compatible:
return None
if compatible_result.empty():
return None
if isinstance(compatible_result, core.ProcessMesh):
mesh = np.array(compatible_result.process_ids).reshape(
compatible_result.shape)
return ProcessMesh(mesh, compatible_result.dim_names)
elif isinstance(compatible_result, ProcessMesh):
return ProcessMesh(compatible_result.mesh, compatible_result.dim_names)
else:
raise ValueError("Unrecognized ProcessMesh.")
def merge_process_mesh(process_meshes):
"""Merge a list of process meshes."""
merged_process_mesh = None
merged_process_ids = set()
for process_mesh in process_meshes:
if process_mesh is not None:
process_ids = set(process_mesh.process_ids)
merged_process_ids = merged_process_ids.union(process_ids)
if len(merged_process_ids) != 0:
merged_process_mesh = ProcessMesh(list(merged_process_ids))
return merged_process_mesh
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os
import copy
import argparse
from . import constants
class BaseConfig(object):
def __init__(self, category, config_dict=None):
self._category = category
self._config_dict = None
if config_dict is not None:
if isinstance(config_dict, dict):
self._config_dict = config_dict
else:
raise ValueError(
"Expected a dictionary. But received: {}".format(
config_dict))
# Initialize attributes by the default config
config = constants.get_category_default_config(self._category)
for field, default_value in config.items():
setattr(self, field, default_value)
# Overide attributes by the config_dict
if self._config_dict:
self.from_dict(self._config_dict)
def from_dict(self, config_dict):
config = constants.get_category_default_config(self._category)
for field in config.keys():
value = config_dict.get(field, constants.NOT_FOUND)
# Use the default value if we cannot found the value
if value != constants.NOT_FOUND:
setattr(self, field, value)
def to_dict(self):
result_dict = {}
config = constants.get_category_default_config(self._category)
for field in config.keys():
value = getattr(self, field)
result_dict[field] = value
for field, value in self.__dict__.items():
if isinstance(value, BaseConfig):
result_dict[field] = value.to_dict()
return result_dict
def __repr__(self):
return yaml.dump(self.to_dict(),
default_flow_style=False,
sort_keys=True,
indent=4)
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
return result
class RecomputeConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.RECOMPUTE
super(RecomputeConfig, self).__init__(category, config_dict)
class AMPConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.AMP
super(AMPConfig, self).__init__(category, config_dict)
class ShardingConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.SHARDING
super(ShardingConfig, self).__init__(category, config_dict)
class GradientMergeConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.GRADIENT_MERGE
super(GradientMergeConfig, self).__init__(category, config_dict)
class QATConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.QAT
super(QATConfig, self).__init__(category, config_dict)
class TuningConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.TUNING
super(TuningConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig):
"""
The `Strategy` object is used to configure the paralleization and optimization beheviors.
Args:
config (dict|string, optional): If this is None, the default configurations will used.
If this is a dictionary, the recognized key-value of it will be used to override the default
configurations while other default configurations are left unchanged. If this is a string,
it is interpreted as the path to a YAML configuration and will be loaded to override the
corresponding default configurations.
Examples:
.. code-block:: python
import paddle
import paddle.distributed.auto_parallel as auto
strategy = auto.Strategy()
sharding = strategy.sharding
self.assertEqual(sharding.enabled, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.degree, 8)
sharding.enabled = True
sharding.stage = 2
sharding.degree = 2
self.assertEqual(sharding.enabled, True)
self.assertEqual(sharding.stage, 2)
self.assertEqual(sharding.degree, 2)
"""
def __init__(self, config=None):
if config is not None:
if isinstance(config, dict):
self._config_dict = copy.deepcopy(config)
# elif os.path.exists(config):
# with open(config, "rb") as yaml_file:
# self._config_dict = yaml.load(yaml_file, Loader=yaml.Loader)
else:
raise ValueError(
"Expected a dictionary. But received: {}".format(config))
else:
self._config_dict = {}
category = constants.BASE
super(Strategy, self).__init__(category, self._config_dict)
config_dict = self._config_dict.get(constants.RECOMPUTE, None)
self.recompute = RecomputeConfig(config_dict)
config_dict = self._config_dict.get(constants.AMP, None)
self.amp = AMPConfig(config_dict)
config_dict = self._config_dict.get(constants.SHARDING, None)
self.sharding = ShardingConfig(config_dict)
config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None)
self.gradient_merge = GradientMergeConfig(config_dict)
config_dict = self._config_dict.get(constants.QAT, None)
self.qat = QATConfig(config_dict)
config_dict = self._config_dict.get(constants.TUNING, None)
self.tuning = TuningConfig(config_dict)
......@@ -16,7 +16,7 @@ import copy
from abc import ABC, abstractmethod
import logging
from paddle.distributed.utils import get_logger
from ..utils import get_logger
from .trial import TrialStatus
from .trial import OptimizationTunerTrial as Trial
......@@ -110,13 +110,13 @@ class ShardingStageAlgorithm(AlgorithmBase):
# TODO import trial class & copy strategy
def __init__(self, config):
super().__init__(config)
self._changed_configs = ["sharding_configs"]
self._changed_configs = ["sharding"]
def _init_spaces(self):
self._max_stage = 3
self._trial_idx = 0
stage_range = self._config.sharding_configs.get("stage_range", None)
stage_range = self._config.sharding.to_dict().get("tuning_range", None)
if stage_range:
assert set(stage_range).issubset(
set([0, 1, 2, 3])
......@@ -136,9 +136,8 @@ class ShardingStageAlgorithm(AlgorithmBase):
stage = self._stage_range[self._trial_idx]
new_strategy = copy.deepcopy(self._config.dist_strategy)
config_dict = new_strategy.sharding_configs
config_dict["stage"] = stage
new_strategy.sharding_configs = config_dict
sharding = new_strategy.sharding
sharding.stage = stage
name = "trial-sharding-stage{}".format(stage)
trial = Trial(new_strategy, name, self.changed_configs)
......
......@@ -17,15 +17,13 @@ import copy
import pathlib
import paddle
from paddle.distributed import fleet
from ..strategy import Strategy
_tuning_supported_passes = ["sharding", "recompute"]
_strategy_config_suffiex = "_configs"
def _get_pass_config(strategy, pass_name):
config_name = pass_name + _strategy_config_suffiex
config = getattr(strategy, config_name)
config = getattr(strategy, pass_name)
return config
......@@ -38,10 +36,8 @@ class TuningConfig(object):
def __init__(self, user_config, strategy):
if not isinstance(strategy, fleet.DistributedStrategy):
raise TypeError(
"'strategy' must be object of class `fleet.DistributedStrategy`."
)
if not isinstance(strategy, Strategy):
raise TypeError("'strategy' must be object of class `Strategy`.")
if not user_config:
user_config = {}
......@@ -116,11 +112,11 @@ class TuningConfig(object):
for p in _tuning_supported_passes:
if getattr(self._dist_strategy, p) and _get_pass_config(
self._dist_strategy, p)["enable_tuning"]:
self._dist_strategy, p).enable_tuning:
# TODO distinguish different args of each passes
self._tuning_passes_name.add(p)
config_name = p + _strategy_config_suffiex
config_name = p
p_dict = getattr(self._dist_strategy, config_name)
self.__dict__[config_name] = p_dict
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# import yaml
import os
import sys
import copy
......@@ -29,7 +30,6 @@ import paddle
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.passes import new_pass, PassContext
from paddle.distributed.utils import get_logger
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.completion import Completer
......@@ -39,6 +39,7 @@ from paddle.distributed.auto_parallel.process_group import clear_all_process_gro
from paddle.distributed.auto_parallel.utils import debug_program
from paddle.distributed.auto_parallel.utils import make_data_unshard, set_grad_var_shape
from ..utils import get_logger
from .config import TuningConfig
from .algorithms import new_algorithm
from .trial import TrialStatus
......@@ -256,8 +257,8 @@ class OptimizationTuner:
startup_program = dist_context.serial_startup_program
# applying optimization pass
if new_strategy.amp:
config = copy.deepcopy(new_strategy.amp_configs)
if new_strategy.amp.enable:
config = copy.deepcopy(new_strategy.amp.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = dist_context._params_grads
......@@ -275,8 +276,8 @@ class OptimizationTuner:
auto_parallel_amp_pass.apply([main_program], [startup_program],
pass_context)
if new_strategy.recompute:
config = copy.deepcopy(new_strategy.recompute_configs)
if new_strategy.recompute.enable:
config = copy.deepcopy(new_strategy.recompute.to_dict())
config["dist_context"] = dist_context
config["no_grad_set"] = None
config["loss"] = dist_context.serial_loss
......@@ -303,8 +304,8 @@ class OptimizationTuner:
dist_context, dist_params_grads)
resharder.reshard()
if new_strategy.sharding:
config = copy.deepcopy(new_strategy.sharding_configs)
if new_strategy.sharding.enable:
config = copy.deepcopy(new_strategy.sharding.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = dist_params_grads
config["global_rank"] = self.rank
......@@ -313,8 +314,8 @@ class OptimizationTuner:
auto_parallel_sharding_pass.apply([dist_main_prog],
[dist_startup_prog], pass_context)
if new_strategy.gradient_merge:
config = copy.deepcopy(new_strategy.gradient_merge_configs)
if new_strategy.gradient_merge.enable:
config = copy.deepcopy(new_strategy.gradient_merge.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = dist_params_grads
auto_parallel_gradient_merge_pass = new_pass(
......@@ -492,9 +493,10 @@ The best trial is: [{}], whose configuration is following:
for line in summary_.split("\n"):
fw.write(line + "\n")
full_strategy = self.get_best_config()
full_strategy.save_to_prototxt(
os.path.join(self.project_dir, "tuned_dist_strategy.prototxt"))
# full_strategy = self.get_best_config()
# path = os.path.join(self.project_dir, "tuned_dist_strategy.yaml")
# with open(path, 'w') as outfile:
# yaml.dump(full_strategy, outfile, default_flow_style=False)
def clear(self):
"""
......
......@@ -156,9 +156,10 @@ class OptimizationTunerTrial(Trial):
draws += h1_format.format("{} auto=True <-> {}".format(name, name))
draws += line + "\n"
my_configs = getattr(self.space, name)
keys = my_configs.keys()
keys = my_configs.to_dict().keys()
for key in keys:
draws += h2_format.format(key, str(my_configs.get(key, None)))
draws += h2_format.format(
key, str(my_configs.to_dict().get(key, None)))
result_res = draws + border
return result_res
......
......@@ -28,6 +28,19 @@ from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
def get_logger(log_level, name="auto_parallel"):
logger = logging.getLogger(name)
logger.propagate = False
if not logger.handlers:
logger.setLevel(log_level)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
log_handler.setFormatter(log_format)
logger.addHandler(log_handler)
return logger
def is_valid_list_index(list, index):
if index >= -len(list) and index < len(list):
return True
......@@ -49,6 +62,58 @@ def is_dim_replicate(mapping):
return False
def verify_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None:
return False
if not all(isinstance(d, int) for d in dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(process_mesh.shape):
return False
for i in range(len(process_mesh.shape)):
if dims_mapping.count(i) > 1:
return False
return True
def convert_to_dims_mapping(shard_spec, process_mesh):
dims_mapping = []
for shard in shard_spec:
if shard is None:
dims_mapping.append(-1)
else:
dims_mapping.append(process_mesh.dim_names.index(shard))
return dims_mapping
def convert_to_shard_spec(dims_mapping, process_mesh):
shard_spec = []
for dim_mapping in dims_mapping:
if dim_mapping == -1:
shard_spec.append(None)
else:
shard_spec.append(process_mesh.dim_names[dim_mapping])
return shard_spec
def verify_shard_spec(shard_spec, tensor_shape, process_mesh):
if len(shard_spec) != len(tensor_shape):
return False
for shard in shard_spec:
if shard is not None and not isinstance(shard, str):
return False
if shard is not None and shard not in process_mesh.dim_names:
return False
dims_mapping = convert_to_dims_mapping(shard_spec, process_mesh)
if not verify_dims_mapping(dims_mapping, process_mesh):
return False
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0:
return False
return True
def compute_compatible_dim_mapping(dim_mappings):
if not dim_mappings:
return None
......@@ -1040,7 +1105,7 @@ def set_grad_var_shape(program, dist_context):
if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast",
'fill_any_like'
"fill_any_like"
]:
forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad":
......@@ -1504,3 +1569,15 @@ def ring_id_to_process_group(ring_id):
if g.id == ring_id:
return g
return None
def find_higher_order_backward_op(program):
higher_order_op_suffix = ['_grad_grad', 'triple_grad']
for block in program.blocks:
for op in block.ops:
for suffix in higher_order_op_suffix:
if suffix in op.type:
return True
return False
......@@ -314,7 +314,9 @@ class AMPState(object):
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr)
else:
assert in_var.dtype == dst_dtype
assert in_var.dtype == dst_dtype, "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type, in_name, dst_dtype, in_var.dtype,
str(grad_op))
for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
......
......@@ -13,12 +13,14 @@
# limitations under the License.
from collections import OrderedDict
import numpy as np
import paddle
from paddle.fluid import core, unique_name
from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op
from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here
......@@ -31,6 +33,10 @@ __rescale_grad_supported_opts__ = [
__max_stream_num_allow__ = 16
def numel(var):
return np.prod(list(var.shape))
@register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase):
"""
......@@ -78,7 +84,9 @@ class DataParallelOptimizationPass(PassBase):
self._analyze_program()
self._prune_grad_scaling()
self._calc_comm_overlap()
self._fuse_allreduce()
grad_group = self._fuse_allreduce()
# self.summary(grad_group)
def _prune_grad_scaling(self):
......@@ -99,7 +107,14 @@ class DataParallelOptimizationPass(PassBase):
self._calc_wait_comms()
def _fuse_allreduce(self):
pass
if not self._could_be_fuse():
return []
grad_group = self._group_grads()
self._update_program(grad_group)
return grad_group
def _analyze_program(self):
"""
......@@ -154,7 +169,7 @@ class DataParallelOptimizationPass(PassBase):
def _could_be_prune(self):
return self.dist_context._gradient_scale and (
return self.dist_context.gradient_scale and (
self._support_rescale_grad or self._all_dp_groups_same_degree())
def _all_dp_groups_same_degree(self):
......@@ -316,3 +331,252 @@ class DataParallelOptimizationPass(PassBase):
'op_role': OpRole.Backward,
'ring_id': ring_id
})
def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient.
# should analyse the dependencies of gradient in backward.
if find_higher_order_backward_op(default_main_program()):
return False
if self.use_sharding:
return False
return True
def _group_grads(self):
"""
conditions for gradients to be grouped:
1. group size < max_fuse_numel
2. same dp group
3. same dtype
4. dependency: grad would NOT be used by other ops within group segment
gradients inside same group would be fuse into one coalesce tensor
"""
block = default_main_program().global_block()
ops = block.ops
# group individual grad vars
# TODO consider fuse gradient for sharding reduce
# TODO let user to set fuse_grad_size
# emb = 50000 * h, ffn = 8 * h * h, mha = 4 * h * h
h = 2048
ffn_numel = 2 * (4 * h) * h
mha_numel = 3 * h * h + h * h
max_fuse_numel = ffn_numel + mha_numel
grad_groups = []
cur_group = GradientsGroup(ops, max_fuse_numel)
grouped_grad_names = set()
def collect_group(cur_group, grad_var, ring_id, i):
if len(cur_group.gradients) == 0:
cur_group = None
elif len(cur_group.gradients) == 1:
grouped_grad_names.remove(cur_group.gradients[0].name)
else:
cur_group.finalize()
grad_groups.append(cur_group)
new_group = GradientsGroup(ops, max_fuse_numel)
if grad_var:
new_group.add(grad_var, ring_id, i)
grouped_grad_names.add(grad_var.name)
return new_group
def op_depend_on_group(op, group):
vars_ = set(op.input_arg_names + op.output_arg_names)
grad_names = set([grad.name for grad in group.gradients])
return len(vars_.intersection(grad_names)) > 0
for i, op in enumerate(ops):
if is_data_parallel_reduce_op(op):
ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0]
grad_var = block.var(grad_name)
grad_numel = numel(grad_var)
if cur_group.acceptable(grad_var, ring_id):
assert grad_name not in grouped_grad_names
grouped_grad_names.add(grad_name)
cur_group.add(grad_var, ring_id, i)
else:
cur_group = collect_group(cur_group, grad_var, ring_id, i)
else:
if op_depend_on_group(op, cur_group):
cur_group = collect_group(cur_group, None, None, None)
# collect last group
collect_group(cur_group, None, None, None)
return grad_groups
def _update_program(self, grad_groups):
block = default_main_program().global_block()
remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute']
for i, group in enumerate(grad_groups[::-1]):
# create coalecse tensor
group.coalesce_var = block.create_var(name=unique_name.generate(
'coalecse_grad_{}'.format(i)),
dtype=group.dtype,
persistable=False,
stop_gradient=True)
# update allreduce & scale op
if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx]
assert scale_op.type == 'scale', "should found scale op but found {}".format(
str(scale_op))
scale_op._rename_input(scale_op.input_arg_names[0],
group.coalesce_var.name)
scale_op._rename_output(scale_op.output_arg_names[0],
group.coalesce_var.name)
allreduce_op = block.ops[group.allreduce_op_idx]
assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format(
str(allreduce_op))
allreduce_op._rename_input(allreduce_op.input_arg_names[0],
group.coalesce_var.name)
allreduce_op._rename_output(allreduce_op.output_arg_names[0],
group.coalesce_var.name)
# remvoe un-used op
remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices
for idx in sorted(remove_op_indices, reverse=True):
assert block.ops[
idx].type in remove_op_types, "Unexception: try to remove op {}".format(
str(op))
block._remove_op(idx)
# insert coalecse op
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(group.coalesce_op_idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
outputs={
"Output": grad_names,
"FusedOutput": group.coalesce_var
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": group.dtype,
"concated_shapes":
concated_shapes,
"concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward
})
block._sync_with_cpp()
# TODO update dist attr
def summary(self, grad_groups=[]):
# TODO: add logger module
import logging
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
self._logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format)
self._logger.addHandler(log_handler)
if len(grad_groups) > 0:
self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
.format(len(self._grad_name_to_group_map.keys()),
len(grad_groups)))
self._logger.info("gradient fusing group are following: ")
fused_grads = set()
for i, group in enumerate(grad_groups):
self._logger.info(
"coalecse gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients]))
fused_grads.update([grad.name for grad in group.gradients])
individual_grads = set(
self._grad_name_to_group_map.keys()) - set(fused_grads)
self._logger.info(
"the following [{}] gradients are not fused: ".format(
len(individual_grads)))
self._logger.info("individual gradient {}".format(individual_grads))
class GradientsGroup(object):
def __init__(self, ops, max_group_size):
self.max_group_size = max_group_size
self.ops = ops
self.gradients = []
self.numel = 0
self.dtype = None
self.ring_id = None
self.coalesce_var = None
self.coalesce_op_idx = -1
self.allreduce_op_idx = -1
self.scale_op_idx = -1
self.remove_wait_op_indices = []
self.remove_allreduce_op_indices = []
self.remove_scale_op_indices = []
def acceptable(self, grad_var, ring_id):
if len(self.gradients) == 0:
return True
if ring_id != self.ring_id:
return False
if numel(grad_var) + self.numel > self.max_group_size:
return False
if grad_var.dtype != self.dtype:
return False
return True
def add(self, grad_var, ring_id, i):
self.gradients.append(grad_var)
self.ring_id = ring_id
self.dtype = grad_var.dtype
self.numel += numel(grad_var)
# remove auxiliary ops in non-fuse dp allreduce
self.remove_allreduce_op_indices.append(i)
# NOTE this pass rely on the original synchronization add in previous passes
# (same stream or calc_wait_comm & comm_wait_calc)
# to guarantee the correctness of comm_calc execution order.
# so the calc_wait_comm should be keep.
grad_op_idx = i - 1
if i > 0 and self.ops[i - 1].type == 'c_wait_compute':
self.remove_wait_op_indices.append(i - 1)
grad_op_idx -= 1
if i + 1 < len(self.ops) and is_data_parallel_scale_op(self.ops[i - 1]):
self.remove_scale_op_indices.append(i + 1)
if len(self.gradients) == 1:
# TODO Remove this is a temporary hack for Tensor Parallel. the logic
# for find grad_op should be more general.
if self.ops[grad_op_idx].type == "c_allreduce_sum":
grad_op_idx -= 1
grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op))
self.coalesce_op_idx = grad_op_idx
def finalize(self):
self.allreduce_op_idx = self.remove_allreduce_op_indices.pop()
if len(self.remove_wait_op_indices) > 1:
self.remove_wait_op_indices.pop()
if len(self.remove_scale_op_indices) > 1:
self.scale_op_idx = self.remove_scale_op_indices.pop()
......@@ -16,6 +16,7 @@ from collections import defaultdict
import paddle
from paddle.framework import core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name
from .pass_base import register_pass
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
......@@ -379,6 +380,10 @@ class FP16State(object):
# create cast grad
grad_slot_name = slot_name + "@GRAD"
assert grad_slot_name in op.output_names
if len(op.output(grad_slot_name)) == 0:
var = block.var(src_name)
assert var.stop_gradient is True
continue
assert len(op.output(grad_slot_name)) == 1
grad_name = op.output(grad_slot_name)[0]
grad = block.var(grad_name)
......@@ -442,7 +447,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward}
attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
......@@ -536,6 +541,39 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
return output_var
def cast_startup_program():
main_program = default_main_program()
startup_program = default_startup_program()
param_to_dtype = {}
for block in main_program.blocks:
for p in block.all_parameters():
param_to_dtype[p.name] = p.dtype
def is_initialization_op(op):
comm_op_prefix = "c_"
op_type = op.type
if op_type.startswith(comm_op_prefix):
return False
if len(op.output_arg_names) != 1 and len(op.input_arg_names) != 0:
return False
return True
for op in startup_program.global_block().ops:
if is_initialization_op(op):
output_name = op.output_arg_names[0]
if param_to_dtype.get(output_name,
None) == core.VarDesc.VarType.FP16:
assert op.has_attr(
'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format(
str(op))
if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
@register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass):
......@@ -563,6 +601,8 @@ class FP16Pass(AMPPass):
input_data_var_names)
is_train = fp16_state._build_state()
cast_startup_program()
if is_train:
with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference
......@@ -575,18 +615,18 @@ class FP16Pass(AMPPass):
) or self.get_attr("init_loss_scaling") != 1.0:
found_infs = []
if fp32_grads:
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
_, found_inf_fp32 = _check_and_update_gradient(
fp32_grads, self._loss_scaling, "@fp32",
self.dist_context)
found_infs.append(found_inf_fp32)
if fp16_grads:
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
_, found_inf_fp16 = _check_and_update_gradient(
fp16_grads, self._loss_scaling, "@fp16",
self.dist_context)
found_infs.append(found_inf_fp16)
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
block = main_program.global_block()
all_infs = paddle.fluid.layers.concat(found_infs)
......@@ -608,7 +648,7 @@ class FP16Pass(AMPPass):
block, self.dist_context)
if self.get_attr("use_dynamic_loss_scaling"):
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
if fp32_grads:
self._update_loss_scaling(fp32_grads, found_inf)
if fp16_grads:
......
......@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase):
super(ClipGradByGloblNormPass, self).__init__()
self.set_attr("rank_id", None)
self.set_attr("dist_context", None)
self.set_attr("params_grads", None)
def _check_self(self):
if self.get_attr("dist_context") is None:
......@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context")
if dist_context._lr_optimizer._grad_clip is None:
return False
if self.get_attr("params_grads") is None:
return False
return True
def _check_conflict(self, other_pass):
......@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context", None)
rank_id = self.get_attr("rank_id", None)
block = main_program.global_block()
dist_params_grads = _get_params_grads(block)
dist_params_grads = self.get_attr("params_grads", None)
# dist_params_grads = _get_params_grads(block)
self.clip_helper = ClipHelper(dist_params_grads, rank_id, block,
dist_context)
......
......@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
return optimize_ops_desc
def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker
op = grad.op
if op and op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName())
def _get_gm_cond_var(main_program, k_steps, dist_context):
main_block = main_program.global_block()
# Add const var
......@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
param.type != core.VarDesc.VarType.SELECTED_ROWS
), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
_remove_op_role_var(param, grad)
# {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge = {}
# {param: gradient_merge_var} to insert scale op and fill_constant op
......
......@@ -51,7 +51,8 @@ class ShardingPass(PassBase):
super(ShardingPass, self).__init__()
self.set_attr("dist_context", None)
self.set_attr("stage", None)
self.set_attr("sharding_degree", None)
self.set_attr("sharding_degree", None) # for parallelizer
self.set_attr("degree", None) # for parallelizer_v2
self.set_attr("params_grads", [])
self.set_attr("global_rank", -1)
self.dp_groups = set()
......@@ -59,6 +60,7 @@ class ShardingPass(PassBase):
self.varname_to_sharding_info = {}
self.partial_sharding = False
self.outer_dp_group = None
self.shared_params_grads = []
def _check_self(self):
if self.get_attr("dist_context") is None:
......@@ -66,8 +68,15 @@ class ShardingPass(PassBase):
if self.get_attr("stage") not in [1, 2, 3]:
return False
if (not isinstance(self.get_attr("sharding_degree"),
int)) or self.get_attr("sharding_degree") <= 1:
if self.get_attr("sharding_degree") is not None:
if (not isinstance(self.get_attr("sharding_degree"), int)) \
or self.get_attr("sharding_degree") <= 1:
return False
elif self.get_attr("degree") is not None:
if (not isinstance(self.get_attr("degree"), int)) \
or self.get_attr("degree") <= 1:
return False
else:
return False
if len(self.get_attr("params_grads")) <= 0:
return False
......@@ -82,7 +91,8 @@ class ShardingPass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context):
self._dist_context = self.get_attr("dist_context")
self.sharding_world_size = int(self.get_attr("sharding_degree"))
self.sharding_world_size = int(
self.get_attr("sharding_degree") or self.get_attr("degree"))
self.stage = int(self.get_attr("stage"))
self.global_rank = int(self.get_attr("global_rank"))
params_grads = self.get_attr("params_grads")
......@@ -94,6 +104,8 @@ class ShardingPass(PassBase):
self._shard_gradient_synchronization(main_block)
self._shard_parameter(main_block, startup_block)
context.set_attr("params_grads", self.shared_params_grads)
def _build_sharding_groups(self, main_block, params_grads):
self._collective_data_parallel_groups(main_block)
self._build_sharding_infos(params_grads)
......@@ -148,13 +160,10 @@ class ShardingPass(PassBase):
self._dist_context._sharding_group = sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group = [p for p, g in params_grads]
assert len(params_in_group) == len(
set(params_in_group)), "found duplicated param in params_grads"
sharding_info = ShardingInfo(sharding_group, self.global_rank,
params_in_group)
params_grads)
self.sharding_infos.append(sharding_info)
for param in params_in_group:
for param in sharding_info.params:
self.varname_to_sharding_info[param.name] = sharding_info
def _shard_optimizer(self, main_block, startup_block, params_grads,
......@@ -201,6 +210,7 @@ class ShardingPass(PassBase):
op.desc.set_output('Out', reversed_x)
else:
if op.type == "check_finite_and_unscale":
op_role = op.attr('op_role')
out_name = op.output_arg_names[0]
out_var = main_block.vars[out_name]
main_block._remove_op(idx, sync=False)
......@@ -212,6 +222,7 @@ class ShardingPass(PassBase):
"shape": out_var.shape,
"dtype": out_var.dtype,
"value": 0,
OP_ROLE_KEY: op_role,
})
else:
main_block._remove_op(idx, sync=False)
......@@ -313,6 +324,9 @@ class ShardingPass(PassBase):
if varname != param_name
])
main_block._remove_op(idx, sync=False)
else:
self.shared_params_grads.append(
self._get_param_grad(param_name))
for idx, op in reversed(list(enumerate(startup_block.ops))):
if len(op.output_arg_names) == 1 and op.output_arg_names[
......@@ -365,6 +379,13 @@ class ShardingPass(PassBase):
sharding_info = self.varname_to_sharding_info[param_name]
return sharding_info.is_in_local_shard(param_name)
def _get_param_grad(self, param_name):
assert param_name in self.varname_to_sharding_info
sharding_info = self.varname_to_sharding_info[param_name]
p_g = sharding_info.get_param_grad(param_name)
assert p_g is not None
return p_g
def _shard_gradient_synchronization(self, main_block):
if self.stage < 2:
......@@ -705,9 +726,13 @@ def shard_parameters(params, group_size):
class ShardingInfo(object):
def __init__(self, group, rank, params):
def __init__(self, group, rank, params_grads):
self.group = group
self.params = params
self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads])
assert len(self.params_grads) == len(set(
self.params_grads)), "found duplicated param in params_grads"
self.params = [p for p, _ in params_grads]
self.param_names = [p.name for p in self.params]
self.group_size = group.nranks
self.global_rank = rank
......@@ -762,3 +787,11 @@ class ShardingInfo(object):
if usage > 0:
broadcast_vars.add(param)
return broadcast_vars, param_usage
def get_param_grad(self, param_name):
if not self.is_in_local_shard(param_name):
raise ValueError(
"param[{}] not in current rank.".format(param_name))
if param_name not in self.params_grads:
raise ValueError('param[{}] not in params_grads'.format(param_name))
return self.params_grads.get(param_name, None)
......@@ -37,8 +37,28 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
${dist_ENVS})
set_tests_properties(test_high_order_grad
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_grad_clip MODULES test_grad_clip ENVS ${dist_ENVS})
set_tests_properties(test_grad_clip PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS})
set_tests_properties(test_iterable_dataset
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_pass_grad_clip MODULES test_pass_grad_clip ENVS
${dist_ENVS})
set_tests_properties(test_pass_grad_clip
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_gradient_merge MODULES test_pass_gradient_merge
ENVS ${dist_ENVS})
set_tests_properties(test_pass_gradient_merge
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_recompute MODULES test_pass_recompute ENVS
${dist_ENVS})
set_tests_properties(test_pass_recompute
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_sharding MODULES test_pass_sharding ENVS
${dist_ENVS})
set_tests_properties(test_pass_sharding
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_amp MODULES test_pass_amp ENVS ${dist_ENVS})
set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion
......@@ -70,11 +90,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip)
py_test_modules(test_quantization MODULES test_quantization)
py_test_modules(test_dist_matmul MODULES test_dist_matmul)
py_test_modules(test_process_mesh MODULES test_process_mesh)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS})
set_tests_properties(test_iterable_dataset
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
def apply_pass(use_amp=False, level=None):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum'
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"]
amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level)
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestAMPPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, level=None):
reset_prog()
strategy = apply_pass(use_amp, level)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses, rtol=None, atol=None):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=rtol or self.rtol,
atol=atol or self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_amp_pass(self):
# mp2 training
mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"])
# mp2 amp-o1 training
amp_o1_engine = self.get_engine(True, "o1")
amp_o1_losses = amp_o1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o1_losses = np.array(amp_o1_losses["loss"])
# self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training
amp_o2_engine = self.get_engine(True, "o2")
amp_o2_losses = amp_o2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o2_losses = np.array(amp_o2_losses["loss"])
# self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training
amp_o3_engine = self.get_engine(True, "o3")
amp_o3_losses = amp_o3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o3_losses = np.array(amp_o3_losses["loss"])
# self.check_results(mp_losses, amp_o3_losses)
if __name__ == "__main__":
unittest.main()
......@@ -32,7 +32,7 @@ import paddle.distributed.auto_parallel as auto
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
batch_size = 4
hidden_size = 1024
sequence_len = 512
......@@ -103,11 +103,7 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, 1],
dtype='float32')
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mappig": [-1, -1, -1]
})
auto.shard_tensor(input, _global_process_mesh, [None, None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -126,9 +122,6 @@ def mlp_pretrain_forward(train_program, start_program):
def train():
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
......
......@@ -18,24 +18,21 @@ import random
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_sharding=False):
strategy = fleet.DistributedStrategy()
strategy.semi_auto = True
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_sharding:
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 2,
}
sharding = strategy.sharding
sharding.degree = 2
sharding.stage = 2
return strategy
......@@ -76,34 +73,17 @@ class TestGradientClipByGlobalNorm(unittest.TestCase):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
engine.mode = "train"
engine._executor.run(engine.startup_program)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_dp2_engine(self):
def get_engine(self, use_sharding=False):
reset_prog()
strategy = apply_pass()
strategy = apply_pass(use_sharding)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
inputs_spec, labels_spec = create_data_holder(self.batch_size)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
self.init(engine)
return engine
def get_dp2sharding2_engine(self):
reset_prog()
strategy = apply_pass(True)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
inputs_spec, labels_spec = create_data_holder(self.batch_size)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
......@@ -121,15 +101,13 @@ class TestGradientClipByGlobalNorm(unittest.TestCase):
def test_grad_clip(self):
# dp2 training
dp_engine = self.get_dp2_engine()
dp_engine.fit(self.dataset, batch_size=self.batch_size, use_cache=True)
dp_engine = self.get_engine()
dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_param_values = get_parameter_value(dp_engine.main_program)
# dp2sharding2 training
sharding_engine = self.get_dp2sharding2_engine()
sharding_engine.fit(self.dataset,
batch_size=self.batch_size,
use_cache=True)
sharding_engine = self.get_engine(True)
sharding_engine.fit(self.dataset, 3, batch_size=self.batch_size)
sharding_param_values = get_parameter_value(
sharding_engine.main_program)
......
......@@ -27,10 +27,8 @@ import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn
......@@ -47,6 +45,8 @@ class_num = 10
paddle.seed(44)
is_fetch = True
class MyDataset(Dataset):
......@@ -90,19 +90,20 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = auto.shard_op(self.norm, dist_attr={"process_mesh":
PP_MESH_0})(input)
out = auto.shard_op(self.norm, PP_MESH_0)(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, dist_attr={"process_mesh":
PP_MESH_1})(out)
out = auto.shard_op(self.linear1, PP_MESH_1)(out)
out = self.dropout(out)
out = self.linear2(out)
self.out = out
if is_fetch:
auto.fetch(out, "out")
return out
def train(fetch):
global is_fetch
is_fetch = fetch
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
......@@ -113,46 +114,34 @@ def train(fetch):
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
strategy = auto.Strategy()
strategy.auto_mode = "semi"
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
# train
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset,
eval_dataset1 = MyDataset(5 * batch_size)
engine.fit(train_data=train_dataset,
epochs=2,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size,
fetches=fetches)
valid_data=eval_dataset1)
# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches)
eval_dataset2 = MyDataset(batch_size)
engine.evaluate(eval_dataset2, batch_size=batch_size)
# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches)
engine.predict(test_dataset, batch_size=batch_size)
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict')
model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup()
......
......@@ -26,11 +26,9 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
from paddle.io import Dataset, DataLoader
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
paddle.enable_static()
batch_size = 2
......@@ -91,6 +89,7 @@ class MLPLayer(nn.Layer):
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
auto.fetch(out, "out")
self.out = out
return out
......@@ -107,46 +106,32 @@ def train(fetch):
epsilon=1e-08,
grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
# init engine
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
engine = auto.Engine(mlp,
loss,
optimizer,
paddle.metric.Accuracy(),
strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset, batch_size=batch_size, fetches=fetches)
engine.fit(train_dataset, batch_size=batch_size)
# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches)
engine.evaluate(eval_dataset, batch_size=batch_size)
# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches)
engine.predict(test_dataset, batch_size=batch_size)
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict')
engine.save(model_filename, training=False)
temp_dir.cleanup()
......
......@@ -14,8 +14,10 @@
import sys
import numpy as np
import random
import paddle
import paddle.distributed.auto_parallel as auto
sys.path.append("..")
import auto_parallel_gpt_model as modeling
......@@ -25,7 +27,7 @@ sequence_len = 512
vocab_size = 1000
class FakeDataset:
class FakeDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
......@@ -33,6 +35,9 @@ class FakeDataset:
self.vocab_size = vocab_size
def __getitem__(self, idx):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape(
......@@ -67,8 +72,9 @@ def create_data_holder(batch_size):
def generate_model(strategy):
modeling.init_global()
modeling._global_process_mesh = list(
range(paddle.distributed.get_world_size()))
ranks = list(range(paddle.distributed.get_world_size()))
modeling._global_process_mesh = auto.ProcessMesh(mesh=ranks,
dim_names=["x"])
if strategy == "serial":
modeling._global_parallel_strategy = "serial"
elif strategy == "mp":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_gradient_merge=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_gradient_merge:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 4
gradient_merge.avg = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestGradientMergePass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 8
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_gradient_merge=False):
reset_prog()
strategy = apply_pass(use_gradient_merge)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_gradient_merge_pass(self):
# dp2 training
dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"])
# dp2 gradient merge training
gm_engine = self.get_engine(True)
gm_losses = gm_engine.fit(self.dataset, 3, batch_size=self.batch_size)
gm_losses = np.array(gm_losses["loss"])
avg_loss = 0
pass_avg_ret_list = []
for i, pass_ret in enumerate(gm_losses):
if (i + 1) % 4 == 0:
avg_loss += pass_ret
pass_avg_ret_list.append(avg_loss / 4)
avg_loss = 0
else:
avg_loss += pass_ret
self.check_results(dp_losses, np.array(pass_avg_ret_list))
if __name__ == "__main__":
unittest.main()
......@@ -17,11 +17,7 @@ import paddle
import unittest
import numpy as np
import paddle.distributed.auto_parallel as auto
from paddle.static import InputSpec
from paddle.distributed import fleet
from paddle.incubate.autograd import Hessian
from paddle.distributed.auto_parallel.engine import Engine
np.random.seed(1234)
paddle.seed(1234)
......@@ -87,7 +83,7 @@ class LaplaceModel(paddle.nn.Layer):
return eq_loss, bc_u
class LaplaceDataset:
class LaplaceDataset(paddle.io.Dataset):
def __init__(self, num_sample):
self.num_sample = num_sample
......@@ -129,23 +125,14 @@ def main():
# model
laplace = LaplaceModel()
# spec
inputs_spec = [
InputSpec([100, 2], 'float32', 'x'),
InputSpec([36], 'int64', 'bc_idx')
]
labels_spec = InputSpec([36, 1], 'float32', 'bc_v')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
engine = Engine(laplace,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
engine = auto.Engine(laplace,
loss=loss_func,
optimizer=optimizer,
strategy=dist_strategy)
engine.prepare(optimizer=optimizer, loss=loss_func)
engine.fit(train_dataset, batch_size=None)
engine.fit(train_dataset, train_sample_split=2, batch_size=None)
dist_context = engine.dist_context
block = engine.main_program.global_block()
......
......@@ -28,9 +28,8 @@ import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn
......@@ -48,10 +47,9 @@ class_num = 10
paddle.seed(44)
class MyDataset(IterableDataset):
class MyDataset(paddle.io.IterableDataset):
def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples
def __iter__(self):
......@@ -61,10 +59,9 @@ class MyDataset(IterableDataset):
yield input, label
class MyDataset1(Dataset):
class MyDataset1(paddle.io.Dataset):
def __init__(self, num_samples):
super(MyDataset1, self).__init__()
self.num_samples = num_samples
self.data = []
for i in range(self.num_samples):
......@@ -112,12 +109,10 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = auto.shard_op(self.norm, dist_attr={"process_mesh":
PP_MESH_0})(input)
out = auto.shard_op(self.norm, PP_MESH_0)(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, dist_attr={"process_mesh":
PP_MESH_1})(out)
out = auto.shard_op(self.linear1, PP_MESH_1)(out)
out = self.dropout(out)
out = self.linear2(out)
self.out = out
......@@ -136,54 +131,36 @@ def train(fetch):
epsilon=1e-08,
grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
dist_strategy.split_data = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
engine = auto.Engine(mlp,
loss,
optimizer,
paddle.metric.Accuracy(),
strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataset1 = MyDataset1(batch_num)
engine.fit(train_dataset,
epochs=2,
batch_size=batch_size,
steps_per_epoch=batch_num,
fetches=fetches)
engine.fit(train_dataset1,
epochs=2,
batch_size=None,
steps_per_epoch=batch_num,
fetches=fetches)
engine.fit(train_dataset, epochs=2, batch_size=batch_size)
train_dataset1 = MyDataset1(batch_size * batch_num)
engine.fit(train_dataset1, epochs=2, batch_size=None)
# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches)
engine.evaluate(eval_dataset, batch_size=batch_size)
# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches)
engine.predict(test_dataset, batch_size=batch_size)
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict')
engine.save(model_filename, training=False)
temp_dir.cleanup()
......
......@@ -27,10 +27,8 @@ import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from engine_api_dp import MyDataset
paddle.enable_static()
......@@ -43,20 +41,6 @@ class_num = 10
paddle.seed(44)
# class MyDataset(Dataset):
# def __init__(self, num_samples):
# super(MyDataset, self).__init__()
# self.num_samples = num_samples
# def __getitem__(self, index):
# input = np.random.uniform(size=image_size).astype("float32")
# label = np.random.randint(0, class_num - 1, dtype="int64")
# return input, label
# def __len__(self):
# return self.num_samples
class MLPLayer(nn.Layer):
......@@ -107,50 +91,33 @@ def train(fetch):
epsilon=1e-08,
grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
dist_strategy.sharding = True
dist_strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 3,
"enable_tuning": True,
}
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
import tempfile
tmp_dir = tempfile.TemporaryDirectory()
dataset = MyDataset(batch_num * batch_size)
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
# sharding config
sharding = dist_strategy.sharding
sharding.enable = True
sharding.degree = 2
sharding.stage = 3
sharding.enable_tuning = True
sharding.tuning_range = [0, 1, 2, 3]
# Tuning configuration
tuning_config = {
"batch_size": batch_size,
"dataset": dataset,
"profile_start_step": 1,
"profile_end_step": 5,
"run_after_tuning": True,
"sharding": {
"stage_range": [0, 1, 2, 3]
},
"verbose": True,
}
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy,
user_tuning_config=tuning_config)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
tuning = dist_strategy.tuning
tuning.enable = True
tuning.profile_start_step = 1
tuning.profile_end_step = 5
tuning.run_after_tuning = True
tuning.verbose = True
dataset = MyDataset(batch_num * batch_size)
engine = auto.Engine(mlp,
loss,
optimizer,
paddle.metric.Accuracy(),
strategy=dist_strategy)
engine._tune(dataset, batch_size=batch_size)
# check tuned
assert (engine._dist_contexts['train'].strategy.sharding_configs['stage'] !=
3)
assert (engine._dist_contexts['train'].strategy.sharding.stage != 3)
if __name__ == "__main__":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
def apply_pass(use_recompute=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestRecomputePass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_recompute=False):
reset_prog()
strategy = apply_pass(use_recompute)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_recompute_pass(self):
# mp2 training
mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"])
# mp2 recompute training
rc_engine = self.get_engine(True)
rc_losses = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc_losses = np.array(rc_losses["loss"])
self.check_results(mp_losses, rc_losses)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_sharding=False, stage=None):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_sharding:
sharding = strategy.sharding
sharding.enable = True
sharding.degree = 2
sharding.stage = 1
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_sharding=False, stage=None):
reset_prog()
strategy = apply_pass(use_sharding, stage)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_sharding_pass(self):
# dp2 training
dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"])
# sharding2 stage1 training
sharding1_engine = self.get_engine(True, 1)
sharding1_losses = sharding1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding1_losses = np.array(sharding1_losses["loss"])
self.check_results(dp_losses, sharding1_losses)
# sharding2 stage2 training
sharding2_engine = self.get_engine(True, 2)
sharding2_losses = sharding2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding2_losses = np.array(sharding2_losses["loss"])
self.check_results(dp_losses, sharding2_losses)
# sharding2 stage3 training
sharding3_engine = self.get_engine(True, 3)
sharding3_losses = sharding3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding3_losses = np.array(sharding3_losses["loss"])
self.check_results(dp_losses, sharding3_losses)
if __name__ == "__main__":
unittest.main()
......@@ -45,9 +45,10 @@ from test_cluster import cluster_json
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]],
dim_names=["x", "y", "z"])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])
class MLPLayer(nn.Layer):
......@@ -74,16 +75,8 @@ class MLPLayer(nn.Layer):
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])
out = self.norm(input)
out = self.linear0(out)
......@@ -111,16 +104,8 @@ def mlp_forward(train_program, start_program):
embedding = paddle.nn.Embedding(10, hidden_size, sparse=True)
embedding_out = embedding(fill_constant_out)
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, PP_MESH_0, ["x", None])
auto.shard_tensor(label, PP_MESH_1, ["x", None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......
......@@ -34,7 +34,10 @@ paddle.enable_static()
batch_size = 4
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]]
_g_process_mesh = [
auto.ProcessMesh([0, 1], dim_names=["x"]),
auto.ProcessMesh([2, 3], dim_names=["x"])
]
def get_random_inputs_and_labels(input_shape, label_shape):
......@@ -82,18 +85,10 @@ class MLPLayer(nn.Layer):
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear0.weight, _g_process_mesh[0], [None, "x"])
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear1.weight, _g_process_mesh[1], ["x", None])
out = self.linear1(out)
return out
......@@ -123,16 +118,8 @@ def get_program():
dataloader.set_batch_generator(batch_generator_creator(),
places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(input,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(input, _g_process_mesh[0], ["x", None, None])
auto.shard_tensor(label, _g_process_mesh[0], ["x", None, None])
mlp_start = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......
......@@ -42,19 +42,13 @@ def make_program_lookup_table_v1_mp_dp():
is_sparse=False)
loss = paddle.fluid.layers.reduce_mean(emb_out)
auto.shard_tensor(src_ids,
dist_attr={
"process_mesh": auto.ProcessMesh([[0, 1], [2,
3]]),
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(
src_ids, auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]),
["x", None, None])
emb_weight = block.vars["emb_weight"]
auto.shard_tensor(emb_weight,
dist_attr={
"process_mesh": auto.ProcessMesh([[0, 1], [2,
3]]),
"dims_mapping": [1, -1]
})
auto.shard_tensor(
emb_weight, auto.ProcessMesh([[0, 1], [2, 3]],
dim_names=["x", "y"]), ["y", None])
return main_program, start_program, loss
......
......@@ -22,82 +22,58 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
mesh = [[0, 1], [2, 3]]
mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
def init_x_row(trans_x):
if trans_x:
x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, 1, -1]
})
auto.shard_tensor(x, mesh, ["x", "y", None])
return x
else:
x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1, 1]
})
auto.shard_tensor(x, mesh, ["x", None, "y"])
return x
def init_x_col(trans_x):
if trans_x:
x = paddle.static.data(name='x', shape=[6, 8], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(x, mesh, [None, "x"])
return x
else:
x = paddle.static.data(name='x', shape=[8, 6], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(x, mesh, ["x", None])
return x
def init_y_row(trans_y):
if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(y, mesh, [None, "y"])
return y
else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(y, mesh, ["y", None])
return y
def init_y_col(trans_y):
if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(y, mesh, ["y", None])
return y
else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(y, mesh, [None, "y"])
return y
......
......@@ -71,11 +71,8 @@ class TestDistOpCost(unittest.TestCase):
shape=[4, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[2, 8], value=1, dtype='float32')
weight_attr = paddle.ParamAttr()
......@@ -121,17 +118,12 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x"])
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
......@@ -141,12 +133,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.shard_tensor(
W, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
......@@ -154,26 +143,20 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, "x"])
out1 = paddle.fluid.layers.matmul(out,
param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, None])
tmp_out = paddle.fluid.layers.matmul(out1, tmp_param)
out2 = paddle.fluid.layers.matmul(tmp_out,
param2) # [8, 4] [-1, 0]
......@@ -227,17 +210,12 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x"])
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
......@@ -247,12 +225,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.shard_tensor(
W, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
......@@ -260,25 +235,20 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, "x"])
out1 = paddle.matmul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, None])
tmp_out = paddle.matmul(out1, tmp_param)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]
......@@ -331,17 +301,11 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x"])
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
......@@ -351,12 +315,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.shard_tensor(
W, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
......@@ -364,25 +325,21 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, "x"])
out1 = paddle.fluid.layers.mul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, None])
tmp_out = paddle.fluid.layers.mul(out1, tmp_param)
out2 = paddle.fluid.layers.mul(tmp_out,
param2) # [8, 4] [-1, 0]
......
......@@ -29,11 +29,8 @@ def make_program_dp2():
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None, None])
tmp_0 = paddle.norm(x, p=2)
return main_program, start_program, tmp_0
......@@ -44,11 +41,8 @@ def make_program_serial():
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0]),
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(x, auto.ProcessMesh([0], dim_names=["x"]),
[None, None, None])
tmp_0 = paddle.norm(x, p=2)
return main_program, start_program, tmp_0
......
......@@ -29,11 +29,9 @@ def make_program_dp2():
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None, None])
tmp_0 = paddle.reshape(x, shape=[0, 0, 4, 2])
tmp_1 = paddle.reshape(tmp_0, shape=[0, 0, 8])
tmp_2 = tmp_1.reshape((tmp_1.shape[0], tmp_1.shape[1], -1))
......
......@@ -71,7 +71,7 @@ class TestDistributedTensor(unittest.TestCase):
def test_new_local_tensor(self):
test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh(
mesh=[0, 1])
mesh=[0, 1], dim_names=["x"])
test_auto_parallel_reshard._global_parallel_strategy = "dp"
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册