未验证 提交 186f5e0f 编写于 作者: J JYChen 提交者: GitHub

hack 1-D tensor to Scalar (#53552)

上级 e9882514
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
from ..utils import compute_compatible_dim_mapping, is_dim_shard from ..utils import compute_compatible_dim_mapping, is_dim_shard
from .common import ( from .common import (
DistributedOperatorImpl, DistributedOperatorImpl,
...@@ -70,7 +72,14 @@ class DistributedSliceImpl(DistributedOperatorImpl): ...@@ -70,7 +72,14 @@ class DistributedSliceImpl(DistributedOperatorImpl):
if i not in decrease_axis: if i not in decrease_axis:
ref_indices.append(i) ref_indices.append(i)
if ref_indices == []: if ref_indices == []:
assert len(out_dims_mapping) == 0 # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
# with FLAGS_set_to_1d=True.
if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']:
assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False
else:
assert len(out_dims_mapping) == 0
else: else:
for i in range(len(out_dims_mapping)): for i in range(len(out_dims_mapping)):
ref_index = ref_indices[i] ref_index = ref_indices[i]
...@@ -140,6 +149,11 @@ class DistributedSliceImpl(DistributedOperatorImpl): ...@@ -140,6 +149,11 @@ class DistributedSliceImpl(DistributedOperatorImpl):
ref_indices.append(i) ref_indices.append(i)
if ref_dims_mapping == []: if ref_dims_mapping == []:
# NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
# with FLAGS_set_to_1d=True.
if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']:
ref_dims_mapping = [-1]
assert ref_dims_mapping[0] == out_dims_mapping[0]
assert len(ref_dims_mapping) == len(out_dims_mapping) assert len(ref_dims_mapping) == len(out_dims_mapping)
changed = False changed = False
else: else:
......
...@@ -284,6 +284,15 @@ def is_integer_or_scalar_tensor(ele): ...@@ -284,6 +284,15 @@ def is_integer_or_scalar_tensor(ele):
if isinstance(ele, int): if isinstance(ele, int):
return True return True
elif isinstance(ele, Variable): elif isinstance(ele, Variable):
# NOTE(zoooo0820): For compatibility, if FLAGS_set_to_1d is set to True,
# 1-D tensor is still treated as a scalar, which means basic indexing.
# This will be removed in future.
if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']:
if len(ele.shape) == 1 and ele.shape[0] == 1:
warnings.warn(
"1-D Tensor will be treat as advanced indexing in future version. Currently, 1-D Tensor means a scalar, not vector, and please modify it to 0-D Tensor. If advanced indexing is needed, please use `export FLAGS_set_to_1d=False` to set the flag."
)
return True
if len(ele.shape) == 0: if len(ele.shape) == 0:
return True return True
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册