diff --git a/ding/__init__.py b/ding/__init__.py index 629deae22a585ec78bff333f107afd98afe1e912..065366d27489b4fe7cc641a01091d042db18c8f9 100644 --- a/ding/__init__.py +++ b/ding/__init__.py @@ -1,5 +1,4 @@ import os -import torch __TITLE__ = 'DI-engine' __VERSION__ = 'v0.2.0' @@ -11,7 +10,3 @@ __version__ = __VERSION__ enable_hpc_rl = False enable_linklink = os.environ.get('ENABLE_LINKLINK', 'false').lower() == 'true' enable_numba = True - - -def torch_gt_131(): - return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131 diff --git a/ding/compatibility.py b/ding/compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..086512af94f5a0a55311ad44e350407452f97553 --- /dev/null +++ b/ding/compatibility.py @@ -0,0 +1,5 @@ +import torch + + +def torch_gt_131(): + return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131 diff --git a/ding/torch_utils/network/nn_module.py b/ding/torch_utils/network/nn_module.py index 76cb1e3644f8d6cf1de1afd46c3c44e94fb6f886..c7298c5b3d5eceb8144e99d9e9659e7a08056748 100644 --- a/ding/torch_utils/network/nn_module.py +++ b/ding/torch_utils/network/nn_module.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.init import xavier_normal_, kaiming_normal_, orthogonal_ from typing import Union, Tuple, List, Callable -from ding import torch_gt_131 +from ding.compatibility import torch_gt_131 from .normalization import build_normalization diff --git a/ding/utils/data/collate_fn.py b/ding/utils/data/collate_fn.py index 57052b22316c086e2dc75846c0e667517c43f3ff..e871b91f39aae75e13fc6217b6b096e37b61bc30 100644 --- a/ding/utils/data/collate_fn.py +++ b/ding/utils/data/collate_fn.py @@ -5,7 +5,7 @@ import torch import re from torch._six import string_classes import collections.abc as container_abcs -from ding import torch_gt_131 +from ding.compatibility import torch_gt_131 int_classes = int np_str_obj_array_pattern = re.compile(r'[SaUO]')