未验证 提交 d1188d71 编写于 作者: J jayyoung0802 提交者: GitHub

polish(yzj): add DataParallel and DataDistributedParallel (#123)

* add spaceinvaders multi gpu

* add dp and ddp

* Update __init__.py

* recover init
上级 cbee45b4
......@@ -8,3 +8,4 @@ from .network import *
from .optimizer_helper import Adam, RMSprop
from .nn_test_helper import is_differentiable
from .math_helper import cov
from .dataparallel import DataParallel
import torch
import torch.nn as nn
class DataParallel(nn.DataParallel):
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super().__init__(module,device_ids=None, output_device=None, dim=0)
self.module = module
def parameters(self, recurse: bool = True):
return self.module.parameters(recurse = True)
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
from ding.utils import DistContext
space_invaders_dqn_config = dict(
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
multi_gpu=True,
),
collect=dict(n_sample=100, ),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
space_invaders_dqn_config = EasyDict(space_invaders_dqn_config)
main_config = space_invaders_dqn_config
space_invaders_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
space_invaders_dqn_create_config = EasyDict(space_invaders_dqn_create_config)
create_config = space_invaders_dqn_create_config
if __name__ == '__main__':
with DistContext():
serial_pipeline((main_config, create_config), seed=0)
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
from ding.model.template.q_learning import DQN
from ding.torch_utils import DataParallel
space_invaders_dqn_config = dict(
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
collect=dict(n_sample=100, ),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
space_invaders_dqn_config = EasyDict(space_invaders_dqn_config)
main_config = space_invaders_dqn_config
space_invaders_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
space_invaders_dqn_create_config = EasyDict(space_invaders_dqn_create_config)
create_config = space_invaders_dqn_create_config
if __name__ == '__main__':
model = DataParallel(DQN(obs_shape=[4, 84, 84],action_shape=6))
serial_pipeline((main_config, create_config), seed=0, model=model)
import os
import torch
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.policy import DQNPolicy
from ding.model import DQN
from ding.utils import set_pkg_seed
from ding.rl_utils import get_epsilon_greedy_fn
from dizoo.atari.config.serial.spaceinvaders.spaceinvaders_dqn_config_multi_gpu_ddp import space_invaders_dqn_config,create_config
from ding.utils import DistContext
from functools import partial
from ding.envs import get_vec_env_setting, create_env_manager
def main(cfg, create_cfg, seed=0):
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
# Set random seed for all package and instance
collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
# Set up RL Policy
model = DQN(**cfg.policy.model)
policy = DQNPolicy(cfg.policy, model=model)
# Set up collection, training and evaluation utilities
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
# Set up other modules, etc. epsilon greedy
eps_cfg = cfg.policy.other.eps
epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
# Training & Evaluation loop
while True:
# Evaluating at the beginning and with specific frequency
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Update other modules
eps = epsilon_greedy(collector.envstep)
# Sampling data from environments
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Training
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
break
learner.train(train_data, collector.envstep)
if __name__ == "__main__":
with DistContext():
main(space_invaders_dqn_config,create_config)
import os
import torch
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.policy import DQNPolicy
from ding.model import DQN
from ding.utils import set_pkg_seed
from ding.rl_utils import get_epsilon_greedy_fn
from dizoo.atari.config.serial.spaceinvaders.spaceinvaders_dqn_config_multi_gpu_dp import space_invaders_dqn_config, create_config
from ding.torch_utils import DataParallel
from functools import partial
from ding.envs import get_vec_env_setting, create_env_manager
def main(cfg, create_cfg, seed=0):
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
# Set random seed for all package and instance
collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
# Set up RL Policy
model = DQN(**cfg.policy.model)
model = DataParallel(model)
policy = DQNPolicy(cfg.policy, model=model)
# Set up collection, training and evaluation utilities
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
# Set up other modules, etc. epsilon greedy
eps_cfg = cfg.policy.other.eps
epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
# Training & Evaluation loop
while True:
# Evaluating at the beginning and with specific frequency
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Update other modules
eps = epsilon_greedy(collector.envstep)
# Sampling data from environments
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Training
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
break
learner.train(train_data, collector.envstep)
if __name__ == "__main__":
main(space_invaders_dqn_config,create_config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册