diff --git a/benchmark/torch/dqn/train.py b/benchmark/torch/dqn/train.py index ba64b95c93a9b4879621331ad30cce3cbcbcac16..0579b2154049c66d74f0e599483a9226e8eea4e1 100644 --- a/benchmark/torch/dqn/train.py +++ b/benchmark/torch/dqn/train.py @@ -121,7 +121,7 @@ def main(): model = AtariModel(CONTEXT_LEN, act_dim, args.algo) if args.algo in ['DQN', 'Dueling']: algorithm = DQN(model, gamma=GAMMA, lr=args.lr) - elif args.algo is 'Double': + elif args.algo == 'Double': algorithm = DDQN(model, gamma=GAMMA, lr=args.lr) agent = AtariAgent(algorithm, act_dim=act_dim)