未验证 提交 339aefac 编写于 作者: Z zhaoyingli 提交者: GitHub

manage no shape var type (#47775)

上级 692a9632
......@@ -18,7 +18,7 @@ import logging
from paddle.fluid import core
from .utils import is_naive_data_parallel, get_logger
from .utils import is_gradient_clip_op, __not_shape_var_type__
from .utils import is_gradient_clip_op, __no_shape_var_type__
from .operators import find_compatible_distributed_operator_impls
from .dist_context import _node_id
from .dist_attribute import TensorDistributedAttribute
......@@ -151,11 +151,7 @@ class Completer:
return False
tensor_desc = tensor_node.var()
# Skip reader tensor
if (
tensor_desc.type() == core.VarDesc.VarType.READER
or tensor_desc.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor_desc.type == core.VarDesc.VarType.STEP_SCOPES
):
if tensor_desc.type() in __no_shape_var_type__:
return False
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node
......@@ -621,7 +617,7 @@ class Completer:
):
if (
tensor_node.var().type()
in __not_shape_var_type__
in __no_shape_var_type__
or len(tensor_node.var().shape()) != 1
):
flag = False
......@@ -633,7 +629,7 @@ class Completer:
):
if (
tensor_node.var().type()
in __not_shape_var_type__
in __no_shape_var_type__
or len(tensor_node.var().shape()) != 1
):
flag = False
......
......@@ -22,7 +22,7 @@ from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh
from .utils import _copy_dist_attr_to_cpp
from .utils import is_loss_grad_op
from .utils import is_loss_grad_op, __no_shape_var_type__
# There always exists a default context for user. And user can set it to another one.
......@@ -862,11 +862,7 @@ class DistributedContext:
for dist_tensor in self._dist_tensors_for_program.values():
serial_tensor = dist_tensor.serial_tensor
dist_attr = dist_tensor.dist_attr
if (
serial_tensor.type == core.VarDesc.VarType.READER
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
if serial_tensor.type in __no_shape_var_type__:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
......@@ -896,10 +892,7 @@ class DistributedContext:
else:
if (
dist_op.get_serial_input(arg_name).type
== core.VarDesc.VarType.READER
or dist_op.get_serial_input(arg_name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or dist_op.serial_op.type == "create_py_reader"
in __no_shape_var_type__
):
tensor_shape = []
else:
......@@ -923,11 +916,7 @@ class DistributedContext:
for arg_name in serial_op.output_arg_names:
if (
dist_op.get_serial_output(arg_name).type
== core.VarDesc.VarType.READER
or dist_op.get_serial_output(arg_name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or dist_op.get_serial_output(arg_name).type
== core.VarDesc.VarType.STEP_SCOPES
in __no_shape_var_type__
):
tensor_shape = []
else:
......
......@@ -14,12 +14,15 @@
import copy
import paddle
from paddle.fluid import core
from paddle.fluid.framework import Variable
from .dist_attribute import OperatorDistributedAttribute
from .dist_attribute import append_op_input_suffix
from .dist_attribute import append_op_output_suffix
from .utils import convert_to_shard_spec, verify_shard_spec
from .utils import (
convert_to_shard_spec,
verify_shard_spec,
__no_shape_var_type__,
)
class DistributedOperator:
......@@ -73,10 +76,7 @@ class DistributedOperator:
if tensor is None:
tensor_shape = []
else:
if (
tensor.type == core.VarDesc.VarType.READER
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
if tensor.type in __no_shape_var_type__:
tensor_shape = []
else:
tensor_shape = tensor.shape
......@@ -87,11 +87,7 @@ class DistributedOperator:
)
for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name)
if (
tensor.type == core.VarDesc.VarType.READER
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
if tensor.type in __no_shape_var_type__:
tensor_shape = []
else:
tensor_shape = tensor.shape
......@@ -151,10 +147,7 @@ class DistributedOperator:
for name in self.serial_op.input_arg_names:
input_dist_attr = self.dist_attr.get_input_dist_attr(name)
dims_mapping = input_dist_attr.dims_mapping
if (
self.get_serial_input(name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
if self.get_serial_input(name).type in __no_shape_var_type__:
shape = []
else:
shape = self.get_serial_input(name).shape
......@@ -174,12 +167,7 @@ class DistributedOperator:
for name in self.serial_op.output_arg_names:
output_dist_attr = self.dist_attr.get_output_dist_attr(name)
dims_mapping = output_dist_attr.dims_mapping
if (
self.get_serial_output(name).type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or self.get_serial_output(name).type
== core.VarDesc.VarType.STEP_SCOPES
):
if self.get_serial_output(name).type in __no_shape_var_type__:
shape = []
else:
shape = self.get_serial_output(name).shape
......@@ -337,12 +325,7 @@ class DistributedOperatorHelper:
if tensor is None:
tensor_shape = []
else:
if (
tensor.type == core.VarDesc.VarType.READER
or tensor.type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
if tensor.type in __no_shape_var_type__:
tensor_shape = []
else:
tensor_shape = tensor.shape
......@@ -368,12 +351,7 @@ class DistributedOperatorHelper:
if tensor is None:
tensor_shape = []
else:
if (
tensor.type == core.VarDesc.VarType.READER
or tensor.type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
if tensor.type in __no_shape_var_type__:
tensor_shape = []
else:
tensor_shape = tensor.shape
......
......@@ -16,10 +16,9 @@ import copy
import inspect
import paddle
from paddle.fluid import core
from paddle.fluid.framework import Parameter, Block, Variable
from .dist_attribute import TensorDistributedAttribute
from .utils import _linear_idx2coordinate
from .utils import _linear_idx2coordinate, __no_shape_var_type__
class DistributedTensor:
......@@ -208,12 +207,7 @@ class DistributedTensor:
def _init_default_dist_attr(self):
if self._dist_attr.dims_mapping is None:
if (
self.serial_tensor.type == core.VarDesc.VarType.READER
or self.serial_tensor.type
== core.VarDesc.VarType.LOD_TENSOR_ARRAY
or self.serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
if self.serial_tensor.type in __no_shape_var_type__:
tensor_shape = []
else:
tensor_shape = self._serial_tensor.shape
......@@ -221,11 +215,7 @@ class DistributedTensor:
self._dist_attr.dims_mapping = tensor_dims_mapping
def validate_dist_attr(self):
if (
self.serial_tensor.type == core.VarDesc.VarType.READER
or self.serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or self.serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
if self.serial_tensor.type in __no_shape_var_type__:
return True
tensor_shape = self.serial_tensor.shape
if len(tensor_shape) != len(self.dist_attr.dims_mapping):
......
......@@ -580,7 +580,7 @@ class Engine:
metric.compute(*(outputs + self._labels))
)
)
else:
elif mode == "train":
assert isinstance(
self._loss, Variable
), "the type of `loss` of the Engine arguments should be Variable."
......
......@@ -13,13 +13,16 @@
# limitations under the License.
import paddle
from paddle.fluid import core
from .process_mesh import ProcessMesh
from .process_mesh import get_current_process_mesh
from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperatorHelper
from .utils import verify_shard_spec, convert_to_dims_mapping
from .utils import (
verify_shard_spec,
convert_to_dims_mapping,
__no_shape_var_type__,
)
def shard_tensor(x, process_mesh=None, shard_spec=None):
......@@ -79,11 +82,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
dist_tensor = DistributedTensor(x)
serial_tensor = dist_tensor.serial_tensor
dist_tensor.dist_attr.process_mesh = process_mesh
if (
serial_tensor.type == core.VarDesc.VarType.READER
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES
):
if serial_tensor.type in __no_shape_var_type__:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
......
......@@ -22,14 +22,16 @@ from paddle.distributed.auto_parallel.operators.common import (
)
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from .dist_attribute import OperatorDistributedAttribute
from .utils import is_backward_op, is_forward_op, is_loss_op, is_optimize_op
from .operators.common import BACKWARD_ONLY_DIST_OPS
from .utils import (
is_backward_op,
is_forward_op,
is_loss_op,
is_optimize_op,
__no_shape_var_type__,
)
__varname_not_in_block__ = ["lod_tensor_blocking_queue"]
__not_shape_var_type__ = [
core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES,
]
class Partitioner:
......@@ -363,7 +365,7 @@ class Partitioner:
var_dist_attrs = [
self._dist_context.get_tensor_dist_attr_for_program(var)
for var in vars_
if (var.type not in __not_shape_var_type__)
if (var.type not in __no_shape_var_type__)
]
all_ops_annotated = all(
......@@ -468,7 +470,7 @@ def _partition_var(
"""
src_var = src_block.var(src_varname)
if src_var.type in __not_shape_var_type__:
if src_var.type in __no_shape_var_type__:
persist = getattr(src_var, 'persistable', False)
new_var = dst_block.create_var(
type=src_var.type,
......
......@@ -33,9 +33,12 @@ from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
)
__not_shape_var_type__ = [
__no_shape_var_type__ = [
core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
core.VarDesc.VarType.FEED_MINIBATCH,
core.VarDesc.VarType.FETCH_LIST,
]
__not_naive_data_parallel_op__ = ["expand_v2"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册