提交 0a317161 编写于 作者: X xixiaoyao

release 0.3

上级 061b9e9a
../../paddlepalm/
\ No newline at end of file
/home/zhangyiming/yiming/v02-exe/pretrain
\ No newline at end of file
# coding=utf-8
import paddlepalm as palm
import json
if __name__ == '__main__':
max_seqlen = 512
batch_size = 32
batch_size = 3
num_epochs = 2
lr = 1e-3
vocab_path = './pretrain/ernie/vocab.txt'
train_file = './data/cls4mrqa/train.tsv'
config = json.load(open('./pretrain/ernie/ernie_config.json'))
print(config)
ernie = palm.backbone.ERNIE.from_config(config)
pred_ernie = palm.backbone.ERNIE.from_config(config, phase='pred')
# cls_reader2 = palm.reader.cls(train_file_topic, vocab_path, batch_size, max_seqlen)
# cls_reader3 = palm.reader.cls(train_file_subj, vocab_path, batch_size, max_seqlen)
# topic_trainer = palm.Trainer('topic_cls', cls_reader2, cls)
# subj_trainer = palm.Trainer('subj_cls', cls_reader3, cls)
# 创建该分类任务的reader,由诸多参数控制数据集读入格式、文件数量、预处理规则等
cls_reader = palm.reader.ClassifyReader(vocab_path, batch_size, max_seqlen)
print(cls_reader.outputs_attr)
# 不同的backbone会对任务reader有不同的特征要求,例如对于分类任务,基本的输入feature为token_ids和label_ids,但是对于BERT,还要求从输入中额外提取position、segment、input_mask等特征,因此经过register后,reader会自动补充backbone所要求的字段
cls_reader.register_with(ernie)
print(cls_reader.outputs_attr)
# 创建任务头(task head),如分类、匹配、机器阅读理解等。每个任务头有跟该任务相关的必选/可选参数。注意,任务头与reader是解耦合的,只要任务头依赖的数据集侧的字段能被reader提供,那么就是合法的
cls_head = palm.head.Classify(4, 1024, 0.1)
cls_pred_head = palm.head.Classify(4, 1024, 0.1, phase='pred')
# 根据reader和任务头来创建一个训练器trainer,trainer代表了一个训练任务,内部维护着训练进程、和任务的关键信息,并完成合法性校验,该任务的模型保存、载入等相关规则控制
trainer = palm.Trainer('senti_cls', cls_reader, cls_head, save_predict_model=True, \
pred_head=cls_pred_head, save_path='./output')
# match4mrqa.reuse_head_with(mrc4mrqa)
# data_vars = cls_reader.build()
# output_vars = ernie.build(data_vars)
# cls_head.build({'backbone': output_vars, 'reader': data_vars})
loss_var = trainer.build_forward(ernie, pred_ernie)
# controller.build_forward()
# Error! a head/backbone can be only build once! Try NOT to call build_forward method for any Trainer!
print(trainer.num_examples)
trainer.load_data(train_file, 'csv', num_epochs=2, batch_size=32)
print(trainer.num_examples)
n_steps = trainer.num_examples * num_epochs // batch_size
warmup_steps = int(0.1 * n_steps)
print(warmup_steps)
sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
adam = palm.optimizer.Adam(loss_var, lr, sched)
trainer.build_backward(optimizer=adam, weight_decay=0.001, \
use_ema=True, ema_decay=0.999)
trainer.random_init_params()
trainer.load_pretrain('../../pretrain_model/ernie/params')
# trainer.train_one_step()
# trainer.train_one_epoch()
trainer.train()
trainer.save()
match_reader = palm.reader.match(train_file, vocab, \
max_seqlen, file_format='csv', tokenizer='wordpiece', \
lang='en', shuffle_train=True)
mrc_reader = palm.reader.mrc(train_file, phase='train')
mlm_reader = palm.reader.mlm(train_file, phase='train')
palm.reader.
match = palm.tasktype.cls(num_classes=4)
mrc = palm.tasktype.match(learning_strategy='pairwise')
mlm = palm.tasktype.mlm()
mlm.print()
bb_flags = palm.load_json('./pretrain/ernie/ernie_config.json')
bb = palm.backbone.ernie(bb_flags['xx'], xxx)
bb.print()
match4mrqa = palm.Task('match4mrqa', match_reader, match_tt)
mrc4mrqa = palm.Task('match4mrqa', match_reader, match_tt)
# match4mrqa.reuse_with(mrc4mrqa)
controller = palm.Controller([mrqa, match4mrqa, mlm4mrqa])
# controller = palm.Controller([mrqa, match4mrqa, mlm4mrqa])
loss = controller.build_forward(bb, mask_task=[])
......
import downloader
from mtl_controller import Controller
import sys
from paddlepalm.mtl_controller import Controller
from paddlepalm.task_instance import Task
sys.path.append('paddlepalm')
del interface
del task_instance
del default_settings
del utils
del mtl_controller
\ No newline at end of file
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import tarfile
import shutil
try:
from urllib.request import urlopen # Python 3
except ImportError:
from urllib2 import urlopen # Python 2
import ssl
__all__ = ["download", "ls"]
# for https
ssl._create_default_https_context = ssl._create_unverified_context
_items = {
'pretrain': {'ernie-en-uncased-large': 'https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz',
'bert-en-uncased-large': 'https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz',
'bert-en-uncased-base': 'https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz',
'utils': None},
'reader': {'utils': None},
'backbone': {'utils': None},
'tasktype': {'utils': None},
}
def _download(item, scope, path, silent=False):
data_url = _items[item][scope]
if data_url == None:
return
if not silent:
print('Downloading {}: {} from {}...'.format(item, scope, data_url))
data_dir = path + '/' + item + '/' + scope
if not os.path.exists(data_dir):
os.makedirs(os.path.join(data_dir))
data_name = data_url.split('/')[-1]
filename = data_dir + '/' + data_name
# print process
def _chunk_report(bytes_so_far, total_size):
percent = float(bytes_so_far) / float(total_size)
if percent > 1:
percent = 1
if not silent:
print('\r>> Downloading... {:.1%}'.format(percent), end = "")
# copy to local
def _chunk_read(response, url, chunk_size = 16 * 1024, report_hook = None):
total_size = response.info().getheader('Content-Length').strip()
total_size = int(total_size)
bytes_so_far = 0
with open("%s" % filename, "wb") as f:
while 1:
chunk = response.read(chunk_size)
f.write(chunk)
f.flush()
bytes_so_far += len(chunk)
if not chunk:
break
if report_hook:
report_hook(bytes_so_far, total_size)
return bytes_so_far
response = urlopen(data_url)
_chunk_read(response, data_url, report_hook=_chunk_report)
if not silent:
print(' done!')
if item == 'pretrain':
if not silent:
print ('Extracting {}...'.format(data_name), end=" ")
if os.path.exists(filename):
tar = tarfile.open(filename, 'r')
tar.extractall(path = data_dir)
tar.close()
os.remove(filename)
if scope.startswith('bert'):
source_path = data_dir + '/' + data_name.split('.')[0]
fileList = os.listdir(source_path)
for file in fileList:
filePath = os.path.join(source_path, file)
shutil.move(filePath, data_dir)
os.removedirs(source_path)
if not silent:
print ('done!')
if not silent:
print ('Converting params...', end=" ")
_convert(data_dir, silent)
if not silent:
print ('done!')
def _convert(path, silent=False):
if os.path.isfile(path + '/params/__palminfo__'):
if not silent:
print ('already converted.')
else:
if os.path.exists(path + '/params/'):
os.rename(path + '/params/', path + '/params1/')
os.mkdir(path + '/params/')
tar_model = tarfile.open(path + '/params/' + '__palmmodel__', 'w')
tar_info = open(path + '/params/'+ '__palminfo__', 'w')
for root, dirs, files in os.walk(path + '/params1/'):
for file in files:
src_file = os.path.join(root, file)
tar_model.add(src_file, '__paddlepalm_' + file)
tar_info.write('__paddlepalm_' + file)
os.remove(src_file)
tar_model.close()
tar_info.close()
os.removedirs(path + '/params1/')
def download(item, scope='all', path='.'):
item = item.lower()
scope = scope.lower()
assert item in _items, '{} is not found. Support list: {}'.format(item, list(_items.keys()))
if _items[item]['utils'] is not None:
_download(item, 'utils', path, silent=True)
if scope != 'all':
assert scope in _items[item], '{} is not found. Support scopes: {}'.format(scope, list(_items[item].keys()))
_download(item, scope, path)
else:
for s in _items[item].keys():
_download(item, s, path)
def _ls(item, scope, l = 10):
if scope != 'all':
assert scope in _items[item], '{} is not found. Support scopes: {}'.format(scope, list(_items[item].keys()))
print ('{}'.format(scope))
else:
for s in _items[item].keys():
if s == 'utils':
continue
print (' => '+s)
def ls(item='all', scope='all'):
if scope == 'utils':
return
if item != 'all':
assert item in _items, '{} is not found. Support scopes: {}'.format(item, list(_items.keys()))
print ('Available {} items:'.format(item))
_ls(item, scope)
else:
l = max(map(len, _items.keys()))
for i in _items.keys():
print ('Available {} items: '.format(i))
_ls(i, scope, l)
......@@ -52,9 +52,9 @@ class Model(backbone):
@property
def inputs_attr(self):
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']}
@property
......@@ -73,7 +73,7 @@ class Model(backbone):
self._emb_dtype = 'float32'
# padding id in vocabulary must be set to 0
emb_out = fluid.layers.embedding(
emb_out = fluid.embedding(
input=src_ids,
size=[self._voc_size, self._emb_size],
dtype=self._emb_dtype,
......@@ -84,14 +84,14 @@ class Model(backbone):
# fluid.global_scope().find_var('backbone-word_embedding').get_tensor()
embedding_table = fluid.default_main_program().global_block().var(scope_name+self._word_emb_name)
position_emb_out = fluid.layers.embedding(
position_emb_out = fluid.embedding(
input=pos_ids,
size=[self._max_position_seq_len, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._pos_emb_name, initializer=self._param_initializer))
sent_emb_out = fluid.layers.embedding(
sent_emb_out = fluid.embedding(
sent_ids,
size=[self._sent_types, self._emb_size],
dtype=self._emb_dtype,
......
......@@ -62,11 +62,11 @@ class Model(backbone):
@property
def inputs_attr(self):
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1,-1, 1], 'int64']}
"task_ids": [[-1,-1], 'int64']}
@property
def outputs_attr(self):
......@@ -85,7 +85,7 @@ class Model(backbone):
task_ids = inputs['task_ids']
# padding id in vocabulary must be set to 0
emb_out = fluid.layers.embedding(
emb_out = fluid.embedding(
input=src_ids,
size=[self._voc_size, self._emb_size],
dtype=self._emb_dtype,
......@@ -96,14 +96,14 @@ class Model(backbone):
# fluid.global_scope().find_var('backbone-word_embedding').get_tensor()
embedding_table = fluid.default_main_program().global_block().var(scope_name+self._word_emb_name)
position_emb_out = fluid.layers.embedding(
position_emb_out = fluid.embedding(
input=pos_ids,
size=[self._max_position_seq_len, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._pos_emb_name, initializer=self._param_initializer))
sent_emb_out = fluid.layers.embedding(
sent_emb_out = fluid.embedding(
sent_ids,
size=[self._sent_types, self._emb_size],
dtype=self._emb_dtype,
......@@ -113,7 +113,7 @@ class Model(backbone):
emb_out = emb_out + position_emb_out
emb_out = emb_out + sent_emb_out
task_emb_out = fluid.layers.embedding(
task_emb_out = fluid.embedding(
task_ids,
size=[self._task_types, self._emb_size],
dtype=self._emb_dtype,
......
from __future__ import print_function
import os
items = {
'pretrain': {'ernie-en-uncased-large': 'http://xxxxx',
'xxx': 'xxx',
'utils': None}
'reader': {'cls': 'xxx',
'xxx': 'xxx',
'utils': 'xxx'}
'backbone': {xxx}
'tasktype': {xxx}
}
def download(item, scope='all', path='.'):
item = item.lower()
scope = scope.lower()
assert item in items, '{} is not found. Support list: {}'.format(item, list(items.keys()))
if not os.path.exists(path, item):
os.makedirs(os.path.join(path, item))
def _download(item, scope, silent=False):
if not silent:
print('downloading {}: {} from {}...'.format(item, scope, items[item][scope]), end='')
urllib.downloadxxx(items[item][scope])
if not silent:
print('done!')
if items['utils'] is not None:
_download(item, 'utils', silent=True)
if scope != 'all':
assert scope in items[item], '{} is not found. Support scopes: {}'.format(item, list(items[item].keys()))
_download(item, scope)
else:
for s in items[item].keys():
_download(item, s)
def ls(item=None, scope='all'):
pass
from _downloader import *
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册