From 6a672c8088662fcc5f7ca81f0094ee532c8eaaf1 Mon Sep 17 00:00:00 2001 From: LI Yunxiang <39279048+Banmahhhh@users.noreply.github.com> Date: Tue, 14 Jan 2020 11:23:26 +0800 Subject: [PATCH] add offline q learning (#193) * add offline q learning * Update README.md * update * yapf --- examples/offline-Q-learning/README.md | 37 +++++ examples/offline-Q-learning/atari.py | 1 + examples/offline-Q-learning/atari_agent.py | 137 +++++++++++++++++ examples/offline-Q-learning/atari_model.py | 50 ++++++ examples/offline-Q-learning/atari_wrapper.py | 1 + examples/offline-Q-learning/dqn.py | 111 +++++++++++++ examples/offline-Q-learning/parallel_run.py | 133 ++++++++++++++++ examples/offline-Q-learning/replay_memory.py | 154 +++++++++++++++++++ examples/offline-Q-learning/rom_files | 1 + examples/offline-Q-learning/utils.py | 1 + 10 files changed, 626 insertions(+) create mode 100644 examples/offline-Q-learning/README.md create mode 120000 examples/offline-Q-learning/atari.py create mode 100644 examples/offline-Q-learning/atari_agent.py create mode 100644 examples/offline-Q-learning/atari_model.py create mode 120000 examples/offline-Q-learning/atari_wrapper.py create mode 100644 examples/offline-Q-learning/dqn.py create mode 100644 examples/offline-Q-learning/parallel_run.py create mode 100644 examples/offline-Q-learning/replay_memory.py create mode 120000 examples/offline-Q-learning/rom_files create mode 120000 examples/offline-Q-learning/utils.py diff --git a/examples/offline-Q-learning/README.md b/examples/offline-Q-learning/README.md new file mode 100644 index 0000000..f442f5d --- /dev/null +++ b/examples/offline-Q-learning/README.md @@ -0,0 +1,37 @@ +## Parallel Training with PARL + +Use parl.compile to train the model parallelly. When applying offline training or dataset is too large to train on a single GPU, we can use parallel computing to accelerate training. +```python +# Set CUDA_VISIBLE_DEVICES to select which GPUs to train + +import parl +import paddle.fluid as fluid + +learn_program = fluid.Program() +with fluid.program_guard(learn_program): + # Define your learn program and training loss + pass + +learn_program = parl.compile(learn_program, loss=training_loss) +# Pass the training loss to parl.compile. Distribute the model and data to GPUs. +``` + +## Demonstration + +We provide a demonstration of offline Q-learning with parallel executing, in which we seperate the procedures of collecting data and training the model. First we collect data by interacting with the environment and save them to a replay memory file, and then fit and evaluate the Q network with the collected data. Repeat these two steps to improve the performance gradually. + +### Dependencies: ++ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) ++ [parl](https://github.com/PaddlePaddle/PARL) ++ gym ++ tqdm ++ atari-py + +### How to Run: +```shell +# Collect training data +python parallel_run.py --rom rom_files/pong.bin + +# Train the model offline with multi-GPU +python parallel_run.py --rom rom_files/pong.bin --train +``` diff --git a/examples/offline-Q-learning/atari.py b/examples/offline-Q-learning/atari.py new file mode 120000 index 0000000..11909eb --- /dev/null +++ b/examples/offline-Q-learning/atari.py @@ -0,0 +1 @@ +../DQN/atari.py \ No newline at end of file diff --git a/examples/offline-Q-learning/atari_agent.py b/examples/offline-Q-learning/atari_agent.py new file mode 100644 index 0000000..abdc8b5 --- /dev/null +++ b/examples/offline-Q-learning/atari_agent.py @@ -0,0 +1,137 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +import parl +from parl import layers + +IMAGE_SIZE = (84, 84) +CONTEXT_LEN = 4 + + +class AtariAgent(parl.Agent): + def __init__(self, algorithm, act_dim, total_step): + super(AtariAgent, self).__init__(algorithm) + assert isinstance(act_dim, int) + self.act_dim = act_dim + self.exploration = 1.1 + self.global_step = 0 + self.update_target_steps = 10000 // 4 + + def build_program(self): + self.pred_program = fluid.Program() + self.learn_program = fluid.Program() + self.supervised_eval_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data( + name='obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + self.value = self.alg.predict(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data( + name='obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + action = layers.data(name='act', shape=[1], dtype='int32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data( + name='next_obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + self.cost = self.alg.learn(obs, action, reward, next_obs, terminal) + + # use parl.compile to distribute data and model to GPUs + self.learn_program = parl.compile(self.learn_program, loss=self.cost) + + with fluid.program_guard(self.supervised_eval_program): + obs = layers.data( + name='obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + action = layers.data(name='act', shape=[1], dtype='int32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data( + name='next_obs', + shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]], + dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + self.supervised_cost = self.alg.supervised_eval( + obs, action, reward, next_obs, terminal) + + def sample(self, obs): + sample = np.random.random() + if sample < self.exploration: + act = np.random.randint(self.act_dim) + else: + if np.random.random() < 0.01: + act = np.random.randint(self.act_dim) + else: + obs = np.expand_dims(obs, axis=0) + pred_Q = self.fluid_executor.run( + self.pred_program, + feed={'obs': obs.astype('float32')}, + fetch_list=[self.value])[0] + pred_Q = np.squeeze(pred_Q, axis=0) + act = np.argmax(pred_Q) + self.exploration = max(0.1, self.exploration - 1e-6) + return act + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + pred_Q = self.fluid_executor.run( + self.pred_program, + feed={'obs': obs.astype('float32')}, + fetch_list=[self.value])[0] + pred_Q = np.squeeze(pred_Q, axis=0) + act = np.argmax(pred_Q) + return act + + def learn(self, obs, act, reward, next_obs, terminal): + if self.global_step % self.update_target_steps == 0: + self.alg.sync_target() + self.global_step += 1 + + act = np.expand_dims(act, -1) + reward = np.clip(reward, -1, 1) + feed = { + 'obs': obs.astype('float32'), + 'act': act.astype('int32'), + 'reward': reward, + 'next_obs': next_obs.astype('float32'), + 'terminal': terminal + } + cost = self.fluid_executor.run( + self.learn_program, feed=feed, fetch_list=[self.cost])[0] + return cost + + def supervised_eval(self, obs, act, reward, next_obs, terminal): + act = np.expand_dims(act, -1) + reward = np.clip(reward, -1, 1) + feed = { + 'obs': obs.astype('float32'), + 'act': act.astype('int32'), + 'reward': reward, + 'next_obs': next_obs.astype('float32'), + 'terminal': terminal + } + cost = self.fluid_executor.run( + self.supervised_eval_program, + feed=feed, + fetch_list=[self.supervised_cost])[0] + return cost diff --git a/examples/offline-Q-learning/atari_model.py b/examples/offline-Q-learning/atari_model.py new file mode 100644 index 0000000..5a4bdbd --- /dev/null +++ b/examples/offline-Q-learning/atari_model.py @@ -0,0 +1,50 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl +from parl import layers + + +class AtariModel(parl.Model): + def __init__(self, act_dim): + self.act_dim = act_dim + + self.conv1 = layers.conv2d( + num_filters=32, filter_size=5, stride=1, padding=2, act='relu') + self.conv2 = layers.conv2d( + num_filters=32, filter_size=5, stride=1, padding=2, act='relu') + self.conv3 = layers.conv2d( + num_filters=64, filter_size=4, stride=1, padding=1, act='relu') + self.conv4 = layers.conv2d( + num_filters=64, filter_size=3, stride=1, padding=1, act='relu') + + self.fc1 = layers.fc(size=act_dim) + + def value(self, obs): + obs = obs / 255.0 + out = self.conv1(obs) + out = layers.pool2d( + input=out, pool_size=2, pool_stride=2, pool_type='max') + out = self.conv2(out) + out = layers.pool2d( + input=out, pool_size=2, pool_stride=2, pool_type='max') + out = self.conv3(out) + out = layers.pool2d( + input=out, pool_size=2, pool_stride=2, pool_type='max') + out = self.conv4(out) + out = layers.flatten(out, axis=1) + + Q = self.fc1(out) + return Q diff --git a/examples/offline-Q-learning/atari_wrapper.py b/examples/offline-Q-learning/atari_wrapper.py new file mode 120000 index 0000000..e58186a --- /dev/null +++ b/examples/offline-Q-learning/atari_wrapper.py @@ -0,0 +1 @@ +../DQN/atari_wrapper.py \ No newline at end of file diff --git a/examples/offline-Q-learning/dqn.py b/examples/offline-Q-learning/dqn.py new file mode 100644 index 0000000..feedf7d --- /dev/null +++ b/examples/offline-Q-learning/dqn.py @@ -0,0 +1,111 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.simplefilter('default') + +import copy +import paddle.fluid as fluid +from parl.core.fluid.algorithm import Algorithm +from parl.core.fluid import layers +from parl.utils.deprecation import deprecated + +__all__ = ['DQN'] + + +class DQN(Algorithm): + def __init__(self, + model, + hyperparas=None, + act_dim=None, + gamma=None, + lr=None): + """ DQN algorithm + + Args: + model (parl.Model): model defining forward network of Q function + hyperparas (dict): (deprecated) dict of hyper parameters. + act_dim (int): dimension of the action space + gamma (float): discounted factor for reward computation. + lr (float): learning rate. + """ + self.model = model + self.target_model = copy.deepcopy(model) + + if hyperparas is not None: + warnings.warn( + "the `hyperparas` argument of `__init__` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.", + DeprecationWarning, + stacklevel=2) + self.act_dim = hyperparas['action_dim'] + self.gamma = hyperparas['gamma'] + else: + assert isinstance(act_dim, int) + assert isinstance(gamma, float) + assert isinstance(lr, float) + self.act_dim = act_dim + self.gamma = gamma + self.lr = lr + + def predict(self, obs): + """ use value model self.model to predict the action value + """ + return self.model.value(obs) + + def learn(self, obs, action, reward, next_obs, terminal): + """ update value model self.model with DQN algorithm + """ + + cost = self.cal_bellman_residual(obs, action, reward, next_obs, + terminal) + optimizer = fluid.optimizer.Adam(learning_rate=self.lr, epsilon=1e-3) + optimizer.minimize(cost) + return cost + + def supervised_eval(self, obs, action, reward, next_obs, terminal): + """ Calculate squared Bellman residual with test dataset. The operations are the same as learn method above, + except backpropagation. + """ + cost = self.cal_bellman_residual(obs, action, reward, next_obs, + terminal) + cost.stop_gradient = True + return cost + + def cal_bellman_residual(self, obs, action, reward, next_obs, terminal): + """ use self.model to get squared Bellman residual with fed data + """ + pred_value = self.model.value(obs) + next_pred_value = self.target_model.value(next_obs) + best_v = layers.reduce_max(next_pred_value, dim=1) + best_v.stop_gradient = True + target = reward + ( + 1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * best_v + + action_onehot = layers.one_hot(action, self.act_dim) + action_onehot = layers.cast(action_onehot, dtype='float32') + pred_action_value = layers.reduce_sum( + layers.elementwise_mul(action_onehot, pred_value), dim=1) + cost = layers.square_error_cost(pred_action_value, target) + cost = layers.reduce_mean(cost) + return cost + + def sync_target(self, gpu_id=None): + """ sync weights of self.model to self.target_model + """ + if gpu_id is not None: + warnings.warn( + "the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.", + DeprecationWarning, + stacklevel=2) + self.model.sync_weights_to(self.target_model) diff --git a/examples/offline-Q-learning/parallel_run.py b/examples/offline-Q-learning/parallel_run.py new file mode 100644 index 0000000..3416f8c --- /dev/null +++ b/examples/offline-Q-learning/parallel_run.py @@ -0,0 +1,133 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gym +import numpy as np +import os +import time +from tqdm import tqdm + +import parl +import paddle.fluid as fluid +from parl.utils import get_gpu_count +from parl.utils import tensorboard, logger + +from dqn import DQN # slight changes from parl.algorithms.DQN +from atari_agent import AtariAgent +from atari_model import AtariModel +from replay_memory import ReplayMemory, Experience +from utils import get_player + +MEMORY_SIZE = int(1e6) +MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20 +IMAGE_SIZE = (84, 84) +CONTEXT_LEN = 4 +FRAME_SKIP = 4 +UPDATE_FREQ = 4 +GAMMA = 0.99 +LEARNING_RATE = 3e-4 + +gpu_num = get_gpu_count() + + +def run_train_step(agent, rpm): + for step in range(args.train_total_steps): + # use the first 80% data to train + batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch( + args.batch_size * gpu_num) + batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] + batch_next_state = batch_all_state[:, 1:, :, :] + cost = agent.learn(batch_state, batch_action, batch_reward, + batch_next_state, batch_isOver) + + if step % 100 == 0: + # use the last 20% data to evaluate + batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_test_batch( + args.batch_size) + batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] + batch_next_state = batch_all_state[:, 1:, :, :] + eval_cost = agent.supervised_eval(batch_state, batch_action, + batch_reward, batch_next_state, + batch_isOver) + logger.info( + "train step {}, train costs are {}, eval cost is {}.".format( + step, cost, eval_cost)) + + +def collect_exp(env, rpm, agent): + state = env.reset() + # collect data to fulfill replay memory + for i in tqdm(range(MEMORY_SIZE)): + context = rpm.recent_state() + context.append(state) + context = np.stack(context, axis=0) + action = agent.sample(context) + + next_state, reward, isOver, _ = env.step(action) + rpm.append(Experience(state, action, reward, isOver)) + state = next_state + + +def main(): + env = get_player( + args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP) + file_path = "memory.npz" + rpm = ReplayMemory( + MEMORY_SIZE, + IMAGE_SIZE, + CONTEXT_LEN, + load_file=True, # load replay memory data from file + file_path=file_path) + act_dim = env.action_space.n + + model = AtariModel(act_dim) + algorithm = DQN( + model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE * gpu_num) + agent = AtariAgent( + algorithm, act_dim=act_dim, total_step=args.train_total_steps) + if os.path.isfile('./model.ckpt'): + logger.info("load model from file") + agent.restore('./model.ckpt') + + if args.train: + logger.info("train with memory data") + run_train_step(agent, rpm) + logger.info("finish training. Save the model.") + agent.save('./model.ckpt') + else: + logger.info("collect experience") + collect_exp(env, rpm, agent) + rpm.save_memory() + logger.info("finish collecting, save successfully") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--rom', help='path of the rom of the atari game', required=True) + parser.add_argument( + '--batch_size', type=int, default=64, help='batch size for each GPU') + parser.add_argument( + '--train', + action="store_true", + help='update the value function (default: False)') + parser.add_argument( + '--train_total_steps', + type=int, + default=int(1e6), + help='maximum environmental steps of games') + + args = parser.parse_args() + main() diff --git a/examples/offline-Q-learning/replay_memory.py b/examples/offline-Q-learning/replay_memory.py new file mode 100644 index 0000000..2296ea9 --- /dev/null +++ b/examples/offline-Q-learning/replay_memory.py @@ -0,0 +1,154 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import copy +import os +from collections import deque, namedtuple +from parl.utils import logger + +Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver']) + + +class ReplayMemory(object): + def __init__(self, + max_size, + state_shape, + context_len, + load_file=False, + file_path=None): + self.max_size = int(max_size) + self.state_shape = state_shape + self.context_len = int(context_len) + + self.file_path = file_path + if load_file and os.path.isfile(file_path): + logger.info("load memory from file" + self.file_path) + self.load_memory() + logger.info("memory size is {}".format(self._curr_size)) + else: + self.state = np.zeros( + (self.max_size, ) + state_shape, dtype='uint8') + self.action = np.zeros((self.max_size, ), dtype='int32') + self.reward = np.zeros((self.max_size, ), dtype='float32') + self.isOver = np.zeros((self.max_size, ), dtype='bool') + + self._curr_size = 0 + self._curr_pos = 0 + self._context = deque(maxlen=context_len - 1) + + def append(self, exp): + """append a new experience into replay memory + """ + if self._curr_size < self.max_size: + self._assign(self._curr_pos, exp) + self._curr_size += 1 + else: + self._assign(self._curr_pos, exp) + self._curr_pos = (self._curr_pos + 1) % self.max_size + if exp.isOver: + self._context.clear() + else: + self._context.append(exp) + + def recent_state(self): + """ maintain recent state for training""" + lst = list(self._context) + states = [np.zeros(self.state_shape, dtype='uint8')] * \ + (self._context.maxlen - len(lst)) + states.extend([k.state for k in lst]) + return states + + def sample(self, idx): + """ return state, action, reward, isOver, + note that some frames in state may be generated from last episode, + they should be removed from state + """ + state = np.zeros( + (self.context_len + 1, ) + self.state_shape, dtype=np.uint8) + state_idx = np.arange(idx, + idx + self.context_len + 1) % self._curr_size + + # confirm that no frame was generated from last episode + has_last_episode = False + for k in range(self.context_len - 2, -1, -1): + to_check_idx = state_idx[k] + if self.isOver[to_check_idx]: + has_last_episode = True + state_idx = state_idx[k + 1:] + state[k + 1:] = self.state[state_idx] + break + + if not has_last_episode: + state = self.state[state_idx] + + real_idx = (idx + self.context_len - 1) % self._curr_size + action = self.action[real_idx] + reward = self.reward[real_idx] + isOver = self.isOver[real_idx] + return state, reward, action, isOver + + def __len__(self): + return self._curr_size + + def size(self): + return self._curr_size + + def _assign(self, pos, exp): + self.state[pos] = exp.state + self.reward[pos] = exp.reward + self.action[pos] = exp.action + self.isOver[pos] = exp.isOver + + def sample_batch(self, batch_size): + """sample a batch from replay memory for training + """ + batch_idx = np.random.randint( + int(self.max_size * 0.8) - self.context_len - 1, size=batch_size) + # batch_idx = (self._curr_pos + batch_idx) % self._curr_size + batch_exp = [self.sample(i) for i in batch_idx] + return self._process_batch(batch_exp) + + def sample_test_batch(self, batch_size): + batch_idx = np.random.randint( + int(self.max_size * 0.2) - self.context_len - 1, + size=batch_size) + int(self.max_size * 0.8) + # batch_idx = (self._curr_pos + batch_idx) % self._curr_size + batch_exp = [self.sample(i) for i in batch_idx] + return self._process_batch(batch_exp) + + def _process_batch(self, batch_exp): + state = np.asarray([e[0] for e in batch_exp], dtype='uint8') + reward = np.asarray([e[1] for e in batch_exp], dtype='float32') + action = np.asarray([e[2] for e in batch_exp], dtype='int8') + isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') + return [state, action, reward, isOver] + + def save_memory(self): + save_data = [ + self.state, self.reward, self.action, self.isOver, self._curr_size, + self._curr_pos, self._context + ] + np.savez(self.file_path, *save_data) + + def load_memory(self): + container = np.load(self.file_path, allow_pickle=True) + [ + self.state, self.reward, self.action, self.isOver, self._curr_size, + self._curr_pos, self._context + ] = [container[key] for key in container] + self._curr_size = self._curr_size.astype(int) + self._curr_pos = self._curr_pos.astype(int) + self._context = deque([Experience(*row) for row in self._context], + maxlen=self.context_len - 1) diff --git a/examples/offline-Q-learning/rom_files b/examples/offline-Q-learning/rom_files new file mode 120000 index 0000000..966a894 --- /dev/null +++ b/examples/offline-Q-learning/rom_files @@ -0,0 +1 @@ +../DQN/rom_files/ \ No newline at end of file diff --git a/examples/offline-Q-learning/utils.py b/examples/offline-Q-learning/utils.py new file mode 120000 index 0000000..721338d --- /dev/null +++ b/examples/offline-Q-learning/utils.py @@ -0,0 +1 @@ +../DQN/utils.py \ No newline at end of file -- GitLab