From 3a27f4072ef30786bad21318aa1ecdde23fc3292 Mon Sep 17 00:00:00 2001 From: Zheyue Tan Date: Mon, 27 Jul 2020 15:17:13 +0800 Subject: [PATCH] Add Prioritized DQN (#326) - add prioritized dqn - fix#239 --- .../DQN_variant/rom_files/battle_zone.bin | Bin 0 -> 8192 bytes examples/Prioritized_DQN/README.md | 30 +++ examples/Prioritized_DQN/atari.py | 1 + examples/Prioritized_DQN/atari_agent.py | 108 +++++++++ examples/Prioritized_DQN/atari_model.py | 51 ++++ examples/Prioritized_DQN/atari_wrapper.py | 1 + examples/Prioritized_DQN/per_alg.py | 127 ++++++++++ examples/Prioritized_DQN/proportional_per.py | 157 +++++++++++++ examples/Prioritized_DQN/result.png | Bin 0 -> 24307 bytes examples/Prioritized_DQN/rom_files | 1 + examples/Prioritized_DQN/train.py | 218 ++++++++++++++++++ examples/Prioritized_DQN/utils.py | 1 + parl/algorithms/fluid/ddqn.py | 22 +- 13 files changed, 700 insertions(+), 17 deletions(-) create mode 100644 examples/DQN_variant/rom_files/battle_zone.bin create mode 100644 examples/Prioritized_DQN/README.md create mode 120000 examples/Prioritized_DQN/atari.py create mode 100644 examples/Prioritized_DQN/atari_agent.py create mode 100644 examples/Prioritized_DQN/atari_model.py create mode 120000 examples/Prioritized_DQN/atari_wrapper.py create mode 100644 examples/Prioritized_DQN/per_alg.py create mode 100644 examples/Prioritized_DQN/proportional_per.py create mode 100644 examples/Prioritized_DQN/result.png create mode 120000 examples/Prioritized_DQN/rom_files create mode 100644 examples/Prioritized_DQN/train.py create mode 120000 examples/Prioritized_DQN/utils.py diff --git a/examples/DQN_variant/rom_files/battle_zone.bin b/examples/DQN_variant/rom_files/battle_zone.bin new file mode 100644 index 0000000000000000000000000000000000000000..662fe240d4af708c5b5dc887d359632574e96b39 GIT binary patch literal 8192 zcma($4Okr2wRdM`XMb3B2lBTGWOpFs2NHGrqg^c$Ow&~8)-SKGwbokQuPuGSSD*HK z)#$6jChlD+uW9#t6(zQjb)1>((4dPi8%xY4Bmv4`V+YL-=7)eL2$~p?z!ild*^1`WExYrL_ydj9n|gW9Z|Y6K?h1s$VUN}mAdORcHq^x zU>`cEII#s=6*tZ!BL;G!mU{n4K}EVbb*LW;srhH|^C{Ga%~GhC)~xc;;re1PJ{m5F zc?=CjIImT}`3lwXx<+Ad#~V<4>hNmChxuDKjS8j+(-ZMV4Cc18{a3kldQ_vgc}*Q} zWFU{JPO*YbS%OAP#9r)aHzm3w{G1xajG51?!mFi}YELiry2Py}oX3(C*@=To?u5N;=gVW5PcmtRcf^x>&gULT&8!U9c@qMRhi zT?E+zND7d*Zn|}oBw${89v4U=UKTDS7RbA${l&WKOXQ?pJncWe~2|?OB9^ZQ3P;!b{tZ!1Cezww`{>;6j6L zou!{s(u4$D9wmVblyb}~JT6oQ@G>F%jAFve>j)=^#I{X1dlPnU!Y%^T0rU)ZeK8Se zd$FsI>i~=X3_(F`0WEkb_6SVGhO@B~yFgEokbXwGfUVo+o4^9E`~rBRsQf1dwcxlRNRjqj6%+d^k@E=?3viAhd3M^f!845on0h97G^Y_5fKnhZv_hBx z2SgG|1V=*234Ibuqz|ElpNB1suy>RHGs=qae1d_jVU>okUHV^~CAAsJ1hDOE^oQ#@ zru2f2UHUTZAFaajR-C7x_BCV(&etDpl`-4ejJbA$uVBXMK zAiaxib)42s283eIg>Dgkv5>G+nBKxXY z+Clp9n5AedxFd)27KJ!Cf|NV(lyVPlR;nQO{~b2cFR`d>#)6SQ6dJEEd98Rk|%+`@Z3tLFDlvKu1NC2Kx1@J3`-!jr+NM(*m#$>h0S;sC7djgw} zLzbA$V}T22uVw>c)74xMZ!oe|5Rqce6a!wf=9+7-z0T+R(pO4~i;GvTSh1|IFc`e& zo_l_N|09t|_h0||+L6D#di>b2b?d+N?TtSO-n%*Cc9*(~YEbFAG%7uOxCXi1HR}L& zL%U_mmNJB_P|6AmYd}6A-R`1w>(dj^l3>^-w-GnnpD3b$Kmbmy30|Z18ijwPoHK z$5+;_%_Wt`0l&3O8Jn*0&N#hJFXLekr)$z+b)J`W?RF*uIGs#Q8Ym7Y@EIg_mbIuf zolcjg-T71bWwo`nWwj1xnbW(L2zGcD5-kQ~=rbUX!vi<~J&d!gjDZetk}*)g2FB^} zFlDud(9-;}n*Yc@3@~T_Hx?DSud;zjS1I7+H)qSqlR7EonPR{klwhV1@DBf~C3gGX z=Dp4F=C~Y}Wm%L(QUp;D1TqUUk1~g+rpC#jIcpuBvNEuam^o*sGh-+*4lKR9%wZTb zXKShZ6I~Jp;6W%gX^=cKGvo0%kfE!r7F<$OQ$rAQ%7|;qV20?bI`GRvCzA*3F{|GR zndH+5Clu!GP1L6gZR~-o75!;4WwkPylEME!?6W*pgoq0v`U!%fzJQw$fKJ#zX?Z!% z^JD;Uf1a6eavX4!muCzNg@B&}=Fjm0D=?OaLK#M=%Mp)&^-4bjFvyBH#Fu{pT}~u^ zqEA2~t^8BCcAw)@eS{po_RnCZKGR32?$gENc<7{b>W264ths#mFHU~-7f0t?eG9)? zpS!#gwXQG!pQ2lzw4z)FtwUCR%X*Xx4YFCyjM+NRmE~H*=Y1(}`M+IXa((GHc71$QzVFAI8(urOyJ`3EBlrF&aPyDqPW80KpKiO5@jgX;QR|#%6sN~wlzb=TCnrri zJv3A_?O~yKb8~Y^-@|bp!Vj$z*56{GD5DQJgWn6JxjD18h4UBMSX8XEDf3WC$=6H% zfS!M@;Q8m!(+f6#e=~Z3^KlP=??3&IbMk_ixj$Zvcdi31Yzxn4fe;=Nk{VV#=_?w&IKiPt#{z=X+V1sjM?r<`j zE##E4SV#l4#i*b~*ZsJD&snx4aMQx~yPmmkEqDF$1>l$0_SZ~JO(B|fWHt@#n%Tnh z1qHde4hKtHZ`-8A<1fC@vg_qNdx-jsqtdW<f&_6)*fy`DcbfFH5Ag=p?5T@>m4anb)c$UJP>Zj2jD`*6qR!3+qP1P zB}y~ifuEJ1OQqk(F#+*+Uyf{r5AeO3r)aCs+!CTRhx8oY3m2hUje(oeuzXNo+M~H6 zgAp~-8_^=&(q7Ffgd$t1*l0ubHIBX*B6sjnsQ(gN27xQZO4(-bzhF-}a{Ier+ z#)e+$GDgfYbiKmokP-f&SrJ3~G6;4EM8|3yg-*U|yiqt})O|+X54EARYnI`DgMrk2 zJgNJkmQD$$`QeGK0gjahix1#IV_L-?ZAr^ksCG0FcMb(?kP&( zBRoVy0bUL_cpG`jQF<^7pFEfYhtL7#1G&lfUs9;^yjI+Gj`PA4n|wdP3mYpHd(>2k z-EA+;;3dQGB(p$qX=~h%ZAjzu;yGmoUZ+N<2Ppk0UNAuGZx5I>Yolt3QX`glY7jnV z;X=Jg&sOLGq+3$t>lP4;6{@pIw<}0zz(9yk{b6v+J+gYPSL#eKa<6oRI3R{9p;Wjk z2bcsCUL4KW%#}j61|L`o6Isv!m!SxdTt%} zt~Bhf<`PWIq_LpR6+>;WrSZbHzo%hFAE)sWxk+W-P|^ljw=iViWsl= zS){=q>S;XKrN|2RdRJVGbKSxr{?}7N4^*&_sv~MM0te#k@5F>-u*?wHCyl^{$cUlv zh3?Gka+9>rnKZ$32cynJ8Rc4hp^iHTPHe?5;}`I5{1W~P#O3!O7=LEOCfNNgLpT6G zk~%w&B=Ti7vR~?jo15LdAMY1>V!QdKNr>$i0+S!gy+Bj1&<XLQER^8@@MsKJ~nA{INaYN)GoWy{2B zeBY}vEeBdx!w%K1S>Q0*(Zjtn;g)t)?Z7*tF7g(jGCHH$G-_zCUqMyR;%ECZL&-o} ztN{kYz-BUVB-rT7t>+kcnjVgb-Q68xlvdI~1qG8P_ai3MU`nD#5NaTO|74~=Iorol zlEyOlL*(ogX)H<-)KZetA;wpTJ_{?CaAQ6i8p0)u@a7Q0QViAm&&yGBbE8>`*V_|x zbClX3Z;-U%u-sT#>93M2MUk_@BgJ=sFX*oRPOx6HHTr0b$)BFB?v)taQ> zYf(yAB4U*ocx(cT{6}NNKrteMOi^lvJ{Id$EwK!uoK>0xr2DGGuoz2+TYv+Q4BO%x ztCw2>9TOwLW}wuDTfi(emS8rtXoMohw^OLij5JfQ-j`LGfX01qV!D~k9^euJMkFjc9!%0Sb*Y&%LtBm9lN)9@YopJsdO1@t zmKI1nlH&x6Nt(1(VlW(!>`cS(T1*i66s%b;P$LwB3ZjSwS(r$0Y8AE$^Q+!`>&__SMgXML7Jj z_nn-1^W03-JS27}#wF1hrl2CpM=93rX0!*Ta%Lb(5Pt!72LHo}r|6&@2VpGiC;@>? zA0=2py?Ku*^);A`V!FynEP}j~1*EE(8)>wiMnUO7Fa5~=EDyyM; z8%(Wo10n#;|87!lQNNsEl#LCvbol(M86QiX=Z(V~LWzq$p_DhUAb{3<3e&0TC7r$c z3NhC1iu&R7X7q)M8L%8+9ZGEjY}*EtV%G8%Msq7B;H!;t?Rvr`Hl{p?PeLa5kD18?VGIK+ z-=UT#QG%)n4F`rN3?d*GsWXXrapGG?duTU`(%lJD-IFLtFdviKiuPd=T^RsLFjg70 z!wrK?(w+T@`I4xHqRcpmJr4Te{tUWl{(!j!u+|_&oN}33%GV#7< zhur~pBtp!%&jXoe!e@u;D_b0@XxBVaNaZ=VI^dgM5z;ugXFY*iBD;uwsYTvJCKlt9 zG&tU#WUJW*7S>{*(fyzq?~1e#F5(BP+#(r6U@d?(ah+;U%xVSK79`F1Sj;s18TZ58 z9M(l(0(V~yN->D5Y`CU%#uBt*BF9gK#YQg=>6f$*qyvGUPDgga$zq&9vl)SGeE%ek zX1f9;`@}Vp2=z+XUlyeE0#a#xiOLw8<0*EJC3QN@|9V`YpRzz$Kn8&{;jX51Ms12x z@M66sc>j}!dJg^ZC7~yqf!8bV(vbC84o&z_yAREz(ZBm_&~_+y&h_-dzBv=#5!^p zdzY{NvcGJ7<;L6Z`pKsI{`29Q$A9y?KQ!(TcK$h$Y)c*NRFCu??K{;!cOCMMp~-_!H5T`r56Vp;I6l}tPaR zh55^hue;&qZ&d!v9ryhFz6T$}+ZuMrFSZ@}+ljZ{`#8d}t%j hZ{$CC;a}H&`}d}86X{@i2 + +

+ +## How to use + +### Dependencies: ++ [paddlepaddle>=1.6.1](https://github.com/PaddlePaddle/Paddle) ++ [parl](https://github.com/PaddlePaddle/PARL) ++ gym[atari]==0.17.2 ++ atari-py==0.2.6 ++ tqdm ++ [ale_python_interface](https://github.com/mgbellemare/Arcade-Learning-Environment) + + +### Start Training: +Train on BattleZone game: +```bash +python train.py --rom ./rom_files/battle_zone.bin +``` + +> To train on more games, you can install more rom files from [here](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms). diff --git a/examples/Prioritized_DQN/atari.py b/examples/Prioritized_DQN/atari.py new file mode 120000 index 0000000..e0e1b3c --- /dev/null +++ b/examples/Prioritized_DQN/atari.py @@ -0,0 +1 @@ +../DQN_variant/atari.py \ No newline at end of file diff --git a/examples/Prioritized_DQN/atari_agent.py b/examples/Prioritized_DQN/atari_agent.py new file mode 100644 index 0000000..f086baf --- /dev/null +++ b/examples/Prioritized_DQN/atari_agent.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +import parl +from parl import layers + +IMAGE_SIZE = (84, 84) +CONTEXT_LEN = 4 + + +class AtariAgent(parl.Agent): + def __init__(self, algorithm, act_dim, update_freq): + super(AtariAgent, self).__init__(algorithm) + assert isinstance(act_dim, int) + self.act_dim = act_dim + self.exploration = 1.0 + self.global_step = 0 + self.update_target_steps = 10000 // 4 + self.update_freq = update_freq + + def build_program(self): + self.pred_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data( + name='obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + self.value = self.alg.predict(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data( + name='obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + action = layers.data(name='act', shape=[1], dtype='int32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data( + name='next_obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + sample_weight = layers.data( + name='sample_weight', shape=[1], dtype='float32') + self.cost, self.delta = self.alg.learn( + obs, action, reward, next_obs, terminal, sample_weight) + + def sample(self, obs, decay_exploration=True): + sample = np.random.random() + if sample < self.exploration: + act = np.random.randint(self.act_dim) + else: + if np.random.random() < 0.01: + act = np.random.randint(self.act_dim) + else: + obs = np.expand_dims(obs, axis=0) + pred_Q = self.fluid_executor.run( + self.pred_program, + feed={'obs': obs.astype('float32')}, + fetch_list=[self.value])[0] + pred_Q = np.squeeze(pred_Q, axis=0) + act = np.argmax(pred_Q) + if decay_exploration: + self.exploration = max(0.1, self.exploration - 1e-6) + return act + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + pred_Q = self.fluid_executor.run( + self.pred_program, + feed={'obs': obs.astype('float32')}, + fetch_list=[self.value])[0] + pred_Q = np.squeeze(pred_Q, axis=0) + act = np.argmax(pred_Q) + return act + + def learn(self, obs, act, reward, next_obs, terminal, sample_weight): + if self.global_step % self.update_target_steps == 0: + self.alg.sync_target() + self.global_step += 1 + + act = np.expand_dims(act, -1) + reward = np.clip(reward, -1, 1) + feed = { + 'obs': obs.astype('float32'), + 'act': act.astype('int32'), + 'reward': reward.astype('float32'), + 'next_obs': next_obs.astype('float32'), + 'terminal': terminal.astype('bool'), + 'sample_weight': sample_weight.astype('float32') + } + cost, delta = self.fluid_executor.run( + self.learn_program, feed=feed, fetch_list=[self.cost, self.delta]) + return cost, delta diff --git a/examples/Prioritized_DQN/atari_model.py b/examples/Prioritized_DQN/atari_model.py new file mode 100644 index 0000000..bc61a8b --- /dev/null +++ b/examples/Prioritized_DQN/atari_model.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +import parl +from parl import layers + + +class AtariModel(parl.Model): + def __init__(self, act_dim): + self.act_dim = act_dim + + self.conv1 = layers.conv2d( + num_filters=32, filter_size=5, stride=1, padding=2, act='relu') + self.conv2 = layers.conv2d( + num_filters=32, filter_size=5, stride=1, padding=2, act='relu') + self.conv3 = layers.conv2d( + num_filters=64, filter_size=4, stride=1, padding=1, act='relu') + self.conv4 = layers.conv2d( + num_filters=64, filter_size=3, stride=1, padding=1, act='relu') + + self.fc1 = layers.fc(size=act_dim) + + def value(self, obs): + obs = obs / 255.0 + out = self.conv1(obs) + out = layers.pool2d( + input=out, pool_size=2, pool_stride=2, pool_type='max') + out = self.conv2(out) + out = layers.pool2d( + input=out, pool_size=2, pool_stride=2, pool_type='max') + out = self.conv3(out) + out = layers.pool2d( + input=out, pool_size=2, pool_stride=2, pool_type='max') + out = self.conv4(out) + out = layers.flatten(out, axis=1) + + Q = self.fc1(out) + return Q diff --git a/examples/Prioritized_DQN/atari_wrapper.py b/examples/Prioritized_DQN/atari_wrapper.py new file mode 120000 index 0000000..2904fb3 --- /dev/null +++ b/examples/Prioritized_DQN/atari_wrapper.py @@ -0,0 +1 @@ +../DQN_variant/atari_wrapper.py \ No newline at end of file diff --git a/examples/Prioritized_DQN/per_alg.py b/examples/Prioritized_DQN/per_alg.py new file mode 100644 index 0000000..5b727cb --- /dev/null +++ b/examples/Prioritized_DQN/per_alg.py @@ -0,0 +1,127 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import numpy as np +import paddle.fluid as fluid + +import parl +from parl.core.fluid import layers + + +class PrioritizedDQN(parl.Algorithm): + def __init__(self, model, act_dim=None, gamma=None, lr=None): + """ DQN algorithm with prioritized experience replay. + + Args: + model (parl.Model): model defining forward network of Q function + act_dim (int): dimension of the action space + gamma (float): discounted factor for reward computation. + lr (float): learning rate. + """ + self.model = model + self.target_model = copy.deepcopy(model) + + assert isinstance(act_dim, int) + assert isinstance(gamma, float) + self.act_dim = act_dim + self.gamma = gamma + self.lr = lr + + def predict(self, obs): + """ use value model self.model to predict the action value + """ + return self.model.value(obs) + + def learn(self, obs, action, reward, next_obs, terminal, sample_weight): + """ update value model self.model with DQN algorithm + """ + + pred_value = self.model.value(obs) + next_pred_value = self.target_model.value(next_obs) + best_v = layers.reduce_max(next_pred_value, dim=1) + best_v.stop_gradient = True + target = reward + ( + 1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * best_v + + action_onehot = layers.one_hot(action, self.act_dim) + action_onehot = layers.cast(action_onehot, dtype='float32') + pred_action_value = layers.reduce_sum( + action_onehot * pred_value, dim=1) + delta = layers.abs(target - pred_action_value) + cost = sample_weight * layers.square_error_cost( + pred_action_value, target) + cost = layers.reduce_mean(cost) + optimizer = fluid.optimizer.Adam(learning_rate=self.lr, epsilon=1e-3) + optimizer.minimize(cost) + return cost, delta # `delta` is the TD-error + + def sync_target(self): + """ sync weights of self.model to self.target_model + """ + self.model.sync_weights_to(self.target_model) + + +class PrioritizedDoubleDQN(parl.Algorithm): + def __init__(self, model, act_dim=None, gamma=None, lr=None): + """ Double DQN algorithm + + Args: + model (parl.Model): model defining forward network of Q function. + gamma (float): discounted factor for reward computation. + """ + self.model = model + self.target_model = copy.deepcopy(model) + + assert isinstance(act_dim, int) + assert isinstance(gamma, float) + + self.act_dim = act_dim + self.gamma = gamma + self.lr = lr + + def predict(self, obs): + return self.model.value(obs) + + def learn(self, obs, action, reward, next_obs, terminal, sample_weight): + pred_value = self.model.value(obs) + action_onehot = layers.one_hot(action, self.act_dim) + pred_action_value = layers.reduce_sum( + action_onehot * pred_value, dim=1) + + # calculate the target q value + next_action_value = self.model.value(next_obs) + greedy_action = layers.argmax(next_action_value, axis=-1) + greedy_action = layers.unsqueeze(greedy_action, axes=[1]) + greedy_action_onehot = layers.one_hot(greedy_action, self.act_dim) + next_pred_value = self.target_model.value(next_obs) + max_v = layers.reduce_sum( + greedy_action_onehot * next_pred_value, dim=1) + max_v.stop_gradient = True + + target = reward + ( + 1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * max_v + delta = layers.abs(target - pred_action_value) + cost = sample_weight * layers.square_error_cost( + pred_action_value, target) + cost = layers.reduce_mean(cost) + optimizer = fluid.optimizer.Adam(learning_rate=self.lr, epsilon=1e-3) + optimizer.minimize(cost) + return cost, delta + + def sync_target(self): + """ sync weights of self.model to self.target_model + """ + self.model.sync_weights_to(self.target_model) diff --git a/examples/Prioritized_DQN/proportional_per.py b/examples/Prioritized_DQN/proportional_per.py new file mode 100644 index 0000000..90a16e4 --- /dev/null +++ b/examples/Prioritized_DQN/proportional_per.py @@ -0,0 +1,157 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +class SumTree(object): + def __init__(self, capacity): + self.capacity = capacity + self.elements = [None for _ in range(capacity)] + self.tree = [0 for _ in range(2 * capacity - 1)] + self._ptr = 0 + self._min = 10 + + def full(self): + return all(self.elements) # no `None` in self.elements + + def add(self, item, priority): + self.elements[self._ptr] = item + tree_idx = self._ptr + self.capacity - 1 + self.update(tree_idx, priority) + self._ptr = (self._ptr + 1) % self.capacity + self._min = min(self._min, priority) + + def update(self, tree_idx, priority): + diff = priority - self.tree[tree_idx] + self.tree[tree_idx] = priority + while tree_idx != 0: + tree_idx = (tree_idx - 1) >> 1 + self.tree[tree_idx] += diff + self._min = min(self._min, priority) + + def retrieve(self, value): + parent_idx = 0 + while True: + left_child_idx = 2 * parent_idx + 1 + right_child_idx = left_child_idx + 1 + if left_child_idx >= len(self.tree): + leaf_idx = parent_idx + break + else: + if value <= self.tree[left_child_idx]: + parent_idx = left_child_idx + else: + value -= self.tree[left_child_idx] + parent_idx = right_child_idx + elem_idx = leaf_idx - self.capacity + 1 + priority = self.tree[leaf_idx] + return self.elements[elem_idx], leaf_idx, priority + + def from_list(self, lst): + assert len(lst) == self.capacity + self.elements = list(lst) + for i in range(self.capacity - 1, 2 * self.capacity - 1): + self.update(i, 1.0) + + @property + def total_p(self): + return self.tree[0] + + +class ProportionalPER(object): + """Proportional Prioritized Experience Replay. + """ + + def __init__(self, + alpha, + seg_num, + size=1e6, + eps=0.01, + init_mem=None, + framestack=4): + self.alpha = alpha + self.seg_num = seg_num + self.size = int(size) + self.elements = SumTree(self.size) + if init_mem: + self.elements.from_list(init_mem) + self.framestack = framestack + self._max_priority = 1.0 + self.eps = eps + + def _get_stacked_item(self, idx): + """ For atari environment, we use a 4-frame-stack as input + """ + obs, act, reward, next_obs, done = self.elements.elements[idx] + stacked_obs = np.zeros((self.framestack, ) + obs.shape) + stacked_obs[-1] = obs + for i in range(self.framestack - 2, -1, -1): + elem_idx = (self.size + idx + i - self.framestack + 1) % self.size + obs, _, _, _, d = self.elements.elements[elem_idx] + if d: + break + stacked_obs[i] = obs + return (stacked_obs, act, reward, next_obs, done) + + def store(self, item, delta=None): + assert len(item) == 5 # (s, a, r, s', terminal) + if not delta: + delta = self._max_priority + assert delta >= 0 + ps = np.power(delta + self.eps, self.alpha) + self.elements.add(item, ps) + + def update(self, indices, priorities): + priorities = np.array(priorities) + self.eps + priorities_alpha = np.power(priorities, self.alpha) + for idx, priority in zip(indices, priorities_alpha): + self.elements.update(idx, priority) + self._max_priority = max(priority, self._max_priority) + + def sample_one(self): + assert self.elements.full(), "The replay memory is not full!" + sample_val = np.random.uniform(0, self.elements.total_p) + item, tree_idx, _ = self.elements.retrieve(sample_val) + return item, tree_idx + + def sample(self, beta=1): + """ sample a batch of `seg_num` transitions + Args: + beta: float, degree of using importance sampling weights, + 0 - no corrections, 1 - full correction + + Return: + items: sampled transitions + indices: idxs of sampled items, used to update priorities later + sample_weights: importance sampling weight + """ + assert self.elements.full(), "The replay memory is not full!" + seg_size = self.elements.total_p / self.seg_num + seg_bound = [(seg_size * i, seg_size * (i + 1)) + for i in range(self.seg_num)] + items, indices, priorities = [], [], [] + for low, high in seg_bound: + sample_val = np.random.uniform(low, high) + _, tree_idx, priority = self.elements.retrieve(sample_val) + elem_idx = tree_idx - self.elements.capacity + 1 + item = self._get_stacked_item(elem_idx) + items.append(item) + indices.append(tree_idx) + priorities.append(priority) + + batch_probs = self.size * np.array(priorities) / self.elements.total_p + min_prob = self.size * self.elements._min / self.elements.total_p + sample_weights = np.power(batch_probs / min_prob, -beta) + return np.array(items), np.array(indices), sample_weights diff --git a/examples/Prioritized_DQN/result.png b/examples/Prioritized_DQN/result.png new file mode 100644 index 0000000000000000000000000000000000000000..96271e70f9e8619e8f86f13d8cfb963887b4fa4b GIT binary patch literal 24307 zcmb?@byOTr@F&hL5*8=0$l?}U10k?D!98dQP9SJ-2r@V10RHeftiGZ z{m`Of%vJnwd7`bZuXuldzrTNWe0(wbZCmSIQ^(+%kdV;okNt{9l)b&Ze|leXa`O22 z`1SSm@$s?Gs<|5m1`~#=f}FnZ(qaDU8@k2EgX=ilZ)g&t(`hOq%@3W6v^fv)es?<# zI-eToVjXG784c0ev2M9+^fx{s$Wld(+SU(}w)WZ0H)?Z^e=oHtu}OUE<-e>g^DZroH>5mk#i?m%M?KZHyDXrn3@geiX&LayP9&>a7RxC^M&} zexNvptbI7p>h*XU7{+{94cZ+c_tj*`jIm3NiWJh=`NGaxmSuxWz|Mu@d--FQLK@yu zStt1or@iK>`b$C6u?%9<-A~;y8O*tFS~6@)j6jQG4wwE13fQR-kM*mr?Ocd_YWt8OM9uKny{`7|s?g<>;knRjo5k&G zhxo;sEi}mz_pjB?17AX+m#uMYFMfwY^(DU1liGdx4kqQf%c5`JU;Z2Bi}7rURpc%EJ|!AS^rt6UAaK z)GwfM0=!Jm`qWf?WIEQ_(^A&TUtWlX&+vbv`d8dIqi77wDH|5nv8LybRiSV%^J`77caccnbRUU;}<1tKzB-to1?45^qu1~B&dS} zK3}OWxyJquj%;cCkbE>#EhoZN2WUDtJ3F846DT%{%SELeg{V_)e%4ArotCimZIlh?AQPNS=jdzgi)>I^{Ri*nS3D1JV@2ADki6U|JlsY7e?)^o0=BH!GOoY&vxlin7G=q^6T z=T1M>D$|wpPwg&t3kbXTOK^=JK}=BPX1*lPeHTH@EfaXH{p*Hb^#Aq2G0q<+&mY~3 z1#xmL?Ym8J{`fI)`u)dcYLHeyqGtN7LZW!EYI{e~qT&6zMSajt%+Hnj_W=V8|AsXW zZ436EGr70+h;l4dzHh%-ZdgBguOUwB=dd?%Llqy+0f|_5icQWzop|IPzzL5>(kSZyuGZYt&3ZEJCRB67yu&t+{&#-`!ALsQ zc%E}@8KsLT?s4Ri^?#!^ePf4PH^cIHC%CyJ!SyL;FC_mrk&0;X^xPIsj4(xc8^*voIo}vGrKvTVsY35Ki#IB{Uqdx`S8WUcc7?nfSzs1iM#+s**8t*Y z#8+;(2%d4FLVpF5Ztf+6vg43qFcWm6hQs%d;Fmrq5*f)W9oLLbXl=Nl2Gm!Ps;Hq= zjp>dV`kBuLYC8`;S-Z3Vx{#XEAL%a;`;Ee5%0uHfBjsJl(bbNfT5m0UVONsO9rY+W zQj2zWo0H00(&jE?Fso6Hr)QI3KH{KP0^Q^LnB&ikO%3AllN`*Mc}HgPnrRVncdrX^ z<|_m{LSgvFtDt_c-8g5xVT35{*VZ>}9QI}acAWxe#kOLWzvFe&<;sb0yE`g-G9n(i z&)TeM+^oQeqQmAS;Fy`iV>cz5JCQJ@I^5!@q$PxCef&-)Gof@07N>pr+5KXG6Q1{F zb=iHE_-Po6U$|HdAqJO4^-=B8pw4hGlwSXCF#)*}>-Gal#XO5mil^D`@@osg-t1$3 zihf0hM0SxJxL;_-NvN5IHc_^M*$yP3t>DFztjDy}9$3Hq_!NiR$=Q!vqLChC7PeoA z{gsWztMhistanhNuwBHQLbQ8JR9>SZv`nnhfv#7g*!0O0-<5cBgazdpcWkJ8Jo4Ew z;2MK8rIHn}R;tex>d#-ss_MuQl z3zXlu$=u0RO1{=T$Be*pLFM<%O_FaE`J@#DHu|*eA{H&<9xE$1v^tTGI-bT7p95dr znBa%>NA4UH9Eh)h)6$61IQJvHI`;wyK6B?F5|r%&hiy8%{u#`}0XUwOSUt46w|MK3 z#3L8j56=WUyqX6l-t(YwR~~3B_@D@l1QA#5vl1R#GWv+<~HWI$wM32ZA`xlsDsrDjLKJpPT2< zFSKxKm4PK{UdNeQzgy?fg}T^3&a)_XU3Dw;iwgx+ksm{;VVjjb2Z4od4U*_@Gh&Lv z-Q;}R#ta;@-A1qR-fK#f1b;`63(+f+?@xR_A?HRk@t_zn2rh`9X=$klrVZ3{|C)CC>+phtUNxU7DueT z^2wXZ`szC=%>(5i${gFV%{vXuCyymn4v#^W3{ICl7ve#+*VtN;{&g9%6=qO=Ug?Qy z>!0RSjd-IYQak#k3B;fH@c#HW>9h&YR*R$Gm2-70vu8lh#D|Z)Z-!AW@sh>4% zb6Wm?0FF;5shMq zUKca5J=KHUwxAb!u|`=_I4A^kAPr2a&>)1~6_S#~PZ_^IW*Gv)KszF!|hs)IWuvayAp(HaQBE z9e&OH?y$h)k#;bwLo|QbhzoFxL4X0v>NV7bWCD=NRxd9Z z0LCt}k`A*k!?mGQn%eY7!_#^e&CbJLf$7#{5CF}1Eh{0YIDS%PP3bWm8ue_#{IDa*&J-$)6B)^`RUY9=%yeyBn z(wef!Kn?_)w2HrJ|LSn3GB{!jchPN_>r(hvYC#9eeyQV3BAG+;y9sLB;*xUGXe4Ey z&JG2ZO4n?EM(&{@0Gk}B>!G1kET?C` ze9A*nkb>+BQ;W1LKFU5v4TySdSBh1&(UNZnRvE!p^l9AJBBQ}|jC3Da?IOjn(rTX$hA?n=guOk4V77!J9~})5UTRwjA-Wy!$3!410zaI7u8UWA{CUphdw{9P;W?dt#*HB zFK}yMT1=z9e~FYeR*`%gS7DZ|4N>moS4Oj_p0z-~Cr*MN!EL~TTQEq$JM1<~p<==T z00bF4x?(0haN_khL+roQc_fr@=qHQVcc$L50vQzhTK*flzjgCSBmVdZHLNHuSv{fh zs~&1=>q4j1^F_Q#xf?24T_k6Ln11Vrenj3q-zIL1k7mLG)LD8_jLoeq`H~(K(zxJ1Yi{poUb{HvaigFrM#L10^#OqxApbY zoU3VFuO^Oa#S&ECze{^*VEN=~yJ~sa_6g^p5rSRe(^6W+lBeMT8cU-EdTD!a;2m_P zXO9VgGytelfcg@$lQQlfu<&k@f2((G_W@Q*N@R!DC>#oS(^sN1CC@8F=H}N?OJTg9 ztIxU37x3clJHX!Jo<&j#AD~a_r@!Utf8^kA2Y`M-*vHh7ZX5|A7&TaQ5(uub`xxpb z9~ApEtMOR!<3;D6>Q>i&5mq8o$fGZu6Y-j7OEt+mhxGhI^V#El`<^N+&$&Aa@t#>~ zVXi0a6Vp5bZ=*$1yEZ-_9a;6J-~Z~i;JJVHrPA1SUt*J4a7aV`S7R4D{dD$}P~N;| z=Gij}kVXNcEd#>|9jw@z9i}33PAJ#qDWcYf2j6GAyx*MFS6%)xmx~ur#x?aU;e`1w zCp-at|UwvLWLhqz9226;ihy_nG*~k-TK}=*Ti=>}dp3vwty8f7Z5kBr` z0NkLJQm+u(d-|)rYX$~5&NWQ4GdiN|e?ZUKJDGEiF13W_+O``&_NaQl+NC*y%RB;V z|s^@8Z&e@CW`7t(Ql`BVD%vxgHV zI^uL}^g#D$oSI^9DtHr-yWlHv98!jdg4c#Z4XoifZva0Afi((F4zxK)}Y1Yj(SYJtgaWVk3diIg;FrZgHtP+TIf zx&%%`@m6KWZdTZoap;qOIzwY0YCoGZg8^)9ileJ7PVRSVwa}xRji0~ujQauWRx(Q% zR4!b=S;iq!bgL2VGS$g^2NMUg%RS!Fd{gh=?q;zDuYoaeR?zU!mnaQV}_zSv~zo)P~ z3)ObZVac#Ul}>_Jh&NWk@(i9S3jf`wDkvKP)vEEEP0KmLL4cC?S=3O<{e9aL+4&GSf=G4K1VqY2rcQu9r_-f5aB=^NEL ztB>XB8jPw490hj%zN~($${g-L)I8mrdqlSH5|EDy0V>hJ0!#11F#D7I@N*GwVd& z;0gO#C1l_Kqr$$=&yrMYdz-ok{}wL5)3yaV6$8r3NxKyYWhKo6B>`JHcmnWCGR?^z z;V(>*E;+iW^{>xgd6&BS5|WKHiF?bXXmgGhY&Sx7JR>pyklf$ImCEqr+MFUmwKP5*%u9N zzIipitmXSh1CDPCQ~IaTQCANJW~wq!++u+IF-Vnzm&BaX#@A6CLNBWl2j;q!N7Hkn z>8<^IQnl~L<#CGLo25WLDoHZZI}Y#hLFvghSnHX_R6$h1o?dSp`(dGbl+dMGtDH!m()iWgo^zVuhu5>5EByvW-}w9xiP8qTG}RH0zi zV!tPS8}Tb}GCQgl8ibZ0B#gQ(GdtPqH@f@Frb!8iKb2kJXv&s`jD8$r0k7sXR^Lwq z3(Sou#TLBI6@`&Y8M3MCCXgQ+t2tQmsFL9}FMThcac0ZAiX7{EokOQK>cJaxba8M1 zGDtbAXeG0luis2*`fLgg5+(^HW8agR?jL(@)hfjnSgR!iCI?>M9b6)HE}&HL;6Two zJDh#el_tkA!>kL%>RVUK2?eihLSO!=Qh}Cpd|=;F6A`g=S}S^@2l`8jqR%#0-6)=`NX`a(wCby8@35GtYO39WQdjnMfA8d&W-~X zCDD}Z=D-*V9n4k8-dP{XskxAO;`U7bR%C>`-GvY_RVGw-K(3Xq5kXG zSsV|Qd9UfX?!dI+@})J5>&nNCYS*8u-;*+*?2S)6nn#v(*(GI3|3L2QpJYuJewO5G zZD@4=7NG3+0!(O_0a${YyPYE4DVY?!@O=ir6ocPe!klJIplOYa-=U!N!0cGI(8k&; zBs-h)Va|;s;MOt=eG;J3dG-$o?Soe)o*X_qn6X9Jy*+-{rSN88+Kt<)^m9YU;ljkU zN9^32CMK_&5{TGQ!u0FDN<8G_!#tc_I?fZ>_6<1aEfdAIeGkBON9Fq^-qA*+5-U{HiR4w>-zF!C=XCj?vFL?U#bdT)QS~8`!`M$JK z8(tsJO!-avzmDc4R4C(5U(Oc7rYb+m?k!qkMQDfV;7Yz>nbJDPEttczfxW|yncMYN zf~Mm*&80A}VdFu-7@0o`Zio^$8!>u@(09h>v=feuFZ{wfR~qNxLX1i+jumbFS7|e*{m^C{i=g3?qwm z>Sg4(NY8`8`r{quWdAsMU-}^64qFcqdcr~$JWOJx`B^x1-34*{{QL`39?Y-f_^S|Axb{ZG=*Gii&e8*qcRb!9 z8pDds+BP#)nw^!WWH6z+N+v~VfdIDn0(+;D5rsV|E#T>ul<}$OgyNk5q$y>{zO+yN z`USt=)|fI{2ma^-BPK9i{cqoAs17h0qBHPFZ?uz(4ua%-tgQgJ!P7i9@e1u@Rzs59 zMP5p2e$9p$pt`B<6fNv;IF==V5GQQW=Yv%Gi1>MjXJ*YNak-4(K@}OB@VQ$YvAA&UsB51-C@Af6Ns^aCPLV#{i+jn z|5HjeB^Me@pEFMHS$CQWPcqqd-QibKfey16mV|I76H2Ar^u}qso!gZ;%bi{30i(7O z^7VLp;$lZd=0jPL#5{P7dusp)d3h4+dIV9l3B4+b-7CBc(8=faY!S^rjelx4*D5BK zxCq^TDxdX{N&O0zd@Xze+me%rKjWf?KUiuG2g~ZV&3D*v2%5nK+^~ZZq$9Ld-e8d$+PKrs8g(2c@>^m^xFzu50x7KH)U^V@h zXB*FEl`hKsl4z10pfg=(_$g6S_<#2T_)Pk1u9<=-2o`Zb@dF+pN%Nm0bI#iE3F7aj zkgzhzc%1YcCk1-XBEWcC$3oBIRk5u6i;p5Sb*E49Adf~J%dmm9aq*L{PXw3S#4Ac( zWX_f3O7(IV!%XHoBGmz-m)2XHOkU3s0}k%XVW(Z)vJ8^b?sjn(91Mg|)XG_XFWssf zO~z*=A+^|ksGNf%D>Fo`KRu<759oCEg$!rd18VxON`HZL@ZX_801}lB%e<_i4+4r z!K(v=hvSqvy#pIXQ-8)sa_+qQJ0i#fYd8j5LF*a`LInow0wIr(M6#Sx3^Z6)R80D3 zjd(VW$4r7c!pyRjYw{^}b4lr8n2^yOf>3RZF{lAbfzER__eQ#wBohP_0h#}fM23u;uz)5}Li1~WgEo=0mJbnnsq+~>RJ zh7w%~^xS8FiQv`AP+5mgSqk`ci^QW@zK$23v-8*;+nYLusr~qc{G_I2O!^9JuBJt{ zisPz_A{upxW8QrH%vT(;D!JbWRU`THv|#-?Hz0#bd%RE##0EQv{F}v5lfAuoPx;#P z$F*OK*q4Yi9_QyypDcc=FnJLRzwbD-|C;0GlI%VY>^`IqIr3*c{@BZdQGzKVRe0Q! z9l9tUt6vC)ty=3a0VSS)XDXf6^&#<;8)_D$Pe9hTWLc9YF;|a(5W}ZY>79ezq9sLM zj9jvnJO7l|Zn|3CZQac}8sCmDl_432_o}?f$*>4{67j%tjJo-~eL;Hmc#of7j=Zt9 z|JiStF4|Ih;7+_~KAdX$As0SH`6wbnI-A(0m!FRFMV#?M+Bdx0*<;kV9;-=7pNqaG zXaz!}ZVU--j!&s_q-(4? z*!%Mn6WF8tc7~&{sBKturJ}+-t}!Q(|9QFjcWv!RU7ti z#`k}IthQHl=M__41gnyDNOR?fc_hw%Pj@{V(npum_ZVnZAS)!=-Y8tm)1)rRPG2yq7 zP*l>uVz8my5C;-xU9HM;Emgz-sA%sp8N6AxhaC5S7TJS+_TG zA(*y-glW?haEUh(?DBht*q{m$MjdF!xs`l$TWkx?6En{XlI!nOT?M9^2k|wWDV{CS zO`==Rz8gpKPE&N!0K4T)GJ%Ct5^4{t5fSm`ilE-B7f)e|j*gf>F0L-Bb}k#c4?;`U z+O#KVqvz1^^RtkuCsljWyVFkx=A<}QUxk`xJLK!!qv2q&x>=A zLs49lIH(m-i?nWYC9FmA(hLS&2V@a2Y2@^)1d@^1u1#cxr;tIAqOuJ^G1->KDwPl8 zLGX^0C@X*1>4?PR%1!CGDT&dO?@$6twy`ZOCKO89<2u-Yu=o`8*|vEh7I3lnB2(TF z6%F6(q~twIVcQ0mT~<^( zr2v$t?h3va-V3>9cON_A1EFvQb6F5aHAFCOC$?S)Rt{1R6@0`YFX7CleqSm`UM%M; zwo~cZ(8XzO{^+wGMQ>nlP2lmkPR{Jz8Zz-p>z%HvhMch^7F-p=9x(;eO4iDs`Id(7 zM-Brzomy8TlS%|W%R@?5kuWC0gLq^}UuW%}_LG!Vm zNQ01X5u+r_ak_3`yIH@4VJDVyK%~lt|LQUyGu2<~E#fo7ZU+<-w)Dq0o1Xv;4`yhG z8@cY*QOfem8j3EL&t|<}zyu^i39C^U;w*09j@V$;VD?Vdn5qx5QBVW37hG?YFiZ1T z?0AyJl~c>9aM`5EL7V2F88mb3O{qtDSYxyy8yCzkI#Mk`2cr8=Y>kgvVVfXb@K++Md^=3SzYnq$JfaB$ z&Y7%!gFrn_ifZ8eH}auwXMZ_!kWIO@9Q5CpWekGRHG!JrIv4+B-a~Mz_-pwF`2Jjp za9nD}5bzWS;Q-EVOu%`G9g7vNoTDRme7M3VA}d2V&L=i$xj)27Q6kqF;)({BdC&L7 z|L_56o~fw!MEk)y6f#hK)Le1CG7d`f#c%D2X;&Y^d#IczMbnX@JJ_@0UHW7UnI`Y? z-{v6AyX#{$=M8ae_DjtIFI`u>!|d#z1;0A*&+o`&2RtNWt9`}P=zr2yTTxjeDM!#e zdPAb3)!JeC>sVO7tAX8GV;q%yJCjGFR5;K1{eG_2 zLdQ907Lsc!vKUA-GW^Y;@CYU1o3ALj?^LHh zzmPo&u^f8+i41<7)BSc%jJKf`??&AE15T3yNtcOB##+%1rt@7{C++N=KX0V^U@)Ga z{i8w_pyQEo-Ftc5!tEU=*FlRB`~D?bz9%`V*KZD~ne*Qdeq|{J@{gU+cH`PQC;7wK z$e4nb%ee50&d&awy@h|Ne}_il%N=x5qa4yzWEQyA8G|RZ;pqan+6CF}v~H#jE0w4X z>^lPD+FjZ*0%9CCu7RO$uxV?z5|vY+oe4p}j*xRoW_#<~(FtglIN$WoXOfAonYhM=ue^dPfycxhCsUW33QDq|zBaRaBLIJkVaY9M-f^5)^OPhO^^sgOcxq&ov zJyNf9PhybB$NLh=$B*a_JGAvMC@(N`&V{ z_P=H(TsIc03^7`7q=7#{Cq3iNeVc(tk1CIQLEjuQr{NT$wXo_ZGCuuw=;#3&Wh37S zPaX_X#Q@SPN}vHYYpgBnPj}C{w4Pu8`i1P`A2cG-*nj5!ihYDqan;X0hWQ*@3HFkOJg$Wz=&Yur)p%|GnAaV7C5DQc&4)^`)8qcHV>Izy3u& znkV`2K9tC6Lj@#`L%Ca0pi2Hxi-5^f$o_+`?mXNtF$pQ}Ly*I4K;fiQK}AO)pZ*&& zU3ZcRNHb4vWJ|x+bxU6hf%Uh^v8jJSv&24d#}8GhA;)Tt$(RIb2^#qdRnW|8ZW_3`oWnk0t2x`w^`C-K z@izFMlny-w4PNcUKbE3EUZVO9%6q2EiAdXYd z>%ysOP6)h}9Y0OykCRE9FyOv$6`I!QR;a}^B|clu`!uU<&xY|Wl&75R{F>w zM7K?MwH|8^$7+f12lg?n;nQrChVAE7zBjHI(ZOue!MpMijrJhDN=R_1*q6EoWRHB( zu0Q(}gBW~`7QLU4lb2h_6+-UxX&Z9Alb97*@J?c)_*C=?+{#(h_iV-(OpMGIQ)8Jj zlM0z4WEx+dZR@}wKF;g;c}iia57#V5M_${shb@wze@QX21CCkr2+7Z}ZH0Tdc5@90 zjbw8^`>6a?1eZ#t?OB|PPpaJE_|z+K+;y=s_S>Y&I?M|F>0@79L- zs^&69ZgFNgC(k=fS{ztHG7PksdZA$mm*4i*Q-wUK*;eSuGk`qJaQ0>@%?zPq?N0v~u67OmPoOS&YB=Eb+Me`8<}O)SIX{5K z86iXe{@zb5K(iVE6B|%0Z^qB^D~x z$1(>O?q@4Y2*_(FS%4>?>Du3T;xnh+aXwOV*YDZl$9k`zP_CO}|t;muq?i_jP8 z81UF!0$f<3k^WQG8E}LV8^{H>csjCCP2euf5a&Gm7Kt1>4y@d9aCPm`#uje*>x0C# z!s0&r%|dMTSMqaR{zIjr$Ip&iTy6J_5?B$9gpFrg5H!)@6cZ$T_A)7ar(W{ayB@d8 zwiW}`o~-d(cJJ-<$TPSD3|L;NG>z5V3{=J#$A=hR2EKo_bDdA32M&;;jUYhJOU1>V zD2ze72t1plN2jNMZExSI8AI4rz(;(5<8g&~U86P(X!`Wlj|)u-R42xpmtl9;I+MMl zbZV$uGMP$qj&0PhipW_-98ZzBJSn~}?D{_o^|i|%t)Y*VtOeP%l2_Sm3%L{dg~Ft9 zEI+(noBd@AFG}L5OGr=pAgr^|nk44&tsLi1Y27e;lJ7WQ&{=A@Dm4*q#aM=r$d9)p z?CoiI89Rx4&whVR1|ial&HEfES%9G%@{QKX*W|(b}I}V5tnAv;8LZ;{^2;V z>O=<>-XN_6qI4GvDD6OgH0heNHK1TI9?G^bME$#*=nvH2*=yfae*?(xNqGcn4Mw@C zw$&fZadBIBB593tQzO9k6PO^5=je`-@86*coGm4_#sy^67v&lrAcKh!@_zg;4tg~V zK=jtBlfnE)yTYkwd!w5o%F6E~+1c_KO%{I&f$Ynz81(q%+dn>8MW}$EEW4jckMEX! zf?%S+!+*`bS;6l(Wh>q5btGf#|pgS2^Z7kDgs;7AS><9=Uv|@}AJIh@c zRnEFo9_?S6R}Q;AAzdx~;vtBw&9KJ6vyB+O2lTJ&e^^=Wy?xU`ihDo_&vU|n$%o2C z6GgsIGW;dPc4{3mwM85pqev-OVSnm@KqF#$Di zC@K{L_QB)`!SHHZ5H4|}+7F5By$v0E-56~BHtyZ}T2{w5E;i|KVV2J;V`>uDFYj-a zM0P4&nB>5iaFkjO1kTK;BPOu931ST;YP=q0B1*)+mTr-==>%k~Et7n3lSRvPN(yon?F;XHqyE6GQE~XUJvoY!T=!ypWygG$DfB4A>=K+Ymf1 zhsQ!kl@yZa)fU+8b74eAGJigK9p%P-Q+hqO8Va+bwvN(v;bc#og=#qGL)hJjEP{V= zjkB+;rR?YHqH1P-R$6TTE=yA-+kteMdZz7@br;=1bR(Mlq5{GboTi)3IK?RTNw%uM zqfThm$c~f??nmg9N2bDfQB(ZUPWteWQx-t|n#v4J$8<6BvG38!+qiU+`EE-s=bjt= z+xgp|pB-Z#a*=_LjCs`-jO)eBxK}&)^Lf|3u|d|1WF|qWbxKe-4}6$i^n)3#)+~xhf)PH(grdSAs7OFw^`?=2Af=?H}^;+n;8u0bC4YWt~LBDXsKZF zf8By%m)Gt^H$YTG5^aWm5}8_Oj;|VBX_t-0P1@U0-^XHjk^r=j)#G0&KZfQU*?6Fp zFyTu+;!Um?T7mFjVa71xQ5%Ywvbew^TE{uSkmug?Xz8=pQNz@hy9I6Zd;hz{^AFk2 zP=WBP|j7U<#L5{@?069Up8<>gp zI|iKJk#$?bkIcuT9%uoOOydcn@<_>BuD3o_aKM#cDm{Bvq7T2AOFaZv6T|3$3SlHa$|eh z@v#70J#WLuKn6xJ*27|@qd7i^s16*oe1Hy@0(LwNE-haH=uDx=(VdFLHpCqN1ixVFAe7AOW9ap*YS(hdQ#NGmK zVe%^J0@kGs9-yA2P^R?joHG3nBrxx{j;XkF;AHsTcV!@pt2X|6 z@#kUW?D(r?<>O`HAJo9of&98z38QdTE{GoemB@9y*!r%2bt+RNMcXp6Tl+HC?ACT&~;qK#SB-G05b5=?I2j3 zPL|8W^sDWdrRJrEU6|%|?GEb6V{stC3N;MtB*i;MP^6mwsa6gqvdtYif-MUHGv%TGzK+JF&*K~sp6TokdtqB`bue%W z)92;uQFsdo%xHom((e;($7^5ktC6r}^iwuVR=<{X0);P=irE2Q8gt#ck|LHLgMu;} zGt*B8W-7GWO2P#W*gB#iPPP3>a0%6PH@pLVT&gG8uC5;Zq4LowAc7poPW~JP3B2qf zl%hzO{xqB^gP(E&8a(!beG2gSk^ww)CErjP0RFuH{Dpbi7kaPz6I;(&V%N_aj*!*{ zSmyj>sdSiR4i<<<19H|v6(V3gEPNB+*@73mu@MGBvo@q~G)=&46e$MZo4_5g0Hte) zP!3beQVz2Ij&I!`fmOWlirvXm9gKxbG~oI zuo`v1wTescA=J9&{4aM2DQh(vtRFIfQZ+y9rZ@IK%B*-h@#*iye6uINpwyPVE(0kF zj$55%XPBD?i9NzZiH1!1Y`~K;8o|%n6EIMwoinsPv<6L6r9e>FpE3ndgaDcW6p`*@ z0yfw+GgDibGz_YW8`lXWLa!J0;=fM~G|p^oP`S|B_*e8GgRdtAnI!)n6-IyH)90ZS zSonBG@IB2e0x`wDuch;XQWHsfwWV#bFU?cc0zrKp^BC2s4{W&Lx#)Y^x`!P(wVg8% zO|c2AxTbA%uH|$hkGTy3FX)-SlriI$pdRSWdB=nQ-Xb12Rj!Z*lZA2d33)Y9Gx0VJ zaRGX7hD#ubgw%K+1EF@AvY8p@37QwbLc_-!#!E$0^F1PQ>rv3^?C(XbOBTfXUGsvW+*7tos_w^J`kBkDyhkK*!1?z5agS+{hQ_0M}ep#j%6PCZMw|Jzv`oX z@)BQRK$(fop%5Yv^YB@wpVTWGG8;Bi+2J$ml%RbpCJG}4Y+#}N z#ZOOx-dlQmlNGvr1UX9KW~=NiZNgRRPis$o+Q5bPlP9gMKZ8w6EXcObZnM+w=EJ37!AZ z8yBzX{sw%h*PlH|t<9t~zTX-#=hOT3CoM?1d}c_^*W-8~@8ogUpXUMg=gVRqk?!6k zAF$uQQjY!dp~-L14#Dmih&!WK}|%RQTxRgGjlf0e18;yf7Fez;{hms-;ZMraB|DKdw#(0O{^`5x=e0C8S7Kl-t>5A_*R+7H2TSqK?xaqVIFibcpMUJ|ZO6REf#d0uE% zT`d#4(v;NZ7cFul$W{xa*m1jT6!Jlu;&GghDi4OkTabguZ?3vOSe-Ee?<~CkGQ2bp4N?3_I*U2dE>xDyW6}c`2oRpoBQFK zcp$wICH$CygOC~wI9G9dBu%RgEpS_r!#mk$Q9NvKMm_$)ycqpLX8MLz>BOgn{t<37 z{Sv`LMV#8A8?Z}(BERSj_z8XfwKcj}ONxL(8AMrn7!r?Zg|`t@#QQ2;K3e5x79~G~ z8ETQN=1mH<2h|ut#5xhuDpK5M^x|wkQm@JZt3{pt!hknJ;n=WoMYZVNrGWK!%Jk_; z5ih?Z)~NI+Qo%r}*$g|OPa1tYvg8wL*NkCD9RZuLHa4w#AsG}PIOnWxWw1Bmvem1R zD0y9a&Q6-Kg9D)jp3MAu{t;7cE*D>gPXuWABM!KSCPh_~5pYhRG`_iCnW#^@;O1e@uwWmAG=)*ch2a=U|ET^B}nw$M&$@<+=Gg*A0?5!^T;K5GC{j*Eq!6SA?B&R|ywU1*p@)(}h)`qNS zA^D2t)NM+NVP#G%%_$$!_-3jBh9?~wU|B0 ze?ZS+cycq_n%gSN$mEp}5mgO5VXNnRz~B}m<66NSB7+Fj{0LHhBNC}CssG5GK0vYt zT8`^2w{x)#qcDffrb8CtDjl*8hHlN#hzA`CNU7zfP%Fg!7FJR|3A~dZWAY(vP;T*F z9NLATtia=kloo4eRRGu?^huGZII!}gK?;ch5iV5)=%Xx`etXJYj=)JJz_)I(LI7__ zu^+_1U~v5JzW^sJBu@S(7uiCvlk!bA<`@|IQ4bAMR05%cE6oibm{%bgQHZAwX7{I-bRr=Vg0X4VfWFY*vM;Sxxd{RiV+l;4v% zD=JqW2wc8kL@imyzVjSAETc}q??{-=%`DJ4_^S;+hYSqJyu5#F{X5NjcHw^&a-Gp| zKi@mivP&Xs^%lEC^iFhPiDk9bqeP9KAYpZad_)MVM~#vdy+uoix+Fx2wt8LBqIdBn z{N(lj#sAfuIdf*tbDw+X%*;K{T>8K1v;Zj8mlJ}h1sQ{U#p5GelVIDlxQmhS@VDju z`iNoh^th>lJ-KbiMmK2+_m(lp*r_olPRa#>7=RR7$XI0{M<64(TG>=DuUP;TuhuJ2^A@yD{~=3NrWFCKP6w!;4La zjG0)|g8>gddj?>lMoE^u^~+eVN!fVh-G`q{E)Pod5_nA4D4pIV`Wk(cs~XP$LF<2SKNxMTTb z(asGnZmoxmA&d-JzGAYx#MZn8GB_uSml#AgV|Ev^-c0fQM3)}%Ljd>)gKyqBU2+>1 z4)i^|huMGn-c>j~JQN+oU++G$-3rChn@ewu#__lgOpgO_eqy2Lg}GjeXMhGQE9j`h zCFw}r!Vk%v3@i3bv&ju*T1IbVzNbkpn`=|z#=o~xoiJb@B0+d9-72waMUKv2hBuzu zrLbIz%Ya-AY@Wr#@|A){64@2?Q#?<$*7|8B?<->$$%s1aRc=3`|{R`&Gda`9gpUNrxS`QFnG-jIvy} zYO^(*A6Lr*@^T9IY}W)^oW!1Q4NK^NcPcYcOl`&uLrlA)-BZn)&${~5*C5}o7|pW= z-QrcEXG@OPu%_-z==;S*$DFNP!~_Yiz5xm=0LrZG84f_i3G=ITc!4m5NUV)I znI4#l_7J-yM0Y9xHE-tmQe2w!mG8q@0s$N@wi26jgPFaVQu?|EH9lkT%=@MuSin-; z057bVt9xgA4s*4pvU4d=!x!IXi)IiYQWhS)=Yd4)fShKv`ifB9cJk)b-g4&UpqB@& zIU3o!U+a_)fKQYhPu=C;wDm10^WF&s;jyB>=ge^h0b=92&HYWY^moi3ya-Tv6+fY` zUX#ZHF#lYe$h15rfhSP~$(c6=Q%2zb1`wr=?1?l|R=oE~-eP-hblkA@^Qmo~vVVf) zD>vLBT`#oQejjiPB+!JeehmpDo0IFg52qu-IWb5YEjgi&ZBoFCJISK-f4pm}aI_w^ zi@!X==j5j_Po=AI<%# zwIFAAMw&P8nH={!VJ$!@fnBPy_;@0K%$63P@WHrnu0Gzang+HEtz9c03UDfLWb8=;YQ<+D7m7kFiw+ z=VVtB9)HW8PjY~S4(Q-2SUj7fe*=-ZR3igzed$}yB;jFUo(r4K;{mb;h+rL=me0C6 z-oEQwG1x}LoBgFVgU?WCk#m$X?-3H*QG+uK60gSjJSa(wh1E9Btxbcz${p3a;bZju zG?l6Zty$03R!tx)r4)eJ`PpHNJn$*69w&B6!Gr*He8P!6=OY6RjOgEE-2bHk9=K*}Ih?#pyZnhA0F{IM ze6{~UT7&8%1{t|G>)6YMf6?@a8q>6i^4rQlzBd-=k`Y}b;qFiWGYTnvYN52x-%$cV z6{&&~uEUiciN5!a^kl-}TtK6~O*i3{g15K<;)H&qK-Hbep%Lw_(*KzA)2g?w!RN2T zGxg$qCRbkVU$bquM;bYk+dZ?iDi5E?Ds__8w$p zOg8m*9Q1e$JM4ME-$zud?GDSkO5VaY=Y|+n{8)6-u5RBF8|zP|eL_VC?121xz#cWI z*D6``&+iEf`RFrUUJ|R_!D1pH$|fHWK}Cu`a6_&x=t4!;Wo)mWoXj_E>-w=Pi*G)y zGcw%=n(J_;ewJIQR0?T+bDPcqdJze`&tz`7FV4y> zr6|ORNj_N$ePye~gJWb?p(@UXcLe?uH{}GpU|Z}}Gf64c=iFMElNmKqiOk>7x%o_|i%JJYY>gW$U`_^Z1A4g8>J0XK zKuQ~S)7fG}VVjN5mRIKyJ)}n%z+#4|5LPsV$pLVl5NK-`k}(bRB0?U|hG6J*Ctq!zF3OkaZ3ig(zc3Mh<(gy+{}`)eqrEuqiTSK{5kV;uDLCtM~& zsIb>NZbM4jIO3lUp0ltk2xx@63~l)CPd3#eX_&%hD7uQ_>yP%=6rSYTQtOjL;p|Yv z^K|AStulz zBDPq*+NuPp?|KOW;$kzG5(HNQ*cX}}U9}-?b0!-{*^vL;Jso1so2Zwv_ zmBuSdivqJNpyMDU7&7356;hA^G2l7Ydq?r=KbZFbLGKgvV()CkF$O|q7*U5~pi)#Q zxGUFKUE8{UDE;h;W$%%9A*)hE!y_~-5d3i9VU56qlQrrP@iK~TmX5sNiZDl$nvhhs^LJ5wb` z6kE=u2r8_xO74Q_oH+_2$1GjUa@L(I1d8%odZ4PBmZ9xUh!}kPkE&ke4JzjR>r7yF z@wAMLqu9@WBvKpW%{W>`T?=*8#2aXfx%s~7#ugeOT^bPqF%gD$J{HSgK=3v|9PJ;7 zN?x}ml|vaT7<%0HlKwUMgO}J zQ~IW0U*-z>sYqrcgmhv37N==K>6C5O<4NM zjWcZu$HrQfE#63V`yQLh6pmjBs~ANCe9KnmE0NphqXF||lC~0&9DiMov#5$|?(ozz zeaf-_h7gob<6VM^5Z`=HWEKj}mx}&_B0TRsqQ-%(OL61CHfo|wH6o%IPv%%Zp{?xa zL_mGH-q2al>=pE)n)Fq(3A549^`qfmApFJ9&}S{D3)f~3lFJY-?zy{uKj!umt+X$K zelvLW9?xkn4{djjeiRx!I4OH;_bJVF+hF*ei9ySJM{VYV^a;x^`jM1ZaTm5LiW9sG zZ2TeO;^jf6KDp!k3Z&+b{OBf!eCx)qmY*k2L9b0xJf-3mWX(9@-3twV=13cz-$0Kk zPM0B5e;K{J{)wq_{C=li!On>&>dUF8IO@w{sp?Vv8Fn(;4Oo5af)vsl+~&P&BFC{- zM>znS?&85nT`{GjF`wRQ?$HH}h~jXbUU=i5cW;g&N|Q&HeBkEddnuhnLSEuBGDD;R ze!;P>cgA{8R3cB7NPpi`+O(NL`$~iGbCH}{g0y*TG)5dx-7f-M%Ke#6^0?yu)sX%= zXt#b+b>ez#PLk9y$;flVTPIx$h+s{T-)5gJgQ@T{-YQ=t4L}yM>Wq|#({B!ql1*)2 zp&OPL#YQiea;1$}--|2U~*!qe39AB#MOb*IAUIToC)!$>f^oU}!2#uGj zC-uKh6z)wV(Xz)g8+I%x?tL$QV5wh?2p}on?$17b(AjDTw0X{!6Z^&afI1p&8@T=AW; z%11O-aHYaW;zK8^IyRvbcYsDN{cF-s&1Y>9q+yei*YYIu+1Y2)*dy6fTqhG%7<9nX zjCW_Kp8BIww^l4_7Qx`X!uvT^?SKXVb1};|NYfe;>>HR)=$Q zd{?&4XXWTW-P!q}E%7zV|C=~lqPkh_LkadDHxU}EKd}|6-0;8Qf*DZvPgd*!gV;iX zoHUb)42rw9gapC$Z9{#?+w`Js!lC3A(xo7=wjFL?V891d{2c@oUiwJ>j6FsHR3t>( zvlhmekwJ58InMEJk)Fa^&f-a2kpT%_9iJYTd9tL@O0+){Gpss<8tt~TesOdDM$mYlFJlt2_B|F31>PiV6Q1qGO2ck`x25zxhHW|_KmJ^_I&eq+`eIM#S~C1q6=Nc8 z4?gFNdhYA|e9M_ld$Q!sDVn0jWWlDEUFJ2Hj+Vr&*7Q{08VKMSknBn4V^lvuAK4Yd zXB*NrdthLE=Btrz8Obn??XkOORATd*`QyU91im!!%EjWSXP&q9mXfz2sCqf1kC8Bj zk3y~$!q8LEUzR>X_~&Nv%FaI#Xgn441X`9ZW_6>Cbp1}Yrx5>KzfA@H&rQU{H1=LQ zOoHO-*5r3~G$K78($eZpI#7Z6jN2E6KB{yqELDB$hIG)1+w-c_F1lfhOT?)wITd%!RwN&@CJd@AhK6;ik~jz%}l# zTJNmNm3M{i!UH80o5n=BaDuE(bsNGTGF`!)4!TJoi`VzV-eq~SjF;7{Nn)$+Hy!4d ze`5EBrb(W%%1P`v3x3;givG?u5V!i)o}0KFJU_4&d{r1}o>Y1n7zU-TYYY!o3>V8s zUbfl;^G_}@{&JzpV${g~5$emt)1Ecx&<2lE)6;YNv8Ju%%*|1YRVUq`41G$Z$4^Iu z6Tl>P2qd10s|>e!dW#1h>B7S7*sv_>Y0cTEv9k7VW}{2&9iW^vTBx;?lgAf@ ze%=Tisk#Gx9SlbG$R=A7;;ZBaY3CG(D+A z@*3Mu<8$(NZi^jslL-(ziv#&1%A@Cg(`Ik%X}?kjKW6cIv&|>&a?MM&Z~eUo(FTSxz(f z)+<>Oke`tQT{6+LH$0qx9sS~&+{EgKRhe|pME(Qr@b?MU$yzy;prQHMEl<5 zifu_4>3%P*Pps&H14WD|C>U?NSRd*Zqvur?jL2*()RUoCfXxr2{mAN7mSG7U8uI=b z%;&`$JIJ&C<#Fn*?A@{zu)Pg{B#^BpoqTMHAF#Rpm`TuLylQ7Tc zWsN!e)WD++zdUmN&b!YwgsW(`h*CZuhfBN=Q(%9-@0IhvRzvhh8X5TB@wl32=SCg9 zLHq6m3Nk7fj_gzKHChwe85QG`EJo-ZKm6oc?Hq28Ybws*(4yK@TdzmGE;4#P*9lRWn-lhJ`+vk2+qzVgop7S~@Q&ms5HU{b*WgxytcW}fbLCby6VN_zQ=zRpfO&$B0 z>wyPz7zd%YB1+Dj-#@>;S0JhApNWLSM$;>cVJ!y1AE?91HlhI?;I0X0jlk+LrIcXr zQ|y1@@fQU zJ$^$x>s3tflK8wBbg5WQLA=(FKDDerXor{I&dob!?f1V=CzN7m*vtsCE8SZErVzSf zzd56)&FWAU*`B>yzc$u=S6dKB@{KJSj%zKcMu<8>PGi4YXU?!i8)`8`Jos zsF;Q9z@LI1Nx8-+D{CKpI|~Xh55FclxfUv{CU(XfPMa|k5=~vdJ~W~+J&2Se`@nee zoBaD)QTm(ShLtBlb?>1lq1JcPeQG!LGF z$7EWhF4d~o#Xsl{#`OXnsk~yTo8d{-=Z`ndKOdQvb~kg;G4``>o=S@%JR4Vn3su&b zJUPb3l)ZbHQ{?hEX?Xf`r+^gs;?km>+Eri3X*J=>GKSCT!FZZtRoKB|%SVH$VETv@ z=ZD`WMj5w>7%ytBx^CUWsbEEla)L-%9aom!6O2p`e{|l}eBao#(xsFsVLdN)o}08% z{XL@i`hTuRKPgLNelN#Xt!qd>5k-Adti|B~Vii^eO}~25`aem_NP<`ID?hs<<%Xu+ zxe6eJcae+9jY1V6(V9JtTmN?p4!XFFoO7=+{gTnQ_}VwrqM~@4kXv(MrDhj${$gjt zGOMcBcw^RU$Etyh0Q3Dw_wqombk3n77e9_DHEq3U z-l;FD#>>-o!wX(Vsrh^eZ@b$!WSKJ?~YKEb7P&-a>ZpNM?!%<%C+fC^g~v zd80noqoWw>WR=5SBEGh}VDZ=|et5^4)9;6XhF|P2Wp+Pq#)z@FQ#EYk~yJ&GpKzrq|ae{5mOXPi#WO$cpT7q|Q4q dNcXFY+u?21-&CGq{}r>pruk4;tpZ^a{yzX<_u>Em literal 0 HcmV?d00001 diff --git a/examples/Prioritized_DQN/rom_files b/examples/Prioritized_DQN/rom_files new file mode 120000 index 0000000..c1c50b9 --- /dev/null +++ b/examples/Prioritized_DQN/rom_files @@ -0,0 +1 @@ +../DQN_variant/rom_files \ No newline at end of file diff --git a/examples/Prioritized_DQN/train.py b/examples/Prioritized_DQN/train.py new file mode 100644 index 0000000..2d08bf1 --- /dev/null +++ b/examples/Prioritized_DQN/train.py @@ -0,0 +1,218 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pickle +from collections import deque +from datetime import datetime + +import gym +import numpy as np +import paddle.fluid as fluid +from tqdm import tqdm + +import parl +from atari_agent import AtariAgent +from atari_model import AtariModel +from parl.utils import logger, summary +from per_alg import PrioritizedDoubleDQN, PrioritizedDQN +from proportional_per import ProportionalPER +from utils import get_player + +MEMORY_SIZE = 1e6 +MEMORY_WARMUP_SIZE = MEMORY_SIZE +IMAGE_SIZE = (84, 84) +CONTEXT_LEN = 4 +FRAME_SKIP = 4 +UPDATE_FREQ = 4 +GAMMA = 0.99 +LEARNING_RATE = 0.00025 / 4 + + +def beta_adder(init_beta, step_size=0.0001): + beta = init_beta + step_size = step_size + + def adder(): + nonlocal beta, step_size + beta += step_size + return min(beta, 1) + + return adder + + +def process_transitions(transitions): + transitions = np.array(transitions) + batch_obs = np.stack(transitions[:, 0].copy()) + batch_act = transitions[:, 1].copy() + batch_reward = transitions[:, 2].copy() + batch_next_obs = np.expand_dims(np.stack(transitions[:, 3]), axis=1) + batch_next_obs = np.concatenate([batch_obs, batch_next_obs], + axis=1)[:, 1:, :, :].copy() + batch_terminal = transitions[:, 4].copy() + batch = (batch_obs, batch_act, batch_reward, batch_next_obs, + batch_terminal) + return batch + + +def run_episode(env, agent, per, mem=None, warmup=False, train=False): + total_reward = 0 + all_cost = [] + traj = deque(maxlen=CONTEXT_LEN) + obs = env.reset() + for _ in range(CONTEXT_LEN - 1): + traj.append(np.zeros(obs.shape)) + steps = 0 + if warmup: + decay_exploration = False + else: + decay_exploration = True + while True: + steps += 1 + traj.append(obs) + context = np.stack(traj, axis=0) + action = agent.sample(context, decay_exploration=decay_exploration) + next_obs, reward, terminal, _ = env.step(action) + transition = [obs, action, reward, next_obs, terminal] + if warmup: + mem.append(transition) + if train: + per.store(transition) + if steps % UPDATE_FREQ == 0: + beta = get_beta() + transitions, idxs, sample_weights = per.sample(beta=beta) + batch = process_transitions(transitions) + + cost, delta = agent.learn(*batch, sample_weights) + all_cost.append(cost) + per.update(idxs, delta) + + total_reward += reward + obs = next_obs + if terminal: + break + + return total_reward, steps, np.mean(all_cost) + + +def run_evaluate_episode(env, agent): + obs = env.reset() + total_reward = 0 + while True: + action = agent.predict(obs) + obs, reward, isOver, info = env.step(action) + total_reward += reward + if isOver: + break + return total_reward + + +def main(): + # Prepare environments + env = get_player( + args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP) + test_env = get_player( + args.rom, + image_size=IMAGE_SIZE, + frame_skip=FRAME_SKIP, + context_len=CONTEXT_LEN) + + # Init Prioritized Replay Memory + per = ProportionalPER(alpha=0.6, seg_num=args.batch_size, size=MEMORY_SIZE) + + # Prepare PARL agent + act_dim = env.action_space.n + model = AtariModel(act_dim) + if args.alg == 'ddqn': + algorithm = PrioritizedDoubleDQN( + model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE) + elif args.alg == 'dqn': + algorithm = PrioritizedDQN( + model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE) + agent = AtariAgent(algorithm, act_dim=act_dim, update_freq=UPDATE_FREQ) + + # Replay memory warmup + total_step = 0 + with tqdm(total=MEMORY_SIZE, desc='[Replay Memory Warm Up]') as pbar: + mem = [] + while total_step < MEMORY_WARMUP_SIZE: + total_reward, steps, _ = run_episode( + env, agent, per, mem=mem, warmup=True) + total_step += steps + pbar.update(steps) + per.elements.from_list(mem[:int(MEMORY_WARMUP_SIZE)]) + + env_name = args.rom.split('/')[-1].split('.')[0] + + test_flag = 0 + total_steps = 0 + pbar = tqdm(total=args.train_total_steps) + while total_steps < args.train_total_steps: + # start epoch + total_reward, steps, loss = run_episode(env, agent, per, train=True) + total_steps += steps + pbar.set_description('[train]exploration:{}'.format(agent.exploration)) + summary.add_scalar('{}/score'.format(env_name), total_reward, + total_steps) + summary.add_scalar('{}/loss'.format(env_name), loss, + total_steps) # mean of total loss + summary.add_scalar('{}/exploration'.format(env_name), + agent.exploration, total_steps) + pbar.update(steps) + + if total_steps // args.test_every_steps >= test_flag: + while total_steps // args.test_every_steps >= test_flag: + test_flag += 1 + pbar.write("testing") + test_rewards = [] + for _ in tqdm(range(3), desc='eval agent'): + eval_reward = run_evaluate_episode(test_env, agent) + test_rewards.append(eval_reward) + eval_reward = np.mean(test_rewards) + logger.info( + "eval_agent done, (steps, eval_reward): ({}, {})".format( + total_steps, eval_reward)) + summary.add_scalar('{}/eval'.format(env_name), eval_reward, + total_steps) + + pbar.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--rom', help='path of the rom of the atari game', required=True) + parser.add_argument( + '--batch_size', type=int, default=32, help='batch size for training') + parser.add_argument( + '--alg', + type=str, + default="ddqn", + help='dqn or ddqn, training algorithm to use.') + parser.add_argument( + '--train_total_steps', + type=int, + default=int(1e7), + help='maximum environmental steps of games') + parser.add_argument( + '--test_every_steps', + type=int, + default=100000, + help='the step interval between two consecutive evaluations') + args = parser.parse_args() + assert args.alg in ['dqn','ddqn'], \ + 'used algorithm should be dqn or ddqn (double dqn)' + get_beta = beta_adder(init_beta=0.5) + main() diff --git a/examples/Prioritized_DQN/utils.py b/examples/Prioritized_DQN/utils.py new file mode 120000 index 0000000..04c590e --- /dev/null +++ b/examples/Prioritized_DQN/utils.py @@ -0,0 +1 @@ +../DQN_variant/utils.py \ No newline at end of file diff --git a/parl/algorithms/fluid/ddqn.py b/parl/algorithms/fluid/ddqn.py index 7f59c9e..1cc45ce 100644 --- a/parl/algorithms/fluid/ddqn.py +++ b/parl/algorithms/fluid/ddqn.py @@ -70,26 +70,14 @@ class DDQN(Algorithm): pred_action_value = layers.reduce_sum( layers.elementwise_mul(action_onehot, pred_value), dim=1) - # choose acc. to behavior network + # calculate the target q value next_action_value = self.model.value(next_obs) greedy_action = layers.argmax(next_action_value, axis=-1) - - # calculate the target q value with target network - batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int64') - range_tmp = layers.range( - start=0, end=batch_size, step=1, dtype='int64') * self.act_dim - a_indices = range_tmp + greedy_action - a_indices = layers.cast(a_indices, dtype='int32') + greedy_action = layers.unsqueeze(greedy_action, axes=[1]) + greedy_action_onehot = layers.one_hot(greedy_action, self.act_dim) next_pred_value = self.target_model.value(next_obs) - next_pred_value = layers.reshape( - next_pred_value, shape=[ - -1, - ]) - max_v = layers.gather(next_pred_value, a_indices) - max_v = layers.reshape( - max_v, shape=[ - -1, - ]) + max_v = layers.reduce_sum( + greedy_action_onehot * next_pred_value, dim=1) max_v.stop_gradient = True target = reward + ( -- GitLab