未验证 提交 bc2265f8 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] adapt lazyinit & fix pass (#45840)

* adapt lazy init and fix pass

* add unittest

* update comment

* fix amp and sharding

* remove clip_by_norm
上级 abc85c50
...@@ -26,14 +26,7 @@ from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collat ...@@ -26,14 +26,7 @@ from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collat
class DistributedDataLoader(metaclass=abc.ABCMeta): class DistributedDataLoader(metaclass=abc.ABCMeta):
def __init__(self, def __init__(self, dataset, batch_size=1, epochs=1, drop_last=False):
dataset,
batch_size=1,
epochs=1,
data_parallel_world_size=None,
data_parallel_rank=None,
drop_last=False,
split_data=True):
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER self.dataset_kind = _DatasetKind.ITER
else: else:
...@@ -42,19 +35,11 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): ...@@ -42,19 +35,11 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
self.dataset = dataset self.dataset = dataset
self.epochs = epochs self.epochs = epochs
self.drop_lost = drop_last self.drop_lost = drop_last
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = data_parallel_rank
self.split_data = split_data
if batch_size is None: if batch_size is None:
self.batch_size = None self.batch_size = None
self.batch_sampler = None self.batch_sampler = None
else: else:
if data_parallel_world_size is not None:
for dp_world_size in data_parallel_world_size:
if dp_world_size is not None:
assert batch_size % dp_world_size == 0, \
"batch_size must be divisible by dp_world_size value {}".format(str(dp_world_size))
self.batch_size = batch_size self.batch_size = batch_size
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler( self.batch_sampler = _InfiniteIterableSampler(
...@@ -97,18 +82,22 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -97,18 +82,22 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
collate_fn=None, collate_fn=None,
data_parallel_world_size=None, data_parallel_world_size=[],
data_parallel_rank=None, data_parallel_rank=[],
drop_last=False, drop_last=False,
split_data=True): split_data=True):
self.feed_list = feed_list self.feed_list = feed_list
self.places = places self.places = places
self.steps_per_epoch = steps_per_epoch self.steps_per_epoch = steps_per_epoch
assert len(data_parallel_world_size) == len(feed_list)
assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.split_data = split_data
super(NonIterableGeneratorLoader, super(NonIterableGeneratorLoader,
self).__init__(dataset, batch_size, epochs, self).__init__(dataset, batch_size, epochs, drop_last)
data_parallel_world_size, data_parallel_rank,
drop_last, split_data)
if self.auto_collate_batch: if self.auto_collate_batch:
self.collate_fn = collate_fn or default_collate_fn self.collate_fn = collate_fn or default_collate_fn
...@@ -154,13 +143,12 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -154,13 +143,12 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
def _create_inner_dataloader(self): def _create_inner_dataloader(self):
def sample_data_generator(): def data_generator():
while True: while True:
try: try:
indices = next(self.sampler_iter) indices = next(self.sampler_iter)
batch = self.dataset_fetcher.fetch(indices) batch = self.dataset_fetcher.fetch(indices)
if batch is None: break if batch is None: break
except StopIteration: except StopIteration:
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.dataset_kind, self.dataset,
...@@ -169,53 +157,23 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -169,53 +157,23 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
break break
partial_data = [] partial_data = []
for i, d in enumerate(batch[:len(self.feed_list)]): for i, d in enumerate(batch):
array = np.array(d) array = np.array(d)
if not self.split_data: if not self.split_data:
partial_data.append(array) partial_data.append(array)
elif self.dp_world_sizes[i] is not None: continue
partial_data.append(
np.split(array,
self.dp_world_sizes[i])[self.dp_ranks[i]])
else:
partial_data.append(array)
yield partial_data
def batch_data_generator():
while True:
try:
indices = next(self.sampler_iter)
batch = self.dataset_fetcher.fetch(indices) batch_size = array.shape[0]
if batch is None: break assert batch_size % self.dp_world_sizes[i] == 0, \
except StopIteration: "batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i]))
break
partial_data = []
for i, d in enumerate(batch[:len(self.feed_list)]):
array = np.array(d)
if not self.split_data:
partial_data.append(array)
elif self.dp_world_sizes[i] is not None:
partial_data.append( partial_data.append(
np.split(array, np.split(array,
self.dp_world_sizes[i])[self.dp_ranks[i]]) self.dp_world_sizes[i])[self.dp_ranks[i]])
else:
partial_data.append(array)
yield partial_data
self.dp_world_sizes = [ yield partial_data
1 for _ in range(len(self.feed_list))
] if self.data_parallel_world_size is None else self.data_parallel_world_size
self.dp_ranks = [
0 for _ in range(len(self.feed_list))
] if self.data_parallel_rank is None else self.data_parallel_rank
dataloader = paddle.fluid.io.DataLoader.from_generator( dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False) feed_list=self.feed_list, capacity=70, iterable=False)
if self.batch_size is not None: dataloader.set_batch_generator(data_generator, self.places)
dataloader.set_batch_generator(sample_data_generator, self.places)
else:
dataloader.set_batch_generator(batch_data_generator, self.places)
return dataloader return dataloader
...@@ -36,7 +36,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -36,7 +36,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.passes import new_pass, PassContext
from .hepler import ProgramHelper from .helper import ProgramHelper
from ..collective import _get_global_env from ..collective import _get_global_env
from .cluster import Cluster, get_default_cluster from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner from .planner_v2 import Planner
...@@ -118,8 +118,7 @@ class Engine: ...@@ -118,8 +118,7 @@ class Engine:
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`." " or `paddle.fluid.optimizer.Optimizer`."
) )
self._optimizer = optimizer self._optimizer = self._validate_opt(optimizer)
self._all_ranks = all_ranks
if loss and not isinstance(loss, if loss and not isinstance(loss,
paddle.nn.Layer) and not callable(loss): paddle.nn.Layer) and not callable(loss):
...@@ -136,6 +135,7 @@ class Engine: ...@@ -136,6 +135,7 @@ class Engine:
self._metrics = to_list(metrics) self._metrics = to_list(metrics)
self._gradient_scale = gradient_scale self._gradient_scale = gradient_scale
self._planned_mode = None self._planned_mode = None
self._all_ranks = all_ranks
self._prepare_single_mode("train") self._prepare_single_mode("train")
def _prepare_single_mode(self, mode): def _prepare_single_mode(self, mode):
...@@ -161,21 +161,23 @@ class Engine: ...@@ -161,21 +161,23 @@ class Engine:
self._dygraph_mode = True self._dygraph_mode = True
self._logger.info("Building model with 'to_static' method.") self._logger.info("Building model with 'to_static' method.")
program_helper = ProgramHelper(self.model, self._loss, inputs_spec = self.inputs_spec
self._metrics, self.inputs_spec, labels_spec = self.labels_spec if self.labels_spec else []
self.labels_spec) self.program_helper = ProgramHelper(self.model, self._loss,
self._metrics, inputs_spec,
labels_spec)
# build forward main program # build forward main program
program_helper.build_program(mode) self.program_helper.build_program(mode)
self.concrete_program = program_helper.concrete_program self.concrete_program = self.program_helper.concrete_program
serial_main_prog = program_helper.main_program serial_main_prog = self.program_helper.main_program
serial_startup_prog = program_helper.startup_program serial_startup_prog = self.program_helper.startup_program
inputs = program_helper.input_vars inputs = self.program_helper.input_vars
outputs = program_helper.output_vars outputs = self.program_helper.output_vars
labels = program_helper.label_vars labels = self.program_helper.label_vars
losses = program_helper.loss_vars losses = self.program_helper.loss_vars
metrics = program_helper.metric_vars metrics = self.program_helper.metric_vars
paddle.enable_static() paddle.enable_static()
else: else:
...@@ -334,40 +336,17 @@ class Engine: ...@@ -334,40 +336,17 @@ class Engine:
continue continue
process_group.instantiate() process_group.instantiate()
self._place = _get_device() place = _get_device()
if isinstance(self._place, fluid.CUDAPlace): if isinstance(place, fluid.CUDAPlace):
self._place = fluid.CUDAPlace(ParallelEnv().dev_id) place = fluid.CUDAPlace(ParallelEnv().dev_id)
if self._dygraph_mode: if self._dygraph_mode:
paddle.disable_static() dist_context = self._dist_contexts[mode]
main_program = self._dist_main_progs[mode][self._cur_rank] dist_main_program = self._dist_main_progs[mode][self._cur_rank]
for param in self.concrete_program.parameters: self.program_helper.init(dist_main_program, place, dist_context)
# create var in scope and share parameters to scope
if param.name not in main_program.global_block().vars:
continue
# get param_var's dist_attr
var = main_program.global_block().vars[param.name]
var_dist_attr = self._dist_contexts[
mode].get_tensor_dist_attr_for_program(var)
dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology,
"process_group": var_dist_attr.process_mesh.processes
}
# slice param_value with dist_attr
# share sliced_param_value with param_tensor in global_scope
from .converter import Converter
param_tensor = global_scope().var(param.name).get_tensor()
sliced_param = Converter.slice_with_dist_attr(
param.numpy(), dist_attr)
shared_tensor = paddle.to_tensor(sliced_param,
place=self._place)
param_tensor._share_data_with(
shared_tensor.value().get_tensor())
paddle.enable_static()
if self._executor is None: if self._executor is None:
self._executor = paddle.static.Executor(self._place) self._executor = paddle.static.Executor(place)
uninitialized = [] uninitialized = []
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
for var in dist_startup_prog.list_vars(): for var in dist_startup_prog.list_vars():
...@@ -411,7 +390,7 @@ class Engine: ...@@ -411,7 +390,7 @@ class Engine:
data = np.array(param_t) data = np.array(param_t)
param_t.set(np.float16(data), place) param_t.set(np.float16(data), place)
cast_parameters_to_fp16(self._place, prune_startup_prog) cast_parameters_to_fp16(place, prune_startup_prog)
def fit(self, def fit(self,
train_data, train_data,
...@@ -577,15 +556,20 @@ class Engine: ...@@ -577,15 +556,20 @@ class Engine:
dist_context = self._dist_contexts[self.mode] dist_context = self._dist_contexts[self.mode]
dist_main_block = dist_main_prog.global_block() dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list from dist_program, then insert dataloader op # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# with sharded var shape. Because predict_program does not contain # Cause predict_program does not contain labels var,
# labels var, so we will filter dataset's value with length of feed_list. # then we will add labels var from serial_program to dist_program,
# that maintains the length of feed_list equal to the length of dataset's values.
inputs_var = self._feed_vars[self.mode]["inputs"] inputs_var = self._feed_vars[self.mode]["inputs"]
labels_var = self._feed_vars[self.mode]["labels"] labels_var = self._feed_vars[self.mode]["labels"]
feed_list = [] feed_list = []
for var in inputs_var + labels_var: for var in inputs_var + labels_var:
if var.name in dist_main_block.vars: if var.name in dist_main_block.vars:
feed_list.append(dist_main_block.vars[var.name]) feed_list.append(dist_main_block.vars[var.name])
else:
copy_var = dist_main_block._clone_variable(var, var.persistable)
copy_var.desc.set_original_id(var.desc.original_id())
feed_list.append(copy_var)
# remove the first three ops if multi run fit/evaluate/predict # remove the first three ops if multi run fit/evaluate/predict
op_size = len(dist_main_block.ops) op_size = len(dist_main_block.ops)
...@@ -688,7 +672,7 @@ class Engine: ...@@ -688,7 +672,7 @@ class Engine:
batch_size_axis, rank_id) batch_size_axis, rank_id)
return len(group_ranks), group_ranks.index(rank_id) return len(group_ranks), group_ranks.index(rank_id)
return None, None return 1, 0
def _set_recompute_ckpts(self): def _set_recompute_ckpts(self):
# NOTE hack to enable recompute in engine api for GPT-3 # NOTE hack to enable recompute in engine api for GPT-3
...@@ -717,6 +701,11 @@ class Engine: ...@@ -717,6 +701,11 @@ class Engine:
} }
self._logger.info(logs) self._logger.info(logs)
def _validate_opt(self, optimizer):
optimizer._parameter_list = None
optimizer._param_groups = None
return optimizer
def save(self, path, training=True, mode=None): def save(self, path, training=True, mode=None):
if not mode: if not mode:
mode = self.mode mode = self.mode
......
...@@ -15,14 +15,18 @@ ...@@ -15,14 +15,18 @@
import logging import logging
from collections import defaultdict from collections import defaultdict
import paddle
from paddle.nn import Layer from paddle.nn import Layer
from paddle.jit import to_static, not_to_static from paddle.jit import to_static, not_to_static
from paddle.distributed.utils import get_logger from paddle.distributed.utils import get_logger
from paddle.fluid.framework import Operator, Parameter, _non_static_mode from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import program_guard from paddle.fluid.framework import program_guard
from paddle.fluid.executor import global_scope
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from .utils import to_list from .utils import to_list
from .converter import Converter
class ProxyLayer(Layer): class ProxyLayer(Layer):
...@@ -89,13 +93,14 @@ class ProxyLayer(Layer): ...@@ -89,13 +93,14 @@ class ProxyLayer(Layer):
# step 4. calculate metrics if needed # step 4. calculate metrics if needed
self._metric_vars[mode] = self.call_metrics(new_inputs) self._metric_vars[mode] = self.call_metrics(new_inputs)
def _predict(self, inputs): def _predict(self, inputs, labels):
""" """
Predict process of inner_layer with forward logic. Predict process of inner_layer with forward logic.
""" """
# step 1. save feed variables of Program # step 1. save feed variables of Program
mode = 'predict' mode = 'predict'
self._input_vars[mode] = inputs self._input_vars[mode] = inputs
self._label_vars[mode] = labels
# step 2. call inner_layer.forward # step 2. call inner_layer.forward
self._output_vars[mode] = self.inner_layer(*inputs) self._output_vars[mode] = self.inner_layer(*inputs)
...@@ -165,6 +170,10 @@ class ProxyLayer(Layer): ...@@ -165,6 +170,10 @@ class ProxyLayer(Layer):
def metric_vars(self): def metric_vars(self):
return self._metric_vars[self.mode] return self._metric_vars[self.mode]
@property
def startup_program(self):
return self.inner_layer._startup_program()
class BuildInfo: class BuildInfo:
...@@ -199,6 +208,7 @@ class ProgramHelper(object): ...@@ -199,6 +208,7 @@ class ProgramHelper(object):
self.build_info = BuildInfo() self.build_info = BuildInfo()
self._logger = get_logger(logging.INFO) self._logger = get_logger(logging.INFO)
self.lazy_init = False
def reset(self): def reset(self):
""" """
...@@ -221,8 +231,7 @@ class ProgramHelper(object): ...@@ -221,8 +231,7 @@ class ProgramHelper(object):
return return
self._logger.info("start to build program for mode = %s." % mode) self._logger.info("start to build program for mode = %s." % mode)
input_spec = [self.inputs_spec, self.labels_spec input_spec = [self.inputs_spec, self.labels_spec]
] if mode != 'predict' else [self.inputs_spec]
static_func = to_static(self.static_func(), input_spec=input_spec) static_func = to_static(self.static_func(), input_spec=input_spec)
func_name = '_' + mode func_name = '_' + mode
...@@ -238,6 +247,9 @@ class ProgramHelper(object): ...@@ -238,6 +247,9 @@ class ProgramHelper(object):
""" """
Create and Sync parameters into startup program. Create and Sync parameters into startup program.
""" """
if len(self.startup_program.global_block().ops) > 1:
self.lazy_init = True
return
for param in self.concrete_program.parameters: for param in self.concrete_program.parameters:
Parameter(name=param.name, Parameter(name=param.name,
desc=param, desc=param,
...@@ -294,6 +306,28 @@ class ProgramHelper(object): ...@@ -294,6 +306,28 @@ class ProgramHelper(object):
func_name = '_' + self.proxy_layer.mode func_name = '_' + self.proxy_layer.mode
return getattr(self.proxy_layer, func_name) return getattr(self.proxy_layer, func_name)
def init(self, main_program, place, dist_context):
if self.lazy_init:
return
for param in self.concrete_program.parameters:
# create var in scope and share parameters to scope
if param.name not in main_program.global_block().vars:
continue
# get param_var's dist_attr
var = main_program.global_block().vars[param.name]
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology,
"process_group": var_dist_attr.process_mesh.processes
}
# slice param_value with dist_attr
# share sliced_param_value with param_tensor in global_scope
param_tensor = global_scope().var(param.name).get_tensor()
sliced_param = Converter.slice_with_dist_attr(
param.numpy(), dist_attr)
param_tensor.set(sliced_param, place)
@property @property
def concrete_program(self): def concrete_program(self):
return self.static_func().concrete_program return self.static_func().concrete_program
...@@ -304,7 +338,12 @@ class ProgramHelper(object): ...@@ -304,7 +338,12 @@ class ProgramHelper(object):
@property @property
def startup_program(self): def startup_program(self):
try:
return self.proxy_layer.startup_program
except Exception as err:
if isinstance(err, AssertionError):
return self.concrete_program.startup_program return self.concrete_program.startup_program
raise err
@property @property
def input_vars(self): def input_vars(self):
......
...@@ -145,11 +145,6 @@ class Parallelizer: ...@@ -145,11 +145,6 @@ class Parallelizer:
params_grads): params_grads):
# NOTE: `apply_gradients` will add an Accumulator for a parameter only once, # NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
# but optimizer will be called repeatedly in re-launch, so optimizer need to be copied. # but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
if self._dist_context._dygraph_mode:
paddle.disable_static()
optimizer = copy.deepcopy(optimizer)
paddle.enable_static()
else:
optimizer = copy.deepcopy(optimizer) optimizer = copy.deepcopy(optimizer)
self._dist_context._lr_optimizer = optimizer self._dist_context._lr_optimizer = optimizer
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
...@@ -222,6 +217,7 @@ class Parallelizer: ...@@ -222,6 +217,7 @@ class Parallelizer:
config = {} config = {}
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["global_rank"] = rank config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass.apply([main_program], [startup_program], self._pass_context) dp_pass.apply([main_program], [startup_program], self._pass_context)
......
...@@ -270,7 +270,7 @@ class Inserter: ...@@ -270,7 +270,7 @@ class Inserter:
dtype=tensor_type, dtype=tensor_type,
type=tensor.type, type=tensor.type,
lod_level=tensor.lod_level) lod_level=tensor.lod_level)
block._insert_op(idx, cast_op = block._insert_op(idx,
type='cast', type='cast',
inputs={'X': [tensor]}, inputs={'X': [tensor]},
outputs={'Out': [out]}, outputs={'Out': [out]},
...@@ -279,6 +279,7 @@ class Inserter: ...@@ -279,6 +279,7 @@ class Inserter:
'out_dtype': out.dtype, 'out_dtype': out.dtype,
'op_role': op_role 'op_role': op_role
}) })
cast_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out return out
@staticmethod @staticmethod
...@@ -287,7 +288,7 @@ class Inserter: ...@@ -287,7 +288,7 @@ class Inserter:
op_type = 'send_v2' op_type = 'send_v2'
# use pair comm group # use pair comm group
process_group = new_process_group([src, dst]) process_group = new_process_group([src, dst])
block._insert_op(idx, send_op = block._insert_op(idx,
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
attrs={ attrs={
...@@ -297,6 +298,7 @@ class Inserter: ...@@ -297,6 +298,7 @@ class Inserter:
'op_role': op_role, 'op_role': op_role,
'dynamic_shape': True 'dynamic_shape': True
}) })
send_op._set_attr('op_namescope', "/auto_parallel/reshard")
@staticmethod @staticmethod
def insert_recv_op(block, idx, tensor, src, dst, op_role): def insert_recv_op(block, idx, tensor, src, dst, op_role):
...@@ -304,7 +306,7 @@ class Inserter: ...@@ -304,7 +306,7 @@ class Inserter:
op_type = 'recv_v2' op_type = 'recv_v2'
# use pair group # use pair group
process_group = new_process_group([src, dst]) process_group = new_process_group([src, dst])
block._insert_op(idx, recv_op = block._insert_op(idx,
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
outputs={'Out': [tensor]}, outputs={'Out': [tensor]},
...@@ -317,6 +319,7 @@ class Inserter: ...@@ -317,6 +319,7 @@ class Inserter:
'op_role': op_role, 'op_role': op_role,
'dynamic_shape': True 'dynamic_shape': True
}) })
recv_op._set_attr('op_namescope', "/auto_parallel/reshard")
@staticmethod @staticmethod
def insert_reset_lod_op(block, idx, X, Y, op_role): def insert_reset_lod_op(block, idx, X, Y, op_role):
...@@ -330,7 +333,7 @@ class Inserter: ...@@ -330,7 +333,7 @@ class Inserter:
dtype=X.dtype, dtype=X.dtype,
lod_level=X.lod_level) lod_level=X.lod_level)
block._insert_op(idx, reset_op = block._insert_op(idx,
type="lod_reset", type="lod_reset",
inputs={ inputs={
'X': X, 'X': X,
...@@ -338,6 +341,7 @@ class Inserter: ...@@ -338,6 +341,7 @@ class Inserter:
}, },
outputs={'Out': reset_lod_out}, outputs={'Out': reset_lod_out},
attrs={'op_role': op_role}) attrs={'op_role': op_role})
reset_op._set_attr('op_namescope', "/auto_parallel/reshard")
return reset_lod_out return reset_lod_out
@staticmethod @staticmethod
...@@ -359,11 +363,12 @@ class Inserter: ...@@ -359,11 +363,12 @@ class Inserter:
type=tensors[0].type, type=tensors[0].type,
persistable=False, persistable=False,
stop_gradient=False) stop_gradient=False)
block._insert_op(idx, concat_op = block._insert_op(idx,
type='concat', type='concat',
inputs=inputs, inputs=inputs,
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs=attrs) attrs=attrs)
concat_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out return out
@staticmethod @staticmethod
...@@ -391,11 +396,12 @@ class Inserter: ...@@ -391,11 +396,12 @@ class Inserter:
inputs = {'X': [tensor]} inputs = {'X': [tensor]}
outputs = {"Out": [out]} outputs = {"Out": [out]}
attrs = {"in_place": False} attrs = {"in_place": False}
block._insert_op(idx, slice_op = block._insert_op(idx,
type="assign", type="assign",
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
attrs=attrs) attrs=attrs)
slice_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out return out
# use split once # use split once
...@@ -427,11 +433,12 @@ class Inserter: ...@@ -427,11 +433,12 @@ class Inserter:
for i in range(num_or_sections) for i in range(num_or_sections)
] ]
out = outs[cur_idx] out = outs[cur_idx]
op = block._insert_op(idx, split_op = block._insert_op(idx,
type="split", type="split",
inputs=inputs, inputs=inputs,
outputs={'Out': outs}, outputs={'Out': outs},
attrs=attrs) attrs=attrs)
split_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out return out
# use slice # use slice
...@@ -449,12 +456,12 @@ class Inserter: ...@@ -449,12 +456,12 @@ class Inserter:
dtype=tensor.dtype, dtype=tensor.dtype,
type=tensor.type, type=tensor.type,
lod_level=tensor.lod_level) lod_level=tensor.lod_level)
block._insert_op(idx, slice_op = block._insert_op(idx,
type="slice", type="slice",
inputs=inputs, inputs=inputs,
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs=attrs) attrs=attrs)
slice_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out return out
@staticmethod @staticmethod
...@@ -482,11 +489,12 @@ class Inserter: ...@@ -482,11 +489,12 @@ class Inserter:
persistable=False, persistable=False,
stop_gradient=False) for i in range(num_or_sections) stop_gradient=False) for i in range(num_or_sections)
] ]
block._insert_op(idx, split_op = block._insert_op(idx,
type="split", type="split",
inputs=inputs, inputs=inputs,
outputs={'Out': outs}, outputs={'Out': outs},
attrs=attrs) attrs=attrs)
split_op._set_attr('op_namescope', "/auto_parallel/reshard")
return outs return outs
@staticmethod @staticmethod
...@@ -514,12 +522,13 @@ class Inserter: ...@@ -514,12 +522,13 @@ class Inserter:
attrs=attrs, attrs=attrs,
shape=[0], shape=[0],
op_type='fill_constant') op_type='fill_constant')
block._insert_op(idx, fillconstant_op = block._insert_op(idx,
type='fill_constant', type='fill_constant',
inputs=inputs, inputs=inputs,
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs=attrs) attrs=attrs)
out.stop_gradient = True out.stop_gradient = True
fillconstant_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out return out
@staticmethod @staticmethod
...@@ -537,7 +546,8 @@ class Inserter: ...@@ -537,7 +546,8 @@ class Inserter:
fill_constant_out.stop_gradient = True fill_constant_out.stop_gradient = True
# insert c_allreduce_sum op # insert c_allreduce_sum op
block._insert_op(idx + 1, allreduce_op = block._insert_op(
idx + 1,
type="c_allreduce_sum", type="c_allreduce_sum",
inputs={'X': [fill_constant_out]}, inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]}, outputs={'Out': [fill_constant_out]},
...@@ -546,13 +556,15 @@ class Inserter: ...@@ -546,13 +556,15 @@ class Inserter:
'use_calc_stream': True, 'use_calc_stream': True,
'op_role': op_role 'op_role': op_role
}) })
allreduce_op._set_attr('op_namescope', "/auto_parallel/reshard")
# insert c_sync_calc_stream op # insert c_sync_calc_stream op
block._insert_op(idx + 2, sync_calc_op = block._insert_op(
idx + 2,
type="c_sync_calc_stream", type="c_sync_calc_stream",
inputs={'X': [fill_constant_out]}, inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]}, outputs={'Out': [fill_constant_out]},
attrs={'op_role': op_role}) attrs={'op_role': op_role})
sync_calc_op._set_attr('op_namescope', "/auto_parallel/reshard")
idx_offset = 3 idx_offset = 3
# insert c_allgather op # insert c_allgather op
...@@ -569,7 +581,7 @@ class Inserter: ...@@ -569,7 +581,7 @@ class Inserter:
type=tensor.type, type=tensor.type,
persistable=False, persistable=False,
stop_gradient=False) stop_gradient=False)
block._insert_op(idx + idx_offset, allgather_op = block._insert_op(idx + idx_offset,
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
outputs={'Out': [allgather_out]}, outputs={'Out': [allgather_out]},
...@@ -579,6 +591,7 @@ class Inserter: ...@@ -579,6 +591,7 @@ class Inserter:
'nranks': group.nranks, 'nranks': group.nranks,
'op_role': op_role 'op_role': op_role
}) })
allgather_op._set_attr('op_namescope', "/auto_parallel/reshard")
idx_offset += 1 idx_offset += 1
# insert split op # insert split op
......
...@@ -26,6 +26,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _k ...@@ -26,6 +26,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _k
from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op
from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from ..auto_parallel.utils import is_forward_op, is_backward_op, is_loss_op
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
...@@ -222,21 +223,33 @@ class AMPState(object): ...@@ -222,21 +223,33 @@ class AMPState(object):
loss_op = get_loss_op(self._block) loss_op = get_loss_op(self._block)
loss_op_index = find_op_index(self._block.desc, loss_op.desc) loss_op_index = find_op_index(self._block.desc, loss_op.desc)
appended_grad_times = 0
idx = loss_op_index + 1 idx = loss_op_index + 1
while idx < len(ops): while idx < len(ops):
num_cast_ops = 0 num_cast_ops = 0
grad_op = ops[idx] grad_op = ops[idx]
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# which will affect the dist_op to insert allreduce_sum op.
op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
if is_backward_op(grad_op) and (is_forward_op(ops[idx - 1])
or is_loss_op(ops[idx - 1])):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
grad_op_orig_id = grad_op.desc.original_id() grad_op_orig_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context dist_op_context = dist_context.dist_op_context
if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id: if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(grad_op_orig_id) == False: # fp32 if self._is_fp16_op(grad_op_orig_id) == False: # fp32
num_cast_ops = self._insert_cast_op_backward( num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP16, grad_op, idx, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, dist_context) core.VarDesc.VarType.FP32, dist_context,
appended_grad_times)
elif self._is_fp16_op(grad_op_orig_id) == True: # fp16 elif self._is_fp16_op(grad_op_orig_id) == True: # fp16
num_cast_ops = self._insert_cast_op_backward( num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP32, grad_op, idx, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dist_context) core.VarDesc.VarType.FP16, dist_context,
appended_grad_times)
elif grad_op.type == "sum": elif grad_op.type == "sum":
in_var_name = grad_op.desc.input_arg_names()[0] in_var_name = grad_op.desc.input_arg_names()[0]
src_dtype = self._block.var(in_var_name).dtype src_dtype = self._block.var(in_var_name).dtype
...@@ -258,7 +271,7 @@ class AMPState(object): ...@@ -258,7 +271,7 @@ class AMPState(object):
_update_backward_cast_ops(params_grads, dist_context) _update_backward_cast_ops(params_grads, dist_context)
def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype, def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype,
dist_context): dist_context, appended_grad_times):
""" only for backward cast """ """ only for backward cast """
def _keep_fp32_input(op, in_name): def _keep_fp32_input(op, in_name):
...@@ -328,7 +341,10 @@ class AMPState(object): ...@@ -328,7 +341,10 @@ class AMPState(object):
grad_op) grad_op)
fwd_cast_name = self._var_name_dict[fwd_op_id][ fwd_cast_name = self._var_name_dict[fwd_op_id][
out_var_name_prefix] out_var_name_prefix]
cast_name = fwd_cast_name + "@GRAD" suffix = ""
if "@RENAME" in out_var_name:
suffix = out_var_name[out_var_name.find("@RENAME"):]
cast_name = fwd_cast_name + "@GRAD" + suffix
cast_var = self._block.vars.get(cast_name) cast_var = self._block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dst_dtype: if cast_var is None or cast_var.dtype != dst_dtype:
grad_op.desc._rename_output(out_var_name, cast_name) grad_op.desc._rename_output(out_var_name, cast_name)
...@@ -347,6 +363,8 @@ class AMPState(object): ...@@ -347,6 +363,8 @@ class AMPState(object):
stop_gradient=out_var.stop_gradient) stop_gradient=out_var.stop_gradient)
set_var_dist_attr(dist_context, cast_var, set_var_dist_attr(dist_context, cast_var,
ref_mapping, ref_mesh) ref_mapping, ref_mesh)
dist_op_context.grad_var_to_var[
appended_grad_times][cast_name] = fwd_cast_name
cast_op = self._block._insert_op( cast_op = self._block._insert_op(
idx + 1, idx + 1,
......
...@@ -45,6 +45,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -45,6 +45,7 @@ class DataParallelOptimizationPass(PassBase):
# NOTE not use depence on loss and param_grads # NOTE not use depence on loss and param_grads
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("global_rank", -1) self.set_attr("global_rank", -1)
self.set_attr("use_sharding", False)
# {grad1: group1, grad2: group1, grad3: group2} # {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory # record the order for fuse grad data memory
self._grad_name_to_group_map = OrderedDict() self._grad_name_to_group_map = OrderedDict()
...@@ -71,6 +72,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -71,6 +72,7 @@ class DataParallelOptimizationPass(PassBase):
self.dist_context = self.get_attr("dist_context") self.dist_context = self.get_attr("dist_context")
self.global_rank = int(self.get_attr("global_rank")) self.global_rank = int(self.get_attr("global_rank"))
self.use_sharding = self.get_attr("use_sharding")
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
self._analyze_program() self._analyze_program()
...@@ -224,7 +226,8 @@ class DataParallelOptimizationPass(PassBase): ...@@ -224,7 +226,8 @@ class DataParallelOptimizationPass(PassBase):
num_dp_comm_stream = len(set(self._group_to_grad_name_map.keys())) num_dp_comm_stream = len(set(self._group_to_grad_name_map.keys()))
if num_dp_comm_stream > __max_stream_num_allow__: if num_dp_comm_stream > __max_stream_num_allow__:
return False return False
if self.use_sharding:
return False
return True return True
def _comms_overlap_calc(self): def _comms_overlap_calc(self):
......
...@@ -151,13 +151,14 @@ class RecomputeState(ProgramStats): ...@@ -151,13 +151,14 @@ class RecomputeState(ProgramStats):
# modify dropout op's desc # modify dropout op's desc
self._ops.insert(op_idx, seed_op) self._ops.insert(op_idx, seed_op)
cur_op.desc.set_input("Seed", [var_unique_name]) cur_op.desc.set_input("Seed", [var_unique_name])
cur_op.desc.remove_attr("fix_seed") cur_op._remove_attr("fix_seed")
cur_op.desc.remove_attr("seed") cur_op._remove_attr("seed")
cur_op_dist_attr.set_input_dist_attr(seed_var.name, cur_op_dist_attr.set_input_dist_attr(seed_var.name,
seed_var_dist_attr) seed_var_dist_attr)
self._block._sync_with_cpp()
op_idx += 2 op_idx += 2
self._block._sync_with_cpp()
def _find_op_index(block, cur_op): def _find_op_index(block, cur_op):
for idx in range(block.desc.op_size()): for idx in range(block.desc.op_size()):
...@@ -339,12 +340,13 @@ class RecomputePass(PassBase): ...@@ -339,12 +340,13 @@ class RecomputePass(PassBase):
grad_op = ops[i] grad_op = ops[i]
# remove some attrs of dropout_grad op's desc # remove some attrs of dropout_grad op's desc
if grad_op.type == "dropout_grad": if grad_op.type == "dropout_grad":
grad_op.desc.remove_attr("fix_seed") grad_op._remove_attr("fix_seed")
grad_op.desc.remove_attr("seed") grad_op._remove_attr("seed")
main_block._sync_with_cpp()
# rename grad op's var_name which is not in 'vars_in_memory' # rename grad op's var_name which is not in 'vars_in_memory'
for key in var_name_dict: for key in var_name_dict:
if key not in grad_op.input_arg_names + grad_op.output_arg_names:
continue
self.reset_op_dist_attr(grad_op, var_name_dict) self.reset_op_dist_attr(grad_op, var_name_dict)
_rename_arg_([grad_op.desc], key, var_name_dict[key]) _rename_arg_([grad_op.desc], key, var_name_dict[key])
...@@ -358,11 +360,11 @@ class RecomputePass(PassBase): ...@@ -358,11 +360,11 @@ class RecomputePass(PassBase):
idx -= 1 idx -= 1
segment_descs = ckpt_ops_dict[fwd_op_id][1] segment_descs = ckpt_ops_dict[fwd_op_id][1]
for _, op_desc in reversed(list(enumerate(segment_descs))): for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_desc = main_block.desc._insert_op(idx) rc_op = main_block._insert_op_without_sync(idx,
type='nop')
rc_desc = rc_op.desc
rc_desc.copy_from(op_desc) rc_desc.copy_from(op_desc)
rc_desc.set_original_id(rc_desc.id()) rc_desc.set_original_id(rc_desc.id())
rc_op = Operator(main_block, rc_desc)
main_block.ops.insert(idx, rc_op)
# set recomputed ops' dist attr # set recomputed ops' dist attr
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id( fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
op_desc.original_id()) op_desc.original_id())
...@@ -371,7 +373,6 @@ class RecomputePass(PassBase): ...@@ -371,7 +373,6 @@ class RecomputePass(PassBase):
var_name_dict) var_name_dict)
ckpt_ops_dict[fwd_op_id][0] = False ckpt_ops_dict[fwd_op_id][0] = False
main_block._sync_with_cpp()
main_program._sync_with_cpp() main_program._sync_with_cpp()
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from functools import reduce from functools import reduce
from collections import OrderedDict from collections import OrderedDict, defaultdict
import numpy as np import numpy as np
import paddle import paddle
...@@ -27,10 +27,7 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_di ...@@ -27,10 +27,7 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_di
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = [ _skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read']
'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split',
'assign', "send_v2"
]
# update here to support new optimizers # update here to support new optimizers
_supported_optimizer_type = [ _supported_optimizer_type = [
"adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum", "adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
...@@ -38,6 +35,11 @@ _supported_optimizer_type = [ ...@@ -38,6 +35,11 @@ _supported_optimizer_type = [
] ]
def _is_reshard_op(op):
return op.desc.has_attr("op_namescope") and \
"/auto_parallel/reshard" in op.desc.attr('op_namescope')
# NOTE we add the "auto_parallel" prefix to the pass in order to # NOTE we add the "auto_parallel" prefix to the pass in order to
# indicate that this pass should obey some constrains by auto_parallel # indicate that this pass should obey some constrains by auto_parallel
# for example all ops and vars should has dist attr before and after pass # for example all ops and vars should has dist attr before and after pass
...@@ -100,6 +102,10 @@ class ShardingPass(PassBase): ...@@ -100,6 +102,10 @@ class ShardingPass(PassBase):
for op in main_block.ops: for op in main_block.ops:
if not _is_forward_op(op) or op.type in _skip_ops: if not _is_forward_op(op) or op.type in _skip_ops:
continue continue
# NOTE: there aren't dist_attr in the ops which reshard insert,
# and should be skip in sharding.
if _is_reshard_op(op):
continue
group = _inference_data_parallel_group_for_operator( group = _inference_data_parallel_group_for_operator(
self.global_rank, op, self._dist_context) self.global_rank, op, self._dist_context)
if group is not None: if group is not None:
...@@ -187,8 +193,28 @@ class ShardingPass(PassBase): ...@@ -187,8 +193,28 @@ class ShardingPass(PassBase):
if self._is_parameter_in_local_shard(param_name): if self._is_parameter_in_local_shard(param_name):
reversed_x.append(input_name) reversed_x.append(input_name)
# NOTE: When `reversed_x` is [], check_finite_and_unscale will be replaced by `fill_constant` op.
# The output of check_finite_and_unscale is be set False
if reversed_x:
op.desc.set_input('X', reversed_x) op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x) op.desc.set_output('Out', reversed_x)
else:
if op.type == "check_finite_and_unscale":
out_name = op.output_arg_names[0]
out_var = main_block.vars[out_name]
main_block._remove_op(idx, sync=False)
main_block._insert_op_without_sync(
idx,
type="fill_constant",
outputs={"Out": out_var},
attrs={
"shape": out_var.shape,
"dtype": out_var.dtype,
"value": 0,
})
else:
main_block._remove_op(idx, sync=False)
main_block._sync_with_cpp() main_block._sync_with_cpp()
...@@ -359,6 +385,17 @@ class ShardingPass(PassBase): ...@@ -359,6 +385,17 @@ class ShardingPass(PassBase):
else: else:
op._set_attr("ring_id", self.outer_dp_group.id) op._set_attr("ring_id", self.outer_dp_group.id)
# NOTE:
# var@GRAD = sum(var@GRAD@RENAME@0, var@GRAD@RENAME@1)
# If the var is not in local rank and it is output of many ops, or the var is renamed in another words,
# the sum op should be removed.
if _is_param_grad_sum_op(op, main_block):
out_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(out_name)
sharding_info = self.varname_to_sharding_info[base_name]
if not sharding_info.is_in_local_shard(base_name):
main_block._remove_op(idx, sync=False)
main_block._sync_with_cpp() main_block._sync_with_cpp()
def _shard_parameter(self, main_block, startup_block): def _shard_parameter(self, main_block, startup_block):
...@@ -606,6 +643,22 @@ def _is_param_grad_allreduce_op(op, block, dp_ring_ids): ...@@ -606,6 +643,22 @@ def _is_param_grad_allreduce_op(op, block, dp_ring_ids):
return block.var(base_name).is_parameter return block.var(base_name).is_parameter
def _is_param_grad_sum_op(op, block):
if not is_backward_op(op):
return False
if op.type != "sum":
return False
output_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(output_name)
if not block.has_var(base_name):
return False
return block.var(base_name).is_parameter
def _is_forward_op(op): def _is_forward_op(op):
return op.attr("op_role") == 0 return op.attr("op_role") == 0
......
...@@ -33,7 +33,7 @@ import paddle.distributed.auto_parallel as auto ...@@ -33,7 +33,7 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine from paddle.distributed.auto_parallel.engine import Engine
paddle.enable_static() paddle.enable_static()
batch_size = 1 batch_size = 2
batch_num = 10 batch_num = 10
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
...@@ -133,10 +133,7 @@ def train(fetch): ...@@ -133,10 +133,7 @@ def train(fetch):
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset, engine.fit(train_dataset, batch_size=batch_size, fetches=fetches)
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size,
fetches=fetches)
# eval # eval
eval_dataset = MyDataset(batch_size) eval_dataset = MyDataset(batch_size)
......
...@@ -67,15 +67,14 @@ def create_data_holder(batch_size): ...@@ -67,15 +67,14 @@ def create_data_holder(batch_size):
def generate_model(strategy): def generate_model(strategy):
modeling.init_global() modeling.init_global()
modeling._global_process_mesh = list(
range(paddle.distributed.get_world_size()))
if strategy == "serial": if strategy == "serial":
modeling._global_parallel_strategy = "serial" modeling._global_parallel_strategy = "serial"
modeling._global_process_mesh = [0]
elif strategy == "mp": elif strategy == "mp":
modeling._global_parallel_strategy = "mp" modeling._global_parallel_strategy = "mp"
modeling._global_process_mesh = [0, 1]
elif strategy == "dp": elif strategy == "dp":
modeling._global_parallel_strategy = "dp" modeling._global_parallel_strategy = "dp"
modeling._global_process_mesh = [0, 1]
else: else:
raise ValueError("Only support serial, mp2 and dp2.") raise ValueError("Only support serial, mp2 and dp2.")
......
...@@ -27,7 +27,6 @@ from paddle.io import Dataset ...@@ -27,7 +27,6 @@ from paddle.io import Dataset
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.hepler import ProgramHelper
from test_to_static import MLPLayer, MyDataset from test_to_static import MLPLayer, MyDataset
......
...@@ -23,11 +23,12 @@ import paddle.nn.functional as F ...@@ -23,11 +23,12 @@ import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from paddle import LazyGuard
from paddle.io import Dataset from paddle.io import Dataset
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.hepler import ProgramHelper from paddle.distributed.auto_parallel.helper import ProgramHelper
batch_size = 4 batch_size = 4
batch_num = 30 batch_num = 30
...@@ -158,5 +159,29 @@ class TestToStatic(unittest.TestCase): ...@@ -158,5 +159,29 @@ class TestToStatic(unittest.TestCase):
engine.predict(dataset, batch_size=batch_size) engine.predict(dataset, batch_size=batch_size)
class TestLazyInit(unittest.TestCase):
def test_lazy_init(self):
with LazyGuard():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
metrics = paddle.metric.Accuracy()
loss = paddle.nn.CrossEntropyLoss()
inputs = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels = InputSpec([batch_size], 'int64', 'label')
program_helper = ProgramHelper(mlp, loss, [metrics], [inputs], [labels])
program_helper.build_program(mode='train')
ops = program_helper.startup_program.block(0).ops
vars = program_helper.startup_program.block(0).vars
assert len(vars.keys()) == len(ops)
program_helper.reset()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -914,12 +914,6 @@ class GPTForPretraining(nn.Layer): ...@@ -914,12 +914,6 @@ class GPTForPretraining(nn.Layer):
initializer_range=0.02, initializer_range=0.02,
): ):
super(GPTForPretraining, self).__init__() super(GPTForPretraining, self).__init__()
self.output_embeddings = nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.ParamAttr(name="output_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.gpt = gpt self.gpt = gpt
def forward(self, def forward(self,
...@@ -938,9 +932,45 @@ class GPTForPretraining(nn.Layer): ...@@ -938,9 +932,45 @@ class GPTForPretraining(nn.Layer):
encoder_outputs, cached_kvs = outputs[:2] encoder_outputs, cached_kvs = outputs[:2]
else: else:
encoder_outputs = outputs encoder_outputs = outputs
logits = paddle.matmul(encoder_outputs,
self.output_embeddings.weight, x = encoder_outputs
transpose_y=True) w = self.gpt.embeddings.word_embeddings.weight
mesh = _global_process_mesh
x_dims_mapping = [-1 for i in range(len(x.shape))]
w_dims_mapping = [-1 for i in range(len(w.shape))]
if _global_parallel_strategy == "pp":
mesh = PP_MESH_LIST[-1]
elif _global_parallel_strategy == "dp":
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
elif _global_parallel_strategy == "mp":
w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_mp":
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_pp":
mesh = DPPP_MESH_LIST[-1]
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
elif _global_parallel_strategy == "mp_pp":
mesh = MPPP_MESH_LIST[-1]
w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_mp_pp":
mesh = DPMPPP_MESH_LIST[-1]
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)]
matmul = auto.shard_op(paddle.matmul,
dist_attr={
'process_mesh': mesh,
x: {
"dims_mapping": x_dims_mapping
},
w: {
"dims_mapping": w_dims_mapping
}
})
logits = matmul(x, w, transpose_y=True)
if use_cache: if use_cache:
return logits, cached_kvs return logits, cached_kvs
else: else:
...@@ -958,6 +988,26 @@ class GPTPretrainingCriterion(nn.Layer): ...@@ -958,6 +988,26 @@ class GPTPretrainingCriterion(nn.Layer):
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
def forward(self, prediction_scores, masked_lm_labels, loss_mask): def forward(self, prediction_scores, masked_lm_labels, loss_mask):
mesh = _global_process_mesh
dims_mapping = [-1 for i in range(len(loss_mask.shape))]
if _global_parallel_strategy == "dp":
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_mp":
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_pp":
mesh = DPPP_MESH_LIST[-1]
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_mp_pp":
mesh = DPMPPP_MESH_LIST[-1]
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
auto.shard_tensor(loss_mask,
dist_attr={
"process_mesh": mesh,
"dims_mapping": dims_mapping
})
masked_lm_loss = self.loss_func(prediction_scores, masked_lm_loss = self.loss_func(prediction_scores,
masked_lm_labels.unsqueeze(2)) masked_lm_labels.unsqueeze(2))
loss_mask = loss_mask.reshape([-1]) loss_mask = loss_mask.reshape([-1])
......
...@@ -178,7 +178,6 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -178,7 +178,6 @@ class AutoPallelPassTestBase(DistPassTestBase):
preds = model(tokens, position_ids, attention_mask) preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion() criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask) loss = criterion(preds, labels, loss_mask)
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
if kwargs.get('optimizer', None) == "LarsMomentum": if kwargs.get('optimizer', None) == "LarsMomentum":
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer( optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
...@@ -189,7 +188,7 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -189,7 +188,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-08, epsilon=1e-08,
grad_clip=clip) grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
startup_program = paddle.static.default_startup_program() startup_program = paddle.static.default_startup_program()
_, _, dist_startup_prog, dist_main_prog = optimizer.minimize( _, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册