未验证 提交 83a4b26a 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Refactor the engine api and parallelizer (#42576)

* [Auto Parallel] Refactor the engine api and parallelizer

* [Auto Parallel] Fix the default dist op for the slice op

* [Auto Parallel] Fix the format of planer.py

* [Auto Parallel] Fix a bug
上级 3540d33b
......@@ -485,10 +485,10 @@ class OperatorDistributedAttribute:
self.process_mesh)
for arg_name, tensor_dist_attr in self.inputs_dist_attrs.items():
str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr)
str += "\n\t\t{}'s (input): {},".format(arg_name, tensor_dist_attr)
for arg_name, tensor_dist_attr in self.outputs_dist_attrs.items():
str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr)
str += "\n\t\t{}'s (output): {},".format(arg_name, tensor_dist_attr)
str += "\n\t\timpl type: {}, ".format(self._impl_type)
str += "impl idx: {}".format(self._impl_idx)
......
......@@ -55,10 +55,10 @@ class DistributedContext:
def __init__(self,
serial_main_prog=None,
serial_startup_prog=None,
dist_main_progs=None,
dist_startup_progs=None,
serial_loss=None,
serial_optimizer=None,
serial_loss=None,
feed_vars=None,
fetch_vars=None,
strategy=None):
# Data members related to original programs (unchanged)
self._original_serial_main_program = serial_main_prog
......@@ -75,8 +75,10 @@ class DistributedContext:
# Data members related to programs (changed)
self._serial_main_program = None
self._serial_startup_program = None
self._serial_loss = None
self._serial_optimizer = None
self._serial_loss = serial_loss
self._serial_optimizer = serial_optimizer
self._serial_feed_vars = feed_vars
self._serial_fetch_vars = fetch_vars
# Data members related to the program
self._dist_tensors_for_program = {}
......@@ -92,11 +94,7 @@ class DistributedContext:
# Data members related to the distributed programs
# Distributed programs
self._dist_main_programs = dist_main_progs
if not self._dist_main_programs:
self._dist_main_programs = {}
self._dist_startup_programs = dist_startup_progs
if not self._dist_startup_programs:
self._dist_startup_programs = {}
# Distributed Strategy
......@@ -132,34 +130,26 @@ class DistributedContext:
def serial_startup_program(self):
return self._serial_startup_program
# @serial_startup_program.setter
# def serial_startup_program(self, serial_startup_program):
# self._serial_startup_program = serial_startup_program
@property
def serial_loss(self):
return self._serial_loss
# @serial_loss.setter
# def serial_loss(self, serial_loss):
# self._serial_loss = serial_loss
@property
def serial_optimizer(self):
return self._serial_optimizer
# @serial_optimizer.setter
# def serial_optimizer(self, serial_optimizer):
# self._serial_optimizer = serial_optimizer
@property
def serial_feed_vars(self):
return self._serial_feed_vars
@property
def serial_fetch_vars(self):
return self._serial_fetch_vars
@property
def strategy(self):
return self._strategy
# @strategy.setter
# def strategy(self, strategy):
# self._strategy = strategy
@property
def serial_graph(self):
return self._serial_graph
......@@ -678,7 +668,7 @@ class DistributedContext:
dist_op.serial_op.type)
if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert False, "Operator {} has a wrong distributed attributes {}.".format(
dist_op.serial_op.type, dist_tensor.dist_attr)
dist_op.serial_op.type, dist_op.dist_attr)
return True
def __deepcopy__(self, memo):
......
......@@ -34,12 +34,9 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext
from .mapper import mapping
from .cluster import Cluster
from .reshard import Resharder
from .planner import Planner
from .completion import Completer
from .partitioner import Partitioner
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
......@@ -79,7 +76,6 @@ class Engine:
self._dist_main_progs = defaultdict(dict) # dist main programs
self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._dist_contexts = {}
self._pass_contexts = {}
self._feed_vars = {}
self._fetch_vars = {}
......@@ -94,10 +90,27 @@ class Engine:
self._loss = loss
self._metrics = to_list(metrics)
self._mode = mode
self._build(mode) # build forward program
self._plan(mode) # completion & planner
self._parallel(mode, all_ranks) # parallel
self._initialize(mode) # init comm and startup program
# Build forward program
self._build(mode)
# Do the planning process
planner = Planner(mode, self._dist_contexts[mode])
planner.plan()
# Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update.
parallelizer = Parallelizer(mode, planner.completer,
self._dist_contexts[mode])
if not all_ranks:
parallelizer.parallel(self._cur_rank)
else:
parallelizer.parallel_all()
# Get the distributed main programs and startup programs
self._dist_main_progs[mode] = self._dist_contexts[
mode].dist_main_programs
self._dist_startup_progs[mode] = self._dist_contexts[
mode].dist_startup_programs
# Init comm and startup program
self._initialize(mode)
def _build(self, mode):
serial_main_prog = self._serial_main_progs.get(mode, None)
......@@ -133,34 +146,9 @@ class Engine:
self._serial_main_progs[mode] = serial_main_prog
self._serial_startup_progs[mode] = serial_startup_prog
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._dist_main_progs[mode],
self._dist_startup_progs[mode])
self._pass_contexts[mode] = PassContext()
def _plan(self, mode):
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completition.
defualt_ctx = get_default_distributed_context()
self._dist_contexts[mode]._dist_op_context = defualt_ctx.dist_op_context
# Complete the distributed annotation
serial_main_prog = self._serial_main_progs[mode]
self._completer = Completer(self._dist_contexts[mode])
self._completer.complete_forward_annotation(serial_main_prog)
# TODO: add auto planner process
# parse forward sub block
self._dist_contexts[mode].block_state.parse_forward_blocks(
serial_main_prog)
def _parallel(self, mode, all_ranks=False):
if not all_ranks:
self._parallel_program(mode, self._cur_rank)
else:
world_process_group = get_world_process_group()
all_ranks = world_process_group.ranks
for rank in all_ranks:
self._parallel_program(mode, rank)
self._serial_main_progs[mode], self._serial_startup_progs[mode],
self._optimizer, losses, self._feed_vars[mode],
self._fetch_vars[mode], self.strategy)
def _initialize(self, mode):
if self._nranks > 1:
......@@ -189,131 +177,6 @@ class Engine:
prune_startup_prog = dist_startup_prog._prune(uninitialized)
self._executor.run(prune_startup_prog)
def _parallel_program(self, mode, rank):
serial_main_program = self._serial_main_progs[mode]
serial_startup_program = self._serial_startup_progs[mode]
dist_context = self._dist_contexts[mode]
if mode == "train" and self._optimizer:
# Generate backward
serial_loss = self._fetch_vars[mode]["loss"][0]
params_grads = self._generate_backward(
serial_main_program, serial_startup_program, serial_loss)
# Apply pre optimization passes
self._apply_pre_optimization(serial_main_program,
serial_startup_program, serial_loss,
params_grads)
# Do logical partition
partitioner = Partitioner(dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, params_grads)
# Generate optimizer
self._generate_optimizer(dist_main_prog, dist_startup_prog,
dist_params_grads)
# Do reshard process
set_grad_var_shape(dist_main_prog, dist_context)
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_context, dist_params_grads)
resharder.reshard()
# Apply post optimization passes
self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
else:
# Apply pre optimization passes
self._apply_pre_optimization(serial_main_program,
serial_startup_program, None, None)
# Do logical partition
partitioner = Partitioner(dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, [])
# Do reshard process
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_context, [], 1)
resharder.reshard()
# clone program for test
if mode != 'train':
dist_main_prog = dist_main_prog.clone(for_test=True)
dist_startup_prog = dist_startup_prog.clone(for_test=True)
self._dist_main_progs[mode][rank] = dist_main_prog
self._dist_startup_progs[mode][rank] = dist_startup_prog
def _generate_backward(self, main_program, startup_program, loss):
with program_guard(main_program, startup_program):
params_grads = append_backward(
loss,
distop_context=self._dist_contexts[self.mode].dist_op_context)
self._completer.complete_backward_annotation(main_program)
self._dist_contexts[self.mode].block_state.parse_backward_blocks(
main_program)
return params_grads
def _generate_optimizer(self, main_program, startup_program, params_grads):
with program_guard(main_program, startup_program):
optimizer_ops = copy.deepcopy(self._optimizer).apply_gradients(
params_grads)
self._completer.complete_update_annotation(main_program)
return optimizer_ops
def _apply_pre_optimization(self, main_program, startup_program, loss,
params_grads):
# apply amp pass
if self.strategy.amp:
config = copy.deepcopy(self.strategy.amp_configs)
config["dist_context"] = self._dist_contexts[self.mode]
config["params_grads"] = params_grads
config["loss"] = loss
config["input_data"] = self._feed_vars[self.mode][
"inputs"] + self._feed_vars[self.mode]["labels"]
if config["use_pure_fp16"]:
config["base_opt"] = self._optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply([main_program],
[startup_program],
self._pass_contexts[self.mode])
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_contexts[self.mode])
# apply recompute pass
if self.strategy.recompute:
config = copy.deepcopy(self.strategy.recompute_configs)
config["dist_context"] = self._dist_contexts[self.mode]
config["no_grad_set"] = None
config["loss"] = loss
auto_parallel_recompute_pass = new_pass("auto_parallel_recompute",
config)
auto_parallel_recompute_pass.apply([main_program],
[startup_program],
self._pass_contexts[self.mode])
def _apply_post_optimization(self, main_program, startup_program, rank,
params_grads):
if self.strategy.sharding:
config = copy.deepcopy(self.strategy.sharding_configs)
config["dist_context"] = self._dist_contexts[self.mode]
config["params_grads"] = params_grads
config["global_rank"] = rank
auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
config)
auto_parallel_sharding_pass.apply([main_program],
[startup_program],
self._pass_contexts[self.mode])
if self.strategy.gradient_merge:
config = copy.deepcopy(self.strategy.gradient_merge_configs)
config["dist_context"] = self._dist_contexts[self.mode]
config["params_grads"] = params_grads
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass", config)
auto_parallel_gradient_merge_pass.apply(
[main_program], [startup_program],
self._pass_contexts[self.mode])
def fit(self,
train_data,
batch_size=1,
......
......@@ -201,10 +201,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" \
or op_desc.type() == "slice" \
or op_desc.type() == "while":
if op_desc.type() == "while":
return False
input_names = op_desc.input_names()
......@@ -273,6 +271,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
)[0])
if input_tensor.is_parameter:
continue
if op_desc.type() in ["shape", "slice"]:
continue
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
......
......@@ -80,11 +80,19 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
output_arg_names = op_desc.output_arg_names()
max_dims_mapping_len = -1
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
dims_mapping_list.append(dims_mapping)
if compute_compatible_dims_mapping(dims_mapping_list) is None:
for idx in range(max_dims_mapping_len):
dim_mappings = []
for dims_mapping in dims_mapping_list:
if idx < len(dims_mapping):
dim_mappings.append(dims_mapping[-(idx + 1)])
if compute_compatible_dim_mapping(dim_mappings) is None:
return False
return True
......@@ -94,19 +102,26 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
return False
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
max_dims_mapping_len = -1
input_max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
if input_max_dims_mapping_len < len(dims_mapping):
input_max_dims_mapping_len = len(dims_mapping)
dims_mapping_list.append(dims_mapping)
output_arg_names = op_desc.output_arg_names()
output_max_dims_mapping_len = -1
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
if output_max_dims_mapping_len < len(dims_mapping):
output_max_dims_mapping_len = len(dims_mapping)
dims_mapping_list.append(dims_mapping)
assert input_max_dims_mapping_len == output_max_dims_mapping_len
max_dims_mapping_len = input_max_dims_mapping_len
for idx in range(max_dims_mapping_len):
dim_mappings = []
for dims_mapping in dims_mapping_list:
......@@ -121,35 +136,58 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {}
input_dims_mapping_lens = {}
max_dims_mapping_len = -1
input_max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
if input_max_dims_mapping_len < len(dims_mapping):
input_max_dims_mapping_len = len(dims_mapping)
input_dims_mapping_dict[arg_name] = dims_mapping
input_dims_mapping_lens[arg_name] = len(dims_mapping)
dims_mapping_list = []
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
if input_dims_mapping_lens[arg_name] < input_max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(input_max_dims_mapping_len)
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
new_idx = (input_max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[
arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
dims_mapping_list.append(input_dims_mapping_dict[arg_name])
output_arg_names = op_desc.output_arg_names()
output_dims_mapping_dict = {}
output_dims_mapping_lens = {}
output_max_dims_mapping_len = -1
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
if output_max_dims_mapping_len < len(dims_mapping):
output_max_dims_mapping_len = len(dims_mapping)
output_dims_mapping_dict[arg_name] = dims_mapping
output_dims_mapping_lens[arg_name] = len(dims_mapping)
for arg_name in output_arg_names:
if output_dims_mapping_lens[arg_name] < output_max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(output_max_dims_mapping_len)
]
for i in range(output_dims_mapping_lens[arg_name]):
new_idx = (output_max_dims_mapping_len -
output_dims_mapping_lens[arg_name]) + i
new_dims_mapping[new_idx] = output_dims_mapping_dict[
arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
dims_mapping_list.append(output_dims_mapping_dict[arg_name])
assert input_max_dims_mapping_len == output_max_dims_mapping_len
max_dims_mapping_len = input_max_dims_mapping_len
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if compatible_dims_mapping is None:
......@@ -175,10 +213,23 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
changed = True
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if compatible_dims_mapping != dims_mapping:
if output_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(output_dims_mapping_lens[arg_name])
]
for i in range(output_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
output_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != output_dims_mapping_dict[arg_name]:
op_dist_attr.set_output_dims_mapping(arg_name,
compatible_dims_mapping)
new_dims_mapping)
changed = True
else:
if compatible_dims_mapping != output_dims_mapping_dict[
arg_name]:
op_dist_attr.set_output_dims_mapping(
arg_name, compatible_dims_mapping)
changed = True
return changed
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections import defaultdict
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.passes import new_pass
from .reshard import Resharder
from .partitioner import Partitioner
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context
class Parallelizer:
def __init__(self, mode, completer, dist_context):
self._mode = mode
self._completer = completer
self._dist_context = dist_context
self._dist_context.initialize()
self._pass_context = self._dist_context.pass_context
self._strategy = self._dist_context.strategy
def parallel_all(self):
world_process_group = get_world_process_group()
all_ranks = world_process_group.ranks
for rank in all_ranks:
self.parallel(rank)
def parallel(self, rank):
serial_main_program = self._dist_context.serial_main_program
serial_startup_program = self._dist_context.serial_startup_program
serial_optimizer = self._dist_context.serial_optimizer
if self._mode == "train" and serial_optimizer:
# Generate backward
serial_loss = self._dist_context.serial_fetch_vars["loss"][0]
params_grads = self._generate_backward(
serial_main_program, serial_startup_program, serial_loss)
# Apply pre optimization passes
self._apply_pre_optimization(serial_main_program,
serial_startup_program, serial_loss,
serial_optimizer, params_grads)
# Do logical partition
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, params_grads)
# Generate optimizer
self._generate_optimizer(dist_main_prog, dist_startup_prog,
serial_optimizer, dist_params_grads)
# Do reshard process
set_grad_var_shape(dist_main_prog, self._dist_context)
make_data_unshard(dist_main_prog, dist_startup_prog,
self._dist_context)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, dist_params_grads)
resharder.reshard()
# Apply post optimization passes
self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
else:
# Apply pre optimization passes
self._apply_pre_optimization(
serial_main_program, serial_startup_program, None, None, None)
# Do logical partition
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, [])
# Do reshard process
make_data_unshard(dist_main_prog, dist_startup_prog,
self._dist_context)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, [], 1)
resharder.reshard()
# Clone program for test
if self._mode != 'train':
dist_main_prog = dist_main_prog.clone(for_test=True)
dist_startup_prog = dist_startup_prog.clone(for_test=True)
# Store the distributed programs for further usages
self._dist_context.dist_main_programs[rank] = dist_main_prog
self._dist_context.dist_startup_programs[rank] = dist_startup_prog
def _generate_backward(self, main_program, startup_program, loss):
with program_guard(main_program, startup_program):
params_grads = append_backward(
loss, distop_context=self._dist_context.dist_op_context)
self._completer.complete_backward_annotation(main_program)
self._dist_context.block_state.parse_backward_blocks(main_program)
return params_grads
def _generate_optimizer(self, main_program, startup_program, optimizer,
params_grads):
with program_guard(main_program, startup_program):
optimizer_ops = copy.deepcopy(optimizer).apply_gradients(
params_grads)
self._completer.complete_update_annotation(main_program)
return optimizer_ops
def _apply_pre_optimization(self, main_program, startup_program, loss,
optimizer, params_grads):
if self._strategy is None:
return
# apply amp pass
if self._strategy.amp:
config = copy.deepcopy(self._strategy.amp_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["loss"] = loss
config["input_data"] = self._dist_context.serial_feed_vars["inputs"] \
+ self._dist_context.serial_feed_vars["labels"]
if config["use_pure_fp16"]:
config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context)
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
# apply recompute pass
if self._strategy.recompute:
config = copy.deepcopy(self._strategy.recompute_configs)
config["dist_context"] = self._dist_context
config["no_grad_set"] = None
config["loss"] = loss
auto_parallel_recompute_pass = new_pass("auto_parallel_recompute",
config)
auto_parallel_recompute_pass.apply(
[main_program], [startup_program], self._dist_context)
def _apply_post_optimization(self, main_program, startup_program, rank,
params_grads):
if self._strategy is None:
return
if self._strategy.sharding:
config = copy.deepcopy(self._strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["global_rank"] = rank
auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
config)
auto_parallel_sharding_pass.apply(
[main_program], [startup_program], self._dist_context)
if self._strategy.gradient_merge:
config = copy.deepcopy(self._strategy.gradient_merge_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass", config)
auto_parallel_gradient_merge_pass.apply(
[main_program], [startup_program], self._dist_context)
......@@ -35,7 +35,6 @@ from .utils import get_all_distributed_main_program
from .dist_context import DistributedContext, DistributedOperatorContext
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
paddle.enable_static()
paddle.seed(123)
random.seed(123)
np.random.seed(123)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .completion import Completer
from .dist_context import get_default_distributed_context
from .utils import print_program_with_dist_attr
class Planner:
def __init__(self, mode, dist_context):
self._mode = mode
self._dist_context = dist_context
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completion.
default_ctx = get_default_distributed_context()
self._dist_context._dist_op_context = default_ctx.dist_op_context
self._dist_context.initialize()
self._completer = Completer(self._dist_context)
@property
def completer(self):
return self._completer
def plan(self):
self._completer.complete_forward_annotation()
# parse forward sub block
self._dist_context.block_state.parse_forward_blocks(
self._dist_context.serial_main_program)
# TODO: add the auto searcher
......@@ -23,6 +23,9 @@ from paddle.distributed import fleet
from paddle.incubate.autograd import Hessian
from paddle.distributed.auto_parallel.engine import Engine
np.random.seed(1234)
paddle.seed(1234)
class FCNet:
def __init__(self, num_ins, num_outs, num_layers, hidden_size):
......@@ -136,10 +139,8 @@ def main():
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
paddle.seed(1234 + engine._cur_rank)
engine.prepare(optimizer=optimizer, loss=loss_func)
res = engine.fit(train_dataset, sample_generator=False)
assert np.allclose(res[-1], 2.840593)
dist_context = engine.dist_context
block = engine.main_program.global_block()
......
......@@ -79,7 +79,6 @@ def parallelizer(program_func, rank):
class TestDistSlice(unittest.TestCase):
def test_dist_slice_dp2(self):
for rank in range(2):
dist_main_prog, dist_context = parallelizer(make_program_dp2, rank)
ops = dist_main_prog.global_block().ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册