From 4157cdae95aa3cf87467c6733cd92aa82fa95249 Mon Sep 17 00:00:00 2001 From: Will-Nie <61083608+Will-Nie@users.noreply.github.com> Date: Thu, 25 Nov 2021 16:05:30 +0800 Subject: [PATCH] feature(nyp): add apple key to door treasure env(#128) * add apple key to door treasure and polish * add test, revise reward, build four envs * add 7x7-1 ADTKT --- dizoo/minigrid/__init__.py | 13 ++ dizoo/minigrid/config/minigrid_ngu_config.py | 22 +- .../minigrid/config/minigrid_onppo_config.py | 5 +- dizoo/minigrid/config/minigrid_r2d2_config.py | 16 +- .../config/minigrid_rnd_onppo_config.py | 11 +- dizoo/minigrid/envs/__init__.py | 1 + .../minigrid/envs/app_key_to_door_treasure.py | 220 ++++++++++++++++++ dizoo/minigrid/envs/minigrid_env.py | 150 +++++++++++- dizoo/minigrid/envs/test_minigrid_env.py | 74 ++++++ 9 files changed, 481 insertions(+), 31 deletions(-) create mode 100644 dizoo/minigrid/envs/app_key_to_door_treasure.py diff --git a/dizoo/minigrid/__init__.py b/dizoo/minigrid/__init__.py index e69de29..f2832cb 100644 --- a/dizoo/minigrid/__init__.py +++ b/dizoo/minigrid/__init__.py @@ -0,0 +1,13 @@ +from gym.envs.registration import register + +register(id='MiniGrid-AKTDT-7x7-1-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_7x7_1') + +register(id='MiniGrid-AKTDT-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure') + +register(id='MiniGrid-AKTDT-13x13-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_13x13') + +register(id='MiniGrid-AKTDT-13x13-1-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_13x13_1') + +register(id='MiniGrid-AKTDT-19x19-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19') + +register(id='MiniGrid-AKTDT-19x19-3-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19_3') \ No newline at end of file diff --git a/dizoo/minigrid/config/minigrid_ngu_config.py b/dizoo/minigrid/config/minigrid_ngu_config.py index 7f75f31..5f1ccf4 100644 --- a/dizoo/minigrid/config/minigrid_ngu_config.py +++ b/dizoo/minigrid/config/minigrid_ngu_config.py @@ -4,7 +4,7 @@ from easydict import EasyDict from ding.entry import serial_pipeline_reward_model_ngu print(torch.cuda.is_available(), torch.__version__) -collector_env_num = 32 #TODO +collector_env_num = 32 #TODO evaluator_env_num = 5 nstep = 5 minigrid_ppo_rnd_config = dict( @@ -25,15 +25,13 @@ minigrid_ppo_rnd_config = dict( learning_rate=5e-4, obs_shape=2739, action_shape=7, - batch_size=320, # transitions - + batch_size=320, # transitions update_per_collect=int(10), # 32*100/320=10 only_use_last_five_frames_for_icm_rnd=False, clear_buffer_per_iters=10, nstep=nstep, hidden_size_list=[128, 128, 64], type='rnd', - ), episodic_reward_model=dict( intrinsic_reward_type='add', @@ -41,7 +39,6 @@ minigrid_ppo_rnd_config = dict( obs_shape=2739, action_shape=7, batch_size=320, # transitions - update_per_collect=int(10), # 32*100/64=50 only_use_last_five_frames_for_icm_rnd=False, clear_buffer_per_iters=10, @@ -84,12 +81,13 @@ minigrid_ppo_rnd_config = dict( end=0.05, decay=1e5, ), - replay_buffer=dict(replay_buffer_size=30000, - # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization - alpha=0.6, - # (Float type) How much correction is used: 0 means no correction while 1 means full correction - beta=0.4, - ) + replay_buffer=dict( + replay_buffer_size=30000, + # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization + alpha=0.6, + # (Float type) How much correction is used: 0 means no correction while 1 means full correction + beta=0.4, + ) ), ), ) @@ -105,7 +103,7 @@ minigrid_ppo_rnd_create_config = dict( policy=dict(type='ngu'), rnd_reward_model=dict(type='rnd'), episodic_reward_model=dict(type='episodic'), - collector=dict(type='sample_ngu',) + collector=dict(type='sample_ngu', ) ) minigrid_ppo_rnd_create_config = EasyDict(minigrid_ppo_rnd_create_config) create_config = minigrid_ppo_rnd_create_config diff --git a/dizoo/minigrid/config/minigrid_onppo_config.py b/dizoo/minigrid/config/minigrid_onppo_config.py index 631295c..19e0073 100644 --- a/dizoo/minigrid/config/minigrid_onppo_config.py +++ b/dizoo/minigrid/config/minigrid_onppo_config.py @@ -6,7 +6,6 @@ minigrid_ppo_config = dict( exp_name="minigrid_fourrooms_onppo", # exp_name="minigrid_doorkey88_onppo", # exp_name="minigrid_doorkey_onppo", - env=dict( collector_env_num=8, evaluator_env_num=5, @@ -36,11 +35,11 @@ minigrid_ppo_config = dict( entropy_weight=0.001, clip_ratio=0.2, adv_norm=True, - value_norm=True, + value_norm=True, ), collect=dict( collector_env_num=collector_env_num, - n_sample=int(3200), + n_sample=int(3200), # here self.traj_length = 3200//8 = 400, because in minigrid env the max_length is 300. # in ding/worker/collector/sample_serial_collector.py # self._traj_len = max( diff --git a/dizoo/minigrid/config/minigrid_r2d2_config.py b/dizoo/minigrid/config/minigrid_r2d2_config.py index a9b97ce..e0ab553 100644 --- a/dizoo/minigrid/config/minigrid_r2d2_config.py +++ b/dizoo/minigrid/config/minigrid_r2d2_config.py @@ -22,7 +22,6 @@ minigrid_r2d2_config = dict( obs_shape=2739, action_shape=7, encoder_hidden_size_list=[128, 128, 512], - ), discount_factor=0.997, burnin_step=2, # TODO(pu) 20 @@ -54,13 +53,14 @@ minigrid_r2d2_config = dict( start=0.95, end=0.05, decay=1e5, - ), - replay_buffer=dict(replay_buffer_size=100000, - # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization - alpha=0.6, - # (Float type) How much correction is used: 0 means no correction while 1 means full correction - beta=0.4, - ) + ), + replay_buffer=dict( + replay_buffer_size=100000, + # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization + alpha=0.6, + # (Float type) How much correction is used: 0 means no correction while 1 means full correction + beta=0.4, + ) ), ), ) diff --git a/dizoo/minigrid/config/minigrid_rnd_onppo_config.py b/dizoo/minigrid/config/minigrid_rnd_onppo_config.py index 9d3f862..d96b8f5 100644 --- a/dizoo/minigrid/config/minigrid_rnd_onppo_config.py +++ b/dizoo/minigrid/config/minigrid_rnd_onppo_config.py @@ -1,8 +1,8 @@ from easydict import EasyDict from ding.entry import serial_pipeline_reward_model_onpolicy import torch -print(torch.__version__,torch.cuda.is_available()) -collector_env_num=8 +print(torch.__version__, torch.cuda.is_available()) +collector_env_num = 8 minigrid_ppo_rnd_config = dict( # exp_name='minigrid_empty8_rnd_onppo_b01_weight1000_maxlen100', # exp_name='minigrid_fourrooms_rnd_onppo_b01_weight1000_maxlen100', @@ -10,7 +10,6 @@ minigrid_ppo_rnd_config = dict( # exp_name='minigrid_doorkey_rnd_onppo_b01_weight1000_maxlen300', # exp_name='minigrid_kcs3r3_rnd_onppo_b01', # exp_name='minigrid_om2dlh_rnd_onppo_b01', - env=dict( collector_env_num=collector_env_num, evaluator_env_num=5, @@ -54,14 +53,14 @@ minigrid_ppo_rnd_config = dict( batch_size=320, # 64, learning_rate=3e-4, value_weight=0.5, - entropy_weight=0.001, + entropy_weight=0.001, clip_ratio=0.2, adv_norm=True, - value_norm=True, + value_norm=True, ), collect=dict( collector_env_num=collector_env_num, - n_sample=int(3200), + n_sample=int(3200), # here self.traj_length = 3200//8 = 400, because in minigrid env the max_length is 300. # in ding/worker/collector/sample_serial_collector.py # self._traj_len = max( diff --git a/dizoo/minigrid/envs/__init__.py b/dizoo/minigrid/envs/__init__.py index 9f8da54..8db2e59 100644 --- a/dizoo/minigrid/envs/__init__.py +++ b/dizoo/minigrid/envs/__init__.py @@ -1 +1,2 @@ from .minigrid_env import MiniGridEnv +from dizoo.minigrid.envs.app_key_to_door_treasure import AppleKeyToDoorTreasure, AppleKeyToDoorTreasure_13x13, AppleKeyToDoorTreasure_19x19, AppleKeyToDoorTreasure_13x13_1, AppleKeyToDoorTreasure_19x19_3, AppleKeyToDoorTreasure_7x7_1 \ No newline at end of file diff --git a/dizoo/minigrid/envs/app_key_to_door_treasure.py b/dizoo/minigrid/envs/app_key_to_door_treasure.py new file mode 100644 index 0000000..46c3831 --- /dev/null +++ b/dizoo/minigrid/envs/app_key_to_door_treasure.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from gym_minigrid.minigrid import * +from gym_minigrid.minigrid import WorldObj + + +class Ball(WorldObj): + + def __init__(self, color='blue'): + super(Ball, self).__init__('ball', color) + + def can_pickup(self): + return False + + def render(self, img): + fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color]) + + +class AppleKeyToDoorTreasure(MiniGridEnv): + """ + Classic 4 rooms gridworld environment. + Can specify agent and goal position, if not it set at random. + """ + + def __init__(self, agent_pos=None, goal_pos=None, grid_size=19, apple=2): + self._agent_default_pos = agent_pos + self._goal_default_pos = goal_pos + self.apple = apple + super().__init__(grid_size=grid_size, max_steps=100) + + def _gen_grid( + self, width, height + ): # Note that it is inherited from MiniGridEnv that if width and height == None, width = grid_size , height = grid_size + # Create the grid + self.grid = Grid(width, height) + + # Generate the surrounding walls + self.grid.horz_wall(0, 0) + self.grid.horz_wall(0, height - 1) + self.grid.vert_wall(0, 0) + self.grid.vert_wall(width - 1, 0) + + room_w = width // 2 + room_h = height // 2 + + # For each row of rooms + for j in range(0, 2): + + # For each column + for i in range(0, 2): + xL = i * room_w + yT = j * room_h + xR = xL + room_w + yB = yT + room_h + + # Bottom wall and door + if i + 1 < 2: + if j + 1 < 2: + self.grid.vert_wall(xR, yT, room_h) + #pos = (xR, self._rand_int(yT + 1, yB)) + else: + self.grid.vert_wall(xR, yT, room_h) + pos = (xR, self._rand_int(yT + 1, yB)) + self.grid.set(*pos, None) + + # Bottom wall and door + if j + 1 < 2: + if i + 1 < 2: + self.grid.horz_wall(xL, yB, room_w) + pos = (self._rand_int(xL + 1, xR), yB) + self.grid.set(*pos, None) + else: + self.grid.horz_wall(xL, yB, room_w) + pos = (self._rand_int(xL + 1, xR), yB) + self.put_obj(Door('yellow', is_locked=True), *pos) + + # Place a yellow key on the left side + pos1 = (self._rand_int(room_w + 1, 2 * room_w), self._rand_int(room_h + 1, 2 * room_h)) # self._rand_int: [) + self.put_obj(Key('yellow'), *pos1) + pos2_dummy_list = [] # to avoid overlap of apples + for i in range(self.apple): + pos2 = (self._rand_int(1, room_w), self._rand_int(1, room_h)) + while pos2 in pos2_dummy_list: + pos2 = (self._rand_int(1, room_w), self._rand_int(1, room_h)) + self.put_obj(Ball('red'), *pos2) + pos2_dummy_list.append(pos2) + # Randomize the player start position and orientation + if self._agent_default_pos is not None: + self.agent_pos = self._agent_default_pos + self.grid.set(*self._agent_default_pos, None) + self.agent_dir = self._rand_int(0, 4) # assuming random start direction + else: + self.place_agent() + + if self._goal_default_pos is not None: + goal = Goal() + self.put_obj(goal, *self._goal_default_pos) + goal.init_pos, goal.cur_pos = self._goal_default_pos + else: + self.place_obj(Goal()) + + self.mission = 'Reach the goal' + + def _reward_ball(self): + """ + Compute the reward to be given upon finding the apple + """ + + return 1 + + def _reward_goal(self): + """ + Compute the reward to be given upon success + """ + + return 10 + + def step(self, action): + self.step_count += 1 + + reward = 0 + done = False + + # Get the position in front of the agent + fwd_pos = self.front_pos + + # Get the contents of the cell in front of the agent + fwd_cell = self.grid.get(*fwd_pos) + + # Rotate left + if action == self.actions.left: + self.agent_dir -= 1 + if self.agent_dir < 0: + self.agent_dir += 4 + + # Rotate right + elif action == self.actions.right: + self.agent_dir = (self.agent_dir + 1) % 4 + + # Move forward + elif action == self.actions.forward: + if fwd_cell == None or fwd_cell.can_overlap(): # Ball and keys' can_overlap are False + self.agent_pos = fwd_pos + if fwd_cell != None and fwd_cell.type == 'goal': + done = True + reward = self._reward_goal() + if fwd_cell != None and fwd_cell.type == 'ball': + reward = self._reward_ball() + self.grid.set(*fwd_pos, None) + self.agent_pos = fwd_pos + if fwd_cell != None and fwd_cell.type == 'lava': + done = True + + # Pick up an object + elif action == self.actions.pickup: + if fwd_cell and fwd_cell.can_pickup(): + if self.carrying is None: + self.carrying = fwd_cell + self.carrying.cur_pos = np.array([-1, -1]) + self.grid.set(*fwd_pos, None) + + # Drop an object + elif action == self.actions.drop: + if not fwd_cell and self.carrying: + self.grid.set(*fwd_pos, self.carrying) + self.carrying.cur_pos = fwd_pos + self.carrying = None + + # Toggle/activate an object: Here, this will open the door if you have the right key + elif action == self.actions.toggle: + if fwd_cell: + fwd_cell.toggle(self, fwd_pos) + + # Done action (not used by default) + elif action == self.actions.done: + pass + + else: + assert False, "unknown action" + + if self.step_count >= self.max_steps: + done = True + + obs = self.gen_obs() + + return obs, reward, done, {} + + +class AppleKeyToDoorTreasure_13x13(AppleKeyToDoorTreasure): + + def __init__(self): + super().__init__(agent_pos=(2, 8), goal_pos=(7, 1), grid_size=13, apple=2) + + +class AppleKeyToDoorTreasure_19x19(AppleKeyToDoorTreasure): + + def __init__(self): + super().__init__(agent_pos=(2, 14), goal_pos=(10, 1), grid_size=19, apple=2) + + +class AppleKeyToDoorTreasure_13x13_1(AppleKeyToDoorTreasure): + + def __init__(self): + super().__init__(agent_pos=(2, 8), goal_pos=(7, 1), grid_size=13, apple=1) + + +class AppleKeyToDoorTreasure_7x7_1(AppleKeyToDoorTreasure): + + def __init__(self): + super().__init__(agent_pos=(1, 5), goal_pos=(4, 1), grid_size=7, apple=1) + +class AppleKeyToDoorTreasure_19x19_3(AppleKeyToDoorTreasure): + + def __init__(self): + super().__init__(agent_pos=(2, 14), goal_pos=(10, 1), grid_size=19, apple=3) + + +if __name__ == '__main__': + AppleKeyToDoorTreasure()._gen_grid(13, 13) # Note that Minigrid has set seeds automatically diff --git a/dizoo/minigrid/envs/minigrid_env.py b/dizoo/minigrid/envs/minigrid_env.py index e9e130e..a3d7634 100644 --- a/dizoo/minigrid/envs/minigrid_env.py +++ b/dizoo/minigrid/envs/minigrid_env.py @@ -8,7 +8,7 @@ import gym import numpy as np from matplotlib import animation import matplotlib.pyplot as plt -from gym_minigrid.wrappers import FlatObsWrapper, RGBImgPartialObsWrapper, ImgObsWrapper +from gym_minigrid.wrappers import FlatObsWrapper, RGBImgPartialObsWrapper, ImgObsWrapper, ViewSizeWrapper from gym_minigrid.window import Window from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo @@ -89,6 +89,144 @@ MINIGRID_INFO_DICT = { max_step=100, use_wrappers=None, ), + 'MiniGrid-AKTDT-v0': MiniGridEnvInfo( + agent_num=1, + obs_space=EnvElementInfo(shape=(2739, ), value={ + 'min': 0, + 'max': 8, + 'dtype': np.float32 + }), + act_space=EnvElementInfo( + shape=(1, ), + value={ + 'min': 0, + 'max': 7, # [0, 7) + 'dtype': np.int64, + } + ), + rew_space=EnvElementInfo(shape=(1, ), value={ + 'min': 0, + 'max': 1, + 'dtype': np.float32 + }), + max_step=500, + use_wrappers=None, + ), + 'MiniGrid-AKTDT-7x7-1-v0': MiniGridEnvInfo( + agent_num=1, + obs_space=EnvElementInfo(shape=(2619, ), value={ + 'min': 0, + 'max': 8, + 'dtype': np.float32 + }), + act_space=EnvElementInfo( + shape=(1, ), + value={ + 'min': 0, + 'max': 7, # [0, 7) + 'dtype': np.int64, + } + ), + rew_space=EnvElementInfo(shape=(1, ), value={ + 'min': 0, + 'max': 1, + 'dtype': np.float32 + }), + max_step=500, + use_wrappers=None, + ), + 'MiniGrid-AKTDT-13x13-v0': MiniGridEnvInfo( + agent_num=1, + obs_space=EnvElementInfo(shape=(2667, ), value={ + 'min': 0, + 'max': 8, + 'dtype': np.float32 + }), + act_space=EnvElementInfo( + shape=(1, ), + value={ + 'min': 0, + 'max': 7, # [0, 7) + 'dtype': np.int64, + } + ), + rew_space=EnvElementInfo(shape=(1, ), value={ + 'min': 0, + 'max': 1, + 'dtype': np.float32 + }), + max_step=500, + use_wrappers=None, + ), + 'MiniGrid-AKTDT-13x13-1-v0': MiniGridEnvInfo( + agent_num=1, + obs_space=EnvElementInfo(shape=(2667, ), value={ + 'min': 0, + 'max': 8, + 'dtype': np.float32 + }), + act_space=EnvElementInfo( + shape=(1, ), + value={ + 'min': 0, + 'max': 7, # [0, 7) + 'dtype': np.int64, + } + ), + rew_space=EnvElementInfo(shape=(1, ), value={ + 'min': 0, + 'max': 1, + 'dtype': np.float32 + }), + max_step=500, + use_wrappers=None, + ), + 'MiniGrid-AKTDT-19x19-v0': MiniGridEnvInfo( + agent_num=1, + obs_space=EnvElementInfo(shape=(2739, ), value={ + 'min': 0, + 'max': 8, + 'dtype': np.float32 + }), + act_space=EnvElementInfo( + shape=(1, ), + value={ + 'min': 0, + 'max': 7, # [0, 7) + 'dtype': np.int64, + } + ), + rew_space=EnvElementInfo(shape=(1, ), value={ + 'min': 0, + 'max': 1, + 'dtype': np.float32 + }), + max_step=500, + use_wrappers=None, + ), + 'MiniGrid-AKTDT-19x19-3-v0': MiniGridEnvInfo( + agent_num=1, + obs_space=EnvElementInfo(shape=(2739, ), value={ + 'min': 0, + 'max': 8, + 'dtype': np.float32 + }), + act_space=EnvElementInfo( + shape=(1, ), + value={ + 'min': 0, + 'max': 7, # [0, 7) + 'dtype': np.int64, + } + ), + rew_space=EnvElementInfo(shape=(1, ), value={ + 'min': 0, + 'max': 1, + 'dtype': np.float32 + }), + max_step=500, + use_wrappers=None, + ), 'MiniGrid-DoorKey-8x8-v0': MiniGridEnvInfo( agent_num=1, obs_space=EnvElementInfo(shape=(2739, ), value={ @@ -231,6 +369,14 @@ class MiniGridEnv(BaseEnv): def reset(self) -> np.ndarray: if not self._init_flag: self._env = gym.make(self._env_id) + if self._env_id in ['MiniGrid-AKTDT-13x13-v0' or 'MiniGrid-AKTDT-13x13-1-v0']: + self._env = ViewSizeWrapper( + self._env, agent_view_size=5 + ) # customize the agent field of view size, note this must be an odd number # This also related to the observation space, see gym_minigrid.wrappers for more details + if self._env_id == 'MiniGrid-AKTDT-7x7-1-v0': + self._env = ViewSizeWrapper( + self._env, agent_view_size=3 + ) if self._flat_obs: self._env = FlatObsWrapper(self._env) # self._env = RGBImgPartialObsWrapper(self._env) @@ -282,7 +428,7 @@ class MiniGridEnv(BaseEnv): self.display_frames_as_gif(self._frames, path) self._save_replay_count += 1 obs = to_ndarray(obs).astype(np.float32) - rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,) + rew = to_ndarray([rew]) # wrapped to be transferred to a array with shape (1,) return BaseEnvTimestep(obs, rew, done, info) def info(self) -> MiniGridEnvInfo: diff --git a/dizoo/minigrid/envs/test_minigrid_env.py b/dizoo/minigrid/envs/test_minigrid_env.py index a76e2b0..1748ef0 100644 --- a/dizoo/minigrid/envs/test_minigrid_env.py +++ b/dizoo/minigrid/envs/test_minigrid_env.py @@ -2,7 +2,23 @@ import pytest import os import numpy as np from dizoo.minigrid.envs import MiniGridEnv +from easydict import EasyDict +import copy +# The following two cfg can be tested through TestMiniGridAKTDTnv +config = dict( + env_id='MiniGrid-AKTDT-13x13-v0', + flat_obs=True, +) +cfg = EasyDict(copy.deepcopy(config)) +cfg.cfg_type = 'MiniGridEnvDict' + +config2 = dict( + env_id='MiniGrid-AKTDT-7x7-1-v0', + flat_obs=True, +) +cfg2 = EasyDict(copy.deepcopy(config2)) +cfg2.cfg_type = 'MiniGridEnvDict' @pytest.mark.envtest class TestMiniGridEnv: @@ -33,3 +49,61 @@ class TestMiniGridEnv: env.reset() print(env.info()) env.close() + + +@pytest.mark.envtest +class TestMiniGridAKTDTnv: + + def test_adtkt_13(self): + env = MiniGridEnv(cfg2) + env.seed(314) + path = './video' + if not os.path.exists(path): + os.mkdir(path) + env.enable_save_replay(path) + assert env._seed == 314 + obs = env.reset() + act_val = env.info().act_space.value + min_val, max_val = act_val['min'], act_val['max'] + for i in range(env._max_step): + random_action = np.random.randint(min_val, max_val, size=(1, )) + timestep = env.step(random_action) + print(timestep) + print(timestep.obs.max()) + assert isinstance(timestep.obs, np.ndarray) + assert isinstance(timestep.done, bool) + assert timestep.obs.shape == (2667, ) + assert timestep.reward.shape == (1, ) + assert timestep.reward >= env.info().rew_space.value['min'] + assert timestep.reward <= env.info().rew_space.value['max'] + if timestep.done: + env.reset() + print(env.info()) + env.close() + + def test_adtkt_7(self): + env = MiniGridEnv(cfg2) + env.seed(314) + path = './video' + if not os.path.exists(path): + os.mkdir(path) + env.enable_save_replay(path) + assert env._seed == 314 + obs = env.reset() + act_val = env.info().act_space.value + min_val, max_val = act_val['min'], act_val['max'] + for i in range(env._max_step): + random_action = np.random.randint(min_val, max_val, size=(1, )) + timestep = env.step(random_action) + print(timestep) + print(timestep.obs.max()) + assert isinstance(timestep.obs, np.ndarray) + assert isinstance(timestep.done, bool) + assert timestep.obs.shape == (2619, ) + assert timestep.reward.shape == (1, ) + assert timestep.reward >= env.info().rew_space.value['min'] + assert timestep.reward <= env.info().rew_space.value['max'] + if timestep.done: + env.reset() + print(env.info()) + env.close() -- GitLab