未验证 提交 339aefac 编写于 作者: Z zhaoyingli 提交者: GitHub

manage no shape var type (#47775)

上级 692a9632
...@@ -18,7 +18,7 @@ import logging ...@@ -18,7 +18,7 @@ import logging
from paddle.fluid import core from paddle.fluid import core
from .utils import is_naive_data_parallel, get_logger from .utils import is_naive_data_parallel, get_logger
from .utils import is_gradient_clip_op, __not_shape_var_type__ from .utils import is_gradient_clip_op, __no_shape_var_type__
from .operators import find_compatible_distributed_operator_impls from .operators import find_compatible_distributed_operator_impls
from .dist_context import _node_id from .dist_context import _node_id
from .dist_attribute import TensorDistributedAttribute from .dist_attribute import TensorDistributedAttribute
...@@ -151,11 +151,7 @@ class Completer: ...@@ -151,11 +151,7 @@ class Completer:
return False return False
tensor_desc = tensor_node.var() tensor_desc = tensor_node.var()
# Skip reader tensor # Skip reader tensor
if ( if tensor_desc.type() in __no_shape_var_type__:
tensor_desc.type() == core.VarDesc.VarType.READER
or tensor_desc.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor_desc.type == core.VarDesc.VarType.STEP_SCOPES
):
return False return False
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node tensor_node
...@@ -621,7 +617,7 @@ class Completer: ...@@ -621,7 +617,7 @@ class Completer:
): ):
if ( if (
tensor_node.var().type() tensor_node.var().type()
in __not_shape_var_type__ in __no_shape_var_type__
or len(tensor_node.var().shape()) != 1 or len(tensor_node.var().shape()) != 1
): ):
flag = False flag = False
...@@ -633,7 +629,7 @@ class Completer: ...@@ -633,7 +629,7 @@ class Completer:
): ):
if ( if (
tensor_node.var().type() tensor_node.var().type()
in __not_shape_var_type__ in __no_shape_var_type__
or len(tensor_node.var().shape()) != 1 or len(tensor_node.var().shape()) != 1
): ):
flag = False flag = False
......
...@@ -22,7 +22,7 @@ from .dist_tensor import DistributedTensor ...@@ -22,7 +22,7 @@ from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
from .utils import _copy_dist_attr_to_cpp from .utils import _copy_dist_attr_to_cpp
from .utils import is_loss_grad_op from .utils import is_loss_grad_op, __no_shape_var_type__
# There always exists a default context for user. And user can set it to another one. # There always exists a default context for user. And user can set it to another one.
...@@ -862,11 +862,7 @@ class DistributedContext: ...@@ -862,11 +862,7 @@ class DistributedContext:
for dist_tensor in self._dist_tensors_for_program.values(): for dist_tensor in self._dist_tensors_for_program.values():
serial_tensor = dist_tensor.serial_tensor serial_tensor = dist_tensor.serial_tensor
dist_attr = dist_tensor.dist_attr dist_attr = dist_tensor.dist_attr
if ( if serial_tensor.type in __no_shape_var_type__:
serial_tensor.type == core.VarDesc.VarType.READER
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = serial_tensor.shape tensor_shape = serial_tensor.shape
...@@ -896,10 +892,7 @@ class DistributedContext: ...@@ -896,10 +892,7 @@ class DistributedContext:
else: else:
if ( if (
dist_op.get_serial_input(arg_name).type dist_op.get_serial_input(arg_name).type
== core.VarDesc.VarType.READER in __no_shape_var_type__
or dist_op.get_serial_input(arg_name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or dist_op.serial_op.type == "create_py_reader"
): ):
tensor_shape = [] tensor_shape = []
else: else:
...@@ -923,11 +916,7 @@ class DistributedContext: ...@@ -923,11 +916,7 @@ class DistributedContext:
for arg_name in serial_op.output_arg_names: for arg_name in serial_op.output_arg_names:
if ( if (
dist_op.get_serial_output(arg_name).type dist_op.get_serial_output(arg_name).type
== core.VarDesc.VarType.READER in __no_shape_var_type__
or dist_op.get_serial_output(arg_name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or dist_op.get_serial_output(arg_name).type
== core.VarDesc.VarType.STEP_SCOPES
): ):
tensor_shape = [] tensor_shape = []
else: else:
......
...@@ -14,12 +14,15 @@ ...@@ -14,12 +14,15 @@
import copy import copy
import paddle import paddle
from paddle.fluid import core
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from .dist_attribute import append_op_input_suffix from .dist_attribute import append_op_input_suffix
from .dist_attribute import append_op_output_suffix from .dist_attribute import append_op_output_suffix
from .utils import convert_to_shard_spec, verify_shard_spec from .utils import (
convert_to_shard_spec,
verify_shard_spec,
__no_shape_var_type__,
)
class DistributedOperator: class DistributedOperator:
...@@ -73,10 +76,7 @@ class DistributedOperator: ...@@ -73,10 +76,7 @@ class DistributedOperator:
if tensor is None: if tensor is None:
tensor_shape = [] tensor_shape = []
else: else:
if ( if tensor.type in __no_shape_var_type__:
tensor.type == core.VarDesc.VarType.READER
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
...@@ -87,11 +87,7 @@ class DistributedOperator: ...@@ -87,11 +87,7 @@ class DistributedOperator:
) )
for tensor_name in self._serial_op.output_arg_names: for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name) tensor = self._serial_op.block._var_recursive(tensor_name)
if ( if tensor.type in __no_shape_var_type__:
tensor.type == core.VarDesc.VarType.READER
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
...@@ -151,10 +147,7 @@ class DistributedOperator: ...@@ -151,10 +147,7 @@ class DistributedOperator:
for name in self.serial_op.input_arg_names: for name in self.serial_op.input_arg_names:
input_dist_attr = self.dist_attr.get_input_dist_attr(name) input_dist_attr = self.dist_attr.get_input_dist_attr(name)
dims_mapping = input_dist_attr.dims_mapping dims_mapping = input_dist_attr.dims_mapping
if ( if self.get_serial_input(name).type in __no_shape_var_type__:
self.get_serial_input(name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
shape = [] shape = []
else: else:
shape = self.get_serial_input(name).shape shape = self.get_serial_input(name).shape
...@@ -174,12 +167,7 @@ class DistributedOperator: ...@@ -174,12 +167,7 @@ class DistributedOperator:
for name in self.serial_op.output_arg_names: for name in self.serial_op.output_arg_names:
output_dist_attr = self.dist_attr.get_output_dist_attr(name) output_dist_attr = self.dist_attr.get_output_dist_attr(name)
dims_mapping = output_dist_attr.dims_mapping dims_mapping = output_dist_attr.dims_mapping
if ( if self.get_serial_output(name).type in __no_shape_var_type__:
self.get_serial_output(name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or self.get_serial_output(name).type
== core.VarDesc.VarType.STEP_SCOPES
):
shape = [] shape = []
else: else:
shape = self.get_serial_output(name).shape shape = self.get_serial_output(name).shape
...@@ -337,12 +325,7 @@ class DistributedOperatorHelper: ...@@ -337,12 +325,7 @@ class DistributedOperatorHelper:
if tensor is None: if tensor is None:
tensor_shape = [] tensor_shape = []
else: else:
if ( if tensor.type in __no_shape_var_type__:
tensor.type == core.VarDesc.VarType.READER
or tensor.type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
...@@ -368,12 +351,7 @@ class DistributedOperatorHelper: ...@@ -368,12 +351,7 @@ class DistributedOperatorHelper:
if tensor is None: if tensor is None:
tensor_shape = [] tensor_shape = []
else: else:
if ( if tensor.type in __no_shape_var_type__:
tensor.type == core.VarDesc.VarType.READER
or tensor.type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
......
...@@ -16,10 +16,9 @@ import copy ...@@ -16,10 +16,9 @@ import copy
import inspect import inspect
import paddle import paddle
from paddle.fluid import core
from paddle.fluid.framework import Parameter, Block, Variable from paddle.fluid.framework import Parameter, Block, Variable
from .dist_attribute import TensorDistributedAttribute from .dist_attribute import TensorDistributedAttribute
from .utils import _linear_idx2coordinate from .utils import _linear_idx2coordinate, __no_shape_var_type__
class DistributedTensor: class DistributedTensor:
...@@ -208,12 +207,7 @@ class DistributedTensor: ...@@ -208,12 +207,7 @@ class DistributedTensor:
def _init_default_dist_attr(self): def _init_default_dist_attr(self):
if self._dist_attr.dims_mapping is None: if self._dist_attr.dims_mapping is None:
if ( if self.serial_tensor.type in __no_shape_var_type__:
self.serial_tensor.type == core.VarDesc.VarType.READER
or self.serial_tensor.type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or self.serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = self._serial_tensor.shape tensor_shape = self._serial_tensor.shape
...@@ -221,11 +215,7 @@ class DistributedTensor: ...@@ -221,11 +215,7 @@ class DistributedTensor:
self._dist_attr.dims_mapping = tensor_dims_mapping self._dist_attr.dims_mapping = tensor_dims_mapping
def validate_dist_attr(self): def validate_dist_attr(self):
if ( if self.serial_tensor.type in __no_shape_var_type__:
self.serial_tensor.type == core.VarDesc.VarType.READER
or self.serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or self.serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
return True return True
tensor_shape = self.serial_tensor.shape tensor_shape = self.serial_tensor.shape
if len(tensor_shape) != len(self.dist_attr.dims_mapping): if len(tensor_shape) != len(self.dist_attr.dims_mapping):
......
...@@ -580,7 +580,7 @@ class Engine: ...@@ -580,7 +580,7 @@ class Engine:
metric.compute(*(outputs + self._labels)) metric.compute(*(outputs + self._labels))
) )
) )
else: elif mode == "train":
assert isinstance( assert isinstance(
self._loss, Variable self._loss, Variable
), "the type of `loss` of the Engine arguments should be Variable." ), "the type of `loss` of the Engine arguments should be Variable."
......
...@@ -13,13 +13,16 @@ ...@@ -13,13 +13,16 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.fluid import core
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
from .process_mesh import get_current_process_mesh from .process_mesh import get_current_process_mesh
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperatorHelper from .dist_op import DistributedOperatorHelper
from .utils import verify_shard_spec, convert_to_dims_mapping from .utils import (
verify_shard_spec,
convert_to_dims_mapping,
__no_shape_var_type__,
)
def shard_tensor(x, process_mesh=None, shard_spec=None): def shard_tensor(x, process_mesh=None, shard_spec=None):
...@@ -79,11 +82,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): ...@@ -79,11 +82,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
dist_tensor = DistributedTensor(x) dist_tensor = DistributedTensor(x)
serial_tensor = dist_tensor.serial_tensor serial_tensor = dist_tensor.serial_tensor
dist_tensor.dist_attr.process_mesh = process_mesh dist_tensor.dist_attr.process_mesh = process_mesh
if ( if serial_tensor.type in __no_shape_var_type__:
serial_tensor.type == core.VarDesc.VarType.READER
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = serial_tensor.shape tensor_shape = serial_tensor.shape
......
...@@ -22,14 +22,16 @@ from paddle.distributed.auto_parallel.operators.common import ( ...@@ -22,14 +22,16 @@ from paddle.distributed.auto_parallel.operators.common import (
) )
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from .utils import is_backward_op, is_forward_op, is_loss_op, is_optimize_op
from .operators.common import BACKWARD_ONLY_DIST_OPS from .operators.common import BACKWARD_ONLY_DIST_OPS
from .utils import (
is_backward_op,
is_forward_op,
is_loss_op,
is_optimize_op,
__no_shape_var_type__,
)
__varname_not_in_block__ = ["lod_tensor_blocking_queue"] __varname_not_in_block__ = ["lod_tensor_blocking_queue"]
__not_shape_var_type__ = [
core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES,
]
class Partitioner: class Partitioner:
...@@ -363,7 +365,7 @@ class Partitioner: ...@@ -363,7 +365,7 @@ class Partitioner:
var_dist_attrs = [ var_dist_attrs = [
self._dist_context.get_tensor_dist_attr_for_program(var) self._dist_context.get_tensor_dist_attr_for_program(var)
for var in vars_ for var in vars_
if (var.type not in __not_shape_var_type__) if (var.type not in __no_shape_var_type__)
] ]
all_ops_annotated = all( all_ops_annotated = all(
...@@ -468,7 +470,7 @@ def _partition_var( ...@@ -468,7 +470,7 @@ def _partition_var(
""" """
src_var = src_block.var(src_varname) src_var = src_block.var(src_varname)
if src_var.type in __not_shape_var_type__: if src_var.type in __no_shape_var_type__:
persist = getattr(src_var, 'persistable', False) persist = getattr(src_var, 'persistable', False)
new_var = dst_block.create_var( new_var = dst_block.create_var(
type=src_var.type, type=src_var.type,
......
...@@ -33,9 +33,12 @@ from paddle.distributed.auto_parallel.dist_attribute import ( ...@@ -33,9 +33,12 @@ from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute, OperatorDistributedAttribute,
) )
__not_shape_var_type__ = [ __no_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES, core.VarDesc.VarType.STEP_SCOPES,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
core.VarDesc.VarType.FEED_MINIBATCH,
core.VarDesc.VarType.FETCH_LIST,
] ]
__not_naive_data_parallel_op__ = ["expand_v2"] __not_naive_data_parallel_op__ = ["expand_v2"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册