diff --git a/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py b/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py index 166653ee1b7c90c6261041e883f6fc1efebe8e96..b42648e7772e73b48500e1d59580953239bcf3be 100644 --- a/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py +++ b/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py @@ -24,7 +24,7 @@ def is_condition_one(idx): a[mask, :] a[mask, ...] """ - if not (isinstance(idx[0], paddle.Tensor) and \ + if not (isinstance(idx[0], paddle.Tensor) and idx[0].dtype == paddle_dtypes.t_bool): return False if len(idx) == 1: @@ -94,6 +94,7 @@ def __getitem__(self, idx): else: return out + VarBase = core.eager.Tensor VarBase.__getitem__ = __getitem__