提交 35241df3 编写于 作者: N niuyazhe

feature(nyz): add vim in docker and add multiple seed cli

上级 58084df3
......@@ -3,7 +3,7 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime
WORKDIR /ding
RUN apt update \
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git gcc \g++ make locales -y \
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make locales -y \
&& apt clean \
&& rm -rf /var/cache/apt/* \
&& sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
......
......@@ -414,6 +414,8 @@ def compile_config(
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
if 'exp_name' not in cfg:
cfg.exp_name = 'default_experiment'
# add seed as suffix of exp_name
cfg.exp_name = cfg.exp_name + '_seed{}'.format(seed)
if save_cfg:
if not os.path.exists(cfg.exp_name):
try:
......@@ -524,6 +526,8 @@ def compile_config_parallel(
cfg.system.coordinator = deep_merge_dicts(Coordinator.default_config(), cfg.system.coordinator)
# seed
cfg.seed = seed
# add seed as suffix of exp_name
cfg.exp_name = cfg.exp_name + '_seed{}'.format(seed)
if save_cfg:
save_config(cfg, save_path)
......
from typing import List, Union
import click
from click.core import Context, Option
import numpy as np
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
from .predefined_config import get_predefined_config
......@@ -65,7 +67,8 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
'-s',
'--seed',
type=int,
default=0,
default=[0],
multiple=True,
help='random generator seed(for all the possible package: random, numpy, torch and user env)'
)
@click.option('-e', '--env', type=str, help='RL env name')
......@@ -117,7 +120,7 @@ def cli(
# serial/eval
mode: str,
config: str,
seed: int,
seed: Union[int, List],
env: str,
policy: str,
train_iter: int,
......@@ -155,6 +158,8 @@ def cli(
from ..utils.profiler_helper import Profiler
profiler = Profiler()
profiler.profile(profile)
def run_single_pipeline(seed, config):
if mode == 'serial':
from .serial_entry import serial_pipeline
if config is None:
......@@ -212,8 +217,8 @@ def cli(
dist_add_replicas, dist_delete_replicas, dist_restart_replicas
if module == 'config':
dist_prepare_config(
config, seed, platform, coordinator_host, learner_host, collector_host, coordinator_port, learner_port,
collector_port
config, seed, platform, coordinator_host, learner_host, collector_host, coordinator_port,
learner_port, collector_port
)
elif module == 'coordinator':
dist_launch_coordinator(config, seed, coordinator_port, disable_flask_log)
......@@ -241,3 +246,10 @@ def cli(
if config is None:
config = get_predefined_config(env, policy)
eval(config, seed, load_path=load_path, replay_path=replay_path)
if isinstance(seed, (list, tuple)):
assert len(seed) > 0, "Please input at least 1 seed"
for s in seed:
run_single_pipeline(s, config)
else:
raise TypeError("invalid seed type: {}".format(type(seed)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册