未验证 提交 c043a21b 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] adapt for 2d laplace (#41601)

* add default_ctx in backward.py

* record grad_var_to_var with grad_times

* fix backward

* update annotation

* add complete_high_order_grad in complete_forward

* add dist slice op

* update grad_var_to_var type

* update partition_block init mapping before loss op

* update compatible for 'XShape' & update 'allreduce_vars'

* add dist reshape op when input dim equal to output dim

* update 'set_grad_var_shape' with grad_var_to_var

* fix dist slice

* fix set_grad_var_shape

* add dist pnorm op

* fix dist pnorm dist_attr

* fix engine startprogram & adapt highorder grad

* fix set_grad_var_shape when mp

* update unittest

* update cmakelist

* default strategy in engine: dp

* bug fix

* tiny fix

* flatten outputs

* fix default strategy

* init default ctx

* tiny fix

* test=allcase
上级 2d29d833
......@@ -648,6 +648,9 @@ class Completer:
self._dist_context.copy_dist_attr_from_graph_to_program()
self._dist_context.clear_dist_info_for_graph()
# NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient
self.complete_high_order_grad_annotation(serial_main_program)
# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()
......@@ -655,6 +658,164 @@ class Completer:
return serial_main_program
def complete_high_order_grad_annotation(self, serial_main_program):
"""
NOTE:
[HighOrderGrad] Complete the annotation of vars and ops only for high order gradient.
This function is temporary to support high order gradient, and will be removed in the future.
"""
def _is_grad_var_name(name):
if "@GRAD" in name:
return True
return False
def _get_op_by_id(ops, id):
for op in ops:
if op.desc.id() == id:
return op
return None
ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars
dist_op_context = self._dist_context.dist_op_context
grad_var_to_var = dist_op_context.grad_var_to_var
appended_grad_times = 0
for idx in range(0, len(ops)):
op = ops[idx]
if int(op.attr('op_role')) == int(
core.op_proto_and_checker_maker.OpRole.Forward):
continue
if int(op.attr('op_role')) == int(
core.op_proto_and_checker_maker.OpRole.Backward) and int(
ops[idx - 1].attr('op_role')) == int(
core.op_proto_and_checker_maker.OpRole.Forward):
appended_grad_times += 1
# complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx]
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(
ops, dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()])
assert forward_op is not None
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op)
fwd_op_process_mesh = fwd_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = fwd_op_process_mesh
for input_name in grad_op.input_arg_names:
if input_name not in forward_op.input_arg_names and input_name not in forward_op.output_arg_names:
if input_name in grad_var_to_var[appended_grad_times]:
fwd_name = grad_var_to_var[appended_grad_times][
input_name]
ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping(
fwd_name)
else:
input_var = vars[input_name]
ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
input_var).dims_mapping
else:
if fwd_op_dist_attr.get_input_dims_mapping(input_name):
ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
input_name)
else:
ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping(
input_name)
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_name)
grad_op_dist_attr.set_input_dims_mapping(input_name,
ref_dims_mapping)
for output_name in grad_op.output_arg_names:
assert output_name in grad_var_to_var[appended_grad_times]
fwd_name = grad_var_to_var[appended_grad_times][output_name]
ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
fwd_name)
# var
output_var = vars[output_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = fwd_op_process_mesh
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# op
grad_op_dist_attr.set_output_dims_mapping(output_name,
ref_dims_mapping)
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
# grad ops that have not a corresponding mapping in grad_op_id_to_op_id
else:
if grad_op.type == 'sum':
assert all(map(_is_grad_var_name, grad_op.input_arg_names))
output_name = grad_op.output_arg_names[0]
assert output_name in grad_var_to_var[appended_grad_times], \
"sum op's output '{}' has no corresponding var".format(
output_name)
ref_fwd_var_name = grad_var_to_var[appended_grad_times][
output_name]
ref_fwd_var = vars[ref_fwd_var_name]
ref_fwd_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_fwd_var)
ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping
ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping
tensor_dist_attr.process_mesh = ref_fwd_process_mesh
output_var = vars[output_name]
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_fwd_process_mesh
for var_name in grad_op.input_arg_names:
grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_fwd_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping)
elif grad_op.type == 'fill_zeros_like':
ref_var_name = grad_op.input_arg_names[0]
ref_var = vars[ref_var_name]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_var)
ref_dims_mapping = ref_dist_attr.dims_mapping
ref_process_mesh = ref_dist_attr.process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = ref_process_mesh
output_var_name = grad_op.output_arg_names[0]
output_var = vars[output_var_name]
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_process_mesh
grad_op_dist_attr.set_input_dims_mapping(ref_var_name,
ref_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(output_var_name,
ref_dims_mapping)
elif grad_op.type in ['shape', 'fill_constant']:
continue
else:
raise ValueError("got unexpect op [{}]".format(
str(grad_op.type)))
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
def complete_backward_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the backward phase for parallel program."""
......@@ -689,6 +850,8 @@ class Completer:
ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars
dist_op_context = self._dist_context.dist_op_context
grad_var_to_var = dist_op_context.grad_var_to_var[len(
dist_op_context.grad_var_to_var)]
for idx in range(first_backward_op_idx, len(ops)):
......@@ -765,102 +928,111 @@ class Completer:
grad_op, grad_op_dist_attr)
continue
# op dist attr
forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op)
forward_op_process_mesh = forward_op_dist_attr.process_mesh
fwd_op_process_mesh = fwd_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = forward_op_process_mesh
grad_op_dist_attr.process_mesh = fwd_op_process_mesh
# var
for input_name in grad_op.input_arg_names:
input_var = vars[input_name]
ref_dims_mapping = None
if "@GRAD" in input_name:
forward_name = _get_forward_varname_from_grad_varname(
input_name)
ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
forward_name)
if input_name not in forward_op.input_arg_names and input_name not in forward_op.output_arg_names:
if input_name in grad_var_to_var:
fwd_name = grad_var_to_var[input_name]
ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping(
fwd_name)
else:
input_var = vars[input_name]
ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
input_var).dims_mapping
else:
if forward_op_dist_attr.get_input_dims_mapping(
input_name):
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
if fwd_op_dist_attr.get_input_dims_mapping(input_name):
ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
input_name)
else:
ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping(
input_name)
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_var.name)
input_name)
grad_op_dist_attr.set_input_dims_mapping(input_name,
ref_dims_mapping)
for output_name in grad_op.desc.output_names():
assert len(grad_op.desc.output(output_name)) in [0, 1]
if _is_grad_var_name(output_name):
input_name = _get_forward_varname_from_grad_varname(
output_name)
else:
assert grad_op.type in [
"cast", "c_identity", "c_allreduce_sum"
]
input_name = "X"
assert input_name in forward_op.desc.input_names(
), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format(
output_name, grad_op.type, input_name)
if len(grad_op.desc.output(output_name)) == 1:
# tensor dist attr
output_var = vars[grad_op.desc.output(output_name)[0]]
forward_name = _get_forward_varname_from_grad_varname(
output_var.name)
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
forward_name)
output_var_dist_attr = TensorDistributedAttribute()
output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = forward_op_process_mesh
self._dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr)
grad_op_dist_attr.set_output_dims_mapping(
output_var.name, ref_dims_mapping)
for output_name in grad_op.output_arg_names:
assert output_name in grad_var_to_var
fwd_name = grad_var_to_var[output_name]
ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
fwd_name)
# var
output_var = vars[output_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = fwd_op_process_mesh
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# op
grad_op_dist_attr.set_output_dims_mapping(output_name,
ref_dims_mapping)
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id
# grad ops that have not a corresponding mapping in grad_op_id_to_op_id
else:
assert grad_op.type == "sum", "got unexpect op [{}]".format(
str(grad_op.type))
assert all(map(_is_grad_var_name, grad_op.input_arg_names))
assert len(grad_op.output_arg_names) == 1
ref_forward_var_name = _get_forward_varname_from_grad_varname(
grad_op.output_arg_names[0])
forward_var = vars[ref_forward_var_name]
ref_forward_var_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).dims_mapping
ref_forward_var_process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).process_mesh
if grad_op.type == 'sum':
assert all(map(_is_grad_var_name, grad_op.input_arg_names))
output_name = grad_op.output_arg_names[0]
assert output_name in grad_var_to_var, "sum op's output '{}' has no corresponding var".format(
output_name)
ref_fwd_var_name = grad_var_to_var[output_name]
ref_fwd_var = vars[ref_fwd_var_name]
ref_fwd_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_fwd_var)
ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping
ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping
tensor_dist_attr.process_mesh = ref_fwd_process_mesh
output_var = vars[output_name]
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping
tensor_dist_attr.process_mesh = ref_forward_var_process_mesh
self._dist_context.set_tensor_dist_attr_for_program(
vars[grad_op.output_arg_names[0]], tensor_dist_attr)
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_fwd_process_mesh
for var_name in grad_op.input_arg_names:
grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_fwd_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping)
elif grad_op.type == 'fill_zeros_like':
ref_var_name = grad_op.input_arg_names[0]
ref_var = vars[ref_var_name]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_var)
ref_dims_mapping = ref_dist_attr.dims_mapping
ref_process_mesh = ref_dist_attr.process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = ref_process_mesh
output_var_name = grad_op.output_arg_names[0]
output_var = vars[output_var_name]
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_process_mesh
grad_op_dist_attr.set_input_dims_mapping(ref_var_name,
ref_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(output_var_name,
ref_dims_mapping)
else:
raise ValueError("got unexpect op [{}]".format(
str(grad_op.type)))
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh
for var_name in grad_op.input_arg_names:
assert _get_forward_varname_from_grad_varname(
var_name) == ref_forward_var_name
grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_forward_var_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(
grad_op.output_arg_names[0], ref_forward_var_dims_mapping)
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
......
......@@ -120,6 +120,11 @@ class DistributedContext:
def dist_startup_programs(self):
return self._dist_startup_programs
@property
def is_annotation(self):
return len(self._dist_tensors_for_program) or len(
self._dist_ops_for_program)
def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \
'The type of dim_mapping must be ProcessMesh.'
......@@ -577,6 +582,7 @@ class DistributedOperatorContext:
self._cur_src_op = None
self._cur_dist_attr = None
self.grad_op_id_to_op_id = {}
self.grad_var_to_var = defaultdict(dict)
self._work_block = None
self.already_init_sync_vars = set()
self.varname_mapping = None
......
......@@ -16,6 +16,7 @@ import abc
import numpy as np
import paddle
from .utils import to_list
from paddle.fluid.layers.utils import flatten
from paddle.io import DataLoader, DistributedBatchSampler
......@@ -56,16 +57,17 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
data_parallel_world_size=None,
data_parallel_rank=None,
drop_last=False,
inputs=[]):
sample_generator=True):
self.feed_list = feed_list
self.places = places
self.steps_per_epoch = steps_per_epoch
self._sample_generator = sample_generator
super(NonIterableGeneratorLoader, self).__init__(
dataset, batch_size, epochs, data_parallel_world_size,
data_parallel_rank, drop_last)
self._inner_dataloader = self._create_inner_dataloader()
self._steps = self._infer_steps()
self._inputs = inputs
def __iter__(self):
self._cur_step = 0
......@@ -91,27 +93,28 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
return steps_per_epoch
def _create_inner_dataloader(self):
def data_generator():
def sample_data_generator():
batch_data = None
for step, data in enumerate(self.dataset):
if not isinstance(data, list):
data = to_list(data)
if self.batch_size == 1:
yield data
data = flatten(data)
if batch_data is None:
batch_data = [[] for i in range(len(data))]
for idx in range(len(data)):
batch_data[idx].append(data[idx])
if (step + 1) % self.batch_size == 0:
yield batch_data
batch_data = None
else:
if batch_data is None:
batch_data = [[] for i in range(len(data))]
for idx in range(len(data)):
batch_data[idx].append(data[idx])
if (step + 1) % self.batch_size == 0:
yield batch_data
batch_data = None
def batch_data_generator():
for data in self.dataset:
data = flatten(data)
yield data
dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False)
dataloader.set_batch_generator(data_generator, self.places)
if self._sample_generator:
dataloader.set_batch_generator(sample_data_generator, self.places)
else:
dataloader.set_batch_generator(batch_data_generator, self.places)
return dataloader
......@@ -17,18 +17,22 @@ import logging
from collections import defaultdict
import paddle
import paddle.distributed.auto_parallel as auto
from paddle import fluid
from paddle.io import Dataset
from paddle.metric import Metric
from paddle.static import InputSpec
from paddle.fluid import core
from paddle.fluid import program_guard
from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed.passes import new_pass, PassContext
from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext
from .mapper import mapping
from .cluster import Cluster
......@@ -61,6 +65,12 @@ class Engine:
self.strategy = strategy
self._executor = None
self._cur_rank = paddle.distributed.get_rank()
self._nranks = paddle.distributed.get_world_size()
self._saver = DistributedSaver()
self._logger = get_logger(logging.INFO)
self._default_strategy = None
self._orig_main_prog = fluid.default_main_program()
self._orig_startup_prog = fluid.default_startup_program()
self._orig_dist_context = get_default_distributed_context()
......@@ -70,9 +80,6 @@ class Engine:
self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._dist_contexts = {}
self._pass_contexts = {}
self._cur_rank = paddle.distributed.get_rank()
self._logger = get_logger(logging.INFO)
self._saver = DistributedSaver()
self._feed_vars = {}
self._fetch_vars = {}
......@@ -86,13 +93,11 @@ class Engine:
# TODO: check loss type
self._loss = loss
self._metrics = to_list(metrics)
for m in ['train', 'predict']:
self.mode = m
self._build(m) # build forward program
self._plan(m) # completion & planner
self._parallel(m, all_ranks) # parallel
self._initialize(m) # init comm and startup program
self.mode = mode
self._mode = mode
self._build(mode) # build forward program
self._plan(mode) # completion & planner
self._parallel(mode, all_ranks) # parallel
self._initialize(mode) # init comm and startup program
def _build(self, mode):
serial_main_prog = self._serial_main_progs.get(mode, None)
......@@ -112,10 +117,16 @@ class Engine:
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))
default_ctx = get_default_distributed_context()
if not default_ctx.is_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels]
# print(serial_main_prog)
self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
self._fetch_vars[mode] = {
"outputs": outputs,
"outputs": flatten(outputs),
"loss": losses,
"metrics": metrics
}
......@@ -128,6 +139,12 @@ class Engine:
self._pass_contexts[mode] = PassContext()
def _plan(self, mode):
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completition.
defualt_ctx = get_default_distributed_context()
self._dist_contexts[mode]._dist_op_context = defualt_ctx.dist_op_context
# Complete the distributed annotation
serial_main_prog = self._serial_main_progs[mode]
self._completer = Completer(self._dist_contexts[mode])
......@@ -147,13 +164,14 @@ class Engine:
self._parallel_program(mode, rank)
def _initialize(self, mode):
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if self._cur_rank not in process_group.ranks:
continue
process_group.instantiate()
if self._nranks > 1:
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if self._cur_rank not in process_group.ranks:
continue
process_group.instantiate()
# initialize
self._place = _get_device()
......@@ -161,8 +179,16 @@ class Engine:
self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
if self._executor is None:
self._executor = paddle.static.Executor(self._place)
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
self._executor.run(dist_startup_prog)
uninitialized = []
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
for var in dist_startup_prog.list_vars():
scope_var = global_scope().find_var(var.name)
if scope_var and scope_var.get_tensor()._is_initialized():
continue
uninitialized.append(var)
if uninitialized:
prune_startup_prog = dist_startup_prog._prune(uninitialized)
self._executor.run(prune_startup_prog)
def _parallel_program(self, mode, rank):
serial_main_program = self._serial_main_progs[mode]
......@@ -246,12 +272,13 @@ class Engine:
if config["use_pure_fp16"]:
config["base_opt"] = self._optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context)
auto_parallel_fp16_pass.apply([main_program],
[startup_program],
self._pass_contexts[self.mode])
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
self._pass_contexts[self.mode])
# apply recompute pass
if self.strategy.recompute:
......@@ -288,18 +315,26 @@ class Engine:
[main_program], [startup_program],
self._pass_contexts[self.mode])
def fit(self, train_data, batch_size=1, epochs=1, steps_per_epoch=None):
def fit(self,
train_data,
batch_size=1,
epochs=1,
steps_per_epoch=None,
use_program_cache=False,
return_numpy=True,
sample_generator=True):
# TODO: callbacks
# TODO: evaluate after training
self.mode = 'train'
assert isinstance(train_data, Dataset)
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch)
assert self.mode in self._dist_main_progs, "train model is not ready, please call `engine.prepare(mode='train')` first."
train_dataloader = self._create_dataloader(
train_data, batch_size, epochs, steps_per_epoch, sample_generator)
outputs = []
for epoch in range(epochs):
for step, data in enumerate(train_dataloader):
logs, loss = self._train_step(data)
logs, loss = self._train_step(data, use_program_cache,
return_numpy)
outputs.append(loss)
train_logs = {
"train_" + name: val
......@@ -308,14 +343,35 @@ class Engine:
self._logger.info(train_logs)
return outputs
def evaluate(self,
eval_data,
batch_size=1,
use_program_cache=False,
return_numpy=True,
sample_generator=True):
self.mode = 'eval'
assert self.mode in self._dist_main_progs, "eval model is not ready, please call `engine.prepare(mode='eval')` first."
eval_dataloader = self._create_dataloader(
eval_data, batch_size, sample_generator=sample_generator)
outputs = []
for step, data in enumerate(eval_dataloader):
logs, outs = self._eval_step(data, use_program_cache, return_numpy)
outputs.append(outs)
predict_logs = {"eval_" + name: val for name, val in logs.items()}
self._logger.info(predict_logs)
return outputs
def predict(self,
test_data,
batch_size=1,
use_program_cache=False,
return_numpy=True):
return_numpy=True,
sample_generator=True):
self.mode = 'predict'
# TODO: need check dataset
test_dataloader = self._create_dataloader(test_data, batch_size)
assert self.mode in self._dist_main_progs, "predict model is not ready, please call `engine.prepare(mode='predict')` first."
test_dataloader = self._create_dataloader(
test_data, batch_size, sample_generator=sample_generator)
outputs = []
for step, data in enumerate(test_dataloader):
......@@ -329,19 +385,39 @@ class Engine:
self._logger.info(predict_logs)
return outputs
def _train_step(self, data):
def _train_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
fetch_var = self._fetch_vars[self.mode]["loss"][0]
if fetch_var.name not in dist_main_prog.global_block().vars:
loss = self._executor.run(dist_main_prog)
loss = self._executor.run(dist_main_prog,
use_program_cache=use_program_cache)
logs["loss"] = None
else:
loss = self._executor.run(dist_main_prog,
fetch_list=to_list(fetch_var))
fetch_list=to_list(fetch_var),
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = loss
return logs, loss
def _eval_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
fetch_var = self._fetch_vars[self.mode]["loss"][0]
if fetch_var.name not in dist_main_prog.global_block().vars:
outs = self._executor.run(dist_main_prog,
use_program_cache=use_program_cache)
logs["loss"] = outs
else:
outs = self._executor.run(dist_main_prog,
fetch_list=fetch_var,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = outs
return logs, outs
def _predict_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
......@@ -366,7 +442,8 @@ class Engine:
dataset,
batch_size,
epochs=1,
steps_per_epoch=None):
steps_per_epoch=None,
sample_generator=True):
feed_list = self._feed_vars[self.mode]["inputs"] + self._feed_vars[
self.mode]["labels"]
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
......@@ -376,9 +453,12 @@ class Engine:
serial_main_prog = self._serial_main_progs[self.mode]
serial_main_block = serial_main_prog.global_block()
op_size = len(dist_main_block.ops)
if dist_main_block.ops[0].type == 'create_py_reader':
op_size -= 3
for _ in range(3):
dist_main_block._remove_op(0, sync=False)
places = paddle.static.cuda_places()
with fluid.program_guard(dist_main_prog, dist_startup_prog):
inputs = self._feed_vars[self.mode]["inputs"]
dataloader = NonIterableGeneratorLoader(
dataset,
feed_list,
......@@ -386,7 +466,7 @@ class Engine:
batch_size,
epochs,
steps_per_epoch,
inputs=inputs)
sample_generator=sample_generator)
new_op_size = len(dist_main_block.ops)
for _ in range(new_op_size - 1, op_size - 1, -1):
op = dist_main_block.ops[new_op_size - 1]
......@@ -396,7 +476,7 @@ class Engine:
dist_main_block, new_op_desc, type=new_op_desc.type())
dist_main_block.ops.insert(0, new_op)
for in_name in new_op.input_arg_names:
if in_name == "lod_tensor_blocking_queue_0":
if "lod_tensor_blocking_queue" in in_name:
continue
if in_name not in dist_main_block.vars:
in_var = serial_main_block._var_recursive(in_name)
......@@ -424,6 +504,27 @@ class Engine:
.format(i, spec))
return specs
def _set_data_parallel(self, var):
if self._nranks == 1:
self._default_strategy = 'serial'
auto.shard_tensor(
var,
dist_attr={
"process_mesh": [0],
"dims_mapping": [-1 for _ in range(len(var.shape))]
})
else:
self._default_strategy = 'dp'
auto.shard_tensor(
var,
dist_attr={
"process_mesh": list(range(self._nranks)),
"dims_mapping":
[0] + [-1 for _ in range(len(var.shape) - 1)]
})
return var
def save(self, path, training=True, mode=None):
if not mode:
mode = self.mode
......@@ -459,3 +560,35 @@ class Engine:
dist_context = self._dist_contexts[mode]
self._saver.load(path, dist_main_prog, dist_context, strict,
load_optimizer)
@property
def mode(self):
return self._mode
@mode.setter
def mode(self, mode):
self._mode = mode
@property
def metrics(self):
return self._metrics
@property
def main_program(self):
return self._dist_main_progs[self.mode][self._cur_rank]
@property
def startup_program(self):
return self._dist_startup_progs[self.mode][self._cur_rank]
@property
def dist_context(self):
return self._dist_contexts[self.mode]
@property
def serial_main_program(self):
return self._serial_main_progs[self.mode]
@property
def serial_startup_program(self):
return self._serial_startup_progs[self.mode]
......@@ -53,6 +53,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
input_names = op_desc.input_names()
xshape_arg_names = []
if "XShape" in input_names:
xshape_arg_names = op_desc.input("XShape")
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
......@@ -63,10 +67,18 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
# continue
# if len(dims_mapping) < 1:
# continue
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
else:
if dims_mapping[0] != -1:
return False
if len(dims_mapping) > 2:
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
return True
def is_output_compatible(self, dist_op):
......@@ -105,17 +117,31 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr
batch_dim_mappings = []
# Check input compatibility
input_names = op_desc.input_names()
xshape_arg_names = []
if "XShape" in input_names:
xshape_arg_names = op_desc.input("XShape")
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
if dims_mapping[0] != -1:
return False
if len(dims_mapping) > 2:
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
# Check output compatibility
output_names = op_desc.output_names()
......@@ -160,24 +186,39 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
or op_desc.type() == "slice" \
or op_desc.type() == "while":
return False
input_names = op_desc.input_names()
input_xshape_arg_names = []
if "XShape" in input_names:
input_xshape_arg_names = op_desc.input("XShape")
output_names = op_desc.output_names()
xshape_arg_names = []
output_xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
output_xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
if arg_name not in input_xshape_arg_names:
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
batch_dim_mappings.append(dims_mapping[1])
for arg_name in op_desc.output_arg_names():
if op_desc.type() == "fill_zeros_like":
input_tensor = dist_op.get_serial_input(op_desc.input_arg_names(
)[0])
if input_tensor.is_parameter:
continue
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if arg_name not in output_xshape_arg_names:
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
......@@ -194,16 +235,27 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping
) >= 1 and compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
if arg_name not in input_xshape_arg_names:
if len(dims_mapping) >= 1 and \
compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
else:
if len(dims_mapping) >= 2 and \
compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
if op_desc.type() == "fill_zeros_like":
input_tensor = dist_op.get_serial_input(op_desc.input_arg_names(
)[0])
if input_tensor.is_parameter:
continue
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if arg_name not in output_xshape_arg_names:
if len(dims_mapping
) >= 1 and compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
......@@ -371,30 +423,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if need_gradient_allreduce:
allreduce_vars = []
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
# NOTE: When amp and recompute pass are effective at the same time,
# if a parameter is casted and recomputed, the 'parameter@GARD' can not
# be found in the grad_op's output.
if "subprog_" in varname:
varname = varname[:varname.index(".subprog_")]
assert len(
backward_op.desc.input(input_name)
) == 1, "parameter input to grad op should be length 1, but got [{}]".format(
backward_op.desc.input(input_name))
assert varname + "@GRAD" in backward_op.desc.output_arg_names(
), "parameter's grad [{}] not found in the grad op's output".format(
varname + "@GRAD")
assert len(
backward_op.desc.output(input_name + "@GRAD")
) == 1, "parameter grad of grad op should be length 1, but got [{}]".format(
backward_op.desc.output(input_name + "@GRAD"))
allreduce_vars.append(
backward_op.desc.output(input_name + "@GRAD")[0])
for output_name in backward_op.desc.output_names():
for varname in backward_op.desc.output(output_name):
if varname in kwargs["grad_var_to_var"]:
fwd_name = kwargs["grad_var_to_var"][varname]
if fwd_name not in main_block.vars:
continue
if is_parameter_related(fwd_name, main_block):
allreduce_vars.append(varname)
if len(allreduce_vars) > 0:
......
......@@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext, Di
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op
from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
......@@ -198,15 +198,29 @@ class Partitioner(object):
dist_op_context = self._dist_context.dist_op_context
serial_ops = ref_block.ops
last_fwd_op_idx = -1
for idx, op in enumerate(ref_block.ops):
if is_loss_op(op):
last_fwd_op_idx = idx
break
if last_fwd_op_idx == -1:
last_fwd_op_idx = len(ref_block.ops)
# init mapping
forward_op_id2forward_op = {}
for idx in range(len(serial_ops)):
if is_forward_op(serial_ops[idx]):
if idx <= last_fwd_op_idx:
forward_op_id2forward_op[serial_ops[idx].desc.id(
)] = serial_ops[idx]
appended_grad_times = 0
# partiiton
for op in serial_ops:
for idx, op in enumerate(serial_ops):
if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1]) or
is_loss_op(serial_ops[idx - 1])):
appended_grad_times += 1
# partititon input variables
for serial_input_varname in op.desc.input_arg_names():
......@@ -244,8 +258,11 @@ class Partitioner(object):
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_backward_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op)
dist_op_backward_impl.backward(self._dist_context, **kinputs,
**koutputs)
grad_var_to_var = self._dist_context.dist_op_context.grad_var_to_var[
appended_grad_times]
dist_op_backward_impl.backward(
self._dist_context, **kinputs, **koutputs,
**{"grad_var_to_var": grad_var_to_var})
else:
raise NotImplementedError(
"partitioner only support forward op and backward op, but got {}".
......
......@@ -996,69 +996,87 @@ def set_grad_var_shape(program, dist_context):
block = program.global_block()
vars = block.vars
for op in block.ops:
appended_grad_times = 0
grad_var_to_var = dist_context.dist_op_context.grad_var_to_var
for idx, op in enumerate(block.ops):
if int(op.attr('op_role')) != int(OpRole.Backward):
continue
if int(block.ops[idx-1].attr('op_role')) == int(OpRole.Forward) or \
int(block.ops[idx-1].attr('op_role')) == 257:
appended_grad_times += 1
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
break
if op.type in ["sum", "concat"]:
if op.type in ["sum", "concat", "shape"]:
continue
if int(op.attr('op_role')) == int(OpRole.Backward):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr is not None
for var_name in op.output_arg_names:
if "@GRAD" not in var_name:
continue
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr is not None
for var_name in op.output_arg_names:
if "@GRAD" not in var_name:
continue
if var_name in grad_var_to_var[appended_grad_times]:
forward_var_name = grad_var_to_var[appended_grad_times][
var_name]
else:
forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast"
]:
forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad":
forward_var_name = None
for output_name in op.output_names:
if var_name in op.output(output_name):
assert "@GRAD" in output_name
input_name = output_name[:output_name.find("@GRAD")]
assert len(op.input(input_name)) == 1
forward_var_name = op.input(input_name)[0]
assert forward_var_name is not None
need_set_shape_list = [
"reshape2_grad", "softmax_with_cross_entropy_grad",
"transpose2_grad", "softmax_grad", "cross_entropy_grad2",
"dropout_grad"
]
forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout"
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
assert int(forward_op.attr('op_role')) != int(
OpRole.Backward)
idx = need_set_shape_list.index(op.type)
forward_op_name = forward_list[idx]
if forward_op.type == forward_op_name and forward_var_name in forward_op.input_arg_names:
op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
break
forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var_name)
assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}"
forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
forward_var)
assert forward_var_dist_attr is not None
grad_var = vars[var_name]
ref_shape = infer_shape(block, forward_var,
forward_var_dist_attr,
forward_input_dist_attr)
if list(grad_var.shape) != ref_shape:
grad_var.desc.set_shape(ref_shape)
if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast",
"fill_zeros_like"
]:
forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad":
forward_var_name = None
for output_name in op.output_names:
if var_name in op.output(output_name):
assert "@GRAD" in output_name
input_name = output_name[:output_name.find("@GRAD")]
assert len(op.input(input_name)) == 1
forward_var_name = op.input(input_name)[0]
assert forward_var_name is not None
need_set_shape_list = [
"reshape2_grad", "softmax_with_cross_entropy_grad",
"transpose2_grad", "softmax_grad", "cross_entropy_grad2",
"dropout_grad", "tanh_grad", "slice", "assign",
"matmul_v2_triple_grad", "elementwise_add_triple_grad",
"fill_constant", "sqrt_grad"
]
forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout", "tanh",
["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad",
"elementwise_add_grad_grad", "shape", "sqrt"
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
idx = need_set_shape_list.index(op.type)
forward_op_name = forward_list[idx]
if forward_op.type in forward_op_name and forward_var_name in forward_op.input_arg_names:
op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
break
forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var_name)
assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}"
forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
forward_var)
assert forward_var_dist_attr is not None
grad_var = vars[var_name]
ref_shape = infer_shape(block, forward_var, forward_var_dist_attr,
forward_input_dist_attr)
if list(grad_var.shape) != ref_shape:
grad_var.desc.set_shape(ref_shape)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
......
......@@ -478,12 +478,16 @@ def _accumulate_gradients_by_add_ops_(var_name,
renamed_vars[var_name] = [var_name]
def _addup_repetitive_outputs_(op_descs, block_idx):
def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
"""
In backward part, an variable may be the output of more than one ops.
And one op may yield its multiple outputs to the same variable.
In these cases, the variable should be the accumulation of all the outputs.
`sum_op`s are added to implement the accumulate.
Args:
grad_var_to_var(dict): used to build the mapping between grad var name and forward var name.
Only for auto parallel.
"""
_MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add']
#pending_sum_ops = []
......@@ -531,6 +535,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx):
new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] += 1
# Build the mapping between the new_name and var_name (Only for auto parallel)
if grad_var_to_var is not None:
if var_name in grad_var_to_var:
grad_var_to_var[new_name] = grad_var_to_var[
var_name]
else:
grad_var_to_var[new_name] = var_name
# rename original var_name
renamed_vars[var_name][0] = new_name
# before change: _rename_arg_(op_descs, var_name,
......@@ -557,6 +568,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx):
new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] += 1
# Build the mapping between the new_name and var_name (Only for auto parallel)
if grad_var_to_var is not None:
if var_name in grad_var_to_var:
grad_var_to_var[new_name] = grad_var_to_var[
var_name]
else:
grad_var_to_var[new_name] = var_name
arg_names[arg_idx] = new_name
op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name)
......@@ -1081,6 +1099,16 @@ def _append_backward_ops_(block,
rename_var_map(dict): used to associate target_grad var name with first grad_op input name.
Only used in for high order gradient.
"""
# Build the mapping between the forward op and backward op (Only for auto parallel)
def update_distop_context(distop_context, op_grad_to_var,
appending_grad_times):
distop_context.grad_var_to_var[appending_grad_times].update(
op_grad_to_var)
for op_desc in grad_op_desc:
assert op_desc.id() not in distop_context.grad_op_id_to_op_id
distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id()
if callbacks is not None:
assert (isinstance(callbacks, (list, tuple)))
for cb in callbacks:
......@@ -1118,11 +1146,18 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
# Build the mapping between the forward op and backward op (Only for auto parallel)
if distop_context is not None:
for op_desc in grad_op_desc:
assert op_desc.id() not in distop_context.grad_op_id_to_op_id
distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id()
update_distop_context(distop_context, op_grad_to_var,
program._appending_grad_times)
else:
default_ctx = getattr(paddle.distributed.auto_parallel.dist_context,
'_g_default_distributed_context', None)
if default_ctx is not None:
distop_context = default_ctx.dist_op_context
update_distop_context(distop_context, op_grad_to_var,
program._appending_grad_times)
# Set device for grad_op according to forward Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
......@@ -1155,6 +1190,11 @@ def _append_backward_ops_(block,
rename_var_map[name] = new_name
if name in op_grad_to_var:
# Build the mapping between the grad var name and var name (Only for auto parallel)
if distop_context is not None:
distop_context.grad_var_to_var[
program._appending_grad_times][
new_name] = op_grad_to_var[name]
op_grad_to_var[new_name] = op_grad_to_var[name]
op_grad_to_var.pop(name)
......@@ -1187,8 +1227,14 @@ def _append_backward_ops_(block,
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
# record mapping bewteen grad var name and var name (Only for auto parallel)
grad_var_to_var = None
if distop_context is not None:
grad_var_to_var = distop_context.grad_var_to_var[
program._appending_grad_times]
# sum parameter's gradients' var given multiple var gradient
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx)
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx,
grad_var_to_var)
# if all outputs of the grad op are in no_grad_set, then just remove and fill zero
# if all inputs of the grad op are in no_grad_set, just remove this op
......
......@@ -12,6 +12,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS})
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_high_order_grad MODULES test_high_order_grad ENVS ${dist_ENVS})
set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS})
py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS})
......
......@@ -127,9 +127,16 @@ def train():
engine.prepare(optimizer, loss)
engine.fit(dataset,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size)
engine.save('./mlp')
engine.load('./mlp')
steps_per_epoch=batch_num * batch_size,
sample_generator=True)
eval_dataset = MyDataset(batch_size)
engine.prepare(optimizer, loss, mode='eval')
engine.evaluate(eval_dataset, batch_size)
test_dataset = MyDataset(batch_size)
engine.prepare(mode='predict')
engine.predict(test_dataset, batch_size)
engine.save('./mlp_inf', training=False, mode='predict')
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import time
import paddle.fluid as fluid
import copy
import os
import numpy as np
import subprocess
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
paddle.enable_static()
global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
batch_size = 1
batch_num = 10
hidden_size = 1024
image_size = hidden_size
paddle.seed(44)
class MyDataset(Dataset):
def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
return input
def __len__(self):
return self.num_samples
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(input)
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": global_process_mesh,
"dims_mapping": [-1, 0]
})
out = F.gelu(out, approximate=True)
out = self.linear1(out)
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": global_process_mesh,
"dims_mapping": [0, -1]
})
out = self.dropout(out)
out = self.linear2(out)
return out
def train():
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
dataset = MyDataset(batch_num * batch_size)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
dist_strategy = fleet.DistributedStrategy()
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
engine = Engine(mlp, inputs_spec=inputs_spec, strategy=dist_strategy)
engine.prepare(mode='predict')
engine.predict(dataset, batch_size=batch_size)
if __name__ == "__main__":
train()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import paddle
import unittest
import numpy as np
import paddle.distributed.auto_parallel as auto
from paddle.static import InputSpec
from paddle.distributed import fleet
from paddle.incubate.autograd import Hessian
from paddle.distributed.auto_parallel.engine import Engine
class FCNet:
def __init__(self, num_ins, num_outs, num_layers, hidden_size):
self.num_ins = num_ins
self.num_outs = num_outs
self.num_layers = num_layers
self.hidden_size = hidden_size
self.activation = paddle.tanh
self.weights = []
self.biases = []
for i in range(self.num_layers):
if i == 0:
lsize = self.num_ins
rsize = self.hidden_size
elif i == (self.num_layers - 1):
lsize = self.hidden_size
rsize = self.num_outs
else:
lsize = self.hidden_size
rsize = self.hidden_size
w = paddle.static.create_parameter(
shape=[lsize, rsize], dtype="float32", is_bias=False)
b = paddle.static.create_parameter(
shape=[rsize], dtype="float32", is_bias=True)
self.weights.append(w)
self.biases.append(b)
def nn_func(self, ins):
u = ins
for i in range(self.num_layers - 1):
u = paddle.nn.functional.linear(u, self.weights[i], self.biases[i])
u = self.activation(u)
u = paddle.nn.functional.linear(u, self.weights[-1], self.biases[-1])
return u
class LaplaceModel(paddle.nn.Layer):
def __init__(self, num_ins=2, num_outs=1, num_layers=5, hidden_size=20):
super(LaplaceModel, self).__init__()
self.net = FCNet(
num_ins=num_ins,
num_outs=num_outs,
num_layers=num_layers,
hidden_size=hidden_size)
def forward(self, inputs, bc_index):
inputs.stop_gradient = False
outputs = self.net.nn_func(inputs)
# eq_loss
hes = Hessian(self.net.nn_func, inputs, is_batched=True)
eq_loss = paddle.norm(hes[:, 0, 0] + hes[:, 1, 1], p=2)
# bc_loss
bc_u = paddle.index_select(outputs, bc_index)
return eq_loss, bc_u
class LaplaceDataset:
def __init__(self, num_sample):
self.num_sample = num_sample
def __getitem__(self, index):
x = np.linspace(0, 0.9, 10)
y = np.linspace(0, 0.9, 10)
bc_value = np.random.rand(36).reshape(36, 1).astype('float32')
domain_space = []
bc_index = []
for j in range(len(y)):
for i in range(len(x)):
domain_space.append([x[i], y[j]])
if i == 0 or i == 9 or j == 0 or j == 9:
bc_index.append(i + 10 * j)
domain_space = np.array(domain_space, dtype='float32')
bc_index = np.array(bc_index, dtype='int64')
return domain_space, bc_index, bc_value
def __len__(self):
return self.num_sample
def loss_func(eq_loss, bc_u, bc_value):
bc_diff = bc_u - bc_value
bc_loss = paddle.norm(bc_diff, p=2)
loss = eq_loss + bc_loss
return loss
def main():
# dataset
train_dataset = LaplaceDataset(10)
# optimizer
optimizer = paddle.optimizer.Adam(learning_rate=0.001)
# model
laplace = LaplaceModel()
# spec
inputs_spec = [
InputSpec([100, 2], 'float32', 'x'), InputSpec([36], 'int64', 'bc_idx')
]
labels_spec = InputSpec([36, 1], 'float32', 'bc_v')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
engine = Engine(
laplace,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
paddle.seed(1234 + engine._cur_rank)
engine.prepare(optimizer=optimizer, loss=loss_func)
res = engine.fit(train_dataset, sample_generator=False)
assert np.allclose(res[-1], 2.840593)
dist_context = engine.dist_context
block = engine.main_program.global_block()
ops = block.ops
for op in ops:
if op.type == 'p_norm':
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == 'p_norm'
if 'x' in op.input_arg_names:
out_name = op.output_arg_names[0]
assert block.vars[out_name].shape[0] == 50
if __name__ == "__main__":
main()
......@@ -49,28 +49,6 @@ class TestEngineAPI(unittest.TestCase):
if os.path.exists('rank_mapping.csv'):
os.remove('rank_mapping.csv')
def test_engine_predict(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "engine_predict_api.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
# Remove unnecessary files
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestHighOrderGrad(unittest.TestCase):
def test_dp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "high_order_grad.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
# Remove unnecessary files
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册