未验证 提交 c043a21b 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] adapt for 2d laplace (#41601)

* add default_ctx in backward.py

* record grad_var_to_var with grad_times

* fix backward

* update annotation

* add complete_high_order_grad in complete_forward

* add dist slice op

* update grad_var_to_var type

* update partition_block init mapping before loss op

* update compatible for 'XShape' & update 'allreduce_vars'

* add dist reshape op when input dim equal to output dim

* update 'set_grad_var_shape' with grad_var_to_var

* fix dist slice

* fix set_grad_var_shape

* add dist pnorm op

* fix dist pnorm dist_attr

* fix engine startprogram & adapt highorder grad

* fix set_grad_var_shape when mp

* update unittest

* update cmakelist

* default strategy in engine: dp

* bug fix

* tiny fix

* flatten outputs

* fix default strategy

* init default ctx

* tiny fix

* test=allcase
上级 2d29d833
...@@ -648,6 +648,9 @@ class Completer: ...@@ -648,6 +648,9 @@ class Completer:
self._dist_context.copy_dist_attr_from_graph_to_program() self._dist_context.copy_dist_attr_from_graph_to_program()
self._dist_context.clear_dist_info_for_graph() self._dist_context.clear_dist_info_for_graph()
# NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient
self.complete_high_order_grad_annotation(serial_main_program)
# Do the validation check and amend some completion # Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program() self._dist_context.amend_dist_attr_for_program()
...@@ -655,6 +658,164 @@ class Completer: ...@@ -655,6 +658,164 @@ class Completer:
return serial_main_program return serial_main_program
def complete_high_order_grad_annotation(self, serial_main_program):
"""
NOTE:
[HighOrderGrad] Complete the annotation of vars and ops only for high order gradient.
This function is temporary to support high order gradient, and will be removed in the future.
"""
def _is_grad_var_name(name):
if "@GRAD" in name:
return True
return False
def _get_op_by_id(ops, id):
for op in ops:
if op.desc.id() == id:
return op
return None
ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars
dist_op_context = self._dist_context.dist_op_context
grad_var_to_var = dist_op_context.grad_var_to_var
appended_grad_times = 0
for idx in range(0, len(ops)):
op = ops[idx]
if int(op.attr('op_role')) == int(
core.op_proto_and_checker_maker.OpRole.Forward):
continue
if int(op.attr('op_role')) == int(
core.op_proto_and_checker_maker.OpRole.Backward) and int(
ops[idx - 1].attr('op_role')) == int(
core.op_proto_and_checker_maker.OpRole.Forward):
appended_grad_times += 1
# complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx]
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(
ops, dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()])
assert forward_op is not None
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op)
fwd_op_process_mesh = fwd_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = fwd_op_process_mesh
for input_name in grad_op.input_arg_names:
if input_name not in forward_op.input_arg_names and input_name not in forward_op.output_arg_names:
if input_name in grad_var_to_var[appended_grad_times]:
fwd_name = grad_var_to_var[appended_grad_times][
input_name]
ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping(
fwd_name)
else:
input_var = vars[input_name]
ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
input_var).dims_mapping
else:
if fwd_op_dist_attr.get_input_dims_mapping(input_name):
ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
input_name)
else:
ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping(
input_name)
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_name)
grad_op_dist_attr.set_input_dims_mapping(input_name,
ref_dims_mapping)
for output_name in grad_op.output_arg_names:
assert output_name in grad_var_to_var[appended_grad_times]
fwd_name = grad_var_to_var[appended_grad_times][output_name]
ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
fwd_name)
# var
output_var = vars[output_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = fwd_op_process_mesh
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# op
grad_op_dist_attr.set_output_dims_mapping(output_name,
ref_dims_mapping)
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
# grad ops that have not a corresponding mapping in grad_op_id_to_op_id
else:
if grad_op.type == 'sum':
assert all(map(_is_grad_var_name, grad_op.input_arg_names))
output_name = grad_op.output_arg_names[0]
assert output_name in grad_var_to_var[appended_grad_times], \
"sum op's output '{}' has no corresponding var".format(
output_name)
ref_fwd_var_name = grad_var_to_var[appended_grad_times][
output_name]
ref_fwd_var = vars[ref_fwd_var_name]
ref_fwd_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_fwd_var)
ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping
ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping
tensor_dist_attr.process_mesh = ref_fwd_process_mesh
output_var = vars[output_name]
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_fwd_process_mesh
for var_name in grad_op.input_arg_names:
grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_fwd_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping)
elif grad_op.type == 'fill_zeros_like':
ref_var_name = grad_op.input_arg_names[0]
ref_var = vars[ref_var_name]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_var)
ref_dims_mapping = ref_dist_attr.dims_mapping
ref_process_mesh = ref_dist_attr.process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = ref_process_mesh
output_var_name = grad_op.output_arg_names[0]
output_var = vars[output_var_name]
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_process_mesh
grad_op_dist_attr.set_input_dims_mapping(ref_var_name,
ref_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(output_var_name,
ref_dims_mapping)
elif grad_op.type in ['shape', 'fill_constant']:
continue
else:
raise ValueError("got unexpect op [{}]".format(
str(grad_op.type)))
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
def complete_backward_annotation(self, serial_main_program): def complete_backward_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the backward phase for parallel program.""" """Complete the annotation of vars and ops in the backward phase for parallel program."""
...@@ -689,6 +850,8 @@ class Completer: ...@@ -689,6 +850,8 @@ class Completer:
ops = list(serial_main_program.global_block().ops) ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars vars = serial_main_program.global_block().vars
dist_op_context = self._dist_context.dist_op_context dist_op_context = self._dist_context.dist_op_context
grad_var_to_var = dist_op_context.grad_var_to_var[len(
dist_op_context.grad_var_to_var)]
for idx in range(first_backward_op_idx, len(ops)): for idx in range(first_backward_op_idx, len(ops)):
...@@ -765,102 +928,111 @@ class Completer: ...@@ -765,102 +928,111 @@ class Completer:
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
continue continue
# op dist attr fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op) forward_op)
forward_op_process_mesh = forward_op_dist_attr.process_mesh fwd_op_process_mesh = fwd_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = forward_op_process_mesh grad_op_dist_attr.process_mesh = fwd_op_process_mesh
# var
for input_name in grad_op.input_arg_names: for input_name in grad_op.input_arg_names:
input_var = vars[input_name] if input_name not in forward_op.input_arg_names and input_name not in forward_op.output_arg_names:
ref_dims_mapping = None if input_name in grad_var_to_var:
if "@GRAD" in input_name: fwd_name = grad_var_to_var[input_name]
forward_name = _get_forward_varname_from_grad_varname( ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping(
input_name) fwd_name)
ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( else:
forward_name) input_var = vars[input_name]
ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
input_var).dims_mapping
else: else:
if forward_op_dist_attr.get_input_dims_mapping( if fwd_op_dist_attr.get_input_dims_mapping(input_name):
input_name): ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
input_name) input_name)
else: else:
ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping(
input_name) input_name)
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_var.name) input_name)
grad_op_dist_attr.set_input_dims_mapping(input_name, grad_op_dist_attr.set_input_dims_mapping(input_name,
ref_dims_mapping) ref_dims_mapping)
for output_name in grad_op.desc.output_names(): for output_name in grad_op.output_arg_names:
assert len(grad_op.desc.output(output_name)) in [0, 1] assert output_name in grad_var_to_var
if _is_grad_var_name(output_name): fwd_name = grad_var_to_var[output_name]
input_name = _get_forward_varname_from_grad_varname( ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
output_name) fwd_name)
else: # var
assert grad_op.type in [ output_var = vars[output_name]
"cast", "c_identity", "c_allreduce_sum" tensor_dist_attr = TensorDistributedAttribute()
] tensor_dist_attr.dims_mapping = ref_dims_mapping
input_name = "X" tensor_dist_attr.process_mesh = fwd_op_process_mesh
assert input_name in forward_op.desc.input_names( self._dist_context.set_tensor_dist_attr_for_program(
), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format( output_var, tensor_dist_attr)
output_name, grad_op.type, input_name) # op
if len(grad_op.desc.output(output_name)) == 1: grad_op_dist_attr.set_output_dims_mapping(output_name,
# tensor dist attr ref_dims_mapping)
output_var = vars[grad_op.desc.output(output_name)[0]]
forward_name = _get_forward_varname_from_grad_varname(
output_var.name)
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
forward_name)
output_var_dist_attr = TensorDistributedAttribute()
output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = forward_op_process_mesh
self._dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr)
grad_op_dist_attr.set_output_dims_mapping(
output_var.name, ref_dims_mapping)
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id # grad ops that have not a corresponding mapping in grad_op_id_to_op_id
else: else:
assert grad_op.type == "sum", "got unexpect op [{}]".format( if grad_op.type == 'sum':
str(grad_op.type)) assert all(map(_is_grad_var_name, grad_op.input_arg_names))
assert all(map(_is_grad_var_name, grad_op.input_arg_names)) output_name = grad_op.output_arg_names[0]
assert len(grad_op.output_arg_names) == 1 assert output_name in grad_var_to_var, "sum op's output '{}' has no corresponding var".format(
output_name)
ref_forward_var_name = _get_forward_varname_from_grad_varname( ref_fwd_var_name = grad_var_to_var[output_name]
grad_op.output_arg_names[0]) ref_fwd_var = vars[ref_fwd_var_name]
forward_var = vars[ref_forward_var_name] ref_fwd_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_forward_var_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( ref_fwd_var)
forward_var).dims_mapping ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping
ref_forward_var_process_mesh = self._dist_context.get_tensor_dist_attr_for_program( ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh
forward_var).process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping
tensor_dist_attr.process_mesh = ref_fwd_process_mesh
output_var = vars[output_name]
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# output # op
tensor_dist_attr = TensorDistributedAttribute() grad_op_dist_attr = OperatorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping grad_op_dist_attr.process_mesh = ref_fwd_process_mesh
tensor_dist_attr.process_mesh = ref_forward_var_process_mesh for var_name in grad_op.input_arg_names:
self._dist_context.set_tensor_dist_attr_for_program( grad_op_dist_attr.set_input_dims_mapping(
vars[grad_op.output_arg_names[0]], tensor_dist_attr) var_name, ref_fwd_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping)
elif grad_op.type == 'fill_zeros_like':
ref_var_name = grad_op.input_arg_names[0]
ref_var = vars[ref_var_name]
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_var)
ref_dims_mapping = ref_dist_attr.dims_mapping
ref_process_mesh = ref_dist_attr.process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = ref_process_mesh
output_var_name = grad_op.output_arg_names[0]
output_var = vars[output_var_name]
self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_process_mesh
grad_op_dist_attr.set_input_dims_mapping(ref_var_name,
ref_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(output_var_name,
ref_dims_mapping)
else:
raise ValueError("got unexpect op [{}]".format(
str(grad_op.type)))
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh
for var_name in grad_op.input_arg_names:
assert _get_forward_varname_from_grad_varname(
var_name) == ref_forward_var_name
grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_forward_var_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(
grad_op.output_arg_names[0], ref_forward_var_dims_mapping)
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
......
...@@ -120,6 +120,11 @@ class DistributedContext: ...@@ -120,6 +120,11 @@ class DistributedContext:
def dist_startup_programs(self): def dist_startup_programs(self):
return self._dist_startup_programs return self._dist_startup_programs
@property
def is_annotation(self):
return len(self._dist_tensors_for_program) or len(
self._dist_ops_for_program)
def add_process_mesh(self, process_mesh): def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \ assert isinstance(process_mesh, ProcessMesh), \
'The type of dim_mapping must be ProcessMesh.' 'The type of dim_mapping must be ProcessMesh.'
...@@ -577,6 +582,7 @@ class DistributedOperatorContext: ...@@ -577,6 +582,7 @@ class DistributedOperatorContext:
self._cur_src_op = None self._cur_src_op = None
self._cur_dist_attr = None self._cur_dist_attr = None
self.grad_op_id_to_op_id = {} self.grad_op_id_to_op_id = {}
self.grad_var_to_var = defaultdict(dict)
self._work_block = None self._work_block = None
self.already_init_sync_vars = set() self.already_init_sync_vars = set()
self.varname_mapping = None self.varname_mapping = None
......
...@@ -16,6 +16,7 @@ import abc ...@@ -16,6 +16,7 @@ import abc
import numpy as np import numpy as np
import paddle import paddle
from .utils import to_list from .utils import to_list
from paddle.fluid.layers.utils import flatten
from paddle.io import DataLoader, DistributedBatchSampler from paddle.io import DataLoader, DistributedBatchSampler
...@@ -56,16 +57,17 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -56,16 +57,17 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
data_parallel_world_size=None, data_parallel_world_size=None,
data_parallel_rank=None, data_parallel_rank=None,
drop_last=False, drop_last=False,
inputs=[]): sample_generator=True):
self.feed_list = feed_list self.feed_list = feed_list
self.places = places self.places = places
self.steps_per_epoch = steps_per_epoch self.steps_per_epoch = steps_per_epoch
self._sample_generator = sample_generator
super(NonIterableGeneratorLoader, self).__init__( super(NonIterableGeneratorLoader, self).__init__(
dataset, batch_size, epochs, data_parallel_world_size, dataset, batch_size, epochs, data_parallel_world_size,
data_parallel_rank, drop_last) data_parallel_rank, drop_last)
self._inner_dataloader = self._create_inner_dataloader() self._inner_dataloader = self._create_inner_dataloader()
self._steps = self._infer_steps() self._steps = self._infer_steps()
self._inputs = inputs
def __iter__(self): def __iter__(self):
self._cur_step = 0 self._cur_step = 0
...@@ -91,27 +93,28 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -91,27 +93,28 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
return steps_per_epoch return steps_per_epoch
def _create_inner_dataloader(self): def _create_inner_dataloader(self):
def data_generator(): def sample_data_generator():
batch_data = None batch_data = None
for step, data in enumerate(self.dataset): for step, data in enumerate(self.dataset):
if not isinstance(data, list): data = flatten(data)
data = to_list(data) if batch_data is None:
batch_data = [[] for i in range(len(data))]
if self.batch_size == 1: for idx in range(len(data)):
yield data batch_data[idx].append(data[idx])
if (step + 1) % self.batch_size == 0:
yield batch_data
batch_data = None batch_data = None
else:
if batch_data is None:
batch_data = [[] for i in range(len(data))]
for idx in range(len(data)):
batch_data[idx].append(data[idx])
if (step + 1) % self.batch_size == 0: def batch_data_generator():
yield batch_data for data in self.dataset:
batch_data = None data = flatten(data)
yield data
dataloader = paddle.fluid.io.DataLoader.from_generator( dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False) feed_list=self.feed_list, capacity=70, iterable=False)
dataloader.set_batch_generator(data_generator, self.places) if self._sample_generator:
dataloader.set_batch_generator(sample_data_generator, self.places)
else:
dataloader.set_batch_generator(batch_data_generator, self.places)
return dataloader return dataloader
...@@ -17,18 +17,22 @@ import logging ...@@ -17,18 +17,22 @@ import logging
from collections import defaultdict from collections import defaultdict
import paddle import paddle
import paddle.distributed.auto_parallel as auto
from paddle import fluid from paddle import fluid
from paddle.io import Dataset from paddle.io import Dataset
from paddle.metric import Metric from paddle.metric import Metric
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import program_guard from paddle.fluid import program_guard
from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator from paddle.fluid.framework import Operator
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed.passes import new_pass, PassContext
from paddle.distributed.utils import get_logger from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext
from .mapper import mapping from .mapper import mapping
from .cluster import Cluster from .cluster import Cluster
...@@ -61,6 +65,12 @@ class Engine: ...@@ -61,6 +65,12 @@ class Engine:
self.strategy = strategy self.strategy = strategy
self._executor = None self._executor = None
self._cur_rank = paddle.distributed.get_rank()
self._nranks = paddle.distributed.get_world_size()
self._saver = DistributedSaver()
self._logger = get_logger(logging.INFO)
self._default_strategy = None
self._orig_main_prog = fluid.default_main_program() self._orig_main_prog = fluid.default_main_program()
self._orig_startup_prog = fluid.default_startup_program() self._orig_startup_prog = fluid.default_startup_program()
self._orig_dist_context = get_default_distributed_context() self._orig_dist_context = get_default_distributed_context()
...@@ -70,9 +80,6 @@ class Engine: ...@@ -70,9 +80,6 @@ class Engine:
self._dist_startup_progs = defaultdict(dict) # dist startup programs self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._dist_contexts = {} self._dist_contexts = {}
self._pass_contexts = {} self._pass_contexts = {}
self._cur_rank = paddle.distributed.get_rank()
self._logger = get_logger(logging.INFO)
self._saver = DistributedSaver()
self._feed_vars = {} self._feed_vars = {}
self._fetch_vars = {} self._fetch_vars = {}
...@@ -86,13 +93,11 @@ class Engine: ...@@ -86,13 +93,11 @@ class Engine:
# TODO: check loss type # TODO: check loss type
self._loss = loss self._loss = loss
self._metrics = to_list(metrics) self._metrics = to_list(metrics)
for m in ['train', 'predict']: self._mode = mode
self.mode = m self._build(mode) # build forward program
self._build(m) # build forward program self._plan(mode) # completion & planner
self._plan(m) # completion & planner self._parallel(mode, all_ranks) # parallel
self._parallel(m, all_ranks) # parallel self._initialize(mode) # init comm and startup program
self._initialize(m) # init comm and startup program
self.mode = mode
def _build(self, mode): def _build(self, mode):
serial_main_prog = self._serial_main_progs.get(mode, None) serial_main_prog = self._serial_main_progs.get(mode, None)
...@@ -112,10 +117,16 @@ class Engine: ...@@ -112,10 +117,16 @@ class Engine:
if mode != "predict" and self._loss: if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels))) losses = to_list(self._loss(*(outputs + labels)))
default_ctx = get_default_distributed_context()
if not default_ctx.is_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels]
# print(serial_main_prog)
self._feed_vars[mode] = {"inputs": inputs, "labels": labels} self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
self._fetch_vars[mode] = { self._fetch_vars[mode] = {
"outputs": outputs, "outputs": flatten(outputs),
"loss": losses, "loss": losses,
"metrics": metrics "metrics": metrics
} }
...@@ -128,6 +139,12 @@ class Engine: ...@@ -128,6 +139,12 @@ class Engine:
self._pass_contexts[mode] = PassContext() self._pass_contexts[mode] = PassContext()
def _plan(self, mode): def _plan(self, mode):
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completition.
defualt_ctx = get_default_distributed_context()
self._dist_contexts[mode]._dist_op_context = defualt_ctx.dist_op_context
# Complete the distributed annotation # Complete the distributed annotation
serial_main_prog = self._serial_main_progs[mode] serial_main_prog = self._serial_main_progs[mode]
self._completer = Completer(self._dist_contexts[mode]) self._completer = Completer(self._dist_contexts[mode])
...@@ -147,13 +164,14 @@ class Engine: ...@@ -147,13 +164,14 @@ class Engine:
self._parallel_program(mode, rank) self._parallel_program(mode, rank)
def _initialize(self, mode): def _initialize(self, mode):
# Traverse different rank programs and traverse each op of them, if self._nranks > 1:
# instantiate communication by process_mapping. # Traverse different rank programs and traverse each op of them,
all_process_groups = get_all_process_groups() # instantiate communication by process_mapping.
for process_group in all_process_groups: all_process_groups = get_all_process_groups()
if self._cur_rank not in process_group.ranks: for process_group in all_process_groups:
continue if self._cur_rank not in process_group.ranks:
process_group.instantiate() continue
process_group.instantiate()
# initialize # initialize
self._place = _get_device() self._place = _get_device()
...@@ -161,8 +179,16 @@ class Engine: ...@@ -161,8 +179,16 @@ class Engine:
self._place = fluid.CUDAPlace(ParallelEnv().dev_id) self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
if self._executor is None: if self._executor is None:
self._executor = paddle.static.Executor(self._place) self._executor = paddle.static.Executor(self._place)
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] uninitialized = []
self._executor.run(dist_startup_prog) dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
for var in dist_startup_prog.list_vars():
scope_var = global_scope().find_var(var.name)
if scope_var and scope_var.get_tensor()._is_initialized():
continue
uninitialized.append(var)
if uninitialized:
prune_startup_prog = dist_startup_prog._prune(uninitialized)
self._executor.run(prune_startup_prog)
def _parallel_program(self, mode, rank): def _parallel_program(self, mode, rank):
serial_main_program = self._serial_main_progs[mode] serial_main_program = self._serial_main_progs[mode]
...@@ -246,12 +272,13 @@ class Engine: ...@@ -246,12 +272,13 @@ class Engine:
if config["use_pure_fp16"]: if config["use_pure_fp16"]:
config["base_opt"] = self._optimizer config["base_opt"] = self._optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply( auto_parallel_fp16_pass.apply([main_program],
[main_program], [startup_program], self._pass_context) [startup_program],
self._pass_contexts[self.mode])
else: else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program], auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context) self._pass_contexts[self.mode])
# apply recompute pass # apply recompute pass
if self.strategy.recompute: if self.strategy.recompute:
...@@ -288,18 +315,26 @@ class Engine: ...@@ -288,18 +315,26 @@ class Engine:
[main_program], [startup_program], [main_program], [startup_program],
self._pass_contexts[self.mode]) self._pass_contexts[self.mode])
def fit(self, train_data, batch_size=1, epochs=1, steps_per_epoch=None): def fit(self,
train_data,
batch_size=1,
epochs=1,
steps_per_epoch=None,
use_program_cache=False,
return_numpy=True,
sample_generator=True):
# TODO: callbacks # TODO: callbacks
# TODO: evaluate after training # TODO: evaluate after training
self.mode = 'train' self.mode = 'train'
assert isinstance(train_data, Dataset) assert self.mode in self._dist_main_progs, "train model is not ready, please call `engine.prepare(mode='train')` first."
train_dataloader = self._create_dataloader(train_data, batch_size, train_dataloader = self._create_dataloader(
epochs, steps_per_epoch) train_data, batch_size, epochs, steps_per_epoch, sample_generator)
outputs = [] outputs = []
for epoch in range(epochs): for epoch in range(epochs):
for step, data in enumerate(train_dataloader): for step, data in enumerate(train_dataloader):
logs, loss = self._train_step(data) logs, loss = self._train_step(data, use_program_cache,
return_numpy)
outputs.append(loss) outputs.append(loss)
train_logs = { train_logs = {
"train_" + name: val "train_" + name: val
...@@ -308,14 +343,35 @@ class Engine: ...@@ -308,14 +343,35 @@ class Engine:
self._logger.info(train_logs) self._logger.info(train_logs)
return outputs return outputs
def evaluate(self,
eval_data,
batch_size=1,
use_program_cache=False,
return_numpy=True,
sample_generator=True):
self.mode = 'eval'
assert self.mode in self._dist_main_progs, "eval model is not ready, please call `engine.prepare(mode='eval')` first."
eval_dataloader = self._create_dataloader(
eval_data, batch_size, sample_generator=sample_generator)
outputs = []
for step, data in enumerate(eval_dataloader):
logs, outs = self._eval_step(data, use_program_cache, return_numpy)
outputs.append(outs)
predict_logs = {"eval_" + name: val for name, val in logs.items()}
self._logger.info(predict_logs)
return outputs
def predict(self, def predict(self,
test_data, test_data,
batch_size=1, batch_size=1,
use_program_cache=False, use_program_cache=False,
return_numpy=True): return_numpy=True,
sample_generator=True):
self.mode = 'predict' self.mode = 'predict'
# TODO: need check dataset assert self.mode in self._dist_main_progs, "predict model is not ready, please call `engine.prepare(mode='predict')` first."
test_dataloader = self._create_dataloader(test_data, batch_size) test_dataloader = self._create_dataloader(
test_data, batch_size, sample_generator=sample_generator)
outputs = [] outputs = []
for step, data in enumerate(test_dataloader): for step, data in enumerate(test_dataloader):
...@@ -329,19 +385,39 @@ class Engine: ...@@ -329,19 +385,39 @@ class Engine:
self._logger.info(predict_logs) self._logger.info(predict_logs)
return outputs return outputs
def _train_step(self, data): def _train_step(self, data, use_program_cache=False, return_numpy=True):
logs = {} logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
fetch_var = self._fetch_vars[self.mode]["loss"][0] fetch_var = self._fetch_vars[self.mode]["loss"][0]
if fetch_var.name not in dist_main_prog.global_block().vars: if fetch_var.name not in dist_main_prog.global_block().vars:
loss = self._executor.run(dist_main_prog) loss = self._executor.run(dist_main_prog,
use_program_cache=use_program_cache)
logs["loss"] = None logs["loss"] = None
else: else:
loss = self._executor.run(dist_main_prog, loss = self._executor.run(dist_main_prog,
fetch_list=to_list(fetch_var)) fetch_list=to_list(fetch_var),
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = loss logs["loss"] = loss
return logs, loss return logs, loss
def _eval_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
fetch_var = self._fetch_vars[self.mode]["loss"][0]
if fetch_var.name not in dist_main_prog.global_block().vars:
outs = self._executor.run(dist_main_prog,
use_program_cache=use_program_cache)
logs["loss"] = outs
else:
outs = self._executor.run(dist_main_prog,
fetch_list=fetch_var,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = outs
return logs, outs
def _predict_step(self, data, use_program_cache=False, return_numpy=True): def _predict_step(self, data, use_program_cache=False, return_numpy=True):
logs = {} logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
...@@ -366,7 +442,8 @@ class Engine: ...@@ -366,7 +442,8 @@ class Engine:
dataset, dataset,
batch_size, batch_size,
epochs=1, epochs=1,
steps_per_epoch=None): steps_per_epoch=None,
sample_generator=True):
feed_list = self._feed_vars[self.mode]["inputs"] + self._feed_vars[ feed_list = self._feed_vars[self.mode]["inputs"] + self._feed_vars[
self.mode]["labels"] self.mode]["labels"]
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
...@@ -376,9 +453,12 @@ class Engine: ...@@ -376,9 +453,12 @@ class Engine:
serial_main_prog = self._serial_main_progs[self.mode] serial_main_prog = self._serial_main_progs[self.mode]
serial_main_block = serial_main_prog.global_block() serial_main_block = serial_main_prog.global_block()
op_size = len(dist_main_block.ops) op_size = len(dist_main_block.ops)
if dist_main_block.ops[0].type == 'create_py_reader':
op_size -= 3
for _ in range(3):
dist_main_block._remove_op(0, sync=False)
places = paddle.static.cuda_places() places = paddle.static.cuda_places()
with fluid.program_guard(dist_main_prog, dist_startup_prog): with fluid.program_guard(dist_main_prog, dist_startup_prog):
inputs = self._feed_vars[self.mode]["inputs"]
dataloader = NonIterableGeneratorLoader( dataloader = NonIterableGeneratorLoader(
dataset, dataset,
feed_list, feed_list,
...@@ -386,7 +466,7 @@ class Engine: ...@@ -386,7 +466,7 @@ class Engine:
batch_size, batch_size,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
inputs=inputs) sample_generator=sample_generator)
new_op_size = len(dist_main_block.ops) new_op_size = len(dist_main_block.ops)
for _ in range(new_op_size - 1, op_size - 1, -1): for _ in range(new_op_size - 1, op_size - 1, -1):
op = dist_main_block.ops[new_op_size - 1] op = dist_main_block.ops[new_op_size - 1]
...@@ -396,7 +476,7 @@ class Engine: ...@@ -396,7 +476,7 @@ class Engine:
dist_main_block, new_op_desc, type=new_op_desc.type()) dist_main_block, new_op_desc, type=new_op_desc.type())
dist_main_block.ops.insert(0, new_op) dist_main_block.ops.insert(0, new_op)
for in_name in new_op.input_arg_names: for in_name in new_op.input_arg_names:
if in_name == "lod_tensor_blocking_queue_0": if "lod_tensor_blocking_queue" in in_name:
continue continue
if in_name not in dist_main_block.vars: if in_name not in dist_main_block.vars:
in_var = serial_main_block._var_recursive(in_name) in_var = serial_main_block._var_recursive(in_name)
...@@ -424,6 +504,27 @@ class Engine: ...@@ -424,6 +504,27 @@ class Engine:
.format(i, spec)) .format(i, spec))
return specs return specs
def _set_data_parallel(self, var):
if self._nranks == 1:
self._default_strategy = 'serial'
auto.shard_tensor(
var,
dist_attr={
"process_mesh": [0],
"dims_mapping": [-1 for _ in range(len(var.shape))]
})
else:
self._default_strategy = 'dp'
auto.shard_tensor(
var,
dist_attr={
"process_mesh": list(range(self._nranks)),
"dims_mapping":
[0] + [-1 for _ in range(len(var.shape) - 1)]
})
return var
def save(self, path, training=True, mode=None): def save(self, path, training=True, mode=None):
if not mode: if not mode:
mode = self.mode mode = self.mode
...@@ -459,3 +560,35 @@ class Engine: ...@@ -459,3 +560,35 @@ class Engine:
dist_context = self._dist_contexts[mode] dist_context = self._dist_contexts[mode]
self._saver.load(path, dist_main_prog, dist_context, strict, self._saver.load(path, dist_main_prog, dist_context, strict,
load_optimizer) load_optimizer)
@property
def mode(self):
return self._mode
@mode.setter
def mode(self, mode):
self._mode = mode
@property
def metrics(self):
return self._metrics
@property
def main_program(self):
return self._dist_main_progs[self.mode][self._cur_rank]
@property
def startup_program(self):
return self._dist_startup_progs[self.mode][self._cur_rank]
@property
def dist_context(self):
return self._dist_contexts[self.mode]
@property
def serial_main_program(self):
return self._serial_main_progs[self.mode]
@property
def serial_startup_program(self):
return self._serial_startup_progs[self.mode]
...@@ -53,6 +53,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -53,6 +53,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
input_names = op_desc.input_names()
xshape_arg_names = []
if "XShape" in input_names:
xshape_arg_names = op_desc.input("XShape")
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
...@@ -63,10 +67,18 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -63,10 +67,18 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
# continue # continue
# if len(dims_mapping) < 1: # if len(dims_mapping) < 1:
# continue # continue
if len(dims_mapping) > 1: if arg_name not in xshape_arg_names:
for mapping in dims_mapping[1:]: if len(dims_mapping) > 1:
if mapping != -1: for mapping in dims_mapping[1:]:
return False if mapping != -1:
return False
else:
if dims_mapping[0] != -1:
return False
if len(dims_mapping) > 2:
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
return True return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
...@@ -105,17 +117,31 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -105,17 +117,31 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
batch_dim_mappings = [] batch_dim_mappings = []
# Check input compatibility # Check input compatibility
input_names = op_desc.input_names()
xshape_arg_names = []
if "XShape" in input_names:
xshape_arg_names = op_desc.input("XShape")
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1: if arg_name not in xshape_arg_names:
for mapping in dims_mapping[1:]: if len(dims_mapping) > 1:
if mapping != -1: for mapping in dims_mapping[1:]:
return False if mapping != -1:
if len(dims_mapping) >= 1: return False
batch_dim_mappings.append(dims_mapping[0]) if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
if dims_mapping[0] != -1:
return False
if len(dims_mapping) > 2:
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
# Check output compatibility # Check output compatibility
output_names = op_desc.output_names() output_names = op_desc.output_names()
...@@ -160,24 +186,39 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -160,24 +186,39 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
or op_desc.type() == "slice" \ or op_desc.type() == "slice" \
or op_desc.type() == "while": or op_desc.type() == "while":
return False return False
input_names = op_desc.input_names()
input_xshape_arg_names = []
if "XShape" in input_names:
input_xshape_arg_names = op_desc.input("XShape")
output_names = op_desc.output_names() output_names = op_desc.output_names()
xshape_arg_names = [] output_xshape_arg_names = []
if "XShape" in output_names: if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape") output_xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = [] batch_dim_mappings = []
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) >= 1: if arg_name not in input_xshape_arg_names:
batch_dim_mappings.append(dims_mapping[0]) if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
batch_dim_mappings.append(dims_mapping[1])
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
if op_desc.type() == "fill_zeros_like":
input_tensor = dist_op.get_serial_input(op_desc.input_arg_names(
)[0])
if input_tensor.is_parameter:
continue
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in output_xshape_arg_names:
if len(dims_mapping) >= 1: if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0]) batch_dim_mappings.append(dims_mapping[0])
else: else:
...@@ -194,16 +235,27 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -194,16 +235,27 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping if arg_name not in input_xshape_arg_names:
) >= 1 and compatible_dim_mapping != dims_mapping[0]: if len(dims_mapping) >= 1 and \
dims_mapping[0] = compatible_dim_mapping compatible_dim_mapping != dims_mapping[0]:
changed = True dims_mapping[0] = compatible_dim_mapping
changed = True
else:
if len(dims_mapping) >= 2 and \
compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
if op_desc.type() == "fill_zeros_like":
input_tensor = dist_op.get_serial_input(op_desc.input_arg_names(
)[0])
if input_tensor.is_parameter:
continue
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in output_xshape_arg_names:
if len(dims_mapping if len(dims_mapping
) >= 1 and compatible_dim_mapping != dims_mapping[0]: ) >= 1 and compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
...@@ -371,30 +423,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -371,30 +423,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if need_gradient_allreduce: if need_gradient_allreduce:
allreduce_vars = [] allreduce_vars = []
for input_name in backward_op.desc.input_names(): for output_name in backward_op.desc.output_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.output(output_name):
if "@GRAD" not in varname and is_parameter_related( if varname in kwargs["grad_var_to_var"]:
varname, main_block): fwd_name = kwargs["grad_var_to_var"][varname]
# NOTE: When amp and recompute pass are effective at the same time, if fwd_name not in main_block.vars:
# if a parameter is casted and recomputed, the 'parameter@GARD' can not continue
# be found in the grad_op's output. if is_parameter_related(fwd_name, main_block):
if "subprog_" in varname: allreduce_vars.append(varname)
varname = varname[:varname.index(".subprog_")]
assert len(
backward_op.desc.input(input_name)
) == 1, "parameter input to grad op should be length 1, but got [{}]".format(
backward_op.desc.input(input_name))
assert varname + "@GRAD" in backward_op.desc.output_arg_names(
), "parameter's grad [{}] not found in the grad op's output".format(
varname + "@GRAD")
assert len(
backward_op.desc.output(input_name + "@GRAD")
) == 1, "parameter grad of grad op should be length 1, but got [{}]".format(
backward_op.desc.output(input_name + "@GRAD"))
allreduce_vars.append(
backward_op.desc.output(input_name + "@GRAD")[0])
if len(allreduce_vars) > 0: if len(allreduce_vars) > 0:
......
...@@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext, Di ...@@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext, Di
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group from .process_group import new_process_group
from .utils import set_dist_op_desc_original_id from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op
from .operators.common import BACKWARD_ONLY_DIST_OPS from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
...@@ -198,15 +198,29 @@ class Partitioner(object): ...@@ -198,15 +198,29 @@ class Partitioner(object):
dist_op_context = self._dist_context.dist_op_context dist_op_context = self._dist_context.dist_op_context
serial_ops = ref_block.ops serial_ops = ref_block.ops
last_fwd_op_idx = -1
for idx, op in enumerate(ref_block.ops):
if is_loss_op(op):
last_fwd_op_idx = idx
break
if last_fwd_op_idx == -1:
last_fwd_op_idx = len(ref_block.ops)
# init mapping # init mapping
forward_op_id2forward_op = {} forward_op_id2forward_op = {}
for idx in range(len(serial_ops)): for idx in range(len(serial_ops)):
if is_forward_op(serial_ops[idx]): if idx <= last_fwd_op_idx:
forward_op_id2forward_op[serial_ops[idx].desc.id( forward_op_id2forward_op[serial_ops[idx].desc.id(
)] = serial_ops[idx] )] = serial_ops[idx]
appended_grad_times = 0
# partiiton # partiiton
for op in serial_ops: for idx, op in enumerate(serial_ops):
if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1]) or
is_loss_op(serial_ops[idx - 1])):
appended_grad_times += 1
# partititon input variables # partititon input variables
for serial_input_varname in op.desc.input_arg_names(): for serial_input_varname in op.desc.input_arg_names():
...@@ -244,8 +258,11 @@ class Partitioner(object): ...@@ -244,8 +258,11 @@ class Partitioner(object):
kinputs, koutputs = dist_op_context.prepare_context(op) kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_backward_impl = _get_dist_op_backward_implement( dist_op_backward_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op) op, self._dist_context, forward_op_id2forward_op)
dist_op_backward_impl.backward(self._dist_context, **kinputs, grad_var_to_var = self._dist_context.dist_op_context.grad_var_to_var[
**koutputs) appended_grad_times]
dist_op_backward_impl.backward(
self._dist_context, **kinputs, **koutputs,
**{"grad_var_to_var": grad_var_to_var})
else: else:
raise NotImplementedError( raise NotImplementedError(
"partitioner only support forward op and backward op, but got {}". "partitioner only support forward op and backward op, but got {}".
......
...@@ -996,69 +996,87 @@ def set_grad_var_shape(program, dist_context): ...@@ -996,69 +996,87 @@ def set_grad_var_shape(program, dist_context):
block = program.global_block() block = program.global_block()
vars = block.vars vars = block.vars
for op in block.ops: appended_grad_times = 0
grad_var_to_var = dist_context.dist_op_context.grad_var_to_var
for idx, op in enumerate(block.ops):
if int(op.attr('op_role')) != int(OpRole.Backward):
continue
if int(block.ops[idx-1].attr('op_role')) == int(OpRole.Forward) or \
int(block.ops[idx-1].attr('op_role')) == 257:
appended_grad_times += 1
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
break break
if op.type in ["sum", "concat"]: if op.type in ["sum", "concat", "shape"]:
continue continue
if int(op.attr('op_role')) == int(OpRole.Backward):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr is not None
for var_name in op.output_arg_names: op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if "@GRAD" not in var_name: assert op_dist_attr is not None
continue
for var_name in op.output_arg_names:
if "@GRAD" not in var_name:
continue
if var_name in grad_var_to_var[appended_grad_times]:
forward_var_name = grad_var_to_var[appended_grad_times][
var_name]
else:
forward_var_name = var_name[:var_name.find("@GRAD")] forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast" if op.type in [
]: "c_allreduce_sum", "c_identity", "scale", "cast",
forward_var_name = op.input_arg_names[0] "fill_zeros_like"
elif op.type == "matmul_v2_grad": ]:
forward_var_name = None forward_var_name = op.input_arg_names[0]
for output_name in op.output_names: elif op.type == "matmul_v2_grad":
if var_name in op.output(output_name): forward_var_name = None
assert "@GRAD" in output_name for output_name in op.output_names:
input_name = output_name[:output_name.find("@GRAD")] if var_name in op.output(output_name):
assert len(op.input(input_name)) == 1 assert "@GRAD" in output_name
forward_var_name = op.input(input_name)[0] input_name = output_name[:output_name.find("@GRAD")]
assert forward_var_name is not None assert len(op.input(input_name)) == 1
forward_var_name = op.input(input_name)[0]
need_set_shape_list = [ assert forward_var_name is not None
"reshape2_grad", "softmax_with_cross_entropy_grad",
"transpose2_grad", "softmax_grad", "cross_entropy_grad2", need_set_shape_list = [
"dropout_grad" "reshape2_grad", "softmax_with_cross_entropy_grad",
] "transpose2_grad", "softmax_grad", "cross_entropy_grad2",
forward_list = [ "dropout_grad", "tanh_grad", "slice", "assign",
"reshape2", "softmax_with_cross_entropy", "transpose2", "matmul_v2_triple_grad", "elementwise_add_triple_grad",
"softmax", "cross_entropy2", "dropout" "fill_constant", "sqrt_grad"
] ]
if op.type in need_set_shape_list: forward_list = [
for forward_op in block.ops: "reshape2", "softmax_with_cross_entropy", "transpose2",
assert int(forward_op.attr('op_role')) != int( "softmax", "cross_entropy2", "dropout", "tanh",
OpRole.Backward) ["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad",
idx = need_set_shape_list.index(op.type) "elementwise_add_grad_grad", "shape", "sqrt"
forward_op_name = forward_list[idx] ]
if forward_op.type == forward_op_name and forward_var_name in forward_op.input_arg_names: if op.type in need_set_shape_list:
op_dist_attr = dist_context.get_op_dist_attr_for_program( for forward_op in block.ops:
forward_op) idx = need_set_shape_list.index(op.type)
break forward_op_name = forward_list[idx]
if forward_op.type in forward_op_name and forward_var_name in forward_op.input_arg_names:
forward_input_dist_attr = op_dist_attr.get_input_dist_attr( op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_var_name) forward_op)
assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}" break
forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var) forward_var_name)
assert forward_var_dist_attr is not None assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}"
grad_var = vars[var_name] forward_var = vars[forward_var_name]
ref_shape = infer_shape(block, forward_var, forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
forward_var_dist_attr, forward_var)
forward_input_dist_attr) assert forward_var_dist_attr is not None
grad_var = vars[var_name]
if list(grad_var.shape) != ref_shape: ref_shape = infer_shape(block, forward_var, forward_var_dist_attr,
grad_var.desc.set_shape(ref_shape) forward_input_dist_attr)
if list(grad_var.shape) != ref_shape:
grad_var.desc.set_shape(ref_shape)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
......
...@@ -478,12 +478,16 @@ def _accumulate_gradients_by_add_ops_(var_name, ...@@ -478,12 +478,16 @@ def _accumulate_gradients_by_add_ops_(var_name,
renamed_vars[var_name] = [var_name] renamed_vars[var_name] = [var_name]
def _addup_repetitive_outputs_(op_descs, block_idx): def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
""" """
In backward part, an variable may be the output of more than one ops. In backward part, an variable may be the output of more than one ops.
And one op may yield its multiple outputs to the same variable. And one op may yield its multiple outputs to the same variable.
In these cases, the variable should be the accumulation of all the outputs. In these cases, the variable should be the accumulation of all the outputs.
`sum_op`s are added to implement the accumulate. `sum_op`s are added to implement the accumulate.
Args:
grad_var_to_var(dict): used to build the mapping between grad var name and forward var name.
Only for auto parallel.
""" """
_MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add'] _MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add']
#pending_sum_ops = [] #pending_sum_ops = []
...@@ -531,6 +535,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx): ...@@ -531,6 +535,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx):
new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \ new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] += 1 var_rename_count[var_name] += 1
# Build the mapping between the new_name and var_name (Only for auto parallel)
if grad_var_to_var is not None:
if var_name in grad_var_to_var:
grad_var_to_var[new_name] = grad_var_to_var[
var_name]
else:
grad_var_to_var[new_name] = var_name
# rename original var_name # rename original var_name
renamed_vars[var_name][0] = new_name renamed_vars[var_name][0] = new_name
# before change: _rename_arg_(op_descs, var_name, # before change: _rename_arg_(op_descs, var_name,
...@@ -557,6 +568,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx): ...@@ -557,6 +568,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx):
new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \ new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] += 1 var_rename_count[var_name] += 1
# Build the mapping between the new_name and var_name (Only for auto parallel)
if grad_var_to_var is not None:
if var_name in grad_var_to_var:
grad_var_to_var[new_name] = grad_var_to_var[
var_name]
else:
grad_var_to_var[new_name] = var_name
arg_names[arg_idx] = new_name arg_names[arg_idx] = new_name
op_desc.set_output(param_name, arg_names) op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name) renamed_vars[var_name].append(new_name)
...@@ -1081,6 +1099,16 @@ def _append_backward_ops_(block, ...@@ -1081,6 +1099,16 @@ def _append_backward_ops_(block,
rename_var_map(dict): used to associate target_grad var name with first grad_op input name. rename_var_map(dict): used to associate target_grad var name with first grad_op input name.
Only used in for high order gradient. Only used in for high order gradient.
""" """
# Build the mapping between the forward op and backward op (Only for auto parallel)
def update_distop_context(distop_context, op_grad_to_var,
appending_grad_times):
distop_context.grad_var_to_var[appending_grad_times].update(
op_grad_to_var)
for op_desc in grad_op_desc:
assert op_desc.id() not in distop_context.grad_op_id_to_op_id
distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id()
if callbacks is not None: if callbacks is not None:
assert (isinstance(callbacks, (list, tuple))) assert (isinstance(callbacks, (list, tuple)))
for cb in callbacks: for cb in callbacks:
...@@ -1118,11 +1146,18 @@ def _append_backward_ops_(block, ...@@ -1118,11 +1146,18 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
# Build the mapping between the forward op and backward op (Only for auto parallel) # Build the mapping between the forward op and backward op (Only for auto parallel)
if distop_context is not None: if distop_context is not None:
for op_desc in grad_op_desc: update_distop_context(distop_context, op_grad_to_var,
assert op_desc.id() not in distop_context.grad_op_id_to_op_id program._appending_grad_times)
distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id() else:
default_ctx = getattr(paddle.distributed.auto_parallel.dist_context,
'_g_default_distributed_context', None)
if default_ctx is not None:
distop_context = default_ctx.dist_op_context
update_distop_context(distop_context, op_grad_to_var,
program._appending_grad_times)
# Set device for grad_op according to forward Op # Set device for grad_op according to forward Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
...@@ -1155,6 +1190,11 @@ def _append_backward_ops_(block, ...@@ -1155,6 +1190,11 @@ def _append_backward_ops_(block,
rename_var_map[name] = new_name rename_var_map[name] = new_name
if name in op_grad_to_var: if name in op_grad_to_var:
# Build the mapping between the grad var name and var name (Only for auto parallel)
if distop_context is not None:
distop_context.grad_var_to_var[
program._appending_grad_times][
new_name] = op_grad_to_var[name]
op_grad_to_var[new_name] = op_grad_to_var[name] op_grad_to_var[new_name] = op_grad_to_var[name]
op_grad_to_var.pop(name) op_grad_to_var.pop(name)
...@@ -1187,8 +1227,14 @@ def _append_backward_ops_(block, ...@@ -1187,8 +1227,14 @@ def _append_backward_ops_(block,
grad_op_descs.extend(grad_op_desc) grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
# record mapping bewteen grad var name and var name (Only for auto parallel)
grad_var_to_var = None
if distop_context is not None:
grad_var_to_var = distop_context.grad_var_to_var[
program._appending_grad_times]
# sum parameter's gradients' var given multiple var gradient # sum parameter's gradients' var given multiple var gradient
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx) grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx,
grad_var_to_var)
# if all outputs of the grad op are in no_grad_set, then just remove and fill zero # if all outputs of the grad op are in no_grad_set, then just remove and fill zero
# if all inputs of the grad op are in no_grad_set, just remove this op # if all inputs of the grad op are in no_grad_set, just remove this op
......
...@@ -12,6 +12,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -12,6 +12,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS}) py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS})
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_high_order_grad MODULES test_high_order_grad ENVS ${dist_ENVS})
set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS}) py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS})
py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS}) py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS})
......
...@@ -127,9 +127,16 @@ def train(): ...@@ -127,9 +127,16 @@ def train():
engine.prepare(optimizer, loss) engine.prepare(optimizer, loss)
engine.fit(dataset, engine.fit(dataset,
batch_size=batch_size, batch_size=batch_size,
steps_per_epoch=batch_num * batch_size) steps_per_epoch=batch_num * batch_size,
engine.save('./mlp') sample_generator=True)
engine.load('./mlp')
eval_dataset = MyDataset(batch_size)
engine.prepare(optimizer, loss, mode='eval')
engine.evaluate(eval_dataset, batch_size)
test_dataset = MyDataset(batch_size)
engine.prepare(mode='predict')
engine.predict(test_dataset, batch_size)
engine.save('./mlp_inf', training=False, mode='predict') engine.save('./mlp_inf', training=False, mode='predict')
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import time
import paddle.fluid as fluid
import copy
import os
import numpy as np
import subprocess
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
paddle.enable_static()
global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
batch_size = 1
batch_num = 10
hidden_size = 1024
image_size = hidden_size
paddle.seed(44)
class MyDataset(Dataset):
def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
return input
def __len__(self):
return self.num_samples
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(input)
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": global_process_mesh,
"dims_mapping": [-1, 0]
})
out = F.gelu(out, approximate=True)
out = self.linear1(out)
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": global_process_mesh,
"dims_mapping": [0, -1]
})
out = self.dropout(out)
out = self.linear2(out)
return out
def train():
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
dataset = MyDataset(batch_num * batch_size)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
dist_strategy = fleet.DistributedStrategy()
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
engine = Engine(mlp, inputs_spec=inputs_spec, strategy=dist_strategy)
engine.prepare(mode='predict')
engine.predict(dataset, batch_size=batch_size)
if __name__ == "__main__":
train()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import paddle
import unittest
import numpy as np
import paddle.distributed.auto_parallel as auto
from paddle.static import InputSpec
from paddle.distributed import fleet
from paddle.incubate.autograd import Hessian
from paddle.distributed.auto_parallel.engine import Engine
class FCNet:
def __init__(self, num_ins, num_outs, num_layers, hidden_size):
self.num_ins = num_ins
self.num_outs = num_outs
self.num_layers = num_layers
self.hidden_size = hidden_size
self.activation = paddle.tanh
self.weights = []
self.biases = []
for i in range(self.num_layers):
if i == 0:
lsize = self.num_ins
rsize = self.hidden_size
elif i == (self.num_layers - 1):
lsize = self.hidden_size
rsize = self.num_outs
else:
lsize = self.hidden_size
rsize = self.hidden_size
w = paddle.static.create_parameter(
shape=[lsize, rsize], dtype="float32", is_bias=False)
b = paddle.static.create_parameter(
shape=[rsize], dtype="float32", is_bias=True)
self.weights.append(w)
self.biases.append(b)
def nn_func(self, ins):
u = ins
for i in range(self.num_layers - 1):
u = paddle.nn.functional.linear(u, self.weights[i], self.biases[i])
u = self.activation(u)
u = paddle.nn.functional.linear(u, self.weights[-1], self.biases[-1])
return u
class LaplaceModel(paddle.nn.Layer):
def __init__(self, num_ins=2, num_outs=1, num_layers=5, hidden_size=20):
super(LaplaceModel, self).__init__()
self.net = FCNet(
num_ins=num_ins,
num_outs=num_outs,
num_layers=num_layers,
hidden_size=hidden_size)
def forward(self, inputs, bc_index):
inputs.stop_gradient = False
outputs = self.net.nn_func(inputs)
# eq_loss
hes = Hessian(self.net.nn_func, inputs, is_batched=True)
eq_loss = paddle.norm(hes[:, 0, 0] + hes[:, 1, 1], p=2)
# bc_loss
bc_u = paddle.index_select(outputs, bc_index)
return eq_loss, bc_u
class LaplaceDataset:
def __init__(self, num_sample):
self.num_sample = num_sample
def __getitem__(self, index):
x = np.linspace(0, 0.9, 10)
y = np.linspace(0, 0.9, 10)
bc_value = np.random.rand(36).reshape(36, 1).astype('float32')
domain_space = []
bc_index = []
for j in range(len(y)):
for i in range(len(x)):
domain_space.append([x[i], y[j]])
if i == 0 or i == 9 or j == 0 or j == 9:
bc_index.append(i + 10 * j)
domain_space = np.array(domain_space, dtype='float32')
bc_index = np.array(bc_index, dtype='int64')
return domain_space, bc_index, bc_value
def __len__(self):
return self.num_sample
def loss_func(eq_loss, bc_u, bc_value):
bc_diff = bc_u - bc_value
bc_loss = paddle.norm(bc_diff, p=2)
loss = eq_loss + bc_loss
return loss
def main():
# dataset
train_dataset = LaplaceDataset(10)
# optimizer
optimizer = paddle.optimizer.Adam(learning_rate=0.001)
# model
laplace = LaplaceModel()
# spec
inputs_spec = [
InputSpec([100, 2], 'float32', 'x'), InputSpec([36], 'int64', 'bc_idx')
]
labels_spec = InputSpec([36, 1], 'float32', 'bc_v')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
engine = Engine(
laplace,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
paddle.seed(1234 + engine._cur_rank)
engine.prepare(optimizer=optimizer, loss=loss_func)
res = engine.fit(train_dataset, sample_generator=False)
assert np.allclose(res[-1], 2.840593)
dist_context = engine.dist_context
block = engine.main_program.global_block()
ops = block.ops
for op in ops:
if op.type == 'p_norm':
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == 'p_norm'
if 'x' in op.input_arg_names:
out_name = op.output_arg_names[0]
assert block.vars[out_name].shape[0] == 50
if __name__ == "__main__":
main()
...@@ -49,28 +49,6 @@ class TestEngineAPI(unittest.TestCase): ...@@ -49,28 +49,6 @@ class TestEngineAPI(unittest.TestCase):
if os.path.exists('rank_mapping.csv'): if os.path.exists('rank_mapping.csv'):
os.remove('rank_mapping.csv') os.remove('rank_mapping.csv')
def test_engine_predict(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "engine_predict_api.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
# Remove unnecessary files
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestHighOrderGrad(unittest.TestCase):
def test_dp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "high_order_grad.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
# Remove unnecessary files
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册