From a5686f4709a5040a94946fb6167c9dfa636e70b3 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 13 Sep 2022 17:05:52 +0530 Subject: [PATCH] Fix data path --- labml_nn/neox/checkpoint.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/labml_nn/neox/checkpoint.py b/labml_nn/neox/checkpoint.py index 3cc36703..28915359 100644 --- a/labml_nn/neox/checkpoint.py +++ b/labml_nn/neox/checkpoint.py @@ -8,7 +8,8 @@ summary: > # GPT-NeoX Checkpoints """ -from typing import Dict, Union, Tuple +from pathlib import Path +from typing import Dict, Union, Tuple, Optional import torch from torch import nn @@ -19,12 +20,21 @@ from labml.utils.download import download_file # Parent url CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/' + +_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None + + # Download path +def get_checkpoints_download_path(): + global _CHECKPOINTS_DOWNLOAD_PATH + + if _CHECKPOINTS_DOWNLOAD_PATH is not None: + return _CHECKPOINTS_DOWNLOAD_PATH -CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights' -if not CHECKPOINTS_DOWNLOAD_PATH.exists(): - CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights' -inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH) + _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights' + if not _CHECKPOINTS_DOWNLOAD_PATH.exists(): + _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights' + inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH) def get_files_to_download(n_layers: int = 44): @@ -65,7 +75,7 @@ def download(n_layers: int = 44): # Log logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)]) # Download - download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f) + download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f) def load_checkpoint_files(files: Tuple[str, str]): @@ -75,7 +85,7 @@ def load_checkpoint_files(files: Tuple[str, str]): :param files: pair of files to load :return: the loaded parameter tensors """ - checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000' + checkpoint_path = get_checkpoints_download_path() / 'global_step150000' with monit.section('Load checkpoint'): data = [torch.load(checkpoint_path / f) for f in files] -- GitLab