diff --git a/python/paddle/distributed/auto_parallel/operators/dist_slice.py b/python/paddle/distributed/auto_parallel/operators/dist_slice.py index 69ccd8d7bc868ba55809bac9e41320420bec193b..0110f54d481a0ffcb4d8e4ecf36aea7692ffe453 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_slice.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_slice.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle + from ..utils import compute_compatible_dim_mapping, is_dim_shard from .common import ( DistributedOperatorImpl, @@ -70,7 +72,14 @@ class DistributedSliceImpl(DistributedOperatorImpl): if i not in decrease_axis: ref_indices.append(i) if ref_indices == []: - assert len(out_dims_mapping) == 0 + # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D + # with FLAGS_set_to_1d=True. + if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']: + assert len(out_dims_mapping) == 1 + if is_dim_shard(out_dims_mapping[0]): + return False + else: + assert len(out_dims_mapping) == 0 else: for i in range(len(out_dims_mapping)): ref_index = ref_indices[i] @@ -140,6 +149,11 @@ class DistributedSliceImpl(DistributedOperatorImpl): ref_indices.append(i) if ref_dims_mapping == []: + # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D + # with FLAGS_set_to_1d=True. + if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']: + ref_dims_mapping = [-1] + assert ref_dims_mapping[0] == out_dims_mapping[0] assert len(ref_dims_mapping) == len(out_dims_mapping) changed = False else: diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 889da4ffaa82bf067f16502eacc5c1652c240650..fbd55ba83aa3551a0f45bf21a8fc2c23eed61d45 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -284,6 +284,15 @@ def is_integer_or_scalar_tensor(ele): if isinstance(ele, int): return True elif isinstance(ele, Variable): + # NOTE(zoooo0820): For compatibility, if FLAGS_set_to_1d is set to True, + # 1-D tensor is still treated as a scalar, which means basic indexing. + # This will be removed in future. + if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']: + if len(ele.shape) == 1 and ele.shape[0] == 1: + warnings.warn( + "1-D Tensor will be treat as advanced indexing in future version. Currently, 1-D Tensor means a scalar, not vector, and please modify it to 0-D Tensor. If advanced indexing is needed, please use `export FLAGS_set_to_1d=False` to set the flag." + ) + return True if len(ele.shape) == 0: return True return False