未验证 提交 3007a7f3 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #33 from wangxiao1021/downloader

fix downloader
*.pyc
__pycache__
pretrain_model
pretrain
output_model
build
dist
......
......@@ -110,28 +110,22 @@ paddlepalm框架的运行原理图如图所示
### 预训练模型
#### 下载
我们提供了BERT、ERNIE等主干网络的相关预训练模型。为了加速模型收敛,获得更佳的测试集表现,我们强烈建议用户在多任务学习时尽量在预训练模型的基础上进行(而不是从参数随机初始化开始)。用户可通过运行`script/download_pretrain_models <model_name>`下载需要的预训练模型,例如,下载预训练BERT模型(uncased large)的命令如下
我们提供了BERT、ERNIE等主干网络的相关预训练模型。为了加速模型收敛,获得更佳的测试集表现,我们强烈建议用户在多任务学习时尽量在预训练模型的基础上进行(而不是从参数随机初始化开始)。用户可以查看可供下载的预训练模型:
```shell
bash script/download_pretrain_backbone.sh bert
python download_models.py ls pretrain
```
脚本会自动在**当前文件夹**中创建一个pretrain_model目录(注:运行DEMO时,需保证pretrain_model文件夹在PALM项目目录下),并在其中创建bert子目录,里面存放预训练模型(`params`文件夹内)、相关的网络参数(`bert_config.json`)和字典(`vocab.txt`)。除了BERT模型,脚本还提供了ERNIE预训练模型(uncased large)的一键下载,将`<model_name>`改成`ernie`即可。全部可用的预训练模型列表见[paddlenlp/lark](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/PaddleLARK)
#### 转换
注意,预训练模型不能直接被框架使用。我们提供了转换脚本可以将其转换成paddlepalm的模型格式。如下,通过运行`script/convert_params.sh`可将预训练模型bert转换成框架的模型格式。
用户可通过运行`python download_models.py download <model_name>`下载需要的预训练模型,例如,下载预训练BERT模型(uncased large)的命令如下:
```shell
bash script/convert_params.sh pretrain_model/bert/params
python download_models.py download bert-en-uncased-large
```
注意,以下恢复操作在执行后述DEMO流程中**无需执行**
若用户需将转换成的paddlepalm模型恢复为原始的预训练模型,可以运行`script/recover_params.sh`进行恢复。
此外,用户也可通过运行`python download_models.py download all`下载已提供的所有预训练模型。
脚本会自动在**当前文件夹**中创建一个pretrain目录(注:运行DEMO时,需保证pretrain文件夹在PALM项目目录下),并在其中创建bert子目录,里面存放预训练模型(`params`文件夹内)、相关的网络参数(`bert_config.json`)和字典(`vocab.txt`)。除了BERT模型,脚本还提供了ERNIE预训练模型(uncased large)的一键下载,将`<model_name>`改成`ernie-en-uncased-large`即可。全部可用的预训练模型列表见[paddlenlp/lark](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/PaddleLARK)
```shell
bash script/recover_params.sh pretrain_model/bert/params
```
## 三个DEMO入门PaddlePALM
......@@ -169,7 +163,7 @@ max_seq_len: 512
max_query_len: 64
doc_stride: 128 # 在MRQA数据集中,存在较长的文档,因此我们这里使用滑动窗口处理样本,滑动步长设置为128
do_lower_case: True
vocab_path: "../../pretrain_model/bert/vocab.txt"
vocab_path: "../../pretrain/bert-en-uncased-large/vocab.txt"
```
更详细的任务实例配置方法(为任务实例选择合适的reader、paradigm和backbone)可参考[这里](#readerbackbone与paradigm的选择)
......@@ -184,7 +178,7 @@ task_instance: "mrqa"
save_path: "output_model/firstrun"
backbone: "bert"
backbone_config_path: "../../pretrain_model/bert/bert_config.json"
backbone_config_path: "../../pretrain/bert-en-uncased-large/bert_config.json"
optimizer: "adam"
learning_rate: 3e-5
......@@ -210,7 +204,7 @@ import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml')
controller.load_pretrain('../../pretrain_model/bert/params')
controller.load_pretrain('../../pretrain/bert-en-uncased-large/params')
controller.train()
```
......@@ -277,9 +271,9 @@ target_tag: 1,0,0
save_path: "output_model/secondrun"
backbone: "ernie"
backbone_config_path: "../../pretrain_model/ernie/ernie_config.json"
backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json"
vocab_path: "../../pretrain_model/ernie/vocab.txt"
vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt"
do_lower_case: True
max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例
......@@ -314,7 +308,7 @@ import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('../../pretrain_model/ernie/params')
controller.load_pretrain('../../pretrain/ernie-en-uncased-large/params')
controller.train()
```
......@@ -406,9 +400,9 @@ task_reuse_tag: 0, 0, 1, 1, 0, 2
save_path: "output_model/secondrun"
backbone: "ernie"
backbone_config_path: "../../pretrain_model/ernie/ernie_config.json"
backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json"
vocab_path: "../../pretrain_model/ernie/vocab.txt"
vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt"
do_lower_case: True
max_seq_len: 512 # 写入全局配置文件的参数会被自动广播到各个任务实例
......@@ -428,7 +422,7 @@ import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('../../pretrain_model/ernie/params')
controller.load_pretrain('../../pretrain/ernie-en-uncased-large/params')
controller.train()
```
......
......@@ -2,8 +2,8 @@ task_instance: "mrqa"
save_path: "output_model/firstrun"
backbone: "bert"
backbone_config_path: "../../pretrain_model/bert/bert_config.json"
backbone: "bert-en-uncased-large"
backbone_config_path: "../../pretrain/bert-en-uncased-large/bert_config.json"
batch_size: 4
num_epochs: 2
......
......@@ -2,7 +2,7 @@ train_file: data/mrqa/train.json
reader: mrc
paradigm: mrc
vocab_path: "../../pretrain_model/bert/vocab.txt"
vocab_path: "../../pretrain/bert-en-uncased-large/vocab.txt"
do_lower_case: True
max_seq_len: 512
doc_stride: 128
......
......@@ -2,6 +2,6 @@ import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml')
controller.load_pretrain('../../pretrain_model/bert/params')
controller.load_pretrain('../../pretrain/bert-en-uncased-large/params')
controller.train()
......@@ -4,15 +4,15 @@ mix_ratio: 1.0, 0.5, 0.5
save_path: "output_model/secondrun"
backbone: "ernie"
backbone_config_path: "../../pretrain_model/ernie/ernie_config.json"
backbone: "ernie-en-uncased-large"
backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json"
vocab_path: "../../pretrain_model/ernie/vocab.txt"
vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt"
do_lower_case: True
max_seq_len: 512
batch_size: 4
num_epochs: 2
num_epochs: 0.1
optimizer: "adam"
learning_rate: 3e-5
warmup_proportion: 0.1
......
......@@ -2,7 +2,7 @@ import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('../../pretrain_model/ernie/params')
controller.load_pretrain('../../pretrain/ernie-en-uncased-large/params')
controller.train()
controller = palm.Controller(config='config.yaml', task_dir='tasks', for_train=False)
......
......@@ -4,10 +4,10 @@ task_reuse_tag: 0,0,1,1,0,2
save_path: "output_model/thirdrun"
backbone: "ernie"
backbone_config_path: "../../pretrain_model/ernie/ernie_config.json"
backbone: "ernie-en-uncased-large"
backbone_config_path: "../../pretrain/ernie-en-uncased-large/ernie_config.json"
vocab_path: "../../pretrain_model/ernie/vocab.txt"
vocab_path: "../../pretrain/ernie-en-uncased-large/vocab.txt"
do_lower_case: True
max_seq_len: 512
......
......@@ -2,6 +2,6 @@ import paddlepalm as palm
if __name__ == '__main__':
controller = palm.Controller('config.yaml', task_dir='tasks')
controller.load_pretrain('../../pretrain_model/ernie/params')
controller.load_pretrain('../../pretrain/ernie-en-uncased-large/params')
controller.train()
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddlepalm as palm
import sys
if(sys.argv[1] == 'ls'):
palm.downloader.ls(sys.argv[2])
if(sys.argv[1] == 'download'):
palm.downloader.download('pretrain', sys.argv[2])
import downloader
from mtl_controller import Controller
import sys
from paddlepalm.mtl_controller import Controller
sys.path.append('paddlepalm')
del interface
del task_instance
del default_settings
del utils
del mtl_controller
\ No newline at end of file
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import tarfile
import shutil
try:
from urllib.request import urlopen # Python 3
except ImportError:
from urllib2 import urlopen # Python 2
import ssl
__all__ = ["download", "ls"]
# for https
ssl._create_default_https_context = ssl._create_unverified_context
_items = {
'pretrain': {'ernie-en-uncased-large': 'https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz',
'bert-en-uncased-large': 'https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz',
'utils': None},
'reader': {'utils': None},
'backbone': {'utils': None},
'tasktype': {'utils': None},
}
def lll():
pass
def _download(item, scope, path, silent=False):
data_url = _items[item][scope]
if data_url == None:
return
if not silent:
print('Downloading {}: {} from {}...'.format(item, scope, data_url))
data_dir = path + '/' + item + '/' + scope
if not os.path.exists(data_dir):
os.makedirs(os.path.join(data_dir))
data_name = data_url.split('/')[-1]
filename = data_dir + '/' + data_name
# print process
def _chunk_report(bytes_so_far, total_size):
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
if not silent:
print('\r>> Downloading... {:.1%}'.format(percent), end = "")
# copy to local
def _chunk_read(response, url, chunk_size = 16 * 1024, report_hook = None):
total_size = response.info().getheader('Content-Length').strip()
total_size = int(total_size)
bytes_so_far = 0
with open("%s" % filename, "wb") as f:
while 1:
chunk = response.read(chunk_size)
f.write(chunk)
f.flush()
bytes_so_far += len(chunk)
if not chunk:
break
if report_hook:
report_hook(bytes_so_far, total_size)
return bytes_so_far
response = urlopen(data_url)
_chunk_read(response, data_url, report_hook=_chunk_report)
if not silent:
print(' done!')
if item == 'pretrain':
if not silent:
print ('Extracting {}...'.format(data_name), end=" ")
if os.path.exists(filename):
tar = tarfile.open(filename, 'r')
tar.extractall(path = data_dir)
tar.close()
os.remove(filename)
if scope == 'bert-en-uncased-large':
source_path = data_dir + '/' + data_name.split('.')[0]
fileList = os.listdir(source_path)
for file in fileList:
filePath = os.path.join(source_path, file)
shutil.move(filePath, data_dir)
os.removedirs(source_path)
if not silent:
print ('done!')
if not silent:
print ('Converting params...', end=" ")
_convert(data_dir, silent)
if not silent:
print ('done!')
def _convert(path, silent=False):
if os.path.isfile(path + '/params/__palminfo__'):
if not silent:
print ('already converted.')
else:
if os.path.exists(path + '/params/'):
os.rename(path + '/params/', path + '/params1/')
os.mkdir(path + '/params/')
tar_model = tarfile.open(path + '/params/' + '__palmmodel__', 'w:gz')
tar_info = open(path + '/params/'+ '__palminfo__', 'w')
for root, dirs, files in os.walk(path + '/params1/'):
for file in files:
src_file = os.path.join(root, file)
tar_model.add(src_file, '__paddlepalm_' + file)
tar_info.write('__paddlepalm_' + file)
os.remove(src_file)
tar_model.close()
tar_info.close()
os.removedirs(path + '/params1/')
def download(item, scope='all', path='.'):
item = item.lower()
scope = scope.lower()
assert item in _items, '{} is not found. Support list: {}'.format(item, list(_items.keys()))
if _items[item]['utils'] is not None:
_download(item, 'utils', path, silent=True)
if scope != 'all':
assert scope in _items[item], '{} is not found. Support items: {}'.format(item, list(_items.keys()))
_download(item, scope, path)
else:
for s in _items[item].keys():
_download(item, s, path)
def _ls(item, scope, l = 10):
if scope != 'all':
assert scope in _items[item], '{} is not found. Support scopes: {}'.format(scope, list(_items[item].keys()))
print ('{} ==> {}'.format(item, scope))
else:
for s in _items[item].keys():
if s == 'utils':
continue
print ('{} ==> {}'.format(item.ljust(l), s))
def ls(item='all', scope='all'):
if scope == 'utils':
return
print ('Download list:')
if item != 'all':
assert item in _items, '{} is not found. Support scopes: {}'.format(item, list(_items.keys()))
_ls(item, scope)
else:
l = max(map(len, _items.keys()))
for i in _items.keys():
_ls(i, scope, l)
from __future__ import print_function
import os
items = {
'pretrain': {'ernie-en-uncased-large': 'http://xxxxx',
'xxx': 'xxx',
'utils': None}
'reader': {'cls': 'xxx',
'xxx': 'xxx',
'utils': 'xxx'}
'backbone': {xxx}
'tasktype': {xxx}
}
def download(item, scope='all', path='.'):
item = item.lower()
scope = scope.lower()
assert item in items, '{} is not found. Support list: {}'.format(item, list(items.keys()))
if not os.path.exists(path, item):
os.makedirs(os.path.join(path, item))
def _download(item, scope, silent=False):
if not silent:
print('downloading {}: {} from {}...'.format(item, scope, items[item][scope]), end='')
urllib.downloadxxx(items[item][scope])
if not silent:
print('done!')
if items['utils'] is not None:
_download(item, 'utils', silent=True)
if scope != 'all':
assert scope in items[item], '{} is not found. Support scopes: {}'.format(item, list(items[item].keys()))
_download(item, scope)
else:
for s in items[item].keys():
_download(item, s)
def ls(item=None, scope='all'):
pass
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _downloader import *
from __future__ import print_function
import os
import tarfile
import shutil
try:
from urllib.request import urlopen # Python 3
except ImportError:
from urllib2 import urlopen # Python 2
# for https
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
_items = {
'pretrain': {'ernie-en-uncased-large': 'https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz',
'bert-en-uncased-large': 'https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz',
'utils': None},
'reader': {'utils': None},
'backbone': {'utils': None},
'tasktype': {'utils': None},
}
def _download(item, scope, path, silent=False):
if not silent:
print('Downloading {}: {} from {}...'.format(item, scope, _items[item][scope]))
data_url = _items[item][scope]
data_dir = path + '/' + item + '/' + scope
if not os.path.exists(data_dir):
os.makedirs(os.path.join(data_dir))
data_name = data_url.split('/')[-1]
filename = data_dir + '/' + data_name
# print process
def _chunk_report(bytes_so_far, total_size):
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
if not silent:
print('\r>> Downloading... {:.1%}'.format(percent), end = "")
# copy to local
def _chunk_read(response, url, chunk_size = 16 * 1024, report_hook = None):
total_size = response.info().getheader('Content-Length').strip()
total_size = int(total_size)
bytes_so_far = 0
with open("%s" % filename, "wb") as f:
while 1:
chunk = response.read(chunk_size)
f.write(chunk)
f.flush()
bytes_so_far += len(chunk)
if not chunk:
break
if report_hook:
report_hook(bytes_so_far, total_size)
return bytes_so_far
response = urlopen(data_url)
_chunk_read(response, data_url, report_hook=_chunk_report)
if not silent:
print(' done!')
if item == 'pretrain':
if not silent:
print ('Extracting {}...'.format(data_name), end=" ")
if os.path.exists(filename):
tar = tarfile.open(filename, 'r')
tar.extractall(path = data_dir)
tar.close()
os.remove(filename)
if scope == 'bert-en-uncased-large':
source_path = data_dir + '/' + data_name.split('.')[0]
fileList = os.listdir(source_path)
for file in fileList:
filePath = os.path.join(source_path, file)
shutil.move(filePath, data_dir)
os.removedirs(source_path)
if not silent:
print ('done!')
if not silent:
print ('Converting params...', end=" ")
_convert(data_dir, silent)
if not silent:
print ('done!')
def _convert(path, silent=False):
if os.path.isfile(path + '/params/__palminfo__'):
if not silent:
print ('already converted.')
else:
if os.path.exists(path + '/params/'):
os.rename(path + '/params/', path + '/params1/')
os.mkdir(path + '/params/')
tar_model = tarfile.open(path + '/params/' + '__palmmodel__', 'w:gz')
tar_info = open(path + '/params/'+ '__palminfo__', 'w')
for root, dirs, files in os.walk(path + '/params1/'):
for file in files:
src_file = os.path.join(root, file)
tar_model.add(src_file, '__paddlepalm_' + file)
tar_info.write('__paddlepalm_' + file)
os.remove(src_file)
tar_model.close()
tar_info.close()
os.removedirs(path + '/params1/')
def download(item, scope='all', path='.'):
item = item.lower()
scope = scope.lower()
assert item in _items, '{} is not found. Support list: {}'.format(item, list(_items.keys()))
if _items[item]['utils'] is not None:
_download(item, 'utils', path, silent=True)
if scope != 'all':
assert scope in _items[item], '{} is not found. Support scopes: {}'.format(item, list(_items[item].keys()))
_download(item, scope, path)
else:
for s in _items[item].keys():
_download(item, s, path)
def ls(item=None, scope='all'):
pass
......@@ -33,7 +33,7 @@ from paddlepalm.utils.print_helper import print_dict
from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs
from paddlepalm.default_settings import *
from paddlepalm.task_instance import TaskInstance, check_instances
from task_instance import TaskInstance, check_instances
DEBUG=False
VERBOSE=0
......@@ -234,7 +234,7 @@ class Controller(object):
bb_conf = _merge_conf(mtl_conf, bb_conf)
else:
bb_conf = mtl_conf
print_dict(bb_conf, title='backbone configuration'.format(instname))
print_dict(bb_conf, title = 'backbone configuration'.format(instname))
bb_name = mtl_conf['backbone']
bb_mod = importlib.import_module(BACKBONE_DIR + '.' + bb_name)
......@@ -522,15 +522,15 @@ class Controller(object):
inst.reader['pred'] = pred_reader
return pred_prog
def load_pretrain(self, pretrain_model_path=None):
def load_pretrain(self, pretrain_path=None):
# load pretrain model (or ckpt)
if pretrain_model_path is None:
assert 'pretrain_model_path' in self.main_conf, "pretrain_model_path NOT set."
pretrain_model_path = self.main_conf['pretrain_model_path']
if pretrain_path is None:
assert 'pretrain_path' in self.main_conf, "pretrain_path NOT set."
pretrain_path = self.main_conf['pretrain_path']
init_pretraining_params(
self.exe,
pretrain_model_path,
pretrain_path,
main_program=fluid.default_startup_program())
......@@ -673,6 +673,7 @@ if __name__ == '__main__':
__all__ = ["Controller"]
......@@ -55,7 +55,7 @@ def init_pretraining_params(exe,
print("Loading pretraining parameters from {}...".format(
pretraining_params_path))
with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r:') as f:
with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r:gz') as f:
f.extractall(os.path.join(pretraining_params_path, '.temp'))
log_path = os.path.join(pretraining_params_path, '__palmmodel__')
......
#!/bin/sh
if [[ $# != 1 ]]; then
echo "usage: bash convert_params.sh <params_dir>"
exit 1
fi
if [[ -f $1/__palminfo__ ]]; then
echo "already converted."
exit 0
fi
echo "converting..."
if [[ -d $1/params ]]; then
cd $1/params
else
cd $1
fi
mkdir .palm.backup
for file in $(ls *)
do cp $file .palm.backup; mv $file "__paddlepalm_"$file
done
tar -cf __rawmodel__ .palm.backup/*
rm .palm.backup/*
mv __rawmodel__ .palm.backup
# find . ! -name '__rawmodel__' -exec rm {} +
tar -cf __palmmodel__ __paddlepalm_*
touch __palminfo__
ls __paddlepalm_* > __palminfo__
rm __paddlepalm_*
cd - >/dev/null
echo "done!"
#!/bin/bash
set -e
if [[ $# != 1 ]]; then
echo "Usage: bash download_pretrain.sh <bert|ernie>"
exit 1
fi
if [[ $1 == 'bert' ]]; then
name="bert"
link="https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz"
packname="uncased_L-24_H-1024_A-16.tar.gz"
dirname="uncased_L-24_H-1024_A-16"
elif [[ $1 == 'ernie' ]]; then
name="ernie"
link="https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz"
packname="ERNIE_Large_en_stable-2.0.0.tar.gz"
else
echo "$1 is currently not supported."
exit 1
fi
if [[ ! -d pretrain_model ]]; then
mkdir pretrain_model
fi
cd pretrain_model
mkdir $name
cd $name
echo "downloading ${name}..."
wget --no-check-certificate $link
echo "decompressing..."
tar -zxf $packname
rm -rf $packname
if [[ $dirname != "" ]]; then
mv $dirname/* .
rm -rf $dirname
fi
cd ../..
#!/bin/sh
if [[ $# != 1 ]]; then
echo "usage: bash recover_params.sh <params_dir>"
exit 1
fi
if [[ ! -d $1 ]]; then
echo "$1 not found."
exit 1
fi
if [[ ! -f $1/__palmmodel__ ]]; then
echo "paddlepalm model not found."
exit 1
fi
echo "recovering..."
if [[ -d $1/params ]]; then
cd $1/params
else
cd $1
fi
rm __palm*
mv .palm.backup/__rawmodel__ .
rm -rf .palm.backup
tar -xf __rawmodel__
mv .palm.backup/* .
rm __rawmodel__
rm -rf .palm.backup
cd - >/dev/null
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册