提交 dd4472e4 编写于 作者: N niuyazhe
...@@ -3,3 +3,4 @@ from .one_vs_one_league import OneVsOneLeague ...@@ -3,3 +3,4 @@ from .one_vs_one_league import OneVsOneLeague
from .player import Player, ActivePlayer, HistoricalPlayer, create_player from .player import Player, ActivePlayer, HistoricalPlayer, create_player
from .starcraft_player import MainPlayer, MainExploiter, LeagueExploiter from .starcraft_player import MainPlayer, MainExploiter, LeagueExploiter
from .shared_payoff import create_payoff from .shared_payoff import create_payoff
from .metric import get_elo, get_elo_array, LeagueMetricEnv
...@@ -7,7 +7,9 @@ import os.path as osp ...@@ -7,7 +7,9 @@ import os.path as osp
from ding.league.player import ActivePlayer, HistoricalPlayer, create_player from ding.league.player import ActivePlayer, HistoricalPlayer, create_player
from ding.league.shared_payoff import create_payoff from ding.league.shared_payoff import create_payoff
from ding.utils import import_module, read_file, save_file, LockContext, LockContextType, LEAGUE_REGISTRY from ding.utils import import_module, read_file, save_file, LockContext, LockContextType, LEAGUE_REGISTRY, \
deep_merge_dicts
from .metric import LeagueMetricEnv
class BaseLeague: class BaseLeague:
...@@ -27,6 +29,41 @@ class BaseLeague: ...@@ -27,6 +29,41 @@ class BaseLeague:
cfg.cfg_type = cls.__name__ + 'Dict' cfg.cfg_type = cls.__name__ + 'Dict'
return cfg return cfg
config = dict(
league_type='base',
import_names=["ding.league.base_league"],
# ---player----
# "player_category" is just a name. Depends on the env.
# For example, in StarCraft, this can be ['zerg', 'terran', 'protoss'].
player_category=['default'],
# Support different types of active players for solo and battle league.
# For solo league, supports ['solo_active_player'].
# For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter'].
# active_players=dict(),
# "use_pretrain" means whether to use pretrain model to initialize active player.
use_pretrain=False,
# "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player.
# "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and
# "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well.
# Otherwise, "pretrain_checkpoint_path" should list paths of all player categories.
use_pretrain_init_historical=False,
pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ),
# ---payoff---
payoff=dict(
# Supports ['battle']
type='battle',
decay=0.99,
min_win_rate_games=8,
),
metric=dict(
mu=0,
sigma=25 / 3,
beta=25 / 3 / 2,
tau=0.0,
draw_probability=0.02,
),
)
def __init__(self, cfg: EasyDict) -> None: def __init__(self, cfg: EasyDict) -> None:
""" """
Overview: Overview:
...@@ -34,16 +71,19 @@ class BaseLeague: ...@@ -34,16 +71,19 @@ class BaseLeague:
Arguments: Arguments:
- cfg (:obj:`EasyDict`): League config. - cfg (:obj:`EasyDict`): League config.
""" """
self.cfg = cfg self.cfg = deep_merge_dicts(self.default_config(), cfg)
self.path_policy = cfg.path_policy self.path_policy = cfg.path_policy
if not osp.exists(self.path_policy): if not osp.exists(self.path_policy):
os.mkdir(self.path_policy) os.mkdir(self.path_policy)
self.league_uid = str(uuid.uuid1()) self.league_uid = str(uuid.uuid1())
# TODO dict players
self.active_players = [] self.active_players = []
self.historical_players = [] self.historical_players = []
self.player_path = "./league" self.player_path = "./league"
self.payoff = create_payoff(self.cfg.payoff) self.payoff = create_payoff(self.cfg.payoff)
metric_cfg = self.cfg.metric
self.metric_env = LeagueMetricEnv(metric_cfg.mu, metric_cfg.sigma, metric_cfg.tau, metric_cfg.draw_probability)
self._active_players_lock = LockContext(type_=LockContextType.THREAD_LOCK) self._active_players_lock = LockContext(type_=LockContextType.THREAD_LOCK)
self._init_players() self._init_players()
...@@ -58,7 +98,9 @@ class BaseLeague: ...@@ -58,7 +98,9 @@ class BaseLeague:
for i in range(n): # This type's active player number for i in range(n): # This type's active player number
name = '{}_{}_{}'.format(k, cate, i) name = '{}_{}_{}'.format(k, cate, i)
ckpt_path = osp.join(self.path_policy, '{}_ckpt.pth'.format(name)) ckpt_path = osp.join(self.path_policy, '{}_ckpt.pth'.format(name))
player = create_player(self.cfg, k, self.cfg[k], cate, self.payoff, ckpt_path, name, 0) player = create_player(
self.cfg, k, self.cfg[k], cate, self.payoff, ckpt_path, name, 0, self.metric_env.create_rating()
)
if self.cfg.use_pretrain: if self.cfg.use_pretrain:
self.save_checkpoint(self.cfg.pretrain_checkpoint_path[cate], ckpt_path) self.save_checkpoint(self.cfg.pretrain_checkpoint_path[cate], ckpt_path)
self.active_players.append(player) self.active_players.append(player)
...@@ -79,6 +121,7 @@ class BaseLeague: ...@@ -79,6 +121,7 @@ class BaseLeague:
self.cfg.pretrain_checkpoint_path[cate], self.cfg.pretrain_checkpoint_path[cate],
name, name,
0, 0,
self.metric_env.create_rating(),
parent_id=parent_name parent_id=parent_name
) )
self.historical_players.append(hp) self.historical_players.append(hp)
...@@ -140,7 +183,7 @@ class BaseLeague: ...@@ -140,7 +183,7 @@ class BaseLeague:
player = self.active_players[idx] player = self.active_players[idx]
if force or player.is_trained_enough(): if force or player.is_trained_enough():
# Snapshot # Snapshot
hp = player.snapshot() hp = player.snapshot(self.metric_env)
self.save_checkpoint(player.checkpoint_path, hp.checkpoint_path) self.save_checkpoint(player.checkpoint_path, hp.checkpoint_path)
self.historical_players.append(hp) self.historical_players.append(hp)
self.payoff.add_player(hp) self.payoff.add_player(hp)
...@@ -197,6 +240,21 @@ class BaseLeague: ...@@ -197,6 +240,21 @@ class BaseLeague:
""" """
# TODO(nyz) more fine-grained job info # TODO(nyz) more fine-grained job info
self.payoff.update(job_info) self.payoff.update(job_info)
if 'eval_flag' in job_info and job_info['eval_flag']:
home_id, away_id = job_info['player_id']
home_player, away_player = self.get_player_by_id(home_id), self.get_player_by_id(away_id)
job_info_result = job_info['result']
if isinstance(job_info_result[0], list):
job_info_result = sum(job_info_result, [])
home_player.rating, away_player.rating = self.metric_env.rate_1vs1(
home_player.rating, away_player.rating, result=job_info_result
)
def get_player_by_id(self, player_id: str) -> 'Player': # noqa
if 'historical' in player_id:
return [p for p in self.historical_players if p.player_id == player_id][0]
else:
return [p for p in self.active_players if p.player_id == player_id][0]
@staticmethod @staticmethod
def save_checkpoint(src_checkpoint, dst_checkpoint) -> None: def save_checkpoint(src_checkpoint, dst_checkpoint) -> None:
......
from typing import Tuple, Union, List
import math
import numpy as np
from trueskill import TrueSkill, Rating, rate_1vs1
class EloCalculator(object):
score = {
1: 1.0, # win
0: 0.5, # draw
-1: 0.0, # lose
}
@classmethod
def get_new_rating(cls,
rating_a: int,
rating_b: int,
result: int,
k_factor: int = 32,
beta: int = 200) -> Tuple[int, int]:
assert result in [1, 0, -1]
expect_a = 1. / (1. + math.pow(10, (rating_b - rating_a) / (2. * beta)))
expect_b = 1. / (1. + math.pow(10, (rating_a - rating_b) / (2. * beta)))
new_rating_a = rating_a + k_factor * (EloCalculator.score[result] - expect_a)
new_rating_b = rating_b + k_factor * (1 - EloCalculator.score[result] - expect_b)
return round(new_rating_a), round(new_rating_b)
@classmethod
def get_new_rating_array(
cls,
rating: np.ndarray,
result: np.ndarray,
game_count: np.ndarray,
k_factor: int = 32,
beta: int = 200
) -> np.ndarray:
"""
Shapes:
rating: :math:`(N, )`, N is the number of player
result: :math:`(N, N)`
game_count: :math:`(N, N)`
"""
rating_diff = np.expand_dims(rating, 0) - np.expand_dims(rating, 1)
expect = 1. / (1. + np.power(10, rating_diff / (2. * beta))) * game_count
delta = ((result + 1.) / 2 - expect) * (game_count > 0)
delta = delta.sum(axis=1)
return np.round(rating + k_factor * delta).astype(np.int64)
class PlayerRating(Rating):
def __init__(self, mu: float = None, sigma: float = None, elo_init: int = None) -> None:
super(PlayerRating, self).__init__(mu, sigma)
self.elo = elo_init
def __repr__(self) -> str:
c = type(self)
args = ('.'.join([c.__module__, c.__name__]), self.mu, self.sigma, self.exposure, self.elo)
return '%s(mu=%.3f, sigma=%.3f, exposure=%.3f, elo=%d)' % args
class LeagueMetricEnv(TrueSkill):
"""
Overview:
TrueSkill rating system among game players, for more details pleas refer to ``https://trueskill.org/``
"""
def __init__(self, *args, elo_init: int = 1200, **kwargs) -> None:
super(LeagueMetricEnv, self).__init__(*args, **kwargs)
self.elo_init = elo_init
def create_rating(self, mu: float = None, sigma: float = None, elo_init: int = None) -> PlayerRating:
if mu is None:
mu = self.mu
if sigma is None:
sigma = self.sigma
if elo_init is None:
elo_init = self.elo_init
return PlayerRating(mu, sigma, elo_init)
@staticmethod
def _rate_1vs1(t1, t2, **kwargs):
t1_elo, t2_elo = t1.elo, t2.elo
t1, t2 = rate_1vs1(t1, t2, **kwargs)
if 'drawn' in kwargs:
result = 0
else:
result = 1
t1_elo, t2_elo = EloCalculator.get_new_rating(t1_elo, t2_elo, result)
t1 = PlayerRating(t1.mu, t1.sigma, t1_elo)
t2 = PlayerRating(t2.mu, t2.sigma, t2_elo)
return t1, t2
def rate_1vs1(self,
team1: PlayerRating,
team2: PlayerRating,
result: List[str] = None,
**kwargs) -> Tuple[PlayerRating, PlayerRating]:
if result is None:
return self._rate_1vs1(team1, team2, **kwargs)
else:
for r in result:
if r == 'wins':
team1, team2 = self._rate_1vs1(team1, team2)
elif r == 'draws':
team1, team2 = self._rate_1vs1(team1, team2, drawn=True)
elif r == 'losses':
team2, team1 = self._rate_1vs1(team2, team1)
else:
raise RuntimeError("invalid result: {}".format(r))
return team1, team2
def rate_1vsC(self, team1: PlayerRating, team2: PlayerRating, result: List[str]) -> PlayerRating:
for r in result:
if r == 'wins':
team1, _ = self._rate_1vs1(team1, team2)
elif r == 'draws':
team1, _ = self._rate_1vs1(team1, team2, drawn=True)
elif r == 'losses':
_, team1 = self._rate_1vs1(team2, team1)
else:
raise RuntimeError("invalid result: {}".format(r))
return team1
get_elo = EloCalculator.get_new_rating
get_elo_array = EloCalculator.get_new_rating_array
...@@ -53,6 +53,13 @@ class OneVsOneLeague(BaseLeague): ...@@ -53,6 +53,13 @@ class OneVsOneLeague(BaseLeague):
decay=0.99, decay=0.99,
min_win_rate_games=8, min_win_rate_games=8,
), ),
metric=dict(
mu=0,
sigma=25 / 3,
beta=25 / 3 / 2,
tau=0.0,
draw_probability=0.02,
),
) )
# override # override
......
...@@ -25,7 +25,8 @@ class Player: ...@@ -25,7 +25,8 @@ class Player:
init_payoff: 'BattleSharedPayoff', # noqa init_payoff: 'BattleSharedPayoff', # noqa
checkpoint_path: str, checkpoint_path: str,
player_id: str, player_id: str,
total_agent_step: int total_agent_step: int,
rating: 'PlayerRating', # noqa
) -> None: ) -> None:
""" """
Overview: Overview:
...@@ -39,6 +40,7 @@ class Player: ...@@ -39,6 +40,7 @@ class Player:
- player_id (:obj:`str`): Player id in string format. - player_id (:obj:`str`): Player id in string format.
- total_agent_step (:obj:`int`): For active player, it should be 0; \ - total_agent_step (:obj:`int`): For active player, it should be 0; \
For historical player, it should be parent player's ``_total_agent_step`` when ``snapshot``. For historical player, it should be parent player's ``_total_agent_step`` when ``snapshot``.
- rating (:obj:`PlayerRating`): player rating information in total league
""" """
self._cfg = cfg self._cfg = cfg
self._category = category self._category = category
...@@ -48,6 +50,7 @@ class Player: ...@@ -48,6 +50,7 @@ class Player:
self._player_id = player_id self._player_id = player_id
assert isinstance(total_agent_step, int), (total_agent_step, type(total_agent_step)) assert isinstance(total_agent_step, int), (total_agent_step, type(total_agent_step))
self._total_agent_step = total_agent_step self._total_agent_step = total_agent_step
self._rating = rating
@property @property
def category(self) -> str: def category(self) -> str:
...@@ -73,6 +76,14 @@ class Player: ...@@ -73,6 +76,14 @@ class Player:
def total_agent_step(self, step: int) -> None: def total_agent_step(self, step: int) -> None:
self._total_agent_step = step self._total_agent_step = step
@property
def rating(self) -> 'PlayerRating': # noqa
return self._rating
@rating.setter
def rating(self, _rating: 'PlayerRating') -> None: # noqa
self._rating = _rating
@PLAYER_REGISTRY.register('historical_player') @PLAYER_REGISTRY.register('historical_player')
class HistoricalPlayer(Player): class HistoricalPlayer(Player):
...@@ -168,10 +179,12 @@ class ActivePlayer(Player): ...@@ -168,10 +179,12 @@ class ActivePlayer(Player):
else: else:
return False return False
def snapshot(self) -> HistoricalPlayer: def snapshot(self, metric_env: 'LeagueMetricEnv') -> HistoricalPlayer: # noqa
""" """
Overview: Overview:
Generate a snapshot historical player from the current player, called in league's ``_snapshot``. Generate a snapshot historical player from the current player, called in league's ``_snapshot``.
Argument:
- metric_env (:obj:`LeagueMetricEnv`): player rating environment, one league one env
Returns: Returns:
- snapshot_player (:obj:`HistoricalPlayer`): new instantiated historical player - snapshot_player (:obj:`HistoricalPlayer`): new instantiated historical player
...@@ -187,6 +200,7 @@ class ActivePlayer(Player): ...@@ -187,6 +200,7 @@ class ActivePlayer(Player):
path, path,
self.player_id + '_{}_historical'.format(int(self._total_agent_step)), self.player_id + '_{}_historical'.format(int(self._total_agent_step)),
self._total_agent_step, self._total_agent_step,
metric_env.create_rating(mu=self.rating.mu),
parent_id=self.player_id parent_id=self.player_id
) )
......
from typing import Optional from typing import Optional, Union
import numpy as np import numpy as np
from ding.utils import PLAYER_REGISTRY from ding.utils import PLAYER_REGISTRY
...@@ -105,7 +105,7 @@ class MainPlayer(ActivePlayer): ...@@ -105,7 +105,7 @@ class MainPlayer(ActivePlayer):
Overview: Overview:
MainPlayer does not mutate MainPlayer does not mutate
""" """
return None pass
@PLAYER_REGISTRY.register('main_exploiter') @PLAYER_REGISTRY.register('main_exploiter')
...@@ -168,9 +168,9 @@ class MainExploiter(ActivePlayer): ...@@ -168,9 +168,9 @@ class MainExploiter(ActivePlayer):
Overview: Overview:
Main exploiter is sure to mutate(reset) to the supervised learning player Main exploiter is sure to mutate(reset) to the supervised learning player
Returns: Returns:
- ckpt_path (:obj:`str`): the pretrained model's ckpt path - mutate_ckpt_path (:obj:`str`): mutation target checkpoint path
""" """
return info['pretrain_checkpoint_path'] return info['reset_checkpoint_path']
@PLAYER_REGISTRY.register('league_exploiter') @PLAYER_REGISTRY.register('league_exploiter')
...@@ -220,15 +220,15 @@ class LeagueExploiter(ActivePlayer): ...@@ -220,15 +220,15 @@ class LeagueExploiter(ActivePlayer):
return super().is_trained_enough(select_fn=lambda p: isinstance(p, HistoricalPlayer)) return super().is_trained_enough(select_fn=lambda p: isinstance(p, HistoricalPlayer))
# override # override
def mutate(self, info) -> Optional[str]: def mutate(self, info) -> Union[str, None]:
""" """
Overview: Overview:
League exploiter can mutate to the supervised learning player with 0.25 prob League exploiter can mutate to the supervised learning player with 0.25 prob
Returns: Returns:
- ckpt_path (:obj:`str`): with ``mutate_prob`` prob returns the pretrained model's ckpt path, \ - ckpt_path (:obj:`Union[str, None]`): with ``mutate_prob`` prob returns the pretrained model's ckpt path, \
with left 1 - ``mutate_prob`` prob returns None, which means no mutation with left 1 - ``mutate_prob`` prob returns None, which means no mutation
""" """
p = np.random.uniform() p = np.random.uniform()
if p < self.mutate_prob: if p < self.mutate_prob:
return info['pretrain_checkpoint_path'] return info['reset_checkpoint_path']
return None return None
import pytest
import numpy as np
from ding.league import get_elo, get_elo_array, LeagueMetricEnv
@pytest.mark.unittest
def test_elo_calculator():
game_count = np.array([[0, 1, 2], [1, 0, 0], [2, 0, 0]])
rating = np.array([1613, 1573, 1601])
result = np.array([[0, -1, -1 + 1], [1, 0, 0], [1 + (-1), 0, 0]])
new_rating0, new_rating1 = get_elo(rating[0], rating[1], result[0][1])
assert new_rating0 == 1595
assert new_rating1 == 1591
old_rating = np.copy(rating)
new_rating = get_elo_array(rating, result, game_count)
assert (rating == old_rating).all() # no inplace modification
assert new_rating.dtype == np.int64
assert new_rating[0] == 1578
assert new_rating[1] == 1591
assert new_rating[2] == 1586
@pytest.mark.unittest
def test_league_metric():
sigma = 25 / 3
env = LeagueMetricEnv(mu=0, sigma=sigma, beta=sigma / 2, tau=0.0, draw_probability=0.02, elo_init=1000)
r1 = env.create_rating(elo_init=1613)
r2 = env.create_rating(elo_init=1573)
assert r1.mu == 0
assert r2.mu == 0
assert r2.sigma == sigma
assert r2.sigma == sigma
assert r1.elo == 1613
assert r2.elo == 1573
# r1 draw r2
r1, r2 = env.rate_1vs1(r1, r2, drawn=True)
assert r1.mu == r2.mu
assert r1.elo == 1611
assert r2.elo == 1575
# r1 win r2
new_r1, new_r2 = env.rate_1vs1(r1, r2)
assert new_r1.mu > r1.mu
assert new_r2.mu < r2.mu
assert new_r1.mu + new_r2.mu == 0
assert pytest.approx(new_r1.mu, 3.230)
assert pytest.approx(new_r2.mu, -3.230)
assert new_r1.elo == 1625
assert new_r2.elo == 1561
# multi result
new_r1, new_r2 = env.rate_1vs1(r1, r2, result=['wins', 'wins', 'losses'])
assert new_r1.elo > 1611
# 1vsConstant
new_r1 = env.rate_1vsC(r1, env.create_rating(elo_init=1800), result=['losses', 'losses'])
assert new_r1.elo < 1611
print('final rating is: ', new_r1)
...@@ -8,6 +8,9 @@ from easydict import EasyDict ...@@ -8,6 +8,9 @@ from easydict import EasyDict
from ding.league.player import Player from ding.league.player import Player
from ding.league.shared_payoff import BattleRecordDict, create_payoff from ding.league.shared_payoff import BattleRecordDict, create_payoff
from ding.league.metric import LeagueMetricEnv
env = LeagueMetricEnv()
@pytest.mark.unittest @pytest.mark.unittest
...@@ -42,7 +45,8 @@ def get_shared_payoff_player(payoff): ...@@ -42,7 +45,8 @@ def get_shared_payoff_player(payoff):
init_payoff=payoff, init_payoff=payoff,
checkpoint_path='sp_ckpt_{}.pth'.format(sp_player_count), checkpoint_path='sp_ckpt_{}.pth'.format(sp_player_count),
player_id='sp_player_{}'.format(sp_player_count), player_id='sp_player_{}'.format(sp_player_count),
total_agent_step=0 total_agent_step=0,
rating=env.create_rating(),
) )
sp_player_count += 1 sp_player_count += 1
return player return player
......
...@@ -8,8 +8,10 @@ from ding.league.player import Player, HistoricalPlayer, ActivePlayer, create_pl ...@@ -8,8 +8,10 @@ from ding.league.player import Player, HistoricalPlayer, ActivePlayer, create_pl
from ding.league.shared_payoff import create_payoff from ding.league.shared_payoff import create_payoff
from ding.league.starcraft_player import MainPlayer, MainExploiter, LeagueExploiter from ding.league.starcraft_player import MainPlayer, MainExploiter, LeagueExploiter
from ding.league.tests.league_test_default_config import league_test_config from ding.league.tests.league_test_default_config import league_test_config
from ding.league.metric import LeagueMetricEnv
ONE_PHASE_STEP = 2000 ONE_PHASE_STEP = 2000
env = LeagueMetricEnv()
@pytest.fixture(scope='function') @pytest.fixture(scope='function')
...@@ -27,7 +29,7 @@ def setup_league(setup_payoff): ...@@ -27,7 +29,7 @@ def setup_league(setup_payoff):
players.append( players.append(
create_player( create_player(
league_test_config.league, 'main_player', league_test_config.league.main_player, category, setup_payoff, league_test_config.league, 'main_player', league_test_config.league.main_player, category, setup_payoff,
'ckpt_{}.pth'.format(main_player_name), main_player_name, 0 'ckpt_{}.pth'.format(main_player_name), main_player_name, 0, env.create_rating()
) )
) )
# main_exloiter # main_exloiter
...@@ -35,7 +37,7 @@ def setup_league(setup_payoff): ...@@ -35,7 +37,7 @@ def setup_league(setup_payoff):
players.append( players.append(
create_player( create_player(
league_test_config.league, 'main_exploiter', league_test_config.league.main_exploiter, category, league_test_config.league, 'main_exploiter', league_test_config.league.main_exploiter, category,
setup_payoff, 'ckpt_{}.pth'.format(main_exploiter_name), main_exploiter_name, 0 setup_payoff, 'ckpt_{}.pth'.format(main_exploiter_name), main_exploiter_name, 0, env.create_rating()
) )
) )
# league_exploiter # league_exploiter
...@@ -51,6 +53,7 @@ def setup_league(setup_payoff): ...@@ -51,6 +53,7 @@ def setup_league(setup_payoff):
'ckpt_{}.pth'.format(league_exploiter_name), 'ckpt_{}.pth'.format(league_exploiter_name),
league_exploiter_name, league_exploiter_name,
0, 0,
env.create_rating(),
) )
) )
# historical player: sl player is used as initial HistoricalPlayer # historical player: sl player is used as initial HistoricalPlayer
...@@ -65,7 +68,8 @@ def setup_league(setup_payoff): ...@@ -65,7 +68,8 @@ def setup_league(setup_payoff):
'ckpt_sl_{}'.format(sl_hp_name), 'ckpt_sl_{}'.format(sl_hp_name),
sl_hp_name, sl_hp_name,
0, 0,
parent_id=main_player_name env.create_rating(),
parent_id=main_player_name,
) )
) )
for p in players: for p in players:
...@@ -94,7 +98,7 @@ class TestMainPlayer: ...@@ -94,7 +98,7 @@ class TestMainPlayer:
for p in setup_league: for p in setup_league:
if isinstance(p, ActivePlayer): if isinstance(p, ActivePlayer):
p.total_agent_step = 2 * ONE_PHASE_STEP p.total_agent_step = 2 * ONE_PHASE_STEP
hp = p.snapshot() hp = p.snapshot(env)
hp_list.append(hp) hp_list.append(hp)
setup_payoff.add_player(hp) setup_payoff.add_player(hp)
setup_league += hp_list # 12+3 + 12 setup_league += hp_list # 12+3 + 12
...@@ -122,7 +126,7 @@ class TestMainPlayer: ...@@ -122,7 +126,7 @@ class TestMainPlayer:
for p in setup_league: for p in setup_league:
for i in range(N): for i in range(N):
if isinstance(p, ActivePlayer): if isinstance(p, ActivePlayer):
hp = p.snapshot() hp = p.snapshot(env)
assert isinstance(hp, HistoricalPlayer) assert isinstance(hp, HistoricalPlayer)
assert id(hp.payoff) == id(p.payoff) assert id(hp.payoff) == id(p.payoff)
assert hp.parent_id == p.player_id assert hp.parent_id == p.player_id
...@@ -146,7 +150,7 @@ class TestMainPlayer: ...@@ -146,7 +150,7 @@ class TestMainPlayer:
hp_list = [] hp_list = []
for p in setup_league: for p in setup_league:
if isinstance(p, MainPlayer): if isinstance(p, MainPlayer):
hp = p.snapshot() hp = p.snapshot(env)
setup_payoff.add_player(hp) setup_payoff.add_player(hp)
hp_list.append(hp) hp_list.append(hp)
setup_league += hp_list setup_league += hp_list
...@@ -243,7 +247,7 @@ class TestMainExploiter: ...@@ -243,7 +247,7 @@ class TestMainExploiter:
for p in setup_league: for p in setup_league:
if isinstance(p, MainPlayer): if isinstance(p, MainPlayer):
p.total_agent_step = (i + 1) * 2 * ONE_PHASE_STEP p.total_agent_step = (i + 1) * 2 * ONE_PHASE_STEP
hp = p.snapshot() hp = p.snapshot(env)
setup_payoff.add_player(hp) setup_payoff.add_player(hp)
hp_list.append(hp) hp_list.append(hp)
setup_league += hp_list setup_league += hp_list
...@@ -272,9 +276,9 @@ class TestMainExploiter: ...@@ -272,9 +276,9 @@ class TestMainExploiter:
def test_mutate(self, setup_league): def test_mutate(self, setup_league):
assert isinstance(setup_league[1], MainExploiter) assert isinstance(setup_league[1], MainExploiter)
info = {'pretrain_checkpoint_path': 'pretrain_checkpoint.pth'} info = {'reset_checkpoint_path': 'pretrain_checkpoint.pth'}
for _ in range(10): for _ in range(10):
assert setup_league[1].mutate(info) == info['pretrain_checkpoint_path'] assert setup_league[1].mutate(info) == info['reset_checkpoint_path']
@pytest.mark.unittest @pytest.mark.unittest
...@@ -296,7 +300,7 @@ class TestLeagueExploiter: ...@@ -296,7 +300,7 @@ class TestLeagueExploiter:
def test_mutate(self, setup_league): def test_mutate(self, setup_league):
assert isinstance(setup_league[2], LeagueExploiter) assert isinstance(setup_league[2], LeagueExploiter)
info = {'pretrain_checkpoint_path': 'pretrain_checkpoint.pth'} info = {'reset_checkpoint_path': 'pretrain_checkpoint.pth'}
results = [] results = []
for _ in range(1000): for _ in range(1000):
results.append(setup_league[2].mutate(info)) results.append(setup_league[2].mutate(info))
......
...@@ -171,7 +171,7 @@ class OnevOneEvaluator(object): ...@@ -171,7 +171,7 @@ class OnevOneEvaluator(object):
train_iter: int = -1, train_iter: int = -1,
envstep: int = -1, envstep: int = -1,
n_episode: Optional[int] = None n_episode: Optional[int] = None
) -> Tuple[bool, float]: ) -> Tuple[bool, float, list]:
''' '''
Overview: Overview:
Evaluate policy and store the best policy based on whether it reaches the highest historical reward. Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
...@@ -183,12 +183,14 @@ class OnevOneEvaluator(object): ...@@ -183,12 +183,14 @@ class OnevOneEvaluator(object):
Returns: Returns:
- stop_flag (:obj:`bool`): Whether this training program can be ended. - stop_flag (:obj:`bool`): Whether this training program can be ended.
- eval_reward (:obj:`float`): Current eval_reward. - eval_reward (:obj:`float`): Current eval_reward.
- return_info (:obj:`list`): Environment information of each finished episode
''' '''
if n_episode is None: if n_episode is None:
n_episode = self._default_n_episode n_episode = self._default_n_episode
assert n_episode is not None, "please indicate eval n_episode" assert n_episode is not None, "please indicate eval n_episode"
envstep_count = 0 envstep_count = 0
info = {} info = {}
return_info = [[] for _ in range(2)]
eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode)
self._env.reset() self._env.reset()
for p in self._policy: for p in self._policy:
...@@ -219,6 +221,8 @@ class OnevOneEvaluator(object): ...@@ -219,6 +221,8 @@ class OnevOneEvaluator(object):
if 'episode_info' in t.info[0]: if 'episode_info' in t.info[0]:
eval_monitor.update_info(env_id, t.info[0]['episode_info']) eval_monitor.update_info(env_id, t.info[0]['episode_info'])
eval_monitor.update_reward(env_id, reward) eval_monitor.update_reward(env_id, reward)
for policy_id in range(2):
return_info[policy_id].append(t.info[policy_id])
self._logger.info( self._logger.info(
"[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format( "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
...@@ -266,7 +270,7 @@ class OnevOneEvaluator(object): ...@@ -266,7 +270,7 @@ class OnevOneEvaluator(object):
"Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) + "Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) +
", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." ", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
) )
return stop_flag, eval_reward return stop_flag, eval_reward, return_info
class VectorEvalMonitor(object): class VectorEvalMonitor(object):
......
import os
import shutil import shutil
from easydict import EasyDict from easydict import EasyDict
from ding.league import BaseLeague, ActivePlayer from ding.league import BaseLeague, ActivePlayer
class DemoLeague(BaseLeague): class DemoLeague(BaseLeague):
def __init__(self, cfg):
super(DemoLeague, self).__init__(cfg)
self.reset_checkpoint_path = os.path.join(self.path_policy, 'reset_ckpt.pth')
# override # override
def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict: def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict:
assert isinstance(player, ActivePlayer), player.__class__ assert isinstance(player, ActivePlayer), player.__class__
...@@ -18,7 +24,12 @@ class DemoLeague(BaseLeague): ...@@ -18,7 +24,12 @@ class DemoLeague(BaseLeague):
# override # override
def _mutate_player(self, player: ActivePlayer): def _mutate_player(self, player: ActivePlayer):
pass for p in self.active_players:
result = p.mutate({'reset_checkpoint_path': self.reset_checkpoint_path})
if result is not None:
p.rating = self.metric_env.create_rating()
self.load_checkpoint(p.player_id, result) # load_checkpoint is set by the caller of league
self.save_checkpoint(result, p.checkpoint_path)
# override # override
def _update_player(self, player: ActivePlayer, player_info: dict) -> None: def _update_player(self, player: ActivePlayer, player_info: dict) -> None:
......
...@@ -7,6 +7,10 @@ class GameEnv(BaseEnv): ...@@ -7,6 +7,10 @@ class GameEnv(BaseEnv):
def __init__(self, game_type='prisoner_dilemma'): def __init__(self, game_type='prisoner_dilemma'):
self.game_type = game_type self.game_type = game_type
assert self.game_type in ['zero_sum', 'prisoner_dilemma'] assert self.game_type in ['zero_sum', 'prisoner_dilemma']
if self.game_type == 'prisoner_dilemma':
self.optimal_policy = [0, 1]
elif self.game_type == 'zero_sum':
self.optimal_policy = [0.375, 0.625]
def seed(self, seed, dynamic_seed=False): def seed(self, seed, dynamic_seed=False):
pass pass
...@@ -36,10 +40,10 @@ class GameEnv(BaseEnv): ...@@ -36,10 +40,10 @@ class GameEnv(BaseEnv):
results = "draws", "draws" results = "draws", "draws"
elif actions == [0, 1]: elif actions == [0, 1]:
rewards = -20, 0 rewards = -20, 0
results = "wins", "losses" results = "losses", "wins"
elif actions == [1, 0]: elif actions == [1, 0]:
rewards = 0, -20 rewards = 0, -20
results = "losses", "wins" results = "wins", "losses"
elif actions == [1, 1]: elif actions == [1, 1]:
rewards = -10, -10 rewards = -10, -10
results = 'draws', 'draws' results = 'draws', 'draws'
......
...@@ -6,7 +6,8 @@ league_demo_ppo_config = dict( ...@@ -6,7 +6,8 @@ league_demo_ppo_config = dict(
collector_env_num=8, collector_env_num=8,
evaluator_env_num=10, evaluator_env_num=10,
n_evaluator_episode=100, n_evaluator_episode=100,
stop_value=[-0.01, -5], env_type='prisoner_dilemma', # ['zero_sum', 'prisoner_dilemma']
stop_value=[-10.1, -5.05], # prisoner_dilemma
manager=dict(shared_memory=False, ), manager=dict(shared_memory=False, ),
), ),
policy=dict( policy=dict(
...@@ -58,7 +59,7 @@ league_demo_ppo_config = dict( ...@@ -58,7 +59,7 @@ league_demo_ppo_config = dict(
one_phase_step=200, one_phase_step=200,
branch_probs=dict(pfsp=1.0, ), branch_probs=dict(pfsp=1.0, ),
strong_win_rate=0.7, strong_win_rate=0.7,
mutate_prob=0.0, mutate_prob=0.5,
), ),
use_pretrain=False, use_pretrain=False,
use_pretrain_init_historical=False, use_pretrain_init_historical=False,
...@@ -66,7 +67,14 @@ league_demo_ppo_config = dict( ...@@ -66,7 +67,14 @@ league_demo_ppo_config = dict(
type='battle', type='battle',
decay=0.99, decay=0.99,
min_win_rate_games=8, min_win_rate_games=8,
) ),
metric=dict(
mu=0,
sigma=25 / 3,
beta=25 / 3 / 2,
tau=0.0,
draw_probability=0.02,
),
), ),
), ),
), ),
......
...@@ -18,8 +18,17 @@ from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config ...@@ -18,8 +18,17 @@ from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config
class EvalPolicy1: class EvalPolicy1:
def __init__(self, optimal_policy: list) -> None:
assert len(optimal_policy) == 2
self.optimal_policy = optimal_policy
def forward(self, data: dict) -> dict: def forward(self, data: dict) -> dict:
return {env_id: {'action': torch.zeros(1)} for env_id in data.keys()} return {
env_id: {
'action': torch.from_numpy(np.random.choice([0, 1], p=self.optimal_policy, size=(1, )))
}
for env_id in data.keys()
}
def reset(self, data_id: list = []) -> None: def reset(self, data_id: list = []) -> None:
pass pass
...@@ -50,17 +59,26 @@ def main(cfg, seed=0, max_iterations=int(1e10)): ...@@ -50,17 +59,26 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
NaiveReplayBuffer, NaiveReplayBuffer,
save_cfg=True save_cfg=True
) )
env_type = cfg.env.env_type
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
evaluator_env1 = BaseEnvManager(env_fn=[GameEnv for _ in range(evaluator_env_num)], cfg=cfg.env.manager) evaluator_env1 = BaseEnvManager(
evaluator_env2 = BaseEnvManager(env_fn=[GameEnv for _ in range(evaluator_env_num)], cfg=cfg.env.manager) env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
evaluator_env2 = BaseEnvManager(
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
evaluator_env3 = BaseEnvManager(
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
evaluator_env1.seed(seed, dynamic_seed=False) evaluator_env1.seed(seed, dynamic_seed=False)
evaluator_env2.seed(seed, dynamic_seed=False) evaluator_env2.seed(seed, dynamic_seed=False)
evaluator_env3.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
league = DemoLeague(cfg.policy.other.league) league = DemoLeague(cfg.policy.other.league)
eval_policy1 = EvalPolicy1() eval_policy1 = EvalPolicy1(evaluator_env1._env_ref.optimal_policy)
eval_policy2 = EvalPolicy2() eval_policy2 = EvalPolicy2()
policies = {} policies = {}
learners = {} learners = {}
...@@ -70,7 +88,9 @@ def main(cfg, seed=0, max_iterations=int(1e10)): ...@@ -70,7 +88,9 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
model = VAC(**cfg.policy.model) model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model) policy = PPOPolicy(cfg.policy, model=model)
policies[player_id] = policy policies[player_id] = policy
collector_env = BaseEnvManager(env_fn=[GameEnv for _ in range(collector_env_num)], cfg=cfg.env.manager) collector_env = BaseEnvManager(
env_fn=[lambda: GameEnv(env_type) for _ in range(collector_env_num)], cfg=cfg.env.manager
)
collector_env.seed(seed) collector_env.seed(seed)
learners[player_id] = BaseLearner( learners[player_id] = BaseLearner(
...@@ -90,8 +110,11 @@ def main(cfg, seed=0, max_iterations=int(1e10)): ...@@ -90,8 +110,11 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
model = VAC(**cfg.policy.model) model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model) policy = PPOPolicy(cfg.policy, model=model)
policies['historical'] = policy policies['historical'] = policy
# use initial policy as another eval_policy
eval_policy3 = PPOPolicy(cfg.policy, model=copy.deepcopy(model)).collect_mode
main_key = [k for k in learners.keys() if k.startswith('main_player')][0] main_key = [k for k in learners.keys() if k.startswith('main_player')][0]
main_player = league.get_player_by_id(main_key)
main_learner = learners[main_key] main_learner = learners[main_key]
main_collector = collectors[main_key] main_collector = collectors[main_key]
# collect_mode ppo use multimonial sample for selecting action # collect_mode ppo use multimonial sample for selecting action
...@@ -113,25 +136,68 @@ def main(cfg, seed=0, max_iterations=int(1e10)): ...@@ -113,25 +136,68 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
exp_name=cfg.exp_name, exp_name=cfg.exp_name,
instance_name='uniform_evaluator' instance_name='uniform_evaluator'
) )
evaluator3_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator3_cfg.stop_value = 99999999 # stop_value of evaluator3 is a placeholder
evaluator3 = OnevOneEvaluator(
evaluator3_cfg,
evaluator_env3, [policies[main_key].collect_mode, eval_policy3],
tb_logger,
exp_name=cfg.exp_name,
instance_name='init_evaluator'
)
def load_checkpoint_fn(player_id: str, ckpt_path: str):
state_dict = torch.load(ckpt_path)
policies[player_id].learn_mode.load_state_dict(state_dict)
torch.save(policies['historical'].learn_mode.state_dict(), league.reset_checkpoint_path)
league.load_checkpoint = load_checkpoint_fn
for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts): for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
torch.save(policies[player_id].collect_mode.state_dict(), player_ckpt_path) torch.save(policies[player_id].collect_mode.state_dict(), player_ckpt_path)
league.judge_snapshot(player_id, force=True) league.judge_snapshot(player_id, force=True)
init_main_player_rating = league.metric_env.create_rating(mu=0)
for run_iter in range(max_iterations): for run_iter in range(max_iterations):
if evaluator1.should_eval(main_learner.train_iter): if evaluator1.should_eval(main_learner.train_iter):
stop_flag1, reward = evaluator1.eval( stop_flag1, reward, episode_info = evaluator1.eval(
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
) )
win_loss_result = [e['result'] for e in episode_info[0]]
# set fixed NE policy trueskill(exposure) equal 10
main_player.rating = league.metric_env.rate_1vsC(
main_player.rating, league.metric_env.create_rating(mu=10, sigma=1e-8), win_loss_result
)
tb_logger.add_scalar('fixed_evaluator_step/reward_mean', reward, main_collector.envstep) tb_logger.add_scalar('fixed_evaluator_step/reward_mean', reward, main_collector.envstep)
if evaluator2.should_eval(main_learner.train_iter): if evaluator2.should_eval(main_learner.train_iter):
stop_flag2, reward = evaluator2.eval( stop_flag2, reward, episode_info = evaluator2.eval(
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
) )
win_loss_result = [e['result'] for e in episode_info[0]]
# set random(uniform) policy trueskill(exposure) equal 0
main_player.rating = league.metric_env.rate_1vsC(
main_player.rating, league.metric_env.create_rating(mu=0, sigma=1e-8), win_loss_result
)
tb_logger.add_scalar('uniform_evaluator_step/reward_mean', reward, main_collector.envstep) tb_logger.add_scalar('uniform_evaluator_step/reward_mean', reward, main_collector.envstep)
if evaluator3.should_eval(main_learner.train_iter):
_, reward, episode_info = evaluator3.eval(
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
)
win_loss_result = [e['result'] for e in episode_info[0]]
# use init main player as another evaluator metric
main_player.rating, init_main_player_rating = league.metric_env.rate_1vs1(
main_player.rating, init_main_player_rating, win_loss_result
)
tb_logger.add_scalar('init_evaluator_step/reward_mean', reward, main_collector.envstep)
tb_logger.add_scalar(
'league/init_main_player_trueskill', init_main_player_rating.exposure, main_collector.envstep
)
if stop_flag1 and stop_flag2: if stop_flag1 and stop_flag2:
break break
for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts): for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
tb_logger.add_scalar(
'league/{}_trueskill'.format(player_id),
league.get_player_by_id(player_id).rating.exposure, main_collector.envstep
)
collector, learner = collectors[player_id], learners[player_id] collector, learner = collectors[player_id], learners[player_id]
job = league.get_job_info(player_id) job = league.get_job_info(player_id)
opponent_player_id = job['player_id'][1] opponent_player_id = job['player_id'][1]
...@@ -144,7 +210,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)): ...@@ -144,7 +210,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
opponent_policy = policies[opponent_player_id].collect_mode opponent_policy = policies[opponent_player_id].collect_mode
collector.reset_policy([policies[player_id].collect_mode, opponent_policy]) collector.reset_policy([policies[player_id].collect_mode, opponent_policy])
train_data, episode_info = collector.collect(train_iter=learner.train_iter) train_data, episode_info = collector.collect(train_iter=learner.train_iter)
train_data, episode_info = train_data[0], episode_info[0] # only use launer player data for training train_data, episode_info = train_data[0], episode_info[0] # only use launch player data for training
for d in train_data: for d in train_data:
d['adv'] = d['reward'] d['adv'] = d['reward']
...@@ -156,7 +222,9 @@ def main(cfg, seed=0, max_iterations=int(1e10)): ...@@ -156,7 +222,9 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
player_info['player_id'] = player_id player_info['player_id'] = player_id
league.update_active_player(player_info) league.update_active_player(player_info)
league.judge_snapshot(player_id) league.judge_snapshot(player_id)
# set eval_flag=True to enable trueskill update
job_finish_info = { job_finish_info = {
'eval_flag': True,
'launch_player': job['launch_player'], 'launch_player': job['launch_player'],
'player_id': job['player_id'], 'player_id': job['player_id'],
'result': [e['result'] for e in episode_info], 'result': [e['result'] for e in episode_info],
......
...@@ -50,10 +50,17 @@ def main(cfg, seed=0, max_iterations=int(1e10)): ...@@ -50,10 +50,17 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
NaiveReplayBuffer, NaiveReplayBuffer,
save_cfg=True save_cfg=True
) )
env_type = cfg.env.env_type
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[GameEnv for _ in range(collector_env_num)], cfg=cfg.env.manager) collector_env = BaseEnvManager(
evaluator_env1 = BaseEnvManager(env_fn=[GameEnv for _ in range(evaluator_env_num)], cfg=cfg.env.manager) env_fn=[lambda: GameEnv(env_type) for _ in range(collector_env_num)], cfg=cfg.env.manager
evaluator_env2 = BaseEnvManager(env_fn=[GameEnv for _ in range(evaluator_env_num)], cfg=cfg.env.manager) )
evaluator_env1 = BaseEnvManager(
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
evaluator_env2 = BaseEnvManager(
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
collector_env.seed(seed) collector_env.seed(seed)
evaluator_env1.seed(seed, dynamic_seed=False) evaluator_env1.seed(seed, dynamic_seed=False)
...@@ -102,10 +109,10 @@ def main(cfg, seed=0, max_iterations=int(1e10)): ...@@ -102,10 +109,10 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
for _ in range(max_iterations): for _ in range(max_iterations):
if evaluator1.should_eval(learner1.train_iter): if evaluator1.should_eval(learner1.train_iter):
stop_flag1, reward = evaluator1.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep) stop_flag1, reward, _ = evaluator1.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep)
tb_logger.add_scalar('fixed_evaluator_step/reward_mean', reward, collector.envstep) tb_logger.add_scalar('fixed_evaluator_step/reward_mean', reward, collector.envstep)
if evaluator2.should_eval(learner1.train_iter): if evaluator2.should_eval(learner1.train_iter):
stop_flag2, reward = evaluator2.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep) stop_flag2, reward, _ = evaluator2.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep)
tb_logger.add_scalar('uniform_evaluator_step/reward_mean', reward, collector.envstep) tb_logger.add_scalar('uniform_evaluator_step/reward_mean', reward, collector.envstep)
if stop_flag1 and stop_flag2: if stop_flag1 and stop_flag2:
break break
......
...@@ -68,6 +68,7 @@ setup( ...@@ -68,6 +68,7 @@ setup(
'opencv-python', # pypy incompatible 'opencv-python', # pypy incompatible
'enum_tools', 'enum_tools',
'scipy', 'scipy',
'trueskill',
], ],
extras_require={ extras_require={
'test': [ 'test': [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册