未验证 提交 a9c4cbb9 编写于 作者: B Bo Zhou 提交者: GitHub

Merge pull request #2054 from zenghsh3/develop

 fix a compatible problem in DQN introduced by the the newest version of fluid
...@@ -71,10 +71,14 @@ class DQNModel(object): ...@@ -71,10 +71,14 @@ class DQNModel(object):
optimizer.minimize(cost) optimizer.minimize(cost)
vars = list(self.train_program.list_vars()) vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter( target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars_name = [
x.name.replace('target', 'policy') for x in target_vars]
policy_vars = list(filter(
lambda x: x.name in policy_vars_name, vars))
policy_vars.sort(key=lambda x: x.name) policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name) target_vars.sort(key=lambda x: x.name)
......
...@@ -79,10 +79,14 @@ class DoubleDQNModel(object): ...@@ -79,10 +79,14 @@ class DoubleDQNModel(object):
optimizer.minimize(cost) optimizer.minimize(cost)
vars = list(self.train_program.list_vars()) vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter( target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars_name = [
x.name.replace('target', 'policy') for x in target_vars]
policy_vars = list(filter(
lambda x: x.name in policy_vars_name, vars))
policy_vars.sort(key=lambda x: x.name) policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name) target_vars.sort(key=lambda x: x.name)
......
...@@ -71,10 +71,14 @@ class DuelingDQNModel(object): ...@@ -71,10 +71,14 @@ class DuelingDQNModel(object):
optimizer.minimize(cost) optimizer.minimize(cost)
vars = list(self.train_program.list_vars()) vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter( target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars_name = [
x.name.replace('target', 'policy') for x in target_vars]
policy_vars = list(filter(
lambda x: x.name in policy_vars_name, vars))
policy_vars.sort(key=lambda x: x.name) policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name) target_vars.sort(key=lambda x: x.name)
......
...@@ -2,4 +2,4 @@ numpy ...@@ -2,4 +2,4 @@ numpy
gym gym
tqdm tqdm
opencv-python opencv-python
paddlepaddle-gpu==0.12.0 paddlepaddle-gpu>=1.0.0
...@@ -19,7 +19,6 @@ from collections import deque ...@@ -19,7 +19,6 @@ from collections import deque
UPDATE_FREQ = 4 UPDATE_FREQ = 4
#MEMORY_WARMUP_SIZE = 2000
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20 MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
...@@ -109,7 +108,7 @@ def train_agent(): ...@@ -109,7 +108,7 @@ def train_agent():
print('Input algorithm name error!') print('Input algorithm name error!')
return return
with tqdm(total=MEMORY_WARMUP_SIZE) as pbar: with tqdm(total=MEMORY_WARMUP_SIZE, desc='Memory warmup') as pbar:
while len(exp) < MEMORY_WARMUP_SIZE: while len(exp) < MEMORY_WARMUP_SIZE:
total_reward, step = run_train_episode(agent, env, exp) total_reward, step = run_train_episode(agent, env, exp)
pbar.update(step) pbar.update(step)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册