selfplay_demo_ppo_main.py 4.4 KB
Newer Older
S
Swain 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
import os
import gym
import numpy as np
import copy
import torch
from tensorboardX import SummaryWriter

from ding.config import compile_config
from ding.worker import BaseLearner, Episode1v1Collector, OnevOneEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
from ding.utils import set_pkg_seed
from dizoo.league_demo.game_env import GameEnv
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config


class EvalPolicy1:

    def forward(self, data: dict) -> dict:
        return {env_id: {'action': torch.zeros(1)} for env_id in data.keys()}

    def reset(self, data_id: list = []) -> None:
        pass


class EvalPolicy2:

    def forward(self, data: dict) -> dict:
        return {
            env_id: {
                'action': torch.from_numpy(np.random.choice([0, 1], p=[0.5, 0.5], size=(1, )))
            }
            for env_id in data.keys()
        }

    def reset(self, data_id: list = []) -> None:
        pass


def main(cfg, seed=0, max_iterations=int(1e10)):
    cfg.exp_name = 'selfplay_demo_ppo'
    cfg = compile_config(
        cfg,
        BaseEnvManager,
        PPOPolicy,
        BaseLearner,
        Episode1v1Collector,
        OnevOneEvaluator,
        NaiveReplayBuffer,
        save_cfg=True
    )
    collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
    collector_env = BaseEnvManager(env_fn=[GameEnv for _ in range(collector_env_num)], cfg=cfg.env.manager)
    evaluator_env1 = BaseEnvManager(env_fn=[GameEnv for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
    evaluator_env2 = BaseEnvManager(env_fn=[GameEnv for _ in range(evaluator_env_num)], cfg=cfg.env.manager)

    collector_env.seed(seed)
    evaluator_env1.seed(seed, dynamic_seed=False)
    evaluator_env2.seed(seed, dynamic_seed=False)
    set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

    model1 = VAC(**cfg.policy.model)
    policy1 = PPOPolicy(cfg.policy, model=model1)
    model2 = VAC(**cfg.policy.model)
    policy2 = PPOPolicy(cfg.policy, model=model2)
    eval_policy1 = EvalPolicy1()
    eval_policy2 = EvalPolicy2()

    tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
    learner1 = BaseLearner(
        cfg.policy.learn.learner, policy1.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner1'
    )
    learner2 = BaseLearner(
        cfg.policy.learn.learner, policy2.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner2'
    )
    collector = Episode1v1Collector(
        cfg.policy.collect.collector,
        collector_env, [policy1.collect_mode, policy2.collect_mode],
        tb_logger,
        exp_name=cfg.exp_name
    )
    # collect_mode ppo use multimonial sample for selecting action
    evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
    evaluator1_cfg.stop_value = cfg.env.stop_value[0]
    evaluator1 = OnevOneEvaluator(
        evaluator1_cfg,
        evaluator_env1, [policy1.collect_mode, eval_policy1],
        tb_logger,
        exp_name=cfg.exp_name,
        instance_name='fixed_evaluator'
    )
    evaluator2_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
    evaluator2_cfg.stop_value = cfg.env.stop_value[1]
    evaluator2 = OnevOneEvaluator(
        evaluator2_cfg,
        evaluator_env2, [policy1.collect_mode, eval_policy2],
        tb_logger,
        exp_name=cfg.exp_name,
        instance_name='uniform_evaluator'
    )

    for _ in range(max_iterations):
        if evaluator1.should_eval(learner1.train_iter):
            stop_flag1, reward = evaluator1.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep)
            tb_logger.add_scalar('fixed_evaluator_step/reward_mean', reward, collector.envstep)
        if evaluator2.should_eval(learner1.train_iter):
            stop_flag2, reward = evaluator2.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep)
            tb_logger.add_scalar('uniform_evaluator_step/reward_mean', reward, collector.envstep)
        if stop_flag1 and stop_flag2:
            break
        train_data, _ = collector.collect(train_iter=learner1.train_iter)
        for data in train_data:
            for d in data:
                d['adv'] = d['reward']
        for i in range(cfg.policy.learn.update_per_collect):
            learner1.train(train_data[0], collector.envstep)
            learner2.train(train_data[1], collector.envstep)


if __name__ == "__main__":
    main(league_demo_ppo_config)