league_demo_ppo_main.py 9.9 KB
Newer Older
S
Swain 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
import os
import copy
import gym
import numpy as np
import torch
from tensorboardX import SummaryWriter

from ding.config import compile_config
from ding.worker import BaseLearner, Episode1v1Collector, OnevOneEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
from ding.utils import set_pkg_seed
from dizoo.league_demo.game_env import GameEnv
from dizoo.league_demo.demo_league import DemoLeague
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config


class EvalPolicy1:

S
Swain 已提交
21 22 23 24
    def __init__(self, optimal_policy: list) -> None:
        assert len(optimal_policy) == 2
        self.optimal_policy = optimal_policy

S
Swain 已提交
25
    def forward(self, data: dict) -> dict:
S
Swain 已提交
26 27 28 29 30 31
        return {
            env_id: {
                'action': torch.from_numpy(np.random.choice([0, 1], p=self.optimal_policy, size=(1, )))
            }
            for env_id in data.keys()
        }
S
Swain 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

    def reset(self, data_id: list = []) -> None:
        pass


class EvalPolicy2:

    def forward(self, data: dict) -> dict:
        return {
            env_id: {
                'action': torch.from_numpy(np.random.choice([0, 1], p=[0.5, 0.5], size=(1, )))
            }
            for env_id in data.keys()
        }

    def reset(self, data_id: list = []) -> None:
        pass


def main(cfg, seed=0, max_iterations=int(1e10)):
    cfg = compile_config(
        cfg,
        BaseEnvManager,
        PPOPolicy,
        BaseLearner,
        Episode1v1Collector,
        OnevOneEvaluator,
        NaiveReplayBuffer,
        save_cfg=True
    )
S
Swain 已提交
62
    env_type = cfg.env.env_type
S
Swain 已提交
63
    collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
S
Swain 已提交
64 65 66 67 68 69 70 71 72
    evaluator_env1 = BaseEnvManager(
        env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
    )
    evaluator_env2 = BaseEnvManager(
        env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
    )
    evaluator_env3 = BaseEnvManager(
        env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
    )
S
Swain 已提交
73 74 75

    evaluator_env1.seed(seed, dynamic_seed=False)
    evaluator_env2.seed(seed, dynamic_seed=False)
S
Swain 已提交
76
    evaluator_env3.seed(seed, dynamic_seed=False)
S
Swain 已提交
77 78 79 80
    set_pkg_seed(seed, use_cuda=cfg.policy.cuda)

    tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
    league = DemoLeague(cfg.policy.other.league)
S
Swain 已提交
81
    eval_policy1 = EvalPolicy1(evaluator_env1._env_ref.optimal_policy)
S
Swain 已提交
82 83 84 85 86 87 88 89 90
    eval_policy2 = EvalPolicy2()
    policies = {}
    learners = {}
    collectors = {}
    for player_id in league.active_players_ids:
        # default set the same arch model(different init weight)
        model = VAC(**cfg.policy.model)
        policy = PPOPolicy(cfg.policy, model=model)
        policies[player_id] = policy
S
Swain 已提交
91 92 93
        collector_env = BaseEnvManager(
            env_fn=[lambda: GameEnv(env_type) for _ in range(collector_env_num)], cfg=cfg.env.manager
        )
S
Swain 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
        collector_env.seed(seed)

        learners[player_id] = BaseLearner(
            cfg.policy.learn.learner,
            policy.learn_mode,
            tb_logger=tb_logger,
            exp_name=cfg.exp_name,
            instance_name=player_id + '_learner'
        )
        collectors[player_id] = Episode1v1Collector(
            cfg.policy.collect.collector,
            collector_env,
            tb_logger=tb_logger,
            exp_name=cfg.exp_name,
            instance_name=player_id + '_colllector',
        )
    model = VAC(**cfg.policy.model)
    policy = PPOPolicy(cfg.policy, model=model)
    policies['historical'] = policy
S
Swain 已提交
113 114
    # use initial policy as another eval_policy
    eval_policy3 = PPOPolicy(cfg.policy, model=copy.deepcopy(model)).collect_mode
S
Swain 已提交
115 116

    main_key = [k for k in learners.keys() if k.startswith('main_player')][0]
S
Swain 已提交
117
    main_player = league.get_player_by_id(main_key)
S
Swain 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
    main_learner = learners[main_key]
    main_collector = collectors[main_key]
    # collect_mode ppo use multimonial sample for selecting action
    evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
    evaluator1_cfg.stop_value = cfg.env.stop_value[0]
    evaluator1 = OnevOneEvaluator(
        evaluator1_cfg,
        evaluator_env1, [policies[main_key].collect_mode, eval_policy1],
        tb_logger,
        exp_name=cfg.exp_name,
        instance_name='fixed_evaluator'
    )
    evaluator2_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
    evaluator2_cfg.stop_value = cfg.env.stop_value[1]
    evaluator2 = OnevOneEvaluator(
        evaluator2_cfg,
        evaluator_env2, [policies[main_key].collect_mode, eval_policy2],
        tb_logger,
        exp_name=cfg.exp_name,
        instance_name='uniform_evaluator'
    )
