提交 171dddc4 编写于 作者: N niuyazhe

polish(nyz): add torch1.1.0 compatibility for torch.utils.data

上级 8df82e01
......@@ -4,6 +4,7 @@ import sys
import traceback
from typing import Callable
import torch
import torch.utils.data # torch1.1.0 compatibility
from ding.utils import read_file, save_file
logger = logging.getLogger('default_logger')
......
......@@ -15,6 +15,10 @@ default_collate_err_msg_format = (
)
def torch_gt_131():
return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131
def default_collate(batch: Sequence,
cat_1dim: bool = True,
ignore_prefix: list = ['collate_ignore']) -> Union[torch.Tensor, Mapping, Sequence]:
......@@ -50,7 +54,7 @@ def default_collate(batch: Sequence,
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
if torch_gt_131() and torch.utils.data.get_worker_info() is not None:
# If we're in a background process, directly concatenate into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
......
......@@ -74,13 +74,6 @@ def main(cfg, seed=0):
if train_data is None:
break
learner.train(train_data, collector.envstep)
# evaluate
evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册