未验证 提交 ec6b8fbd 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Add the support for the auto completion of while_op (#39939)

* [Auto Parallel] Support the auto completion of while_op

* [Auto Parallel] Improve the completion algorithms

* [Auto Parallel] Fix bugs for ernie inference

* [Auto Parallel] Remove attrs which cannot be pickled

* [Auto Parallel] make the dims_mappings of LodTensorArray vars empty

* [Auto Parallel] Fix bugs for the ernie inference in the pipeline parallel

* [Auto Parallel] Remove unncessary comments

* [Auto Parallel] Fix a bug of the CMakeLists

* [Auto Parallel] Use the newest APIs to write the unit test

* [Auto Parallel] Remove unnecessary statements
上级 31858263
...@@ -95,6 +95,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock( ...@@ -95,6 +95,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
std::unordered_map<std::string, std::pair<VarDesc *, int>> std::unordered_map<std::string, std::pair<VarDesc *, int>>
name_to_desc_block_id; name_to_desc_block_id;
block_id_ = block.ID();
const BlockDesc *block_var_visible = &block; const BlockDesc *block_var_visible = &block;
while (block_var_visible != nullptr) { while (block_var_visible != nullptr) {
for (auto *var : block_var_visible->AllVars()) { for (auto *var : block_var_visible->AllVars()) {
......
...@@ -230,6 +230,7 @@ class Graph { ...@@ -230,6 +230,7 @@ class Graph {
auto *x = auto *x =
AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id)); AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x; return x;
} }
...@@ -245,6 +246,7 @@ class Graph { ...@@ -245,6 +246,7 @@ class Graph {
"The OpDesc used to create operator node is null.")); "The OpDesc used to create operator node is null."));
auto *x = AddNode(new ir::Node(op_desc)); auto *x = AddNode(new ir::Node(op_desc));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x; return x;
} }
...@@ -263,6 +265,7 @@ class Graph { ...@@ -263,6 +265,7 @@ class Graph {
num_node_created_); num_node_created_);
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_)); auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x; return x;
} }
...@@ -276,6 +279,7 @@ class Graph { ...@@ -276,6 +279,7 @@ class Graph {
} }
auto *x = AddNode(new ir::Node(name, type, block_id_)); auto *x = AddNode(new ir::Node(name, type, block_id_));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x; return x;
} }
......
...@@ -125,6 +125,7 @@ class Node { ...@@ -125,6 +125,7 @@ class Node {
// Only use this for auto parallel. // Only use this for auto parallel.
// A node does not have original desc if the return is zero. // A node does not have original desc if the return is zero.
uint64_t OriginalDescId() const { return original_desc_id_; } uint64_t OriginalDescId() const { return original_desc_id_; }
int GraphId() const { return graph_id_; }
bool IsOp() const { return type_ == Type::kOperation; } bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; } bool IsVar() const { return type_ == Type::kVariable; }
...@@ -246,10 +247,12 @@ class Node { ...@@ -246,10 +247,12 @@ class Node {
// Store the original id of var desc or op desc. // Store the original id of var desc or op desc.
// Only use this for auto parallel. // Only use this for auto parallel.
uint64_t original_desc_id_{0}; uint64_t original_desc_id_{0};
int graph_id_{-1};
private: private:
// ID can only set by a Graph. // ID can only set by a Graph.
void SetId(int id) { id_ = id; } void SetId(int id) { id_ = id; }
void SetGraphId(int graph_id) { graph_id_ = graph_id; }
// desc_order can only set by a Graph when constructing a Graph from a // desc_order can only set by a Graph when constructing a Graph from a
// BlockDesc. // BlockDesc.
......
...@@ -143,6 +143,7 @@ void BindNode(py::module *m) { ...@@ -143,6 +143,7 @@ void BindNode(py::module *m) {
.def("var", &Node::Var, return_value_policy::reference) .def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op, return_value_policy::reference) .def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id) .def("id", &Node::id)
.def("graph_id", &Node::GraphId)
.def("original_desc_id", &Node::OriginalDescId) .def("original_desc_id", &Node::OriginalDescId)
.def("is_op", &Node::IsOp) .def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar) .def("is_var", &Node::IsVar)
......
...@@ -175,6 +175,7 @@ class TensorDistributedAttribute: ...@@ -175,6 +175,7 @@ class TensorDistributedAttribute:
class OperatorDistributedAttribute: class OperatorDistributedAttribute:
def __init__(self): def __init__(self):
self._process_mesh = None self._process_mesh = None
self._op_type = None
self._impl_type = None self._impl_type = None
self._impl_idx = None self._impl_idx = None
self._inputs_dist_attrs = {} self._inputs_dist_attrs = {}
...@@ -194,11 +195,23 @@ class OperatorDistributedAttribute: ...@@ -194,11 +195,23 @@ class OperatorDistributedAttribute:
if isinstance(process_mesh, list): if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh) process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh) self._process_mesh = copy.deepcopy(process_mesh)
# In while op, the proess mesh is not shared by all inputs and outputs
if self._op_type == "while":
return None
for dist_attr in self._inputs_dist_attrs.values(): for dist_attr in self._inputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh dist_attr.process_mesh = process_mesh
for dist_attr in self._outputs_dist_attrs.values(): for dist_attr in self._outputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh dist_attr.process_mesh = process_mesh
@property
def op_type(self):
return self._op_type
@op_type.setter
def op_type(self, op_type):
if op_type is not None:
self._op_type = op_type
@property @property
def impl_type(self): def impl_type(self):
return self._impl_type return self._impl_type
...@@ -326,6 +339,8 @@ class OperatorDistributedAttribute: ...@@ -326,6 +339,8 @@ class OperatorDistributedAttribute:
assert False, "No setter for {} in args {}.".format( assert False, "No setter for {} in args {}.".format(
key, dist_attr) key, dist_attr)
# Make sure proscess_meshes in dist op be same # Make sure proscess_meshes in dist op be same
if self.op_type == "while":
return None
process_meshes = [] process_meshes = []
process_meshes.append(self.process_mesh) process_meshes.append(self.process_mesh)
for tensor_dist_attr in self.inputs_dist_attrs.values(): for tensor_dist_attr in self.inputs_dist_attrs.values():
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import copy import copy
from collections import defaultdict from collections import defaultdict
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.framework import get_flags, set_flags
from paddle.fluid import core from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
...@@ -39,6 +40,10 @@ def set_default_distributed_context(dist_context): ...@@ -39,6 +40,10 @@ def set_default_distributed_context(dist_context):
_g_default_distributed_context = dist_context _g_default_distributed_context = dist_context
def _node_id(node):
return (node.node.graph_id(), node.node.id())
class DistributedContext: class DistributedContext:
""" """
DistributedContext is used to collect related distributed information for program and graph. DistributedContext is used to collect related distributed information for program and graph.
...@@ -146,7 +151,7 @@ class DistributedContext: ...@@ -146,7 +151,7 @@ class DistributedContext:
return None return None
def get_dist_tensor_for_graph(self, serial_tensor_node): def get_dist_tensor_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id() serial_tensor_node_id = _node_id(serial_tensor_node)
return self._dist_tensors_for_graph.get(serial_tensor_node_id, None) return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)
def get_dist_op_for_program(self, serial_op): def get_dist_op_for_program(self, serial_op):
...@@ -168,7 +173,7 @@ class DistributedContext: ...@@ -168,7 +173,7 @@ class DistributedContext:
del self._dist_ops_for_program[serial_tensor_id] del self._dist_ops_for_program[serial_tensor_id]
def get_dist_op_for_graph(self, serial_op_node): def get_dist_op_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id() serial_op_node_id = _node_id(serial_op_node)
return self._dist_ops_for_graph.get(serial_op_node_id, None) return self._dist_ops_for_graph.get(serial_op_node_id, None)
def get_tensor_dist_attr_for_program(self, serial_tensor): def get_tensor_dist_attr_for_program(self, serial_tensor):
...@@ -197,7 +202,7 @@ class DistributedContext: ...@@ -197,7 +202,7 @@ class DistributedContext:
self.add_dist_tensor_for_program(dist_tensor) self.add_dist_tensor_for_program(dist_tensor)
def get_tensor_dist_attr_for_graph(self, serial_tensor_node): def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id() serial_tensor_node_id = _node_id(serial_tensor_node)
dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id, dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id,
None) None)
if dist_tensor: if dist_tensor:
...@@ -242,7 +247,7 @@ class DistributedContext: ...@@ -242,7 +247,7 @@ class DistributedContext:
self.add_dist_op_for_program(dist_op) self.add_dist_op_for_program(dist_op)
def get_op_dist_attr_for_graph(self, serial_op_node): def get_op_dist_attr_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id() serial_op_node_id = _node_id(serial_op_node)
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op: if dist_op:
return dist_op.dist_attr return dist_op.dist_attr
...@@ -262,7 +267,7 @@ class DistributedContext: ...@@ -262,7 +267,7 @@ class DistributedContext:
def get_dist_attr_for_graph(self, serial_node): def get_dist_attr_for_graph(self, serial_node):
if serial_node.is_var() and serial_node.var() is not None: if serial_node.is_var() and serial_node.var() is not None:
serial_tensor_node_id = serial_node.id() serial_tensor_node_id = _node_id(serial_node)
dist_tensor = self._dist_tensors_for_graph.get( dist_tensor = self._dist_tensors_for_graph.get(
serial_tensor_node_id, None) serial_tensor_node_id, None)
if dist_tensor: if dist_tensor:
...@@ -270,7 +275,7 @@ class DistributedContext: ...@@ -270,7 +275,7 @@ class DistributedContext:
else: else:
return None return None
if serial_node.is_op() and serial_node.op() is not None: if serial_node.is_op() and serial_node.op() is not None:
serial_op_node_id = serial_node.id() serial_op_node_id = _node_id(serial_node)
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op: if dist_op:
return dist_op.dist_attr return dist_op.dist_attr
...@@ -311,40 +316,69 @@ class DistributedContext: ...@@ -311,40 +316,69 @@ class DistributedContext:
def order_nodes_by_program_order(self): def order_nodes_by_program_order(self):
def _contains(nodes, target_node): def _contains(nodes, target_node):
for node in nodes: for node in nodes:
if node.id() == target_node.id(): if _node_id(node) == _node_id(target_node):
return True return True
return False return False
ordered_tensor_nodes = [] serial_ordered_tensor_nodes = []
ordered_op_nodes = [] serial_ordered_op_nodes = []
all_nodes = self._serial_graph.all_nodes() all_nodes = []
# for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for node in graph.all_nodes():
all_nodes.append(node)
for node in all_nodes: for node in all_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
ordered_tensor_nodes.append(node) serial_ordered_tensor_nodes.append(node)
if node.is_op() and node.op() is not None: if node.is_op() and node.op() is not None:
ordered_op_nodes.append(node) serial_ordered_op_nodes.append(node)
ordered_tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) serial_ordered_tensor_nodes.sort(
ordered_op_nodes.sort(key=lambda node: node.node.original_desc_id()) key=lambda node: node.node.original_desc_id())
for op_node in ordered_op_nodes: serial_ordered_op_nodes.sort(
key=lambda node: node.node.original_desc_id())
num_nodes_before = len(serial_ordered_tensor_nodes) + len(
serial_ordered_op_nodes)
new_serial_ordered_tensor_nodes = []
new_serial_ordered_op_nodes = []
for op_node in serial_ordered_op_nodes:
tensor_nodes = [] tensor_nodes = []
for tensor_node in op_node.inputs: for tensor_node in op_node.inputs:
if tensor_node.is_var() \ if tensor_node.is_var() \
and tensor_node.var() is not None \ and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node): and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node) tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes) self._serial_ordered_nodes.extend(tensor_nodes)
self._serial_ordered_nodes.append(op_node) self._serial_ordered_nodes.append(op_node)
new_serial_ordered_op_nodes.append(op_node)
tensor_nodes = [] tensor_nodes = []
for tensor_node in op_node.outputs: for tensor_node in op_node.outputs:
if tensor_node.is_var() \ if tensor_node.is_var() \
and tensor_node.var() is not None \ and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node): and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node) tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes) self._serial_ordered_nodes.extend(tensor_nodes)
num_nodes_before = len(ordered_tensor_nodes) + len(ordered_op_nodes) new_serial_ordered_tensor_nodes.sort(
assert len(self._serial_ordered_nodes) == num_nodes_before, \ key=lambda node: node.node.original_desc_id())
"The number of nodes before ordering is not the same after ordering." new_serial_ordered_op_nodes.sort(
key=lambda node: node.node.original_desc_id())
self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes
self._serial_ordered_op_nodes = new_serial_ordered_op_nodes
assert len(self._serial_ordered_nodes) == len(
self._serial_ordered_tensor_nodes) + len(
self._serial_ordered_op_nodes)
self._serial_orphan_tensor_nodes = []
for tensor_node in serial_ordered_tensor_nodes:
if not _contains(self._serial_ordered_tensor_nodes, tensor_node):
self._serial_orphan_tensor_nodes.append(tensor_node)
if len(self._serial_ordered_nodes) != num_nodes_before:
print(
"WARNING: there are some orphan tensors or ops which are not used in the execution."
)
def init_dist_attr_for_graph(self): def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \ assert self._is_initialized_for_program, \
...@@ -352,9 +386,9 @@ class DistributedContext: ...@@ -352,9 +386,9 @@ class DistributedContext:
if self._is_initialized_for_graph: if self._is_initialized_for_graph:
return return
# Convert program to graph # Convert program to graph
set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph( self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc)) core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
self.order_nodes_by_program_order() self.order_nodes_by_program_order()
for node in self.serial_ordered_nodes: for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
...@@ -365,10 +399,11 @@ class DistributedContext: ...@@ -365,10 +399,11 @@ class DistributedContext:
if tensor_id == cur_tensor_id \ if tensor_id == cur_tensor_id \
or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id(): or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
dist_tensor = cur_dist_tensor dist_tensor = cur_dist_tensor
self._node_id_to_tensor_id[node.id()] = cur_tensor_id self._node_id_to_tensor_id[_node_id(
node)] = cur_tensor_id
assert dist_tensor is not None, \ assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program." "Tensor must have a distributed tensor after the initialization for program."
serial_tensor_node_id = node.id() serial_tensor_node_id = _node_id(node)
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_tensor.dist_attr) dist_tensor.dist_attr)
self._dist_tensors_for_graph[ self._dist_tensors_for_graph[
...@@ -381,10 +416,10 @@ class DistributedContext: ...@@ -381,10 +416,10 @@ class DistributedContext:
if op_id == cur_op_id \ if op_id == cur_op_id \
or op_id == cur_dist_op.serial_op.desc.original_id(): or op_id == cur_dist_op.serial_op.desc.original_id():
dist_op = cur_dist_op dist_op = cur_dist_op
self._node_id_to_op_id[node.id()] = cur_op_id self._node_id_to_op_id[_node_id(node)] = cur_op_id
assert dist_op is not None, \ assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program." "Operator must have a distributed operator after the initialization for program."
serial_op_node_id = node.id() serial_op_node_id = _node_id(node)
new_dist_op = DistributedOperator(dist_op.serial_op, new_dist_op = DistributedOperator(dist_op.serial_op,
dist_op.dist_attr) dist_op.dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
...@@ -402,10 +437,11 @@ class DistributedContext: ...@@ -402,10 +437,11 @@ class DistributedContext:
assert self._is_initialized_for_program and self._is_initialized_for_graph, \ assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"Both program and graph must be initialized." "Both program and graph must be initialized."
updated_tensors = {} updated_tensors = {}
all_nodes = self._serial_graph.all_nodes() # all_nodes = self._serial_graph.all_nodes()
all_nodes = self._serial_ordered_nodes
for node in all_nodes: for node in all_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
tensor_id = self._node_id_to_tensor_id[node.id()] tensor_id = self._node_id_to_tensor_id[_node_id(node)]
updated = updated_tensors.get(tensor_id, False) updated = updated_tensors.get(tensor_id, False)
# If a var has multiples var nodes in graph, only use the first one for now # If a var has multiples var nodes in graph, only use the first one for now
if not updated: if not updated:
...@@ -416,16 +452,31 @@ class DistributedContext: ...@@ -416,16 +452,31 @@ class DistributedContext:
dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
updated_tensors[tensor_id] = True updated_tensors[tensor_id] = True
if node.is_op() and node.op() is not None: if node.is_op() and node.op() is not None:
op_id = self._node_id_to_op_id[node.id()] op_id = self._node_id_to_op_id[_node_id(node)]
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node) op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id] dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph dist_op_for_program.dist_attr = op_dist_attr_for_graph
# TODO: the completion algorithm will skip orphan tensors,
# here we just set there process_mesh to the first one.
for orphan_node in self._serial_orphan_tensor_nodes:
serial_tensor_id = orphan_node.var().id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id,
None)
if dist_tensor:
dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
else:
serial_tensor_id = orphan_node.var().original_id()
dist_tensor = self._dist_tensors_for_program.get(
serial_tensor_id, None)
dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
def amend_dist_attr_for_program(self): def amend_dist_attr_for_program(self):
for dist_tensor in self._dist_tensors_for_program.values(): for dist_tensor in self._dist_tensors_for_program.values():
serial_tensor = dist_tensor.serial_tensor serial_tensor = dist_tensor.serial_tensor
dist_attr = dist_tensor.dist_attr dist_attr = dist_tensor.dist_attr
if serial_tensor.type == core.VarDesc.VarType.READER: if serial_tensor.type == core.VarDesc.VarType.READER \
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = serial_tensor.shape tensor_shape = serial_tensor.shape
...@@ -446,6 +497,7 @@ class DistributedContext: ...@@ -446,6 +497,7 @@ class DistributedContext:
tensor_shape = [] tensor_shape = []
else: else:
if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \ if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or dist_op.serial_op.type == "create_py_reader": or dist_op.serial_op.type == "create_py_reader":
tensor_shape = [] tensor_shape = []
else: else:
...@@ -459,8 +511,9 @@ class DistributedContext: ...@@ -459,8 +511,9 @@ class DistributedContext:
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1 dims_mapping[i] = -1
for arg_name in serial_op.output_arg_names: for arg_name in serial_op.output_arg_names:
if dist_op.get_serial_output( if dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.READER \
arg_name).type == core.VarDesc.VarType.READER: or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = dist_op.get_serial_output(arg_name).shape tensor_shape = dist_op.get_serial_output(arg_name).shape
...@@ -498,7 +551,8 @@ class DistributedContext: ...@@ -498,7 +551,8 @@ class DistributedContext:
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph" \ if k == "_serial_program" or k == "_serial_graph" \
or k == "_dist_main_programs" or k == "_dist_startup_programs" \ or k == "_dist_main_programs" or k == "_dist_startup_programs" \
or k == "_serial_ordered_nodes": or k == "_serial_ordered_nodes" or k == "_serial_ordered_tensor_nodes" \
or k == "_serial_ordered_op_nodes":
setattr(result, k, v) setattr(result, k, v)
else: else:
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
......
...@@ -76,7 +76,8 @@ class DistributedOperator: ...@@ -76,7 +76,8 @@ class DistributedOperator:
if tensor is None: if tensor is None:
tensor_shape = [] tensor_shape = []
else: else:
if tensor.type == core.VarDesc.VarType.READER: if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
...@@ -86,7 +87,9 @@ class DistributedOperator: ...@@ -86,7 +87,9 @@ class DistributedOperator:
tensor_dims_mapping) tensor_dims_mapping)
for tensor_name in self._serial_op.output_arg_names: for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name) tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER or tensor.type == core.VarDesc.VarType.STEP_SCOPES: if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
...@@ -95,6 +98,8 @@ class DistributedOperator: ...@@ -95,6 +98,8 @@ class DistributedOperator:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_output_dims_mapping(tensor_name, self._dist_attr.set_output_dims_mapping(tensor_name,
tensor_dims_mapping) tensor_dims_mapping)
if self._dist_attr.op_type is None:
self._dist_attr.op_type = self.serial_op.type
if self._dist_attr.impl_type is None: if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default" self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None: if self._dist_attr.impl_idx is None:
...@@ -134,12 +139,16 @@ class DistributedOperator: ...@@ -134,12 +139,16 @@ class DistributedOperator:
return new_dist_attr return new_dist_attr
def validate_dist_attr(self): def validate_dist_attr(self):
if "read" in self.serial_op.type: if "read" in self.serial_op.type or "while" == self.serial_op.type:
return True return True
for name in self.serial_op.input_arg_names: for name in self.serial_op.input_arg_names:
input_dist_attr = self.dist_attr.get_input_dist_attr(name) input_dist_attr = self.dist_attr.get_input_dist_attr(name)
dims_mapping = input_dist_attr.dims_mapping dims_mapping = input_dist_attr.dims_mapping
shape = self.get_serial_input(name).shape if self.get_serial_input(
name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
shape = []
else:
shape = self.get_serial_input(name).shape
if len(shape) != len(dims_mapping): if len(shape) != len(dims_mapping):
return False return False
for i in range(len(dims_mapping)): for i in range(len(dims_mapping)):
...@@ -155,7 +164,11 @@ class DistributedOperator: ...@@ -155,7 +164,11 @@ class DistributedOperator:
for name in self.serial_op.output_arg_names: for name in self.serial_op.output_arg_names:
output_dist_attr = self.dist_attr.get_output_dist_attr(name) output_dist_attr = self.dist_attr.get_output_dist_attr(name)
dims_mapping = output_dist_attr.dims_mapping dims_mapping = output_dist_attr.dims_mapping
shape = self.get_serial_output(name).shape if self.get_serial_output(name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY\
or self.get_serial_output(name).type == core.VarDesc.VarType.STEP_SCOPES:
shape = []
else:
shape = self.get_serial_output(name).shape
if len(shape) != len(dims_mapping): if len(shape) != len(dims_mapping):
return False return False
for i in range(len(dims_mapping)): for i in range(len(dims_mapping)):
...@@ -241,14 +254,14 @@ class DistributedModule: ...@@ -241,14 +254,14 @@ class DistributedModule:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
main_prog = paddle.fluid.default_main_program() default_prog = paddle.fluid.default_main_program()
main_block = main_prog.global_block() cur_block = default_prog.current_block()
op_size = len(main_block.ops) op_size = len(cur_block.ops)
output = self._serial_module(*args, **kwargs) output = self._serial_module(*args, **kwargs)
new_op_size = len(main_block.ops) new_op_size = len(cur_block.ops)
default_dist_ctx = get_default_distributed_context() default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size): for idx in range(op_size, new_op_size):
op = main_block.ops[idx] op = cur_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr) dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr) dist_op.dist_attr.mark_annotated_as(self._dist_attr)
default_dist_ctx.add_dist_op_for_program(dist_op) default_dist_ctx.add_dist_op_for_program(dist_op)
......
...@@ -184,7 +184,9 @@ class DistributedTensor: ...@@ -184,7 +184,9 @@ class DistributedTensor:
def _init_default_dist_attr(self): def _init_default_dist_attr(self):
if self._dist_attr.dims_mapping is None: if self._dist_attr.dims_mapping is None:
if self.serial_tensor.type == core.VarDesc.VarType.READER: if self.serial_tensor.type == core.VarDesc.VarType.READER \
or self.serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or self.serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = self._serial_tensor.shape tensor_shape = self._serial_tensor.shape
...@@ -192,7 +194,9 @@ class DistributedTensor: ...@@ -192,7 +194,9 @@ class DistributedTensor:
self._dist_attr.dims_mapping = tensor_dims_mapping self._dist_attr.dims_mapping = tensor_dims_mapping
def validate_dist_attr(self): def validate_dist_attr(self):
if self.serial_tensor.type == core.VarDesc.VarType.READER: if self.serial_tensor.type == core.VarDesc.VarType.READER \
or self.serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or self.serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
return True return True
tensor_shape = self.serial_tensor.shape tensor_shape = self.serial_tensor.shape
if len(tensor_shape) != len(self.dist_attr.dims_mapping): if len(tensor_shape) != len(self.dist_attr.dims_mapping):
......
...@@ -259,7 +259,7 @@ class Engine: ...@@ -259,7 +259,7 @@ class Engine:
"train_" + name: val "train_" + name: val
for name, val in logs.items() for name, val in logs.items()
} }
self._logger.info(logs) self._logger.info(train_logs)
def _train_step(self, data): def _train_step(self, data):
logs = {} logs = {}
......
...@@ -17,7 +17,9 @@ from ..dist_attribute import OperatorDistributedAttribute ...@@ -17,7 +17,9 @@ from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_containers = {} _g_distributed_operator_impl_containers = {}
_g_elementwise_ops = ["elementwise_add", "gelu", "dropout", "cast"] _g_elementwise_ops = [
"elementwise_add", "gelu", "dropout", "cast", "gather", "concat"
]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
......
...@@ -55,9 +55,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -55,9 +55,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
# continue
# if len(dims_mapping) < 1:
# continue
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
if mapping != -1: if mapping != -1:
...@@ -73,9 +78,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -73,9 +78,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
xshape_arg_names = op_desc.output("XShape") xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
# continue
# if len(dims_mapping) < 1:
# continue
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
...@@ -104,7 +114,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -104,7 +114,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
if mapping != -1: if mapping != -1:
return False return False
batch_dim_mappings.append(dims_mapping[0]) if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
# Check output compatibility # Check output compatibility
output_names = op_desc.output_names() output_names = op_desc.output_names()
...@@ -121,7 +132,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -121,7 +132,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
if mapping != -1: if mapping != -1:
return False return False
batch_dim_mappings.append(dims_mapping[0]) if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else: else:
if dims_mapping[0] != -1: if dims_mapping[0] != -1:
return False return False
...@@ -129,7 +141,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -129,7 +141,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[2:]: for mapping in dims_mapping[2:]:
if mapping != -1: if mapping != -1:
return False return False
batch_dim_mappings.append(dims_mapping[1]) if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
# Check batch dim mapping compatibility # Check batch dim mapping compatibility
if not all(batch_dim_mappings[0] == dim_mapping if not all(batch_dim_mappings[0] == dim_mapping
...@@ -143,7 +156,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -143,7 +156,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
# The following statement will be replaced by a more elegent way # The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice": if op_desc.type() == "shape" \
or op_desc.type() == "slice" \
or op_desc.type() == "while":
return False return False
output_names = op_desc.output_names() output_names = op_desc.output_names()
xshape_arg_names = [] xshape_arg_names = []
...@@ -155,17 +170,22 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -155,17 +170,22 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
batch_dim_mappings.append(dims_mapping[0]) if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
batch_dim_mappings.append(dims_mapping[0]) if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else: else:
batch_dim_mappings.append(dims_mapping[1]) batch_dim_mappings.append(dims_mapping[1])
if not batch_dim_mappings:
return changed
compatible_dim_mapping = compute_compatible_dim_mapping( compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings) batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping." assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
...@@ -174,7 +194,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -174,7 +194,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]: if len(dims_mapping
) >= 1 and compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
changed = True changed = True
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
...@@ -183,11 +204,13 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -183,11 +204,13 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
continue continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]: if len(dims_mapping
) >= 1 and compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
changed = True changed = True
else: else:
if compatible_dim_mapping != dims_mapping[1]: if len(dims_mapping
) >= 2 and compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping dims_mapping[1] = compatible_dim_mapping
changed = True changed = True
......
...@@ -1432,7 +1432,6 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -1432,7 +1432,6 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
if is_valid_list_index(y_dims_mapping, if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]): -2) and is_dim_shard(y_dims_mapping[-2]):
return False return False
return True return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
......
...@@ -1271,7 +1271,6 @@ def get_all_distributed_main_program(serial_program_info, dist_context, ...@@ -1271,7 +1271,6 @@ def get_all_distributed_main_program(serial_program_info, dist_context,
used_dist_context._dist_op_context = DistributedOperatorContext() used_dist_context._dist_op_context = DistributedOperatorContext()
_, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program( _, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program(
rank_id, used_dist_context) rank_id, used_dist_context)
# print("dist_main_program: ", dist_main_program)
all_dist_main_program.append(dist_main_program) all_dist_main_program.append(dist_main_program)
return all_dist_main_program return all_dist_main_program
......
...@@ -9,6 +9,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -9,6 +9,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240) set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240)
py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS}) py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS})
set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS}) py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS})
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
import paddle.nn as nn
import paddle.utils as utils
import paddle.static as static
import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]]
def get_random_inputs_and_labels(input_shape, label_shape):
input = np.random.random(size=input_shape).astype('float32')
label = np.random.random(size=label_shape).astype('float32')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_size):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, sequence_len, hidden_size],
[batch_size, sequence_len, 1])
yield batch_input, batch_label
return __reader__
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(
mean=0.0, std=initializer_range)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear1(out)
return out
def loop_cond(i, loop_len, input_array):
return i < loop_len
def loop_body(i, loop_len, input_array):
pre_input = paddle.tensor.array_read(array=input_array, i=i)
mlp_while0 = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
mlp_while1 = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
output = mlp_while0(pre_input)
cur_pred = mlp_while1(output)
# 更新循环条件
i = paddle.increment(x=i, value=1)
paddle.tensor.array_write(cur_pred, array=input_array, i=i)
return i, loop_len, input_array
def get_program():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
# 循环计数器
i = paddle.full(shape=[1], fill_value=0, dtype='int64')
# 循环次数
loop_len = paddle.full(shape=[1], fill_value=epoch_num, dtype='int64')
# input
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
feed_list=data_holder, capacity=4 * batch_size, iterable=False)
dataloader.set_batch_generator(
batch_generator_creator(), places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, -1, -1]
})
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_start(input)
input_array = paddle.tensor.array_write(pred, i)
i, loop_len, input_array = static.nn.while_loop(
cond=loop_cond,
body=loop_body,
loop_vars=[i, loop_len, input_array])
end_pred = paddle.tensor.array_read(array=input_array, i=i)
mlp_end = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_end(end_pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
loss = paddle.mean(error_cost)
return train_program, start_program, dataloader, i, loss
class TestMLP(unittest.TestCase):
def test_completer(self):
train_program, start_program, dataloader, i, loss = get_program()
dist_context = DistributedContext()
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
# print_program_with_dist_attr(complete_train_program, dist_context)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册