未验证 提交 ec6b8fbd 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Add the support for the auto completion of while_op (#39939)

* [Auto Parallel] Support the auto completion of while_op

* [Auto Parallel] Improve the completion algorithms

* [Auto Parallel] Fix bugs for ernie inference

* [Auto Parallel] Remove attrs which cannot be pickled

* [Auto Parallel] make the dims_mappings of LodTensorArray vars empty

* [Auto Parallel] Fix bugs for the ernie inference in the pipeline parallel

* [Auto Parallel] Remove unncessary comments

* [Auto Parallel] Fix a bug of the CMakeLists

* [Auto Parallel] Use the newest APIs to write the unit test

* [Auto Parallel] Remove unnecessary statements
上级 31858263
......@@ -95,6 +95,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
std::unordered_map<std::string, std::pair<VarDesc *, int>>
name_to_desc_block_id;
block_id_ = block.ID();
const BlockDesc *block_var_visible = &block;
while (block_var_visible != nullptr) {
for (auto *var : block_var_visible->AllVars()) {
......
......@@ -230,6 +230,7 @@ class Graph {
auto *x =
AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id));
x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x;
}
......@@ -245,6 +246,7 @@ class Graph {
"The OpDesc used to create operator node is null."));
auto *x = AddNode(new ir::Node(op_desc));
x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x;
}
......@@ -263,6 +265,7 @@ class Graph {
num_node_created_);
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_));
x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x;
}
......@@ -276,6 +279,7 @@ class Graph {
}
auto *x = AddNode(new ir::Node(name, type, block_id_));
x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x;
}
......
......@@ -125,6 +125,7 @@ class Node {
// Only use this for auto parallel.
// A node does not have original desc if the return is zero.
uint64_t OriginalDescId() const { return original_desc_id_; }
int GraphId() const { return graph_id_; }
bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; }
......@@ -246,10 +247,12 @@ class Node {
// Store the original id of var desc or op desc.
// Only use this for auto parallel.
uint64_t original_desc_id_{0};
int graph_id_{-1};
private:
// ID can only set by a Graph.
void SetId(int id) { id_ = id; }
void SetGraphId(int graph_id) { graph_id_ = graph_id; }
// desc_order can only set by a Graph when constructing a Graph from a
// BlockDesc.
......
......@@ -143,6 +143,7 @@ void BindNode(py::module *m) {
.def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id)
.def("graph_id", &Node::GraphId)
.def("original_desc_id", &Node::OriginalDescId)
.def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar)
......
......@@ -21,11 +21,12 @@ from paddle.fluid import framework
from .utils import print_program_with_dist_attr
from .operators import find_best_compatible_distributed_operator_impl
from .dist_context import get_default_distributed_context
from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .process_mesh import ProcessMesh
from paddle.distributed.fleet.meta_optimizers.common import OpRole
......@@ -108,6 +109,20 @@ def compute_compatible_dims_mapping(dims_mapping_list):
return compatible_result
def merge_process_mesh_two(pm1, pm2):
process_set1 = set()
process_set2 = set()
if pm1 is None and pm2 is None:
return None
if pm1 is not None:
process_set1 = set(pm1.processes)
if pm2 is not None:
process_set2 = set(pm2.processes)
merged_process_set = process_set1.union(process_set2)
merged_process_mesh = ProcessMesh(list(merged_process_set))
return merged_process_mesh
class Completer:
def __init__(self, dist_context):
assert dist_context is not None
......@@ -119,7 +134,9 @@ class Completer:
return False
tensor_desc = tensor_node.var()
# Skip reader tensor
if tensor_desc.type() == core.VarDesc.VarType.READER:
if tensor_desc.type() == core.VarDesc.VarType.READER \
or tensor_desc.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor_desc.type == core.VarDesc.VarType.STEP_SCOPES:
return False
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
......@@ -185,7 +202,7 @@ class Completer:
op_dist_attr = dist_op.dist_attr
if fwd:
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
if tensor_node.is_var() and tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
......@@ -208,7 +225,7 @@ class Completer:
# Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
......@@ -220,7 +237,7 @@ class Completer:
op_dist_attr.impl_idx = op_dist_impl.idx
else:
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
if tensor_node.is_var() and tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
......@@ -243,7 +260,7 @@ class Completer:
# Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
......@@ -255,49 +272,26 @@ class Completer:
op_dist_attr.impl_idx = op_dist_impl.idx
return changed
def _update_process_mesh(self):
def _find_nearset_node(nodes, idx):
for node in reversed(nodes[:idx]):
node_dist_attr = self._dist_context.get_dist_attr_for_graph(
node)
if node_dist_attr.process_mesh is not None:
return node
total_reach_fix_point = False
while not total_reach_fix_point:
total_changed = False
for is_fwd in [True, False]:
all_nodes = self._dist_context.serial_ordered_nodes \
if is_fwd else reversed(self._dist_context.serial_ordered_nodes)
reach_fix_point = False
while not reach_fix_point:
def _update_dims_mapping_between_graphs(self):
changed = False
for idx, node in enumerate(all_nodes):
nearest_node = _find_nearset_node(
self._dist_context.serial_ordered_nodes, idx)
if nearest_node is None:
continue
nearest_node_dis_attr = self._dist_context.get_dist_attr_for_graph(
nearest_node)
nearest_process_mesh = nearest_node_dis_attr.process_mesh
cur_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
node)
cur_process_mesh = cur_node_dist_attr.process_mesh
compatible_process_mesh = compute_compatible_process_mesh(
[cur_process_mesh, nearest_process_mesh])
if compatible_process_mesh is not None \
and cur_process_mesh != compatible_process_mesh:
cur_node_dist_attr.process_mesh = compatible_process_mesh
for parent_node, child_node in self._node_pairs_between_graphs:
parent_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
parent_node)
child_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
child_node)
parent_node_dims_mapping = parent_node_dist_attr.dims_mapping
child_node_dims_mapping = child_node_dist_attr.dims_mapping
compatible_dims_mapping = compute_compatible_dims_mapping(
[parent_node_dims_mapping, child_node_dims_mapping])
if (compatible_dims_mapping is not None) \
and (compatible_dims_mapping != parent_node_dims_mapping):
parent_node_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
if changed:
reach_fix_point = False
total_changed = True
else:
reach_fix_point = True
if total_changed:
total_reach_fix_point = False
else:
total_reach_fix_point = True
if (compatible_dims_mapping is not None) \
and (compatible_dims_mapping != child_node_dims_mapping):
parent_node_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
return changed
def _update_dims_mapping(self):
# Complete dims_mapping for each node
......@@ -318,11 +312,314 @@ class Completer:
node, fwd=is_fwd)
if op_changed:
changed = True
graph_changed = self._update_dims_mapping_between_graphs()
if graph_changed:
changed = True
if changed:
reach_fix_point = False
else:
reach_fix_point = True
def _update_process_mesh_by_nearest(self, op_node, nearest_op_node):
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
# Set the process mesh of the op node by its nearest op node
if not op_dist_attr.is_annotated("process_mesh"):
process_mesh = op_dist_attr.process_mesh
nearest_op_dis_attr = self._dist_context.get_dist_attr_for_graph(
nearest_op_node)
nearest_process_mesh = nearest_op_dis_attr.process_mesh
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh, nearest_process_mesh])
if compatible_process_mesh is not None \
and process_mesh != compatible_process_mesh:
op_dist_attr.process_mesh = compatible_process_mesh
# Skip the process_mesh setting of inputs and outputs of while_op
if op_dist_attr.op_type == "while":
return
# Set the process mesh of the op node's leaf-inputs
for tensor_node in op_node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
if tensor_dist_attr.is_annotated("process_mesh"):
continue
# Skip the non-leaf var node
if len(tensor_node.inputs) != 0:
continue
compatible_process_mesh = compute_compatible_process_mesh(
[tensor_dist_attr.process_mesh, op_dist_attr.process_mesh])
if compatible_process_mesh is not None \
and tensor_dist_attr.process_mesh != compatible_process_mesh:
tensor_dist_attr.process_mesh = compatible_process_mesh
# Set the process mesh of the op node's outputs
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
if tensor_dist_attr.is_annotated("process_mesh"):
continue
compatible_process_mesh = compute_compatible_process_mesh(
[tensor_dist_attr.process_mesh, op_dist_attr.process_mesh])
if compatible_process_mesh is not None \
and tensor_dist_attr.process_mesh != compatible_process_mesh:
tensor_dist_attr.process_mesh = compatible_process_mesh
def _update_process_mesh_for_specials(self):
def _find_nearest_tensor_node_before(nodes, idx, var_name):
for node in reversed(nodes[:idx]):
if node.is_var() and node.var() is not None \
and node.var().name() == var_name:
return node
def _find_nearest_tensor_node_after(nodes, idx, var_name):
for node in nodes[idx + 1:]:
if node.is_var() and node.var() is not None \
and node.var().name() == var_name:
return node
def _find_nodes_related_to_cond(source_node):
related_nodes = []
visited = set()
frontier = list()
frontier.append(source_node)
# BFS
while len(frontier) != 0:
cur = frontier[0]
frontier = frontier[1:]
if _node_id(cur) in visited:
continue
# TODO: need more restrictions
for node in cur.inputs:
if node.is_var() and node.var() is not None:
if node.var().type() != core.VarDesc.VarType.READER \
and len(node.var().shape()) == 1:
frontier.append(node)
related_nodes.append(node)
if node.is_op() and node.op() is not None:
flag = True
if node.op().type() == "create_py_reader" \
or node.op().type() == "create_double_buffer_reader" \
or node.op().type() == "read":
flag = False
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
or len(tensor_node.var().shape()) != 1:
flag = False
break
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
or len(tensor_node.var().shape()) != 1:
flag = False
break
if flag:
frontier.append(node)
related_nodes.append(node)
visited.add(_node_id(cur))
return related_nodes
# Amend the process meshes related to while_op
for while_op_node, while_op_node_idx in self._while_op_nodes.values():
sub_graph_id = while_op_node.op()._block_attr_id("sub_block")
sub_graph = self._dist_context._serial_graph.get_sub_graph(
sub_graph_id)
sub_graph_nodes = list(sub_graph.all_nodes())
while_dist_op = self._dist_context.get_dist_op_for_graph(
while_op_node)
while_op_dist_attr = while_dist_op.dist_attr
# Step 1: set the process mesh of while_op to the merged process mesh of its subblock
merged_process_mesh = while_op_dist_attr.process_mesh
for node in sub_graph_nodes:
if (node.is_var() and node.var() is not None) \
or (node.is_op() and node.op() is not None):
dist_attr = self._dist_context.get_dist_attr_for_graph(node)
merged_process_mesh = merge_process_mesh_two(
merged_process_mesh, dist_attr.process_mesh)
while_op_dist_attr.process_mesh = merged_process_mesh
# Step 2: set the related nodes of while_op to the process mesh of while_op
# Step 2.1: Find related nodes of cond var the graph of while_op
cond_tensor_related_nodes = []
cond_tensor_name = while_op_node.op().input("Condition")[0]
cond_tensor_node = None
for node in while_op_node.inputs:
if node.is_var() and node.var() is not None \
and node.var().name() == cond_tensor_name:
cond_tensor_node = node
cond_tensor_related_nodes.append(cond_tensor_node)
break
cond_tensor_related_nodes.extend(
_find_nodes_related_to_cond(cond_tensor_node))
# Step 2.2: Find related nodes of cond var in the subgraph of while_op
cond_tensor_node = None
for node in reversed(sub_graph_nodes):
if node.is_var() and node.var() is not None \
and node.var().name() == cond_tensor_name \
and len(node.outputs) == 0:
cond_tensor_node = node
break
cond_tensor_related_nodes.extend(
_find_nodes_related_to_cond(cond_tensor_node))
# Step 2.3: Add the StepScops output of while_op
stepscopes_tensor_name = while_op_node.op().output("StepScopes")[0]
stepscopes_tensor_node = None
for output_node in while_op_node.outputs:
if output_node.is_var() and output_node.var() is not None \
and output_node.var().name() == stepscopes_tensor_name:
stepscopes_tensor_node = output_node
cond_tensor_related_nodes.append(stepscopes_tensor_node)
# Step 2.4: Set the process meshes of all nodes related to cond var to the process mesh of while op
for node in cond_tensor_related_nodes:
tensor_dist_attr = self._dist_context.get_dist_attr_for_graph(
node)
tensor_dist_attr.process_mesh = merged_process_mesh
# Step 3: set the process meshes of the inputs in while_op to the process meshes of the outside input nodes
while_op_inputs_dist_attrs = while_op_dist_attr.inputs_dist_attrs
for tensor_name, tensor_dist_attr in while_op_inputs_dist_attrs.items(
):
nearest_tensor_node = _find_nearest_tensor_node_before(
self._dist_context.serial_ordered_nodes, while_op_node_idx,
tensor_name)
nearest_tensor_dist_attr = self._dist_context.get_dist_attr_for_graph(
nearest_tensor_node)
tensor_dist_attr.process_mesh = nearest_tensor_dist_attr.process_mesh
# Step 4: set the process meshes of the outputs in while_op to the process meshes of the outside output nodes
while_op_outputs_dist_attrs = while_op_dist_attr.outputs_dist_attrs
for tensor_name, tensor_dist_attr in while_op_outputs_dist_attrs.items(
):
nearest_tensor_node = _find_nearest_tensor_node_before(
self._dist_context.serial_ordered_nodes, while_op_node_idx,
tensor_name)
if nearest_tensor_node is None:
nearest_tensor_node = _find_nearest_tensor_node_after(
self._dist_context.serial_ordered_nodes,
while_op_node_idx, tensor_name)
nearest_tensor_dist_attr = self._dist_context.get_dist_attr_for_graph(
nearest_tensor_node)
tensor_dist_attr.process_mesh = nearest_tensor_dist_attr.process_mesh
# Amend the process meshes related to array
for array_node_list in self._array_nodes.values():
merged_process_mesh = None
for array_node in array_node_list:
dist_attr = self._dist_context.get_dist_attr_for_graph(
array_node)
merged_process_mesh = merge_process_mesh_two(
merged_process_mesh, dist_attr.process_mesh)
for array_node in array_node_list:
dist_attr = self._dist_context.get_dist_attr_for_graph(
array_node)
dist_attr.process_mesh = merged_process_mesh
def _update_process_mesh(self):
ordered_op_nodes = self._dist_context._serial_ordered_op_nodes
# Step 1: Set the annotated process meshes from tensors to the first ops using them
ordered_tensor_nodes = self._dist_context._serial_ordered_tensor_nodes
for tensor_node in ordered_tensor_nodes:
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
if not tensor_dist_attr.is_annotated("process_mesh"):
continue
first_op_node = None
for op_node in ordered_op_nodes:
# TODO: Need a better rule for the control flow ops.
# For now, do not set the process mesh of while_op from its inputs
if op_node.op().type() == "while":
continue
for input_tensor_node in op_node.inputs:
if _node_id(tensor_node) == _node_id(input_tensor_node):
first_op_node = op_node
break
if first_op_node is not None:
break
if first_op_node is None:
continue
op_dist_attr = self._dist_context.get_dist_attr_for_graph(
first_op_node)
if op_dist_attr is not None and not op_dist_attr.is_annotated(
"process_mesh"):
compatible_process_mesh = compute_compatible_process_mesh(
[tensor_dist_attr.process_mesh, op_dist_attr.process_mesh])
if compatible_process_mesh is not None \
and op_dist_attr.process_mesh != compatible_process_mesh:
op_dist_attr.process_mesh = compatible_process_mesh
# Step 2: set the process meshes of ops with the nearest op before them
# Step 2.1: find the first op node which has the process mesh
idx_of_first_op_node_has_process_mesh = -1
for idx, op_node in enumerate(ordered_op_nodes):
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
if op_dist_attr.process_mesh is not None \
and idx_of_first_op_node_has_process_mesh == -1:
idx_of_first_op_node_has_process_mesh = idx
# Reuse the following method to set the related tensors for same op node
self._update_process_mesh_by_nearest(op_node, op_node)
# Step 2.2: set the process meshes of ops by the nearest op node after the first op node
if idx_of_first_op_node_has_process_mesh + 1 > len(ordered_op_nodes):
return None
for idx, op_node in enumerate(ordered_op_nodes[
idx_of_first_op_node_has_process_mesh + 1:]):
original_idx = idx_of_first_op_node_has_process_mesh + +idx + 1
nearest_op_node = ordered_op_nodes[original_idx - 1]
nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph(
nearest_op_node)
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
assert nearest_op_dist_attr.process_mesh is not None
self._update_process_mesh_by_nearest(op_node, nearest_op_node)
# Step 2.3: set the process meshes of ops by the nearest op node before the first op node
nearest_op_node = ordered_op_nodes[
idx_of_first_op_node_has_process_mesh]
for op_node in ordered_op_nodes[:idx_of_first_op_node_has_process_mesh]:
self._update_process_mesh_by_nearest(op_node, nearest_op_node)
# Step 3: adjust the process meshes for special ops
self._update_process_mesh_for_specials()
def _prepare(self):
self._while_op_nodes = {}
self._array_nodes = {}
self._node_pairs_between_graphs = []
all_nodes = self._dist_context.serial_ordered_nodes
for idx, node in enumerate(all_nodes):
if node.is_op():
if node.op().type() == "while":
self._while_op_nodes[_node_id(node)] = (node, idx)
if node.op().type() == "read_from_array":
array_var_name = node.op().input("X")[0]
if self._array_nodes.get(array_var_name, None) is None:
self._array_nodes[array_var_name] = []
self._array_nodes[array_var_name].append(node)
if node.op().type() == "write_to_array":
array_var_name = node.op().output("Out")[0]
if self._array_nodes.get(array_var_name, None) is None:
self._array_nodes[array_var_name] = []
self._array_nodes[array_var_name].append(node)
self._array_nodes[array_var_name].append(node.outputs[0])
if node.is_var() and node.var() is not None:
if node.node.graph_id() != 0:
for before_node in reversed(all_nodes[:idx]):
if before_node.is_var() and before_node.var() is not None \
and before_node.node.graph_id() == node.node.graph_id() - 1 \
and before_node.var().name() == node.var().name():
self._node_pairs_between_graphs.append(
(before_node, node))
for after_node in all_nodes[idx + 1:]:
if after_node.is_var() and after_node.var() is not None \
and after_node.node.graph_id() == node.node.graph_id() - 1 \
and after_node.var().name() == node.var().name():
self._node_pairs_between_graphs.append(
(after_node, node))
def complete_forward_annotation(self, serial_main_program):
""" Complete annotation for the partial annotated serial_main_program.
Arguments:
......@@ -336,24 +633,24 @@ class Completer:
# Initialize distributed attributes for all var and op node in serial_main_program
self._dist_context.init_dist_attr_for_program()
# print_program_with_dist_attr(serial_main_program, self._dist_context)
# Initialize distributed attributes for all var and op node in graph
self._dist_context.init_dist_attr_for_graph()
self._prepare()
self._update_process_mesh()
# Complete dims_mapping for each node
self._update_dims_mapping()
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
self._dist_context.clear_dist_info_for_graph()
# print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context)
# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()
# print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context)
self._dist_context.validate_dist_attr_for_program()
return serial_main_program
......
......@@ -175,6 +175,7 @@ class TensorDistributedAttribute:
class OperatorDistributedAttribute:
def __init__(self):
self._process_mesh = None
self._op_type = None
self._impl_type = None
self._impl_idx = None
self._inputs_dist_attrs = {}
......@@ -194,11 +195,23 @@ class OperatorDistributedAttribute:
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
# In while op, the proess mesh is not shared by all inputs and outputs
if self._op_type == "while":
return None
for dist_attr in self._inputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
for dist_attr in self._outputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
@property
def op_type(self):
return self._op_type
@op_type.setter
def op_type(self, op_type):
if op_type is not None:
self._op_type = op_type
@property
def impl_type(self):
return self._impl_type
......@@ -326,6 +339,8 @@ class OperatorDistributedAttribute:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
# Make sure proscess_meshes in dist op be same
if self.op_type == "while":
return None
process_meshes = []
process_meshes.append(self.process_mesh)
for tensor_dist_attr in self.inputs_dist_attrs.values():
......
......@@ -15,6 +15,7 @@
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid.framework import get_flags, set_flags
from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
......@@ -39,6 +40,10 @@ def set_default_distributed_context(dist_context):
_g_default_distributed_context = dist_context
def _node_id(node):
return (node.node.graph_id(), node.node.id())
class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
......@@ -146,7 +151,7 @@ class DistributedContext:
return None
def get_dist_tensor_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
serial_tensor_node_id = _node_id(serial_tensor_node)
return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)
def get_dist_op_for_program(self, serial_op):
......@@ -168,7 +173,7 @@ class DistributedContext:
del self._dist_ops_for_program[serial_tensor_id]
def get_dist_op_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id()
serial_op_node_id = _node_id(serial_op_node)
return self._dist_ops_for_graph.get(serial_op_node_id, None)
def get_tensor_dist_attr_for_program(self, serial_tensor):
......@@ -197,7 +202,7 @@ class DistributedContext:
self.add_dist_tensor_for_program(dist_tensor)
def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
serial_tensor_node_id = _node_id(serial_tensor_node)
dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id,
None)
if dist_tensor:
......@@ -242,7 +247,7 @@ class DistributedContext:
self.add_dist_op_for_program(dist_op)
def get_op_dist_attr_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id()
serial_op_node_id = _node_id(serial_op_node)
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
......@@ -262,7 +267,7 @@ class DistributedContext:
def get_dist_attr_for_graph(self, serial_node):
if serial_node.is_var() and serial_node.var() is not None:
serial_tensor_node_id = serial_node.id()
serial_tensor_node_id = _node_id(serial_node)
dist_tensor = self._dist_tensors_for_graph.get(
serial_tensor_node_id, None)
if dist_tensor:
......@@ -270,7 +275,7 @@ class DistributedContext:
else:
return None
if serial_node.is_op() and serial_node.op() is not None:
serial_op_node_id = serial_node.id()
serial_op_node_id = _node_id(serial_node)
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
......@@ -311,40 +316,69 @@ class DistributedContext:
def order_nodes_by_program_order(self):
def _contains(nodes, target_node):
for node in nodes:
if node.id() == target_node.id():
if _node_id(node) == _node_id(target_node):
return True
return False
ordered_tensor_nodes = []
ordered_op_nodes = []
all_nodes = self._serial_graph.all_nodes()
serial_ordered_tensor_nodes = []
serial_ordered_op_nodes = []
all_nodes = []
# for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for node in graph.all_nodes():
all_nodes.append(node)
for node in all_nodes:
if node.is_var() and node.var() is not None:
ordered_tensor_nodes.append(node)
serial_ordered_tensor_nodes.append(node)
if node.is_op() and node.op() is not None:
ordered_op_nodes.append(node)
ordered_tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
ordered_op_nodes.sort(key=lambda node: node.node.original_desc_id())
for op_node in ordered_op_nodes:
serial_ordered_op_nodes.append(node)
serial_ordered_tensor_nodes.sort(
key=lambda node: node.node.original_desc_id())
serial_ordered_op_nodes.sort(
key=lambda node: node.node.original_desc_id())
num_nodes_before = len(serial_ordered_tensor_nodes) + len(
serial_ordered_op_nodes)
new_serial_ordered_tensor_nodes = []
new_serial_ordered_op_nodes = []
for op_node in serial_ordered_op_nodes:
tensor_nodes = []
for tensor_node in op_node.inputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
self._serial_ordered_nodes.append(op_node)
new_serial_ordered_op_nodes.append(op_node)
tensor_nodes = []
for tensor_node in op_node.outputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
num_nodes_before = len(ordered_tensor_nodes) + len(ordered_op_nodes)
assert len(self._serial_ordered_nodes) == num_nodes_before, \
"The number of nodes before ordering is not the same after ordering."
new_serial_ordered_tensor_nodes.sort(
key=lambda node: node.node.original_desc_id())
new_serial_ordered_op_nodes.sort(
key=lambda node: node.node.original_desc_id())
self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes
self._serial_ordered_op_nodes = new_serial_ordered_op_nodes
assert len(self._serial_ordered_nodes) == len(
self._serial_ordered_tensor_nodes) + len(
self._serial_ordered_op_nodes)
self._serial_orphan_tensor_nodes = []
for tensor_node in serial_ordered_tensor_nodes:
if not _contains(self._serial_ordered_tensor_nodes, tensor_node):
self._serial_orphan_tensor_nodes.append(tensor_node)
if len(self._serial_ordered_nodes) != num_nodes_before:
print(
"WARNING: there are some orphan tensors or ops which are not used in the execution."
)
def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
......@@ -352,9 +386,9 @@ class DistributedContext:
if self._is_initialized_for_graph:
return
# Convert program to graph
set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
self.order_nodes_by_program_order()
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
......@@ -365,10 +399,11 @@ class DistributedContext:
if tensor_id == cur_tensor_id \
or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
dist_tensor = cur_dist_tensor
self._node_id_to_tensor_id[node.id()] = cur_tensor_id
self._node_id_to_tensor_id[_node_id(
node)] = cur_tensor_id
assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program."
serial_tensor_node_id = node.id()
serial_tensor_node_id = _node_id(node)
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_tensor.dist_attr)
self._dist_tensors_for_graph[
......@@ -381,10 +416,10 @@ class DistributedContext:
if op_id == cur_op_id \
or op_id == cur_dist_op.serial_op.desc.original_id():
dist_op = cur_dist_op
self._node_id_to_op_id[node.id()] = cur_op_id
self._node_id_to_op_id[_node_id(node)] = cur_op_id
assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program."
serial_op_node_id = node.id()
serial_op_node_id = _node_id(node)
new_dist_op = DistributedOperator(dist_op.serial_op,
dist_op.dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
......@@ -402,10 +437,11 @@ class DistributedContext:
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"Both program and graph must be initialized."
updated_tensors = {}
all_nodes = self._serial_graph.all_nodes()
# all_nodes = self._serial_graph.all_nodes()
all_nodes = self._serial_ordered_nodes
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_id = self._node_id_to_tensor_id[node.id()]
tensor_id = self._node_id_to_tensor_id[_node_id(node)]
updated = updated_tensors.get(tensor_id, False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
......@@ -416,16 +452,31 @@ class DistributedContext:
dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
updated_tensors[tensor_id] = True
if node.is_op() and node.op() is not None:
op_id = self._node_id_to_op_id[node.id()]
op_id = self._node_id_to_op_id[_node_id(node)]
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph
# TODO: the completion algorithm will skip orphan tensors,
# here we just set there process_mesh to the first one.
for orphan_node in self._serial_orphan_tensor_nodes:
serial_tensor_id = orphan_node.var().id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id,
None)
if dist_tensor:
dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
else:
serial_tensor_id = orphan_node.var().original_id()
dist_tensor = self._dist_tensors_for_program.get(
serial_tensor_id, None)
dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
def amend_dist_attr_for_program(self):
for dist_tensor in self._dist_tensors_for_program.values():
serial_tensor = dist_tensor.serial_tensor
dist_attr = dist_tensor.dist_attr
if serial_tensor.type == core.VarDesc.VarType.READER:
if serial_tensor.type == core.VarDesc.VarType.READER \
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
......@@ -446,6 +497,7 @@ class DistributedContext:
tensor_shape = []
else:
if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or dist_op.serial_op.type == "create_py_reader":
tensor_shape = []
else:
......@@ -459,8 +511,9 @@ class DistributedContext:
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in serial_op.output_arg_names:
if dist_op.get_serial_output(
arg_name).type == core.VarDesc.VarType.READER:
if dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = dist_op.get_serial_output(arg_name).shape
......@@ -498,7 +551,8 @@ class DistributedContext:
for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph" \
or k == "_dist_main_programs" or k == "_dist_startup_programs" \
or k == "_serial_ordered_nodes":
or k == "_serial_ordered_nodes" or k == "_serial_ordered_tensor_nodes" \
or k == "_serial_ordered_op_nodes":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
......
......@@ -76,7 +76,8 @@ class DistributedOperator:
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
tensor_shape = []
else:
tensor_shape = tensor.shape
......@@ -86,7 +87,9 @@ class DistributedOperator:
tensor_dims_mapping)
for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
......@@ -95,6 +98,8 @@ class DistributedOperator:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_output_dims_mapping(tensor_name,
tensor_dims_mapping)
if self._dist_attr.op_type is None:
self._dist_attr.op_type = self.serial_op.type
if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
......@@ -134,11 +139,15 @@ class DistributedOperator:
return new_dist_attr
def validate_dist_attr(self):
if "read" in self.serial_op.type:
if "read" in self.serial_op.type or "while" == self.serial_op.type:
return True
for name in self.serial_op.input_arg_names:
input_dist_attr = self.dist_attr.get_input_dist_attr(name)
dims_mapping = input_dist_attr.dims_mapping
if self.get_serial_input(
name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
shape = []
else:
shape = self.get_serial_input(name).shape
if len(shape) != len(dims_mapping):
return False
......@@ -155,6 +164,10 @@ class DistributedOperator:
for name in self.serial_op.output_arg_names:
output_dist_attr = self.dist_attr.get_output_dist_attr(name)
dims_mapping = output_dist_attr.dims_mapping
if self.get_serial_output(name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY\
or self.get_serial_output(name).type == core.VarDesc.VarType.STEP_SCOPES:
shape = []
else:
shape = self.get_serial_output(name).shape
if len(shape) != len(dims_mapping):
return False
......@@ -241,14 +254,14 @@ class DistributedModule:
def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context
main_prog = paddle.fluid.default_main_program()
main_block = main_prog.global_block()
op_size = len(main_block.ops)
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
op_size = len(cur_block.ops)
output = self._serial_module(*args, **kwargs)
new_op_size = len(main_block.ops)
new_op_size = len(cur_block.ops)
default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size):
op = main_block.ops[idx]
op = cur_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
default_dist_ctx.add_dist_op_for_program(dist_op)
......
......@@ -184,7 +184,9 @@ class DistributedTensor:
def _init_default_dist_attr(self):
if self._dist_attr.dims_mapping is None:
if self.serial_tensor.type == core.VarDesc.VarType.READER:
if self.serial_tensor.type == core.VarDesc.VarType.READER \
or self.serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or self.serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = self._serial_tensor.shape
......@@ -192,7 +194,9 @@ class DistributedTensor:
self._dist_attr.dims_mapping = tensor_dims_mapping
def validate_dist_attr(self):
if self.serial_tensor.type == core.VarDesc.VarType.READER:
if self.serial_tensor.type == core.VarDesc.VarType.READER \
or self.serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or self.serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
return True
tensor_shape = self.serial_tensor.shape
if len(tensor_shape) != len(self.dist_attr.dims_mapping):
......
......@@ -259,7 +259,7 @@ class Engine:
"train_" + name: val
for name, val in logs.items()
}
self._logger.info(logs)
self._logger.info(train_logs)
def _train_step(self, data):
logs = {}
......
......@@ -17,7 +17,9 @@ from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_containers = {}
_g_elementwise_ops = ["elementwise_add", "gelu", "dropout", "cast"]
_g_elementwise_ops = [
"elementwise_add", "gelu", "dropout", "cast", "gather", "concat"
]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
......
......@@ -55,9 +55,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
# continue
# if len(dims_mapping) < 1:
# continue
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
......@@ -73,9 +78,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
# continue
# if len(dims_mapping) < 1:
# continue
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
......@@ -104,6 +114,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
# Check output compatibility
......@@ -121,6 +132,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
if dims_mapping[0] != -1:
......@@ -129,6 +141,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
# Check batch dim mapping compatibility
......@@ -143,7 +156,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice":
if op_desc.type() == "shape" \
or op_desc.type() == "slice" \
or op_desc.type() == "while":
return False
output_names = op_desc.output_names()
xshape_arg_names = []
......@@ -155,6 +170,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
......@@ -162,10 +178,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
batch_dim_mappings.append(dims_mapping[1])
if not batch_dim_mappings:
return changed
compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
......@@ -174,7 +194,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
if len(dims_mapping
) >= 1 and compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
......@@ -183,11 +204,13 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
if len(dims_mapping
) >= 1 and compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
else:
if compatible_dim_mapping != dims_mapping[1]:
if len(dims_mapping
) >= 2 and compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping
changed = True
......
......@@ -1432,7 +1432,6 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def is_output_compatible(self, dist_op):
......
......@@ -1271,7 +1271,6 @@ def get_all_distributed_main_program(serial_program_info, dist_context,
used_dist_context._dist_op_context = DistributedOperatorContext()
_, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program(
rank_id, used_dist_context)
# print("dist_main_program: ", dist_main_program)
all_dist_main_program.append(dist_main_program)
return all_dist_main_program
......
......@@ -9,6 +9,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240)
py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS})
set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS})
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
import paddle.nn as nn
import paddle.utils as utils
import paddle.static as static
import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]]
def get_random_inputs_and_labels(input_shape, label_shape):
input = np.random.random(size=input_shape).astype('float32')
label = np.random.random(size=label_shape).astype('float32')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_size):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, sequence_len, hidden_size],
[batch_size, sequence_len, 1])
yield batch_input, batch_label
return __reader__
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(
mean=0.0, std=initializer_range)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear1(out)
return out
def loop_cond(i, loop_len, input_array):
return i < loop_len
def loop_body(i, loop_len, input_array):
pre_input = paddle.tensor.array_read(array=input_array, i=i)
mlp_while0 = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
mlp_while1 = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
output = mlp_while0(pre_input)
cur_pred = mlp_while1(output)
# 更新循环条件
i = paddle.increment(x=i, value=1)
paddle.tensor.array_write(cur_pred, array=input_array, i=i)
return i, loop_len, input_array
def get_program():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
# 循环计数器
i = paddle.full(shape=[1], fill_value=0, dtype='int64')
# 循环次数
loop_len = paddle.full(shape=[1], fill_value=epoch_num, dtype='int64')
# input
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
feed_list=data_holder, capacity=4 * batch_size, iterable=False)
dataloader.set_batch_generator(
batch_generator_creator(), places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, -1, -1]
})
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_start(input)
input_array = paddle.tensor.array_write(pred, i)
i, loop_len, input_array = static.nn.while_loop(
cond=loop_cond,
body=loop_body,
loop_vars=[i, loop_len, input_array])
end_pred = paddle.tensor.array_read(array=input_array, i=i)
mlp_end = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_end(end_pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
loss = paddle.mean(error_cost)
return train_program, start_program, dataloader, i, loss
class TestMLP(unittest.TestCase):
def test_completer(self):
train_program, start_program, dataloader, i, loss = get_program()
dist_context = DistributedContext()
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
# print_program_with_dist_attr(complete_train_program, dist_context)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册