From bae722ecf56bdb6e1db9437fbad69e6189cacd49 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Wed, 12 Apr 2023 14:11:33 +0800 Subject: [PATCH] fixed import models bug --- x2paddle/project_convertor/pytorch/torch2paddle/varbase.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py b/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py index 166653e..b42648e 100644 --- a/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py +++ b/x2paddle/project_convertor/pytorch/torch2paddle/varbase.py @@ -24,7 +24,7 @@ def is_condition_one(idx): a[mask, :] a[mask, ...] """ - if not (isinstance(idx[0], paddle.Tensor) and \ + if not (isinstance(idx[0], paddle.Tensor) and idx[0].dtype == paddle_dtypes.t_bool): return False if len(idx) == 1: @@ -94,6 +94,7 @@ def __getitem__(self, idx): else: return out + VarBase = core.eager.Tensor VarBase.__getitem__ = __getitem__ -- GitLab