未验证 提交 186f5e0f 编写于 作者: J JYChen 提交者: GitHub

hack 1-D tensor to Scalar (#53552)

上级 e9882514
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from ..utils import compute_compatible_dim_mapping, is_dim_shard
from .common import (
DistributedOperatorImpl,
......@@ -70,7 +72,14 @@ class DistributedSliceImpl(DistributedOperatorImpl):
if i not in decrease_axis:
ref_indices.append(i)
if ref_indices == []:
assert len(out_dims_mapping) == 0
# NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
# with FLAGS_set_to_1d=True.
if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']:
assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False
else:
assert len(out_dims_mapping) == 0
else:
for i in range(len(out_dims_mapping)):
ref_index = ref_indices[i]
......@@ -140,6 +149,11 @@ class DistributedSliceImpl(DistributedOperatorImpl):
ref_indices.append(i)
if ref_dims_mapping == []:
# NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
# with FLAGS_set_to_1d=True.
if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']:
ref_dims_mapping = [-1]
assert ref_dims_mapping[0] == out_dims_mapping[0]
assert len(ref_dims_mapping) == len(out_dims_mapping)
changed = False
else:
......
......@@ -284,6 +284,15 @@ def is_integer_or_scalar_tensor(ele):
if isinstance(ele, int):
return True
elif isinstance(ele, Variable):
# NOTE(zoooo0820): For compatibility, if FLAGS_set_to_1d is set to True,
# 1-D tensor is still treated as a scalar, which means basic indexing.
# This will be removed in future.
if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']:
if len(ele.shape) == 1 and ele.shape[0] == 1:
warnings.warn(
"1-D Tensor will be treat as advanced indexing in future version. Currently, 1-D Tensor means a scalar, not vector, and please modify it to 0-D Tensor. If advanced indexing is needed, please use `export FLAGS_set_to_1d=False` to set the flag."
)
return True
if len(ele.shape) == 0:
return True
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册