未验证 提交 a5686f47 编写于 作者: V Varuna Jayasiri 提交者: GitHub

Fix data path

上级 7d1550dd
......@@ -8,7 +8,8 @@ summary: >
# GPT-NeoX Checkpoints
"""
from typing import Dict, Union, Tuple
from pathlib import Path
from typing import Dict, Union, Tuple, Optional
import torch
from torch import nn
......@@ -19,12 +20,21 @@ from labml.utils.download import download_file
# Parent url
CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None
# Download path
def get_checkpoints_download_path():
global _CHECKPOINTS_DOWNLOAD_PATH
if _CHECKPOINTS_DOWNLOAD_PATH is not None:
return _CHECKPOINTS_DOWNLOAD_PATH
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
if not CHECKPOINTS_DOWNLOAD_PATH.exists():
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH)
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
if not _CHECKPOINTS_DOWNLOAD_PATH.exists():
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)
def get_files_to_download(n_layers: int = 44):
......@@ -65,7 +75,7 @@ def download(n_layers: int = 44):
# Log
logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
# Download
download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f)
download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)
def load_checkpoint_files(files: Tuple[str, str]):
......@@ -75,7 +85,7 @@ def load_checkpoint_files(files: Tuple[str, str]):
:param files: pair of files to load
:return: the loaded parameter tensors
"""
checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000'
checkpoint_path = get_checkpoints_download_path() / 'global_step150000'
with monit.section('Load checkpoint'):
data = [torch.load(checkpoint_path / f) for f in files]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册