S
Swain 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151
    evaluator3_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
    evaluator3_cfg.stop_value = 99999999  # stop_value of evaluator3 is a placeholder
    evaluator3 = OnevOneEvaluator(
        evaluator3_cfg,
        evaluator_env3, [policies[main_key].collect_mode, eval_policy3],
        tb_logger,
        exp_name=cfg.exp_name,
        instance_name='init_evaluator'
    )

    def load_checkpoint_fn(player_id: str, ckpt_path: str):
        state_dict = torch.load(ckpt_path)
        policies[player_id].learn_mode.load_state_dict(state_dict)
S
Swain 已提交
152

S
Swain 已提交
153 154
    torch.save(policies['historical'].learn_mode.state_dict(), league.reset_checkpoint_path)
    league.load_checkpoint = load_checkpoint_fn
S
Swain 已提交
155 156 157
    for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
        torch.save(policies[player_id].collect_mode.state_dict(), player_ckpt_path)
        league.judge_snapshot(player_id, force=True)
S
Swain 已提交
158
    init_main_player_rating = league.metric_env.create_rating(mu=0)
S
Swain 已提交
159 160 161

    for run_iter in range(max_iterations):
        if evaluator1.should_eval(main_learner.train_iter):
S
Swain 已提交
162
            stop_flag1, reward, episode_info = evaluator1.eval(
S
Swain 已提交
163 164
                main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
            )
S
Swain 已提交
165 166 167 168 169
            win_loss_result = [e['result'] for e in episode_info[0]]
            # set fixed NE policy trueskill(exposure) equal 10
            main_player.rating = league.metric_env.rate_1vsC(
                main_player.rating, league.metric_env.create_rating(mu=10, sigma=1e-8), win_loss_result
            )
S
Swain 已提交
170 171
            tb_logger.add_scalar('fixed_evaluator_step/reward_mean', reward, main_collector.envstep)
        if evaluator2.should_eval(main_learner.train_iter):
S
Swain 已提交
172
            stop_flag2, reward, episode_info = evaluator2.eval(
S
Swain 已提交
173 174
                main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
            )
S
Swain 已提交
175 176 177 178 179
            win_loss_result = [e['result'] for e in episode_info[0]]
            # set random(uniform) policy trueskill(exposure) equal 0
            main_player.rating = league.metric_env.rate_1vsC(
                main_player.rating, league.metric_env.create_rating(mu=0, sigma=1e-8), win_loss_result
            )
S
Swain 已提交
180
            tb_logger.add_scalar('uniform_evaluator_step/reward_mean', reward, main_collector.envstep)
S
Swain 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193
        if evaluator3.should_eval(main_learner.train_iter):
            _, reward, episode_info = evaluator3.eval(
                main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
            )
            win_loss_result = [e['result'] for e in episode_info[0]]
            # use init main player as another evaluator metric
            main_player.rating, init_main_player_rating = league.metric_env.rate_1vs1(
                main_player.rating, init_main_player_rating, win_loss_result
            )
            tb_logger.add_scalar('init_evaluator_step/reward_mean', reward, main_collector.envstep)
            tb_logger.add_scalar(
                'league/init_main_player_trueskill', init_main_player_rating.exposure, main_collector.envstep
            )
S
Swain 已提交
194 195 196
        if stop_flag1 and stop_flag2:
            break
        for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
S
Swain 已提交
197 198 199 200
            tb_logger.add_scalar(
                'league/{}_trueskill'.format(player_id),
                league.get_player_by_id(player_id).rating.exposure, main_collector.envstep
            )
S
Swain 已提交
201 202 203 204 205 206 207 208 209 210 211 212
            collector, learner = collectors[player_id], learners[player_id]
            job = league.get_job_info(player_id)
            opponent_player_id = job['player_id'][1]
            # print('job player: {}'.format(job['player_id']))
            if 'historical' in opponent_player_id:
                opponent_policy = policies['historical'].collect_mode
                opponent_path = job['checkpoint_path'][1]
                opponent_policy.load_state_dict(torch.load(opponent_path, map_location='cpu'))
            else:
                opponent_policy = policies[opponent_player_id].collect_mode
            collector.reset_policy([policies[player_id].collect_mode, opponent_policy])
            train_data, episode_info = collector.collect(train_iter=learner.train_iter)
S
Swain 已提交
213
            train_data, episode_info = train_data[0], episode_info[0]  # only use launch player data for training
S
Swain 已提交
214 215 216 217 218 219 220 221 222 223 224
            for d in train_data:
                d['adv'] = d['reward']

            for i in range(cfg.policy.learn.update_per_collect):
                learner.train(train_data, collector.envstep)
            torch.save(learner.policy.state_dict(), player_ckpt_path)

            player_info = learner.learn_info
            player_info['player_id'] = player_id
            league.update_active_player(player_info)
            league.judge_snapshot(player_id)
S
Swain 已提交
225
            # set eval_flag=True to enable trueskill update
S
Swain 已提交
226
            job_finish_info = {
S
Swain 已提交
227
                'eval_flag': True,
S
Swain 已提交
228 229 230 231 232 233 234 235 236 237 238
                'launch_player': job['launch_player'],
                'player_id': job['player_id'],
                'result': [e['result'] for e in episode_info],
            }
            league.finish_job(job_finish_info)
        if run_iter % 100 == 0:
            print(repr(league.payoff))


if __name__ == "__main__":
    main(league_demo_ppo_config)