提交 6d62a71f 编写于 作者: N niuyazhe

fix(nyz): fix pytorch1.9.0 compatibility bug and change naive buffer log freq

上级 21767320
......@@ -60,7 +60,7 @@ def test_serial_pipeline_il_ppo():
# il training 1
il_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
il_config[0].policy.learn.train_epoch = 10
il_config[0].policy.learn.train_epoch = 20
il_config[0].policy.type = 'ppo_il'
_, converge_stop_flag = serial_pipeline_il(il_config, seed=314, data_path=expert_data_path)
assert converge_stop_flag
......
......@@ -3,8 +3,10 @@ from typing import List, Dict, Union, Any
import torch
import re
from torch._six import container_abcs, string_classes, int_classes
from torch._six import string_classes
import collections.abc as container_abcs
int_classes = int
np_str_obj_array_pattern = re.compile(r'[SaUO]')
default_collate_err_msg_format = (
......
......@@ -77,7 +77,7 @@ class NaiveReplayBuffer(IBuffer):
)
# Periodic thruput. Here by default, monitor range is 60 seconds. You can modify it for free.
self._periodic_thruput_monitor = PeriodicThruputMonitor(
self._instance_name, EasyDict(seconds=3), self._logger, self._tb_logger
self._instance_name, EasyDict(seconds=60), self._logger, self._tb_logger
)
def start(self) -> None:
......
......@@ -48,7 +48,7 @@ setup(
'requests>=2.25.1',
'six',
'gym>=0.20.0', # pypy incompatible
'torch>=1.3.1,<=1.8.0', # PyTorch 1.9.0 is available, but you need to do something like https://github.com/opendilab/DI-engine/discussions/81
'torch>=1.3.1,<=1.9.0', # PyTorch 1.9.0 is available, if some errors, you need to do something like https://github.com/opendilab/DI-engine/discussions/81
'pyyaml',
'easydict==1.9',
'tensorboardX>=2.1,<=2.2',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册