未验证 提交 a9c4cbb9 编写于 作者: B Bo Zhou 提交者: GitHub

Merge pull request #2054 from zenghsh3/develop

 fix a compatible problem in DQN introduced by the the newest version of fluid
......@@ -71,10 +71,14 @@ class DQNModel(object):
optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars_name = [
x.name.replace('target', 'policy') for x in target_vars]
policy_vars = list(filter(
lambda x: x.name in policy_vars_name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
......
......@@ -79,10 +79,14 @@ class DoubleDQNModel(object):
optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars_name = [
x.name.replace('target', 'policy') for x in target_vars]
policy_vars = list(filter(
lambda x: x.name in policy_vars_name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
......
......@@ -71,10 +71,14 @@ class DuelingDQNModel(object):
optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars_name = [
x.name.replace('target', 'policy') for x in target_vars]
policy_vars = list(filter(
lambda x: x.name in policy_vars_name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
......
......@@ -2,4 +2,4 @@ numpy
gym
tqdm
opencv-python
paddlepaddle-gpu==0.12.0
paddlepaddle-gpu>=1.0.0
......@@ -19,7 +19,6 @@ from collections import deque
UPDATE_FREQ = 4
#MEMORY_WARMUP_SIZE = 2000
MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84)
......@@ -109,7 +108,7 @@ def train_agent():
print('Input algorithm name error!')
return
with tqdm(total=MEMORY_WARMUP_SIZE) as pbar:
with tqdm(total=MEMORY_WARMUP_SIZE, desc='Memory warmup') as pbar:
while len(exp) < MEMORY_WARMUP_SIZE:
total_reward, step = run_train_episode(agent, env, exp)
pbar.update(step)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册