未验证 提交 c5cc4278 编写于 作者: Y Yulong Ao 提交者: GitHub

[Cherry-pick][Auto Parallel] Improve the APIs (#46164)

* [AutoParallel] adapt gradient merge pass (#45915)

* adapt gradient merge

* fix op_role

* fix strategy

* [Auto Parallel] Gradient Fuse Allreduce (#45643)

* bugfix (#45332)

* dist embedding support lookup table v1

* add unitest

* customize wait_comm

* group gradients

* bugfix

* update program

* [Auto Parallel] Improve the APIs (#45776)

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Add the serialization process for dist attrs

* [Auto Parallel] Remove unnecessary comments

* [Auto Parallel] Fix some bugs

* [Auto Parallel] Fix the code style

* [Auto Parallel] Remove unnecessary impls

* [Auto Parallel] Fix the importing error

* [Auto Parallel] Fix the copy from bugs of op dist attr

* [Auto Parallel] Replace the use of constexpr if

* [Auto Parallel] Redesign the shard_tensor, shard_op and ProcessMesh

* [Auto Parallel] Change API of the completion unittest

* [Auto Parallel] Fix the bug when set_attr an int

* [Auto Parallel] Add the unittest for the serialization

* [Auto Parallel] Add some unit tests

* [Auto Paralle] Unify the strategy

* [Auto Parallel] Improve the engine api

* [Auto Parallel] Reset the changes made to the framework

* [Auto Parallel] Change the engine unittest

* [Auto Parallel] Update API of the completion and partitioner

* [Auto Parallel] Update unit tests using engine api

* update shard annotation

* [Auto Parallel] Remove the modifications of other modules

* [Auto Parallel] Add docs for APIs

* add new strategy

* [Auto Parallel] Replace the logger

* [Auto Parallel] Restore the test_program.py

* [Auto Parallel] Change the import rules

* [Auto Parallel] Add the examples for Engine

* [Auto Parallel] Do some minor changes

* [Auto Parallel] Remove yaml dependency

* [Auto Parallel] Fix the unittests

* add valid after train

* bug fix
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <48191911+Caozhou1995@users.noreply.github.com>

* [Auto Parallel] Bugfix allreduce fuse for MP (#46086)

* bugfix

* bugfix

* typos fixed

* update strategy (#46138)
Co-authored-by: Nzhaoyingli <86812880+zhaoyinglia@users.noreply.github.com>
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <48191911+Caozhou1995@users.noreply.github.com>
上级 860f6077
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .interface import shard_tensor # noqa: F401 from .strategy import Strategy
from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
from .reshard import Resharder # noqa: F401 from .engine import Engine
from .cost_model import estimate_cost from .interface import shard_tensor
from .interface import shard_op
from .interface import recompute
from .interface import fetch
__all__ = [] __all__ = []
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from collections import defaultdict
# _g_default_config[category][field] = default_value
_g_default_config = defaultdict(dict)
def get_category_default_config(category):
return _g_default_config[category]
def set_category_default_config(category, default_value):
_g_default_config[category] = default_value
def get_field_default_config(category, field):
return _g_default_config[category][field]
def set_field_default_config(category, field, default_value):
_g_default_config[category][field] = default_value
NOT_FOUND = "not_found"
#########################################
# base configuration
#########################################
BASE = "base"
set_field_default_config(BASE, "auto_mode", "semi")
set_field_default_config(BASE, "gradient_scale", True)
set_field_default_config(BASE, "use_cache", True)
set_field_default_config(BASE, "return_numpy", True)
set_field_default_config(BASE, "all_ranks", False)
set_field_default_config(BASE, "split_data", False)
set_field_default_config(BASE, "seed", None)
set_field_default_config(BASE, "reinit", False) # Only for debug
#########################################
# recompute configuration
#########################################
RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None)
set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
# AMP configuration
#########################################
AMP = "amp"
set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
set_field_default_config(AMP, "incr_ratio", 2.0)
set_field_default_config(AMP, "decr_ratio", 0.8)
set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_optimizer_fp16", False)
#########################################
# sharding configuration
#########################################
SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0)
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])
#########################################
# gradient merge configuration
#########################################
GRADIENT_MERGE = "gradient_merge"
set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
set_field_default_config(GRADIENT_MERGE, "avg", True)
#########################################
# quantization configuration
#########################################
QAT = "qat"
set_field_default_config(QAT, "enable", False)
set_field_default_config(QAT, "channel_wise_abs_max", True)
set_field_default_config(QAT, "weight_bits", 8)
set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None)
# #########################################
# auto tuning configuration
# #########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
set_field_default_config(TUNING, "dataset", None)
set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)
...@@ -16,7 +16,7 @@ import paddle ...@@ -16,7 +16,7 @@ import paddle
import warnings import warnings
import logging import logging
import numpy as np import numpy as np
from ..utils import get_logger from .utils import get_logger
class Converter(object): class Converter(object):
......
...@@ -173,6 +173,17 @@ class TensorDistributedAttribute: ...@@ -173,6 +173,17 @@ class TensorDistributedAttribute:
def clear_annotated(self): def clear_annotated(self):
self._is_annotated.clear() self._is_annotated.clear()
def __eq__(self, other):
if not isinstance(other, TensorDistributedAttribute):
return False
if self.process_mesh != other.process_mesh:
return False
if self.dims_mapping != other.dims_mapping:
return False
if self._is_annotated != other._is_annotated:
return False
return True
def __str__(self): def __str__(self):
str = "\n\ttensor_dist_attr = {" str = "\n\ttensor_dist_attr = {"
if self.is_annotated("process_mesh"): if self.is_annotated("process_mesh"):
...@@ -486,6 +497,27 @@ class OperatorDistributedAttribute: ...@@ -486,6 +497,27 @@ class OperatorDistributedAttribute:
else: else:
return False return False
def __eq__(self, other):
if not isinstance(other, OperatorDistributedAttribute):
return False
if self.process_mesh != other.process_mesh:
return False
if self.op_type != other.op_type:
return False
if self.impl_type != other.impl_type:
return False
if self.impl_idx != other.impl_idx:
return False
if self._is_annotated != other._is_annotated:
return False
if self._is_recompute != other._is_recompute:
return False
if self.inputs_dist_attrs != other.inputs_dist_attrs:
return False
if self.outputs_dist_attrs != other.outputs_dist_attrs:
return False
return True
def __str__(self): def __str__(self):
str = "\n\top_dist_attr = {" str = "\n\top_dist_attr = {"
if self.is_annotated("process_mesh"): if self.is_annotated("process_mesh"):
......
...@@ -126,9 +126,6 @@ class DistributedContext: ...@@ -126,9 +126,6 @@ class DistributedContext:
# A flag indicates whether the used parallelism is data parallel # A flag indicates whether the used parallelism is data parallel
self._data_parallel = False self._data_parallel = False
# flag whether using `to_static`
self._dygraph_mode = False
@property @property
def serial_main_program(self): def serial_main_program(self):
return self._serial_main_program return self._serial_main_program
......
...@@ -23,6 +23,7 @@ from .dist_attribute import append_op_input_suffix ...@@ -23,6 +23,7 @@ from .dist_attribute import append_op_input_suffix
from .dist_attribute import append_op_output_suffix from .dist_attribute import append_op_output_suffix
from .dist_attribute import get_tensor_dist_attr_field_keys from .dist_attribute import get_tensor_dist_attr_field_keys
from .dist_attribute import get_op_dist_attr_field_keys from .dist_attribute import get_op_dist_attr_field_keys
from .utils import convert_to_shard_spec, verify_shard_spec
class DistributedOperator: class DistributedOperator:
...@@ -248,23 +249,106 @@ class DistributedOperator: ...@@ -248,23 +249,106 @@ class DistributedOperator:
return result return result
class DistributedModule: class DistributedOperatorHelper:
def __init__(self, serial_module, dist_attr=None): def __init__(self, serial_op, process_mesh, in_dims_mappings,
self._serial_module = serial_module out_dims_mappings):
self._dist_attr = dist_attr self._serial_op = serial_op
self._process_mesh = process_mesh
self._in_dims_mappings = in_dims_mappings
self._out_dims_mappings = out_dims_mappings
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context tensor_to_dims_mapping = {}
index = 0
if self._in_dims_mappings:
assert len(args) + len(kwargs) == len(self._in_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._in_dims_mappings), len(args) + len(kwargs))
for arg in args:
if isinstance(arg, Variable) and self._in_dims_mappings:
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
index += 1
for arg in kwargs.values() and self._in_dims_mappings:
if isinstance(arg, Variable):
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
index += 1
default_prog = paddle.fluid.default_main_program() default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block() cur_block = default_prog.current_block()
op_size = len(cur_block.ops) op_size = len(cur_block.ops)
output = self._serial_module(*args, **kwargs) output = self._serial_op(*args, **kwargs)
new_op_size = len(cur_block.ops) new_op_size = len(cur_block.ops)
if isinstance(output, tuple) or isinstance(output, list):
new_output = list(output)
elif isinstance(output, Variable):
new_output = [output]
else:
raise ValueError("Unrecognized outpout.")
if self._out_dims_mappings:
assert len(new_output) == len(self._out_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._out_dims_mappings), len(new_output))
for i, item in enumerate(new_output):
if isinstance(item, Variable) and self._out_dims_mappings:
tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i]
from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context() default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size): for idx in range(op_size, new_op_size):
op = cur_block.ops[idx] op = cur_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr) dist_op = DistributedOperator(op)
dist_op.dist_attr.mark_annotated_as(self._dist_attr) for name in dist_op.serial_op.input_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_input(name)
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
name)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
for name in dist_op.serial_op.output_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_output(name)
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
name)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
dist_op.dist_attr.process_mesh = self._process_mesh
if self._process_mesh is not None:
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op) default_dist_ctx.add_dist_op_for_program(dist_op)
return output return output
...@@ -27,7 +27,7 @@ from paddle.fluid.framework import static_only ...@@ -27,7 +27,7 @@ from paddle.fluid.framework import static_only
from .utils import get_dist_attr from .utils import get_dist_attr
from .converter import Converter from .converter import Converter
from .process_group import _g_process_group_map from .process_group import _g_process_group_map
from ..utils import get_logger from .utils import get_logger
def check_filename(re_exp, filename): def check_filename(re_exp, filename):
...@@ -59,6 +59,14 @@ class DistributedSaver: ...@@ -59,6 +59,14 @@ class DistributedSaver:
def save(self, path, serial_program, dist_main_program, dist_context): def save(self, path, serial_program, dist_main_program, dist_context):
def _save_state(program, path, mode="param"):
state = {
k: np.array(v)
for k, v in program.state_dict(mode).items()
}
with open(path, "wb") as f:
pickle.dump(state, f)
dirname, filename = _process_path(path) dirname, filename = _process_path(path)
rank_id = paddle.distributed.get_rank() rank_id = paddle.distributed.get_rank()
...@@ -76,16 +84,6 @@ class DistributedSaver: ...@@ -76,16 +84,6 @@ class DistributedSaver:
with open(dist_model_path, "wb") as f: with open(dist_model_path, "wb") as f:
f.write(dist_main_program.desc.serialize_to_string()) f.write(dist_main_program.desc.serialize_to_string())
# save distributed params
dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams"
dist_param_path = os.path.join(dirname, dist_param_filename)
dist_param = {
k: np.array(v)
for k, v in dist_main_program.state_dict().items()
}
with open(dist_param_path, "wb") as f:
pickle.dump(dist_param, f)
# save distributed attribute # save distributed attribute
dist_attr_filename = filename + "_dist" + str(rank_id) + ".pdattr" dist_attr_filename = filename + "_dist" + str(rank_id) + ".pdattr"
dist_attr_path = os.path.join(dirname, dist_attr_filename) dist_attr_path = os.path.join(dirname, dist_attr_filename)
...@@ -93,65 +91,69 @@ class DistributedSaver: ...@@ -93,65 +91,69 @@ class DistributedSaver:
with open(dist_attr_path, "wb") as f: with open(dist_attr_path, "wb") as f:
pickle.dump(dist_attrs, f) pickle.dump(dist_attrs, f)
# save distributed params
dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams"
dist_param_path = os.path.join(dirname, dist_param_filename)
_save_state(dist_main_program, dist_param_path)
# save distributed opt states
dist_opt_filename = filename + "_dist" + str(rank_id) + ".pdopt"
dist_opt_path = os.path.join(dirname, dist_opt_filename)
_save_state(dist_main_program, dist_opt_path, "opt")
# TODO:save cluster.json # TODO:save cluster.json
def load(self, def load(self, path, load_optimizer=True):
path,
program,
dist_context,
strict=True,
load_optimizer=True):
# TODO: if `program` is None, load `path.pdmodel`. # TODO: if `program` is None, load `path.pdmodel`.
def _load_file(filename, dirname, suffix="pdparams"):
file_list = []
for file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix),
file):
file_list.append(os.path.join(dirname, file))
file_list.sort()
return file_list
def _load_state(filename, dirname, suffix="pdparams"):
file_list = _load_file(filename, dirname, suffix)
state_dict = {}
for file in file_list:
with open(file, 'rb') as f:
state_dict_info = pickle.load(f, encoding='latin1')
for name, value in state_dict_info.items():
if name in state_dict:
state_dict[name].append(np.array(value))
else:
state_dict[name] = [np.array(value)]
self._logger.info("Load param file: {}".format(file_list))
return state_dict
filename = os.path.basename(path) filename = os.path.basename(path)
if filename == "": if filename == "":
raise ValueError( raise ValueError(
"path should be of 'dirname/filename' format, but received filename is empty string" "path should be of 'dirname/filename' format, but received filename is empty string"
) )
dirname = os.path.dirname(path) dirname = os.path.dirname(path)
# load path.pdparam
param_file_list = [] # load path.pdparam and path.pdopt
for param_file in os.listdir(dirname): param_state_dict = _load_state(filename, dirname)
if check_filename('{}(.*)_dist(.*).pdparams'.format(filename), opt_state_dict = _load_state(filename, dirname,
param_file): "pdopt") if load_optimizer else {}
param_file_list.append(os.path.join(dirname, param_file)) state_dict = dict(param_state_dict, **opt_state_dict)
param_file_list.sort()
self._logger.info(
"Load distributed attribute file: {}".format(param_file_list))
param_dict = {}
for param_file in param_file_list:
with open(param_file, 'rb') as f:
state_dict_info = pickle.load(f, encoding='latin1')
for name, value in state_dict_info.items():
if name in param_dict:
param_dict[name].append(np.array(value))
else:
param_dict[name] = [np.array(value)]
# load path.pdattr # load path.pdattr
dist_attr_file_list = [] dist_attr_file_list = _load_file(filename, dirname, "pdattr")
for dist_attr_file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).pdattr'.format(filename),
dist_attr_file):
dist_attr_file_list.append(os.path.join(dirname,
dist_attr_file))
dist_attr_file_list.sort()
self._logger.info( self._logger.info(
"Load distributed attribute file: {}".format(dist_attr_file_list)) "Load distributed attribute file: {}".format(dist_attr_file_list))
pre_dist_attr = {} dist_attr = {}
for dist_attr_file in dist_attr_file_list: for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f: with open(dist_attr_file, 'rb') as f:
dist_attr = pickle.load(f, encoding='latin1') dist_attr_info = pickle.load(f, encoding='latin1')
for name, attr in dist_attr.items(): for name, attr in dist_attr_info.items():
if name not in pre_dist_attr: if name not in dist_attr:
pre_dist_attr[name] = attr dist_attr[name] = attr
# get current dist_attr return state_dict, dist_attr
cur_dist_attr = get_dist_attr(program, dist_context)
# param convert
converter = Converter(param_dict, pre_dist_attr, cur_dist_attr)
param_dict = converter.convert(strict=strict)
program.set_state_dict(param_dict)
def save_inference_model(self, path, feed_vars, fetch_vars, exe, **kwargs): def save_inference_model(self, path, feed_vars, fetch_vars, exe, **kwargs):
......
...@@ -12,80 +12,169 @@ ...@@ -12,80 +12,169 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import time import time
import copy import copy
import logging import logging
import random
import numpy as np
from collections import defaultdict from collections import defaultdict
import paddle import paddle
import paddle.utils as utils import paddle.utils as utils
from paddle import fluid, static from paddle import fluid, static
from paddle.io import Dataset
from paddle.jit import to_static from paddle.jit import to_static
from paddle.metric import Metric from paddle.metric import Metric
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import program_guard from paddle.fluid import Variable
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope, _to_name_str from paddle.fluid.executor import global_scope, _to_name_str
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator, Parameter, _non_static_mode from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.passes import new_pass, PassContext
from .converter import Converter
from .helper import ProgramHelper from .helper import ProgramHelper
from ..collective import _get_global_env
from .cluster import Cluster, get_default_cluster from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer from .parallelizer_v2 import Parallelizer
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list from .utils import print_program_with_dist_attr, to_list
from .process_group import new_process_group, get_all_process_groups, get_world_process_group from .utils import get_logger, get_dist_attr
from .process_group import new_process_group, get_all_process_groups
from .dist_context import DistributedContext, get_default_distributed_context from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy
from .interface import _get_fetches
class Engine: class Engine:
"""
An Engine object can provide the full power of auto parallel to users.
With the help of it, users can easily obtain the abilities of the
distributed training and inference. It also support the dynamic graph and
static graph at the same time.
Args:
model (paddle.nn.Layer, optional): The model is an instance of
paddle.nn.Layer.
loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
instance or any callable function taken the predicted values and
ground truth values as input. It can be None when there is no loss.
Default: None.
optimizer (Optimizer|None, optional): The optimizer need to be set in training
and should be None in eval and predict mode. Default: None.
metrics (Metric|list[Metric]|None, optional): If metrics is set, all
metrics will be calculated and output in train/eval mode. Default: None.
cluster (Cluster|None, optional): The cluster represents the topology information
about the used physical devices. Default: None. (Unused for now)
strategy (Strategy|None, optional): The strategy is used to configure the
parallelization and optimization behaviors. Default: None.
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = MNIST(mode='train', transform=transform)
valid_dataset = MNIST(mode='test', transform=transform)
model = paddle.vision.models.LeNet()
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
metrics = paddle.metric.Accuracy(topk=(1, 2))
engine = auto.Engine(model, loss, optimizer, metrics)
# fit
engine.fit(train_dataset,
epochs=2,
batch_size=64)
# evaluate
engine.evaluate(valid_dataset,
batch_size=64)
# predict
engine.predict(valid_dataset,
batch_size=64)
# save
engine.save("./my_model")
# load
engine.load("./my_model")
"""
def __init__(self, def __init__(self,
model=None, model=None,
inputs_spec=None, loss=None,
labels_spec=None, optimizer=None,
metrics=None,
cluster=None, cluster=None,
strategy=None, strategy=None):
user_tuning_config=None):
self.model = model if model and not isinstance(model,
self.inputs_spec = self._validate_spec(inputs_spec) paddle.nn.Layer) and not callable(model):
self.labels_spec = self._validate_spec(labels_spec) raise TypeError(
self.cluster = cluster "'model must be sub classes of `paddle.nn.Layer` or any callable function."
if self.cluster is None: )
self.cluster = get_default_cluster() self._model = model
self.strategy = strategy
if self.strategy is None: if loss and not isinstance(loss,
self.strategy = fleet.DistributedStrategy() paddle.nn.Layer) and not callable(loss):
self._user_tuning_config = user_tuning_config raise TypeError(
"'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
)
self._loss = loss
if optimizer and not isinstance(
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`"
" or `paddle.fluid.optimizer.Optimizer`.")
self._optimizer = self._validate_opt(optimizer)
metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(
metric.__class__.__name__)
self._metrics = to_list(metrics)
if cluster and not isinstance(cluster, Cluster):
raise TypeError(
"'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
)
self._cluster = cluster or get_default_cluster()
if strategy and not isinstance(strategy, Strategy):
raise TypeError(
"'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
)
self._strategy = strategy or Strategy()
if os.getenv("POD_NAME"):
print("Distribute training by paddle.distributed.launch",
flush=True)
fleet.init(is_collective=True)
self._executor = None self._executor = None
self._cur_rank = paddle.distributed.get_rank() self._cur_rank = paddle.distributed.get_rank()
self._nranks = paddle.distributed.get_world_size() self._nranks = paddle.distributed.get_world_size()
self._saver = DistributedSaver() self._saver = DistributedSaver()
# TODO: add logger module self._logger = get_logger(logging.INFO)
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
self._logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format)
self._logger.addHandler(log_handler)
self._orig_main_prog = static.default_main_program() self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = static.default_startup_program() self._orig_startup_prog = static.default_startup_program()
...@@ -103,54 +192,18 @@ class Engine: ...@@ -103,54 +192,18 @@ class Engine:
"eval": False, "eval": False,
"predict": False "predict": False
} }
self._dygraph_mode = False
def prepare(self,
optimizer=None,
loss=None,
gradient_scale=True,
metrics=None,
all_ranks=False):
if optimizer and not isinstance(
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`."
)
self._optimizer = self._validate_opt(optimizer)
if loss and not isinstance(loss,
paddle.nn.Layer) and not callable(loss):
raise TypeError(
"'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
)
self._loss = loss
metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(
metric.__class__.__name__)
self._metrics = to_list(metrics)
self._gradient_scale = gradient_scale
self._planned_mode = None self._planned_mode = None
self._all_ranks = all_ranks self._dygraph_mode = False
self._prepare_single_mode("train") self._tuning = self._strategy.tuning
def _prepare_single_mode(self, mode): def _prepare_single_mode(self, mode):
# Do the build process
self._build(mode) self._build(mode)
# Do the planning process # Do the planning process
self._plan(mode) self._plan(mode)
# Do the Optimization tuning
if self._user_tuning_config and mode == "train":
self._optimization_tuning(mode)
# Do the parallel process # Do the parallel process
self._parallel(mode, self._all_ranks) self._parallel(mode)
# Init comm and startup program # Init comm and startup program
self._initialize(mode) self._initialize(mode)
self._mode_init_states[mode] = True self._mode_init_states[mode] = True
...@@ -163,7 +216,7 @@ class Engine: ...@@ -163,7 +216,7 @@ class Engine:
inputs_spec = self.inputs_spec inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else [] labels_spec = self.labels_spec if self.labels_spec else []
self.program_helper = ProgramHelper(self.model, self._loss, self.program_helper = ProgramHelper(self._model, self._loss,
self._metrics, inputs_spec, self._metrics, inputs_spec,
labels_spec) labels_spec)
# build forward main program # build forward main program
...@@ -190,14 +243,13 @@ class Engine: ...@@ -190,14 +243,13 @@ class Engine:
metrics = [] metrics = []
serial_main_prog = self._orig_main_prog.clone() serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone() serial_startup_prog = self._orig_startup_prog.clone()
# FIXME to support grad clip
with static.program_guard(serial_main_prog, serial_startup_prog), \ with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard(): utils.unique_name.guard():
inputs_spec = self.inputs_spec inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else [] labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec] inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec] labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs)) outputs = to_list(self._model(*inputs))
if mode != "predict" and self._loss: if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels))) losses = to_list(self._loss(*(outputs + labels)))
...@@ -221,25 +273,30 @@ class Engine: ...@@ -221,25 +273,30 @@ class Engine:
"metrics": metrics "metrics": metrics
} }
if mode != "train":
serial_main_prog = serial_main_prog.clone(for_test=True)
self._set_recompute_ckpts() self._set_recompute_ckpts()
self._dist_contexts[mode] = DistributedContext( self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses, serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy) feed_vars, fetch_vars, self._cluster, self._strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode
def _optimization_tuning(self, mode): def _optimization_tuning(self, mode, dataset, batch_size):
if not self._tuning.enable:
raise ValueError("Please set `tuning.enable=True`.")
self.mode = mode assert mode == "train"
assert "batch_size" in self._user_tuning_config, "Optimization Tuning should provide with batch size." # Do the build process
assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset." self._build(mode)
batch_size = self._user_tuning_config["batch_size"] # Do the planning process
dataset = self._user_tuning_config["dataset"] self._plan(mode)
dataset.dp_world_size = self.dp_world_sizes
dataset.dp_rank = self.dp_ranks dataset.dp_world_size = self._dp_world_sizes
dataset.dp_rank = self._dp_ranks
from .tuner.optimization_tuner import OptimizationTuner from .tuner.optimization_tuner import OptimizationTuner
self._optimization_tuner = OptimizationTuner(self._user_tuning_config, self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(),
self._dist_contexts[mode], self._dist_contexts[mode],
dataset, dataset,
self.inputs_spec, self.inputs_spec,
...@@ -249,12 +306,10 @@ class Engine: ...@@ -249,12 +306,10 @@ class Engine:
self._optimization_tuner.tune() self._optimization_tuner.tune()
if self._user_tuning_config["run_after_tuning"]: if self._tuning.run_after_tuning:
# update the strategy # update the strategy
self._dist_contexts[ self._dist_contexts[
mode]._strategy = self._optimization_tuner.get_best_config() mode]._strategy = self._optimization_tuner.get_best_config()
else:
return
def _plan(self, mode): def _plan(self, mode):
if self._planned_mode is None: if self._planned_mode is None:
...@@ -274,15 +329,15 @@ class Engine: ...@@ -274,15 +329,15 @@ class Engine:
if var.name in block.vars: if var.name in block.vars:
feed_list.append(block.vars[var.name]) feed_list.append(block.vars[var.name])
self.dp_world_sizes = [] self._dp_world_sizes = []
self.dp_ranks = [] self._dp_ranks = []
for feed_var in feed_list: for feed_var in feed_list:
dp_world_size, dp_rank = self._get_input_split_info( dp_world_size, dp_rank = self._get_input_split_info(
feed_var, self._dist_contexts[mode]) feed_var, self._dist_contexts[mode])
self.dp_world_sizes.append(dp_world_size) self._dp_world_sizes.append(dp_world_size)
self.dp_ranks.append(dp_rank) self._dp_ranks.append(dp_rank)
def _parallel(self, mode, all_ranks): def _parallel(self, mode, all_ranks=False):
# Parallelize program based on the planner's results # Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner, # For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update. # because we may use it to complete the annotation of the backwarkward and update.
...@@ -340,6 +395,11 @@ class Engine: ...@@ -340,6 +395,11 @@ class Engine:
if isinstance(place, fluid.CUDAPlace): if isinstance(place, fluid.CUDAPlace):
place = fluid.CUDAPlace(ParallelEnv().dev_id) place = fluid.CUDAPlace(ParallelEnv().dev_id)
if self._strategy.seed:
paddle.seed(self._strategy.seed + self._dp_ranks[0])
np.random.seed(self._strategy.seed + self._dp_ranks[0])
random.seed(self._strategy.seed + self._dp_ranks[0])
if self._dygraph_mode: if self._dygraph_mode:
dist_context = self._dist_contexts[mode] dist_context = self._dist_contexts[mode]
dist_main_program = self._dist_main_progs[mode][self._cur_rank] dist_main_program = self._dist_main_progs[mode][self._cur_rank]
...@@ -358,136 +418,303 @@ class Engine: ...@@ -358,136 +418,303 @@ class Engine:
prune_startup_prog = dist_startup_prog._prune(uninitialized) prune_startup_prog = dist_startup_prog._prune(uninitialized)
self._executor.run(prune_startup_prog) self._executor.run(prune_startup_prog)
if self.strategy.amp and self.strategy.amp_configs['use_pure_fp16']: if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
# from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16 self._set_state_dict(mode, self._strict, self._state_dict,
def cast_parameters_to_fp16(place, self._dist_attr)
program,
scope=None, if self._strategy.reinit:
to_fp16_var_names=None): self._logger.info("NOTE: parameters wiil be re-initialized.")
""" dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
Traverse all parameters in the whole model and set them to the FP16 data type. self._executor.run(dist_startup_prog)
Whereas, this function will keep parameters of batchnorms in FP32.
Args: def _infer_sample_spec(self, data, batch_size, split):
place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors. if isinstance(data, paddle.io.IterableDataset):
program (Program): The used program. if split is None:
scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values. input, label = next(iter(data))
Default is None. else:
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names` sample = next(iter(data))
will be set to FP16. Usually, it is the returned input = sample[:split]
value of `cast_model_to_fp16` API. label = sample[split:]
""" elif isinstance(data, paddle.io.Dataset):
from paddle.framework import core if split is None:
import numpy as np input, label = data[0]
all_parameters = [] else:
for block in program.blocks: sample = data[0]
all_parameters.extend(block.all_parameters()) input = sample[:split]
label = sample[split:]
var_scope = scope if scope else paddle.static.global_scope() else:
for param in all_parameters: raise ValueError(
if param.dtype == core.VarDesc.VarType.FP16: "Data should be a Dataset or IterableDatset, but received {}.".
param_t = var_scope.find_var( format(type(data).__name__))
param.name).get_tensor()
data = np.array(param_t) self.inputs_spec = []
param_t.set(np.float16(data), place) self.labels_spec = []
input_list = to_list(input)
cast_parameters_to_fp16(place, prune_startup_prog) label_list = to_list(label)
def _infer_item_spec(item, name, batch_size, specs):
if isinstance(item, np.ndarray):
spec = InputSpec.from_numpy(item, name)
if batch_size is None:
specs.append(spec)
else:
specs.append(spec.batch(batch_size))
elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
spec = InputSpec.from_tensor(item, name)
if batch_size is None:
specs.append(spec)
else:
specs.append(spec.batch(batch_size))
else:
specs.append(InputSpec([batch_size], type(item), name))
if input_list is not None:
for i, item in enumerate(input_list):
assert item is not None, "Receive None input."
name = "input" + str(i)
_infer_item_spec(item, name, batch_size, self.inputs_spec)
if label_list is not None:
for i, item in enumerate(label_list):
assert item is not None, "Receive None input."
name = "label" + str(i)
_infer_item_spec(item, name, batch_size, self.labels_spec)
self.inputs_spec = self._validate_spec(self.inputs_spec)
self.labels_spec = self._validate_spec(self.labels_spec)
def fit(self, def fit(self,
train_data, train_data,
train_sample_split=None,
batch_size=1, batch_size=1,
epochs=1, epochs=1,
fetches=None,
steps_per_epoch=None, steps_per_epoch=None,
valid_data=None,
valid_sample_split=None,
valid_freq=1,
valid_steps=None,
collate_fn=None, collate_fn=None,
use_cache=False, callbacks=None):
return_numpy=True): """
# TODO: callbacks Trains the model for a fixed number of epochs. If `valid_data` is set,
# TODO: evaluate after training evaluation will be done at the end of each epoch.
if not self._mode_init_states['train']:
raise Exception(
"train program is not initialized yet, please call engine.prepare() before calling fit() funtion."
)
Args:
train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
train_sample_split (int, optional): Each sample of the train dataset is assumed
to be a (input, label) pair by default and has two items. If each sample has
more than two items, train_sample_split specifies how to split these items into
input and label. The items before it are input and the left are label. Default: None.
batch_size (int, optional): The batch size of train_data and valid_data if provided.
The user's data will be used directly without batching if set to None. Default: 1.
epochs (int, optional): The number of epochs to train the model. Default: 1.
steps_per_epoch (int, optional): The total number of steps (batches of samples)
is executed in one epoch before stating the next one. If None, it is equal to
the number samples in your dataset divided by the batch size. Default: None.
valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
evaluation at the end of epoch. No evaluation will be done if set to None.
Default: None. (Unsupported for now)
valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
how many training epochs before a new evaluation is performed. Default: 1.
valid_sample_split (int, optional): Only relevant if valid_data is provided.
Each sample of the valid dataset is assumed to be a (input, label) pair
by default and has two items. If each sample has more than two items,
valid_sample_split specifies how to split these items into input and label.
The items before it are input and the left are label. Default: None.
valid_steps (int, optional): Only relevant if valid_data is provided.
It is the total number of steps (batches of samples) to draw before
stopping validation at the end of every epoch. If None, validation will run until the
`valid_data` dataset is exhausted. The validation will start from the
beginning of the dataset at each epoch. Default: None.
collate_fn(callable, optional): function to generate mini-batch data by merging
the sample list, None for only stack each fields of sample in axis
0. Default None.
callbacks (Callback|None, optional): A list of `Callback` instances to apply
during training. Default: None. (Unused for now)
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = MNIST(mode='train', transform=transform)
model = paddle.vision.models.LeNet()
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
metrics = paddle.metric.Accuracy(topk=(1, 2))
engine = auto.Engine(model, loss, optimizer, metrics)
engine.fit(train_dataset,
epochs=2,
batch_size=64)
"""
self.mode = 'train' self.mode = 'train'
self._infer_sample_spec(train_data, batch_size, train_sample_split)
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
else:
self._switch_mode("train")
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare()` first." "train model is not ready, please call `engine._prepare_single_mode('train')` first."
train_dataloader = self._create_dataloader(train_data, batch_size, train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch, epochs, steps_per_epoch,
collate_fn) collate_fn)
usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch) fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"])
lr_scheduler = self.get_lr_scheduler(self.main_program) inner_fetch = dict(fetch_loss, **fetch_metrics)
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
lr_scheduler = self._get_lr_scheduler(self.main_program)
outputs = defaultdict(list)
for epoch in range(epochs): for epoch in range(epochs):
train_logs = {"epoch: {:d} ": epoch} train_logs = {"epoch: {:d} ": epoch}
for step, _ in enumerate(train_dataloader): for step, _ in enumerate(train_dataloader):
try: try:
outs = self._executor.run(self.main_program, outs = self._executor.run(
self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=return_numpy) return_numpy=self._strategy.return_numpy)
except fluid.core.EOFException: except core.EOFException:
break break
train_logs["step: {:d} "] = step train_logs["step: {:d} "] = step
if lr_scheduler is not None: # update lr
if lr_scheduler and step % self._k_steps == 0:
lr_scheduler.step() lr_scheduler.step()
try: train_logs["lr: {:5e} "] = self._get_lr(self._lr_optimizer)
train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr()
except:
train_logs[
"lr: {:5e} "] = self._lr_optimizer._learning_rate.get_lr(
)
# inner fetches # inner fetches
if fetch_loss: if fetch_loss:
train_logs["loss: {:9f} "] = outs[0][0] train_logs["loss: {:8f} "] = outs[0][0]
outputs["loss"].append(outs[0][0])
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
train_logs[metric.name()[i] + ": {:8f} "] = res
outputs[metric.name()[i]].append(outs[0][0])
# user fetches # user fetches
user_outs = outs[len(fetch_loss):] user_outs = outs[len(inner_fetch):]
user_fetch_list = fetch_list[len(fetch_loss):] user_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(user_outs): for i, out in enumerate(user_outs):
train_logs[fetch_map[user_fetch_list[i]] + ": {}"] = out train_logs[fetch_map[user_fetch_list[i]] + ": {}"] = out
# logger # logger
string = '[train] ' + ''.join(list(train_logs.keys())) string = '[train] ' + ''.join(list(train_logs.keys()))
self._logger.info(string.format(*list(train_logs.values()))) self._logger.info(string.format(*list(train_logs.values())))
if valid_data and epoch % valid_freq == 0:
self.evaluate(valid_data, valid_sample_split, batch_size,
valid_steps, collate_fn, callbacks)
self._switch_mode("train")
else:
self._reset_metrics()
return outputs
def evaluate(self, def evaluate(self,
eval_data, valid_data,
valid_sample_split=None,
batch_size=1, batch_size=1,
fetches=None, steps=None,
collate_fn=None, collate_fn=None,
use_cache=False, callbacks=None):
return_numpy=True): """
Evaluate the loss and metrics of the model on evaluation data.
Args:
valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
valid_sample_split (int, optional): Each sample of the eval dataset is assumed
to be a (input, label) pair by default and has two items. If each sample has
more than two items, valid_sample_split specifies how to split these items into
input and label. The items before it are input and the left are label. Default: None.
batch_size (int, optional): The batch size of valid_data. The user's data will
be used directly without batching if set to None. Default: 1.
steps (int, optional): It is the total number of steps (batches of samples) to draw before
stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
The evaluation will start from the beginning of the dataset in each run. Default: None.
collate_fn(callable, optional): function to generate mini-batch data by merging
the sample list, None for only stack each fields of sample in axis
0. Default None.
callbacks (Callback|None, optional): A list of `Callback` instances to apply
during evaling. Default: None. (Unused for now)
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
valid_dataset = MNIST(mode='test', transform=transform)
model = paddle.vision.models.LeNet()
loss = paddle.nn.CrossEntropyLoss()
metrics = paddle.metric.Accuracy(topk=(1, 2))
engine = auto.Engine(model, loss, metrics=metrics)
engine.evaluate(valid_dataset, batch_size=64)
"""
self.mode = 'eval' self.mode = 'eval'
self._infer_sample_spec(valid_data, batch_size, valid_sample_split)
if not self._mode_init_states[self.mode]: if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode) self._prepare_single_mode(self.mode)
else:
self._switch_mode("eval")
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first." "eval model is not ready, please call `engine._prepare_single_mode('eval')` first."
eval_dataloader = self._create_dataloader(eval_data, valid_dataloader = self._create_dataloader(valid_data,
batch_size, batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn) collate_fn=collate_fn)
usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"]) fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"])
inner_fetch = dict(fetch_loss, **fetch_metrics) inner_fetch = dict(fetch_loss, **fetch_metrics)
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch) fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
for step, _ in enumerate(eval_dataloader): outputs = defaultdict(list)
eval_logs = {"step: {:d} ": step} for step, _ in enumerate(valid_dataloader):
try: try:
outs = self._executor.run(self.main_program, outs = self._executor.run(
self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=return_numpy) return_numpy=self._strategy.return_numpy)
except fluid.core.EOFException: except core.EOFException:
break break
eval_logs = {"step: {:d} ": step}
# inner fetches # inner fetches
if fetch_loss: if fetch_loss:
eval_logs["loss: {:9f} "] = outs[0][0] eval_logs["loss: {:8f} "] = outs[0][0]
outputs["eval_loss"].append(outs[0][0])
# Metric # Metric
if fetch_metrics: if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)] metric_out = outs[len(fetch_loss):len(inner_fetch)]
...@@ -495,8 +722,9 @@ class Engine: ...@@ -495,8 +722,9 @@ class Engine:
metric.update(*metric_out) metric.update(*metric_out)
results = metric.accumulate() results = metric.accumulate()
for i, res in enumerate(to_list(results)): for i, res in enumerate(to_list(results)):
eval_logs[metric.name()[i] + ": {:9f} "] = res eval_logs[metric.name()[i] + ": {:8f} "] = res
# usr fetches outputs["eval_" + metric.name()[i]].append(res)
# user fetches
usr_outs = outs[len(inner_fetch):] usr_outs = outs[len(inner_fetch):]
usr_fetch_list = fetch_list[len(inner_fetch):] usr_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(usr_outs): for i, out in enumerate(usr_outs):
...@@ -504,38 +732,88 @@ class Engine: ...@@ -504,38 +732,88 @@ class Engine:
# logger # logger
string = '[eval] ' + ''.join(list(eval_logs.keys())) string = '[eval] ' + ''.join(list(eval_logs.keys()))
self._logger.info(string.format(*list(eval_logs.values()))) self._logger.info(string.format(*list(eval_logs.values())))
self._reset_metrics()
return outputs
def predict(self, def predict(self,
test_data, test_data,
test_sample_split=None,
batch_size=1, batch_size=1,
fetches=None, steps=None,
collate_fn=None, collate_fn=None,
use_cache=False, callbacks=None):
return_numpy=True): """
Compute the output predictions on testing data.
Args:
test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
test_sample_split (int, optional): Each sample of the test dataset is assumed
to be a (input, label) pair by default and has two items. If each sample has
more than two items, test_sample_split specifies how to split these items into
input and label. The items before it are input and the left are label. Default: None.
batch_size (int, optional): The batch size of test_data. The user's data will
be used directly without batching if set to None. Default: 1.
steps (int, optional): It is the total number of steps (batches of samples) to draw before
stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
The predict will start from the beginning of the dataset in each run. Default: None.
collate_fn(callable, optional): function to generate mini-batch data by merging
the sample list, None for only stack each fields of sample in axis
0. Default None.
callbacks (Callback|None, optional): A list of `Callback` instances to apply
during testing. Default: None. (Unused for now)
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
valid_dataset = MNIST(mode='test', transform=transform)
model = paddle.vision.models.LeNet()
engine = auto.Engine(model)
engine.predict(valid_dataset, batch_size=64)
"""
self.mode = 'predict' self.mode = 'predict'
self._infer_sample_spec(test_data, batch_size, test_sample_split)
if not self._mode_init_states[self.mode]: if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode) self._prepare_single_mode(self.mode)
else:
self._switch_mode("predict")
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first." "predict model is not ready, please call `engine._prepare_single_mode('predict')` first."
test_dataloader = self._create_dataloader(test_data, test_dataloader = self._create_dataloader(test_data,
batch_size, batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn) collate_fn=collate_fn)
usr_fetch = self._validate_fetches(fetches)
fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"]) fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"])
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch) fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
outputs = [] outputs = []
for step, _ in enumerate(test_dataloader): for step, _ in enumerate(test_dataloader):
predict_logs = {"step: {:d} ": step}
try: try:
outs = self._executor.run(self.main_program, outs = self._executor.run(
self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=return_numpy) return_numpy=self._strategy.return_numpy)
except fluid.core.EOFException: except core.EOFException:
break break
predict_logs = {"step: {:d} ": step}
outputs.append(outs[:len(fetch_outputs)]) outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs): for i, out in enumerate(outs):
predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out
...@@ -545,12 +823,23 @@ class Engine: ...@@ -545,12 +823,23 @@ class Engine:
return outputs return outputs
def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
self.mode = 'train'
self._infer_sample_spec(tune_data, batch_size, tune_sample_split)
self._optimization_tuning(self.mode, tune_data, batch_size)
def _create_dataloader(self, def _create_dataloader(self,
dataset, dataset,
batch_size, batch_size,
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
collate_fn=None): collate_fn=None):
if self._strategy.gradient_merge and batch_size is not None:
assert batch_size % self._k_steps == 0, \
"Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps)
batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
dist_context = self._dist_contexts[self.mode] dist_context = self._dist_contexts[self.mode]
...@@ -589,9 +878,9 @@ class Engine: ...@@ -589,9 +878,9 @@ class Engine:
epochs, epochs,
steps_per_epoch, steps_per_epoch,
collate_fn, collate_fn,
data_parallel_world_size=self.dp_world_sizes, data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self.dp_ranks, data_parallel_rank=self._dp_ranks,
split_data=self.strategy.split_data) split_data=self._strategy.split_data)
# move read op from the end of program to the start of program # move read op from the end of program to the start of program
new_op_size = len(dist_main_block.ops) new_op_size = len(dist_main_block.ops)
...@@ -612,6 +901,7 @@ class Engine: ...@@ -612,6 +901,7 @@ class Engine:
def _validate_spec(self, specs): def _validate_spec(self, specs):
specs = to_list(specs) specs = to_list(specs)
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None: if specs is not None:
for i, spec in enumerate(specs): for i, spec in enumerate(specs):
assert isinstance(spec, InputSpec) assert isinstance(spec, InputSpec)
...@@ -619,6 +909,12 @@ class Engine: ...@@ -619,6 +909,12 @@ class Engine:
raise ValueError( raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}." "Requires Input[{}].name != None, but receive `None` with {}."
.format(i, spec)) .format(i, spec))
if self._k_steps > 1:
shape = list(spec.shape)
assert shape[0] % self._k_steps == 0, \
"Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps)
shape[0] //= self._k_steps
spec.shape = shape
return specs return specs
def _is_local_var(self, var): def _is_local_var(self, var):
...@@ -678,41 +974,98 @@ class Engine: ...@@ -678,41 +974,98 @@ class Engine:
# NOTE hack to enable recompute in engine api for GPT-3 # NOTE hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here # TODO support more PaddleNLP/CV models here
config = self.strategy.recompute_configs recompute = self._strategy.recompute
# extract ckpts by specific model # extract ckpts by specific model
if isinstance(self.model, paddle.nn.Layer): if isinstance(self._model, paddle.nn.Layer):
if hasattr( if hasattr(
self.model, "gpt" self._model, "gpt"
) and self.model.__class__.__name__ == 'GPTForPretraining': ) and self._model.__class__.__name__ == 'GPTForPretraining':
exact_ckpts = self.model.gpt.checkpoints exact_ckpts = self._model.gpt.checkpoints
else: else:
exact_ckpts = config["checkpoints"] exact_ckpts = recompute.checkpoints
else: else:
exact_ckpts = config["checkpoints"] exact_ckpts = recompute.checkpoints
# modify strategy # modify strategy
if self.strategy.recompute: if recompute.enable:
config["checkpoints"] = exact_ckpts[:] recompute.checkpoints = exact_ckpts[:]
self.strategy.recompute_configs = config
logs = { logs = {
'Model Class': self.model.__class__.__name__, 'Model Class': self._model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts 'Applied Recompute ckpts': exact_ckpts
} }
self._logger.info(logs) self._logger.info(logs)
def _validate_opt(self, optimizer): def _validate_opt(self, optimizer):
if optimizer is not None:
optimizer._parameter_list = None optimizer._parameter_list = None
optimizer._param_groups = None optimizer._param_groups = None
return optimizer return optimizer
def save(self, path, training=True, mode=None): def _reset_metrics(self):
if not mode: for metric in self._metrics:
mode = self.mode metric.reset()
def _switch_mode(self, mode):
self.mode = mode
self._initialize(mode)
def _set_state_dict(self, mode, strict, state_dict, dist_attr):
program = self._dist_main_progs[mode][self._cur_rank]
dist_context = self._dist_contexts[mode]
cur_dist_attr = get_dist_attr(program, dist_context)
converter = Converter(state_dict, dist_attr, cur_dist_attr)
state_dict = converter.convert(strict=strict)
program.set_state_dict(state_dict)
def save(self, path, training=True):
"""
Saves the model, parameters, optimizer state to path.
If `training` is set to False, only inference model will be saved.
Args:
path (str): The file prefix to save model. The format
is 'dirname/file_prefix' or 'file_prefix'. if empty str.
A exception will be raised.
training (bool, optional): Whether to save for training. If not, save
for inference only. If `training` is set to True, the optimzer state
will be saved. Otherwise, only the model and parameters are saved.
This function will silently overwrite existing file at the target
location. Default: True.
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = MNIST(mode='train', transform=transform)
model = paddle.vision.models.LeNet()
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
metrics = paddle.metric.Accuracy(topk=(1, 2))
engine = auto.Engine(model, loss, optimizer, metrics)
engine.fit(train_dataset,
epochs=1,
batch_size=64)
engine.save("./my_model")
"""
if training: if training:
assert 'train' in self._serial_main_progs, \ assert 'train' in self._serial_main_progs, \
"training model is not ready, please call `engine.prepare()` first." "training model is not ready, please call `engine._prepare_single_mode('train')` first."
serial_program = self._serial_main_progs["train"] serial_program = self._serial_main_progs["train"]
dist_main_prog = self._dist_main_progs["train"][self._cur_rank] dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
dist_context = self._dist_contexts["train"] dist_context = self._dist_contexts["train"]
...@@ -721,7 +1074,7 @@ class Engine: ...@@ -721,7 +1074,7 @@ class Engine:
dist_main_program=dist_main_prog, dist_main_program=dist_main_prog,
dist_context=dist_context) dist_context=dist_context)
else: else:
assert mode, "Please set the 'mode' you want to save." mode = "predict"
feed_vars = self._feed_vars[mode]['inputs'] feed_vars = self._feed_vars[mode]['inputs']
fetch_vars = self._fetch_vars[mode]['outputs'] fetch_vars = self._fetch_vars[mode]['outputs']
dist_main_prog = self._dist_main_progs[mode][self._cur_rank] dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
...@@ -731,18 +1084,59 @@ class Engine: ...@@ -731,18 +1084,59 @@ class Engine:
self._executor, self._executor,
program=dist_main_prog) program=dist_main_prog)
def load(self, path, strict=True, load_optimizer=True, mode=None): def load(self, path, strict=True, load_optimizer=True):
if not mode: """
mode = self.mode Load the stored model, parameters and optimizer states.
assert mode, "Please set the 'mode' you want to load."
dist_main_prog = self._dist_main_progs[mode][self._cur_rank] Args:
dist_context = self._dist_contexts[mode] path (str): The prefix of files storing the model states and
self._saver.load(path, dist_main_prog, dist_context, strict, optimizer states.
load_optimizer) strict (bool, optional): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape). Default: False.
load_optimizer (bool, optional): If True, the stored optimizer
states is restored. Otherwise, the optimizer states is intialized
from scratch. Default: False.
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = MNIST(mode='train', transform=transform)
model = paddle.vision.models.LeNet()
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
metrics = paddle.metric.Accuracy(topk=(1, 2))
engine = auto.Engine(model, loss, optimizer, metrics)
engine.fit(train_dataset,
epochs=1,
batch_size=64)
engine.save("./my_model")
engine.load("./my_model")
"""
self._strict = strict
self._state_dict, self._dist_attr = self._saver.load(
path, load_optimizer)
return self._state_dict, self._dist_attr
@staticmethod @staticmethod
def get_lr_scheduler(program): def _get_lr_scheduler(program):
lr_sheduler = None lr_sheduler = None
if hasattr(program, 'lr_sheduler'): if hasattr(program, 'lr_sheduler'):
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
...@@ -750,6 +1144,20 @@ class Engine: ...@@ -750,6 +1144,20 @@ class Engine:
assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler" assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler"
return lr_sheduler return lr_sheduler
def _get_lr(self, optimizer):
if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer.get_lr()
elif isinstance(optimizer, paddle.fluid.optimizer.Optimizer):
if isinstance(optimizer._learning_rate, float):
return optimizer._learning_rate
else:
return optimizer._learning_rate()
else:
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer))
)
@property @property
def mode(self): def mode(self):
return self._mode return self._mode
...@@ -781,3 +1189,11 @@ class Engine: ...@@ -781,3 +1189,11 @@ class Engine:
@property @property
def fetch_vars(self): def fetch_vars(self):
return self._fetch_vars[self.mode] return self._fetch_vars[self.mode]
@property
def inputs(self):
return self.inputs_spec
@property
def labels(self):
return self.labels_spec
...@@ -19,13 +19,13 @@ import paddle ...@@ -19,13 +19,13 @@ import paddle
from paddle.nn import Layer from paddle.nn import Layer
from paddle.jit import to_static, not_to_static from paddle.jit import to_static, not_to_static
from paddle.distributed.utils import get_logger
from paddle.fluid.framework import Operator, Parameter, _non_static_mode from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import program_guard from paddle.fluid.framework import program_guard
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from .utils import to_list from .utils import to_list
from .utils import get_logger
from .converter import Converter from .converter import Converter
......
...@@ -12,101 +12,199 @@ ...@@ -12,101 +12,199 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy
import copy
import paddle import paddle
import paddle.fluid.core as core from paddle.fluid import core
from paddle.fluid.framework import Variable from .process_mesh import ProcessMesh
from paddle.fluid.framework import _non_static_mode from .process_mesh import get_current_process_mesh
from .process_mesh import set_current_process_mesh
from .process_mesh import reset_current_process_mesh
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
from .dist_op import DistributedModule from .dist_op import DistributedOperatorHelper
from .dist_attribute import TensorDistributedAttribute from .utils import verify_shard_spec, convert_to_dims_mapping
from .dist_attribute import OperatorDistributedAttribute
def _static_mode_check(): def shard_tensor(x, process_mesh=None, shard_spec=None):
if _non_static_mode():
raise RuntimeError("Auto-parallel only supports static mode for now, "
"please use paddle.enable_static() first.")
def shard_tensor(x, dist_attr=None):
""" """
Add distributed attributes for a tensors. Shard a tensor on a process mesh according to the shard specification.
Args: Args:
x (Tensor): the tensor to be sharded. x (Tensor): the tensor to be sharded.
dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow: process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh
"process_mesh": a nested list an to describe the mesh topology of logical processes. topology of the used logical processes where the tensor is sharded. If it is None,
"dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension the found current process mesh will be used. And an error will be raised if the
`i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`, current process mesh cannot be found. Default: None.
where -1 means that tensor dimension is not split. shard_spec (list, optional): a list to describe the sharding mapping between `x` and `process_mesh`,
Both process_mesh and dims_mapping are optional and users can specify as need. which means the dimension `i` of `x` is split across the dimension `shard_spec[i]` of `process_mesh`,
where `None` means that tensor dimension is not split. For example, given a tensor wih
the shape [6, 12] and a process mesh with the shape [2, 3] and the dimension names ["x", "y"]:
If `shard_spec=["x", "y"]`, each shard of the tensor will have a shape [3, 4];
If `shard_spec=["y", "x"]`, each shard of the tensor will have a shape [2, 6];
If `shard_spec=["x", None]`, each shard of the tensor will have a shape [3, 12];
If `shard_spec=[None, "x"]`, each shard of the tensor will have a shape [6, 4];
If `shard_spec=["y", None]`, each shard of the tensor will have a shape [2, 12];
If `shard_spec=[None, "y"]`, each shard of the tensor will have a shape [6, 4];
If `shard_spec=[None, None]`, each shard of the tensor will have a shape [6, 12];
If the `shard_spec` is None, the tensor will be replicated across all the processes of `process_mesh`.
In the above example, the `shard_spec=None` is same as 'shard_spec=[None, None]'. Defaults: None.
Returns: Returns:
Tensor: the tensor `x` annotated with distributed attributes. Tensor: the tensor `x` annotated with sharding information.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.distributed as dist import paddle.distributed.auto_parallel as auto
paddle.enable_static()
mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
x = paddle.ones([4, 6]) x = paddle.ones([4, 6])
dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]], shard_spec = ["x", "y"]
"dims_mapping": [0, -1]}) auto.shard_tensor(x, mesh, shard_spec)
""" """
_static_mode_check()
assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \ if process_mesh is not None:
"The type of dist_attr must be None, dict or TensorDistributedAttribute." assert isinstance(process_mesh, ProcessMesh), \
dist_tensor = DistributedTensor(x, dist_attr) "Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh)
dist_tensor.dist_attr.mark_annotated_as(dist_attr) else:
process_mesh = get_current_process_mesh()
assert process_mesh is not None, \
"Specify the process mesh argument or use ProcessMesh context manager first."
assert isinstance(shard_spec, list), \
"Argument shard_spec {} is not an instance of list".format(shard_spec)
dist_tensor = DistributedTensor(x)
serial_tensor = dist_tensor.serial_tensor
dist_tensor.dist_attr.process_mesh = process_mesh
if serial_tensor.type == core.VarDesc.VarType.READER \
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
if shard_spec is not None:
assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
serial_tensor.name, shard_spec, tensor_shape, process_mesh)
dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping(
shard_spec, process_mesh)
if process_mesh is not None:
dist_tensor.dist_attr.mark_annotated("process_mesh")
if shard_spec is not None:
dist_tensor.dist_attr.mark_annotated("dims_mapping")
default_dist_ctx = get_default_distributed_context() default_dist_ctx = get_default_distributed_context()
default_dist_ctx.add_dist_tensor_for_program(dist_tensor) default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x)
return x return x
def shard_op(op_fn, dist_attr=None): def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
""" """
Call a functioin and add distributed attributes for ops added by the function. Shard an operation on a process mesh according to its input and output shard specification.
Args: Args:
op_fn (callable): a callable operator or module to be sharded. op (Callable): a callable operator or module to be sharded.
dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh
two categories. The first category decsribes the distributed attributes shared by all inputs and topology of the used logical processes where the op is sharded. All of its inputs and
outputs, and only `process_mesh` can be specified now. The second category describes distributed outputs are sharded by this process mesh. If it is None, the found current process mesh
attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are will be used. And an error will be raised if the current process mesh cannot be found.
optional and users can specify them as need. Note that `process_mesh` for operators must be the Default: None.
same as these process_meshes for inputs and outputs. in_shard_specs (list of list, optional): a list of list to describe the sharding specifications
for the inputs. Each item of `in_shard_specs` is a `shard_spec` between the correspoinding input
and `process_mesh`. If one item is None, the cooresponding input is replicated across all processes
If it is None, all inputs are replicated accross all processes. Note that the lenght of the
`in_shard_specs` should be equal to the actual number of inputs when calling this operation.
Default: None.
out_shard_specs (list of list, optional): a list of list to describe the sharding specifications
for the outputs. Each item of `out_shard_specs` is a `shard_spec` between the correspoinding output
and `process_mesh`. If one item is None, the cooresponding output is replicated across all processes
If it is None, all outputs are replicated accross all processes. Note that the lenght of the
`in_shard_specs` should be equal to the actual number of inputs when calling this operation.
Default: None. Default: None.
Returns: Returns:
list: the outputs of the function `op_fn`, which are annotated with distributed attributes. Outputs of `op`, each of which is annotated with sharding information.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.distributed as dist import paddle.distributed.auto_parallel as auto
paddle.enable_static()
x = paddle.ones([4, 6]) x = paddle.ones([4, 6])
y = paddle.zeros([4, 6]) y = paddle.zeros([4, 6])
dist_add = dist.shard_op(paddle.add, mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
dist_attr={ dist_add = auto.shard_op(paddle.add,
"process_mesh": [[2, 3, 1], [0, 4, 5]], in_shard_specs=[["x", "y"], ["y", None]],
x: {"dims_mapping": [-1, 0]}, out_shard_specs=[[None, "x"]])
y: {"dims_mapping": [0, -1]}
})
dist_add(x, y) dist_add(x, y)
""" """
_static_mode_check()
assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \ if process_mesh is not None:
"The type of dist_attr must be dict or OperatorDistributedAttribute." assert isinstance(process_mesh, ProcessMesh), \
dist_module = DistributedModule(op_fn, dist_attr) "Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh)
return dist_module else:
process_mesh = get_current_process_mesh()
assert process_mesh is not None, \
"Specify the process mesh argument or use ProcessMesh context manager first."
in_dims_mappings = []
if in_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in in_shard_specs), \
"in_shard_spec {} is not a list of list or None".format(in_shard_specs)
for shard_spec in in_shard_specs:
if shard_spec is not None:
in_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh))
else:
in_dims_mappings.append(None)
out_dims_mappings = []
if out_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in out_shard_specs), \
"out_shard_spec {} is not a list of list or None".format(out_shard_specs)
for shard_spec in out_shard_specs:
if shard_spec is not None:
out_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh))
else:
out_dims_mappings.append(None)
op = DistributedOperatorHelper(op, process_mesh, in_dims_mappings,
out_dims_mappings)
return op
def recompute(op):
class RecomputeOperator:
def __init__(self, op):
self._op = op
def __call__(self, *args, **kwargs):
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
op_size = len(cur_block.ops)
output = self._op(*args, **kwargs)
new_op_size = len(cur_block.ops)
for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
op._set_attr("is_recompute@auto_parallel", True)
return output
return RecomputeOperator(op)
_g_fetched_tensors = {}
def fetch(tensor, name=None):
if name is None:
_g_fetched_tensors[tensor.name] = tensor
else:
_g_fetched_tensors[name] = tensor
def _get_fetches():
return _g_fetched_tensors
...@@ -42,6 +42,7 @@ from .utils import make_data_unshard ...@@ -42,6 +42,7 @@ from .utils import make_data_unshard
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo from .utils import SerialProgramInfo
from .utils import get_logger
from .reshard import Resharder from .reshard import Resharder
from .cluster import Cluster from .cluster import Cluster
from .mapper import mapping from .mapper import mapping
...@@ -84,7 +85,7 @@ class AutoParallelizer: ...@@ -84,7 +85,7 @@ class AutoParallelizer:
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING") self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \ self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False self._need_rank_mapping.lower() == 'true' else False
self._pass_context = None # self._pass_context = None
def _remove_distributed_attrs(self, main_program): def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix() suffix = core.kAutoParallelSuffix()
...@@ -143,10 +144,11 @@ class AutoParallelizer: ...@@ -143,10 +144,11 @@ class AutoParallelizer:
def _apply_optimize(self, main_program, startup_program, params_grads): def _apply_optimize(self, main_program, startup_program, params_grads):
optimizer = copy.deepcopy(self._optimizer)
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy( optimize_ops = optimizer.apply_gradients(params_grads)
self._optimizer).apply_gradients(params_grads)
self._dist_context._lr_optimizer = optimizer
# update completion # update completion
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program) self._completer.complete_update_annotation(main_program)
...@@ -165,6 +167,15 @@ class AutoParallelizer: ...@@ -165,6 +167,15 @@ class AutoParallelizer:
config) config)
auto_parallel_sharding_pass.apply([main_program], [startup_program], auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context) self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["rank_id"] = rank
auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
auto_parallel_clip_pass.apply([main_program], [startup_program],
self._pass_context)
if self._dist_strategy.gradient_merge: if self._dist_strategy.gradient_merge:
config = copy.deepcopy(self._dist_strategy.gradient_merge_configs) config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
......
...@@ -22,7 +22,6 @@ from paddle.fluid import program_guard ...@@ -22,7 +22,6 @@ from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode, unique_name from paddle.fluid.framework import _non_static_mode, unique_name
from paddle.distributed.passes import new_pass from paddle.distributed.passes import new_pass
from paddle.distributed.utils import get_logger
from .reshard import Resharder from .reshard import Resharder
from .partitioner import Partitioner from .partitioner import Partitioner
...@@ -31,6 +30,7 @@ from .dist_saver import DistributedSaver ...@@ -31,6 +30,7 @@ from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list from .utils import print_program_with_dist_attr, to_list
from .utils import get_logger
from .process_group import get_all_process_groups, get_world_process_group from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context from .dist_context import DistributedContext, get_default_distributed_context
...@@ -160,8 +160,8 @@ class Parallelizer: ...@@ -160,8 +160,8 @@ class Parallelizer:
# apply quantization pass # apply quantization pass
# The pass can be applied when mode must be 'train' # The pass can be applied when mode must be 'train'
if self._mode == 'train' and self._strategy.qat: if self._mode == 'train' and self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat_configs) config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
auto_parallel_quantization_pass = new_pass( auto_parallel_quantization_pass = new_pass(
...@@ -176,8 +176,8 @@ class Parallelizer: ...@@ -176,8 +176,8 @@ class Parallelizer:
# apply amp pass # apply amp pass
# FIXME we disenable amp for eval since it has a little bug with # FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future # eval program and which will be fixed in future
if self._mode == 'train' and self._strategy.amp: if self._mode == 'train' and self._strategy.amp.enable:
config = copy.deepcopy(self._strategy.amp_configs) config = copy.deepcopy(self._strategy.amp.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
config["loss"] = loss config["loss"] = loss
...@@ -195,8 +195,8 @@ class Parallelizer: ...@@ -195,8 +195,8 @@ class Parallelizer:
# apply recompute pass # apply recompute pass
# recompute is then train-only optimization # recompute is then train-only optimization
if self._mode == "train" and self._strategy.recompute: if self._mode == "train" and self._strategy.recompute.enable:
config = copy.deepcopy(self._strategy.recompute_configs) config = copy.deepcopy(self._strategy.recompute.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["no_grad_set"] = None config["no_grad_set"] = None
config["loss"] = loss config["loss"] = loss
...@@ -217,12 +217,12 @@ class Parallelizer: ...@@ -217,12 +217,12 @@ class Parallelizer:
config = {} config = {}
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["global_rank"] = rank config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass.apply([main_program], [startup_program], self._pass_context) dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding: if self._strategy.sharding.enable:
config = copy.deepcopy(self._strategy.sharding_configs) config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
config["global_rank"] = rank config["global_rank"] = rank
...@@ -230,11 +230,11 @@ class Parallelizer: ...@@ -230,11 +230,11 @@ class Parallelizer:
config) config)
auto_parallel_sharding_pass.apply([main_program], [startup_program], auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context) self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
# GradClip is train-only optimization # GradClip is train-only optimization
if self._mode == "train": if self._mode == "train":
config = copy.deepcopy(self._strategy.sharding_configs) config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
config["rank_id"] = rank config["rank_id"] = rank
...@@ -244,8 +244,8 @@ class Parallelizer: ...@@ -244,8 +244,8 @@ class Parallelizer:
self._pass_context) self._pass_context)
# gradient_merge is then train-only optimization # gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge: if self._mode == "train" and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge_configs) config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
auto_parallel_gradient_merge_pass = new_pass( auto_parallel_gradient_merge_pass = new_pass(
......
...@@ -12,86 +12,90 @@ ...@@ -12,86 +12,90 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy import numpy as np
import copy import copy
import paddle
# Use to store the previous and current process mesh
_g_previous_process_mesh = None
_g_current_process_mesh = None
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def get_current_process_mesh():
global _g_current_process_mesh
return _g_current_process_mesh
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
def set_current_process_mesh(process_mesh):
global _g_previous_process_mesh
global _g_current_process_mesh
_g_previous_process_mesh = _g_current_process_mesh
_g_current_process_mesh = process_mesh
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args: def reset_current_process_mesh():
mesh (list): an N-dimensional array (nested list) describes the toplogy global _g_previous_process_mesh
of logical processes. The shape of the N-dimensional array global _g_current_process_mesh
represents the topology of logical processes and every _g_current_process_mesh = _g_previous_process_mesh
element of the N-dimensional array represents a logical process.
Returns: class ProcessMesh(object):
None """
The `Processmesh` object describes the topology of the used processes.
Raises: Args:
ValueError: If `mesh` is not an instance of list. mesh (list|numpy.array): an n-dimensional array describes the toplogy
of the processes.
dim_names (list, optional): the i-th element of this list gives the name of the
i-th dimension of the mesh.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
assert mesh.topology == [2, 3] assert mesh.shape == [2, 3]
assert mesh.processes == [2, 4, 5, 0, 1, 3] assert mesh.processe_ids == [2, 4, 5, 0, 1, 3]
""" """
def __init__(self, mesh): def __init__(self, mesh=None, dim_names=None, shape=None, process_ids=None):
if mesh is None or not isinstance(mesh, list): # Use shape and process_ids just for compatibility
raise ValueError('mesh must be an instance of list.') # Users should not use these directly
if mesh is None:
processes = _flatten_nested_list(mesh) assert shape is not None
assert process_ids is not None
assert all(isinstance(p, int) for p in processes), \ mesh = np.array(process_ids).reshape(shape)
("All elements of mesh must be integer")
if not isinstance(mesh, list) and \
assert min(processes) >= 0, ('All elements of mesh must be >= 0.') not isinstance(mesh, np.ndarray):
raise ValueError(
unique_processes = set(processes) 'The mesh must be an instance of list or np.ndarray.')
assert len(unique_processes) == len(processes), ( if isinstance(mesh, list):
'All elements of mesh must be unique.') mesh = np.array(mesh)
self._topology = _get_nested_list_shape(mesh) self._mesh = mesh
self._processes = processes self._shape = list(self._mesh.shape)
self._process_ids = self._mesh.flatten().tolist()
assert all(isinstance(p, int) for p in self._process_ids), \
("All elements of the mesh must be integer")
assert min(
self._process_ids) >= 0, ('All elements of the mesh must be >= 0.')
unique_process_ids = set(self._process_ids)
assert len(unique_process_ids) == len(
self._process_ids), ('All elements of the mesh must be unique.')
if dim_names is not None:
assert len(dim_names) == len(self._shape), \
("The length of dims_names must be same as the shape of the mesh.")
self._dim_names = copy.deepcopy(dim_names)
else:
self._dim_names = ["d" + str(i) for i in range(len(self._shape))]
unique_dim_names = set(self._dim_names)
assert len(unique_dim_names) == len(self._dim_names), (
'All dim_names {} must be unique.'.format(dim_names))
# Store all process meshes # Store all process meshes
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
...@@ -103,31 +107,117 @@ class ProcessMesh(object): ...@@ -103,31 +107,117 @@ class ProcessMesh(object):
pg0.add_ranks(self.processes) pg0.add_ranks(self.processes)
@property @property
def topology(self): def shape(self):
r""" """
Get the topology of logical processes belonging to this ProcessMesh. Get the shape of this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
""" """
return self._topology return self._shape
@property @property
def processes(self): def process_ids(self):
r""" """
Get a list of all processes belonging to this ProcessMesh. Get the process ids belonging to this ProcessMesh.
""" """
return self._processes return self._process_ids
@property
def dim_names(self):
"""
Get the dimension names of this ProcessMesh.
"""
return self._dim_names
@property @property
def ndim(self): def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
""" """
return len(self._topology) Get the number of dimension of this ProcessMesh.
"""
return len(self._shape)
@property
def mesh(self):
"""
Get the underlying mesh of ProcessMesh.
"""
return self._mesh
@property
def topology(self):
return self._shape
@property
def processes(self):
return self._process_ids
def __getitem__(self, index):
if isinstance(index, tuple):
new_dim_names = []
for i, item in enumerate(index):
if isinstance(item, slice):
new_dim_names.append(self._dim_names[i])
new_mesh = self._mesh[index]
if new_mesh.shape:
return ProcessMesh(new_mesh, new_dim_names)
else:
# Wrap a scalar into a list but without dim_names
return ProcessMesh([new_mesh])
elif isinstance(index, slice):
new_mesh = self._mesh[index]
new_dim_names = self._dim_names
return ProcessMesh(new_mesh, new_dim_names)
else:
new_mesh = self._mesh[index]
new_dim_names = self._dim_names[1:]
return ProcessMesh(new_mesh, new_dim_names)
def __enter__(self):
set_current_process_mesh(self)
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
self._old_var_names = list(cur_block.vars.keys())
self._old_op_size = len(cur_block.ops)
def __exit__(self, exc_type, exc_value, exc_traceback):
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
new_var_names = list(cur_block.vars.keys())
new_op_size = len(cur_block.ops)
from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context()
for name in new_var_names:
if name not in self._old_var_names:
tensor = cur_block.vars[name]
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(
tensor)
if dist_tensor is None:
dist_tensor = DistributedTensor(cur_block.vars[name],
{"process_mesh": self})
dist_tensor.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
else:
if dist_tensor.dist_attr.process_mesh is None:
dist_tensor.dist_attr.process_mesh = self
dist_tensor.dist_attr.mark_annotated("process_mesh")
for idx in range(self._old_op_size, new_op_size):
op = cur_block.ops[idx]
dist_op = default_dist_ctx.get_dist_op_for_program(op)
if dist_op is None:
dist_op = DistributedOperator(op, {"process_mesh": self})
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)
else:
if dist_op.dist_attr.process_mesh is None:
dist_op.dist_attr.process_mesh = self
dist_op.dist_attr.mark_annotated("process_mesh")
reset_current_process_mesh()
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, ProcessMesh): if not isinstance(other, ProcessMesh):
return False return False
if self.topology != other.topology or self.processes != other.processes: if self.shape != other.shape or self.process_ids != other.process_ids:
return False return False
return True return True
...@@ -135,6 +225,6 @@ class ProcessMesh(object): ...@@ -135,6 +225,6 @@ class ProcessMesh(object):
return not self.__eq__(other) return not self.__eq__(other)
def __str__(self): def __str__(self):
str = "shape {} and process group {}".format(self.topology, str = "shape {}, process_ids {}, dim_nams {}".format(
self.processes) self.shape, self.process_ids, self.dim_names)
return str return str
...@@ -81,54 +81,57 @@ class ProcessMesh(core.ProcessMesh): ...@@ -81,54 +81,57 @@ class ProcessMesh(core.ProcessMesh):
return self._mesh return self._mesh
# def compute_compatible_process_meshes(process_meshes): def compute_compatible_process_mesh(process_meshes):
# """Compute the compatible process mesh given a list of process meshes.""" """Compute the compatible process mesh given a list of process meshes."""
# if not process_meshes: if not process_meshes:
# return None return None
# def _compute_compatible_two_process_meshes(pm1, pm2): def _compute_compatible_of_two_process_meshes(pm1, pm2):
# if pm1 is None: if pm1 is None:
# return True, pm2 return True, pm2
# if pm2 is None: if pm2 is None:
# return True, pm1 return True, pm1
# if pm1 == pm2: if pm1 == pm2:
# return True, pm1 return True, pm1
# if pm1.device_mesh != pm2.device_mesh: if pm1.process_ids == pm2.process_ids:
# return False, None if len(pm1.shape) >= len(pm2.shape):
# if pm1.process_ids == pm2.process_ids: return True, pm1
# if len(pm1.shape) >= len(pm2.shape): else:
# return True, pm1 return True, pm2
# else: process_set1 = set(pm1.process_ids)
# return True, pm2 process_set2 = set(pm2.process_ids)
# process_set1 = set(pm1.process_ids) if process_set1.issubset(process_set2):
# process_set2 = set(pm2.process_ids) return True, pm2
# if process_set1.issubset(process_set2): if process_set2.issubset(process_set1):
# return True, pm2 return True, pm1
# if process_set2.issubset(process_set1): return False, None
# return True, pm1
# return False, None compatible_result = None
for process_mesh in process_meshes:
# compatible_result = None compatible, compatible_result = _compute_compatible_of_two_process_meshes(
# for process_mesh in process_meshes: compatible_result, process_mesh)
# compatible, compatible_result = _compute_compatible_two_process_meshes( if not compatible:
# compatible_result, process_mesh) return None
# if not compatible: if compatible_result.empty():
# return None return None
# return ProcessMesh(compatible_result.mesh, compatible_result.dim_names) if isinstance(compatible_result, core.ProcessMesh):
mesh = np.array(compatible_result.process_ids).reshape(
# def merge_process_meshes(process_meshes): compatible_result.shape)
# """Merge a list of process meshes.""" return ProcessMesh(mesh, compatible_result.dim_names)
# merged_process_mesh = None elif isinstance(compatible_result, ProcessMesh):
# merged_process_ids = set() return ProcessMesh(compatible_result.mesh, compatible_result.dim_names)
# device_type = "" else:
# for process_mesh in process_meshes: raise ValueError("Unrecognized ProcessMesh.")
# if process_mesh is not None:
# process_ids = set(process_mesh.process_ids)
# if not device_type: def merge_process_mesh(process_meshes):
# device_type = process_mesh.device_type """Merge a list of process meshes."""
# assert device_type != process_mesh.device_type, \ merged_process_mesh = None
# "All process meshes must have the same device_type." merged_process_ids = set()
# merged_process_ids.union(process_ids) for process_mesh in process_meshes:
# if len(merged_process_ids) != 0: if process_mesh is not None:
# merged_process_mesh = ProcessMesh(list(merged_process_ids)) process_ids = set(process_mesh.process_ids)
# return merged_process_mesh merged_process_ids = merged_process_ids.union(process_ids)
if len(merged_process_ids) != 0:
merged_process_mesh = ProcessMesh(list(merged_process_ids))
return merged_process_mesh
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os
import copy
import argparse
from . import constants
class BaseConfig(object):
def __init__(self, category, config_dict=None):
self._category = category
self._config_dict = None
if config_dict is not None:
if isinstance(config_dict, dict):
self._config_dict = config_dict
else:
raise ValueError(
"Expected a dictionary. But received: {}".format(
config_dict))
# Initialize attributes by the default config
config = constants.get_category_default_config(self._category)
for field, default_value in config.items():
setattr(self, field, default_value)
# Overide attributes by the config_dict
if self._config_dict:
self.from_dict(self._config_dict)
def from_dict(self, config_dict):
config = constants.get_category_default_config(self._category)
for field in config.keys():
value = config_dict.get(field, constants.NOT_FOUND)
# Use the default value if we cannot found the value
if value != constants.NOT_FOUND:
setattr(self, field, value)
def to_dict(self):
result_dict = {}
config = constants.get_category_default_config(self._category)
for field in config.keys():
value = getattr(self, field)
result_dict[field] = value
for field, value in self.__dict__.items():
if isinstance(value, BaseConfig):
result_dict[field] = value.to_dict()
return result_dict
def __repr__(self):
return yaml.dump(self.to_dict(),
default_flow_style=False,
sort_keys=True,
indent=4)
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
return result
class RecomputeConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.RECOMPUTE
super(RecomputeConfig, self).__init__(category, config_dict)
class AMPConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.AMP
super(AMPConfig, self).__init__(category, config_dict)
class ShardingConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.SHARDING
super(ShardingConfig, self).__init__(category, config_dict)
class GradientMergeConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.GRADIENT_MERGE
super(GradientMergeConfig, self).__init__(category, config_dict)
class QATConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.QAT
super(QATConfig, self).__init__(category, config_dict)
class TuningConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.TUNING
super(TuningConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig):
"""
The `Strategy` object is used to configure the paralleization and optimization beheviors.
Args:
config (dict|string, optional): If this is None, the default configurations will used.
If this is a dictionary, the recognized key-value of it will be used to override the default
configurations while other default configurations are left unchanged. If this is a string,
it is interpreted as the path to a YAML configuration and will be loaded to override the
corresponding default configurations.
Examples:
.. code-block:: python
import paddle
import paddle.distributed.auto_parallel as auto
strategy = auto.Strategy()
sharding = strategy.sharding
self.assertEqual(sharding.enabled, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.degree, 8)
sharding.enabled = True
sharding.stage = 2
sharding.degree = 2
self.assertEqual(sharding.enabled, True)
self.assertEqual(sharding.stage, 2)
self.assertEqual(sharding.degree, 2)
"""
def __init__(self, config=None):
if config is not None:
if isinstance(config, dict):
self._config_dict = copy.deepcopy(config)
# elif os.path.exists(config):
# with open(config, "rb") as yaml_file:
# self._config_dict = yaml.load(yaml_file, Loader=yaml.Loader)
else:
raise ValueError(
"Expected a dictionary. But received: {}".format(config))
else:
self._config_dict = {}
category = constants.BASE
super(Strategy, self).__init__(category, self._config_dict)
config_dict = self._config_dict.get(constants.RECOMPUTE, None)
self.recompute = RecomputeConfig(config_dict)
config_dict = self._config_dict.get(constants.AMP, None)
self.amp = AMPConfig(config_dict)
config_dict = self._config_dict.get(constants.SHARDING, None)
self.sharding = ShardingConfig(config_dict)
config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None)
self.gradient_merge = GradientMergeConfig(config_dict)
config_dict = self._config_dict.get(constants.QAT, None)
self.qat = QATConfig(config_dict)
config_dict = self._config_dict.get(constants.TUNING, None)
self.tuning = TuningConfig(config_dict)
...@@ -16,7 +16,7 @@ import copy ...@@ -16,7 +16,7 @@ import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import logging import logging
from paddle.distributed.utils import get_logger from ..utils import get_logger
from .trial import TrialStatus from .trial import TrialStatus
from .trial import OptimizationTunerTrial as Trial from .trial import OptimizationTunerTrial as Trial
...@@ -110,13 +110,13 @@ class ShardingStageAlgorithm(AlgorithmBase): ...@@ -110,13 +110,13 @@ class ShardingStageAlgorithm(AlgorithmBase):
# TODO import trial class & copy strategy # TODO import trial class & copy strategy
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self._changed_configs = ["sharding_configs"] self._changed_configs = ["sharding"]
def _init_spaces(self): def _init_spaces(self):
self._max_stage = 3 self._max_stage = 3
self._trial_idx = 0 self._trial_idx = 0
stage_range = self._config.sharding_configs.get("stage_range", None) stage_range = self._config.sharding.to_dict().get("tuning_range", None)
if stage_range: if stage_range:
assert set(stage_range).issubset( assert set(stage_range).issubset(
set([0, 1, 2, 3]) set([0, 1, 2, 3])
...@@ -136,9 +136,8 @@ class ShardingStageAlgorithm(AlgorithmBase): ...@@ -136,9 +136,8 @@ class ShardingStageAlgorithm(AlgorithmBase):
stage = self._stage_range[self._trial_idx] stage = self._stage_range[self._trial_idx]
new_strategy = copy.deepcopy(self._config.dist_strategy) new_strategy = copy.deepcopy(self._config.dist_strategy)
config_dict = new_strategy.sharding_configs sharding = new_strategy.sharding
config_dict["stage"] = stage sharding.stage = stage
new_strategy.sharding_configs = config_dict
name = "trial-sharding-stage{}".format(stage) name = "trial-sharding-stage{}".format(stage)
trial = Trial(new_strategy, name, self.changed_configs) trial = Trial(new_strategy, name, self.changed_configs)
......
...@@ -17,15 +17,13 @@ import copy ...@@ -17,15 +17,13 @@ import copy
import pathlib import pathlib
import paddle import paddle
from paddle.distributed import fleet from ..strategy import Strategy
_tuning_supported_passes = ["sharding", "recompute"] _tuning_supported_passes = ["sharding", "recompute"]
_strategy_config_suffiex = "_configs"
def _get_pass_config(strategy, pass_name): def _get_pass_config(strategy, pass_name):
config_name = pass_name + _strategy_config_suffiex config = getattr(strategy, pass_name)
config = getattr(strategy, config_name)
return config return config
...@@ -38,10 +36,8 @@ class TuningConfig(object): ...@@ -38,10 +36,8 @@ class TuningConfig(object):
def __init__(self, user_config, strategy): def __init__(self, user_config, strategy):
if not isinstance(strategy, fleet.DistributedStrategy): if not isinstance(strategy, Strategy):
raise TypeError( raise TypeError("'strategy' must be object of class `Strategy`.")
"'strategy' must be object of class `fleet.DistributedStrategy`."
)
if not user_config: if not user_config:
user_config = {} user_config = {}
...@@ -116,11 +112,11 @@ class TuningConfig(object): ...@@ -116,11 +112,11 @@ class TuningConfig(object):
for p in _tuning_supported_passes: for p in _tuning_supported_passes:
if getattr(self._dist_strategy, p) and _get_pass_config( if getattr(self._dist_strategy, p) and _get_pass_config(
self._dist_strategy, p)["enable_tuning"]: self._dist_strategy, p).enable_tuning:
# TODO distinguish different args of each passes # TODO distinguish different args of each passes
self._tuning_passes_name.add(p) self._tuning_passes_name.add(p)
config_name = p + _strategy_config_suffiex config_name = p
p_dict = getattr(self._dist_strategy, config_name) p_dict = getattr(self._dist_strategy, config_name)
self.__dict__[config_name] = p_dict self.__dict__[config_name] = p_dict
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# import yaml
import os import os
import sys import sys
import copy import copy
...@@ -29,7 +30,6 @@ import paddle ...@@ -29,7 +30,6 @@ import paddle
from paddle.fluid import program_guard from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.passes import new_pass, PassContext
from paddle.distributed.utils import get_logger
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.completion import Completer
...@@ -39,6 +39,7 @@ from paddle.distributed.auto_parallel.process_group import clear_all_process_gro ...@@ -39,6 +39,7 @@ from paddle.distributed.auto_parallel.process_group import clear_all_process_gro
from paddle.distributed.auto_parallel.utils import debug_program from paddle.distributed.auto_parallel.utils import debug_program
from paddle.distributed.auto_parallel.utils import make_data_unshard, set_grad_var_shape from paddle.distributed.auto_parallel.utils import make_data_unshard, set_grad_var_shape
from ..utils import get_logger
from .config import TuningConfig from .config import TuningConfig
from .algorithms import new_algorithm from .algorithms import new_algorithm
from .trial import TrialStatus from .trial import TrialStatus
...@@ -256,8 +257,8 @@ class OptimizationTuner: ...@@ -256,8 +257,8 @@ class OptimizationTuner:
startup_program = dist_context.serial_startup_program startup_program = dist_context.serial_startup_program
# applying optimization pass # applying optimization pass
if new_strategy.amp: if new_strategy.amp.enable:
config = copy.deepcopy(new_strategy.amp_configs) config = copy.deepcopy(new_strategy.amp.to_dict())
config["dist_context"] = dist_context config["dist_context"] = dist_context
config["params_grads"] = dist_context._params_grads config["params_grads"] = dist_context._params_grads
...@@ -275,8 +276,8 @@ class OptimizationTuner: ...@@ -275,8 +276,8 @@ class OptimizationTuner:
auto_parallel_amp_pass.apply([main_program], [startup_program], auto_parallel_amp_pass.apply([main_program], [startup_program],
pass_context) pass_context)
if new_strategy.recompute: if new_strategy.recompute.enable:
config = copy.deepcopy(new_strategy.recompute_configs) config = copy.deepcopy(new_strategy.recompute.to_dict())
config["dist_context"] = dist_context config["dist_context"] = dist_context
config["no_grad_set"] = None config["no_grad_set"] = None
config["loss"] = dist_context.serial_loss config["loss"] = dist_context.serial_loss
...@@ -303,8 +304,8 @@ class OptimizationTuner: ...@@ -303,8 +304,8 @@ class OptimizationTuner:
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
if new_strategy.sharding: if new_strategy.sharding.enable:
config = copy.deepcopy(new_strategy.sharding_configs) config = copy.deepcopy(new_strategy.sharding.to_dict())
config["dist_context"] = dist_context config["dist_context"] = dist_context
config["params_grads"] = dist_params_grads config["params_grads"] = dist_params_grads
config["global_rank"] = self.rank config["global_rank"] = self.rank
...@@ -313,8 +314,8 @@ class OptimizationTuner: ...@@ -313,8 +314,8 @@ class OptimizationTuner:
auto_parallel_sharding_pass.apply([dist_main_prog], auto_parallel_sharding_pass.apply([dist_main_prog],
[dist_startup_prog], pass_context) [dist_startup_prog], pass_context)
if new_strategy.gradient_merge: if new_strategy.gradient_merge.enable:
config = copy.deepcopy(new_strategy.gradient_merge_configs) config = copy.deepcopy(new_strategy.gradient_merge.to_dict())
config["dist_context"] = dist_context config["dist_context"] = dist_context
config["params_grads"] = dist_params_grads config["params_grads"] = dist_params_grads
auto_parallel_gradient_merge_pass = new_pass( auto_parallel_gradient_merge_pass = new_pass(
...@@ -492,9 +493,10 @@ The best trial is: [{}], whose configuration is following: ...@@ -492,9 +493,10 @@ The best trial is: [{}], whose configuration is following:
for line in summary_.split("\n"): for line in summary_.split("\n"):
fw.write(line + "\n") fw.write(line + "\n")
full_strategy = self.get_best_config() # full_strategy = self.get_best_config()
full_strategy.save_to_prototxt( # path = os.path.join(self.project_dir, "tuned_dist_strategy.yaml")
os.path.join(self.project_dir, "tuned_dist_strategy.prototxt")) # with open(path, 'w') as outfile:
# yaml.dump(full_strategy, outfile, default_flow_style=False)
def clear(self): def clear(self):
""" """
......
...@@ -156,9 +156,10 @@ class OptimizationTunerTrial(Trial): ...@@ -156,9 +156,10 @@ class OptimizationTunerTrial(Trial):
draws += h1_format.format("{} auto=True <-> {}".format(name, name)) draws += h1_format.format("{} auto=True <-> {}".format(name, name))
draws += line + "\n" draws += line + "\n"
my_configs = getattr(self.space, name) my_configs = getattr(self.space, name)
keys = my_configs.keys() keys = my_configs.to_dict().keys()
for key in keys: for key in keys:
draws += h2_format.format(key, str(my_configs.get(key, None))) draws += h2_format.format(
key, str(my_configs.to_dict().get(key, None)))
result_res = draws + border result_res = draws + border
return result_res return result_res
......
...@@ -28,6 +28,19 @@ from paddle.fluid.io import is_parameter, is_belong_to_optimizer ...@@ -28,6 +28,19 @@ from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
def get_logger(log_level, name="auto_parallel"):
logger = logging.getLogger(name)
logger.propagate = False
if not logger.handlers:
logger.setLevel(log_level)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
log_handler.setFormatter(log_format)
logger.addHandler(log_handler)
return logger
def is_valid_list_index(list, index): def is_valid_list_index(list, index):
if index >= -len(list) and index < len(list): if index >= -len(list) and index < len(list):
return True return True
...@@ -49,6 +62,58 @@ def is_dim_replicate(mapping): ...@@ -49,6 +62,58 @@ def is_dim_replicate(mapping):
return False return False
def verify_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None:
return False
if not all(isinstance(d, int) for d in dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(process_mesh.shape):
return False
for i in range(len(process_mesh.shape)):
if dims_mapping.count(i) > 1:
return False
return True
def convert_to_dims_mapping(shard_spec, process_mesh):
dims_mapping = []
for shard in shard_spec:
if shard is None:
dims_mapping.append(-1)
else:
dims_mapping.append(process_mesh.dim_names.index(shard))
return dims_mapping
def convert_to_shard_spec(dims_mapping, process_mesh):
shard_spec = []
for dim_mapping in dims_mapping:
if dim_mapping == -1:
shard_spec.append(None)
else:
shard_spec.append(process_mesh.dim_names[dim_mapping])
return shard_spec
def verify_shard_spec(shard_spec, tensor_shape, process_mesh):
if len(shard_spec) != len(tensor_shape):
return False
for shard in shard_spec:
if shard is not None and not isinstance(shard, str):
return False
if shard is not None and shard not in process_mesh.dim_names:
return False
dims_mapping = convert_to_dims_mapping(shard_spec, process_mesh)
if not verify_dims_mapping(dims_mapping, process_mesh):
return False
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0:
return False
return True
def compute_compatible_dim_mapping(dim_mappings): def compute_compatible_dim_mapping(dim_mappings):
if not dim_mappings: if not dim_mappings:
return None return None
...@@ -1040,7 +1105,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1040,7 +1105,7 @@ def set_grad_var_shape(program, dist_context):
if op.type in [ if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast", "c_allreduce_sum", "c_identity", "scale", "cast",
'fill_any_like' "fill_any_like"
]: ]:
forward_var_name = op.input_arg_names[0] forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad": elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad":
...@@ -1504,3 +1569,15 @@ def ring_id_to_process_group(ring_id): ...@@ -1504,3 +1569,15 @@ def ring_id_to_process_group(ring_id):
if g.id == ring_id: if g.id == ring_id:
return g return g
return None return None
def find_higher_order_backward_op(program):
higher_order_op_suffix = ['_grad_grad', 'triple_grad']
for block in program.blocks:
for op in block.ops:
for suffix in higher_order_op_suffix:
if suffix in op.type:
return True
return False
...@@ -314,7 +314,9 @@ class AMPState(object): ...@@ -314,7 +314,9 @@ class AMPState(object):
consume_op_attr.set_input_dist_attr( consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr) cast_name, in_var_dist_attr)
else: else:
assert in_var.dtype == dst_dtype assert in_var.dtype == dst_dtype, "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type, in_name, dst_dtype, in_var.dtype,
str(grad_op))
for out_name in grad_op.output_names: for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
import numpy as np
import paddle import paddle
from paddle.fluid import core, unique_name
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here # add new optimizers supporting rescale_grad here
...@@ -31,6 +33,10 @@ __rescale_grad_supported_opts__ = [ ...@@ -31,6 +33,10 @@ __rescale_grad_supported_opts__ = [
__max_stream_num_allow__ = 16 __max_stream_num_allow__ = 16
def numel(var):
return np.prod(list(var.shape))
@register_pass("auto_parallel_data_parallel_optimization") @register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase): class DataParallelOptimizationPass(PassBase):
""" """
...@@ -78,7 +84,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -78,7 +84,9 @@ class DataParallelOptimizationPass(PassBase):
self._analyze_program() self._analyze_program()
self._prune_grad_scaling() self._prune_grad_scaling()
self._calc_comm_overlap() self._calc_comm_overlap()
self._fuse_allreduce() grad_group = self._fuse_allreduce()
# self.summary(grad_group)
def _prune_grad_scaling(self): def _prune_grad_scaling(self):
...@@ -99,7 +107,14 @@ class DataParallelOptimizationPass(PassBase): ...@@ -99,7 +107,14 @@ class DataParallelOptimizationPass(PassBase):
self._calc_wait_comms() self._calc_wait_comms()
def _fuse_allreduce(self): def _fuse_allreduce(self):
pass
if not self._could_be_fuse():
return []
grad_group = self._group_grads()
self._update_program(grad_group)
return grad_group
def _analyze_program(self): def _analyze_program(self):
""" """
...@@ -154,7 +169,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -154,7 +169,7 @@ class DataParallelOptimizationPass(PassBase):
def _could_be_prune(self): def _could_be_prune(self):
return self.dist_context._gradient_scale and ( return self.dist_context.gradient_scale and (
self._support_rescale_grad or self._all_dp_groups_same_degree()) self._support_rescale_grad or self._all_dp_groups_same_degree())
def _all_dp_groups_same_degree(self): def _all_dp_groups_same_degree(self):
...@@ -316,3 +331,252 @@ class DataParallelOptimizationPass(PassBase): ...@@ -316,3 +331,252 @@ class DataParallelOptimizationPass(PassBase):
'op_role': OpRole.Backward, 'op_role': OpRole.Backward,
'ring_id': ring_id 'ring_id': ring_id
}) })
def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient.
# should analyse the dependencies of gradient in backward.
if find_higher_order_backward_op(default_main_program()):
return False
if self.use_sharding:
return False
return True
def _group_grads(self):
"""
conditions for gradients to be grouped:
1. group size < max_fuse_numel
2. same dp group
3. same dtype
4. dependency: grad would NOT be used by other ops within group segment
gradients inside same group would be fuse into one coalesce tensor
"""
block = default_main_program().global_block()
ops = block.ops
# group individual grad vars
# TODO consider fuse gradient for sharding reduce
# TODO let user to set fuse_grad_size
# emb = 50000 * h, ffn = 8 * h * h, mha = 4 * h * h
h = 2048
ffn_numel = 2 * (4 * h) * h
mha_numel = 3 * h * h + h * h
max_fuse_numel = ffn_numel + mha_numel
grad_groups = []
cur_group = GradientsGroup(ops, max_fuse_numel)
grouped_grad_names = set()
def collect_group(cur_group, grad_var, ring_id, i):
if len(cur_group.gradients) == 0:
cur_group = None
elif len(cur_group.gradients) == 1:
grouped_grad_names.remove(cur_group.gradients[0].name)
else:
cur_group.finalize()
grad_groups.append(cur_group)
new_group = GradientsGroup(ops, max_fuse_numel)
if grad_var:
new_group.add(grad_var, ring_id, i)
grouped_grad_names.add(grad_var.name)
return new_group
def op_depend_on_group(op, group):
vars_ = set(op.input_arg_names + op.output_arg_names)
grad_names = set([grad.name for grad in group.gradients])
return len(vars_.intersection(grad_names)) > 0
for i, op in enumerate(ops):
if is_data_parallel_reduce_op(op):
ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0]
grad_var = block.var(grad_name)
grad_numel = numel(grad_var)
if cur_group.acceptable(grad_var, ring_id):
assert grad_name not in grouped_grad_names
grouped_grad_names.add(grad_name)
cur_group.add(grad_var, ring_id, i)
else:
cur_group = collect_group(cur_group, grad_var, ring_id, i)
else:
if op_depend_on_group(op, cur_group):
cur_group = collect_group(cur_group, None, None, None)
# collect last group
collect_group(cur_group, None, None, None)
return grad_groups
def _update_program(self, grad_groups):
block = default_main_program().global_block()
remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute']
for i, group in enumerate(grad_groups[::-1]):
# create coalecse tensor
group.coalesce_var = block.create_var(name=unique_name.generate(
'coalecse_grad_{}'.format(i)),
dtype=group.dtype,
persistable=False,
stop_gradient=True)
# update allreduce & scale op
if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx]
assert scale_op.type == 'scale', "should found scale op but found {}".format(
str(scale_op))
scale_op._rename_input(scale_op.input_arg_names[0],
group.coalesce_var.name)
scale_op._rename_output(scale_op.output_arg_names[0],
group.coalesce_var.name)
allreduce_op = block.ops[group.allreduce_op_idx]
assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format(
str(allreduce_op))
allreduce_op._rename_input(allreduce_op.input_arg_names[0],
group.coalesce_var.name)
allreduce_op._rename_output(allreduce_op.output_arg_names[0],
group.coalesce_var.name)
# remvoe un-used op
remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices
for idx in sorted(remove_op_indices, reverse=True):
assert block.ops[
idx].type in remove_op_types, "Unexception: try to remove op {}".format(
str(op))
block._remove_op(idx)
# insert coalecse op
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(group.coalesce_op_idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
outputs={
"Output": grad_names,
"FusedOutput": group.coalesce_var
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": group.dtype,
"concated_shapes":
concated_shapes,
"concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward
})
block._sync_with_cpp()
# TODO update dist attr
def summary(self, grad_groups=[]):
# TODO: add logger module
import logging
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
self._logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format)
self._logger.addHandler(log_handler)
if len(grad_groups) > 0:
self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
.format(len(self._grad_name_to_group_map.keys()),
len(grad_groups)))
self._logger.info("gradient fusing group are following: ")
fused_grads = set()
for i, group in enumerate(grad_groups):
self._logger.info(
"coalecse gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients]))
fused_grads.update([grad.name for grad in group.gradients])
individual_grads = set(
self._grad_name_to_group_map.keys()) - set(fused_grads)
self._logger.info(
"the following [{}] gradients are not fused: ".format(
len(individual_grads)))
self._logger.info("individual gradient {}".format(individual_grads))
class GradientsGroup(object):
def __init__(self, ops, max_group_size):
self.max_group_size = max_group_size
self.ops = ops
self.gradients = []
self.numel = 0
self.dtype = None
self.ring_id = None
self.coalesce_var = None
self.coalesce_op_idx = -1
self.allreduce_op_idx = -1
self.scale_op_idx = -1
self.remove_wait_op_indices = []
self.remove_allreduce_op_indices = []
self.remove_scale_op_indices = []
def acceptable(self, grad_var, ring_id):
if len(self.gradients) == 0:
return True
if ring_id != self.ring_id:
return False
if numel(grad_var) + self.numel > self.max_group_size:
return False
if grad_var.dtype != self.dtype:
return False
return True
def add(self, grad_var, ring_id, i):
self.gradients.append(grad_var)
self.ring_id = ring_id
self.dtype = grad_var.dtype
self.numel += numel(grad_var)
# remove auxiliary ops in non-fuse dp allreduce
self.remove_allreduce_op_indices.append(i)
# NOTE this pass rely on the original synchronization add in previous passes
# (same stream or calc_wait_comm & comm_wait_calc)
# to guarantee the correctness of comm_calc execution order.
# so the calc_wait_comm should be keep.
grad_op_idx = i - 1
if i > 0 and self.ops[i - 1].type == 'c_wait_compute':
self.remove_wait_op_indices.append(i - 1)
grad_op_idx -= 1
if i + 1 < len(self.ops) and is_data_parallel_scale_op(self.ops[i - 1]):
self.remove_scale_op_indices.append(i + 1)
if len(self.gradients) == 1:
# TODO Remove this is a temporary hack for Tensor Parallel. the logic
# for find grad_op should be more general.
if self.ops[grad_op_idx].type == "c_allreduce_sum":
grad_op_idx -= 1
grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op))
self.coalesce_op_idx = grad_op_idx
def finalize(self):
self.allreduce_op_idx = self.remove_allreduce_op_indices.pop()
if len(self.remove_wait_op_indices) > 1:
self.remove_wait_op_indices.pop()
if len(self.remove_scale_op_indices) > 1:
self.scale_op_idx = self.remove_scale_op_indices.pop()
...@@ -16,6 +16,7 @@ from collections import defaultdict ...@@ -16,6 +16,7 @@ from collections import defaultdict
import paddle import paddle
from paddle.framework import core from paddle.framework import core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name from paddle.fluid import unique_name
from .pass_base import register_pass from .pass_base import register_pass
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
...@@ -379,6 +380,10 @@ class FP16State(object): ...@@ -379,6 +380,10 @@ class FP16State(object):
# create cast grad # create cast grad
grad_slot_name = slot_name + "@GRAD" grad_slot_name = slot_name + "@GRAD"
assert grad_slot_name in op.output_names assert grad_slot_name in op.output_names
if len(op.output(grad_slot_name)) == 0:
var = block.var(src_name)
assert var.stop_gradient is True
continue
assert len(op.output(grad_slot_name)) == 1 assert len(op.output(grad_slot_name)) == 1
grad_name = op.output(grad_slot_name)[0] grad_name = op.output(grad_slot_name)[0]
grad = block.var(grad_name) grad = block.var(grad_name)
...@@ -442,7 +447,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): ...@@ -442,7 +447,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
inputs = {'X': grads, 'Scale': loss_scaling} inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf} outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward} attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale', new_op = main_block.append_op(type='check_finite_and_unscale',
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
...@@ -536,6 +541,39 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): ...@@ -536,6 +541,39 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
return output_var return output_var
def cast_startup_program():
main_program = default_main_program()
startup_program = default_startup_program()
param_to_dtype = {}
for block in main_program.blocks:
for p in block.all_parameters():
param_to_dtype[p.name] = p.dtype
def is_initialization_op(op):
comm_op_prefix = "c_"
op_type = op.type
if op_type.startswith(comm_op_prefix):
return False
if len(op.output_arg_names) != 1 and len(op.input_arg_names) != 0:
return False
return True
for op in startup_program.global_block().ops:
if is_initialization_op(op):
output_name = op.output_arg_names[0]
if param_to_dtype.get(output_name,
None) == core.VarDesc.VarType.FP16:
assert op.has_attr(
'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format(
str(op))
if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
@register_pass("auto_parallel_fp16") @register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass): class FP16Pass(AMPPass):
...@@ -563,6 +601,8 @@ class FP16Pass(AMPPass): ...@@ -563,6 +601,8 @@ class FP16Pass(AMPPass):
input_data_var_names) input_data_var_names)
is_train = fp16_state._build_state() is_train = fp16_state._build_state()
cast_startup_program()
if is_train: if is_train:
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference # TODO (JZ-LIANG)support cast forward program only when inference
...@@ -575,18 +615,18 @@ class FP16Pass(AMPPass): ...@@ -575,18 +615,18 @@ class FP16Pass(AMPPass):
) or self.get_attr("init_loss_scaling") != 1.0: ) or self.get_attr("init_loss_scaling") != 1.0:
found_infs = [] found_infs = []
if fp32_grads: if fp32_grads:
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
_, found_inf_fp32 = _check_and_update_gradient( _, found_inf_fp32 = _check_and_update_gradient(
fp32_grads, self._loss_scaling, "@fp32", fp32_grads, self._loss_scaling, "@fp32",
self.dist_context) self.dist_context)
found_infs.append(found_inf_fp32) found_infs.append(found_inf_fp32)
if fp16_grads: if fp16_grads:
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
_, found_inf_fp16 = _check_and_update_gradient( _, found_inf_fp16 = _check_and_update_gradient(
fp16_grads, self._loss_scaling, "@fp16", fp16_grads, self._loss_scaling, "@fp16",
self.dist_context) self.dist_context)
found_infs.append(found_inf_fp16) found_infs.append(found_inf_fp16)
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
block = main_program.global_block() block = main_program.global_block()
all_infs = paddle.fluid.layers.concat(found_infs) all_infs = paddle.fluid.layers.concat(found_infs)
...@@ -608,7 +648,7 @@ class FP16Pass(AMPPass): ...@@ -608,7 +648,7 @@ class FP16Pass(AMPPass):
block, self.dist_context) block, self.dist_context)
if self.get_attr("use_dynamic_loss_scaling"): if self.get_attr("use_dynamic_loss_scaling"):
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
if fp32_grads: if fp32_grads:
self._update_loss_scaling(fp32_grads, found_inf) self._update_loss_scaling(fp32_grads, found_inf)
if fp16_grads: if fp16_grads:
......
...@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase):
super(ClipGradByGloblNormPass, self).__init__() super(ClipGradByGloblNormPass, self).__init__()
self.set_attr("rank_id", None) self.set_attr("rank_id", None)
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("params_grads", None)
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
...@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context") dist_context = self.get_attr("dist_context")
if dist_context._lr_optimizer._grad_clip is None: if dist_context._lr_optimizer._grad_clip is None:
return False return False
if self.get_attr("params_grads") is None:
return False
return True return True
def _check_conflict(self, other_pass): def _check_conflict(self, other_pass):
...@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context", None) dist_context = self.get_attr("dist_context", None)
rank_id = self.get_attr("rank_id", None) rank_id = self.get_attr("rank_id", None)
block = main_program.global_block() block = main_program.global_block()
dist_params_grads = _get_params_grads(block) dist_params_grads = self.get_attr("params_grads", None)
# dist_params_grads = _get_params_grads(block)
self.clip_helper = ClipHelper(dist_params_grads, rank_id, block, self.clip_helper = ClipHelper(dist_params_grads, rank_id, block,
dist_context) dist_context)
......
...@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context): ...@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
return optimize_ops_desc return optimize_ops_desc
def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker
op = grad.op
if op and op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName())
def _get_gm_cond_var(main_program, k_steps, dist_context): def _get_gm_cond_var(main_program, k_steps, dist_context):
main_block = main_program.global_block() main_block = main_program.global_block()
# Add const var # Add const var
...@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op( ...@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
param.type != core.VarDesc.VarType.SELECTED_ROWS param.type != core.VarDesc.VarType.SELECTED_ROWS
), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now" ), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
_remove_op_role_var(param, grad)
# {grad.name: gradient_merge_var.name} to rename opt inputs # {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge = {} grad_to_gradient_merge = {}
# {param: gradient_merge_var} to insert scale op and fill_constant op # {param: gradient_merge_var} to insert scale op and fill_constant op
......
...@@ -51,7 +51,8 @@ class ShardingPass(PassBase): ...@@ -51,7 +51,8 @@ class ShardingPass(PassBase):
super(ShardingPass, self).__init__() super(ShardingPass, self).__init__()
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("stage", None) self.set_attr("stage", None)
self.set_attr("sharding_degree", None) self.set_attr("sharding_degree", None) # for parallelizer
self.set_attr("degree", None) # for parallelizer_v2
self.set_attr("params_grads", []) self.set_attr("params_grads", [])
self.set_attr("global_rank", -1) self.set_attr("global_rank", -1)
self.dp_groups = set() self.dp_groups = set()
...@@ -59,6 +60,7 @@ class ShardingPass(PassBase): ...@@ -59,6 +60,7 @@ class ShardingPass(PassBase):
self.varname_to_sharding_info = {} self.varname_to_sharding_info = {}
self.partial_sharding = False self.partial_sharding = False
self.outer_dp_group = None self.outer_dp_group = None
self.shared_params_grads = []
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
...@@ -66,8 +68,15 @@ class ShardingPass(PassBase): ...@@ -66,8 +68,15 @@ class ShardingPass(PassBase):
if self.get_attr("stage") not in [1, 2, 3]: if self.get_attr("stage") not in [1, 2, 3]:
return False return False
if (not isinstance(self.get_attr("sharding_degree"), if self.get_attr("sharding_degree") is not None:
int)) or self.get_attr("sharding_degree") <= 1: if (not isinstance(self.get_attr("sharding_degree"), int)) \
or self.get_attr("sharding_degree") <= 1:
return False
elif self.get_attr("degree") is not None:
if (not isinstance(self.get_attr("degree"), int)) \
or self.get_attr("degree") <= 1:
return False
else:
return False return False
if len(self.get_attr("params_grads")) <= 0: if len(self.get_attr("params_grads")) <= 0:
return False return False
...@@ -82,7 +91,8 @@ class ShardingPass(PassBase): ...@@ -82,7 +91,8 @@ class ShardingPass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
self._dist_context = self.get_attr("dist_context") self._dist_context = self.get_attr("dist_context")
self.sharding_world_size = int(self.get_attr("sharding_degree")) self.sharding_world_size = int(
self.get_attr("sharding_degree") or self.get_attr("degree"))
self.stage = int(self.get_attr("stage")) self.stage = int(self.get_attr("stage"))
self.global_rank = int(self.get_attr("global_rank")) self.global_rank = int(self.get_attr("global_rank"))
params_grads = self.get_attr("params_grads") params_grads = self.get_attr("params_grads")
...@@ -94,6 +104,8 @@ class ShardingPass(PassBase): ...@@ -94,6 +104,8 @@ class ShardingPass(PassBase):
self._shard_gradient_synchronization(main_block) self._shard_gradient_synchronization(main_block)
self._shard_parameter(main_block, startup_block) self._shard_parameter(main_block, startup_block)
context.set_attr("params_grads", self.shared_params_grads)
def _build_sharding_groups(self, main_block, params_grads): def _build_sharding_groups(self, main_block, params_grads):
self._collective_data_parallel_groups(main_block) self._collective_data_parallel_groups(main_block)
self._build_sharding_infos(params_grads) self._build_sharding_infos(params_grads)
...@@ -148,13 +160,10 @@ class ShardingPass(PassBase): ...@@ -148,13 +160,10 @@ class ShardingPass(PassBase):
self._dist_context._sharding_group = sharding_group self._dist_context._sharding_group = sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group = [p for p, g in params_grads]
assert len(params_in_group) == len(
set(params_in_group)), "found duplicated param in params_grads"
sharding_info = ShardingInfo(sharding_group, self.global_rank, sharding_info = ShardingInfo(sharding_group, self.global_rank,
params_in_group) params_grads)
self.sharding_infos.append(sharding_info) self.sharding_infos.append(sharding_info)
for param in params_in_group: for param in sharding_info.params:
self.varname_to_sharding_info[param.name] = sharding_info self.varname_to_sharding_info[param.name] = sharding_info
def _shard_optimizer(self, main_block, startup_block, params_grads, def _shard_optimizer(self, main_block, startup_block, params_grads,
...@@ -201,6 +210,7 @@ class ShardingPass(PassBase): ...@@ -201,6 +210,7 @@ class ShardingPass(PassBase):
op.desc.set_output('Out', reversed_x) op.desc.set_output('Out', reversed_x)
else: else:
if op.type == "check_finite_and_unscale": if op.type == "check_finite_and_unscale":
op_role = op.attr('op_role')
out_name = op.output_arg_names[0] out_name = op.output_arg_names[0]
out_var = main_block.vars[out_name] out_var = main_block.vars[out_name]
main_block._remove_op(idx, sync=False) main_block._remove_op(idx, sync=False)
...@@ -212,6 +222,7 @@ class ShardingPass(PassBase): ...@@ -212,6 +222,7 @@ class ShardingPass(PassBase):
"shape": out_var.shape, "shape": out_var.shape,
"dtype": out_var.dtype, "dtype": out_var.dtype,
"value": 0, "value": 0,
OP_ROLE_KEY: op_role,
}) })
else: else:
main_block._remove_op(idx, sync=False) main_block._remove_op(idx, sync=False)
...@@ -313,6 +324,9 @@ class ShardingPass(PassBase): ...@@ -313,6 +324,9 @@ class ShardingPass(PassBase):
if varname != param_name if varname != param_name
]) ])
main_block._remove_op(idx, sync=False) main_block._remove_op(idx, sync=False)
else:
self.shared_params_grads.append(
self._get_param_grad(param_name))
for idx, op in reversed(list(enumerate(startup_block.ops))): for idx, op in reversed(list(enumerate(startup_block.ops))):
if len(op.output_arg_names) == 1 and op.output_arg_names[ if len(op.output_arg_names) == 1 and op.output_arg_names[
...@@ -365,6 +379,13 @@ class ShardingPass(PassBase): ...@@ -365,6 +379,13 @@ class ShardingPass(PassBase):
sharding_info = self.varname_to_sharding_info[param_name] sharding_info = self.varname_to_sharding_info[param_name]
return sharding_info.is_in_local_shard(param_name) return sharding_info.is_in_local_shard(param_name)
def _get_param_grad(self, param_name):
assert param_name in self.varname_to_sharding_info
sharding_info = self.varname_to_sharding_info[param_name]
p_g = sharding_info.get_param_grad(param_name)
assert p_g is not None
return p_g
def _shard_gradient_synchronization(self, main_block): def _shard_gradient_synchronization(self, main_block):
if self.stage < 2: if self.stage < 2:
...@@ -705,9 +726,13 @@ def shard_parameters(params, group_size): ...@@ -705,9 +726,13 @@ def shard_parameters(params, group_size):
class ShardingInfo(object): class ShardingInfo(object):
def __init__(self, group, rank, params): def __init__(self, group, rank, params_grads):
self.group = group self.group = group
self.params = params self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads])
assert len(self.params_grads) == len(set(
self.params_grads)), "found duplicated param in params_grads"
self.params = [p for p, _ in params_grads]
self.param_names = [p.name for p in self.params] self.param_names = [p.name for p in self.params]
self.group_size = group.nranks self.group_size = group.nranks
self.global_rank = rank self.global_rank = rank
...@@ -762,3 +787,11 @@ class ShardingInfo(object): ...@@ -762,3 +787,11 @@ class ShardingInfo(object):
if usage > 0: if usage > 0:
broadcast_vars.add(param) broadcast_vars.add(param)
return broadcast_vars, param_usage return broadcast_vars, param_usage
def get_param_grad(self, param_name):
if not self.is_in_local_shard(param_name):
raise ValueError(
"param[{}] not in current rank.".format(param_name))
if param_name not in self.params_grads:
raise ValueError('param[{}] not in params_grads'.format(param_name))
return self.params_grads.get(param_name, None)
...@@ -37,8 +37,28 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -37,8 +37,28 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
${dist_ENVS}) ${dist_ENVS})
set_tests_properties(test_high_order_grad set_tests_properties(test_high_order_grad
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_grad_clip MODULES test_grad_clip ENVS ${dist_ENVS}) py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
set_tests_properties(test_grad_clip PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" ${dist_ENVS})
set_tests_properties(test_iterable_dataset
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_pass_grad_clip MODULES test_pass_grad_clip ENVS
${dist_ENVS})
set_tests_properties(test_pass_grad_clip
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_gradient_merge MODULES test_pass_gradient_merge
ENVS ${dist_ENVS})
set_tests_properties(test_pass_gradient_merge
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_recompute MODULES test_pass_recompute ENVS
${dist_ENVS})
set_tests_properties(test_pass_recompute
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_sharding MODULES test_pass_sharding ENVS
${dist_ENVS})
set_tests_properties(test_pass_sharding
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_amp MODULES test_pass_amp ENVS ${dist_ENVS})
set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50) TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion py_test_modules(test_while_op_completion MODULES test_while_op_completion
...@@ -70,11 +90,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -70,11 +90,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2) py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip) py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip)
py_test_modules(test_quantization MODULES test_quantization)
py_test_modules(test_dist_matmul MODULES test_dist_matmul) py_test_modules(test_dist_matmul MODULES test_dist_matmul)
py_test_modules(test_process_mesh MODULES test_process_mesh)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS})
set_tests_properties(test_iterable_dataset
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
endif() endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
def apply_pass(use_amp=False, level=None):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum'
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"]
amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level)
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestAMPPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, level=None):
reset_prog()
strategy = apply_pass(use_amp, level)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses, rtol=None, atol=None):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=rtol or self.rtol,
atol=atol or self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_amp_pass(self):
# mp2 training
mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"])
# mp2 amp-o1 training
amp_o1_engine = self.get_engine(True, "o1")
amp_o1_losses = amp_o1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o1_losses = np.array(amp_o1_losses["loss"])
# self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training
amp_o2_engine = self.get_engine(True, "o2")
amp_o2_losses = amp_o2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o2_losses = np.array(amp_o2_losses["loss"])
# self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training
amp_o3_engine = self.get_engine(True, "o3")
amp_o3_losses = amp_o3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o3_losses = np.array(amp_o3_losses["loss"])
# self.check_results(mp_losses, amp_o3_losses)
if __name__ == "__main__":
unittest.main()
...@@ -32,7 +32,7 @@ import paddle.distributed.auto_parallel as auto ...@@ -32,7 +32,7 @@ import paddle.distributed.auto_parallel as auto
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
batch_size = 4 batch_size = 4
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
...@@ -103,11 +103,7 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -103,11 +103,7 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, 1], shape=[batch_size, sequence_len, 1],
dtype='float32') dtype='float32')
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, [None, None, None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mappig": [-1, -1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -126,9 +122,6 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -126,9 +122,6 @@ def mlp_pretrain_forward(train_program, start_program):
def train(): def train():
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False dist_strategy.amp = False
dist_strategy.pipeline = False dist_strategy.pipeline = False
......
...@@ -18,24 +18,21 @@ import random ...@@ -18,24 +18,21 @@ import random
import numpy as np import numpy as np
import paddle import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed.auto_parallel.engine import Engine
from get_gpt_model import generate_model, create_data_holder, FakeDataset from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static() paddle.enable_static()
def apply_pass(use_sharding=False): def apply_pass(use_sharding=False):
strategy = fleet.DistributedStrategy() strategy = auto.Strategy()
strategy.semi_auto = True strategy.auto_mode = "semi"
strategy.reinit = True
if use_sharding: if use_sharding:
strategy.sharding = True sharding = strategy.sharding
strategy.sharding_configs = { sharding.degree = 2
"sharding_degree": 2, sharding.stage = 2
"stage": 2,
}
return strategy return strategy
...@@ -76,34 +73,17 @@ class TestGradientClipByGlobalNorm(unittest.TestCase): ...@@ -76,34 +73,17 @@ class TestGradientClipByGlobalNorm(unittest.TestCase):
paddle.seed(2022) paddle.seed(2022)
np.random.seed(2022) np.random.seed(2022)
random.seed(2022) random.seed(2022)
engine.mode = "train" place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor.run(engine.startup_program) engine._executor = paddle.static.Executor(place)
def get_dp2_engine(self): def get_engine(self, use_sharding=False):
reset_prog() reset_prog()
strategy = apply_pass() strategy = apply_pass(use_sharding)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp") model, loss = generate_model("dp")
inputs_spec, labels_spec = create_data_holder(self.batch_size) engine = auto.Engine(model, loss, opt, strategy=strategy)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
self.init(engine)
return engine
def get_dp2sharding2_engine(self):
reset_prog()
strategy = apply_pass(True)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
inputs_spec, labels_spec = create_data_holder(self.batch_size)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
self.init(engine) self.init(engine)
return engine return engine
...@@ -121,15 +101,13 @@ class TestGradientClipByGlobalNorm(unittest.TestCase): ...@@ -121,15 +101,13 @@ class TestGradientClipByGlobalNorm(unittest.TestCase):
def test_grad_clip(self): def test_grad_clip(self):
# dp2 training # dp2 training
dp_engine = self.get_dp2_engine() dp_engine = self.get_engine()
dp_engine.fit(self.dataset, batch_size=self.batch_size, use_cache=True) dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_param_values = get_parameter_value(dp_engine.main_program) dp_param_values = get_parameter_value(dp_engine.main_program)
# dp2sharding2 training # dp2sharding2 training
sharding_engine = self.get_dp2sharding2_engine() sharding_engine = self.get_engine(True)
sharding_engine.fit(self.dataset, sharding_engine.fit(self.dataset, 3, batch_size=self.batch_size)
batch_size=self.batch_size,
use_cache=True)
sharding_param_values = get_parameter_value( sharding_param_values = get_parameter_value(
sharding_engine.main_program) sharding_engine.main_program)
......
...@@ -27,10 +27,8 @@ import paddle.nn.functional as F ...@@ -27,10 +27,8 @@ import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
from paddle.fluid import layers from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.optimizer.lr import CosineAnnealingDecay from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn from paddle.fluid.dataloader.collate import default_collate_fn
...@@ -47,6 +45,8 @@ class_num = 10 ...@@ -47,6 +45,8 @@ class_num = 10
paddle.seed(44) paddle.seed(44)
is_fetch = True
class MyDataset(Dataset): class MyDataset(Dataset):
...@@ -90,19 +90,20 @@ class MLPLayer(nn.Layer): ...@@ -90,19 +90,20 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input): def forward(self, input):
out = auto.shard_op(self.norm, dist_attr={"process_mesh": out = auto.shard_op(self.norm, PP_MESH_0)(input)
PP_MESH_0})(input)
out = self.linear0(out) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, dist_attr={"process_mesh": out = auto.shard_op(self.linear1, PP_MESH_1)(out)
PP_MESH_1})(out)
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
self.out = out if is_fetch:
auto.fetch(out, "out")
return out return out
def train(fetch): def train(fetch):
global is_fetch
is_fetch = fetch
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
dropout_ratio=0.1, dropout_ratio=0.1,
...@@ -113,46 +114,34 @@ def train(fetch): ...@@ -113,46 +114,34 @@ def train(fetch):
beta2=0.999, beta2=0.999,
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=None)
metric = paddle.metric.Accuracy()
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') strategy = auto.Strategy()
labels_spec = InputSpec([batch_size], 'int64', 'label') strategy.auto_mode = "semi"
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset, eval_dataset1 = MyDataset(5 * batch_size)
engine.fit(train_data=train_dataset,
epochs=2,
batch_size=batch_size, batch_size=batch_size,
steps_per_epoch=batch_num * batch_size, valid_data=eval_dataset1)
fetches=fetches)
# eval # eval
eval_dataset = MyDataset(batch_size) eval_dataset2 = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches) engine.evaluate(eval_dataset2, batch_size=batch_size)
# predict # predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches) engine.predict(test_dataset, batch_size=batch_size)
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf') model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=False, mode='predict') engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup() temp_dir.cleanup()
......
...@@ -26,11 +26,9 @@ import paddle.static as static ...@@ -26,11 +26,9 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
from paddle.fluid import layers from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
paddle.enable_static() paddle.enable_static()
batch_size = 2 batch_size = 2
...@@ -91,6 +89,7 @@ class MLPLayer(nn.Layer): ...@@ -91,6 +89,7 @@ class MLPLayer(nn.Layer):
out = self.linear1(out) out = self.linear1(out)
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
auto.fetch(out, "out")
self.out = out self.out = out
return out return out
...@@ -107,46 +106,32 @@ def train(fetch): ...@@ -107,46 +106,32 @@ def train(fetch):
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') dist_strategy = auto.Strategy()
labels_spec = InputSpec([batch_size], 'int64', 'label') dist_strategy.auto_mode = "semi"
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine # init engine
engine = Engine(mlp, engine = auto.Engine(mlp,
inputs_spec=inputs_spec, loss,
labels_spec=labels_spec, optimizer,
paddle.metric.Accuracy(),
strategy=dist_strategy) strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset, batch_size=batch_size, fetches=fetches) engine.fit(train_dataset, batch_size=batch_size)
# eval # eval
eval_dataset = MyDataset(batch_size) eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches) engine.evaluate(eval_dataset, batch_size=batch_size)
# predict # predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches) engine.predict(test_dataset, batch_size=batch_size)
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf') model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict') engine.save(model_filename, training=False)
temp_dir.cleanup() temp_dir.cleanup()
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
import sys import sys
import numpy as np import numpy as np
import random
import paddle import paddle
import paddle.distributed.auto_parallel as auto
sys.path.append("..") sys.path.append("..")
import auto_parallel_gpt_model as modeling import auto_parallel_gpt_model as modeling
...@@ -25,7 +27,7 @@ sequence_len = 512 ...@@ -25,7 +27,7 @@ sequence_len = 512
vocab_size = 1000 vocab_size = 1000
class FakeDataset: class FakeDataset(paddle.io.Dataset):
def __init__(self, num_samples): def __init__(self, num_samples):
self.num_samples = num_samples self.num_samples = num_samples
...@@ -33,6 +35,9 @@ class FakeDataset: ...@@ -33,6 +35,9 @@ class FakeDataset:
self.vocab_size = vocab_size self.vocab_size = vocab_size
def __getitem__(self, idx): def __getitem__(self, idx):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
tokens = np.random.randint(self.vocab_size, size=self.sequence_len) tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len) position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape( attention_mask = np.tril(np.ones(self.sequence_len)).reshape(
...@@ -67,8 +72,9 @@ def create_data_holder(batch_size): ...@@ -67,8 +72,9 @@ def create_data_holder(batch_size):
def generate_model(strategy): def generate_model(strategy):
modeling.init_global() modeling.init_global()
modeling._global_process_mesh = list( ranks = list(range(paddle.distributed.get_world_size()))
range(paddle.distributed.get_world_size())) modeling._global_process_mesh = auto.ProcessMesh(mesh=ranks,
dim_names=["x"])
if strategy == "serial": if strategy == "serial":
modeling._global_parallel_strategy = "serial" modeling._global_parallel_strategy = "serial"
elif strategy == "mp": elif strategy == "mp":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_gradient_merge=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_gradient_merge:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 4
gradient_merge.avg = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestGradientMergePass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 8
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_gradient_merge=False):
reset_prog()
strategy = apply_pass(use_gradient_merge)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_gradient_merge_pass(self):
# dp2 training
dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"])
# dp2 gradient merge training
gm_engine = self.get_engine(True)
gm_losses = gm_engine.fit(self.dataset, 3, batch_size=self.batch_size)
gm_losses = np.array(gm_losses["loss"])
avg_loss = 0
pass_avg_ret_list = []
for i, pass_ret in enumerate(gm_losses):
if (i + 1) % 4 == 0:
avg_loss += pass_ret
pass_avg_ret_list.append(avg_loss / 4)
avg_loss = 0
else:
avg_loss += pass_ret
self.check_results(dp_losses, np.array(pass_avg_ret_list))
if __name__ == "__main__":
unittest.main()
...@@ -17,11 +17,7 @@ import paddle ...@@ -17,11 +17,7 @@ import paddle
import unittest import unittest
import numpy as np import numpy as np
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.static import InputSpec
from paddle.distributed import fleet
from paddle.incubate.autograd import Hessian from paddle.incubate.autograd import Hessian
from paddle.distributed.auto_parallel.engine import Engine
np.random.seed(1234) np.random.seed(1234)
paddle.seed(1234) paddle.seed(1234)
...@@ -87,7 +83,7 @@ class LaplaceModel(paddle.nn.Layer): ...@@ -87,7 +83,7 @@ class LaplaceModel(paddle.nn.Layer):
return eq_loss, bc_u return eq_loss, bc_u
class LaplaceDataset: class LaplaceDataset(paddle.io.Dataset):
def __init__(self, num_sample): def __init__(self, num_sample):
self.num_sample = num_sample self.num_sample = num_sample
...@@ -129,23 +125,14 @@ def main(): ...@@ -129,23 +125,14 @@ def main():
# model # model
laplace = LaplaceModel() laplace = LaplaceModel()
# spec dist_strategy = auto.Strategy()
inputs_spec = [ dist_strategy.auto_mode = "semi"
InputSpec([100, 2], 'float32', 'x'),
InputSpec([36], 'int64', 'bc_idx')
]
labels_spec = InputSpec([36, 1], 'float32', 'bc_v')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
engine = Engine(laplace, engine = auto.Engine(laplace,
inputs_spec=inputs_spec, loss=loss_func,
labels_spec=labels_spec, optimizer=optimizer,
strategy=dist_strategy) strategy=dist_strategy)
engine.prepare(optimizer=optimizer, loss=loss_func) engine.fit(train_dataset, train_sample_split=2, batch_size=None)
engine.fit(train_dataset, batch_size=None)
dist_context = engine.dist_context dist_context = engine.dist_context
block = engine.main_program.global_block() block = engine.main_program.global_block()
......
...@@ -28,9 +28,8 @@ import paddle.utils as utils ...@@ -28,9 +28,8 @@ import paddle.utils as utils
from paddle.fluid import layers from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.optimizer.lr import CosineAnnealingDecay from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn from paddle.fluid.dataloader.collate import default_collate_fn
...@@ -48,10 +47,9 @@ class_num = 10 ...@@ -48,10 +47,9 @@ class_num = 10
paddle.seed(44) paddle.seed(44)
class MyDataset(IterableDataset): class MyDataset(paddle.io.IterableDataset):
def __init__(self, num_samples): def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples self.num_samples = num_samples
def __iter__(self): def __iter__(self):
...@@ -61,10 +59,9 @@ class MyDataset(IterableDataset): ...@@ -61,10 +59,9 @@ class MyDataset(IterableDataset):
yield input, label yield input, label
class MyDataset1(Dataset): class MyDataset1(paddle.io.Dataset):
def __init__(self, num_samples): def __init__(self, num_samples):
super(MyDataset1, self).__init__()
self.num_samples = num_samples self.num_samples = num_samples
self.data = [] self.data = []
for i in range(self.num_samples): for i in range(self.num_samples):
...@@ -112,12 +109,10 @@ class MLPLayer(nn.Layer): ...@@ -112,12 +109,10 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input): def forward(self, input):
out = auto.shard_op(self.norm, dist_attr={"process_mesh": out = auto.shard_op(self.norm, PP_MESH_0)(input)
PP_MESH_0})(input)
out = self.linear0(out) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, dist_attr={"process_mesh": out = auto.shard_op(self.linear1, PP_MESH_1)(out)
PP_MESH_1})(out)
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
self.out = out self.out = out
...@@ -136,54 +131,36 @@ def train(fetch): ...@@ -136,54 +131,36 @@ def train(fetch):
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') dist_strategy = auto.Strategy()
labels_spec = InputSpec([batch_size], 'int64', 'label') dist_strategy.auto_mode = "semi"
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
dist_strategy.split_data = True dist_strategy.split_data = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine # init engine
engine = Engine(mlp, engine = auto.Engine(mlp,
inputs_spec=inputs_spec, loss,
labels_spec=labels_spec, optimizer,
paddle.metric.Accuracy(),
strategy=dist_strategy) strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
train_dataset1 = MyDataset1(batch_num) engine.fit(train_dataset, epochs=2, batch_size=batch_size)
engine.fit(train_dataset,
epochs=2, train_dataset1 = MyDataset1(batch_size * batch_num)
batch_size=batch_size, engine.fit(train_dataset1, epochs=2, batch_size=None)
steps_per_epoch=batch_num,
fetches=fetches)
engine.fit(train_dataset1,
epochs=2,
batch_size=None,
steps_per_epoch=batch_num,
fetches=fetches)
# eval # eval
eval_dataset = MyDataset(batch_size) eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches) engine.evaluate(eval_dataset, batch_size=batch_size)
# predict # predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches) engine.predict(test_dataset, batch_size=batch_size)
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf') model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict') engine.save(model_filename, training=False)
temp_dir.cleanup() temp_dir.cleanup()
......
...@@ -27,10 +27,8 @@ import paddle.nn.functional as F ...@@ -27,10 +27,8 @@ import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
from paddle.fluid import layers from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from engine_api_dp import MyDataset from engine_api_dp import MyDataset
paddle.enable_static() paddle.enable_static()
...@@ -43,20 +41,6 @@ class_num = 10 ...@@ -43,20 +41,6 @@ class_num = 10
paddle.seed(44) paddle.seed(44)
# class MyDataset(Dataset):
# def __init__(self, num_samples):
# super(MyDataset, self).__init__()
# self.num_samples = num_samples
# def __getitem__(self, index):
# input = np.random.uniform(size=image_size).astype("float32")
# label = np.random.randint(0, class_num - 1, dtype="int64")
# return input, label
# def __len__(self):
# return self.num_samples
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -107,50 +91,33 @@ def train(fetch): ...@@ -107,50 +91,33 @@ def train(fetch):
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') dist_strategy = auto.Strategy()
labels_spec = InputSpec([batch_size], 'int64', 'label') dist_strategy.auto_mode = "semi"
# sharding config
dist_strategy = fleet.DistributedStrategy() sharding = dist_strategy.sharding
dist_strategy.amp = False sharding.enable = True
dist_strategy.pipeline = False sharding.degree = 2
dist_strategy.recompute = False sharding.stage = 3
# init parallel optimizer sharding.enable_tuning = True
dist_strategy.semi_auto = True sharding.tuning_range = [0, 1, 2, 3]
dist_strategy.sharding = True
dist_strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 3,
"enable_tuning": True,
}
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
import tempfile
tmp_dir = tempfile.TemporaryDirectory()
dataset = MyDataset(batch_num * batch_size)
# Tuning configuration # Tuning configuration
tuning_config = { tuning = dist_strategy.tuning
"batch_size": batch_size, tuning.enable = True
"dataset": dataset, tuning.profile_start_step = 1
"profile_start_step": 1, tuning.profile_end_step = 5
"profile_end_step": 5, tuning.run_after_tuning = True
"run_after_tuning": True, tuning.verbose = True
"sharding": {
"stage_range": [0, 1, 2, 3] dataset = MyDataset(batch_num * batch_size)
}, engine = auto.Engine(mlp,
"verbose": True, loss,
} optimizer,
engine = Engine(mlp, paddle.metric.Accuracy(),
inputs_spec=inputs_spec, strategy=dist_strategy)
labels_spec=labels_spec, engine._tune(dataset, batch_size=batch_size)
strategy=dist_strategy,
user_tuning_config=tuning_config)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# check tuned # check tuned
assert (engine._dist_contexts['train'].strategy.sharding_configs['stage'] != assert (engine._dist_contexts['train'].strategy.sharding.stage != 3)
3)
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
def apply_pass(use_recompute=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestRecomputePass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_recompute=False):
reset_prog()
strategy = apply_pass(use_recompute)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_recompute_pass(self):
# mp2 training
mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"])
# mp2 recompute training
rc_engine = self.get_engine(True)
rc_losses = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc_losses = np.array(rc_losses["loss"])
self.check_results(mp_losses, rc_losses)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_sharding=False, stage=None):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_sharding:
sharding = strategy.sharding
sharding.enable = True
sharding.degree = 2
sharding.stage = 1
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_sharding=False, stage=None):
reset_prog()
strategy = apply_pass(use_sharding, stage)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_sharding_pass(self):
# dp2 training
dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"])
# sharding2 stage1 training
sharding1_engine = self.get_engine(True, 1)
sharding1_losses = sharding1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding1_losses = np.array(sharding1_losses["loss"])
self.check_results(dp_losses, sharding1_losses)
# sharding2 stage2 training
sharding2_engine = self.get_engine(True, 2)
sharding2_losses = sharding2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding2_losses = np.array(sharding2_losses["loss"])
self.check_results(dp_losses, sharding2_losses)
# sharding2 stage3 training
sharding3_engine = self.get_engine(True, 3)
sharding3_losses = sharding3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding3_losses = np.array(sharding3_losses["loss"])
self.check_results(dp_losses, sharding3_losses)
if __name__ == "__main__":
unittest.main()
...@@ -45,9 +45,10 @@ from test_cluster import cluster_json ...@@ -45,9 +45,10 @@ from test_cluster import cluster_json
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp" _global_parallel_strategy = "dp_mp_pp"
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) _global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]],
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]]) dim_names=["x", "y", "z"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]]) PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -74,16 +75,8 @@ class MLPLayer(nn.Layer): ...@@ -74,16 +75,8 @@ class MLPLayer(nn.Layer):
self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input): def forward(self, input):
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"])
dist_attr={ auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -111,16 +104,8 @@ def mlp_forward(train_program, start_program): ...@@ -111,16 +104,8 @@ def mlp_forward(train_program, start_program):
embedding = paddle.nn.Embedding(10, hidden_size, sparse=True) embedding = paddle.nn.Embedding(10, hidden_size, sparse=True)
embedding_out = embedding(fill_constant_out) embedding_out = embedding(fill_constant_out)
auto.shard_tensor(input, auto.shard_tensor(input, PP_MESH_0, ["x", None])
dist_attr={ auto.shard_tensor(label, PP_MESH_1, ["x", None])
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
......
...@@ -34,7 +34,10 @@ paddle.enable_static() ...@@ -34,7 +34,10 @@ paddle.enable_static()
batch_size = 4 batch_size = 4
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]] _g_process_mesh = [
auto.ProcessMesh([0, 1], dim_names=["x"]),
auto.ProcessMesh([2, 3], dim_names=["x"])
]
def get_random_inputs_and_labels(input_shape, label_shape): def get_random_inputs_and_labels(input_shape, label_shape):
...@@ -82,18 +85,10 @@ class MLPLayer(nn.Layer): ...@@ -82,18 +85,10 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
out = self.norm(input) out = self.norm(input)
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, _g_process_mesh[0], [None, "x"])
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
out = self.linear0(out) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight, _g_process_mesh[1], ["x", None])
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear1(out) out = self.linear1(out)
return out return out
...@@ -123,16 +118,8 @@ def get_program(): ...@@ -123,16 +118,8 @@ def get_program():
dataloader.set_batch_generator(batch_generator_creator(), dataloader.set_batch_generator(batch_generator_creator(),
places=paddle.static.cuda_places()) places=paddle.static.cuda_places())
# data dist_attr # data dist_attr
auto.shard_tensor(input, auto.shard_tensor(input, _g_process_mesh[0], ["x", None, None])
dist_attr={ auto.shard_tensor(label, _g_process_mesh[0], ["x", None, None])
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
mlp_start = MLPLayer(hidden_size=hidden_size, mlp_start = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
......
...@@ -42,19 +42,13 @@ def make_program_lookup_table_v1_mp_dp(): ...@@ -42,19 +42,13 @@ def make_program_lookup_table_v1_mp_dp():
is_sparse=False) is_sparse=False)
loss = paddle.fluid.layers.reduce_mean(emb_out) loss = paddle.fluid.layers.reduce_mean(emb_out)
auto.shard_tensor(src_ids, auto.shard_tensor(
dist_attr={ src_ids, auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]),
"process_mesh": auto.ProcessMesh([[0, 1], [2, ["x", None, None])
3]]),
"dims_mapping": [0, -1, -1]
})
emb_weight = block.vars["emb_weight"] emb_weight = block.vars["emb_weight"]
auto.shard_tensor(emb_weight, auto.shard_tensor(
dist_attr={ emb_weight, auto.ProcessMesh([[0, 1], [2, 3]],
"process_mesh": auto.ProcessMesh([[0, 1], [2, dim_names=["x", "y"]), ["y", None])
3]]),
"dims_mapping": [1, -1]
})
return main_program, start_program, loss return main_program, start_program, loss
......
...@@ -22,82 +22,58 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr ...@@ -22,82 +22,58 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
mesh = [[0, 1], [2, 3]] mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
def init_x_row(trans_x): def init_x_row(trans_x):
if trans_x: if trans_x:
x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32') x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32')
auto.shard_tensor(x, auto.shard_tensor(x, mesh, ["x", "y", None])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, 1, -1]
})
return x return x
else: else:
x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32') x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32')
auto.shard_tensor(x, auto.shard_tensor(x, mesh, ["x", None, "y"])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1, 1]
})
return x return x
def init_x_col(trans_x): def init_x_col(trans_x):
if trans_x: if trans_x:
x = paddle.static.data(name='x', shape=[6, 8], dtype='float32') x = paddle.static.data(name='x', shape=[6, 8], dtype='float32')
auto.shard_tensor(x, auto.shard_tensor(x, mesh, [None, "x"])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 0]
})
return x return x
else: else:
x = paddle.static.data(name='x', shape=[8, 6], dtype='float32') x = paddle.static.data(name='x', shape=[8, 6], dtype='float32')
auto.shard_tensor(x, auto.shard_tensor(x, mesh, ["x", None])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1]
})
return x return x
def init_y_row(trans_y): def init_y_row(trans_y):
if trans_y: if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32') y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y, auto.shard_tensor(y, mesh, [None, "y"])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
return y return y
else: else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32') y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y, auto.shard_tensor(y, mesh, ["y", None])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
return y return y
def init_y_col(trans_y): def init_y_col(trans_y):
if trans_y: if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32') y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y, auto.shard_tensor(y, mesh, ["y", None])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
return y return y
else: else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32') y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y, auto.shard_tensor(y, mesh, [None, "y"])
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
return y return y
......
...@@ -71,11 +71,8 @@ class TestDistOpCost(unittest.TestCase): ...@@ -71,11 +71,8 @@ class TestDistOpCost(unittest.TestCase):
shape=[4, 1], shape=[4, 1],
dtype='float32') dtype='float32')
label.stop_gradient = True label.stop_gradient = True
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x", None])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
tmp = paddle.fluid.layers.fill_constant_batch_size_like( tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[2, 8], value=1, dtype='float32') input=x, shape=[2, 8], value=1, dtype='float32')
weight_attr = paddle.ParamAttr() weight_attr = paddle.ParamAttr()
...@@ -121,17 +118,12 @@ class TestDistOpCost(unittest.TestCase): ...@@ -121,17 +118,12 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1], shape=[8, 1],
dtype='float32') dtype='float32')
label.stop_gradient = True label.stop_gradient = True
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x"])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label, auto.shard_tensor(label,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
# embedding # embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like( tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32') input=x, shape=[4], value=1, dtype='int32')
...@@ -141,12 +133,9 @@ class TestDistOpCost(unittest.TestCase): ...@@ -141,12 +133,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type == "lookup_table_v2": if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]] W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W, auto.shard_tensor(
dist_attr={ W, auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": ["x", None])
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out, out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0] [1, 0]) # [8, 2] [-1, 0]
...@@ -154,26 +143,20 @@ class TestDistOpCost(unittest.TestCase): ...@@ -154,26 +143,20 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter( param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1] [4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1, auto.shard_tensor(param1,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter( param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0] [8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, "x"])
"dims_mapping": [-1, 0]
})
out1 = paddle.fluid.layers.matmul(out, out1 = paddle.fluid.layers.matmul(out,
param1) # [8, 8] [-1, -1] param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter( tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1] [8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, None])
"dims_mapping": [-1, -1]
})
tmp_out = paddle.fluid.layers.matmul(out1, tmp_param) tmp_out = paddle.fluid.layers.matmul(out1, tmp_param)
out2 = paddle.fluid.layers.matmul(tmp_out, out2 = paddle.fluid.layers.matmul(tmp_out,
param2) # [8, 4] [-1, 0] param2) # [8, 4] [-1, 0]
...@@ -227,17 +210,12 @@ class TestDistOpCost(unittest.TestCase): ...@@ -227,17 +210,12 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1], shape=[8, 1],
dtype='float32') dtype='float32')
label.stop_gradient = True label.stop_gradient = True
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x"])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label, auto.shard_tensor(label,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
# embedding # embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like( tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32') input=x, shape=[4], value=1, dtype='int32')
...@@ -247,12 +225,9 @@ class TestDistOpCost(unittest.TestCase): ...@@ -247,12 +225,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type == "lookup_table_v2": if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]] W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W, auto.shard_tensor(
dist_attr={ W, auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": ["x", None])
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out, out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0] [1, 0]) # [8, 2] [-1, 0]
...@@ -260,25 +235,20 @@ class TestDistOpCost(unittest.TestCase): ...@@ -260,25 +235,20 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter( param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1] [4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1, auto.shard_tensor(param1,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter( param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0] [8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, "x"])
"dims_mapping": [-1, 0]
})
out1 = paddle.matmul(out, param1) # [8, 8] [-1, -1] out1 = paddle.matmul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter( tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1] [8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, None])
"dims_mapping": [-1, -1]
})
tmp_out = paddle.matmul(out1, tmp_param) tmp_out = paddle.matmul(out1, tmp_param)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0] out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]
...@@ -331,17 +301,11 @@ class TestDistOpCost(unittest.TestCase): ...@@ -331,17 +301,11 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1], shape=[8, 1],
dtype='float32') dtype='float32')
label.stop_gradient = True label.stop_gradient = True
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x"])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label, auto.shard_tensor(label,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
# embedding # embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like( tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32') input=x, shape=[4], value=1, dtype='int32')
...@@ -351,12 +315,9 @@ class TestDistOpCost(unittest.TestCase): ...@@ -351,12 +315,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type == "lookup_table_v2": if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]] W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W, auto.shard_tensor(
dist_attr={ W, auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": ["x", None])
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out, out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0] [1, 0]) # [8, 2] [-1, 0]
...@@ -364,25 +325,21 @@ class TestDistOpCost(unittest.TestCase): ...@@ -364,25 +325,21 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter( param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1] [4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1, auto.shard_tensor(param1,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), ["x", None])
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter( param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0] [8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, "x"])
"dims_mapping": [-1, 0]
})
out1 = paddle.fluid.layers.mul(out, param1) # [8, 8] [-1, -1] out1 = paddle.fluid.layers.mul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter( tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1] [8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2, auto.shard_tensor(param2,
dist_attr={ auto.ProcessMesh([0, 1], dim_names=["x"]),
"process_mesh": auto.ProcessMesh([0, 1]), [None, None])
"dims_mapping": [-1, -1]
})
tmp_out = paddle.fluid.layers.mul(out1, tmp_param) tmp_out = paddle.fluid.layers.mul(out1, tmp_param)
out2 = paddle.fluid.layers.mul(tmp_out, out2 = paddle.fluid.layers.mul(tmp_out,
param2) # [8, 4] [-1, 0] param2) # [8, 4] [-1, 0]
......
...@@ -29,11 +29,8 @@ def make_program_dp2(): ...@@ -29,11 +29,8 @@ def make_program_dp2():
with paddle.static.program_guard(main_program, start_program): with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False x.stop_gradient = False
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x", None, None])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1, -1]
})
tmp_0 = paddle.norm(x, p=2) tmp_0 = paddle.norm(x, p=2)
return main_program, start_program, tmp_0 return main_program, start_program, tmp_0
...@@ -44,11 +41,8 @@ def make_program_serial(): ...@@ -44,11 +41,8 @@ def make_program_serial():
with paddle.static.program_guard(main_program, start_program): with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False x.stop_gradient = False
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0], dim_names=["x"]),
dist_attr={ [None, None, None])
"process_mesh": auto.ProcessMesh([0]),
"dims_mapping": [-1, -1, -1]
})
tmp_0 = paddle.norm(x, p=2) tmp_0 = paddle.norm(x, p=2)
return main_program, start_program, tmp_0 return main_program, start_program, tmp_0
......
...@@ -29,11 +29,9 @@ def make_program_dp2(): ...@@ -29,11 +29,9 @@ def make_program_dp2():
with paddle.static.program_guard(main_program, start_program): with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32') x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32')
x.stop_gradient = False x.stop_gradient = False
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x", None, None])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1, -1]
})
tmp_0 = paddle.reshape(x, shape=[0, 0, 4, 2]) tmp_0 = paddle.reshape(x, shape=[0, 0, 4, 2])
tmp_1 = paddle.reshape(tmp_0, shape=[0, 0, 8]) tmp_1 = paddle.reshape(tmp_0, shape=[0, 0, 8])
tmp_2 = tmp_1.reshape((tmp_1.shape[0], tmp_1.shape[1], -1)) tmp_2 = tmp_1.reshape((tmp_1.shape[0], tmp_1.shape[1], -1))
......
...@@ -25,11 +25,9 @@ def make_program_dp2(): ...@@ -25,11 +25,9 @@ def make_program_dp2():
start_program = paddle.fluid.Program() start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program): with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
dist_attr={ ["x", None, None])
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1, -1]
})
tmp_0 = x[0] tmp_0 = x[0]
tmp_1 = x[:, 0, :] tmp_1 = x[:, 0, :]
tmp_2 = x[:, :, 1] tmp_2 = x[:, :, 1]
...@@ -42,11 +40,9 @@ def make_program_serial(): ...@@ -42,11 +40,9 @@ def make_program_serial():
start_program = paddle.fluid.Program() start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program): with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
auto.shard_tensor(x, auto.shard_tensor(x, auto.ProcessMesh([0], dim_names=["x"]),
dist_attr={ [None, None, None])
"process_mesh": auto.ProcessMesh([0]),
"dims_mapping": [-1, -1, -1]
})
tmp_0 = x[0] tmp_0 = x[0]
tmp_1 = x[:, 0, :] tmp_1 = x[:, 0, :]
tmp_2 = x[:, :, 1] tmp_2 = x[:, :, 1]
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.static as static
import paddle.distributed as dist
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
process_mesh1 = ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]],
dim_names=["x", "y"])
process_mesh2 = ProcessMesh(mesh=[0, 1, 2, 3], dim_names=["x"])
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(mean=0.0,
std=initializer_range)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
def forward(self, input):
auto.shard_tensor(self.linear0.weight, process_mesh1[0], [None, "y"])
linear0 = auto.shard_op(self.linear0, process_mesh1,
[["y", None, None]], [[None, "x", None]])
linear0_out = linear0(input)
gelu = auto.shard_op(F.gelu, process_mesh1, [["y", "x", None], None])
gelu_out = gelu(linear0_out, approximate=True)
auto.shard_tensor(self.linear1.weight, shard_spec=["y", None])
linear1 = auto.shard_op(self.linear1,
process_mesh1[1],
out_shard_specs=[["y", None, None]])
linear1_out = linear1(gelu_out)
return self.linear0, self.linear1, linear0_out, gelu_out, linear1_out
class TestAutoParallelAPI(unittest.TestCase):
def test_api(self):
# input
input = static.data(name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(name="label",
shape=[batch_size, sequence_len, 1],
dtype='float32')
auto.shard_tensor(input, process_mesh1, ["x", None, None])
auto.shard_tensor(label, process_mesh1, ["y", None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
with ProcessMesh(process_mesh1.mesh, process_mesh1.dim_names):
linear0, linear1, linear0_out, gelu_out, linear1_out = mlp(input)
default_program = paddle.fluid.default_main_program()
default_dist_context = get_default_distributed_context()
self.assertEqual(len(default_program.blocks[0].ops), 5)
matmul0 = default_program.blocks[0].ops[0]
self.assertEqual(matmul0.type, "matmul_v2")
ewise_add0 = default_program.blocks[0].ops[1]
self.assertEqual(ewise_add0.type, "elementwise_add")
gelu = default_program.blocks[0].ops[2]
self.assertEqual(gelu.type, "gelu")
matmul1 = default_program.blocks[0].ops[3]
self.assertEqual(matmul1.type, "matmul_v2")
ewise_add1 = default_program.blocks[0].ops[4]
self.assertEqual(ewise_add1.type, "elementwise_add")
dist_input = default_dist_context.get_dist_tensor_for_program(input)
self.assertEqual(dist_input.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))
dist_input = default_dist_context.get_dist_tensor_for_program(label)
self.assertEqual(dist_input.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_input.dist_attr.dims_mapping, [1, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))
dist_linear0_weight = default_dist_context.get_dist_tensor_for_program(
linear0.weight)
self.assertEqual(dist_linear0_weight.dist_attr.process_mesh,
process_mesh1[0])
self.assertEqual(dist_linear0_weight.dist_attr.dims_mapping, [-1, 0])
self.assertTrue(
dist_linear0_weight.dist_attr.is_annotated("process_mesh"))
self.assertTrue(
dist_linear0_weight.dist_attr.is_annotated("dims_mapping"))
dist_linear1_weight = default_dist_context.get_dist_tensor_for_program(
linear1.weight)
self.assertEqual(dist_linear1_weight.dist_attr.process_mesh,
process_mesh1)
self.assertEqual(dist_linear1_weight.dist_attr.dims_mapping, [1, -1])
self.assertTrue(
dist_linear1_weight.dist_attr.is_annotated("process_mesh"))
self.assertTrue(
dist_linear1_weight.dist_attr.is_annotated("dims_mapping"))
dist_linear1_out = default_dist_context.get_dist_tensor_for_program(
linear1_out)
self.assertEqual(dist_linear1_out.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_linear1_out.dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(dist_linear1_out.dist_attr.is_annotated("process_mesh"))
self.assertFalse(
dist_linear1_out.dist_attr.is_annotated("dims_mapping"))
dist_op = default_dist_context.get_dist_op_for_program(matmul0)
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(input.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1)
self.assertEqual(tensor_dist_attr.dims_mapping, [1, -1, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping"))
dist_op = default_dist_context.get_dist_op_for_program(ewise_add0)
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
linear0_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1)
self.assertEqual(tensor_dist_attr.dims_mapping, [-1, 0, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping"))
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
dist_op = default_dist_context.get_dist_op_for_program(gelu)
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
linear0_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1)
self.assertEqual(tensor_dist_attr.dims_mapping, [1, 0, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping"))
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(gelu_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1)
self.assertEqual(tensor_dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertFalse(tensor_dist_attr.is_annotated("dims_mapping"))
dist_op = default_dist_context.get_dist_op_for_program(matmul1)
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1[1])
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(gelu_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1[1])
self.assertEqual(tensor_dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertFalse(tensor_dist_attr.is_annotated("dims_mapping"))
dist_op = default_dist_context.get_dist_op_for_program(ewise_add1)
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1[1])
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
linear1_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1[1])
self.assertEqual(tensor_dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping"))
if __name__ == '__main__':
unittest.main()
...@@ -26,7 +26,6 @@ import paddle.distributed.fleet as fleet ...@@ -26,7 +26,6 @@ import paddle.distributed.fleet as fleet
from paddle.io import Dataset from paddle.io import Dataset
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine
from test_to_static import MLPLayer, MyDataset from test_to_static import MLPLayer, MyDataset
...@@ -60,14 +59,12 @@ class TestEngineBase(unittest.TestCase): ...@@ -60,14 +59,12 @@ class TestEngineBase(unittest.TestCase):
self.dataset = MyDataset(self.batch_num * self.batch_size) self.dataset = MyDataset(self.batch_num * self.batch_size)
def init_engine(self): def init_engine(self):
inputs = InputSpec([self.batch_size, self.hidden_size], 'float32', 'x') # inputs = InputSpec([self.batch_size, self.hidden_size], 'float32', 'x')
labels = InputSpec([self.batch_size], 'int64', 'label') # labels = InputSpec([self.batch_size], 'int64', 'label')
self.engine = Engine(model=self.mlp, self.engine = auto.Engine(model=self.mlp,
inputs_spec=inputs,
labels_spec=labels)
self.engine.prepare(optimizer=self.optimizer,
loss=self.loss, loss=self.loss,
optimizer=self.optimizer,
metrics=paddle.metric.Accuracy()) metrics=paddle.metric.Accuracy())
...@@ -80,9 +77,9 @@ class TestLRScheduler(TestEngineBase): ...@@ -80,9 +77,9 @@ class TestLRScheduler(TestEngineBase):
def test_lr_scheduler(self): def test_lr_scheduler(self):
self.init_engine() self.init_engine()
lr = self.engine._optimizer._learning_rate
assert isinstance(lr, paddle.optimizer.lr.LRScheduler)
self.engine.fit(self.dataset, batch_size=self.batch_size) self.engine.fit(self.dataset, batch_size=self.batch_size)
lr = self.engine._lr_optimizer._learning_rate
assert isinstance(lr, paddle.optimizer.lr.LRScheduler)
class TestGradClipByGlobalNorm(TestEngineBase): class TestGradClipByGlobalNorm(TestEngineBase):
...@@ -94,7 +91,6 @@ class TestGradClipByGlobalNorm(TestEngineBase): ...@@ -94,7 +91,6 @@ class TestGradClipByGlobalNorm(TestEngineBase):
def test_grad_clip(self): def test_grad_clip(self):
clip = self.engine._optimizer._grad_clip
self.engine.fit(self.dataset, batch_size=self.batch_size) self.engine.fit(self.dataset, batch_size=self.batch_size)
self.check_program() self.check_program()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestAMPPass(unittest.TestCase):
def test_mp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "amp_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestGradientMergePass(unittest.TestCase):
def test_dp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir,
"gradient_merge_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
...@@ -14,123 +14,41 @@ ...@@ -14,123 +14,41 @@
import unittest import unittest
import sys import sys
import random
import numpy as np import numpy as np
import paddle import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from get_gpt_model import generate_model, create_data_holder, FakeDataset
from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
paddle.enable_static() paddle.enable_static()
class FakeDataset:
def __init__(self, num_samples, sequence_len, vocab_size):
self.num_samples = num_samples
self.sequence_len = sequence_len
self.vocab_size = vocab_size
def __getitem__(self, idx):
tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape(
(1, self.sequence_len, self.sequence_len)).astype(np.float32)
labels = np.random.randint(self.vocab_size, size=self.sequence_len)
loss_mask = np.ones(self.sequence_len).astype(np.float32)
return tokens, position_ids, attention_mask, labels, loss_mask
def __len__(self):
return self.num_samples
def apply_pass(): def apply_pass():
dist_strategy = fleet.DistributedStrategy() dist_strategy = auto.Strategy()
dist_strategy.semi_auto = True dist_strategy.auto_mode = "semi"
dist_strategy.qat = True qat = dist_strategy.qat
dist_strategy.qat_configs = { qat.enable = True
'channel_wise_abs_max': True, qat.channel_wise_abs_max = True
'weight_bits': 8, qat.weight_bits = 8
'activation_bits': 8, qat.activation_bits = 8
'not_quant_pattern': ['skip_quant'], qat.not_quant_pattern = ['skip_quant']
}
return dist_strategy return dist_strategy
def create_data_holder(batch_size, sequence_len):
tokens = paddle.static.InputSpec(name="tokens",
shape=[batch_size, sequence_len],
dtype='int64')
position_ids = paddle.static.InputSpec(name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = paddle.static.InputSpec(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32')
labels = paddle.static.InputSpec(name="labels",
shape=[batch_size, sequence_len],
dtype='int64')
loss_mask = paddle.static.InputSpec(name="loss_mask",
shape=[batch_size, sequence_len],
dtype='float32')
return [tokens, position_ids, attention_mask], [labels, loss_mask]
def get_gpt_model():
modeling.init_global()
modeling._global_parallel_strategy = "serial"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0])
gpt = GPTModel(vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3)
model = GPTForPretraining(gpt,
vocab_size=1000,
hidden_size=64,
initializer_range=0.02)
criterion = GPTPretrainingCriterion()
return model, criterion
class TestQuantizationPass(unittest.TestCase): class TestQuantizationPass(unittest.TestCase):
def test_qat_pass(self): def test_qat_pass(self):
batch_size = 8 batch_size = 8
batch_num = 10 batch_num = 10
sequence_len = 512
vocab_size = 1000
strategy = apply_pass() strategy = apply_pass()
model, loss = get_gpt_model() model, loss = generate_model("serial")
opt = paddle.optimizer.AdamW(learning_rate=0.00001) opt = paddle.optimizer.AdamW(learning_rate=0.00001)
inputs_spec, labels_spec = create_data_holder(batch_size=batch_size, engine = auto.Engine(model, loss, opt, strategy=strategy)
sequence_len=sequence_len) dataset = FakeDataset(batch_size * batch_num)
engine.fit(dataset, 3, batch_size=batch_size)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
dataset = FakeDataset(batch_size * batch_num, sequence_len, vocab_size)
engine.fit(train_data=dataset, batch_size=batch_size)
self.check_program(engine.main_program) self.check_program(engine.main_program)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestRecomputePass(unittest.TestCase):
def test_mp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "recompute_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestShardingPass(unittest.TestCase):
def test_dp2sharding2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "sharding_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(mean=0.0,
std=initializer_range)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
return out
class TestProcessMesh(unittest.TestCase):
def test_construction(self):
mesh = [[0, 1, 2], [3, 4, 5]]
process_mesh = ProcessMesh(mesh, dim_names=["x", "y"])
self.assertEqual(process_mesh.shape, [2, 3])
self.assertEqual(process_mesh.process_ids, [0, 1, 2, 3, 4, 5])
self.assertEqual(process_mesh.dim_names, ["x", "y"])
self.assertEqual(process_mesh.ndim, 2)
self.assertEqual(process_mesh, process_mesh)
self.assertEqual(str(process_mesh), str(process_mesh))
sub_process_mesh1 = process_mesh[0]
self.assertEqual(sub_process_mesh1.shape, [3])
self.assertEqual(sub_process_mesh1.process_ids, [0, 1, 2])
self.assertEqual(sub_process_mesh1.dim_names, ["y"])
self.assertEqual(sub_process_mesh1.ndim, 1)
sub_process_mesh2 = process_mesh[:, 1]
self.assertEqual(sub_process_mesh2.shape, [2])
self.assertEqual(sub_process_mesh2.process_ids, [1, 4])
self.assertEqual(sub_process_mesh2.dim_names, ["x"])
self.assertEqual(sub_process_mesh2.ndim, 1)
sub_process_mesh3 = sub_process_mesh2[:]
self.assertEqual(sub_process_mesh3.shape, [2])
self.assertEqual(sub_process_mesh3.process_ids, [1, 4])
self.assertEqual(sub_process_mesh3.dim_names, ["x"])
self.assertEqual(sub_process_mesh3.ndim, 1)
sub_process_mesh4 = process_mesh[1, 1]
self.assertEqual(sub_process_mesh4.shape, [1])
self.assertEqual(sub_process_mesh4.process_ids, [4])
self.assertEqual(sub_process_mesh4.dim_names, ["d0"])
self.assertEqual(sub_process_mesh4.ndim, 1)
def test_context_manager(self):
mesh = np.array([1, 2, 3, 4])
input = static.data(name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(name="label",
shape=[batch_size, sequence_len, 1],
dtype='float32')
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
with ProcessMesh(mesh, "d"):
out = mlp(input)
default_program = paddle.fluid.default_main_program()
default_dist_context = get_default_distributed_context()
for block in default_program.blocks:
for tensor in block.vars.values():
dist_tensor = default_dist_context.get_dist_tensor_for_program(
tensor)
if dist_tensor is not None:
self.assertEqual(dist_tensor.dist_attr.process_mesh,
ProcessMesh(mesh))
for op in block.ops:
dist_op = default_dist_context.get_dist_op_for_program(op)
if dist_op is not None:
self.assertEqual(dist_op.dist_attr.process_mesh,
ProcessMesh(mesh))
if __name__ == "__main__":
unittest.main()
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# limitations under the License # limitations under the License
import unittest import unittest
from paddle.distributed.auto_parallel.process_mesh_v2 import ProcessMesh from paddle.distributed.auto_parallel.process_mesh_v2 import (
ProcessMesh, compute_compatible_process_mesh, merge_process_mesh)
class TestProcessMesh(unittest.TestCase): class TestProcessMesh(unittest.TestCase):
...@@ -39,6 +40,54 @@ class TestProcessMesh(unittest.TestCase): ...@@ -39,6 +40,54 @@ class TestProcessMesh(unittest.TestCase):
self.assertNotEqual(process_mesh, process_mesh2) self.assertNotEqual(process_mesh, process_mesh2)
self.assertEqual(str(process_mesh), str(process_mesh)) self.assertEqual(str(process_mesh), str(process_mesh))
def test_compute_compatible_process_mesh(self):
process_mesh1 = ProcessMesh([[0, 1, 2], [3, 4, 5]],
dim_names=["x", "y"])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, None])
self.assertEqual(compatible_process_mesh, process_mesh1)
compatible_process_mesh = compute_compatible_process_mesh(
[None, process_mesh1])
self.assertEqual(compatible_process_mesh, process_mesh1)
process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2])
self.assertEqual(compatible_process_mesh, process_mesh1)
self.assertEqual(compatible_process_mesh, process_mesh2)
process_mesh2 = ProcessMesh([[0, 1, 2, 3, 4, 5]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2])
self.assertEqual(compatible_process_mesh, process_mesh1)
process_mesh2 = ProcessMesh([[0, 1, 2]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2])
self.assertEqual(compatible_process_mesh, process_mesh1)
def test_merge_process_mesh(self):
process_mesh1 = ProcessMesh([[0, 1, 2], [3, 4, 5]],
dim_names=["x", "y"])
merged_process_mesh = merge_process_mesh([process_mesh1, None])
print(merged_process_mesh)
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
merged_process_mesh = merge_process_mesh([None, process_mesh1])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[0, 1, 2]])
merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[6, 7]])
merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
self.assertEqual(merged_process_mesh,
ProcessMesh([0, 1, 2, 3, 4, 5, 6, 7]))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# import yaml
import unittest
import paddle.distributed.auto_parallel as auto
class TestStrategy(unittest.TestCase):
def test_default_config(self):
strategy = auto.Strategy()
recompute = strategy.recompute
self.assertEqual(recompute.enable, False)
self.assertEqual(recompute.checkpoints, None)
amp = strategy.amp
self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
self.assertAlmostEqual(amp.incr_ratio, 2.0)
self.assertAlmostEqual(amp.decr_ratio, 0.8)
self.assertEqual(amp.use_dynamic_loss_scaling, True)
self.assertEqual(amp.custom_black_list, [])
self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_pure_fp16, False)
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_optimizer_fp16, False)
sharding = strategy.sharding
self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.degree, 8)
self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0)
self.assertEqual(sharding.enable_tuning, False)
self.assertEqual(sharding.tuning_range, [])
gradient_merge = strategy.gradient_merge
self.assertEqual(gradient_merge.enable, False)
self.assertEqual(gradient_merge.k_steps, 1)
self.assertEqual(gradient_merge.avg, True)
qat = strategy.qat
self.assertEqual(qat.enable, False)
self.assertEqual(qat.channel_wise_abs_max, True)
self.assertEqual(qat.weight_bits, 8)
self.assertEqual(qat.activation_bits, 8)
self.assertEqual(qat.not_quant_pattern, ['skip_quant'])
self.assertEqual(qat.algo, None)
tuning = strategy.tuning
self.assertEqual(tuning.enable, False)
self.assertEqual(tuning.batch_size, 1)
self.assertEqual(tuning.dataset, None)
self.assertEqual(tuning.profile_start_step, 1)
self.assertEqual(tuning.profile_end_step, 1)
self.assertEqual(tuning.run_after_tuning, True)
self.assertEqual(tuning.verbose, True)
def test_modify_config(self):
strategy = auto.Strategy()
recompute = strategy.recompute
recompute.enable = True
recompute.checkpoints = ["x"]
self.assertEqual(recompute.enable, True)
self.assertEqual(recompute.checkpoints, ["x"])
amp = strategy.amp
amp.enable = True
amp.init_loss_scaling = 16384.0
amp.incr_every_n_steps = 2000
amp.decr_every_n_nan_or_inf = 4
amp.incr_ratio = 4.0
amp.decr_ratio = 0.4
amp.use_dynamic_loss_scaling = False
amp.custom_white_list = ["x"]
amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"]
amp.use_pure_fp16 = True
amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True)
self.assertAlmostEqual(amp.init_loss_scaling, 16384.0)
self.assertEqual(amp.incr_every_n_steps, 2000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 4)
self.assertAlmostEqual(amp.incr_ratio, 4.0)
self.assertAlmostEqual(amp.decr_ratio, 0.4)
self.assertEqual(amp.use_dynamic_loss_scaling, False)
self.assertEqual(amp.custom_white_list, ["x"])
self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_pure_fp16, True)
self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True)
sharding = strategy.sharding
sharding.enable = True
sharding.stage = 2
sharding.degree = 2
sharding.segment_broadcast_MB = 64.0
sharding.enable_tuning = True
sharding.tuning_range = [1, 2, 3]
self.assertEqual(sharding.enable, True)
self.assertEqual(sharding.stage, 2)
self.assertEqual(sharding.degree, 2)
self.assertAlmostEqual(sharding.segment_broadcast_MB, 64.0)
self.assertEqual(sharding.enable_tuning, True)
self.assertEqual(sharding.tuning_range, [1, 2, 3])
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 4
gradient_merge.avg = False
self.assertEqual(gradient_merge.enable, True)
self.assertEqual(gradient_merge.k_steps, 4)
self.assertEqual(gradient_merge.avg, False)
# def test_file_config(self):
# yaml_data = """
# all_ranks: false
# amp:
# custom_black_list:
# - y
# custom_black_varnames:
# - z
# custom_white_list:
# - x
# decr_every_n_nan_or_inf: 4
# decr_ratio: 0.4
# enable: false
# incr_every_n_steps: 2000
# incr_ratio: 4.0
# init_loss_scaling: 16384.0
# use_dynamic_loss_scaling: false
# use_fp16_guard: false
# use_optimizer_fp16: true
# use_pure_fp16: true
# auto_mode: semi
# gradient_merge:
# avg: false
# enable: false
# k_steps: 4
# gradient_scale: true
# qat:
# activation_bits: 8
# algo: null
# channel_wise_abs_max: true
# enable: false
# not_quant_pattern:
# - skip_quant
# weight_bits: 8
# recompute:
# checkpoints: null
# enable: false
# enable_tuning: false
# return_numpy: true
# seed: null
# sharding:
# enable: false
# enable_tuning: true
# segment_broadcast_MB: 64.0
# degree: 8
# stage: 2
# tuning_range: None
# split_data: false
# tuning:
# batch_size: 1
# dataset: null
# enable: false
# profile_end_step: 1
# profile_start_step: 1
# run_after_tuning: true
# verbose: true
# use_cache: true
# """
# yaml_path = "./strategy.yml"
# yaml_dict = yaml.load(yaml_data, Loader=yaml.Loader)
# with open(yaml_path, 'w') as outfile:
# yaml.dump(yaml_dict, outfile, default_flow_style=False)
# strategy = auto.Strategy(yaml_path)
# self.assertEqual(yaml_dict, strategy.to_dict())
# # Remove the created file
# if os.path.exists(yaml_path):
# os.remove(yaml_path)
if __name__ == '__main__':
unittest.main()
...@@ -27,7 +27,6 @@ from paddle import LazyGuard ...@@ -27,7 +27,6 @@ from paddle import LazyGuard
from paddle.io import Dataset from paddle.io import Dataset
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.helper import ProgramHelper from paddle.distributed.auto_parallel.helper import ProgramHelper
batch_size = 4 batch_size = 4
...@@ -140,23 +139,19 @@ class TestToStatic(unittest.TestCase): ...@@ -140,23 +139,19 @@ class TestToStatic(unittest.TestCase):
dataset = MyDataset(batch_num * batch_size) dataset = MyDataset(batch_num * batch_size)
inputs = InputSpec([batch_size, hidden_size], 'float32', 'x') # inputs = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels = InputSpec([batch_size], 'int64', 'label') # labels = InputSpec([batch_size], 'int64', 'label')
engine = Engine(model=mlp,
inputs_spec=inputs,
labels_spec=labels,
strategy=None)
assert _non_static_mode() == True assert _non_static_mode() == True
engine = auto.Engine(model=mlp,
engine.prepare(optimizer=optimizer,
loss=loss, loss=loss,
metrics=paddle.metric.Accuracy()) optimizer=optimizer,
metrics=paddle.metric.Accuracy(),
assert _non_static_mode() == False strategy=None)
engine.fit(dataset, batch_size=batch_size) engine.fit(dataset, batch_size=batch_size)
engine.evaluate(dataset, batch_size=batch_size) engine.evaluate(dataset, batch_size=batch_size)
engine.predict(dataset, batch_size=batch_size) engine.predict(dataset, batch_size=batch_size)
assert _non_static_mode() == False
class TestLazyInit(unittest.TestCase): class TestLazyInit(unittest.TestCase):
......
...@@ -36,7 +36,7 @@ batch_size = 4 ...@@ -36,7 +36,7 @@ batch_size = 4
epoch_num = 10 epoch_num = 10
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]] _g_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y'])
def get_random_inputs_and_labels(input_shape, label_shape): def get_random_inputs_and_labels(input_shape, label_shape):
...@@ -84,18 +84,12 @@ class MLPLayer(nn.Layer): ...@@ -84,18 +84,12 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
out = self.norm(input) out = self.norm(input)
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, _g_process_mesh[:, 0],
dist_attr={ [None, 'x'])
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
out = self.linear0(out) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight, _g_process_mesh[:, 1],
dist_attr={ ['x', None])
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear1(out) out = self.linear1(out)
return out return out
...@@ -155,16 +149,8 @@ def get_program(): ...@@ -155,16 +149,8 @@ def get_program():
dataloader.set_batch_generator(batch_generator_creator(), dataloader.set_batch_generator(batch_generator_creator(),
places=paddle.static.cuda_places()) places=paddle.static.cuda_places())
# data dist_attr # data dist_attr
auto.shard_tensor(input, auto.shard_tensor(input, _g_process_mesh[:, 0], [None, None, None])
dist_attr={ auto.shard_tensor(label, _g_process_mesh[:, 0], [None, None, None])
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, -1, -1]
})
mlp_start = MLPLayer(hidden_size=hidden_size, mlp_start = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
......
...@@ -37,7 +37,7 @@ batch_size = 4 ...@@ -37,7 +37,7 @@ batch_size = 4
epoch_num = 10 epoch_num = 10
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
_g_process_mesh = auto.ProcessMesh([0, 1]) _g_process_mesh = auto.ProcessMesh([0, 1], dim_names=['x'])
def get_random_inputs_and_labels(input_shape, label_shape): def get_random_inputs_and_labels(input_shape, label_shape):
...@@ -85,61 +85,21 @@ class MLPLayer(nn.Layer): ...@@ -85,61 +85,21 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
auto.shard_tensor(self.norm.weight, auto.shard_tensor(self.norm.weight, _g_process_mesh, [None])
dist_attr={ auto.shard_tensor(self.norm.bias, _g_process_mesh, [None])
"process_mesh": _g_process_mesh, auto.shard_tensor(self.linear0.weight, _g_process_mesh, [None, 'x'])
"dims_mapping": [-1] auto.shard_tensor(self.linear0.bias, _g_process_mesh, ['x'])
}) auto.shard_tensor(self.linear1.weight, _g_process_mesh, ['x', None])
auto.shard_tensor(self.norm.bias, auto.shard_tensor(self.linear1.bias, _g_process_mesh, [None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear0.bias,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear1.bias,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
out = self.norm(input) out = self.norm(input)
auto.shard_tensor(out, auto.shard_tensor(out, _g_process_mesh, [None, None, None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
out = self.linear0(out) out = self.linear0(out)
auto.shard_tensor(out, auto.shard_tensor(out, _g_process_mesh, [None, None, 'x'])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, 0]
})
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
auto.shard_tensor(out, auto.shard_tensor(out, _g_process_mesh, [None, None, 'x'])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, 0]
})
out = self.linear1(out) out = self.linear1(out)
auto.shard_tensor(out, auto.shard_tensor(out, _g_process_mesh, [None, None, None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
return out return out
...@@ -155,21 +115,13 @@ def get_program(): ...@@ -155,21 +115,13 @@ def get_program():
# 循环计数器 # 循环计数器
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
auto.shard_tensor(i, auto.shard_tensor(i, _g_process_mesh, [None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
# 循环次数 # 循环次数
loop_len = fluid.layers.fill_constant(shape=[1], loop_len = fluid.layers.fill_constant(shape=[1],
dtype='int64', dtype='int64',
value=epoch_num) value=epoch_num)
auto.shard_tensor(loop_len, auto.shard_tensor(loop_len, _g_process_mesh, [None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
# input # input
input = static.data(name="input", input = static.data(name="input",
...@@ -188,25 +140,13 @@ def get_program(): ...@@ -188,25 +140,13 @@ def get_program():
dataloader.set_batch_generator(batch_generator_creator(), dataloader.set_batch_generator(batch_generator_creator(),
places=paddle.static.cuda_places()) places=paddle.static.cuda_places())
# data dist_attr # data dist_attr
auto.shard_tensor(input, auto.shard_tensor(input, _g_process_mesh, [None, None, None])
dist_attr={ auto.shard_tensor(label, _g_process_mesh, [None, None, None])
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
# fill constant bsz like # fill constant bsz like
tmp = paddle.fluid.layers.fill_constant_batch_size_like( tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=input, shape=[-1, 16, 0, 48], dtype='float32', value=0) input=input, shape=[-1, 16, 0, 48], dtype='float32', value=0)
auto.shard_tensor(tmp, auto.shard_tensor(tmp, _g_process_mesh, [None, 'x', None, None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, 0, -1, -1]
})
# model # model
mlp_start = MLPLayer(hidden_size=hidden_size, mlp_start = MLPLayer(hidden_size=hidden_size,
...@@ -216,28 +156,21 @@ def get_program(): ...@@ -216,28 +156,21 @@ def get_program():
pred = mlp_start(input) pred = mlp_start(input)
input_array = fluid.layers.array_write(pred, i) input_array = fluid.layers.array_write(pred, i)
auto.shard_tensor(input_array, # TODO: check whether this annotation is needed
dist_attr={ # auto.shard_tensor(input_array,
"process_mesh": _g_process_mesh, # dist_attr={
"dims_mapping": [-1, -1, -1] # "process_mesh": _g_process_mesh,
}) # "dims_mapping": [-1, -1, -1]
# })
cond = fluid.layers.less_than(x=i, y=loop_len) cond = fluid.layers.less_than(x=i, y=loop_len)
auto.shard_tensor(cond, auto.shard_tensor(cond, _g_process_mesh, [None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
while_op = fluid.layers.While(cond=cond) while_op = fluid.layers.While(cond=cond)
with while_op.block(): with while_op.block():
pre_input = fluid.layers.array_read(array=input_array, i=i) pre_input = fluid.layers.array_read(array=input_array, i=i)
auto.shard_tensor(pre_input, auto.shard_tensor(pre_input, _g_process_mesh, [None, None, None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
mlp_while = MLPLayer(hidden_size=hidden_size, mlp_while = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -251,11 +184,7 @@ def get_program(): ...@@ -251,11 +184,7 @@ def get_program():
fluid.layers.less_than(x=i, y=loop_len, cond=cond) fluid.layers.less_than(x=i, y=loop_len, cond=cond)
end_pred = fluid.layers.array_read(array=input_array, i=i) end_pred = fluid.layers.array_read(array=input_array, i=i)
auto.shard_tensor(end_pred, auto.shard_tensor(end_pred, _g_process_mesh, [None, None, None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
mlp_end = MLPLayer(hidden_size=hidden_size, mlp_end = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -264,18 +193,10 @@ def get_program(): ...@@ -264,18 +193,10 @@ def get_program():
pred = mlp_end(end_pred) pred = mlp_end(end_pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label) error_cost = paddle.nn.functional.square_error_cost(pred, label)
auto.shard_tensor(error_cost, auto.shard_tensor(error_cost, _g_process_mesh, [None, None, None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
loss = paddle.mean(error_cost) loss = paddle.mean(error_cost)
auto.shard_tensor(loss, auto.shard_tensor(loss, _g_process_mesh, [None])
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
return train_program, start_program, dataloader, i, loss return train_program, start_program, dataloader, i, loss
......
...@@ -67,38 +67,18 @@ class MLPLayer(nn.Layer): ...@@ -67,38 +67,18 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
dist_attr={ auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "mp": elif _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, _global_process_mesh,
dist_attr={ [None, "x"])
"process_mesh": _global_process_mesh, auto.shard_tensor(self.linear1.weight, _global_process_mesh,
"dims_mapping": [-1, 0] ["x", None])
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp": elif _global_parallel_strategy == "dp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, _global_process_mesh,
dist_attr={ [None, None])
"process_mesh": _global_process_mesh, auto.shard_tensor(self.linear1.weight, _global_process_mesh,
"dims_mapping": [-1, -1] [None, None])
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -120,28 +100,12 @@ def mlp_forward(train_program, start_program): ...@@ -120,28 +100,12 @@ def mlp_forward(train_program, start_program):
dtype='float32') dtype='float32')
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(input, auto.shard_tensor(input, PP_MESH_0, [None, None])
dist_attr={ auto.shard_tensor(label, PP_MESH_1, [None, None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "dp": elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, ["x", None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "mp": elif _global_parallel_strategy == "mp":
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, [None, None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -186,7 +150,7 @@ class TestMLPAutoConvert(unittest.TestCase): ...@@ -186,7 +150,7 @@ class TestMLPAutoConvert(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1]) _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
input = np.random.random(size=(80, 64)).astype('float32') input = np.random.random(size=(80, 64)).astype('float32')
label = np.random.random(size=(80, 1)).astype('float32') label = np.random.random(size=(80, 1)).astype('float32')
...@@ -212,11 +176,11 @@ class TestMLPAutoConvert(unittest.TestCase): ...@@ -212,11 +176,11 @@ class TestMLPAutoConvert(unittest.TestCase):
set_default_distributed_context(None) set_default_distributed_context(None)
_global_parallel_strategy = "pp" _global_parallel_strategy = "pp"
_global_process_mesh = auto.ProcessMesh([0, 1]) _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
global PP_MESH_0 global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0]) PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["pp0"])
global PP_MESH_1 global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1]) PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["pp1"])
dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program( dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program(
) )
...@@ -268,7 +232,7 @@ class TestMLPAutoConvert2(unittest.TestCase): ...@@ -268,7 +232,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "pp" _global_parallel_strategy = "pp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1]) _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
global PP_MESH_0 global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0]) PP_MESH_0 = auto.ProcessMesh(mesh=[0])
global PP_MESH_1 global PP_MESH_1
...@@ -303,7 +267,7 @@ class TestMLPAutoConvert2(unittest.TestCase): ...@@ -303,7 +267,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
set_default_distributed_context(None) set_default_distributed_context(None)
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
_global_process_mesh = auto.ProcessMesh([0, 1]) _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program( dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program(
) )
...@@ -350,7 +314,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase): ...@@ -350,7 +314,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1]) _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
dist_main_prog, _, _ = get_distributed_program() dist_main_prog, _, _ = get_distributed_program()
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
save_distributed_checkpoint(dist_main_prog, [""], [""], save_distributed_checkpoint(dist_main_prog, [""], [""],
......
...@@ -38,7 +38,7 @@ class TestDataUnshard(unittest.TestCase): ...@@ -38,7 +38,7 @@ class TestDataUnshard(unittest.TestCase):
def create_model(train_program, start_program): def create_model(train_program, start_program):
with paddle.static.program_guard(train_program, start_program): with paddle.static.program_guard(train_program, start_program):
MESH_0 = auto.ProcessMesh([0, 1]) MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"])
input = paddle.static.data(name='input', shape=[2, 8]) input = paddle.static.data(name='input', shape=[2, 8])
label = paddle.static.data(name='label', shape=[2, 8]) label = paddle.static.data(name='label', shape=[2, 8])
...@@ -47,26 +47,10 @@ class TestDataUnshard(unittest.TestCase): ...@@ -47,26 +47,10 @@ class TestDataUnshard(unittest.TestCase):
linear0 = nn.Linear(8, 8, weight_attr) linear0 = nn.Linear(8, 8, weight_attr)
linear1 = nn.Linear(8, 8, weight_attr) linear1 = nn.Linear(8, 8, weight_attr)
auto.shard_tensor(input, auto.shard_tensor(input, MESH_0, ["x", None])
dist_attr={ auto.shard_tensor(label, MESH_0, ["x", None])
"process_mesh": MESH_0, auto.shard_tensor(linear0.weight, MESH_0, [None, None])
"dims_mapping": [0, -1] auto.shard_tensor(linear1.weight, MESH_0, [None, None])
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(linear0.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(linear1.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
linear0_out = linear0(input) linear0_out = linear0(input)
gelu_out = F.gelu(linear0_out) gelu_out = F.gelu(linear0_out)
...@@ -124,7 +108,7 @@ class TestDataUnshard(unittest.TestCase): ...@@ -124,7 +108,7 @@ class TestDataUnshard(unittest.TestCase):
def create_model(train_program, start_program): def create_model(train_program, start_program):
with paddle.static.program_guard(train_program, start_program): with paddle.static.program_guard(train_program, start_program):
MESH_0 = auto.ProcessMesh([0, 1]) MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"])
input = paddle.static.data(name='input', shape=[8, 8]) input = paddle.static.data(name='input', shape=[8, 8])
label = paddle.static.data(name='label', shape=[8, 8]) label = paddle.static.data(name='label', shape=[8, 8])
...@@ -133,27 +117,10 @@ class TestDataUnshard(unittest.TestCase): ...@@ -133,27 +117,10 @@ class TestDataUnshard(unittest.TestCase):
linear0 = nn.Linear(8, 8, weight_attr) linear0 = nn.Linear(8, 8, weight_attr)
linear1 = nn.Linear(8, 8, weight_attr) linear1 = nn.Linear(8, 8, weight_attr)
auto.shard_tensor(input, auto.shard_tensor(input, MESH_0, [None, None])
dist_attr={ auto.shard_tensor(label, MESH_0, [None, None])
"process_mesh": MESH_0, auto.shard_tensor(linear0.weight, MESH_0, [None, "x"])
"dims_mapping": [-1, -1] auto.shard_tensor(linear1.weight, MESH_0, ["x", None])
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(linear0.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(linear1.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
linear0_out = linear0(input) linear0_out = linear0(input)
gelu_out = F.gelu(linear0_out) gelu_out = F.gelu(linear0_out)
......
...@@ -114,30 +114,18 @@ class MultiHeadAttention(nn.Layer): ...@@ -114,30 +114,18 @@ class MultiHeadAttention(nn.Layer):
""" """
q = self.q_proj(query) q = self.q_proj(query)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor(self.q_proj.weight, auto.shard_tensor(self.q_proj.weight, _global_process_mesh,
dist_attr={ [None, "x"])
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight, auto.shard_tensor(self.q_proj.weight, _global_process_mesh,
dist_attr={ [None, "y"])
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.q_proj.weight, auto.shard_tensor(self.q_proj.weight, MPPP_MESH_LIST[self.mesh_idx],
dist_attr={ [None, "x"])
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.q_proj.weight, auto.shard_tensor(self.q_proj.weight,
dist_attr={ DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"])
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
if isinstance(cache, self.StaticCache): if isinstance(cache, self.StaticCache):
...@@ -165,56 +153,30 @@ class MultiHeadAttention(nn.Layer): ...@@ -165,56 +153,30 @@ class MultiHeadAttention(nn.Layer):
""" """
k = self.k_proj(key) k = self.k_proj(key)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor(self.k_proj.weight, auto.shard_tensor(self.k_proj.weight, _global_process_mesh,
dist_attr={ [None, "x"])
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.k_proj.weight, auto.shard_tensor(self.k_proj.weight, _global_process_mesh,
dist_attr={ [None, "y"])
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.k_proj.weight, auto.shard_tensor(self.k_proj.weight, MPPP_MESH_LIST[self.mesh_idx],
dist_attr={ [None, "x"])
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.k_proj.weight, auto.shard_tensor(self.k_proj.weight,
dist_attr={ DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"])
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
v = self.v_proj(value) v = self.v_proj(value)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.v_proj.weight, _global_process_mesh,
dist_attr={ [None, "x"])
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.v_proj.weight, _global_process_mesh,
dist_attr={ [None, "y"])
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.v_proj.weight, MPPP_MESH_LIST[self.mesh_idx],
dist_attr={ [None, "x"])
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.v_proj.weight,
dist_attr={ DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"])
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
...@@ -287,30 +249,18 @@ class MultiHeadAttention(nn.Layer): ...@@ -287,30 +249,18 @@ class MultiHeadAttention(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight, _global_process_mesh,
dist_attr={ ["x", None])
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight, _global_process_mesh,
dist_attr={ ["y", None])
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight,
dist_attr={ MPPP_MESH_LIST[self.mesh_idx], ["x", None])
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight,
dist_attr={ DPMPPP_MESH_LIST[self.mesh_idx], ["y", None])
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1]
})
outs = [out] outs = [out]
if self.need_weights: if self.need_weights:
outs.append(weights) outs.append(weights)
...@@ -352,96 +302,53 @@ class TransformerDecoder(nn.Layer): ...@@ -352,96 +302,53 @@ class TransformerDecoder(nn.Layer):
new_caches = [] new_caches = []
self.checkpoints = [] self.checkpoints = []
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(output, auto.shard_tensor(output, PP_MESH_LIST[0],
dist_attr={ [None for i in range(len(output.shape))])
"process_mesh":
PP_MESH_LIST[0],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
if _global_parallel_strategy == "dp_pp": if _global_parallel_strategy == "dp_pp":
auto.shard_tensor(output, auto.shard_tensor(output, DPPP_MESH_LIST[0], ["x"].extends(
dist_attr={ [None for i in range(len(output.shape) - 1)]))
"process_mesh":
DPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
if _global_parallel_strategy == "mp_pp": if _global_parallel_strategy == "mp_pp":
auto.shard_tensor(output, auto.shard_tensor(output, MPPP_MESH_LIST[0],
dist_attr={ [None for i in range(len(output.shape))])
"process_mesh":
MPPP_MESH_LIST[0],
"dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)]
})
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(output, auto.shard_tensor(output, DPMPPP_MESH_LIST[0], ["x"].extends(
dist_attr={ [None for i in range(len(output.shape) - 1)]))
"process_mesh":
DPMPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
if cache is None: if cache is None:
if use_cache: if use_cache:
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod, PP_MESH_LIST[mod.mesh_idx])(output, memory,
dist_attr={ tgt_mask,
"process_mesh": PP_MESH_LIST[mod.mesh_idx] use_cache, cache)
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output, PP_MESH_LIST[mod.mesh_idx],
dist_attr={ [None for i in range(len(output.shape))])
"process_mesh":
PP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
elif _global_parallel_strategy == "dp_pp": elif _global_parallel_strategy == "dp_pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod, DPPP_MESH_LIST[mod.mesh_idx])(output, memory,
dist_attr={ tgt_mask,
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx] use_cache, cache)
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output, DPPP_MESH_LIST[mod.mesh_idx], ["x"].extends(
dist_attr={ [None for i in range(len(output.shape) - 1)]))
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory,
dist_attr={ tgt_mask,
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx] use_cache, cache)
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output, MPPP_MESH_LIST[mod.mesh_idx],
dist_attr={ [None for i in range(len(output.shape))])
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)]
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod,
dist_attr={ DPMPPP_MESH_LIST[mod.mesh_idx])(output, memory,
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] tgt_mask, use_cache,
})(output, memory, tgt_mask, use_cache, cache) cache)
auto.shard_tensor( auto.shard_tensor(
output, output, DPMPPP_MESH_LIST[mod.mesh_idx],
dist_attr={ [None for i in range(len(output.shape))])
"process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
else: else:
output, new_cache = mod(output, output, new_cache = mod(output,
memory, memory,
...@@ -451,64 +358,36 @@ class TransformerDecoder(nn.Layer): ...@@ -451,64 +358,36 @@ class TransformerDecoder(nn.Layer):
new_caches.append(new_cache) new_caches.append(new_cache)
else: else:
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
output = auto.shard_op(mod, output = auto.shard_op(mod, PP_MESH_LIST[mod.mesh_idx])(
dist_attr={ output, memory, tgt_mask, use_cache, cache)
"process_mesh":
PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output, PP_MESH_LIST[mod.mesh_idx],
dist_attr={ [None for i in range(len(output.shape))])
"process_mesh":
PP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
elif _global_parallel_strategy == "dp_pp": elif _global_parallel_strategy == "dp_pp":
output = auto.shard_op(mod, output = auto.shard_op(
dist_attr={ mod, DPPP_MESH_LIST[mod.mesh_idx])(output, memory,
"process_mesh": tgt_mask,
DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache) use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output, DPPP_MESH_LIST[mod.mesh_idx], ["x"].extends(
dist_attr={ [None for i in range(len(output.shape) - 1)]))
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
output = auto.shard_op(mod, output = auto.shard_op(
dist_attr={ mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory,
"process_mesh": tgt_mask,
MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache) use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output, MPPP_MESH_LIST[mod.mesh_idx],
dist_attr={ [None for i in range(len(output.shape))])
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)]
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
output = auto.shard_op( output = auto.shard_op(mod,
mod, DPMPPP_MESH_LIST[mod.mesh_idx])(
dist_attr={ output, memory, tgt_mask,
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] use_cache, cache)
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output, DPMPPP_MESH_LIST[mod.mesh_idx],
dist_attr={ ["x"].extends(
"process_mesh": [None for i in range(len(output.shape) - 1)]))
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
else: else:
output = mod(output, output = mod(output,
memory, memory,
...@@ -519,58 +398,33 @@ class TransformerDecoder(nn.Layer): ...@@ -519,58 +398,33 @@ class TransformerDecoder(nn.Layer):
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod,
dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx] PP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask,
})(output, memory, tgt_mask, use_cache, use_cache, cache)
cache) auto.shard_tensor(output, PP_MESH_LIST[mod.mesh_idx],
auto.shard_tensor( [None for i in range(len(output.shape))])
output,
dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
elif _global_parallel_strategy == "dp_pp": elif _global_parallel_strategy == "dp_pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod,
dist_attr={ DPPP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask,
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx] use_cache, cache)
})(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor(output, DPPP_MESH_LIST[mod.mesh_idx], [
auto.shard_tensor( "x"
output, ].extends([None for i in range(len(output.shape) - 1)]))
dist_attr={
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)]
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod,
dist_attr={ MPPP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask,
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx] use_cache, cache)
})(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor(output, MPPP_MESH_LIST[mod.mesh_idx],
auto.shard_tensor( [None for i in range(len(output.shape))])
output,
dist_attr={
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1] + [-1 for i in range(len(output.shape) - 1)]
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod, DPMPPP_MESH_LIST[mod.mesh_idx])(output, memory,
dist_attr={ tgt_mask,
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] use_cache, cache)
})(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor(output, DPMPPP_MESH_LIST[mod.mesh_idx], [
auto.shard_tensor( "x"
output, ].extends([None for i in range(len(output.shape) - 1)]))
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)]
})
else: else:
output, new_cache = mod(output, output, new_cache = mod(output,
memory, memory,
...@@ -661,55 +515,30 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -661,55 +515,30 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight, _global_process_mesh,
dist_attr={ [None, "x"])
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight, _global_process_mesh,
dist_attr={ [None, "y"])
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight,
dist_attr={ MPPP_MESH_LIST[self.mesh_idx], [None, "x"])
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight,
dist_attr={ DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"])
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear2.weight, auto.shard_tensor(self.linear2.weight, _global_process_mesh,
dist_attr={ ["x", None])
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear2.weight, auto.shard_tensor(self.linear2.weight, _global_process_mesh,
dist_attr={ ["y", None])
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.linear2.weight, auto.shard_tensor(self.linear2.weight,
dist_attr={ MPPP_MESH_LIST[self.mesh_idx], ["x", None])
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.linear2.weight, auto.shard_tensor(self.linear2.weight,
dist_attr={ DPMPPP_MESH_LIST[self.mesh_idx], ["y", None])
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1]
})
tgt = self.dropout2( tgt = self.dropout2(
self.linear2(F.gelu(self.linear1(tgt), approximate=True))) self.linear2(F.gelu(self.linear1(tgt), approximate=True)))
tgt = residual + tgt tgt = residual + tgt
...@@ -757,29 +586,18 @@ class GPTEmbeddings(nn.Layer): ...@@ -757,29 +586,18 @@ class GPTEmbeddings(nn.Layer):
position_ids = seq_length - ones position_ids = seq_length - ones
input_embedings = self.word_embeddings(input_ids) input_embedings = self.word_embeddings(input_ids)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor(self.word_embeddings.weight, auto.shard_tensor(self.word_embeddings.weight, _global_process_mesh,
dist_attr={ ["x", None])
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.word_embeddings.weight, auto.shard_tensor(self.word_embeddings.weight, _global_process_mesh,
dist_attr={ ["y", None])
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.word_embeddings.weight, auto.shard_tensor(self.word_embeddings.weight, MPPP_MESH_LIST[0],
dist_attr={ ["x", None])
"process_mesh": MPPP_MESH_LIST[0],
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.word_embeddings.weight, auto.shard_tensor(self.word_embeddings.weight, DPMPPP_MESH_LIST[0],
dist_attr={ ["y", None])
"process_mesh": DPMPPP_MESH_LIST[0],
"dims_mapping": [1, -1]
})
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings embeddings = input_embedings + position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
...@@ -868,29 +686,14 @@ class GPTModel(nn.Layer): ...@@ -868,29 +686,14 @@ class GPTModel(nn.Layer):
embedding_output = self.embeddings(input_ids=input_ids, embedding_output = self.embeddings(input_ids=input_ids,
position_ids=position_ids) position_ids=position_ids)
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(input_ids, auto.shard_tensor(input_ids, PP_MESH_LIST[0],
dist_attr={ [None for i in range(len(input_ids.shape))])
"process_mesh":
PP_MESH_LIST[0],
"dims_mapping":
[-1 for i in range(len(input_ids.shape))]
})
if _global_parallel_strategy == "dp_pp": if _global_parallel_strategy == "dp_pp":
auto.shard_tensor(input_ids, auto.shard_tensor(input_ids, DPPP_MESH_LIST[0], ["x"].extends(
dist_attr={ [None for i in range(len(input_ids.shape) - 1)]))
"process_mesh":
DPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(input_ids.shape) - 1)]
})
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(input_ids, auto.shard_tensor(input_ids, DPMPPP_MESH_LIST[0], ["x"].extends(
dist_attr={ [None for i in range(len(input_ids.shape) - 1)]))
"process_mesh":
DPMPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(input_ids.shape) - 1)]
})
encoder_outputs = self.decoder(embedding_output, encoder_outputs = self.decoder(embedding_output,
memory=None, memory=None,
tgt_mask=attention_mask, tgt_mask=attention_mask,
...@@ -923,6 +726,10 @@ class GPTForPretraining(nn.Layer): ...@@ -923,6 +726,10 @@ class GPTForPretraining(nn.Layer):
masked_positions=None, masked_positions=None,
use_cache=False, use_cache=False,
cache=None): cache=None):
input_ids.stop_gradient = True
position_ids.stop_gradient = True
attention_mask.stop_gradient = True
outputs = self.gpt(input_ids, outputs = self.gpt(input_ids,
position_ids=position_ids, position_ids=position_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -936,40 +743,42 @@ class GPTForPretraining(nn.Layer): ...@@ -936,40 +743,42 @@ class GPTForPretraining(nn.Layer):
x = encoder_outputs x = encoder_outputs
w = self.gpt.embeddings.word_embeddings.weight w = self.gpt.embeddings.word_embeddings.weight
mesh = _global_process_mesh mesh = None
x_dims_mapping = [-1 for i in range(len(x.shape))]
w_dims_mapping = [-1 for i in range(len(w.shape))]
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
mesh = PP_MESH_LIST[-1] mesh = PP_MESH_LIST[-1]
x_dims_mapping = [None for i in range(len(x.shape))]
w_dims_mapping = [None for i in range(len(w.shape))]
elif _global_parallel_strategy == "dp": elif _global_parallel_strategy == "dp":
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] mesh = _global_process_mesh
x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
w_dims_mapping = [None for i in range(len(w.shape))]
elif _global_parallel_strategy == "mp": elif _global_parallel_strategy == "mp":
w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)] mesh = _global_process_mesh
x_dims_mapping = [None for i in range(len(x.shape))]
w_dims_mapping = ["x"] + [None for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] mesh = _global_process_mesh
w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)] x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_pp": elif _global_parallel_strategy == "dp_pp":
mesh = DPPP_MESH_LIST[-1] mesh = DPPP_MESH_LIST[-1]
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
w_dims_mapping = [None for i in range(len(w.shape))]
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
mesh = MPPP_MESH_LIST[-1] mesh = MPPP_MESH_LIST[-1]
w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)] x_dims_mapping = [None for i in range(len(x.shape))]
w_dims_mapping = ["x"] + [-1 for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
mesh = DPMPPP_MESH_LIST[-1] mesh = DPMPPP_MESH_LIST[-1]
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)] w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]
matmul = auto.shard_op(paddle.matmul, if mesh:
dist_attr={ matmul = auto.shard_op(paddle.matmul, mesh,
'process_mesh': mesh, [x_dims_mapping, w_dims_mapping, None])
x: {
"dims_mapping": x_dims_mapping
},
w: {
"dims_mapping": w_dims_mapping
}
})
logits = matmul(x, w, transpose_y=True) logits = matmul(x, w, transpose_y=True)
else:
logits = paddle.matmul(x, w, transpose_y=True)
if use_cache: if use_cache:
return logits, cached_kvs return logits, cached_kvs
...@@ -988,25 +797,29 @@ class GPTPretrainingCriterion(nn.Layer): ...@@ -988,25 +797,29 @@ class GPTPretrainingCriterion(nn.Layer):
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
def forward(self, prediction_scores, masked_lm_labels, loss_mask): def forward(self, prediction_scores, masked_lm_labels, loss_mask):
masked_lm_labels.stop_gradient = True
loss_mask.stop_gradient = True
mesh = _global_process_mesh mesh = None
dims_mapping = [-1 for i in range(len(loss_mask.shape))]
if _global_parallel_strategy == "dp": if _global_parallel_strategy == "dp":
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] mesh = _global_process_mesh
dims_mapping = ["x"
] + [None for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] mesh = _global_process_mesh
dims_mapping = ["x"
] + [None for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_pp": elif _global_parallel_strategy == "dp_pp":
mesh = DPPP_MESH_LIST[-1] mesh = DPPP_MESH_LIST[-1]
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] dims_mapping = ["x"
] + [None for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
mesh = DPMPPP_MESH_LIST[-1] mesh = DPMPPP_MESH_LIST[-1]
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] dims_mapping = ["x"
] + [None for i in range(len(loss_mask.shape) - 1)]
auto.shard_tensor(loss_mask, if mesh:
dist_attr={ auto.shard_tensor(loss_mask, mesh, dims_mapping)
"process_mesh": mesh,
"dims_mapping": dims_mapping
})
masked_lm_loss = self.loss_func(prediction_scores, masked_lm_loss = self.loss_func(prediction_scores,
masked_lm_labels.unsqueeze(2)) masked_lm_labels.unsqueeze(2))
......
...@@ -64,38 +64,18 @@ class MLPLayer(nn.Layer): ...@@ -64,38 +64,18 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
dist_attr={ auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "mp": elif _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, _global_process_mesh,
dist_attr={ [None, "x"])
"process_mesh": _global_process_mesh, auto.shard_tensor(self.linear1.weight, _global_process_mesh,
"dims_mapping": [-1, 0] ["x", None])
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp": elif _global_parallel_strategy == "dp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, _global_process_mesh,
dist_attr={ [None, None])
"process_mesh": _global_process_mesh, auto.shard_tensor(self.linear1.weight, _global_process_mesh,
"dims_mapping": [-1, -1] [None, None])
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -119,28 +99,12 @@ def mlp_forward(train_program, start_program): ...@@ -119,28 +99,12 @@ def mlp_forward(train_program, start_program):
dtype='float32') dtype='float32')
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(input, auto.shard_tensor(input, PP_MESH_0, [None, None])
dist_attr={ auto.shard_tensor(label, PP_MESH_1, [None, None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "dp": elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, ["x", None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "mp": elif _global_parallel_strategy == "mp":
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, [None, None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -183,7 +147,7 @@ class TestMLPSaveLoad(unittest.TestCase): ...@@ -183,7 +147,7 @@ class TestMLPSaveLoad(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1]) _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
dist_main_prog, dist_start_prog, loss = get_distributed_program() dist_main_prog, dist_start_prog, loss = get_distributed_program()
place = paddle.set_device("gpu") place = paddle.set_device("gpu")
...@@ -230,7 +194,7 @@ class TestMLPSaveLoad(unittest.TestCase): ...@@ -230,7 +194,7 @@ class TestMLPSaveLoad(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1]) _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
dist_main_prog, dist_start_prog, loss = get_distributed_program() dist_main_prog, dist_start_prog, loss = get_distributed_program()
...@@ -278,11 +242,11 @@ class TestMLPSaveLoad(unittest.TestCase): ...@@ -278,11 +242,11 @@ class TestMLPSaveLoad(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "pp" _global_parallel_strategy = "pp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1]) _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
global PP_MESH_0 global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0]) PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"])
global PP_MESH_1 global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1]) PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"])
dist_main_prog, dist_start_prog, loss = get_distributed_program() dist_main_prog, dist_start_prog, loss = get_distributed_program()
......
...@@ -82,11 +82,7 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -82,11 +82,7 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, 1], shape=[batch_size, sequence_len, 1],
dtype='float32') dtype='float32')
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, [None, None, None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mappig": [-1, -1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -106,7 +102,7 @@ class TestMLPAutoParallelizer(unittest.TestCase): ...@@ -106,7 +102,7 @@ class TestMLPAutoParallelizer(unittest.TestCase):
def test_mlp_serial(self): def test_mlp_serial(self):
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False dist_strategy.amp = False
......
...@@ -86,7 +86,7 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -86,7 +86,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
paddle.static.Program()): paddle.static.Program()):
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
with paddle.fluid.unique_name.guard(): with paddle.fluid.unique_name.guard():
main_prog, startup_prog, inputs, outputs, reader = self.get_model( main_prog, startup_prog, inputs, outputs, data_loader = self.get_model(
place, **kwargs) place, **kwargs)
inputs = self._to_var_names(inputs) inputs = self._to_var_names(inputs)
outputs = self._to_var_names(outputs) outputs = self._to_var_names(outputs)
...@@ -95,27 +95,57 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -95,27 +95,57 @@ class AutoPallelPassTestBase(DistPassTestBase):
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
exe.run(startup_prog) exe.run(startup_prog)
for batch_id, input_data in enumerate(reader()): data_loader.start()
assert len(input_data) == len(inputs), "{} vs {}".format( batch_id = 0
len(input_data), len(inputs)) while True:
feed = dict(zip(inputs, input_data)) try:
fetch_values = exe.run(main_prog, feed=feed, fetch_list=outputs) fetch_values = exe.run(main_prog, fetch_list=outputs)
if paddle.distributed.get_rank() == 0: if paddle.distributed.get_rank() == 0:
output_dict = OrderedDict(zip(outputs, fetch_values)) output_dict = OrderedDict(zip(outputs, fetch_values))
print('batch {}, outputs {}'.format(batch_id, output_dict)) print('batch {}, outputs {}'.format(
batch_id, output_dict))
all_fetch_values.append(fetch_values) all_fetch_values.append(fetch_values)
batch_id += 1
except paddle.fluid.core.EOFException:
data_loader.reset()
break
with open(dump_file, "wb") as f: with open(dump_file, "wb") as f:
pickle.dump(all_fetch_values, f) pickle.dump(all_fetch_values, f)
def get_gpt_model(self, strategy, place, batch_size, sequence_len, def get_gpt_model(self, strategy, place, batch_size, sequence_len,
vocab_size, **kwargs): vocab_size, **kwargs):
def gen_data():
np.random.seed(2021)
for _ in range(10):
tokens = []
position_ids = []
attention_mask = []
labels = []
loss_mask = []
for _ in range(batch_size):
tokens.append(
np.random.randint(vocab_size,
size=sequence_len).astype("int64"))
position_ids.append(np.arange(sequence_len).astype("int64"))
attention_mask.append(
[np.tril(np.ones(sequence_len)).astype("float32")])
labels.append(
np.random.randint(vocab_size,
size=sequence_len).astype("int64"))
loss_mask.append(np.ones(sequence_len).astype("float32"))
yield tokens, position_ids, attention_mask, labels, loss_mask
modeling.init_global() modeling.init_global()
if strategy == "dp": if strategy == "dp":
modeling._global_parallel_strategy = "dp" modeling._global_parallel_strategy = "dp"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1],
dim_names=["x"])
elif strategy == "mp": elif strategy == "mp":
modeling._global_parallel_strategy = "mp" modeling._global_parallel_strategy = "mp"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1],
dim_names=["x"])
else: else:
raise ValueError("'get_gpt_model' only support dp and mp.") raise ValueError("'get_gpt_model' only support dp and mp.")
...@@ -137,23 +167,17 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -137,23 +167,17 @@ class AutoPallelPassTestBase(DistPassTestBase):
dtype='float32') dtype='float32')
data_holder = [tokens, position_ids, attention_mask, labels, loss_mask] data_holder = [tokens, position_ids, attention_mask, labels, loss_mask]
data_loader = paddle.fluid.io.DataLoader.from_generator(
feed_list=data_holder, capacity=70, iterable=False)
data_loader.set_batch_generator(gen_data, paddle.static.cuda_places())
if modeling._global_parallel_strategy == "dp": if modeling._global_parallel_strategy == "dp":
auto.shard_tensor(tokens, auto.shard_tensor(tokens, modeling._global_process_mesh,
dist_attr={ ["x", None])
"process_mesh": modeling._global_process_mesh,
"dims_mapping": [0, -1]
})
elif modeling._global_parallel_strategy == "pp": elif modeling._global_parallel_strategy == "pp":
auto.shard_tensor(tokens, auto.shard_tensor(tokens, modeling.PP_MESH_LIST[0], [None, None])
dist_attr={ auto.shard_tensor(attention_mask, modeling.PP_MESH_LIST[0],
"process_mesh": modeling.PP_MESH_LIST[0], [None, None, None, None])
"dims_mapping": [-1, -1]
})
auto.shard_tensor(attention_mask,
dist_attr={
"process_mesh": modeling.PP_MESH_LIST[0],
"dims_mapping": [-1, -1, -1, -1]
})
gpt = GPTModel(vocab_size=1000, gpt = GPTModel(vocab_size=1000,
hidden_size=64, hidden_size=64,
...@@ -179,38 +203,20 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -179,38 +203,20 @@ class AutoPallelPassTestBase(DistPassTestBase):
criterion = GPTPretrainingCriterion() criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask) loss = criterion(preds, labels, loss_mask)
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
if kwargs.get('optimizer', None) == "LarsMomentum": if kwargs.get('optimizer', None) == "LarsMomentum":
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer( optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
learning_rate=0.001, momentum=0.9) learning_rate=0.001, momentum=0.9)
else: else:
optimizer = paddle.fluid.optimizer.AdamOptimizer( optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
learning_rate=0.00001,
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=clip)
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
startup_program = paddle.static.default_startup_program() startup_program = paddle.static.default_startup_program()
_, _, dist_startup_prog, dist_main_prog = optimizer.minimize( _, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
loss, startup_program) loss, startup_program)
def gen_data(): return dist_main_prog, dist_startup_prog, data_holder, [loss
np.random.seed(2021) ], data_loader
for _ in range(10):
tokens = []
position_ids = []
attention_mask = []
labels = []
loss_mask = []
for _ in range(batch_size):
tokens.append(
np.random.randint(vocab_size, size=sequence_len))
position_ids.append(np.arange(sequence_len))
attention_mask.append([np.tril(np.ones(sequence_len))])
labels.append(
np.random.randint(vocab_size, size=sequence_len))
loss_mask.append(np.ones(sequence_len))
yield tokens, position_ids, attention_mask, labels, loss_mask
return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data
...@@ -20,10 +20,19 @@ import unittest ...@@ -20,10 +20,19 @@ import unittest
import paddle import paddle
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from auto_parallel_pass_test_base import AutoPallelPassTestBase from auto_parallel_pass_test_base import AutoPallelPassTestBase
from test_auto_parallel_amp_pass import TestAMPPass
class TestPF16Pass(TestAMPPass): class TestPF16Pass(AutoPallelPassTestBase):
def init(self):
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.rtol = 1e-5
self.atol = 1e-8
paddle.seed(2021)
random.seed(2021)
np.random.seed(2021)
def apply_passes(self): def apply_passes(self):
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
...@@ -34,14 +43,30 @@ class TestPF16Pass(TestAMPPass): ...@@ -34,14 +43,30 @@ class TestPF16Pass(TestAMPPass):
'layer_norm', 'layer_norm',
'gelu', 'gelu',
], ],
"custom_black_list": ['c_softmax_with_cross_entropy'], "custom_black_list":
"init_loss_scaling": 32768, ['c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum'],
"use_dynamic_loss_scaling": True, "init_loss_scaling":
"use_pure_fp16": True 32768,
"use_dynamic_loss_scaling":
True,
"use_pure_fp16":
True,
"use_fp16_guard":
False
} }
dist_strategy.semi_auto = True dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy) fleet.init(is_collective=True, strategy=dist_strategy)
def test_bs_8(self):
self.check_main(gpus=[0, 1],
batch_size=8,
sequence_len=512,
vocab_size=1000)
def get_model(self, place, batch_size, sequence_len, vocab_size):
return self.get_gpt_model("mp", place, batch_size, sequence_len,
vocab_size)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -97,11 +97,8 @@ class MLPLayer(nn.Layer): ...@@ -97,11 +97,8 @@ class MLPLayer(nn.Layer):
def mlp_forward(input, label, hidden_size): def mlp_forward(input, label, hidden_size):
auto.shard_tensor(input, auto.shard_tensor(input, auto.ProcessMesh([0], dim_names=["x"]),
dist_attr={ [None, None])
"process_mesh": [0],
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
initializer_range=0.02) initializer_range=0.02)
...@@ -160,6 +157,12 @@ class TestGradientMergePass(AutoPallelPassTestBase): ...@@ -160,6 +157,12 @@ class TestGradientMergePass(AutoPallelPassTestBase):
def get_model(self, place, batch_size, hidden_size, max_step): def get_model(self, place, batch_size, hidden_size, max_step):
def gen_data():
for i in range(max_step):
x_data = input_data[i * batch_size:(i + 1) * batch_size, :]
y_data = label_data[i * batch_size:(i + 1) * batch_size, :]
yield x_data, y_data
train_program = static.Program() train_program = static.Program()
startup_program = static.Program() startup_program = static.Program()
with static.program_guard(train_program, startup_program), \ with static.program_guard(train_program, startup_program), \
...@@ -171,6 +174,12 @@ class TestGradientMergePass(AutoPallelPassTestBase): ...@@ -171,6 +174,12 @@ class TestGradientMergePass(AutoPallelPassTestBase):
shape=[batch_size, 1], shape=[batch_size, 1],
dtype='float32') dtype='float32')
input.stop_gradient = False input.stop_gradient = False
data_holder = [input, label]
data_loader = paddle.fluid.io.DataLoader.from_generator(
feed_list=data_holder, capacity=70, iterable=False)
data_loader.set_batch_generator(gen_data,
paddle.static.cuda_places())
loss = mlp_forward(input, label, hidden_size) loss = mlp_forward(input, label, hidden_size)
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.01) optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.01)
...@@ -181,13 +190,8 @@ class TestGradientMergePass(AutoPallelPassTestBase): ...@@ -181,13 +190,8 @@ class TestGradientMergePass(AutoPallelPassTestBase):
input_data = np.random.random(size=(128, hidden_size)).astype('float32') input_data = np.random.random(size=(128, hidden_size)).astype('float32')
label_data = np.random.random(size=(128, 1)).astype('float32') label_data = np.random.random(size=(128, 1)).astype('float32')
def reader(): return dist_main_prog, dist_startup_prog, [input,
for i in range(max_step): label], [loss], data_loader
x_data = input_data[i * batch_size:(i + 1) * batch_size, :]
y_data = label_data[i * batch_size:(i + 1) * batch_size, :]
yield x_data, y_data
return dist_main_prog, dist_startup_prog, [input, label], [loss], reader
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
import paddle.nn as nn
import paddle.distributed as dist
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
paddle.enable_static()
process_mesh1 = [0, 1, 2, 3]
process_mesh2 = [[0, 1, 2], [3, 4, 5]]
class SimpleNet(nn.Layer):
def __init__(self, vocab_size=128, hidden_size=4):
super(SimpleNet, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.dense1 = nn.Linear(hidden_size, hidden_size)
self.dense2 = nn.Linear(hidden_size, hidden_size // 2)
def forward(self, x, y):
# Test shard_tensor interface with dist_attr arg
x = dist.shard_tensor(x,
dist_attr={
"process_mesh": process_mesh1,
"dims_mapping": [0, -1]
})
emb_out = self.word_embeddings(x)
# Test shard_tensor interface with no dist_attr arg
y = dist.shard_tensor(y)
linear1 = self.dense1(y)
out = self.dense2(linear1)
return x, y
class TestAutoParallelAPI(unittest.TestCase):
def test_api(self):
dist_context = get_default_distributed_context()
net = SimpleNet()
data1 = fluid.layers.fill_constant(shape=[2, 4], value=1, dtype="int64")
data2 = fluid.layers.fill_constant(shape=[2, 4],
value=2,
dtype="float32")
data3 = fluid.layers.fill_constant(shape=[2, 4],
value=4,
dtype="float32")
x, y = net.forward(data1, data2)
dist_x = dist_context.get_dist_tensor_for_program(x)
self.assertEqual(dist_x.dist_attr.process_mesh.processes, process_mesh1)
self.assertEqual(dist_x.dist_attr.dims_mapping, [0, -1])
self.assertEqual(dist_x.dist_attr.shard_sizes, None)
self.assertEqual(dist_x.dist_attr.device_placement, None)
self.assertTrue(dist_x.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_x.dist_attr.is_annotated("dims_mapping"))
self.assertFalse(dist_x.dist_attr.is_annotated("shard_sizes"))
self.assertFalse(dist_x.dist_attr.is_annotated("device_placement"))
dist_y = dist_context.get_dist_tensor_for_program(y)
self.assertEqual(dist_y.dist_attr.process_mesh, None)
self.assertEqual(dist_y.dist_attr.dims_mapping, [-1, -1])
self.assertEqual(dist_y.dist_attr.shard_sizes, None)
self.assertEqual(dist_y.dist_attr.device_placement, None)
self.assertFalse(dist_y.dist_attr.is_annotated("process_mesh"))
self.assertFalse(dist_y.dist_attr.is_annotated("dims_mapping"))
self.assertFalse(dist_y.dist_attr.is_annotated("shard_sizes"))
self.assertFalse(dist_y.dist_attr.is_annotated("device_placement"))
# Test shard_op interface with dist_attr
dims_mapping1 = [0, 1]
dims_mapping2 = [-1, 0]
dist_add = dist.shard_op(paddle.add,
dist_attr={
data2: {
"process_mesh": process_mesh2,
"dims_mapping": dims_mapping1
},
data3: {
"dims_mapping": dims_mapping2
}
})
results = dist_add(data2, data3)
ops = paddle.static.default_main_program().block(0).ops
last_op = ops[-1]
dist_op = dist_context.get_dist_op_for_program(last_op)
self.assertEqual(dist_op.dist_attr.process_mesh,
ProcessMesh(process_mesh2))
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
self.assertEqual(data2_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data2_dist_attr.dims_mapping, dims_mapping1)
self.assertEqual(data2_dist_attr.shard_sizes, None)
self.assertEqual(data2_dist_attr.device_placement, None)
self.assertTrue(data2_dist_attr.is_annotated("process_mesh"))
self.assertTrue(data2_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data2_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data2_dist_attr.is_annotated("device_placement"))
data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name)
self.assertEqual(data3_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data3_dist_attr.dims_mapping, dims_mapping2)
self.assertEqual(data3_dist_attr.shard_sizes, None)
self.assertEqual(data3_dist_attr.device_placement, None)
self.assertTrue(data3_dist_attr.is_annotated("process_mesh"))
self.assertTrue(data3_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data3_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data3_dist_attr.is_annotated("device_placement"))
# Test shard_op interface with dist_attr
dist_add = dist.shard_op(paddle.add)
results = dist_add(data2, data3)
ops = paddle.static.default_main_program().block(0).ops
last_op = ops[-1]
dist_op = dist_context.get_dist_op_for_program(last_op)
self.assertEqual(dist_op.dist_attr.process_mesh, None)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
self.assertEqual(data2_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data2_dist_attr.dims_mapping, [-1, -1])
self.assertEqual(data2_dist_attr.shard_sizes, None)
self.assertEqual(data2_dist_attr.device_placement, None)
self.assertFalse(data2_dist_attr.is_annotated("process_mesh"))
self.assertFalse(data2_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data2_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data2_dist_attr.is_annotated("device_placement"))
data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name)
self.assertEqual(data3_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data3_dist_attr.dims_mapping, [-1, -1])
self.assertEqual(data3_dist_attr.shard_sizes, None)
self.assertEqual(data3_dist_attr.device_placement, None)
self.assertFalse(data3_dist_attr.is_annotated("process_mesh"))
self.assertFalse(data3_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data3_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data3_dist_attr.is_annotated("device_placement"))
if __name__ == '__main__':
unittest.main()
...@@ -66,39 +66,13 @@ class MLPLayer(nn.Layer): ...@@ -66,39 +66,13 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
elif _global_parallel_strategy == "pp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh2,
"dims_mapping": [1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -119,18 +93,10 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -119,18 +93,10 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, hidden_size], shape=[batch_size, sequence_len, hidden_size],
dtype='float32') dtype='float32')
if _global_parallel_strategy == "dp": if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input, auto.shard_tensor(input,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["dp", None, None])
"dims_mapping": [0, -1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -146,7 +112,8 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -146,7 +112,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
...@@ -161,7 +128,8 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -161,7 +128,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -177,8 +145,9 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -177,8 +145,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) [4, 5, 6, 7]],
dim_names=["dp", "mp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -286,18 +255,10 @@ class AttentionLayer(nn.Layer): ...@@ -286,18 +255,10 @@ class AttentionLayer(nn.Layer):
bias_attr=bias_attr) bias_attr=bias_attr)
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "dp": if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input, auto.shard_tensor(input,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["dp", None, None])
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
q = self.q_proj(input) q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
...@@ -306,38 +267,16 @@ class AttentionLayer(nn.Layer): ...@@ -306,38 +267,16 @@ class AttentionLayer(nn.Layer):
k = self.k_proj(input) k = self.k_proj(input)
v = self.v_proj(input) v = self.v_proj(input)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight, auto.shard_tensor(self.q_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.k_proj.weight, auto.shard_tensor(self.k_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.v_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -369,18 +308,10 @@ class AttentionLayer(nn.Layer): ...@@ -369,18 +308,10 @@ class AttentionLayer(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
return out return out
...@@ -411,7 +342,8 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -411,7 +342,8 @@ class TestAttentionAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
...@@ -420,15 +352,14 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -420,15 +352,14 @@ class TestAttentionAutoCompletion(unittest.TestCase):
completer = Completer(dist_context) completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_attn_mp(self): def test_attn_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -444,8 +375,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -444,8 +375,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) [4, 5, 6, 7]],
dim_names=["dp", "mp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -542,34 +474,18 @@ class DecoderLayer(nn.Layer): ...@@ -542,34 +474,18 @@ class DecoderLayer(nn.Layer):
self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
def forward(self, input_ids, position_ids): def forward(self, input_ids, position_ids):
if _global_parallel_strategy == "dp": if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input_ids, auto.shard_tensor(input_ids,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["dp", None])
"dims_mapping": [0, -1]
})
input_embeddings = self.word_embeddings(input_ids) input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.word_embeddings.weight, auto.shard_tensor(self.word_embeddings.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [1, -1]
})
embeddings = input_embeddings + position_embeddings embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings) embeddings = self.dropout1(embeddings)
...@@ -585,38 +501,16 @@ class DecoderLayer(nn.Layer): ...@@ -585,38 +501,16 @@ class DecoderLayer(nn.Layer):
k = self.k_proj(target) k = self.k_proj(target)
v = self.v_proj(target) v = self.v_proj(target)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight, auto.shard_tensor(self.q_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.k_proj.weight, auto.shard_tensor(self.k_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.v_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -649,18 +543,10 @@ class DecoderLayer(nn.Layer): ...@@ -649,18 +543,10 @@ class DecoderLayer(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [1, -1]
})
# Add residual # Add residual
residual = embeddings + self.dropout2(out) residual = embeddings + self.dropout2(out)
...@@ -673,28 +559,13 @@ class DecoderLayer(nn.Layer): ...@@ -673,28 +559,13 @@ class DecoderLayer(nn.Layer):
out2 = F.gelu(out1, approximate=True) out2 = F.gelu(out1, approximate=True)
out3 = self.linear1(out2) out3 = self.linear1(out2)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [1, -1]
})
# Add residual # Add residual
final = residual + self.dropout3(out3) final = residual + self.dropout3(out3)
...@@ -732,7 +603,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -732,7 +603,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
...@@ -747,7 +619,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -747,7 +619,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -763,8 +636,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -763,8 +636,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) [4, 5, 6, 7]],
dim_names=["dp", "mp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
......
...@@ -116,18 +116,10 @@ class MultiHeadAttention(nn.Layer): ...@@ -116,18 +116,10 @@ class MultiHeadAttention(nn.Layer):
""" """
q = self.q_proj(query) q = self.q_proj(query)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight, auto.shard_tensor(self.q_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
...@@ -158,34 +150,15 @@ class MultiHeadAttention(nn.Layer): ...@@ -158,34 +150,15 @@ class MultiHeadAttention(nn.Layer):
to construct cache for inference. to construct cache for inference.
""" """
k = self.k_proj(key) k = self.k_proj(key)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
v = self.v_proj(value) v = self.v_proj(value)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.k_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.v_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -265,18 +238,10 @@ class MultiHeadAttention(nn.Layer): ...@@ -265,18 +238,10 @@ class MultiHeadAttention(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
outs = [out] outs = [out]
if self.need_weights: if self.need_weights:
...@@ -439,31 +404,13 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -439,31 +404,13 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear2.weight, auto.shard_tensor(self.linear2.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# tgt = self.dropout2( # tgt = self.dropout2(
# self.linear2(F.gelu( # self.linear2(F.gelu(
...@@ -523,18 +470,10 @@ class GPTEmbeddings(nn.Layer): ...@@ -523,18 +470,10 @@ class GPTEmbeddings(nn.Layer):
input_embedings = self.word_embeddings(input_ids) input_embedings = self.word_embeddings(input_ids)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.word_embeddings.weight, auto.shard_tensor(self.word_embeddings.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [1, -1]
})
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings embeddings = input_embedings + position_embeddings
...@@ -757,18 +696,10 @@ def gpt_pretrain_forward(train_program, start_program): ...@@ -757,18 +696,10 @@ def gpt_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len], shape=[batch_size, sequence_len],
dtype='float64') dtype='float64')
if _global_parallel_strategy == "dp": if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input_ids, auto.shard_tensor(input_ids,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["dp", None])
"dims_mapping": [0, -1]
})
gpt = GPTModel(vocab_size=32768, gpt = GPTModel(vocab_size=32768,
hidden_size=1024, hidden_size=1024,
...@@ -801,7 +732,8 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -801,7 +732,8 @@ class TestGPTAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -817,7 +749,8 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -817,7 +749,8 @@ class TestGPTAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -833,8 +766,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -833,8 +766,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) [4, 5, 6, 7]],
dim_names=["dp", "mp"])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
......
...@@ -35,8 +35,8 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr ...@@ -35,8 +35,8 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp" _global_parallel_strategy = "dp_mp_pp"
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]]) PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]]) PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])
NUM_RANKS = 8 NUM_RANKS = 8
STAGE_0_CNT = 5 STAGE_0_CNT = 5
STAGE_1_CNT = 10 STAGE_1_CNT = 10
...@@ -73,16 +73,8 @@ class MLPLayer(nn.Layer): ...@@ -73,16 +73,8 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if self.is_distributed: if self.is_distributed:
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
dist_attr={ auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -135,16 +127,8 @@ def mlp_forward(train_program, start_program, is_distributed=True): ...@@ -135,16 +127,8 @@ def mlp_forward(train_program, start_program, is_distributed=True):
dtype='float32') dtype='float32')
if is_distributed: if is_distributed:
auto.shard_tensor(input, auto.shard_tensor(input, PP_MESH_0, ["x", None])
dist_attr={ auto.shard_tensor(label, PP_MESH_1, ["x", None])
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
......
...@@ -71,7 +71,7 @@ class TestDistributedTensor(unittest.TestCase): ...@@ -71,7 +71,7 @@ class TestDistributedTensor(unittest.TestCase):
def test_new_local_tensor(self): def test_new_local_tensor(self):
test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh( test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh(
mesh=[0, 1]) mesh=[0, 1], dim_names=["x"])
test_auto_parallel_reshard._global_parallel_strategy = "dp" test_auto_parallel_reshard._global_parallel_strategy = "dp"
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
......
...@@ -414,37 +414,25 @@ class MLPLayer(nn.Layer): ...@@ -414,37 +414,25 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, _global_process_mesh[0],
dist_attr={ [None, "y"])
"process_mesh": _global_process_mesh[0],
"dims_mapping": [-1, 1] auto.shard_tensor(self.linear1.weight, _global_process_mesh[0],
}) ["y", None])
auto.shard_tensor(self.linear1.weight,
dist_attr={ auto.shard_tensor(self.linear2.weight, _global_process_mesh[1],
"process_mesh": _global_process_mesh[0], [None, "y"])
"dims_mapping": [1, -1]
}) auto.shard_tensor(self.linear3.weight, _global_process_mesh[1],
auto.shard_tensor(self.linear2.weight, ["y", None])
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear3.weight,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = self.linear1(out) out = self.linear1(out)
auto.shard_tensor(out, auto.shard_tensor(out, _global_process_mesh[1], ["x", None])
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear2(out) out = self.linear2(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = self.linear3(out) out = self.linear3(out)
...@@ -464,11 +452,7 @@ def mlp_forward(train_program, start_program): ...@@ -464,11 +452,7 @@ def mlp_forward(train_program, start_program):
dtype='float32') dtype='float32')
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh[0], ["x", None])
dist_attr={
"process_mesh": _global_process_mesh[0],
"dims_mapping": [0, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
initializer_range=0.02) initializer_range=0.02)
...@@ -548,7 +532,10 @@ class TestAutoParallelMapper(unittest.TestCase): ...@@ -548,7 +532,10 @@ class TestAutoParallelMapper(unittest.TestCase):
global _global_num_stages global _global_num_stages
_global_num_stages = 2 _global_num_stages = 2
global _global_process_mesh global _global_process_mesh
_global_process_mesh = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] _global_process_mesh = [
auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]),
auto.ProcessMesh([[4, 5], [6, 7]], dim_names=["x", "y"])
]
processes = [0, 1, 2, 3, 4, 5, 6, 7] processes = [0, 1, 2, 3, 4, 5, 6, 7]
dist_programs = {} dist_programs = {}
......
...@@ -276,39 +276,20 @@ class MLPLayer(nn.Layer): ...@@ -276,39 +276,20 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
else: else:
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, None])
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, None])
"dims_mapping": [-1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -329,18 +310,10 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -329,18 +310,10 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, hidden_size], shape=[batch_size, sequence_len, hidden_size],
dtype='float32') dtype='float32')
if _global_parallel_strategy == "dp": if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input, auto.shard_tensor(input,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["dp", None, None])
"dims_mapping": [0, -1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -356,7 +329,8 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -356,7 +329,8 @@ class TestMLPAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward) mlp_pretrain_forward)
...@@ -391,7 +365,8 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -391,7 +365,8 @@ class TestMLPAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward) mlp_pretrain_forward)
...@@ -453,8 +428,9 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -453,8 +428,9 @@ class TestMLPAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) [4, 5, 6, 7]],
dim_names=["dp", "mp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward) mlp_pretrain_forward)
...@@ -558,18 +534,10 @@ class AttentionLayer(nn.Layer): ...@@ -558,18 +534,10 @@ class AttentionLayer(nn.Layer):
bias_attr=bias_attr) bias_attr=bias_attr)
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "dp": if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input, auto.shard_tensor(input,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["dp", None, None])
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
q = self.q_proj(input) q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
...@@ -578,38 +546,16 @@ class AttentionLayer(nn.Layer): ...@@ -578,38 +546,16 @@ class AttentionLayer(nn.Layer):
k = self.k_proj(input) k = self.k_proj(input)
v = self.v_proj(input) v = self.v_proj(input)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight, auto.shard_tensor(self.q_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.k_proj.weight, auto.shard_tensor(self.k_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.v_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -641,18 +587,11 @@ class AttentionLayer(nn.Layer): ...@@ -641,18 +587,11 @@ class AttentionLayer(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.out_proj.weight, if _global_parallel_strategy in ["mp", "dp_mp"]:
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [1, -1]
})
return out return out
...@@ -683,7 +622,8 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -683,7 +622,8 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward) attn_pretrain_forward)
...@@ -717,7 +657,8 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -717,7 +657,8 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward) attn_pretrain_forward)
...@@ -783,8 +724,9 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -783,8 +724,9 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) [4, 5, 6, 7]],
dim_names=["dp", "mp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward) attn_pretrain_forward)
...@@ -930,34 +872,18 @@ class DecoderLayer(nn.Layer): ...@@ -930,34 +872,18 @@ class DecoderLayer(nn.Layer):
self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
def forward(self, input_ids, position_ids): def forward(self, input_ids, position_ids):
if _global_parallel_strategy == "dp": if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input_ids, auto.shard_tensor(input_ids,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["dp", None])
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
input_embeddings = self.word_embeddings(input_ids) input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.word_embeddings.weight, auto.shard_tensor(self.word_embeddings.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [1, -1]
})
embeddings = input_embeddings + position_embeddings embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings) embeddings = self.dropout1(embeddings)
...@@ -973,38 +899,16 @@ class DecoderLayer(nn.Layer): ...@@ -973,38 +899,16 @@ class DecoderLayer(nn.Layer):
k = self.k_proj(target) k = self.k_proj(target)
v = self.v_proj(target) v = self.v_proj(target)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight, auto.shard_tensor(self.q_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.k_proj.weight, auto.shard_tensor(self.k_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.v_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -1037,24 +941,14 @@ class DecoderLayer(nn.Layer): ...@@ -1037,24 +941,14 @@ class DecoderLayer(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [1, -1]
})
else: else:
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, None])
"dims_mapping": [-1, -1]
})
# Add residual # Add residual
residual = embeddings + self.dropout2(out) residual = embeddings + self.dropout2(out)
...@@ -1067,28 +961,13 @@ class DecoderLayer(nn.Layer): ...@@ -1067,28 +961,13 @@ class DecoderLayer(nn.Layer):
out2 = F.gelu(out1, approximate=True) out2 = F.gelu(out1, approximate=True)
out3 = self.linear1(out2) out3 = self.linear1(out2)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [1, -1]
})
# Add residual # Add residual
final = residual + self.dropout3(out3) final = residual + self.dropout3(out3)
...@@ -1126,8 +1005,9 @@ class TestDecoderLayerPartitioner(unittest.TestCase): ...@@ -1126,8 +1005,9 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) [4, 5, 6, 7]],
dim_names=["dp", "mp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
decoder_pretrain_forward) decoder_pretrain_forward)
...@@ -1208,8 +1088,9 @@ class TestDecoderLayerPartitioner(unittest.TestCase): ...@@ -1208,8 +1088,9 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "None" _global_parallel_strategy = "None"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) [4, 5, 6, 7]],
dim_names=["x", "y"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
decoder_pretrain_forward) decoder_pretrain_forward)
......
...@@ -163,18 +163,10 @@ class MultiHeadAttention(nn.Layer): ...@@ -163,18 +163,10 @@ class MultiHeadAttention(nn.Layer):
""" """
q = self.q_proj(query) q = self.q_proj(query)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight, auto.shard_tensor(self.q_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
...@@ -205,34 +197,15 @@ class MultiHeadAttention(nn.Layer): ...@@ -205,34 +197,15 @@ class MultiHeadAttention(nn.Layer):
to construct cache for inference. to construct cache for inference.
""" """
k = self.k_proj(key) k = self.k_proj(key)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
v = self.v_proj(value) v = self.v_proj(value)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.k_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.v_proj.weight, auto.shard_tensor(self.v_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -312,18 +285,10 @@ class MultiHeadAttention(nn.Layer): ...@@ -312,18 +285,10 @@ class MultiHeadAttention(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight, auto.shard_tensor(self.out_proj.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
outs = [out] outs = [out]
if self.need_weights: if self.need_weights:
...@@ -486,31 +451,13 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -486,31 +451,13 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear1.weight, auto.shard_tensor(self.linear1.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=[None, "mp"])
"dims_mapping": [-1, 1]
})
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear2.weight, auto.shard_tensor(self.linear2.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# tgt = self.dropout2( # tgt = self.dropout2(
# self.linear2(F.gelu( # self.linear2(F.gelu(
...@@ -570,18 +517,10 @@ class GPTEmbeddings(nn.Layer): ...@@ -570,18 +517,10 @@ class GPTEmbeddings(nn.Layer):
input_embedings = self.word_embeddings(input_ids) input_embedings = self.word_embeddings(input_ids)
if _global_parallel_strategy == "mp": if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.word_embeddings.weight, auto.shard_tensor(self.word_embeddings.weight,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["mp", None])
"dims_mapping": [1, -1]
})
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings embeddings = input_embedings + position_embeddings
...@@ -804,18 +743,10 @@ def gpt_pretrain_forward(train_program, startup_program): ...@@ -804,18 +743,10 @@ def gpt_pretrain_forward(train_program, startup_program):
shape=[batch_size, sequence_len], shape=[batch_size, sequence_len],
dtype='float64') dtype='float64')
if _global_parallel_strategy == "dp": if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input_ids, auto.shard_tensor(input_ids,
dist_attr={ process_mesh=_global_process_mesh,
"process_mesh": _global_process_mesh, shard_spec=["dp", None])
"dims_mapping": [0, -1]
})
gpt = GPTModel(vocab_size=32768, gpt = GPTModel(vocab_size=32768,
hidden_size=768, hidden_size=768,
...@@ -863,8 +794,9 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -863,8 +794,9 @@ class TestGPTPartitioner(unittest.TestCase):
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) [4, 5, 6, 7]],
dim_names=["dp", "mp"])
train_program = static.Program() train_program = static.Program()
startup_program = static.Program() startup_program = static.Program()
......
...@@ -63,27 +63,13 @@ class MLPLayer(nn.Layer): ...@@ -63,27 +63,13 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
dist_attr={ auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
else: else:
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, _global_process_mesh,
dist_attr={ [None, None])
"process_mesh": _global_process_mesh, auto.shard_tensor(self.linear1.weight, _global_process_mesh,
"dims_mapping": [-1, -1] [None, None])
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -107,28 +93,12 @@ def mlp_forward(train_program, start_program): ...@@ -107,28 +93,12 @@ def mlp_forward(train_program, start_program):
dtype='float32') dtype='float32')
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(input, auto.shard_tensor(input, PP_MESH_0, [None, None])
dist_attr={ auto.shard_tensor(label, PP_MESH_1, [None, None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "dp": elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, ["x", None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
else: else:
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, [None, None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -296,11 +266,11 @@ class TestMLPReshard(unittest.TestCase): ...@@ -296,11 +266,11 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "pp" _global_parallel_strategy = "pp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
global PP_MESH_0 global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0]) PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"])
global PP_MESH_1 global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1]) PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"])
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -325,11 +295,11 @@ class TestMLPReshard(unittest.TestCase): ...@@ -325,11 +295,11 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "pp" _global_parallel_strategy = "pp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
global PP_MESH_0 global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0]) PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"])
global PP_MESH_1 global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1]) PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"])
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -352,7 +322,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -352,7 +322,7 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
......
...@@ -34,9 +34,10 @@ from paddle.distributed.auto_parallel.cluster import Cluster ...@@ -34,9 +34,10 @@ from paddle.distributed.auto_parallel.cluster import Cluster
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp" _global_parallel_strategy = "dp_mp_pp"
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) _global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]],
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]]) dim_names=["x", "y", "z"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]]) PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -63,16 +64,8 @@ class MLPLayer(nn.Layer): ...@@ -63,16 +64,8 @@ class MLPLayer(nn.Layer):
self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input): def forward(self, input):
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"])
dist_attr={ auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -80,11 +73,7 @@ class MLPLayer(nn.Layer): ...@@ -80,11 +73,7 @@ class MLPLayer(nn.Layer):
out = self.linear1(out) out = self.linear1(out)
param = paddle.fluid.layers.create_parameter([1024, 4096], param = paddle.fluid.layers.create_parameter([1024, 4096],
paddle.float32) paddle.float32)
auto.shard_tensor(param, auto.shard_tensor(param, PP_MESH_1, [None, "y"])
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, 1]
})
out = paddle.fluid.layers.mul(out, param) out = paddle.fluid.layers.mul(out, param)
return out return out
...@@ -103,16 +92,8 @@ def mlp_forward(train_program, start_program): ...@@ -103,16 +92,8 @@ def mlp_forward(train_program, start_program):
shape=[batch_size, 1], shape=[batch_size, 1],
dtype='float32') dtype='float32')
auto.shard_tensor(input, auto.shard_tensor(input, PP_MESH_0, ["x", None])
dist_attr={ auto.shard_tensor(label, PP_MESH_1, ["x", None])
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
......
...@@ -34,9 +34,9 @@ from paddle.distributed.auto_parallel.cluster import Cluster ...@@ -34,9 +34,9 @@ from paddle.distributed.auto_parallel.cluster import Cluster
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = "mp_pp" _global_parallel_strategy = "mp_pp"
_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]]) _global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
PP_MESH_0 = auto.ProcessMesh([0, 1]) PP_MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"])
PP_MESH_1 = auto.ProcessMesh([2, 3]) PP_MESH_1 = auto.ProcessMesh([2, 3], dim_names=["x"])
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -73,35 +73,15 @@ class MLPLayer(nn.Layer): ...@@ -73,35 +73,15 @@ class MLPLayer(nn.Layer):
bias_attr=bias_attr) bias_attr=bias_attr)
def forward(self, input): def forward(self, input):
auto.shard_tensor(self.word_embeddings.weight, auto.shard_tensor(self.word_embeddings.weight, PP_MESH_0, ["x", None])
dist_attr={ auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "x"])
"process_mesh": PP_MESH_0, auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["x", None])
"dims_mapping": [0, -1] auto.shard_tensor(self.linear2.weight, PP_MESH_1, ["x", None])
})
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
w_out = self.word_embeddings(input) w_out = self.word_embeddings(input)
out = self.linear0(w_out) out = self.linear0(w_out)
param = paddle.fluid.layers.create_parameter([4096, 4096], param = paddle.fluid.layers.create_parameter([4096, 4096],
paddle.float32) paddle.float32)
auto.shard_tensor(param, auto.shard_tensor(param, PP_MESH_0, ["x", None])
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.mul(out, param) out = paddle.fluid.layers.mul(out, param)
gelu_out = F.gelu(out, approximate=True) gelu_out = F.gelu(out, approximate=True)
out = self.linear1(gelu_out) out = self.linear1(gelu_out)
...@@ -122,16 +102,8 @@ def mlp_forward(train_program, start_program): ...@@ -122,16 +102,8 @@ def mlp_forward(train_program, start_program):
shape=[batch_size, 1], shape=[batch_size, 1],
dtype='float32') dtype='float32')
auto.shard_tensor(input, auto.shard_tensor(input, PP_MESH_0, [None])
dist_attr={ auto.shard_tensor(label, PP_MESH_1, [None, None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -238,7 +210,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -238,7 +210,6 @@ class TestMLPReshard(unittest.TestCase):
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
...@@ -249,32 +220,15 @@ class TestMLPReshard(unittest.TestCase): ...@@ -249,32 +220,15 @@ class TestMLPReshard(unittest.TestCase):
def test_allgather(self): def test_allgather(self):
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
process_mesh = auto.ProcessMesh(mesh=[0, 1]) process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
with static.program_guard(train_program, startup_program): with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor(x, x = auto.shard_tensor(x, process_mesh, ["x", None])
dist_attr={
"process_mesh": process_mesh,
"dims_mapping": [0, -1]
})
w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
w = auto.shard_tensor(w, w = auto.shard_tensor(w, process_mesh, [None, None])
dist_attr={
"process_mesh": process_mesh, y = paddle.distributed.shard_op(paddle.matmul, process_mesh,
"dims_mapping": [-1, -1] [[None, None], [None, None]])(x, w)
})
y = paddle.distributed.shard_op(paddle.matmul,
dist_attr={
"process_mesh": process_mesh,
x: {
"dims_mapping": [-1, -1]
},
w: {
"dims_mapping": [-1, -1]
}
})(x, w)
rank_id = 0 rank_id = 0
dist_context = DistributedContext() dist_context = DistributedContext()
......
...@@ -62,27 +62,13 @@ class MLPLayer(nn.Layer): ...@@ -62,27 +62,13 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
dist_attr={ auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
else: else:
auto.shard_tensor(self.linear0.weight, auto.shard_tensor(self.linear0.weight, _global_process_mesh,
dist_attr={ [None, None])
"process_mesh": _global_process_mesh, auto.shard_tensor(self.linear1.weight, _global_process_mesh,
"dims_mapping": [-1, -1] [None, None])
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -106,28 +92,12 @@ def mlp_forward(train_program, start_program): ...@@ -106,28 +92,12 @@ def mlp_forward(train_program, start_program):
dtype='float32') dtype='float32')
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(input, auto.shard_tensor(input, PP_MESH_0, [None, None])
dist_attr={ auto.shard_tensor(label, PP_MESH_1, [None, None])
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "dp": elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, ["x", None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
else: else:
auto.shard_tensor(input, auto.shard_tensor(input, _global_process_mesh, [None, None])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
...@@ -196,7 +166,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -196,7 +166,7 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = None _global_parallel_strategy = None
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0]) _global_process_mesh = auto.ProcessMesh(mesh=[0], dim_names=["x"])
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册