未验证 提交 c5cc4278 编写于 作者: Y Yulong Ao 提交者: GitHub

[Cherry-pick][Auto Parallel] Improve the APIs (#46164)

* [AutoParallel] adapt gradient merge pass (#45915)

* adapt gradient merge

* fix op_role

* fix strategy

* [Auto Parallel] Gradient Fuse Allreduce (#45643)

* bugfix (#45332)

* dist embedding support lookup table v1

* add unitest

* customize wait_comm

* group gradients

* bugfix

* update program

* [Auto Parallel] Improve the APIs (#45776)

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Add the serialization process for dist attrs

* [Auto Parallel] Remove unnecessary comments

* [Auto Parallel] Fix some bugs

* [Auto Parallel] Fix the code style

* [Auto Parallel] Remove unnecessary impls

* [Auto Parallel] Fix the importing error

* [Auto Parallel] Fix the copy from bugs of op dist attr

* [Auto Parallel] Replace the use of constexpr if

* [Auto Parallel] Redesign the shard_tensor, shard_op and ProcessMesh

* [Auto Parallel] Change API of the completion unittest

* [Auto Parallel] Fix the bug when set_attr an int

* [Auto Parallel] Add the unittest for the serialization

* [Auto Parallel] Add some unit tests

* [Auto Paralle] Unify the strategy

* [Auto Parallel] Improve the engine api

* [Auto Parallel] Reset the changes made to the framework

* [Auto Parallel] Change the engine unittest

* [Auto Parallel] Update API of the completion and partitioner

* [Auto Parallel] Update unit tests using engine api

* update shard annotation

* [Auto Parallel] Remove the modifications of other modules

* [Auto Parallel] Add docs for APIs

* add new strategy

* [Auto Parallel] Replace the logger

* [Auto Parallel] Restore the test_program.py

* [Auto Parallel] Change the import rules

* [Auto Parallel] Add the examples for Engine

* [Auto Parallel] Do some minor changes

* [Auto Parallel] Remove yaml dependency

* [Auto Parallel] Fix the unittests

* add valid after train

* bug fix
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <48191911+Caozhou1995@users.noreply.github.com>

* [Auto Parallel] Bugfix allreduce fuse for MP (#46086)

* bugfix

* bugfix

* typos fixed

* update strategy (#46138)
Co-authored-by: Nzhaoyingli <86812880+zhaoyinglia@users.noreply.github.com>
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <48191911+Caozhou1995@users.noreply.github.com>
上级 860f6077
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .interface import shard_tensor # noqa: F401 from .strategy import Strategy
from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
from .reshard import Resharder # noqa: F401 from .engine import Engine
from .cost_model import estimate_cost from .interface import shard_tensor
from .interface import shard_op
from .interface import recompute
from .interface import fetch
__all__ = [] __all__ = []
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from collections import defaultdict
# _g_default_config[category][field] = default_value
_g_default_config = defaultdict(dict)
def get_category_default_config(category):
return _g_default_config[category]
def set_category_default_config(category, default_value):
_g_default_config[category] = default_value
def get_field_default_config(category, field):
return _g_default_config[category][field]
def set_field_default_config(category, field, default_value):
_g_default_config[category][field] = default_value
NOT_FOUND = "not_found"
#########################################
# base configuration
#########################################
BASE = "base"
set_field_default_config(BASE, "auto_mode", "semi")
set_field_default_config(BASE, "gradient_scale", True)
set_field_default_config(BASE, "use_cache", True)
set_field_default_config(BASE, "return_numpy", True)
set_field_default_config(BASE, "all_ranks", False)
set_field_default_config(BASE, "split_data", False)
set_field_default_config(BASE, "seed", None)
set_field_default_config(BASE, "reinit", False) # Only for debug
#########################################
# recompute configuration
#########################################
RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None)
set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
# AMP configuration
#########################################
AMP = "amp"
set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
set_field_default_config(AMP, "incr_ratio", 2.0)
set_field_default_config(AMP, "decr_ratio", 0.8)
set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_optimizer_fp16", False)
#########################################
# sharding configuration
#########################################
SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0)
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])
#########################################
# gradient merge configuration
#########################################
GRADIENT_MERGE = "gradient_merge"
set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
set_field_default_config(GRADIENT_MERGE, "avg", True)
#########################################
# quantization configuration
#########################################
QAT = "qat"
set_field_default_config(QAT, "enable", False)
set_field_default_config(QAT, "channel_wise_abs_max", True)
set_field_default_config(QAT, "weight_bits", 8)
set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None)
# #########################################
# auto tuning configuration
# #########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
set_field_default_config(TUNING, "dataset", None)
set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)
...@@ -16,7 +16,7 @@ import paddle ...@@ -16,7 +16,7 @@ import paddle
import warnings import warnings
import logging import logging
import numpy as np import numpy as np
from ..utils import get_logger from .utils import get_logger
class Converter(object): class Converter(object):
......
...@@ -173,6 +173,17 @@ class TensorDistributedAttribute: ...@@ -173,6 +173,17 @@ class TensorDistributedAttribute:
def clear_annotated(self): def clear_annotated(self):
self._is_annotated.clear() self._is_annotated.clear()
def __eq__(self, other):
if not isinstance(other, TensorDistributedAttribute):
return False
if self.process_mesh != other.process_mesh:
return False
if self.dims_mapping != other.dims_mapping:
return False
if self._is_annotated != other._is_annotated:
return False
return True
def __str__(self): def __str__(self):
str = "\n\ttensor_dist_attr = {" str = "\n\ttensor_dist_attr = {"
if self.is_annotated("process_mesh"): if self.is_annotated("process_mesh"):
...@@ -486,6 +497,27 @@ class OperatorDistributedAttribute: ...@@ -486,6 +497,27 @@ class OperatorDistributedAttribute:
else: else:
return False return False
def __eq__(self, other):
if not isinstance(other, OperatorDistributedAttribute):
return False
if self.process_mesh != other.process_mesh:
return False
if self.op_type != other.op_type:
return False
if self.impl_type != other.impl_type:
return False
if self.impl_idx != other.impl_idx:
return False
if self._is_annotated != other._is_annotated:
return False
if self._is_recompute != other._is_recompute:
return False
if self.inputs_dist_attrs != other.inputs_dist_attrs:
return False
if self.outputs_dist_attrs != other.outputs_dist_attrs:
return False
return True
def __str__(self): def __str__(self):
str = "\n\top_dist_attr = {" str = "\n\top_dist_attr = {"
if self.is_annotated("process_mesh"): if self.is_annotated("process_mesh"):
......
...@@ -126,9 +126,6 @@ class DistributedContext: ...@@ -126,9 +126,6 @@ class DistributedContext:
# A flag indicates whether the used parallelism is data parallel # A flag indicates whether the used parallelism is data parallel
self._data_parallel = False self._data_parallel = False
# flag whether using `to_static`
self._dygraph_mode = False
@property @property
def serial_main_program(self): def serial_main_program(self):
return self._serial_main_program return self._serial_main_program
......
...@@ -23,6 +23,7 @@ from .dist_attribute import append_op_input_suffix ...@@ -23,6 +23,7 @@ from .dist_attribute import append_op_input_suffix
from .dist_attribute import append_op_output_suffix from .dist_attribute import append_op_output_suffix
from .dist_attribute import get_tensor_dist_attr_field_keys from .dist_attribute import get_tensor_dist_attr_field_keys
from .dist_attribute import get_op_dist_attr_field_keys from .dist_attribute import get_op_dist_attr_field_keys
from .utils import convert_to_shard_spec, verify_shard_spec
class DistributedOperator: class DistributedOperator:
...@@ -248,23 +249,106 @@ class DistributedOperator: ...@@ -248,23 +249,106 @@ class DistributedOperator:
return result return result
class DistributedModule: class DistributedOperatorHelper:
def __init__(self, serial_module, dist_attr=None): def __init__(self, serial_op, process_mesh, in_dims_mappings,
self._serial_module = serial_module out_dims_mappings):
self._dist_attr = dist_attr self._serial_op = serial_op
self._process_mesh = process_mesh
self._in_dims_mappings = in_dims_mappings
self._out_dims_mappings = out_dims_mappings
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context tensor_to_dims_mapping = {}
index = 0
if self._in_dims_mappings:
assert len(args) + len(kwargs) == len(self._in_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._in_dims_mappings), len(args) + len(kwargs))
for arg in args:
if isinstance(arg, Variable) and self._in_dims_mappings:
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
index += 1
for arg in kwargs.values() and self._in_dims_mappings:
if isinstance(arg, Variable):
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
index += 1
default_prog = paddle.fluid.default_main_program() default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block() cur_block = default_prog.current_block()
op_size = len(cur_block.ops) op_size = len(cur_block.ops)
output = self._serial_module(*args, **kwargs) output = self._serial_op(*args, **kwargs)
new_op_size = len(cur_block.ops) new_op_size = len(cur_block.ops)
if isinstance(output, tuple) or isinstance(output, list):
new_output = list(output)
elif isinstance(output, Variable):
new_output = [output]
else:
raise ValueError("Unrecognized outpout.")
if self._out_dims_mappings:
assert len(new_output) == len(self._out_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._out_dims_mappings), len(new_output))
for i, item in enumerate(new_output):
if isinstance(item, Variable) and self._out_dims_mappings:
tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i]
from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context() default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size): for idx in range(op_size, new_op_size):
op = cur_block.ops[idx] op = cur_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr) dist_op = DistributedOperator(op)
dist_op.dist_attr.mark_annotated_as(self._dist_attr) for name in dist_op.serial_op.input_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_input(name)
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
name)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
for name in dist_op.serial_op.output_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_output(name)
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
name)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
dist_op.dist_attr.process_mesh = self._process_mesh
if self._process_mesh is not None:
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op) default_dist_ctx.add_dist_op_for_program(dist_op)
return output return output
...@@ -27,7 +27,7 @@ from paddle.fluid.framework import static_only ...@@ -27,7 +27,7 @@ from paddle.fluid.framework import static_only
from .utils import get_dist_attr from .utils import get_dist_attr
from .converter import Converter from .converter import Converter
from .process_group import _g_process_group_map from .process_group import _g_process_group_map
from ..utils import get_logger from .utils import get_logger
def check_filename(re_exp, filename): def check_filename(re_exp, filename):
...@@ -59,6 +59,14 @@ class DistributedSaver: ...@@ -59,6 +59,14 @@ class DistributedSaver:
def save(self, path, serial_program, dist_main_program, dist_context): def save(self, path, serial_program, dist_main_program, dist_context):
def _save_state(program, path, mode="param"):
state = {
k: np.array(v)
for k, v in program.state_dict(mode).items()
}
with open(path, "wb") as f:
pickle.dump(state, f)
dirname, filename = _process_path(path) dirname, filename = _process_path(path)
rank_id = paddle.distributed.get_rank() rank_id = paddle.distributed.get_rank()
...@@ -76,16 +84,6 @@ class DistributedSaver: ...@@ -76,16 +84,6 @@ class DistributedSaver:
with open(dist_model_path, "wb") as f: with open(dist_model_path, "wb") as f:
f.write(dist_main_program.desc.serialize_to_string()) f.write(dist_main_program.desc.serialize_to_string())
# save distributed params
dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams"
dist_param_path = os.path.join(dirname, dist_param_filename)
dist_param = {
k: np.array(v)
for k, v in dist_main_program.state_dict().items()
}
with open(dist_param_path, "wb") as f:
pickle.dump(dist_param, f)
# save distributed attribute # save distributed attribute
dist_attr_filename = filename + "_dist" + str(rank_id) + ".pdattr" dist_attr_filename = filename + "_dist" + str(rank_id) + ".pdattr"
dist_attr_path = os.path.join(dirname, dist_attr_filename) dist_attr_path = os.path.join(dirname, dist_attr_filename)
...@@ -93,65 +91,69 @@ class DistributedSaver: ...@@ -93,65 +91,69 @@ class DistributedSaver:
with open(dist_attr_path, "wb") as f: with open(dist_attr_path, "wb") as f:
pickle.dump(dist_attrs, f) pickle.dump(dist_attrs, f)
# save distributed params
dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams"
dist_param_path = os.path.join(dirname, dist_param_filename)
_save_state(dist_main_program, dist_param_path)
# save distributed opt states
dist_opt_filename = filename + "_dist" + str(rank_id) + ".pdopt"
dist_opt_path = os.path.join(dirname, dist_opt_filename)
_save_state(dist_main_program, dist_opt_path, "opt")
# TODO:save cluster.json # TODO:save cluster.json
def load(self, def load(self, path, load_optimizer=True):
path,
program,
dist_context,
strict=True,
load_optimizer=True):
# TODO: if `program` is None, load `path.pdmodel`. # TODO: if `program` is None, load `path.pdmodel`.
def _load_file(filename, dirname, suffix="pdparams"):
file_list = []
for file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix),
file):
file_list.append(os.path.join(dirname, file))
file_list.sort()
return file_list
def _load_state(filename, dirname, suffix="pdparams"):
file_list = _load_file(filename, dirname, suffix)
state_dict = {}
for file in file_list:
with open(file, 'rb') as f:
state_dict_info = pickle.load(f, encoding='latin1')
for name, value in state_dict_info.items():
if name in state_dict:
state_dict[name].append(np.array(value))
else:
state_dict[name] = [np.array(value)]
self._logger.info("Load param file: {}".format(file_list))
return state_dict
filename = os.path.basename(path) filename = os.path.basename(path)
if filename == "": if filename == "":
raise ValueError( raise ValueError(
"path should be of 'dirname/filename' format, but received filename is empty string" "path should be of 'dirname/filename' format, but received filename is empty string"
) )
dirname = os.path.dirname(path) dirname = os.path.dirname(path)
# load path.pdparam
param_file_list = [] # load path.pdparam and path.pdopt
for param_file in os.listdir(dirname): param_state_dict = _load_state(filename, dirname)
if check_filename('{}(.*)_dist(.*).pdparams'.format(filename), opt_state_dict = _load_state(filename, dirname,
param_file): "pdopt") if load_optimizer else {}
param_file_list.append(os.path.join(dirname, param_file)) state_dict = dict(param_state_dict, **opt_state_dict)
param_file_list.sort()
self._logger.info(
"Load distributed attribute file: {}".format(param_file_list))
param_dict = {}
for param_file in param_file_list:
with open(param_file, 'rb') as f:
state_dict_info = pickle.load(f, encoding='latin1')
for name, value in state_dict_info.items():
if name in param_dict:
param_dict[name].append(np.array(value))
else:
param_dict[name] = [np.array(value)]
# load path.pdattr # load path.pdattr
dist_attr_file_list = [] dist_attr_file_list = _load_file(filename, dirname, "pdattr")
for dist_attr_file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).pdattr'.format(filename),
dist_attr_file):
dist_attr_file_list.append(os.path.join(dirname,
dist_attr_file))
dist_attr_file_list.sort()
self._logger.info( self._logger.info(
"Load distributed attribute file: {}".format(dist_attr_file_list)) "Load distributed attribute file: {}".format(dist_attr_file_list))
pre_dist_attr = {} dist_attr = {}
for dist_attr_file in dist_attr_file_list: for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f: with open(dist_attr_file, 'rb') as f:
dist_attr = pickle.load(f, encoding='latin1') dist_attr_info = pickle.load(f, encoding='latin1')
for name, attr in dist_attr.items(): for name, attr in dist_attr_info.items():
if name not in pre_dist_attr: if name not in dist_attr:
pre_dist_attr[name] = attr dist_attr[name] = attr
# get current dist_attr return state_dict, dist_attr
cur_dist_attr = get_dist_attr(program, dist_context)
# param convert
converter = Converter(param_dict, pre_dist_attr, cur_dist_attr)
param_dict = converter.convert(strict=strict)
program.set_state_dict(param_dict)
def save_inference_model(self, path, feed_vars, fetch_vars, exe, **kwargs): def save_inference_model(self, path, feed_vars, fetch_vars, exe, **kwargs):
......
...@@ -19,13 +19,13 @@ import paddle ...@@ -19,13 +19,13 @@ import paddle
from paddle.nn import Layer from paddle.nn import Layer
from paddle.jit import to_static, not_to_static from paddle.jit import to_static, not_to_static
from paddle.distributed.utils import get_logger
from paddle.fluid.framework import Operator, Parameter, _non_static_mode from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import program_guard from paddle.fluid.framework import program_guard
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from .utils import to_list from .utils import to_list
from .utils import get_logger
from .converter import Converter from .converter import Converter
......
...@@ -12,101 +12,199 @@ ...@@ -12,101 +12,199 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy
import copy
import paddle import paddle
import paddle.fluid.core as core from paddle.fluid import core
from paddle.fluid.framework import Variable from .process_mesh import ProcessMesh
from paddle.fluid.framework import _non_static_mode from .process_mesh import get_current_process_mesh
from .process_mesh import set_current_process_mesh
from .process_mesh import reset_current_process_mesh
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
from .dist_op import DistributedModule from .dist_op import DistributedOperatorHelper
from .dist_attribute import TensorDistributedAttribute from .utils import verify_shard_spec, convert_to_dims_mapping
from .dist_attribute import OperatorDistributedAttribute
def _static_mode_check(): def shard_tensor(x, process_mesh=None, shard_spec=None):
if _non_static_mode():
raise RuntimeError("Auto-parallel only supports static mode for now, "
"please use paddle.enable_static() first.")
def shard_tensor(x, dist_attr=None):
""" """
Add distributed attributes for a tensors. Shard a tensor on a process mesh according to the shard specification.
Args: Args:
x (Tensor): the tensor to be sharded. x (Tensor): the tensor to be sharded.
dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow: process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh
"process_mesh": a nested list an to describe the mesh topology of logical processes. topology of the used logical processes where the tensor is sharded. If it is None,
"dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension the found current process mesh will be used. And an error will be raised if the
`i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`, current process mesh cannot be found. Default: None.
where -1 means that tensor dimension is not split. shard_spec (list, optional): a list to describe the sharding mapping between `x` and `process_mesh`,
Both process_mesh and dims_mapping are optional and users can specify as need. which means the dimension `i` of `x` is split across the dimension `shard_spec[i]` of `process_mesh`,
where `None` means that tensor dimension is not split. For example, given a tensor wih
the shape [6, 12] and a process mesh with the shape [2, 3] and the dimension names ["x", "y"]:
If `shard_spec=["x", "y"]`, each shard of the tensor will have a shape [3, 4];
If `shard_spec=["y", "x"]`, each shard of the tensor will have a shape [2, 6];
If `shard_spec=["x", None]`, each shard of the tensor will have a shape [3, 12];
If `shard_spec=[None, "x"]`, each shard of the tensor will have a shape [6, 4];
If `shard_spec=["y", None]`, each shard of the tensor will have a shape [2, 12];
If `shard_spec=[None, "y"]`, each shard of the tensor will have a shape [6, 4];
If `shard_spec=[None, None]`, each shard of the tensor will have a shape [6, 12];
If the `shard_spec` is None, the tensor will be replicated across all the processes of `process_mesh`.
In the above example, the `shard_spec=None` is same as 'shard_spec=[None, None]'. Defaults: None.
Returns: Returns:
Tensor: the tensor `x` annotated with distributed attributes. Tensor: the tensor `x` annotated with sharding information.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.distributed as dist import paddle.distributed.auto_parallel as auto
paddle.enable_static()
mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
x = paddle.ones([4, 6]) x = paddle.ones([4, 6])
dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]], shard_spec = ["x", "y"]
"dims_mapping": [0, -1]}) auto.shard_tensor(x, mesh, shard_spec)
""" """
_static_mode_check()
assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \ if process_mesh is not None:
"The type of dist_attr must be None, dict or TensorDistributedAttribute." assert isinstance(process_mesh, ProcessMesh), \
dist_tensor = DistributedTensor(x, dist_attr) "Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh)
dist_tensor.dist_attr.mark_annotated_as(dist_attr) else:
process_mesh = get_current_process_mesh()
assert process_mesh is not None, \
"Specify the process mesh argument or use ProcessMesh context manager first."
assert isinstance(shard_spec, list), \
"Argument shard_spec {} is not an instance of list".format(shard_spec)
dist_tensor = DistributedTensor(x)
serial_tensor = dist_tensor.serial_tensor
dist_tensor.dist_attr.process_mesh = process_mesh
if serial_tensor.type == core.VarDesc.VarType.READER \
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
if shard_spec is not None:
assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
serial_tensor.name, shard_spec, tensor_shape, process_mesh)
dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping(
shard_spec, process_mesh)
if process_mesh is not None:
dist_tensor.dist_attr.mark_annotated("process_mesh")
if shard_spec is not None:
dist_tensor.dist_attr.mark_annotated("dims_mapping")
default_dist_ctx = get_default_distributed_context() default_dist_ctx = get_default_distributed_context()
default_dist_ctx.add_dist_tensor_for_program(dist_tensor) default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x)
return x return x
def shard_op(op_fn, dist_attr=None): def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
""" """
Call a functioin and add distributed attributes for ops added by the function. Shard an operation on a process mesh according to its input and output shard specification.
Args: Args:
op_fn (callable): a callable operator or module to be sharded. op (Callable): a callable operator or module to be sharded.
dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh
two categories. The first category decsribes the distributed attributes shared by all inputs and topology of the used logical processes where the op is sharded. All of its inputs and
outputs, and only `process_mesh` can be specified now. The second category describes distributed outputs are sharded by this process mesh. If it is None, the found current process mesh
attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are will be used. And an error will be raised if the current process mesh cannot be found.
optional and users can specify them as need. Note that `process_mesh` for operators must be the Default: None.
same as these process_meshes for inputs and outputs. in_shard_specs (list of list, optional): a list of list to describe the sharding specifications
for the inputs. Each item of `in_shard_specs` is a `shard_spec` between the correspoinding input
and `process_mesh`. If one item is None, the cooresponding input is replicated across all processes
If it is None, all inputs are replicated accross all processes. Note that the lenght of the
`in_shard_specs` should be equal to the actual number of inputs when calling this operation.
Default: None.
out_shard_specs (list of list, optional): a list of list to describe the sharding specifications
for the outputs. Each item of `out_shard_specs` is a `shard_spec` between the correspoinding output
and `process_mesh`. If one item is None, the cooresponding output is replicated across all processes
If it is None, all outputs are replicated accross all processes. Note that the lenght of the
`in_shard_specs` should be equal to the actual number of inputs when calling this operation.
Default: None. Default: None.
Returns: Returns:
list: the outputs of the function `op_fn`, which are annotated with distributed attributes. Outputs of `op`, each of which is annotated with sharding information.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.distributed as dist import paddle.distributed.auto_parallel as auto
paddle.enable_static()
x = paddle.ones([4, 6]) x = paddle.ones([4, 6])
y = paddle.zeros([4, 6]) y = paddle.zeros([4, 6])
dist_add = dist.shard_op(paddle.add, mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
dist_attr={ dist_add = auto.shard_op(paddle.add,
"process_mesh": [[2, 3, 1], [0, 4, 5]], in_shard_specs=[["x", "y"], ["y", None]],
x: {"dims_mapping": [-1, 0]}, out_shard_specs=[[None, "x"]])
y: {"dims_mapping": [0, -1]}
})
dist_add(x, y) dist_add(x, y)
""" """
_static_mode_check()
assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \ if process_mesh is not None:
"The type of dist_attr must be dict or OperatorDistributedAttribute." assert isinstance(process_mesh, ProcessMesh), \
dist_module = DistributedModule(op_fn, dist_attr) "Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh)
return dist_module else:
process_mesh = get_current_process_mesh()
assert process_mesh is not None, \
"Specify the process mesh argument or use ProcessMesh context manager first."
in_dims_mappings = []
if in_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in in_shard_specs), \
"in_shard_spec {} is not a list of list or None".format(in_shard_specs)
for shard_spec in in_shard_specs:
if shard_spec is not None:
in_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh))
else:
in_dims_mappings.append(None)
out_dims_mappings = []
if out_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in out_shard_specs), \
"out_shard_spec {} is not a list of list or None".format(out_shard_specs)
for shard_spec in out_shard_specs:
if shard_spec is not None:
out_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh))
else:
out_dims_mappings.append(None)
op = DistributedOperatorHelper(op, process_mesh, in_dims_mappings,
out_dims_mappings)
return op
def recompute(op):
class RecomputeOperator:
def __init__(self, op):
self._op = op
def __call__(self, *args, **kwargs):
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
op_size = len(cur_block.ops)
output = self._op(*args, **kwargs)
new_op_size = len(cur_block.ops)
for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
op._set_attr("is_recompute@auto_parallel", True)
return output
return RecomputeOperator(op)
_g_fetched_tensors = {}
def fetch(tensor, name=None):
if name is None:
_g_fetched_tensors[tensor.name] = tensor
else:
_g_fetched_tensors[name] = tensor
def _get_fetches():
return _g_fetched_tensors
...@@ -42,6 +42,7 @@ from .utils import make_data_unshard ...@@ -42,6 +42,7 @@ from .utils import make_data_unshard
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo from .utils import SerialProgramInfo
from .utils import get_logger
from .reshard import Resharder from .reshard import Resharder
from .cluster import Cluster from .cluster import Cluster
from .mapper import mapping from .mapper import mapping
...@@ -84,7 +85,7 @@ class AutoParallelizer: ...@@ -84,7 +85,7 @@ class AutoParallelizer:
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING") self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \ self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False self._need_rank_mapping.lower() == 'true' else False
self._pass_context = None # self._pass_context = None
def _remove_distributed_attrs(self, main_program): def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix() suffix = core.kAutoParallelSuffix()
...@@ -143,10 +144,11 @@ class AutoParallelizer: ...@@ -143,10 +144,11 @@ class AutoParallelizer:
def _apply_optimize(self, main_program, startup_program, params_grads): def _apply_optimize(self, main_program, startup_program, params_grads):
optimizer = copy.deepcopy(self._optimizer)
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy( optimize_ops = optimizer.apply_gradients(params_grads)
self._optimizer).apply_gradients(params_grads)
self._dist_context._lr_optimizer = optimizer
# update completion # update completion
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program) self._completer.complete_update_annotation(main_program)
...@@ -165,6 +167,15 @@ class AutoParallelizer: ...@@ -165,6 +167,15 @@ class AutoParallelizer:
config) config)
auto_parallel_sharding_pass.apply([main_program], [startup_program], auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context) self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["rank_id"] = rank
auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
auto_parallel_clip_pass.apply([main_program], [startup_program],
self._pass_context)
if self._dist_strategy.gradient_merge: if self._dist_strategy.gradient_merge:
config = copy.deepcopy(self._dist_strategy.gradient_merge_configs) config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
......
...@@ -22,7 +22,6 @@ from paddle.fluid import program_guard ...@@ -22,7 +22,6 @@ from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode, unique_name from paddle.fluid.framework import _non_static_mode, unique_name
from paddle.distributed.passes import new_pass from paddle.distributed.passes import new_pass
from paddle.distributed.utils import get_logger
from .reshard import Resharder from .reshard import Resharder
from .partitioner import Partitioner from .partitioner import Partitioner
...@@ -31,6 +30,7 @@ from .dist_saver import DistributedSaver ...@@ -31,6 +30,7 @@ from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list from .utils import print_program_with_dist_attr, to_list
from .utils import get_logger
from .process_group import get_all_process_groups, get_world_process_group from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context from .dist_context import DistributedContext, get_default_distributed_context
...@@ -160,8 +160,8 @@ class Parallelizer: ...@@ -160,8 +160,8 @@ class Parallelizer:
# apply quantization pass # apply quantization pass
# The pass can be applied when mode must be 'train' # The pass can be applied when mode must be 'train'
if self._mode == 'train' and self._strategy.qat: if self._mode == 'train' and self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat_configs) config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
auto_parallel_quantization_pass = new_pass( auto_parallel_quantization_pass = new_pass(
...@@ -176,8 +176,8 @@ class Parallelizer: ...@@ -176,8 +176,8 @@ class Parallelizer:
# apply amp pass # apply amp pass
# FIXME we disenable amp for eval since it has a little bug with # FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future # eval program and which will be fixed in future
if self._mode == 'train' and self._strategy.amp: if self._mode == 'train' and self._strategy.amp.enable:
config = copy.deepcopy(self._strategy.amp_configs) config = copy.deepcopy(self._strategy.amp.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
config["loss"] = loss config["loss"] = loss
...@@ -195,8 +195,8 @@ class Parallelizer: ...@@ -195,8 +195,8 @@ class Parallelizer:
# apply recompute pass # apply recompute pass
# recompute is then train-only optimization # recompute is then train-only optimization
if self._mode == "train" and self._strategy.recompute: if self._mode == "train" and self._strategy.recompute.enable:
config = copy.deepcopy(self._strategy.recompute_configs) config = copy.deepcopy(self._strategy.recompute.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["no_grad_set"] = None config["no_grad_set"] = None
config["loss"] = loss config["loss"] = loss
...@@ -217,12 +217,12 @@ class Parallelizer: ...@@ -217,12 +217,12 @@ class Parallelizer:
config = {} config = {}
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["global_rank"] = rank config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass.apply([main_program], [startup_program], self._pass_context) dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding: if self._strategy.sharding.enable:
config = copy.deepcopy(self._strategy.sharding_configs) config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
config["global_rank"] = rank config["global_rank"] = rank
...@@ -230,11 +230,11 @@ class Parallelizer: ...@@ -230,11 +230,11 @@ class Parallelizer:
config) config)
auto_parallel_sharding_pass.apply([main_program], [startup_program], auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context) self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
# GradClip is train-only optimization # GradClip is train-only optimization
if self._mode == "train": if self._mode == "train":
config = copy.deepcopy(self._strategy.sharding_configs) config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
config["rank_id"] = rank config["rank_id"] = rank
...@@ -244,8 +244,8 @@ class Parallelizer: ...@@ -244,8 +244,8 @@ class Parallelizer:
self._pass_context) self._pass_context)
# gradient_merge is then train-only optimization # gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge: if self._mode == "train" and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge_configs) config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
auto_parallel_gradient_merge_pass = new_pass( auto_parallel_gradient_merge_pass = new_pass(
......
...@@ -12,86 +12,90 @@ ...@@ -12,86 +12,90 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy import numpy as np
import copy import copy
import paddle
# Use to store the previous and current process mesh
_g_previous_process_mesh = None
_g_current_process_mesh = None
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def get_current_process_mesh():
global _g_current_process_mesh
return _g_current_process_mesh
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
def set_current_process_mesh(process_mesh):
global _g_previous_process_mesh
global _g_current_process_mesh
_g_previous_process_mesh = _g_current_process_mesh
_g_current_process_mesh = process_mesh
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args: def reset_current_process_mesh():
mesh (list): an N-dimensional array (nested list) describes the toplogy global _g_previous_process_mesh
of logical processes. The shape of the N-dimensional array global _g_current_process_mesh
represents the topology of logical processes and every _g_current_process_mesh = _g_previous_process_mesh
element of the N-dimensional array represents a logical process.
Returns: class ProcessMesh(object):
None """
The `Processmesh` object describes the topology of the used processes.
Raises: Args:
ValueError: If `mesh` is not an instance of list. mesh (list|numpy.array): an n-dimensional array describes the toplogy
of the processes.
dim_names (list, optional): the i-th element of this list gives the name of the
i-th dimension of the mesh.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
assert mesh.topology == [2, 3] assert mesh.shape == [2, 3]
assert mesh.processes == [2, 4, 5, 0, 1, 3] assert mesh.processe_ids == [2, 4, 5, 0, 1, 3]
""" """
def __init__(self, mesh): def __init__(self, mesh=None, dim_names=None, shape=None, process_ids=None):
if mesh is None or not isinstance(mesh, list): # Use shape and process_ids just for compatibility
raise ValueError('mesh must be an instance of list.') # Users should not use these directly
if mesh is None:
processes = _flatten_nested_list(mesh) assert shape is not None
assert process_ids is not None
assert all(isinstance(p, int) for p in processes), \ mesh = np.array(process_ids).reshape(shape)
("All elements of mesh must be integer")
if not isinstance(mesh, list) and \
assert min(processes) >= 0, ('All elements of mesh must be >= 0.') not isinstance(mesh, np.ndarray):
raise ValueError(
unique_processes = set(processes) 'The mesh must be an instance of list or np.ndarray.')
assert len(unique_processes) == len(processes), ( if isinstance(mesh, list):
'All elements of mesh must be unique.') mesh = np.array(mesh)
self._topology = _get_nested_list_shape(mesh) self._mesh = mesh
self._processes = processes self._shape = list(self._mesh.shape)
self._process_ids = self._mesh.flatten().tolist()
assert all(isinstance(p, int) for p in self._process_ids), \
("All elements of the mesh must be integer")
assert min(
self._process_ids) >= 0, ('All elements of the mesh must be >= 0.')
unique_process_ids = set(self._process_ids)
assert len(unique_process_ids) == len(
self._process_ids), ('All elements of the mesh must be unique.')
if dim_names is not None:
assert len(dim_names) == len(self._shape), \
("The length of dims_names must be same as the shape of the mesh.")
self._dim_names = copy.deepcopy(dim_names)
else:
self._dim_names = ["d" + str(i) for i in range(len(self._shape))]
unique_dim_names = set(self._dim_names)
assert len(unique_dim_names) == len(self._dim_names), (
'All dim_names {} must be unique.'.format(dim_names))
# Store all process meshes # Store all process meshes
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
...@@ -103,31 +107,117 @@ class ProcessMesh(object): ...@@ -103,31 +107,117 @@ class ProcessMesh(object):
pg0.add_ranks(self.processes) pg0.add_ranks(self.processes)
@property @property
def topology(self): def shape(self):
r""" """
Get the topology of logical processes belonging to this ProcessMesh. Get the shape of this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
""" """
return self._topology return self._shape
@property @property
def processes(self): def process_ids(self):
r""" """
Get a list of all processes belonging to this ProcessMesh. Get the process ids belonging to this ProcessMesh.
""" """
return self._processes return self._process_ids
@property
def dim_names(self):
"""
Get the dimension names of this ProcessMesh.
"""
return self._dim_names
@property @property
def ndim(self): def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
""" """
return len(self._topology) Get the number of dimension of this ProcessMesh.
"""
return len(self._shape)
@property
def mesh(self):
"""
Get the underlying mesh of ProcessMesh.
"""
return self._mesh
@property
def topology(self):
return self._shape
@property
def processes(self):
return self._process_ids
def __getitem__(self, index):
if isinstance(index, tuple):
new_dim_names = []
for i, item in enumerate(index):
if isinstance(item, slice):
new_dim_names.append(self._dim_names[i])
new_mesh = self._mesh[index]
if new_mesh.shape:
return ProcessMesh(new_mesh, new_dim_names)
else:
# Wrap a scalar into a list but without dim_names
return ProcessMesh([new_mesh])
elif isinstance(index, slice):
new_mesh = self._mesh[index]
new_dim_names = self._dim_names
return ProcessMesh(new_mesh, new_dim_names)
else:
new_mesh = self._mesh[index]
new_dim_names = self._dim_names[1:]
return ProcessMesh(new_mesh, new_dim_names)
def __enter__(self):
set_current_process_mesh(self)
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
self._old_var_names = list(cur_block.vars.keys())
self._old_op_size = len(cur_block.ops)
def __exit__(self, exc_type, exc_value, exc_traceback):
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
new_var_names = list(cur_block.vars.keys())
new_op_size = len(cur_block.ops)
from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context()
for name in new_var_names:
if name not in self._old_var_names:
tensor = cur_block.vars[name]
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(
tensor)
if dist_tensor is None:
dist_tensor = DistributedTensor(cur_block.vars[name],
{"process_mesh": self})
dist_tensor.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
else:
if dist_tensor.dist_attr.process_mesh is None:
dist_tensor.dist_attr.process_mesh = self
dist_tensor.dist_attr.mark_annotated("process_mesh")
for idx in range(self._old_op_size, new_op_size):
op = cur_block.ops[idx]
dist_op = default_dist_ctx.get_dist_op_for_program(op)
if dist_op is None:
dist_op = DistributedOperator(op, {"process_mesh": self})
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)
else:
if dist_op.dist_attr.process_mesh is None:
dist_op.dist_attr.process_mesh = self
dist_op.dist_attr.mark_annotated("process_mesh")
reset_current_process_mesh()
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, ProcessMesh): if not isinstance(other, ProcessMesh):
return False return False
if self.topology != other.topology or self.processes != other.processes: if self.shape != other.shape or self.process_ids != other.process_ids:
return False return False
return True return True
...@@ -135,6 +225,6 @@ class ProcessMesh(object): ...@@ -135,6 +225,6 @@ class ProcessMesh(object):
return not self.__eq__(other) return not self.__eq__(other)
def __str__(self): def __str__(self):
str = "shape {} and process group {}".format(self.topology, str = "shape {}, process_ids {}, dim_nams {}".format(
self.processes) self.shape, self.process_ids, self.dim_names)
return str return str
...@@ -81,54 +81,57 @@ class ProcessMesh(core.ProcessMesh): ...@@ -81,54 +81,57 @@ class ProcessMesh(core.ProcessMesh):
return self._mesh return self._mesh
# def compute_compatible_process_meshes(process_meshes): def compute_compatible_process_mesh(process_meshes):
# """Compute the compatible process mesh given a list of process meshes.""" """Compute the compatible process mesh given a list of process meshes."""
# if not process_meshes: if not process_meshes:
# return None return None
# def _compute_compatible_two_process_meshes(pm1, pm2): def _compute_compatible_of_two_process_meshes(pm1, pm2):
# if pm1 is None: if pm1 is None:
# return True, pm2 return True, pm2
# if pm2 is None: if pm2 is None:
# return True, pm1 return True, pm1
# if pm1 == pm2: if pm1 == pm2:
# return True, pm1 return True, pm1
# if pm1.device_mesh != pm2.device_mesh: if pm1.process_ids == pm2.process_ids:
# return False, None if len(pm1.shape) >= len(pm2.shape):
# if pm1.process_ids == pm2.process_ids: return True, pm1
# if len(pm1.shape) >= len(pm2.shape): else:
# return True, pm1 return True, pm2
# else: process_set1 = set(pm1.process_ids)
# return True, pm2 process_set2 = set(pm2.process_ids)
# process_set1 = set(pm1.process_ids) if process_set1.issubset(process_set2):
# process_set2 = set(pm2.process_ids) return True, pm2
# if process_set1.issubset(process_set2): if process_set2.issubset(process_set1):
# return True, pm2 return True, pm1
# if process_set2.issubset(process_set1): return False, None
# return True, pm1
# return False, None compatible_result = None
for process_mesh in process_meshes:
# compatible_result = None compatible, compatible_result = _compute_compatible_of_two_process_meshes(
# for process_mesh in process_meshes: compatible_result, process_mesh)
# compatible, compatible_result = _compute_compatible_two_process_meshes( if not compatible:
# compatible_result, process_mesh) return None
# if not compatible: if compatible_result.empty():
# return None return None
# return ProcessMesh(compatible_result.mesh, compatible_result.dim_names) if isinstance(compatible_result, core.ProcessMesh):
mesh = np.array(compatible_result.process_ids).reshape(
# def merge_process_meshes(process_meshes): compatible_result.shape)
# """Merge a list of process meshes.""" return ProcessMesh(mesh, compatible_result.dim_names)
# merged_process_mesh = None elif isinstance(compatible_result, ProcessMesh):
# merged_process_ids = set() return ProcessMesh(compatible_result.mesh, compatible_result.dim_names)
# device_type = "" else:
# for process_mesh in process_meshes: raise ValueError("Unrecognized ProcessMesh.")
# if process_mesh is not None:
# process_ids = set(process_mesh.process_ids)
# if not device_type: def merge_process_mesh(process_meshes):
# device_type = process_mesh.device_type """Merge a list of process meshes."""
# assert device_type != process_mesh.device_type, \ merged_process_mesh = None
# "All process meshes must have the same device_type." merged_process_ids = set()
# merged_process_ids.union(process_ids) for process_mesh in process_meshes:
# if len(merged_process_ids) != 0: if process_mesh is not None:
# merged_process_mesh = ProcessMesh(list(merged_process_ids)) process_ids = set(process_mesh.process_ids)
# return merged_process_mesh merged_process_ids = merged_process_ids.union(process_ids)
if len(merged_process_ids) != 0:
merged_process_mesh = ProcessMesh(list(merged_process_ids))
return merged_process_mesh
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os
import copy
import argparse
from . import constants
class BaseConfig(object):
def __init__(self, category, config_dict=None):
self._category = category
self._config_dict = None
if config_dict is not None:
if isinstance(config_dict, dict):
self._config_dict = config_dict
else:
raise ValueError(
"Expected a dictionary. But received: {}".format(
config_dict))
# Initialize attributes by the default config
config = constants.get_category_default_config(self._category)
for field, default_value in config.items():
setattr(self, field, default_value)
# Overide attributes by the config_dict
if self._config_dict:
self.from_dict(self._config_dict)
def from_dict(self, config_dict):
config = constants.get_category_default_config(self._category)
for field in config.keys():
value = config_dict.get(field, constants.NOT_FOUND)
# Use the default value if we cannot found the value
if value != constants.NOT_FOUND:
setattr(self, field, value)
def to_dict(self):
result_dict = {}
config = constants.get_category_default_config(self._category)
for field in config.keys():
value = getattr(self, field)
result_dict[field] = value
for field, value in self.__dict__.items():
if isinstance(value, BaseConfig):
result_dict[field] = value.to_dict()
return result_dict
def __repr__(self):
return yaml.dump(self.to_dict(),
default_flow_style=False,
sort_keys=True,
indent=4)
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
return result
class RecomputeConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.RECOMPUTE
super(RecomputeConfig, self).__init__(category, config_dict)
class AMPConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.AMP
super(AMPConfig, self).__init__(category, config_dict)
class ShardingConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.SHARDING
super(ShardingConfig, self).__init__(category, config_dict)
class GradientMergeConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.GRADIENT_MERGE
super(GradientMergeConfig, self).__init__(category, config_dict)
class QATConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.QAT
super(QATConfig, self).__init__(category, config_dict)
class TuningConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.TUNING
super(TuningConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig):
"""
The `Strategy` object is used to configure the paralleization and optimization beheviors.
Args:
config (dict|string, optional): If this is None, the default configurations will used.
If this is a dictionary, the recognized key-value of it will be used to override the default
configurations while other default configurations are left unchanged. If this is a string,
it is interpreted as the path to a YAML configuration and will be loaded to override the
corresponding default configurations.
Examples:
.. code-block:: python
import paddle
import paddle.distributed.auto_parallel as auto
strategy = auto.Strategy()
sharding = strategy.sharding
self.assertEqual(sharding.enabled, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.degree, 8)
sharding.enabled = True
sharding.stage = 2
sharding.degree = 2
self.assertEqual(sharding.enabled, True)
self.assertEqual(sharding.stage, 2)
self.assertEqual(sharding.degree, 2)
"""
def __init__(self, config=None):
if config is not None:
if isinstance(config, dict):
self._config_dict = copy.deepcopy(config)
# elif os.path.exists(config):
# with open(config, "rb") as yaml_file:
# self._config_dict = yaml.load(yaml_file, Loader=yaml.Loader)
else:
raise ValueError(
"Expected a dictionary. But received: {}".format(config))
else:
self._config_dict = {}
category = constants.BASE
super(Strategy, self).__init__(category, self._config_dict)
config_dict = self._config_dict.get(constants.RECOMPUTE, None)
self.recompute = RecomputeConfig(config_dict)
config_dict = self._config_dict.get(constants.AMP, None)
self.amp = AMPConfig(config_dict)
config_dict = self._config_dict.get(constants.SHARDING, None)
self.sharding = ShardingConfig(config_dict)
config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None)
self.gradient_merge = GradientMergeConfig(config_dict)
config_dict = self._config_dict.get(constants.QAT, None)
self.qat = QATConfig(config_dict)
config_dict = self._config_dict.get(constants.TUNING, None)
self.tuning = TuningConfig(config_dict)
...@@ -16,7 +16,7 @@ import copy ...@@ -16,7 +16,7 @@ import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import logging import logging
from paddle.distributed.utils import get_logger from ..utils import get_logger
from .trial import TrialStatus from .trial import TrialStatus
from .trial import OptimizationTunerTrial as Trial from .trial import OptimizationTunerTrial as Trial
...@@ -110,13 +110,13 @@ class ShardingStageAlgorithm(AlgorithmBase): ...@@ -110,13 +110,13 @@ class ShardingStageAlgorithm(AlgorithmBase):
# TODO import trial class & copy strategy # TODO import trial class & copy strategy
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self._changed_configs = ["sharding_configs"] self._changed_configs = ["sharding"]
def _init_spaces(self): def _init_spaces(self):
self._max_stage = 3 self._max_stage = 3
self._trial_idx = 0 self._trial_idx = 0
stage_range = self._config.sharding_configs.get("stage_range", None) stage_range = self._config.sharding.to_dict().get("tuning_range", None)
if stage_range: if stage_range:
assert set(stage_range).issubset( assert set(stage_range).issubset(
set([0, 1, 2, 3]) set([0, 1, 2, 3])
...@@ -136,9 +136,8 @@ class ShardingStageAlgorithm(AlgorithmBase): ...@@ -136,9 +136,8 @@ class ShardingStageAlgorithm(AlgorithmBase):
stage = self._stage_range[self._trial_idx] stage = self._stage_range[self._trial_idx]
new_strategy = copy.deepcopy(self._config.dist_strategy) new_strategy = copy.deepcopy(self._config.dist_strategy)
config_dict = new_strategy.sharding_configs sharding = new_strategy.sharding
config_dict["stage"] = stage sharding.stage = stage
new_strategy.sharding_configs = config_dict
name = "trial-sharding-stage{}".format(stage) name = "trial-sharding-stage{}".format(stage)
trial = Trial(new_strategy, name, self.changed_configs) trial = Trial(new_strategy, name, self.changed_configs)
......
...@@ -17,15 +17,13 @@ import copy ...@@ -17,15 +17,13 @@ import copy
import pathlib import pathlib
import paddle import paddle
from paddle.distributed import fleet from ..strategy import Strategy
_tuning_supported_passes = ["sharding", "recompute"] _tuning_supported_passes = ["sharding", "recompute"]
_strategy_config_suffiex = "_configs"
def _get_pass_config(strategy, pass_name): def _get_pass_config(strategy, pass_name):
config_name = pass_name + _strategy_config_suffiex config = getattr(strategy, pass_name)
config = getattr(strategy, config_name)
return config return config
...@@ -38,10 +36,8 @@ class TuningConfig(object): ...@@ -38,10 +36,8 @@ class TuningConfig(object):
def __init__(self, user_config, strategy): def __init__(self, user_config, strategy):
if not isinstance(strategy, fleet.DistributedStrategy): if not isinstance(strategy, Strategy):
raise TypeError( raise TypeError("'strategy' must be object of class `Strategy`.")
"'strategy' must be object of class `fleet.DistributedStrategy`."
)
if not user_config: if not user_config:
user_config = {} user_config = {}
...@@ -116,11 +112,11 @@ class TuningConfig(object): ...@@ -116,11 +112,11 @@ class TuningConfig(object):
for p in _tuning_supported_passes: for p in _tuning_supported_passes:
if getattr(self._dist_strategy, p) and _get_pass_config( if getattr(self._dist_strategy, p) and _get_pass_config(
self._dist_strategy, p)["enable_tuning"]: self._dist_strategy, p).enable_tuning:
# TODO distinguish different args of each passes # TODO distinguish different args of each passes
self._tuning_passes_name.add(p) self._tuning_passes_name.add(p)
config_name = p + _strategy_config_suffiex config_name = p
p_dict = getattr(self._dist_strategy, config_name) p_dict = getattr(self._dist_strategy, config_name)
self.__dict__[config_name] = p_dict self.__dict__[config_name] = p_dict
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# import yaml
import os import os
import sys import sys
import copy import copy
...@@ -29,7 +30,6 @@ import paddle ...@@ -29,7 +30,6 @@ import paddle
from paddle.fluid import program_guard from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.passes import new_pass, PassContext
from paddle.distributed.utils import get_logger
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.completion import Completer
...@@ -39,6 +39,7 @@ from paddle.distributed.auto_parallel.process_group import clear_all_process_gro ...@@ -39,6 +39,7 @@ from paddle.distributed.auto_parallel.process_group import clear_all_process_gro
from paddle.distributed.auto_parallel.utils import debug_program from paddle.distributed.auto_parallel.utils import debug_program
from paddle.distributed.auto_parallel.utils import make_data_unshard, set_grad_var_shape from paddle.distributed.auto_parallel.utils import make_data_unshard, set_grad_var_shape
from ..utils import get_logger
from .config import TuningConfig from .config import TuningConfig
from .algorithms import new_algorithm from .algorithms import new_algorithm
from .trial import TrialStatus from .trial import TrialStatus
...@@ -256,8 +257,8 @@ class OptimizationTuner: ...@@ -256,8 +257,8 @@ class OptimizationTuner:
startup_program = dist_context.serial_startup_program startup_program = dist_context.serial_startup_program
# applying optimization pass # applying optimization pass
if new_strategy.amp: if new_strategy.amp.enable:
config = copy.deepcopy(new_strategy.amp_configs) config = copy.deepcopy(new_strategy.amp.to_dict())
config["dist_context"] = dist_context config["dist_context"] = dist_context
config["params_grads"] = dist_context._params_grads config["params_grads"] = dist_context._params_grads
...@@ -275,8 +276,8 @@ class OptimizationTuner: ...@@ -275,8 +276,8 @@ class OptimizationTuner:
auto_parallel_amp_pass.apply([main_program], [startup_program], auto_parallel_amp_pass.apply([main_program], [startup_program],
pass_context) pass_context)
if new_strategy.recompute: if new_strategy.recompute.enable:
config = copy.deepcopy(new_strategy.recompute_configs) config = copy.deepcopy(new_strategy.recompute.to_dict())
config["dist_context"] = dist_context config["dist_context"] = dist_context
config["no_grad_set"] = None config["no_grad_set"] = None
config["loss"] = dist_context.serial_loss config["loss"] = dist_context.serial_loss
...@@ -303,8 +304,8 @@ class OptimizationTuner: ...@@ -303,8 +304,8 @@ class OptimizationTuner:
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
if new_strategy.sharding: if new_strategy.sharding.enable:
config = copy.deepcopy(new_strategy.sharding_configs) config = copy.deepcopy(new_strategy.sharding.to_dict())
config["dist_context"] = dist_context config["dist_context"] = dist_context
config["params_grads"] = dist_params_grads config["params_grads"] = dist_params_grads
config["global_rank"] = self.rank config["global_rank"] = self.rank
...@@ -313,8 +314,8 @@ class OptimizationTuner: ...@@ -313,8 +314,8 @@ class OptimizationTuner:
auto_parallel_sharding_pass.apply([dist_main_prog], auto_parallel_sharding_pass.apply([dist_main_prog],
[dist_startup_prog], pass_context) [dist_startup_prog], pass_context)
if new_strategy.gradient_merge: if new_strategy.gradient_merge.enable:
config = copy.deepcopy(new_strategy.gradient_merge_configs) config = copy.deepcopy(new_strategy.gradient_merge.to_dict())
config["dist_context"] = dist_context config["dist_context"] = dist_context
config["params_grads"] = dist_params_grads config["params_grads"] = dist_params_grads
auto_parallel_gradient_merge_pass = new_pass( auto_parallel_gradient_merge_pass = new_pass(
...@@ -492,9 +493,10 @@ The best trial is: [{}], whose configuration is following: ...@@ -492,9 +493,10 @@ The best trial is: [{}], whose configuration is following:
for line in summary_.split("\n"): for line in summary_.split("\n"):
fw.write(line + "\n") fw.write(line + "\n")
full_strategy = self.get_best_config() # full_strategy = self.get_best_config()
full_strategy.save_to_prototxt( # path = os.path.join(self.project_dir, "tuned_dist_strategy.yaml")
os.path.join(self.project_dir, "tuned_dist_strategy.prototxt")) # with open(path, 'w') as outfile:
# yaml.dump(full_strategy, outfile, default_flow_style=False)
def clear(self): def clear(self):
""" """
......
...@@ -156,9 +156,10 @@ class OptimizationTunerTrial(Trial): ...@@ -156,9 +156,10 @@ class OptimizationTunerTrial(Trial):
draws += h1_format.format("{} auto=True <-> {}".format(name, name)) draws += h1_format.format("{} auto=True <-> {}".format(name, name))
draws += line + "\n" draws += line + "\n"
my_configs = getattr(self.space, name) my_configs = getattr(self.space, name)
keys = my_configs.keys() keys = my_configs.to_dict().keys()
for key in keys: for key in keys:
draws += h2_format.format(key, str(my_configs.get(key, None))) draws += h2_format.format(
key, str(my_configs.to_dict().get(key, None)))
result_res = draws + border result_res = draws + border
return result_res return result_res
......
...@@ -28,6 +28,19 @@ from paddle.fluid.io import is_parameter, is_belong_to_optimizer ...@@ -28,6 +28,19 @@ from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
def get_logger(log_level, name="auto_parallel"):
logger = logging.getLogger(name)
logger.propagate = False
if not logger.handlers:
logger.setLevel(log_level)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
log_handler.setFormatter(log_format)
logger.addHandler(log_handler)
return logger
def is_valid_list_index(list, index): def is_valid_list_index(list, index):
if index >= -len(list) and index < len(list): if index >= -len(list) and index < len(list):
return True return True
...@@ -49,6 +62,58 @@ def is_dim_replicate(mapping): ...@@ -49,6 +62,58 @@ def is_dim_replicate(mapping):
return False return False
def verify_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None:
return False
if not all(isinstance(d, int) for d in dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(process_mesh.shape):
return False
for i in range(len(process_mesh.shape)):
if dims_mapping.count(i) > 1:
return False
return True
def convert_to_dims_mapping(shard_spec, process_mesh):
dims_mapping = []
for shard in shard_spec:
if shard is None:
dims_mapping.append(-1)
else:
dims_mapping.append(process_mesh.dim_names.index(shard))
return dims_mapping
def convert_to_shard_spec(dims_mapping, process_mesh):
shard_spec = []
for dim_mapping in dims_mapping:
if dim_mapping == -1:
shard_spec.append(None)
else:
shard_spec.append(process_mesh.dim_names[dim_mapping])
return shard_spec
def verify_shard_spec(shard_spec, tensor_shape, process_mesh):
if len(shard_spec) != len(tensor_shape):
return False
for shard in shard_spec:
if shard is not None and not isinstance(shard, str):
return False
if shard is not None and shard not in process_mesh.dim_names:
return False
dims_mapping = convert_to_dims_mapping(shard_spec, process_mesh)
if not verify_dims_mapping(dims_mapping, process_mesh):
return False
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0:
return False
return True
def compute_compatible_dim_mapping(dim_mappings): def compute_compatible_dim_mapping(dim_mappings):
if not dim_mappings: if not dim_mappings:
return None return None
...@@ -1040,7 +1105,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1040,7 +1105,7 @@ def set_grad_var_shape(program, dist_context):
if op.type in [ if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast", "c_allreduce_sum", "c_identity", "scale", "cast",
'fill_any_like' "fill_any_like"
]: ]:
forward_var_name = op.input_arg_names[0] forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad": elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad":
...@@ -1504,3 +1569,15 @@ def ring_id_to_process_group(ring_id): ...@@ -1504,3 +1569,15 @@ def ring_id_to_process_group(ring_id):
if g.id == ring_id: if g.id == ring_id:
return g return g
return None return None
def find_higher_order_backward_op(program):
higher_order_op_suffix = ['_grad_grad', 'triple_grad']
for block in program.blocks:
for op in block.ops:
for suffix in higher_order_op_suffix:
if suffix in op.type:
return True
return False
...@@ -314,7 +314,9 @@ class AMPState(object): ...@@ -314,7 +314,9 @@ class AMPState(object):
consume_op_attr.set_input_dist_attr( consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr) cast_name, in_var_dist_attr)
else: else:
assert in_var.dtype == dst_dtype assert in_var.dtype == dst_dtype, "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type, in_name, dst_dtype, in_var.dtype,
str(grad_op))
for out_name in grad_op.output_names: for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
import numpy as np
import paddle import paddle
from paddle.fluid import core, unique_name
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here # add new optimizers supporting rescale_grad here
...@@ -31,6 +33,10 @@ __rescale_grad_supported_opts__ = [ ...@@ -31,6 +33,10 @@ __rescale_grad_supported_opts__ = [
__max_stream_num_allow__ = 16 __max_stream_num_allow__ = 16
def numel(var):
return np.prod(list(var.shape))
@register_pass("auto_parallel_data_parallel_optimization") @register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase): class DataParallelOptimizationPass(PassBase):
""" """
...@@ -78,7 +84,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -78,7 +84,9 @@ class DataParallelOptimizationPass(PassBase):
self._analyze_program() self._analyze_program()
self._prune_grad_scaling() self._prune_grad_scaling()
self._calc_comm_overlap() self._calc_comm_overlap()
self._fuse_allreduce() grad_group = self._fuse_allreduce()
# self.summary(grad_group)
def _prune_grad_scaling(self): def _prune_grad_scaling(self):
...@@ -99,7 +107,14 @@ class DataParallelOptimizationPass(PassBase): ...@@ -99,7 +107,14 @@ class DataParallelOptimizationPass(PassBase):
self._calc_wait_comms() self._calc_wait_comms()
def _fuse_allreduce(self): def _fuse_allreduce(self):
pass
if not self._could_be_fuse():
return []
grad_group = self._group_grads()
self._update_program(grad_group)
return grad_group
def _analyze_program(self): def _analyze_program(self):
""" """
...@@ -154,7 +169,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -154,7 +169,7 @@ class DataParallelOptimizationPass(PassBase):
def _could_be_prune(self): def _could_be_prune(self):
return self.dist_context._gradient_scale and ( return self.dist_context.gradient_scale and (
self._support_rescale_grad or self._all_dp_groups_same_degree()) self._support_rescale_grad or self._all_dp_groups_same_degree())
def _all_dp_groups_same_degree(self): def _all_dp_groups_same_degree(self):
...@@ -316,3 +331,252 @@ class DataParallelOptimizationPass(PassBase): ...@@ -316,3 +331,252 @@ class DataParallelOptimizationPass(PassBase):
'op_role': OpRole.Backward, 'op_role': OpRole.Backward,
'ring_id': ring_id 'ring_id': ring_id
}) })
def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient.
# should analyse the dependencies of gradient in backward.
if find_higher_order_backward_op(default_main_program()):
return False
if self.use_sharding:
return False
return True
def _group_grads(self):
"""
conditions for gradients to be grouped:
1. group size < max_fuse_numel
2. same dp group
3. same dtype
4. dependency: grad would NOT be used by other ops within group segment
gradients inside same group would be fuse into one coalesce tensor
"""
block = default_main_program().global_block()
ops = block.ops
# group individual grad vars
# TODO consider fuse gradient for sharding reduce
# TODO let user to set fuse_grad_size
# emb = 50000 * h, ffn = 8 * h * h, mha = 4 * h * h
h = 2048
ffn_numel = 2 * (4 * h) * h
mha_numel = 3 * h * h + h * h
max_fuse_numel = ffn_numel + mha_numel
grad_groups = []
cur_group = GradientsGroup(ops, max_fuse_numel)
grouped_grad_names = set()
def collect_group(cur_group, grad_var, ring_id, i):
if len(cur_group.gradients) == 0:
cur_group = None
elif len(cur_group.gradients) == 1:
grouped_grad_names.remove(cur_group.gradients[0].name)
else:
cur_group.finalize()
grad_groups.append(cur_group)
new_group = GradientsGroup(ops, max_fuse_numel)
if grad_var:
new_group.add(grad_var, ring_id, i)
grouped_grad_names.add(grad_var.name)
return new_group
def op_depend_on_group(op, group):
vars_ = set(op.input_arg_names + op.output_arg_names)
grad_names = set([grad.name for grad in group.gradients])
return len(vars_.intersection(grad_names)) > 0
for i, op in enumerate(ops):
if is_data_parallel_reduce_op(op):
ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0]
grad_var = block.var(grad_name)
grad_numel = numel(grad_var)
if cur_group.acceptable(grad_var, ring_id):
assert grad_name not in grouped_grad_names
grouped_grad_names.add(grad_name)
cur_group.add(grad_var, ring_id, i)
else:
cur_group = collect_group(cur_group, grad_var, ring_id, i)
else:
if op_depend_on_group(op, cur_group):
cur_group = collect_group(cur_group, None, None, None)
# collect last group
collect_group(cur_group, None, None, None)
return grad_groups
def _update_program(self, grad_groups):
block = default_main_program().global_block()
remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute']
for i, group in enumerate(grad_groups[::-1]):
# create coalecse tensor
group.coalesce_var = block.create_var(name=unique_name.generate(
'coalecse_grad_{}'.format(i)),
dtype=group.dtype,
persistable=False,
stop_gradient=True)
# update allreduce & scale op
if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx]
assert scale_op.type == 'scale', "should found scale op but found {}".format(
str(scale_op))
scale_op._rename_input(scale_op.input_arg_names[0],
group.coalesce_var.name)
scale_op._rename_output(scale_op.output_arg_names[0],
group.coalesce_var.name)
allreduce_op = block.ops[group.allreduce_op_idx]
assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format(
str(allreduce_op))
allreduce_op._rename_input(allreduce_op.input_arg_names[0],
group.coalesce_var.name)
allreduce_op._rename_output(allreduce_op.output_arg_names[0],
group.coalesce_var.name)
# remvoe un-used op
remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices
for idx in sorted(remove_op_indices, reverse=True):
assert block.ops[
idx].type in remove_op_types, "Unexception: try to remove op {}".format(
str(op))
block._remove_op(idx)
# insert coalecse op
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(group.coalesce_op_idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
outputs={
"Output": grad_names,
"FusedOutput": group.coalesce_var
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": group.dtype,
"concated_shapes":
concated_shapes,
"concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward
})
block._sync_with_cpp()
# TODO update dist attr
def summary(self, grad_groups=[]):
# TODO: add logger module
import logging
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
self._logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format)
self._logger.addHandler(log_handler)
if len(grad_groups) > 0:
self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
.format(len(self._grad_name_to_group_map.keys()),
len(grad_groups)))
self._logger.info("gradient fusing group are following: ")
fused_grads = set()
for i, group in enumerate(grad_groups):
self._logger.info(
"coalecse gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients]))
fused_grads.update([grad.name for grad in group.gradients])
individual_grads = set(
self._grad_name_to_group_map.keys()) - set(fused_grads)
self._logger.info(
"the following [{}] gradients are not fused: ".format(
len(individual_grads)))
self._logger.info("individual gradient {}".format(individual_grads))
class GradientsGroup(object):
def __init__(self, ops, max_group_size):
self.max_group_size = max_group_size
self.ops = ops
self.gradients = []
self.numel = 0
self.dtype = None
self.ring_id = None
self.coalesce_var = None
self.coalesce_op_idx = -1
self.allreduce_op_idx = -1
self.scale_op_idx = -1
self.remove_wait_op_indices = []
self.remove_allreduce_op_indices = []
self.remove_scale_op_indices = []
def acceptable(self, grad_var, ring_id):
if len(self.gradients) == 0:
return True
if ring_id != self.ring_id:
return False
if numel(grad_var) + self.numel > self.max_group_size:
return False
if grad_var.dtype != self.dtype:
return False
return True
def add(self, grad_var, ring_id, i):
self.gradients.append(grad_var)
self.ring_id = ring_id
self.dtype = grad_var.dtype
self.numel += numel(grad_var)
# remove auxiliary ops in non-fuse dp allreduce
self.remove_allreduce_op_indices.append(i)
# NOTE this pass rely on the original synchronization add in previous passes
# (same stream or calc_wait_comm & comm_wait_calc)
# to guarantee the correctness of comm_calc execution order.
# so the calc_wait_comm should be keep.
grad_op_idx = i - 1
if i > 0 and self.ops[i - 1].type == 'c_wait_compute':
self.remove_wait_op_indices.append(i - 1)
grad_op_idx -= 1
if i + 1 < len(self.ops) and is_data_parallel_scale_op(self.ops[i - 1]):
self.remove_scale_op_indices.append(i + 1)
if len(self.gradients) == 1:
# TODO Remove this is a temporary hack for Tensor Parallel. the logic
# for find grad_op should be more general.
if self.ops[grad_op_idx].type == "c_allreduce_sum":
grad_op_idx -= 1
grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op))
self.coalesce_op_idx = grad_op_idx
def finalize(self):
self.allreduce_op_idx = self.remove_allreduce_op_indices.pop()
if len(self.remove_wait_op_indices) > 1:
self.remove_wait_op_indices.pop()
if len(self.remove_scale_op_indices) > 1:
self.scale_op_idx = self.remove_scale_op_indices.pop()
...@@ -16,6 +16,7 @@ from collections import defaultdict ...@@ -16,6 +16,7 @@ from collections import defaultdict
import paddle import paddle
from paddle.framework import core from paddle.framework import core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name from paddle.fluid import unique_name
from .pass_base import register_pass from .pass_base import register_pass
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
...@@ -379,6 +380,10 @@ class FP16State(object): ...@@ -379,6 +380,10 @@ class FP16State(object):
# create cast grad # create cast grad
grad_slot_name = slot_name + "@GRAD" grad_slot_name = slot_name + "@GRAD"
assert grad_slot_name in op.output_names assert grad_slot_name in op.output_names
if len(op.output(grad_slot_name)) == 0:
var = block.var(src_name)
assert var.stop_gradient is True
continue
assert len(op.output(grad_slot_name)) == 1 assert len(op.output(grad_slot_name)) == 1
grad_name = op.output(grad_slot_name)[0] grad_name = op.output(grad_slot_name)[0]
grad = block.var(grad_name) grad = block.var(grad_name)
...@@ -442,7 +447,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): ...@@ -442,7 +447,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
inputs = {'X': grads, 'Scale': loss_scaling} inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf} outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward} attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale', new_op = main_block.append_op(type='check_finite_and_unscale',
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
...@@ -536,6 +541,39 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): ...@@ -536,6 +541,39 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
return output_var return output_var
def cast_startup_program():
main_program = default_main_program()
startup_program = default_startup_program()
param_to_dtype = {}
for block in main_program.blocks:
for p in block.all_parameters():
param_to_dtype[p.name] = p.dtype
def is_initialization_op(op):
comm_op_prefix = "c_"
op_type = op.type
if op_type.startswith(comm_op_prefix):
return False
if len(op.output_arg_names) != 1 and len(op.input_arg_names) != 0:
return False
return True
for op in startup_program.global_block().ops:
if is_initialization_op(op):
output_name = op.output_arg_names[0]
if param_to_dtype.get(output_name,
None) == core.VarDesc.VarType.FP16:
assert op.has_attr(
'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format(
str(op))
if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
@register_pass("auto_parallel_fp16") @register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass): class FP16Pass(AMPPass):
...@@ -563,6 +601,8 @@ class FP16Pass(AMPPass): ...@@ -563,6 +601,8 @@ class FP16Pass(AMPPass):
input_data_var_names) input_data_var_names)
is_train = fp16_state._build_state() is_train = fp16_state._build_state()
cast_startup_program()
if is_train: if is_train:
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference # TODO (JZ-LIANG)support cast forward program only when inference
...@@ -575,18 +615,18 @@ class FP16Pass(AMPPass): ...@@ -575,18 +615,18 @@ class FP16Pass(AMPPass):
) or self.get_attr("init_loss_scaling") != 1.0: ) or self.get_attr("init_loss_scaling") != 1.0:
found_infs = [] found_infs = []
if fp32_grads: if fp32_grads:
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
_, found_inf_fp32 = _check_and_update_gradient( _, found_inf_fp32 = _check_and_update_gradient(
fp32_grads, self._loss_scaling, "@fp32", fp32_grads, self._loss_scaling, "@fp32",
self.dist_context) self.dist_context)
found_infs.append(found_inf_fp32) found_infs.append(found_inf_fp32)
if fp16_grads: if fp16_grads:
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
_, found_inf_fp16 = _check_and_update_gradient( _, found_inf_fp16 = _check_and_update_gradient(
fp16_grads, self._loss_scaling, "@fp16", fp16_grads, self._loss_scaling, "@fp16",
self.dist_context) self.dist_context)
found_infs.append(found_inf_fp16) found_infs.append(found_inf_fp16)
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
block = main_program.global_block() block = main_program.global_block()
all_infs = paddle.fluid.layers.concat(found_infs) all_infs = paddle.fluid.layers.concat(found_infs)
...@@ -608,7 +648,7 @@ class FP16Pass(AMPPass): ...@@ -608,7 +648,7 @@ class FP16Pass(AMPPass):
block, self.dist_context) block, self.dist_context)
if self.get_attr("use_dynamic_loss_scaling"): if self.get_attr("use_dynamic_loss_scaling"):
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
if fp32_grads: if fp32_grads:
self._update_loss_scaling(fp32_grads, found_inf) self._update_loss_scaling(fp32_grads, found_inf)
if fp16_grads: if fp16_grads:
......
...@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase):
super(ClipGradByGloblNormPass, self).__init__() super(ClipGradByGloblNormPass, self).__init__()
self.set_attr("rank_id", None) self.set_attr("rank_id", None)
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("params_grads", None)
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
...@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context") dist_context = self.get_attr("dist_context")
if dist_context._lr_optimizer._grad_clip is None: if dist_context._lr_optimizer._grad_clip is None:
return False return False
if self.get_attr("params_grads") is None:
return False
return True return True
def _check_conflict(self, other_pass): def _check_conflict(self, other_pass):
...@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context", None) dist_context = self.get_attr("dist_context", None)
rank_id = self.get_attr("rank_id", None) rank_id = self.get_attr("rank_id", None)
block = main_program.global_block() block = main_program.global_block()
dist_params_grads = _get_params_grads(block) dist_params_grads = self.get_attr("params_grads", None)
# dist_params_grads = _get_params_grads(block)
self.clip_helper = ClipHelper(dist_params_grads, rank_id, block, self.clip_helper = ClipHelper(dist_params_grads, rank_id, block,
dist_context) dist_context)
......
...@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context): ...@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
return optimize_ops_desc return optimize_ops_desc
def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker
op = grad.op
if op and op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName())
def _get_gm_cond_var(main_program, k_steps, dist_context): def _get_gm_cond_var(main_program, k_steps, dist_context):
main_block = main_program.global_block() main_block = main_program.global_block()
# Add const var # Add const var
...@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op( ...@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
param.type != core.VarDesc.VarType.SELECTED_ROWS param.type != core.VarDesc.VarType.SELECTED_ROWS
), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now" ), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
_remove_op_role_var(param, grad)
# {grad.name: gradient_merge_var.name} to rename opt inputs # {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge = {} grad_to_gradient_merge = {}
# {param: gradient_merge_var} to insert scale op and fill_constant op # {param: gradient_merge_var} to insert scale op and fill_constant op
......
...@@ -51,7 +51,8 @@ class ShardingPass(PassBase): ...@@ -51,7 +51,8 @@ class ShardingPass(PassBase):
super(ShardingPass, self).__init__() super(ShardingPass, self).__init__()
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("stage", None) self.set_attr("stage", None)
self.set_attr("sharding_degree", None) self.set_attr("sharding_degree", None) # for parallelizer
self.set_attr("degree", None) # for parallelizer_v2
self.set_attr("params_grads", []) self.set_attr("params_grads", [])
self.set_attr("global_rank", -1) self.set_attr("global_rank", -1)
self.dp_groups = set() self.dp_groups = set()
...@@ -59,6 +60,7 @@ class ShardingPass(PassBase): ...@@ -59,6 +60,7 @@ class ShardingPass(PassBase):
self.varname_to_sharding_info = {} self.varname_to_sharding_info = {}
self.partial_sharding = False self.partial_sharding = False
self.outer_dp_group = None self.outer_dp_group = None
self.shared_params_grads = []
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
...@@ -66,8 +68,15 @@ class ShardingPass(PassBase): ...@@ -66,8 +68,15 @@ class ShardingPass(PassBase):
if self.get_attr("stage") not in [1, 2, 3]: if self.get_attr("stage") not in [1, 2, 3]:
return False return False
if (not isinstance(self.get_attr("sharding_degree"), if self.get_attr("sharding_degree") is not None:
int)) or self.get_attr("sharding_degree") <= 1: if (not isinstance(self.get_attr("sharding_degree"), int)) \
or self.get_attr("sharding_degree") <= 1:
return False
elif self.get_attr("degree") is not None:
if (not isinstance(self.get_attr("degree"), int)) \
or self.get_attr("degree") <= 1:
return False
else:
return False return False
if len(self.get_attr("params_grads")) <= 0: if len(self.get_attr("params_grads")) <= 0:
return False return False
...@@ -82,7 +91,8 @@ class ShardingPass(PassBase): ...@@ -82,7 +91,8 @@ class ShardingPass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
self._dist_context = self.get_attr("dist_context") self._dist_context = self.get_attr("dist_context")
self.sharding_world_size = int(self.get_attr("sharding_degree")) self.sharding_world_size = int(
self.get_attr("sharding_degree") or self.get_attr("degree"))
self.stage = int(self.get_attr("stage")) self.stage = int(self.get_attr("stage"))
self.global_rank = int(self.get_attr("global_rank")) self.global_rank = int(self.get_attr("global_rank"))
params_grads = self.get_attr("params_grads") params_grads = self.get_attr("params_grads")
...@@ -94,6 +104,8 @@ class ShardingPass(PassBase): ...@@ -94,6 +104,8 @@ class ShardingPass(PassBase):
self._shard_gradient_synchronization(main_block) self._shard_gradient_synchronization(main_block)
self._shard_parameter(main_block, startup_block) self._shard_parameter(main_block, startup_block)
context.set_attr("params_grads", self.shared_params_grads)
def _build_sharding_groups(self, main_block, params_grads): def _build_sharding_groups(self, main_block, params_grads):
self._collective_data_parallel_groups(main_block) self._collective_data_parallel_groups(main_block)
self._build_sharding_infos(params_grads) self._build_sharding_infos(params_grads)
...@@ -148,13 +160,10 @@ class ShardingPass(PassBase): ...@@ -148,13 +160,10 @@ class ShardingPass(PassBase):
self._dist_context._sharding_group = sharding_group self._dist_context._sharding_group = sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group = [p for p, g in params_grads]
assert len(params_in_group) == len(
set(params_in_group)), "found duplicated param in params_grads"
sharding_info = ShardingInfo(sharding_group, self.global_rank, sharding_info = ShardingInfo(sharding_group, self.global_rank,
params_in_group) params_grads)
self.sharding_infos.append(sharding_info) self.sharding_infos.append(sharding_info)
for param in params_in_group: for param in sharding_info.params:
self.varname_to_sharding_info[param.name] = sharding_info self.varname_to_sharding_info[param.name] = sharding_info
def _shard_optimizer(self, main_block, startup_block, params_grads, def _shard_optimizer(self, main_block, startup_block, params_grads,
...@@ -201,6 +210,7 @@ class ShardingPass(PassBase): ...@@ -201,6 +210,7 @@ class ShardingPass(PassBase):
op.desc.set_output('Out', reversed_x) op.desc.set_output('Out', reversed_x)
else: else:
if op.type == "check_finite_and_unscale": if op.type == "check_finite_and_unscale":
op_role = op.attr('op_role')
out_name = op.output_arg_names[0] out_name = op.output_arg_names[0]
out_var = main_block.vars[out_name] out_var = main_block.vars[out_name]
main_block._remove_op(idx, sync=False) main_block._remove_op(idx, sync=False)
...@@ -212,6 +222,7 @@ class ShardingPass(PassBase): ...@@ -212,6 +222,7 @@ class ShardingPass(PassBase):
"shape": out_var.shape, "shape": out_var.shape,
"dtype": out_var.dtype, "dtype": out_var.dtype,
"value": 0, "value": 0,
OP_ROLE_KEY: op_role,
}) })
else: else:
main_block._remove_op(idx, sync=False) main_block._remove_op(idx, sync=False)
...@@ -313,6 +324,9 @@ class ShardingPass(PassBase): ...@@ -313,6 +324,9 @@ class ShardingPass(PassBase):
if varname != param_name if varname != param_name
]) ])
main_block._remove_op(idx, sync=False) main_block._remove_op(idx, sync=False)
else:
self.shared_params_grads.append(
self._get_param_grad(param_name))
for idx, op in reversed(list(enumerate(startup_block.ops))): for idx, op in reversed(list(enumerate(startup_block.ops))):
if len(op.output_arg_names) == 1 and op.output_arg_names[ if len(op.output_arg_names) == 1 and op.output_arg_names[
...@@ -365,6 +379,13 @@ class ShardingPass(PassBase): ...@@ -365,6 +379,13 @@ class ShardingPass(PassBase):
sharding_info = self.varname_to_sharding_info[param_name] sharding_info = self.varname_to_sharding_info[param_name]
return sharding_info.is_in_local_shard(param_name) return sharding_info.is_in_local_shard(param_name)
def _get_param_grad(self, param_name):
assert param_name in self.varname_to_sharding_info
sharding_info = self.varname_to_sharding_info[param_name]
p_g = sharding_info.get_param_grad(param_name)
assert p_g is not None
return p_g
def _shard_gradient_synchronization(self, main_block): def _shard_gradient_synchronization(self, main_block):
if self.stage < 2: if self.stage < 2:
...@@ -705,9 +726,13 @@ def shard_parameters(params, group_size): ...@@ -705,9 +726,13 @@ def shard_parameters(params, group_size):
class ShardingInfo(object): class ShardingInfo(object):
def __init__(self, group, rank, params): def __init__(self, group, rank, params_grads):
self.group = group self.group = group
self.params = params self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads])
assert len(self.params_grads) == len(set(
self.params_grads)), "found duplicated param in params_grads"
self.params = [p for p, _ in params_grads]
self.param_names = [p.name for p in self.params] self.param_names = [p.name for p in self.params]
self.group_size = group.nranks self.group_size = group.nranks
self.global_rank = rank self.global_rank = rank
...@@ -762,3 +787,11 @@ class ShardingInfo(object): ...@@ -762,3 +787,11 @@ class ShardingInfo(object):
if usage > 0: if usage > 0:
broadcast_vars.add(param) broadcast_vars.add(param)
return broadcast_vars, param_usage return broadcast_vars, param_usage
def get_param_grad(self, param_name):
if not self.is_in_local_shard(param_name):
raise ValueError(
"param[{}] not in current rank.".format(param_name))
if param_name not in self.params_grads:
raise ValueError('param[{}] not in params_grads'.format(param_name))
return self.params_grads.get(param_name, None)
...@@ -37,8 +37,28 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -37,8 +37,28 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
${dist_ENVS}) ${dist_ENVS})
set_tests_properties(test_high_order_grad set_tests_properties(test_high_order_grad
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_grad_clip MODULES test_grad_clip ENVS ${dist_ENVS}) py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
set_tests_properties(test_grad_clip PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" ${dist_ENVS})
set_tests_properties(test_iterable_dataset
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_pass_grad_clip MODULES test_pass_grad_clip ENVS
${dist_ENVS})
set_tests_properties(test_pass_grad_clip
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_gradient_merge MODULES test_pass_gradient_merge
ENVS ${dist_ENVS})
set_tests_properties(test_pass_gradient_merge
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_recompute MODULES test_pass_recompute ENVS
${dist_ENVS})
set_tests_properties(test_pass_recompute
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_sharding MODULES test_pass_sharding ENVS
${dist_ENVS})
set_tests_properties(test_pass_sharding
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_amp MODULES test_pass_amp ENVS ${dist_ENVS})
set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50) TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion py_test_modules(test_while_op_completion MODULES test_while_op_completion
...@@ -70,11 +90,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -70,11 +90,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2) py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip) py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip)
py_test_modules(test_quantization MODULES test_quantization)
py_test_modules(test_dist_matmul MODULES test_dist_matmul) py_test_modules(test_dist_matmul MODULES test_dist_matmul)
py_test_modules(test_process_mesh MODULES test_process_mesh)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS})
set_tests_properties(test_iterable_dataset
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
endif() endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
def apply_pass(use_amp=False, level=None):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum'
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"]
amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level)
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestAMPPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, level=None):
reset_prog()
strategy = apply_pass(use_amp, level)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses, rtol=None, atol=None):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=rtol or self.rtol,
atol=atol or self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_amp_pass(self):
# mp2 training
mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"])
# mp2 amp-o1 training
amp_o1_engine = self.get_engine(True, "o1")
amp_o1_losses = amp_o1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o1_losses = np.array(amp_o1_losses["loss"])
# self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training
amp_o2_engine = self.get_engine(True, "o2")
amp_o2_losses = amp_o2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o2_losses = np.array(amp_o2_losses["loss"])
# self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training
amp_o3_engine = self.get_engine(True, "o3")
amp_o3_losses = amp_o3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o3_losses = np.array(amp_o3_losses["loss"])
# self.check_results(mp_losses, amp_o3_losses)
if __name__ == "__main__":
unittest.main()
...@@ -32,7 +32,7 @@ import paddle.distributed.auto_parallel as auto ...@@ -32,7 +32,7 @@ import paddle.distributed.auto_parallel as auto
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
batch_size = 4 batch_size = 4
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
...@@ -103,11 +103,7 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -103,11 +103,7 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, 1], shape=[batch_size, sequence_len, 1],
dtype='float32') dtype='float32')
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, [None, None, None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mappig": [-1, -1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -126,9 +122,6 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -126,9 +122,6 @@ def mlp_pretrain_forward(train_program, start_program):
def train(): def train():
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False dist_strategy.amp = False
dist_strategy.pipeline = False dist_strategy.pipeline = False
......
...@@ -18,24 +18,21 @@ import random ...@@ -18,24 +18,21 @@ import random
import numpy as np import numpy as np
import paddle import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed.auto_parallel.engine import Engine
from get_gpt_model import generate_model, create_data_holder, FakeDataset from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static() paddle.enable_static()
def apply_pass(use_sharding=False): def apply_pass(use_sharding=False):
strategy = fleet.DistributedStrategy() strategy = auto.Strategy()
strategy.semi_auto = True strategy.auto_mode = "semi"
strategy.reinit = True
if use_sharding: if use_sharding:
strategy.sharding = True sharding = strategy.sharding
strategy.sharding_configs = { sharding.degree = 2
"sharding_degree": 2, sharding.stage = 2
"stage": 2,
}
return strategy return strategy
...@@ -76,34 +73,17 @@ class TestGradientClipByGlobalNorm(unittest.TestCase): ...@@ -76,34 +73,17 @@ class TestGradientClipByGlobalNorm(unittest.TestCase):
paddle.seed(2022) paddle.seed(2022)
np.random.seed(2022) np.random.seed(2022)
random.seed(2022) random.seed(2022)
engine.mode = "train" place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor.run(engine.startup_program) engine._executor = paddle.static.Executor(place)
def get_dp2_engine(self): def get_engine(self, use_sharding=False):
reset_prog() reset_prog()
strategy = apply_pass() strategy = apply_pass(use_sharding)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp") model, loss = generate_model("dp")
inputs_spec, labels_spec = create_data_holder(self.batch_size) engine = auto.Engine(model, loss, opt, strategy=strategy)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
self.init(engine)
return engine
def get_dp2sharding2_engine(self):
reset_prog()
strategy = apply_pass(True)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
inputs_spec, labels_spec = create_data_holder(self.batch_size)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
self.init(engine) self.init(engine)
return engine return engine
...@@ -121,15 +101,13 @@ class TestGradientClipByGlobalNorm(unittest.TestCase): ...@@ -121,15 +101,13 @@ class TestGradientClipByGlobalNorm(unittest.TestCase):
def test_grad_clip(self): def test_grad_clip(self):
# dp2 training # dp2 training
dp_engine = self.get_dp2_engine() dp_engine = self.get_engine()
dp_engine.fit(self.dataset, batch_size=self.batch_size, use_cache=True) dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_param_values = get_parameter_value(dp_engine.main_program) dp_param_values = get_parameter_value(dp_engine.main_program)
# dp2sharding2 training # dp2sharding2 training
sharding_engine = self.get_dp2sharding2_engine() sharding_engine = self.get_engine(True)
sharding_engine.fit(self.dataset, sharding_engine.fit(self.dataset, 3, batch_size=self.batch_size)
batch_size=self.batch_size,
use_cache=True)
sharding_param_values = get_parameter_value( sharding_param_values = get_parameter_value(
sharding_engine.main_program) sharding_engine.main_program)
......
...@@ -27,10 +27,8 @@ import paddle.nn.functional as F ...@@ -27,10 +27,8 @@ import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
from paddle.fluid import layers from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.optimizer.lr import CosineAnnealingDecay from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn from paddle.fluid.dataloader.collate import default_collate_fn
...@@ -47,6 +45,8 @@ class_num = 10 ...@@ -47,6 +45,8 @@ class_num = 10
paddle.seed(44) paddle.seed(44)
is_fetch = True
class MyDataset(Dataset): class MyDataset(Dataset):
...@@ -90,19 +90,20 @@ class MLPLayer(nn.Layer): ...@@ -90,19 +90,20 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input): def forward(self, input):
out = auto.shard_op(self.norm, dist_attr={"process_mesh": out = auto.shard_op(self.norm, PP_MESH_0)(input)
PP_MESH_0})(input)
out = self.linear0(out) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, dist_attr={"process_mesh": out = auto.shard_op(self.linear1, PP_MESH_1)(out)
PP_MESH_1})(out)
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
self.out = out if is_fetch:
auto.fetch(out, "out")
return out return out
def train(fetch): def train(fetch):
global is_fetch
is_fetch = fetch
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
dropout_ratio=0.1, dropout_ratio=0.1,
...@@ -113,46 +114,34 @@ def train(fetch): ...@@ -113,46 +114,34 @@ def train(fetch):
beta2=0.999, beta2=0.999,
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=None)
metric = paddle.metric.Accuracy()
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') strategy = auto.Strategy()
labels_spec = InputSpec([batch_size], 'int64', 'label') strategy.auto_mode = "semi"
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset, eval_dataset1 = MyDataset(5 * batch_size)
engine.fit(train_data=train_dataset,
epochs=2,
batch_size=batch_size, batch_size=batch_size,
steps_per_epoch=batch_num * batch_size, valid_data=eval_dataset1)
fetches=fetches)
# eval # eval
eval_dataset = MyDataset(batch_size) eval_dataset2 = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches) engine.evaluate(eval_dataset2, batch_size=batch_size)
# predict # predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches) engine.predict(test_dataset, batch_size=batch_size)
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf') model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=False, mode='predict') engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup() temp_dir.cleanup()
......
...@@ -26,11 +26,9 @@ import paddle.static as static ...@@ -26,11 +26,9 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
from paddle.fluid import layers from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
paddle.enable_static() paddle.enable_static()
batch_size = 2 batch_size = 2
...@@ -91,6 +89,7 @@ class MLPLayer(nn.Layer): ...@@ -91,6 +89,7 @@ class MLPLayer(nn.Layer):
out = self.linear1(out) out = self.linear1(out)
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
auto.fetch(out, "out")
self.out = out self.out = out
return out return out
...@@ -107,46 +106,32 @@ def train(fetch): ...@@ -107,46 +106,32 @@ def train(fetch):
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') dist_strategy = auto.Strategy()
labels_spec = InputSpec([batch_size], 'int64', 'label') dist_strategy.auto_mode = "semi"
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine # init engine
engine = Engine(mlp, engine = auto.Engine(mlp,
inputs_spec=inputs_spec, loss,
labels_spec=labels_spec, optimizer,
paddle.metric.Accuracy(),
strategy=dist_strategy) strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset, batch_size=batch_size, fetches=fetches) engine.fit(train_dataset, batch_size=batch_size)
# eval # eval
eval_dataset = MyDataset(batch_size) eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches) engine.evaluate(eval_dataset, batch_size=batch_size)
# predict # predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches) engine.predict(test_dataset, batch_size=batch_size)
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf') model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict') engine.save(model_filename, training=False)
temp_dir.cleanup() temp_dir.cleanup()
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
import sys import sys
import numpy as np import numpy as np
import random
import paddle import paddle
import paddle.distributed.auto_parallel as auto
sys.path.append("..") sys.path.append("..")
import auto_parallel_gpt_model as modeling import auto_parallel_gpt_model as modeling
...@@ -25,7 +27,7 @@ sequence_len = 512 ...@@ -25,7 +27,7 @@ sequence_len = 512
vocab_size = 1000 vocab_size = 1000
class FakeDataset: class FakeDataset(paddle.io.Dataset):
def __init__(self, num_samples): def __init__(self, num_samples):
self.num_samples = num_samples self.num_samples = num_samples
...@@ -33,6 +35,9 @@ class FakeDataset: ...@@ -33,6 +35,9 @@ class FakeDataset:
self.vocab_size = vocab_size self.vocab_size = vocab_size
def __getitem__(self, idx): def __getitem__(self, idx):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
tokens = np.random.randint(self.vocab_size, size=self.sequence_len) tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len) position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape( attention_mask = np.tril(np.ones(self.sequence_len)).reshape(
...@@ -67,8 +72,9 @@ def create_data_holder(batch_size): ...@@ -67,8 +72,9 @@ def create_data_holder(batch_size):
def generate_model(strategy): def generate_model(strategy):
modeling.init_global() modeling.init_global()
modeling._global_process_mesh = list( ranks = list(range(paddle.distributed.get_world_size()))
range(paddle.distributed.get_world_size())) modeling._global_process_mesh = auto.ProcessMesh(mesh=ranks,
dim_names=["x"])
if strategy == "serial": if strategy == "serial":
modeling._global_parallel_strategy = "serial" modeling._global_parallel_strategy = "serial"
elif strategy == "mp": elif strategy == "mp":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_gradient_merge=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_gradient_merge:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 4
gradient_merge.avg = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestGradientMergePass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 8
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_gradient_merge=False):
reset_prog()
strategy = apply_pass(use_gradient_merge)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_gradient_merge_pass(self):
# dp2 training
dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"])
# dp2 gradient merge training
gm_engine = self.get_engine(True)
gm_losses = gm_engine.fit(self.dataset, 3, batch_size=self.batch_size)
gm_losses = np.array(gm_losses["loss"])
avg_loss = 0
pass_avg_ret_list = []
for i, pass_ret in enumerate(gm_losses):
if (i + 1) % 4 == 0:
avg_loss += pass_ret
pass_avg_ret_list.append(avg_loss / 4)
avg_loss = 0
else:
avg_loss += pass_ret
self.check_results(dp_losses, np.array(pass_avg_ret_list))
if __name__ == "__main__":
unittest.main()
...@@ -17,11 +17,7 @@ import paddle ...@@ -17,11 +17,7 @@ import paddle
import unittest import unittest
import numpy as np import numpy as np
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.static import InputSpec
from paddle.distributed import fleet
from paddle.incubate.autograd import Hessian from paddle.incubate.autograd import Hessian
from paddle.distributed.auto_parallel.engine import Engine
np.random.seed(1234) np.random.seed(1234)
paddle.seed(1234) paddle.seed(1234)
...@@ -87,7 +83,7 @@ class LaplaceModel(paddle.nn.Layer): ...@@ -87,7 +83,7 @@ class LaplaceModel(paddle.nn.Layer):
return eq_loss, bc_u return eq_loss, bc_u
class LaplaceDataset: class LaplaceDataset(paddle.io.Dataset):
def __init__(self, num_sample): def __init__(self, num_sample):
self.num_sample = num_sample self.num_sample = num_sample
...@@ -129,23 +125,14 @@ def main(): ...@@ -129,23 +125,14 @@ def main():
# model # model
laplace = LaplaceModel() laplace = LaplaceModel()
# spec dist_strategy = auto.Strategy()
inputs_spec = [ dist_strategy.auto_mode = "semi"
InputSpec([100, 2], 'float32', 'x'),
InputSpec([36], 'int64', 'bc_idx')
]
labels_spec = InputSpec([36, 1], 'float32', 'bc_v')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
engine = Engine(laplace, engine = auto.Engine(laplace,
inputs_spec=inputs_spec, loss=loss_func,
labels_spec=labels_spec, optimizer=optimizer,
strategy=dist_strategy) strategy=dist_strategy)
engine.prepare(optimizer=optimizer, loss=loss_func) engine.fit(train_dataset, train_sample_split=2, batch_size=None)
engine.fit(train_dataset, batch_size=None)
dist_context = engine.dist_context dist_context = engine.dist_context
block = engine.main_program.global_block() block = engine.main_program.global_block()
......
...@@ -28,9 +28,8 @@ import paddle.utils as utils ...@@ -28,9 +28,8 @@ import paddle.utils as utils
from paddle.fluid import layers from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.optimizer.lr import CosineAnnealingDecay from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn from paddle.fluid.dataloader.collate import default_collate_fn
...@@ -48,10 +47,9 @@ class_num = 10 ...@@ -48,10 +47,9 @@ class_num = 10
paddle.seed(44) paddle.seed(44)
class MyDataset(IterableDataset): class MyDataset(paddle.io.IterableDataset):
def __init__(self, num_samples): def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples self.num_samples = num_samples
def __iter__(self): def __iter__(self):
...@@ -61,10 +59,9 @@ class MyDataset(IterableDataset): ...@@ -61,10 +59,9 @@ class MyDataset(IterableDataset):
yield input, label yield input, label
class MyDataset1(Dataset): class MyDataset1(paddle.io.Dataset):
def __init__(self, num_samples): def __init__(self, num_samples):
super(MyDataset1, self).__init__()
self.num_samples = num_samples self.num_samples = num_samples
self.data = [] self.data = []
for i in range(self.num_samples): for i in range(self.num_samples):
...@@ -112,12 +109,10 @@ class MLPLayer(nn.Layer): ...@@ -112,12 +109,10 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input): def forward(self, input):
out = auto.shard_op(self.norm, dist_attr={"process_mesh": out = auto.shard_op(self.norm, PP_MESH_0)(input)
PP_MESH_0})(input)
out = self.linear0(out) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, dist_attr={"process_mesh": out = auto.shard_op(self.linear1, PP_MESH_1)(out)
PP_MESH_1})(out)
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
self.out = out self.out = out
...@@ -136,54 +131,36 @@ def train(fetch): ...@@ -136,54 +131,36 @@ def train(fetch):
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') dist_strategy = auto.Strategy()
labels_spec = InputSpec([batch_size], 'int64', 'label') dist_strategy.auto_mode = "semi"
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
dist_strategy.split_data = True dist_strategy.split_data = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine # init engine
engine = Engine(mlp, engine = auto.Engine(mlp,
inputs_spec=inputs_spec, loss,
labels_spec=labels_spec, optimizer,
paddle.metric.Accuracy(),
strategy=dist_strategy) strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
train_dataset1 = MyDataset1(batch_num) engine.fit(train_dataset, epochs=2, batch_size=batch_size)
engine.fit(train_dataset,
epochs=2, train_dataset1 = MyDataset1(batch_size * batch_num)
batch_size=batch_size, engine.fit(train_dataset1, epochs=2, batch_size=None)
steps_per_epoch=batch_num,
fetches=fetches)
engine.fit(train_dataset1,
epochs=2,
batch_size=None,
steps_per_epoch=batch_num,
fetches=fetches)
# eval # eval
eval_dataset = MyDataset(batch_size) eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches) engine.evaluate(eval_dataset, batch_size=batch_size)
# predict # predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches) engine.predict(test_dataset, batch_size=batch_size)
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf') model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict') engine.save(model_filename, training=False)
temp_dir.cleanup() temp_dir.cleanup()
......
...@@ -27,10 +27,8 @@ import paddle.nn.functional as F ...@@ -27,10 +27,8 @@ import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
from paddle.fluid import layers from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from engine_api_dp import MyDataset from engine_api_dp import MyDataset
paddle.enable_static() paddle.enable_static()
...@@ -43,20 +41,6 @@ class_num = 10 ...@@ -43,20 +41,6 @@ class_num = 10
paddle.seed(44) paddle.seed(44)
# class MyDataset(Dataset):
# def __init__(self, num_samples):
# super(MyDataset, self).__init__()
# self.num_samples = num_samples
# def __getitem__(self, index):
# input = np.random.uniform(size=image_size).astype("float32")
# label = np.random.randint(0, class_num - 1, dtype="int64")
# return input, label
# def __len__(self):
# return self.num_samples
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -107,50 +91,33 @@ def train(fetch): ...@@ -107,50 +91,33 @@ def train(fetch):
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') dist_strategy = auto.Strategy()
labels_spec = InputSpec([batch_size], 'int64', 'label') dist_strategy.auto_mode = "semi"
# sharding config
dist_strategy = fleet.DistributedStrategy() sharding = dist_strategy.sharding
dist_strategy.amp = False sharding.enable = True
dist_strategy.pipeline = False sharding.degree = 2
dist_strategy.recompute = False sharding.stage = 3
# init parallel optimizer sharding.enable_tuning = True
dist_strategy.semi_auto = True sharding.tuning_range = [0, 1, 2, 3]
dist_strategy.sharding = True
dist_strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 3,
"enable_tuning": True,
}
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
import tempfile
tmp_dir = tempfile.TemporaryDirectory()
dataset = MyDataset(batch_num * batch_size)
# Tuning configuration # Tuning configuration
tuning_config = { tuning = dist_strategy.tuning
"batch_size": batch_size, tuning.enable = True
"dataset": dataset, tuning.profile_start_step = 1
"profile_start_step": 1, tuning.profile_end_step = 5
"profile_end_step": 5, tuning.run_after_tuning = True
"run_after_tuning": True, tuning.verbose = True
"sharding": {
"stage_range": [0, 1, 2, 3] dataset = MyDataset(batch_num * batch_size)
}, engine = auto.Engine(mlp,
"verbose": True, loss,
} optimizer,
engine = Engine(mlp, paddle.metric.Accuracy(),
inputs_spec=inputs_spec, strategy=dist_strategy)
labels_spec=labels_spec, engine._tune(dataset, batch_size=batch_size)
strategy=dist_strategy,
user_tuning_config=tuning_config)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# check tuned # check tuned
assert (engine._dist_contexts['train'].strategy.sharding_configs['stage'] != assert (engine._dist_contexts['train'].strategy.sharding.stage != 3)
3)
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
def apply_pass(use_recompute=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestRecomputePass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_recompute=False):
reset_prog()
strategy = apply_pass(use_recompute)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_recompute_pass(self):
# mp2 training
mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"])
# mp2 recompute training
rc_engine = self.get_engine(True)
rc_losses = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc_losses = np.array(rc_losses["loss"])
self.check_results(mp_losses, rc_losses)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_sharding=False, stage=None):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_sharding:
sharding = strategy.sharding
sharding.enable = True
sharding.degree = 2
sharding.stage = 1
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_sharding=False, stage=None):
reset_prog()
strategy = apply_pass(use_sharding, stage)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_sharding_pass(self):
# dp2 training
dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"])
# sharding2 stage1 training
sharding1_engine = self.get_engine(True, 1)
sharding1_losses = sharding1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding1_losses = np.array(sharding1_losses["loss"])
self.check_results(dp_losses, sharding1_losses)
# sharding2 stage2 training
sharding2_engine = self.get_engine(True, 2)
sharding2_losses = sharding2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding2_losses = np.array(sharding2_losses["loss"])
self.check_results(dp_losses, sharding2_losses)
# sharding2 stage3 training
sharding3_engine = self.get_engine(True, 3)
sharding3_losses = sharding3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding3_losses = np.array(sharding3_losses["loss"])
self.check_results(dp_losses, sharding3_losses)
if __name__ == "__main__":
unittest.main()
...@@ -45,9 +45,10 @@ from test_cluster import cluster_json ...@@ -45,9 +45,10 @@ from test_cluster import cluster_json
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp" _global_parallel_strategy = "dp_mp_pp"
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) _global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]],
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]]) dim_names=["x", "y", "z"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]]) PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -74,16 +75,8 @@ class MLPLayer(nn.Layer): ...@@ -74,16 +75,8 @@ class MLPLayer(nn.Layer):
self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input): def forward(self, input):
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"])
dist_attr={ auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -111,16 +104,8 @@ def mlp_forward(train_program, start_program): ...@@ -111,16 +104,8 @@ def mlp_forward(train_program, start_program):
embedding = paddle.nn.Embedding(10, hidden_size, sparse=True) embedding = paddle.nn.Embedding(10, hidden_size, sparse=True)
embedding_out = embedding(fill_constant_out) embedding_out = embedding(fill_constant_out)
auto.shard_tensor(input, auto.shard_tensor(input, PP_MESH_0, ["x", None])
dist_attr={ auto.shard_tensor(label, PP_MESH_1, ["x", None])
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
......
...@@ -34,7 +34,10 @@ paddle.enable_static() ...@@ -34,7 +34,10 @@ paddle.enable_static()
batch_size = 4 batch_size = 4
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]] _g_process_mesh = [
auto.ProcessMesh([0, 1], dim_names=["x"]),
auto.ProcessMesh([2, 3], dim_names=["x"])
]
def get_random_inputs_and_labels(input_shape, label_shape): def get_random_inputs_and_labels(input_shape, label_shape):
...@@ -82,18 +85,10 @@ class MLPLayer(nn.Layer): ...@@ -82,18 +85,10 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
out = self.norm(input) out = self.norm(input)
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, _g_process_mesh[0], [None, "x"])
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
out = self.linear0(out) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight, _g_process_mesh[1], ["x", None])
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear1(out) out = self.linear1(out)
return out return out
...@@ -123,16 +118,8 @@ def get_program(): ...@@ -123,16 +118,8 @@ def get_program():
dataloader.set_batch_generator(batch_generator_creator(), dataloader.set_batch_generator(batch_generator_creator(),
places=paddle.static.cuda_places()) places=paddle.static.cuda_places())
# data dist_attr # data dist_attr
auto.shard_tensor(input, auto.shard_tensor(input, _g_process_mesh[0], ["x", None, None])
dist_attr={ auto.shard_tensor(label, _g_process_mesh[0], ["x", None, None])
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
mlp_start = MLPLayer(hidden_size=hidden_size, mlp_start = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
......
...@@ -42,19 +42,13 @@ def make_program_lookup_table_v1_mp_dp(): ...@@ -42,19 +42,13 @@ def make_program_lookup_table_v1_mp_dp():
is_sparse=False) is_sparse=False)
loss = paddle.fluid.layers.reduce_mean(emb_out) loss = paddle.fluid.layers.reduce_mean(emb_out)
auto.shard_tensor(src_ids, auto.shard_tensor(
dist_attr={ src_ids, auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]),
"process_mesh": auto.ProcessMesh([[0, 1], [2, ["x", None, None])
3]]),
"dims_mapping": [0, -1, -1]
})
emb_weight = block.vars["emb_weight"] emb_weight = block.vars["emb_weight"]
auto.shard_tensor(emb_weight, auto.shard_tensor(
dist_attr={ emb_weight, auto.ProcessMesh([[0, 1], [2, 3]],
"process_mesh": auto.ProcessMesh([[0, 1], [2, dim_names=["x", "y"]), ["y", None])
3]]),
"dims_mapping": [1, -1]
})
return main_program, start_program, loss return main_program, start_program, loss
......
...@@ -22,82 +22,58 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr ...@@ -22,82 +22,58 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
mesh = [[0, 1], [2, 3]] mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
def init_x_row(trans_x): def init_x_row(trans_x):
if trans_x: if trans_x:
x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32') x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32')
auto.shard_tensor(x, auto.shard_tensor(x, mesh, ["x", "y", None])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, 1, -1]
})
return x return x
else: else:
x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32') x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32')
auto.shard_tensor(x, auto.shard_tensor(x, mesh, ["x", None, "y"])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1, 1]
})
return x return x
def init_x_col(trans_x): def init_x_col(trans_x):
if trans_x: if trans_x:
x = paddle.static.data(name='x', shape=[6, 8], dtype='float32') x = paddle.static.data(name='x', shape=[6, 8], dtype='float32')
auto.shard_tensor(x, auto.shard_tensor(x, mesh, [None, "x"])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 0]
})
return x return x
else: else:
x = paddle.static.data(name='x', shape=[8, 6], dtype='float32') x = paddle.static.data(name='x', shape=[8, 6], dtype='float32')
auto.shard_tensor(x, auto.shard_tensor(x, mesh, ["x", None])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1]
})
return x return x
def init_y_row(trans_y): def init_y_row(trans_y):
if trans_y: if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32') y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y, auto.shard_tensor(y, mesh, [None, "y"])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
return y return y
else: else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32') y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y, auto.shard_tensor(y, mesh, ["y", None])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
return y return y
def init_y_col(trans_y): def init_y_col(trans_y):
if trans_y: if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32') y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y, auto.shard_tensor(y, mesh, ["y", None])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
return y return y
else: else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32') y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y, auto.shard_tensor(y, mesh, [None, "y"])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
return y return y
......
...@@ -71,11 +71,8 @@ class TestDistOpCost(unittest.TestCase): ...@@ -71,11 +71,8 @@ class TestDistOpCost(unittest.TestCase):
shape=[4, 1], shape=[4, 1],
dtype='float32') dtype='float32')
label.stop_gradient = True label.stop_gradient = True
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x", None])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
tmp = paddle.fluid.layers.fill_constant_batch_size_like( tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[2, 8], value=1, dtype='float32') input=x, shape=[2, 8], value=1, dtype='float32')
weight_attr = paddle.ParamAttr() weight_attr = paddle.ParamAttr()
...@@ -121,17 +118,12 @@ class TestDistOpCost(unittest.TestCase): ...@@ -121,17 +118,12 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1], shape=[8, 1],
dtype='float32') dtype='float32')
label.stop_gradient = True label.stop_gradient = True
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x"])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label, auto.shard_tensor(label,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
# embedding # embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like( tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32') input=x, shape=[4], value=1, dtype='int32')
...@@ -141,12 +133,9 @@ class TestDistOpCost(unittest.TestCase): ...@@ -141,12 +133,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type == "lookup_table_v2": if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]] W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W, auto.shard_tensor(
dist_attr={ W, auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": ["x", None])
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out, out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0] [1, 0]) # [8, 2] [-1, 0]
...@@ -154,26 +143,20 @@ class TestDistOpCost(unittest.TestCase): ...@@ -154,26 +143,20 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter( param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1] [4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1, auto.shard_tensor(param1,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter( param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0] [8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, "x"])
"dims_mapping": [-1, 0]
})
out1 = paddle.fluid.layers.matmul(out, out1 = paddle.fluid.layers.matmul(out,
param1) # [8, 8] [-1, -1] param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter( tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1] [8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, None])
"dims_mapping": [-1, -1]
})
tmp_out = paddle.fluid.layers.matmul(out1, tmp_param) tmp_out = paddle.fluid.layers.matmul(out1, tmp_param)
out2 = paddle.fluid.layers.matmul(tmp_out, out2 = paddle.fluid.layers.matmul(tmp_out,
param2) # [8, 4] [-1, 0] param2) # [8, 4] [-1, 0]
...@@ -227,17 +210,12 @@ class TestDistOpCost(unittest.TestCase): ...@@ -227,17 +210,12 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1], shape=[8, 1],
dtype='float32') dtype='float32')
label.stop_gradient = True label.stop_gradient = True
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x"])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label, auto.shard_tensor(label,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
# embedding # embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like( tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32') input=x, shape=[4], value=1, dtype='int32')
...@@ -247,12 +225,9 @@ class TestDistOpCost(unittest.TestCase): ...@@ -247,12 +225,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type == "lookup_table_v2": if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]] W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W, auto.shard_tensor(
dist_attr={ W, auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": ["x", None])
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out, out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0] [1, 0]) # [8, 2] [-1, 0]
...@@ -260,25 +235,20 @@ class TestDistOpCost(unittest.TestCase): ...@@ -260,25 +235,20 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter( param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1] [4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1, auto.shard_tensor(param1,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter( param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0] [8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, "x"])
"dims_mapping": [-1, 0]
})
out1 = paddle.matmul(out, param1) # [8, 8] [-1, -1] out1 = paddle.matmul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter( tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1] [8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, None])
"dims_mapping": [-1, -1]
})
tmp_out = paddle.matmul(out1, tmp_param) tmp_out = paddle.matmul(out1, tmp_param)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0] out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]
...@@ -331,17 +301,11 @@ class TestDistOpCost(unittest.TestCase): ...@@ -331,17 +301,11 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1], shape=[8, 1],
dtype='float32') dtype='float32')
label.stop_gradient = True label.stop_gradient = True
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x"])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label, auto.shard_tensor(label,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
# embedding # embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like( tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32') input=x, shape=[4], value=1, dtype='int32')
...@@ -351,12 +315,9 @@ class TestDistOpCost(unittest.TestCase): ...@@ -351,12 +315,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type == "lookup_table_v2": if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]] W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W, auto.shard_tensor(
dist_attr={ W, auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": ["x", None])
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out, out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0] [1, 0]) # [8, 2] [-1, 0]
...@@ -364,25 +325,21 @@ class TestDistOpCost(unittest.TestCase): ...@@ -364,25 +325,21 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter( param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1] [4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1, auto.shard_tensor(param1,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter( param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0] [8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, "x"])
"dims_mapping": [-1, 0]
})
out1 = paddle.fluid.layers.mul(out, param1) # [8, 8] [-1, -1] out1 = paddle.fluid.layers.mul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter( tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1] [8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, None])
"dims_mapping": [-1, -1]
})
tmp_out = paddle.fluid.layers.mul(out1, tmp_param) tmp_out = paddle.fluid.layers.mul(out1, tmp_param)
out2 = paddle.fluid.layers.mul(tmp_out, out2 = paddle.fluid.layers.mul(tmp_out,
param2) # [8, 4] [-1, 0] param2) # [8, 4] [-1, 0]
......
...@@ -29,11 +29,8 @@ def make_program_dp2(): ...@@ -29,11 +29,8 @@ def make_program_dp2():
with paddle.static.program_guard(main_program, start_program): with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False x.stop_gradient = False
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x", None, None])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1, -1]
})
tmp_0 = paddle.norm(x, p=2) tmp_0 = paddle.norm(x, p=2)
return main_program, start_program, tmp_0 return main_program, start_program, tmp_0
...@@ -44,11 +41,8 @@ def make_program_serial(): ...@@ -44,11 +41,8 @@ def make_program_serial():
with paddle.static.program_guard(main_program, start_program): with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False x.stop_gradient = False
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0], dim_names=["x"]),
dist_attr={ [None, None, None])
"process_mesh": auto.ProcessMesh([0]),
"dims_mapping": [-1, -1, -1]
})
tmp_0 = paddle.norm(x, p=2) tmp_0 = paddle.norm(x, p=2)
return main_program, start_program, tmp_0 return main_program, start_program, tmp_0
......
...@@ -29,11 +29,9 @@ def make_program_dp2(): ...@@ -29,11 +29,9 @@ def make_program_dp2():
with paddle.static.program_guard(main_program, start_program): with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32') x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32')
x.stop_gradient = False x.stop_gradient = False
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x", None, None])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1, -1]
})
tmp_0 = paddle.reshape(x, shape=[0, 0, 4, 2]) tmp_0 = paddle.reshape(x, shape=[0, 0, 4, 2])
tmp_1 = paddle.reshape(tmp_0, shape=[0, 0, 8]) tmp_1 = paddle.reshape(tmp_0, shape=[0, 0, 8])
tmp_2 = tmp_1.reshape((tmp_1.shape[0], tmp_1.shape[1], -1)) tmp_2 = tmp_1.reshape((tmp_1.shape[0], tmp_1.shape[1], -1))
......
...@@ -71,7 +71,7 @@ class TestDistributedTensor(unittest.TestCase): ...@@ -71,7 +71,7 @@ class TestDistributedTensor(unittest.TestCase):
def test_new_local_tensor(self): def test_new_local_tensor(self):
test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh( test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh(
mesh=[0, 1]) mesh=[0, 1], dim_names=["x"])
test_auto_parallel_reshard._global_parallel_strategy = "dp" test_auto_parallel_reshard._global_parallel_strategy = "dp"
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册