From 6d62a71f7b7991b6acac77e19621768eac2930bf Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 30 Sep 2021 15:28:06 +0800 Subject: [PATCH] fix(nyz): fix pytorch1.9.0 compatibility bug and change naive buffer log freq --- ding/entry/tests/test_serial_entry_il.py | 2 +- ding/utils/data/collate_fn.py | 4 +++- ding/worker/replay_buffer/naive_buffer.py | 2 +- setup.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ding/entry/tests/test_serial_entry_il.py b/ding/entry/tests/test_serial_entry_il.py index 50bd1522..2540d335 100644 --- a/ding/entry/tests/test_serial_entry_il.py +++ b/ding/entry/tests/test_serial_entry_il.py @@ -60,7 +60,7 @@ def test_serial_pipeline_il_ppo(): # il training 1 il_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] - il_config[0].policy.learn.train_epoch = 10 + il_config[0].policy.learn.train_epoch = 20 il_config[0].policy.type = 'ppo_il' _, converge_stop_flag = serial_pipeline_il(il_config, seed=314, data_path=expert_data_path) assert converge_stop_flag diff --git a/ding/utils/data/collate_fn.py b/ding/utils/data/collate_fn.py index 702a24fb..9e171a13 100644 --- a/ding/utils/data/collate_fn.py +++ b/ding/utils/data/collate_fn.py @@ -3,8 +3,10 @@ from typing import List, Dict, Union, Any import torch import re -from torch._six import container_abcs, string_classes, int_classes +from torch._six import string_classes +import collections.abc as container_abcs +int_classes = int np_str_obj_array_pattern = re.compile(r'[SaUO]') default_collate_err_msg_format = ( diff --git a/ding/worker/replay_buffer/naive_buffer.py b/ding/worker/replay_buffer/naive_buffer.py index 1318495b..993cfd01 100644 --- a/ding/worker/replay_buffer/naive_buffer.py +++ b/ding/worker/replay_buffer/naive_buffer.py @@ -77,7 +77,7 @@ class NaiveReplayBuffer(IBuffer): ) # Periodic thruput. Here by default, monitor range is 60 seconds. You can modify it for free. self._periodic_thruput_monitor = PeriodicThruputMonitor( - self._instance_name, EasyDict(seconds=3), self._logger, self._tb_logger + self._instance_name, EasyDict(seconds=60), self._logger, self._tb_logger ) def start(self) -> None: diff --git a/setup.py b/setup.py index 53abfecf..5167ca4f 100755 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ setup( 'requests>=2.25.1', 'six', 'gym>=0.20.0', # pypy incompatible - 'torch>=1.3.1,<=1.8.0', # PyTorch 1.9.0 is available, but you need to do something like https://github.com/opendilab/DI-engine/discussions/81 + 'torch>=1.3.1,<=1.9.0', # PyTorch 1.9.0 is available, if some errors, you need to do something like https://github.com/opendilab/DI-engine/discussions/81 'pyyaml', 'easydict==1.9', 'tensorboardX>=2.1,<=2.2', -- GitLab