未验证 提交 325fdf1d 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update rule based tuner (#51908)

* add patterns

* update rule based tuner

* add forward sub program completion

* add unittest

* add bwd sub program completion
上级 13b8b5e0
......@@ -64,6 +64,7 @@ class DistributedContext:
fetch_vars={},
cluster=None,
strategy=None,
json_config=None,
):
# Data members related to original programs (unchanged)
self._original_serial_main_program = serial_main_prog
......@@ -129,6 +130,8 @@ class DistributedContext:
# A flag indicates whether the used parallelism is data parallel
self._data_parallel = False
self._json_config = json_config
@property
def serial_main_program(self):
return self._serial_main_program
......@@ -181,6 +184,10 @@ class DistributedContext:
def process_meshes(self):
return self._process_meshes
@process_meshes.setter
def process_meshes(self, val):
self._process_meshes = val
@property
def pass_context(self):
return self._pass_context
......@@ -397,7 +404,7 @@ class DistributedContext:
if dist:
self._restore_dist_info(dist_mode)
def initialize(self, with_graph=True, with_cpp=False):
def initialize(self, with_graph=True, with_cpp=False, no_default=False):
if not self._is_initialized:
if not self._serial_main_program:
if self._original_serial_main_program:
......@@ -418,7 +425,7 @@ class DistributedContext:
if not self._serial_fetch_vars:
self._restore_serial_fetch_vars()
self._init_dist_attr_for_program()
self._init_dist_attr_for_program(no_default)
# Backup the original distributed information for later restore
self._original_dist_tensors_for_program = copy.deepcopy(
self._dist_tensors_for_program
......
......@@ -174,7 +174,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
varname
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
......
......@@ -278,6 +278,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
if is_dim_shard(ids_dims_mapping[0]) and is_dim_shard(
w_dims_mapping[-2]
):
if ids_dims_mapping[0] == w_dims_mapping[-2]:
return False
return True
def is_output_compatible(self, dist_op):
......
......@@ -1507,7 +1507,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
processes = process_mesh.process_ids
# col parallel: matmul + allreduce
if backward_op.attr("trans_y"):
Y_var_dim_mapping.reverse()
Y_var_dim_mapping = list(reversed(Y_var_dim_mapping))
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
......
......@@ -12,10 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from ..cost import (
_g_op_cost_factory,
build_comp_costs_from_descs,
build_comp_desc_from_dist_op,
build_dp_costs,
)
from ..utils import compute_compatible_and_update_dim_mapping
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
is_parameter_related,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
......@@ -42,6 +51,84 @@ class DistributedScaleImpl(DistributedOperatorImpl):
def is_input_compatible(self, dist_op):
return True
def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
_g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
)
res_cost = [cost_mapping]
return res_cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.process_ids
backward_op = dist_op.serial_op
op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs(
_g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)
main_block = backward_op.block
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
if need_gradient_allreduce:
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block
):
var_dim_mapping = dist_attr.get_input_dims_mapping(
varname
)
mesh_shape = process_mesh.shape
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(
res,
dist_op,
ctx,
var_names,
attrs,
parallel_axis,
cluster,
)
return res
def is_output_compatible(self, dist_op):
return True
......
......@@ -12,9 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import math
import os
from abc import abstractmethod
from collections import OrderedDict
import paddle
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Parameter, unique_name
from ...utils.log_utils import get_logger
from ..graph import Graph
_PATTERNS = {}
......@@ -548,6 +565,7 @@ class GraphUtil:
def _match_core(src_node, tgt_node):
nonlocal not_matched
# not support one input name or output name corresponding to multiple vars
if not_matched:
return
......@@ -998,20 +1016,168 @@ def convert_to_process_meshes(device_mesh: list) -> list:
class RuleBasedTuner:
def __init__(self, dist_context, mode="train"):
"""
A tuner based on rule from expert experience to search a good parallel strategy.
Args:
dist_context (DistributedContext): The distributed context.
mode (str): The mode of current task, it can be train or eval. Default: train.
level (str): The level of this tuner, it can be o1 or o2.
o2 level may find better strategy but need more time than o1.
If level is o1, it means all layers within same parallelism and place layers evenly when in pipeline parallism.
If level is o2, it means layers can has own parallelism and place layers may not evenly.
Default: o1.
"""
def __init__(self, dist_context, mode="train", level="o1"):
self._dist_context = dist_context
self._cluster = self._dist_context.cluster
self._mode = mode
assert level in ["o1", "o2"]
self._level = level
self._logger = get_logger(logging.INFO)
self._use_dp = False
def cluster_operators(self, ops):
"""
Cluster operators to layers.
# forward sub program
self.fwd_sub_programs = OrderedDict()
Args:
ops (list): A operator list.
# dist_context of sub program
self.sub_programs_dist_context = OrderedDict()
# graph of forward sub program
self.fwd_sub_program_graphs = OrderedDict()
# full main program
self.full_main_program = None
# full startup program
self.full_startup_program = None
# full main program dist context
self.full_main_program_dist_context = None
# tensor dist attribute from pattern setting
self.tensor_dist_attrs = {}
# op original id to op mapping
self.op_original_id_to_op = {}
# op original id to op idx in program
self.op_original_id_to_idx = {}
# op original id to grad op original id mapping
self.op_original_id_to_grad_op_original_id = {}
# all process meshes that the cluster can express
self.process_meshes = []
# all device meshes that the cluster can be partitioned
self.device_meshes_list = []
# the best cost of stage in a given device mesh
self.stage_best_cost_of_dm = {}
# the best cost of stage in a given process mesh
self.stage_best_cost_of_pm = {}
# the op clustering result
self.layers = []
self._is_run = True
if os.getenv("PADDLE_AUTO_PARALLEL_STAGE") != "tuner":
self._is_run = True
else:
self._is_run = False
self._strategy_path = None
if self._dist_context._json_config:
try:
self._strategy_path = self._dist_context._json_config[
"tuner_save_path"
]
except:
self._strategy_path = None
@property
def dist_context(self):
return self._dist_context
@property
def cluster(self):
return self._cluster
@property
def mode(self):
return self._mode
@property
def level(self):
return self._level
def convert_process_mesh_to_key(self, process_mesh):
"""Convert process mesh object to str."""
processes = ",".join([str(x) for x in process_mesh._process_ids])
topology = ",".join([str(x) for x in process_mesh._shape])
key = processes + ";" + topology
return key
def gen_full_program(self):
"""Generate full program that contain backward and update phase program if mode is train."""
self.full_main_program = self.dist_context.serial_main_program.clone()
if self.mode == "train":
self.full_startup_program = (
self.dist_context.serial_startup_program.clone()
)
loss = self.full_main_program.global_block().vars[
self.dist_context.serial_loss.name
]
serial_optimizer = self._dist_context.serial_optimizer
optimizer = copy.deepcopy(serial_optimizer)
self.full_main_program_dist_context = DistributedContext(
serial_main_prog=self.full_main_program,
serial_startup_prog=self.full_startup_program,
serial_loss=loss,
)
# if in train mode, generate backward and update program.
with program_guard(
self.full_main_program, self.full_startup_program
):
params_grads = append_backward(
loss,
distop_context=self.full_main_program_dist_context.dist_op_context,
)
with program_guard(
self.full_main_program, self.full_startup_program
):
with unique_name.guard("opt_"):
optimizer_ops = optimizer.apply_gradients(params_grads)
# op original id to grad op id
for idx, op in enumerate(self.full_main_program.global_block().ops):
self.op_original_id_to_op[op.desc.original_id()] = op
self.op_original_id_to_idx[op.desc.original_id()] = idx
grad_op_id_to_op_id = (
self.full_main_program_dist_context.dist_op_context.grad_op_id_to_op_id
)
for grad_op_original_id in grad_op_id_to_op_id:
op_id = grad_op_id_to_op_id[grad_op_original_id]
self.op_original_id_to_grad_op_original_id[
op_id
] = grad_op_original_id
def cluster_operators(self):
"""Group operators to layers."""
ops = self._dist_context._serial_main_program.global_block().ops
# clear op dist attr when user shard tensor or op but in the full auto parallel mode.
for op in ops:
op.dist_attr = OperatorDistAttr(op.desc)
vars = self._dist_context._serial_main_program.global_block().vars
for var_name in vars:
vars[var_name].dist_attr = TensorDistAttr(vars[var_name].desc)
Returns:
List: The list contains the list of operators which belong to the same layer.
"""
seq = [op.type for op in ops]
while not OperatorClusteringUtil.stop_replace(seq):
......@@ -1061,6 +1227,7 @@ class RuleBasedTuner:
to_replace_seq = OperatorClusteringUtil.replace_by_decomposed_seq(
decomposed_sub_seq, to_replace_seq
)
result = seq[: to_replace_idxes[0]]
if not has_merged:
result.extend(to_replace_seq)
......@@ -1077,3 +1244,369 @@ class RuleBasedTuner:
layers.append(layer)
return layers
def match_program(self, program):
"""Use patterns to match the program and get tensor shard spec when pattern matched."""
graph = GraphUtil.convert_to_graph(program.global_block())
results = GraphUtil.match_all_patterns(graph)
if results:
for pattern_name in results.keys():
pattern = _PATTERNS[pattern_name]
for parallelism in pattern.attrs["shard_spec"].keys():
shard_spec = pattern.attrs["shard_spec"][parallelism]
for pattern_node_id in shard_spec.keys():
for item in results[pattern_name]:
var_id = item[pattern_node_id]
var_desc_id = graph.attrs["id_to_var_desc_id"][
var_id
]
if var_desc_id not in self.tensor_dist_attrs:
self.tensor_dist_attrs[var_desc_id] = {}
self.tensor_dist_attrs[var_desc_id][
parallelism
] = shard_spec[pattern_node_id]
tensor_name = graph.attrs["id_to_var_name"][var_id]
self._logger.info(
"{}'s shard_spec may be {} when under {} parallelism.".format(
tensor_name,
shard_spec[pattern_node_id],
parallelism,
)
)
else:
self._logger.info(
"No pattern has be matched by this program. Currently, only the transformer-based models are supported. Data parallelism will be used."
)
self._use_dp = True
def gen_fwd_sub_programs_by_clone(self):
"""Generate all forward sub programs by cloned from the original program."""
for idx, layer in enumerate(self.layers):
sub_fwd_program = self._gen_fwd_sub_program_by_clone(layer)
self.fwd_sub_programs[idx] = sub_fwd_program
def _gen_fwd_sub_program_by_clone(self, ops):
"""Generate the forward sub program of the given ops."""
program = paddle.static.Program()
block = ops[0].block
vars = block.vars
target_block = program.global_block()
with paddle.static.program_guard(program):
has_cloned_vars = set()
for op in ops:
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op.desc)
for var_name in op.input_arg_names:
if var_name not in has_cloned_vars:
if vars[var_name].is_parameter:
src_var = vars[var_name]
copied_kwargs = {}
copied_kwargs['trainable'] = src_var.trainable
copied_kwargs[
'optimize_attr'
] = src_var.optimize_attr
copied_kwargs['regularizer'] = src_var.regularizer
copied_kwargs[
'do_model_average'
] = src_var.do_model_average
copied_kwargs['need_clip'] = src_var.need_clip
param = Parameter(
block=target_block,
type=src_var.type,
name=src_var.name,
shape=src_var.shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs
)
else:
target_block._clone_variable(vars[var_name])
target_block.vars[var_name].persistable = vars[
var_name
].persistable
target_block.vars[var_name].desc.set_original_id(
vars[var_name].desc.original_id()
)
has_cloned_vars.add(var_name)
for var_name in op.output_arg_names:
if var_name not in has_cloned_vars:
target_block._clone_variable(vars[var_name])
target_block.vars[var_name].persistable = vars[
var_name
].persistable
target_block.vars[var_name].desc.set_original_id(
vars[var_name].desc.original_id()
)
has_cloned_vars.add(var_name)
target_block._sync_with_cpp()
return program
def _compelte_sub_fwd_program(self, idx, sub_fwd_program, process_mesh):
"""Compelete forward sub program."""
selective_parallelisms = (
["dp", "mp"] if len(process_mesh.shape) == 1 else ["dp_mp", "mp_dp"]
)
for parallelism in selective_parallelisms:
has_set_tensor_count = 0
dist_context = DistributedContext(sub_fwd_program)
has_set_dist_attr_tensors = set()
dist_context.process_meshes = []
dist_context.add_process_mesh(process_mesh)
vars = sub_fwd_program.global_block().vars
# clear op dist attr
ops = sub_fwd_program.global_block().ops
for op in ops:
op.dist_attr = OperatorDistAttr(op.desc)
# clear tensor dist attr
for var_name in vars:
vars[var_name].dist_attr = TensorDistAttr(vars[var_name].desc)
for var_name in vars:
var_id = vars[var_name].desc.original_id()
if var_id in self.tensor_dist_attrs:
if parallelism in self.tensor_dist_attrs[var_id]:
dims_mapping = self.tensor_dist_attrs[var_id][
parallelism
]
dist_tensor = DistributedTensor(vars[var_name])
dist_tensor.dist_attr.process_mesh = process_mesh
dist_tensor.dist_attr.dims_mapping = dims_mapping
dist_tensor.dist_attr.mark_annotated("dims_mapping")
dist_tensor.dist_attr.mark_annotated("process_mesh")
dist_context.add_dist_tensor_for_program(dist_tensor)
has_set_tensor_count += 1
has_set_dist_attr_tensors.add(var_id)
# check whether no dist attr in dist context
if has_set_tensor_count > 0:
dist_context.initialize(no_default=True)
completer = Completer(dist_context)
completer.complete_forward_annotation()
if parallelism not in self.sub_programs_dist_context[idx]:
self.sub_programs_dist_context[idx][parallelism] = {}
key = self.convert_process_mesh_to_key(process_mesh)
self.sub_programs_dist_context[idx][parallelism][
key
] = dist_context
else:
self._logger.info(
"No pattern has be matched under {} parallelism whe sub program is {}.".format(
parallelism, sub_fwd_program
)
)
def complete_sub_fwd_programs(self, process_mesh):
"""Complete all forward sub programs."""
for idx in self.fwd_sub_programs.keys():
sub_fwd_program = self.fwd_sub_programs[idx]
if idx not in self.sub_programs_dist_context:
self.sub_programs_dist_context[idx] = {}
self._compelte_sub_fwd_program(idx, sub_fwd_program, process_mesh)
def _complete_sub_bwd_program(self, sub_program_dist_context):
"""
Complete the backward OP according to the forward OP.
Most of the logic is the same as the backward completion in the completer.
The difference is that find the backward OP according to the forward OP,
while find the forward OP according to the backward OP in the completer.
"""
def _is_grad_var_name(name):
if "@GRAD" in name:
return True
return False
sub_fwd_program = sub_program_dist_context.serial_main_program
block = sub_fwd_program.global_block()
vars = self.full_main_program.global_block().vars
ops = self.full_main_program.global_block().ops
grad_var_to_var = (
self.full_main_program_dist_context.dist_op_context.grad_var_to_var[
1
]
)
for forward_op in block.ops:
if (
forward_op.desc.original_id()
not in self.op_original_id_to_grad_op_original_id
):
continue
grad_op_id = self.op_original_id_to_grad_op_original_id[
forward_op.desc.original_id()
]
# for unsqueeze2 op in gpt, it has no grad op
# or for no need to bwd
if grad_op_id not in self.op_original_id_to_op:
continue
grad_op = self.op_original_id_to_op[grad_op_id]
if grad_op.type == "concat" and forward_op.type == "split":
forward_op_dist_attr = (
sub_program_dist_context.get_op_dist_attr_for_program(
forward_op
)
)
output_var = vars[grad_op.desc.output('Out')[0]]
split_input_var_name = forward_op.input("X")[0]
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
split_input_var_name
)
ref_mesh = forward_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistAttr()
for input_name in grad_op.input_arg_names:
grad_op_dist_attr.set_input_dims_mapping(
input_name, ref_dims_mapping
)
output_var_dist_attr = TensorDistAttr()
output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = ref_mesh
sub_program_dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr
)
grad_op_dist_attr.set_output_dims_mapping(
output_var.name, ref_dims_mapping
)
grad_op_dist_attr.process_mesh = ref_mesh
sub_program_dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr
)
grad_op_dist_attr.impl_type = (
fwd_op_dist_attr.impl_type # noqa: F821
)
grad_op_dist_attr.impl_idx = (
fwd_op_dist_attr.impl_idx # noqa: F821
)
continue
fwd_op_dist_attr = (
sub_program_dist_context.get_op_dist_attr_for_program(
forward_op
)
)
fwd_op_process_mesh = fwd_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistAttr()
grad_op_dist_attr.process_mesh = fwd_op_process_mesh
for input_name in grad_op.input_arg_names:
if (
input_name not in forward_op.input_arg_names
and input_name not in forward_op.output_arg_names
):
if input_name in grad_var_to_var.keys():
fwd_name = grad_var_to_var[input_name]
ref_dims_mapping = (
fwd_op_dist_attr.get_output_dims_mapping(fwd_name)
)
else:
input_var = vars[input_name]
ref_dims_mapping = sub_program_dist_context.get_tensor_dist_attr_for_program(
input_var
).dims_mapping
else:
if input_name in forward_op.input_arg_names:
ref_dims_mapping = (
fwd_op_dist_attr.get_input_dims_mapping(input_name)
)
else:
ref_dims_mapping = (
fwd_op_dist_attr.get_output_dims_mapping(input_name)
)
assert (
ref_dims_mapping is not None
), "[{}] 's dims mapping is NONE".format(input_name)
grad_op_dist_attr.set_input_dims_mapping(
input_name, ref_dims_mapping
)
for output_name in grad_op.output_arg_names:
assert output_name in grad_var_to_var
fwd_name = grad_var_to_var[output_name]
ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
fwd_name
)
# var
output_var = vars[output_name]
tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = fwd_op_process_mesh
sub_program_dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr
)
# op
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_dims_mapping
)
grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx
sub_program_dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr
)
grad_op_idx = self.op_original_id_to_idx[grad_op_id]
if grad_op_idx + 1 < len(ops):
grad_op_next_op = ops[grad_op_idx + 1]
if grad_op_next_op.type == "sum":
assert all(
map(_is_grad_var_name, grad_op_next_op.input_arg_names)
)
output_name = grad_op_next_op.output_arg_names[0]
assert (
output_name in grad_var_to_var
), "sum op's output '{}' has no corresponding var".format(
output_name
)
ref_fwd_var_name = grad_var_to_var[output_name]
ref_fwd_var = vars[ref_fwd_var_name]
ref_fwd_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
ref_fwd_var
)
ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping
ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh
# output
tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping
tensor_dist_attr.process_mesh = ref_fwd_process_mesh
output_var = vars[output_name]
sub_program_dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr
)
# op
grad_op_dist_attr = OperatorDistAttr()
grad_op_dist_attr.process_mesh = ref_fwd_process_mesh
for var_name in grad_op_next_op.input_arg_names:
grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_fwd_dims_mapping
)
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping
)
grad_op_dist_attr.impl_type = "default"
grad_op_dist_attr.impl_idx = 0
sub_program_dist_context.set_op_dist_attr_for_program(
grad_op_next_op, grad_op_dist_attr
)
def complete_sub_bwd_programs(self):
for idx in self.sub_programs_dist_context:
for parallelism in self.sub_programs_dist_context[idx]:
for key in self.sub_programs_dist_context[idx][parallelism]:
sub_program_dist_context = self.sub_programs_dist_context[
idx
][parallelism][key]
self._complete_sub_bwd_program(sub_program_dist_context)
......@@ -127,6 +127,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_bf16 MODULES test_pass_bf16)
py_test_modules(test_dist_saver MODULES test_dist_saver)
py_test_modules(test_engine_save_load MODULES test_engine_save_load)
py_test_modules(test_rule_based_tuner MODULES test_rule_based_tuner)
# End of unittests WITH single card WITHOUT timeout
endif()
......@@ -178,6 +178,7 @@ class TestDistOpCost(unittest.TestCase):
[None, None],
)
tmp_out = paddle.matmul(out1, tmp_param)
tmp_out = paddle.scale(tmp_out, 0.5)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]
out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1]
......@@ -286,6 +287,7 @@ class TestDistOpCost(unittest.TestCase):
)
tmp_out = paddle.matmul(out1, tmp_param)
tmp_out = paddle.scale(tmp_out, 0.5)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]
out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1]
......
......@@ -119,9 +119,10 @@ class TestGroupOperators(unittest.TestCase):
RuleBasedTuner,
)
dist_context = DistributedContext()
dist_context = DistributedContext(train_program)
dist_context.initialize()
tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops)
layers = tuner.cluster_operators()
op_types = []
for layer in layers:
tmp = []
......
......@@ -112,18 +112,11 @@ class TestGroupOperatorsAndPatterns(unittest.TestCase):
sequence_len,
vocab_size,
)
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS,
GraphUtil,
RuleBasedTuner,
)
dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops)
graph = GraphUtil.convert_to_graph(train_program.global_block())
print("graph: ", graph)
print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"])
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
import paddle.static as static
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
GPTForPretraining,
GPTModel,
GPTPretrainingCriterion,
)
def get_gpt_model(
train_program, start_program, place, batch_size, sequence_len, vocab_size
):
with static.program_guard(train_program, start_program):
tokens = paddle.static.data(
name="tokens", shape=[batch_size, sequence_len], dtype='int64'
)
position_ids = paddle.static.data(
name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
)
attention_mask = paddle.static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32',
)
labels = paddle.static.data(
name="labels", shape=[batch_size, sequence_len], dtype='int64'
)
loss_mask = paddle.static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float32'
)
gpt = GPTModel(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
)
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
def gen_data():
np.random.seed(2021)
tokens = []
position_ids = []
attention_mask = []
labels = []
loss_mask = []
for _ in range(batch_size):
tokens.append(np.random.randint(vocab_size, size=sequence_len))
position_ids.append(np.arange(sequence_len))
attention_mask.append([np.tril(np.ones(sequence_len))])
labels.append(np.random.randint(vocab_size, size=sequence_len))
loss_mask.append(np.ones(sequence_len))
return tokens, position_ids, attention_mask, labels, loss_mask
return train_program, start_program, loss, gen_data
class TestRuleBasedTuner(unittest.TestCase):
def test_gpt(self):
modeling.init_global()
train_program = static.Program()
start_program = static.Program()
place = paddle.set_device("gpu")
batch_size = 8
sequence_len = 512
vocab_size = 1000
train_program, start_program, loss, gen_data = get_gpt_model(
train_program,
start_program,
place,
batch_size,
sequence_len,
vocab_size,
)
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
RuleBasedTuner,
)
clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
dist_context = DistributedContext(
serial_main_prog=train_program,
serial_startup_prog=start_program,
serial_optimizer=opt,
serial_loss=loss,
)
dist_context.initialize()
tuner = RuleBasedTuner(dist_context)
tuner.cluster_operators()
tuner.gen_full_program()
tuner.match_program(tuner._dist_context.serial_main_program)
process_mesh = ProcessMesh([0, 1])
tuner.gen_fwd_sub_programs_by_clone()
tuner.complete_sub_fwd_programs(process_mesh)
tuner.complete_sub_bwd_programs()
if __name__ == "__main__":
unittest.main()
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
_FLOPS_COMPUTE_FUNC_MAP = {}
......@@ -244,8 +245,12 @@ def _matmul_flops(input_shapes, attrs):
equation: flops = 2 * numel(output) * dim_n
"""
x_shape = input_shapes.get("X", input_shapes.get("x", [[0]]))[0]
y_shape = input_shapes.get("Y", input_shapes.get("y", [[0]]))[0]
x_shape = copy.deepcopy(
input_shapes.get("X", input_shapes.get("x", [[0]]))[0]
)
y_shape = copy.deepcopy(
input_shapes.get("Y", input_shapes.get("y", [[0]]))[0]
)
if attrs.get('transpose_X') or attrs.get('transpose_x'):
x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1]
......@@ -276,11 +281,11 @@ def _matmul_v2_flops(input_shapes, attrs):
shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m]
equation: flops = 2 * numel(outputs) * dim_n
"""
x_shape = input_shapes.get('X')[0]
y_shape = input_shapes.get('Y')[0]
if attrs.get('trans_x') is not None:
x_shape = copy.deepcopy(input_shapes.get('X')[0])
y_shape = copy.deepcopy(input_shapes.get('Y')[0])
if attrs.get('trans_x'):
x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1]
if attrs.get('trans_y') is not None:
if attrs.get('trans_y'):
y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1]
dim_x = len(x_shape)
dim_y = len(y_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册