diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 97023a43ccfe40199c2793a385bb3641f948ccef..c109c861a5e579288fc0f51a5da3c578d8e989aa 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -18,7 +18,7 @@ import logging from paddle.fluid import core from .utils import is_naive_data_parallel, get_logger -from .utils import is_gradient_clip_op, __not_shape_var_type__ +from .utils import is_gradient_clip_op, __no_shape_var_type__ from .operators import find_compatible_distributed_operator_impls from .dist_context import _node_id from .dist_attribute import TensorDistributedAttribute @@ -151,11 +151,7 @@ class Completer: return False tensor_desc = tensor_node.var() # Skip reader tensor - if ( - tensor_desc.type() == core.VarDesc.VarType.READER - or tensor_desc.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or tensor_desc.type == core.VarDesc.VarType.STEP_SCOPES - ): + if tensor_desc.type() in __no_shape_var_type__: return False tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( tensor_node @@ -621,7 +617,7 @@ class Completer: ): if ( tensor_node.var().type() - in __not_shape_var_type__ + in __no_shape_var_type__ or len(tensor_node.var().shape()) != 1 ): flag = False @@ -633,7 +629,7 @@ class Completer: ): if ( tensor_node.var().type() - in __not_shape_var_type__ + in __no_shape_var_type__ or len(tensor_node.var().shape()) != 1 ): flag = False diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index f410468f45b8584c8152c317cdaf4977a93df939..199f27934d728e736c27168a6d0a8791c2e8619d 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -22,7 +22,7 @@ from .dist_tensor import DistributedTensor from .dist_op import DistributedOperator from .process_mesh import ProcessMesh from .utils import _copy_dist_attr_to_cpp -from .utils import is_loss_grad_op +from .utils import is_loss_grad_op, __no_shape_var_type__ # There always exists a default context for user. And user can set it to another one. @@ -862,11 +862,7 @@ class DistributedContext: for dist_tensor in self._dist_tensors_for_program.values(): serial_tensor = dist_tensor.serial_tensor dist_attr = dist_tensor.dist_attr - if ( - serial_tensor.type == core.VarDesc.VarType.READER - or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES - ): + if serial_tensor.type in __no_shape_var_type__: tensor_shape = [] else: tensor_shape = serial_tensor.shape @@ -896,10 +892,7 @@ class DistributedContext: else: if ( dist_op.get_serial_input(arg_name).type - == core.VarDesc.VarType.READER - or dist_op.get_serial_input(arg_name).type - == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or dist_op.serial_op.type == "create_py_reader" + in __no_shape_var_type__ ): tensor_shape = [] else: @@ -923,11 +916,7 @@ class DistributedContext: for arg_name in serial_op.output_arg_names: if ( dist_op.get_serial_output(arg_name).type - == core.VarDesc.VarType.READER - or dist_op.get_serial_output(arg_name).type - == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or dist_op.get_serial_output(arg_name).type - == core.VarDesc.VarType.STEP_SCOPES + in __no_shape_var_type__ ): tensor_shape = [] else: diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index 41b4696174a5a6b168ba2b233a7ee18100888b81..80141730bc1a168c4776a7c4530e61fb4675f599 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -14,12 +14,15 @@ import copy import paddle -from paddle.fluid import core from paddle.fluid.framework import Variable from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import append_op_input_suffix from .dist_attribute import append_op_output_suffix -from .utils import convert_to_shard_spec, verify_shard_spec +from .utils import ( + convert_to_shard_spec, + verify_shard_spec, + __no_shape_var_type__, +) class DistributedOperator: @@ -73,10 +76,7 @@ class DistributedOperator: if tensor is None: tensor_shape = [] else: - if ( - tensor.type == core.VarDesc.VarType.READER - or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY - ): + if tensor.type in __no_shape_var_type__: tensor_shape = [] else: tensor_shape = tensor.shape @@ -87,11 +87,7 @@ class DistributedOperator: ) for tensor_name in self._serial_op.output_arg_names: tensor = self._serial_op.block._var_recursive(tensor_name) - if ( - tensor.type == core.VarDesc.VarType.READER - or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or tensor.type == core.VarDesc.VarType.STEP_SCOPES - ): + if tensor.type in __no_shape_var_type__: tensor_shape = [] else: tensor_shape = tensor.shape @@ -151,10 +147,7 @@ class DistributedOperator: for name in self.serial_op.input_arg_names: input_dist_attr = self.dist_attr.get_input_dist_attr(name) dims_mapping = input_dist_attr.dims_mapping - if ( - self.get_serial_input(name).type - == core.VarDesc.VarType.LOD_TENSOR_ARRAY - ): + if self.get_serial_input(name).type in __no_shape_var_type__: shape = [] else: shape = self.get_serial_input(name).shape @@ -174,12 +167,7 @@ class DistributedOperator: for name in self.serial_op.output_arg_names: output_dist_attr = self.dist_attr.get_output_dist_attr(name) dims_mapping = output_dist_attr.dims_mapping - if ( - self.get_serial_output(name).type - == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or self.get_serial_output(name).type - == core.VarDesc.VarType.STEP_SCOPES - ): + if self.get_serial_output(name).type in __no_shape_var_type__: shape = [] else: shape = self.get_serial_output(name).shape @@ -337,12 +325,7 @@ class DistributedOperatorHelper: if tensor is None: tensor_shape = [] else: - if ( - tensor.type == core.VarDesc.VarType.READER - or tensor.type - == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or tensor.type == core.VarDesc.VarType.STEP_SCOPES - ): + if tensor.type in __no_shape_var_type__: tensor_shape = [] else: tensor_shape = tensor.shape @@ -368,12 +351,7 @@ class DistributedOperatorHelper: if tensor is None: tensor_shape = [] else: - if ( - tensor.type == core.VarDesc.VarType.READER - or tensor.type - == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or tensor.type == core.VarDesc.VarType.STEP_SCOPES - ): + if tensor.type in __no_shape_var_type__: tensor_shape = [] else: tensor_shape = tensor.shape diff --git a/python/paddle/distributed/auto_parallel/dist_tensor.py b/python/paddle/distributed/auto_parallel/dist_tensor.py index 88c754f06f3d8a6dfc3a578fab5560cd13d356c2..9a6f9c41154e14d9a5b5807b27268d365cccfdb1 100644 --- a/python/paddle/distributed/auto_parallel/dist_tensor.py +++ b/python/paddle/distributed/auto_parallel/dist_tensor.py @@ -16,10 +16,9 @@ import copy import inspect import paddle -from paddle.fluid import core from paddle.fluid.framework import Parameter, Block, Variable from .dist_attribute import TensorDistributedAttribute -from .utils import _linear_idx2coordinate +from .utils import _linear_idx2coordinate, __no_shape_var_type__ class DistributedTensor: @@ -208,12 +207,7 @@ class DistributedTensor: def _init_default_dist_attr(self): if self._dist_attr.dims_mapping is None: - if ( - self.serial_tensor.type == core.VarDesc.VarType.READER - or self.serial_tensor.type - == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or self.serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES - ): + if self.serial_tensor.type in __no_shape_var_type__: tensor_shape = [] else: tensor_shape = self._serial_tensor.shape @@ -221,11 +215,7 @@ class DistributedTensor: self._dist_attr.dims_mapping = tensor_dims_mapping def validate_dist_attr(self): - if ( - self.serial_tensor.type == core.VarDesc.VarType.READER - or self.serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or self.serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES - ): + if self.serial_tensor.type in __no_shape_var_type__: return True tensor_shape = self.serial_tensor.shape if len(tensor_shape) != len(self.dist_attr.dims_mapping): diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index b8870106acb0a683a37e29b97aeb1c3a453ba1bc..116eaa97f1088530df16dc71b3f60167be532936 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -580,7 +580,7 @@ class Engine: metric.compute(*(outputs + self._labels)) ) ) - else: + elif mode == "train": assert isinstance( self._loss, Variable ), "the type of `loss` of the Engine arguments should be Variable." diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index d2f7e894149c79ba2a9061e021be3bacf6491dd7..cc8afb4f27173b674a325a4042f0735ed0435f31 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -13,13 +13,16 @@ # limitations under the License. import paddle -from paddle.fluid import core from .process_mesh import ProcessMesh from .process_mesh import get_current_process_mesh from .dist_context import get_default_distributed_context from .dist_tensor import DistributedTensor from .dist_op import DistributedOperatorHelper -from .utils import verify_shard_spec, convert_to_dims_mapping +from .utils import ( + verify_shard_spec, + convert_to_dims_mapping, + __no_shape_var_type__, +) def shard_tensor(x, process_mesh=None, shard_spec=None): @@ -79,11 +82,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): dist_tensor = DistributedTensor(x) serial_tensor = dist_tensor.serial_tensor dist_tensor.dist_attr.process_mesh = process_mesh - if ( - serial_tensor.type == core.VarDesc.VarType.READER - or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY - or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES - ): + if serial_tensor.type in __no_shape_var_type__: tensor_shape = [] else: tensor_shape = serial_tensor.shape diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 6ec52ff69796fe183eccfb12e7b9040df194f16e..cad9fe1d4277ea1190dcf44f5402334cc790d0ae 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -22,14 +22,16 @@ from paddle.distributed.auto_parallel.operators.common import ( ) from paddle.distributed.auto_parallel.dist_context import DistributedContext from .dist_attribute import OperatorDistributedAttribute -from .utils import is_backward_op, is_forward_op, is_loss_op, is_optimize_op from .operators.common import BACKWARD_ONLY_DIST_OPS +from .utils import ( + is_backward_op, + is_forward_op, + is_loss_op, + is_optimize_op, + __no_shape_var_type__, +) __varname_not_in_block__ = ["lod_tensor_blocking_queue"] -__not_shape_var_type__ = [ - core.VarDesc.VarType.READER, - core.VarDesc.VarType.STEP_SCOPES, -] class Partitioner: @@ -363,7 +365,7 @@ class Partitioner: var_dist_attrs = [ self._dist_context.get_tensor_dist_attr_for_program(var) for var in vars_ - if (var.type not in __not_shape_var_type__) + if (var.type not in __no_shape_var_type__) ] all_ops_annotated = all( @@ -468,7 +470,7 @@ def _partition_var( """ src_var = src_block.var(src_varname) - if src_var.type in __not_shape_var_type__: + if src_var.type in __no_shape_var_type__: persist = getattr(src_var, 'persistable', False) new_var = dst_block.create_var( type=src_var.type, diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index bf4aa34303a722dba8a5f12d691d3d5084b233d8..35b3483a31481c5af164cf5ec5a9358c2e6347a8 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -33,9 +33,12 @@ from paddle.distributed.auto_parallel.dist_attribute import ( OperatorDistributedAttribute, ) -__not_shape_var_type__ = [ +__no_shape_var_type__ = [ core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES, + core.VarDesc.VarType.LOD_TENSOR_ARRAY, + core.VarDesc.VarType.FEED_MINIBATCH, + core.VarDesc.VarType.FETCH_LIST, ] __not_naive_data_parallel_op__ = ["expand_v2"]