提交 bf089579 编写于 作者: X xixiaoyao

fix load pretrain

上级 af31077d
......@@ -62,7 +62,7 @@ if __name__ == '__main__':
use_ema=True, ema_decay=0.999)
trainer.random_init_params()
trainer.load_pretrain('../../pretrain_model/ernie/params')
trainer.load_pretrain('pretrain/ernie/params')
# trainer.train_one_step()
# trainer.train_one_epoch()
......
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
from reader import yield_pieces, data_feeder
from . import gpu_dev_count, cpu_dev_count
import Queue
from threading import Thread
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
def yield_pieces(data, distribute_strategy, batch_size):
"""
Args:
distribute_strategy: support s=split, c=copy, u=unstack,
"""
assert batch_size % dev_count == 0, "batch_size need to be integer times larger than dev_count."
print('data in yield pieces')
print(len(data))
assert type(data) == type(distribute_strategy), [type(data), type(distribute_strategy)]
assert len(data) == len(distribute_strategy), [len(data), len(distribute_strategy)]
if isinstance(data, dict):
keys = list(data.keys())
data_list = [data[i] for i in keys]
ds_list = [distribute_strategy[i] for i in keys]
else:
assert isinstance(data, list), "the input data must be a list or dict, and contained with multiple tensors."
data_list = data
ds_list = distribute_strategy
stride = batch_size // dev_count
p = stride
# while p < len(data_list) + stride:
while p <= batch_size:
temp = []
for d, s in zip(data_list, ds_list):
s = s.strip().lower()
if s == 's' or s == 'split':
if p - stride >= len(d):
print('WARNING: no more examples to feed empty devices')
temp = []
return
temp.append(d[p-stride:p])
elif s == 'u' or s == 'unstack':
assert len(d) <= dev_count, 'Tensor size on dim 0 must be less equal to dev_count when unstack is applied.'
if p//stride > len(d):
print('WARNING: no more examples to feed empty devices')
return
temp.append(d[p//stride-1])
elif s == 'c' or s == 'copy':
temp.append(d)
else:
raise NotImplementedError()
p += stride
if type(data) == dict:
yield dict(zip(*[keys, temp]))
else:
print('yielded pieces')
print(len(temp))
yield temp
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
if postprocess_fn is None:
def postprocess_fn(batch):
return batch
def worker(reader, dev_count, queue):
dev_batches = []
for index, data in enumerate(reader()):
if len(dev_batches) < dev_count:
dev_batches.append(data)
if len(dev_batches) == dev_count:
queue.put((dev_batches, 0))
dev_batches = []
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if len(dev_batches) > 0:
num_pad = dev_count - len(dev_batches)
for i in range(len(dev_batches), dev_count):
dev_batches.append(dev_batches[-1])
queue.put((dev_batches, num_pad))
queue.put(None)
queue = Queue.Queue(dev_count*prefetch_steps)
p = Thread(
target=worker, args=(reader, dev_count, queue))
p.daemon = True
p.start()
while True:
ret = queue.get()
queue.task_done()
if ret is not None:
batches, num_pad = ret
batch_buf = []
flag_buf = []
for idx, batch in enumerate(batches):
# flag = num_pad == 0
flag = idx-len(batches) < -num_pad
# if num_pad > 0:
# num_pad -= 1
batch = postprocess_fn(batch)
batch_buf.append(batch)
flag_buf.append(flag)
yield batch_buf, flag_buf
else:
break
queue.join()
......@@ -18,7 +18,8 @@ import os
import json
from paddle import fluid
import paddlepalm.utils.basic_helper as helper
from paddlepalm.utils import reader_helper
from paddlepalm.utils import reader_helper, saver
from paddlepalm.distribute import gpu_dev_count
# from paddlepalm.default_settings import *
DEBUG=False
......@@ -79,7 +80,7 @@ class Trainer(object):
self._pred_fetch_name_list = []
self._pred_fetch_var_list = []
self._exe = fluid.Executor(fluid.CPUPlace())
self._exe = None
self._save_protocol = {
'input_names': 'self._pred_input_name_list',
......@@ -256,8 +257,17 @@ class Trainer(object):
return iterator_fn
def random_init_params(self):
helper.build_executor()
on_gpu = gpu_dev_count > 0
self._exe = helper.build_executor(on_gpu)
def load_pretrain(self, model_path):
# load pretrain model (or ckpt)
assert self._exe is not None, "You need to random_init_params before load pretrain models."
saver.init_pretraining_params(
self._exe,
model_path,
main_program=self._train_init_prog)
def _build_head(self, net_inputs, phase, scope=""):
if phase == 'train':
......
......@@ -3,6 +3,7 @@ import os
import json
import yaml
from config_helper import PDConfig
from paddle import fluid
def get_basename(f):
return os.path.splitext(f)[0]
......
......@@ -55,7 +55,7 @@ def init_pretraining_params(exe,
print("Loading pretraining parameters from {}...".format(
pretraining_params_path))
with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r:') as f:
with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r') as f:
f.extractall(os.path.join(pretraining_params_path, '.temp'))
log_path = os.path.join(pretraining_params_path, '__palmmodel__')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册