未验证 提交 0aad25c0 编写于 作者: K kinghuin 提交者: GitHub

add ernie_gen (#853)

上级 abfc18b6
## 概述
ERNIE-GEN 是面向生成任务的预训练-微调框架,首次在预训练阶段加入span-by-span 生成任务,让模型每次能够生成一个语义完整的片段。在预训练和微调中通过填充式生成机制和噪声感知机制来缓解曝光偏差问题。此外, ERNIE-GEN 采样多片段-多粒度目标文本采样策略, 增强源文本和目标文本的关联性,加强了编码器和解码器的交互。ernie_gen module是一个具备微调功能的module,可以快速完成特定场景module的制作。
<p align="center">
<img src="https://paddlehub.bj.bcebos.com/resources/multi-flow-attention.png" hspace='10'/> <br />
</p>
更多详情参考论文[ERNIE-GEN:An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation](https://arxiv.org/abs/2001.11314)
## API
```python
def finetune(
train_path,
dev_path=None,
save_dir="ernie_gen_result",
init_ckpt_path=None,
use_gpu=True,
max_steps=500,
batch_size=8,
max_encode_len=15,
max_decode_len=15,
learning_rate=5e-5,
warmup_proportion=0.1,
weight_decay=0.1,
noise_prob=0,
label_smooth=0,
beam_width=5,
length_penalty=1.0,
log_interval=100,
save_interval=200,
):
```
微调API,
**参数**
* train_path(str): 训练集路径。训练集的格式应为:"序号\t输入文本\t标签",例如:"1\t床前明月光\t疑是地上霜"
* dev_path(str): 验证集路径。验证集的格式应为:"序号\t输入文本\t标签",例如:"1\t举头望明月\t低头思故乡"
* save_dir(str): 模型保存以及验证集预测输出路径。
* init_ckpt_path(str): 模型初始化加载路径,可实现增量训练。
* use_gpu(bool): 是否使用GPU。
* max_steps(int): 最大训练步数。
* batch_size(int): 训练时的batch大小。
* max_encode_len(int): 最长编码长度。
* max_decode_len(int): 最长解码长度。
* learning_rate(float): 学习率大小。
* warmup_proportion(float): 学习率warmup比例。
* weight_decay(float): 权值衰减大小。
* noise_prob(float): 噪声概率,详见ernie gen论文。
* label_smooth(float): 标签平滑权重。
* beam_width(int): 验证集预测时的beam大小。
* length_penalty(float): 验证集预测时的长度惩罚权重。
* log_interval(int): 训练时的日志打印间隔步数。
* save_interval(int): 训练时的模型保存间隔部署。验证集将在模型保存完毕后进行预测。
**返回**
* result(dict): 运行结果。包含2个键:
```
last_save_path(str): 训练结束时的模型保存路径。
last_ppl(float): 训练结束时的模型困惑度。
```
```python
def export(
params_path,
module_name,
author,
version="1.0.0",
summary="",
author_email="",
export_path=".")
```
module导出API,通过此API可以一键将训练参数打包为hub module。
**参数**
* params_path(str): 模型参数路径。
* module_name(str): module名称,例如"ernie_gen_couplet"。
* author(str): 作者名称。
* version(str): 版本号。
* summary(str): module的英文简介。
* author_email(str): 作者的邮箱地址。
* export_path(str): module的导出路径。
**代码示例**
```python
import paddlehub as hub
module = hub.Module(name="ernie_gen")
result = module.finetune(
train_path='test_data/train.txt',
dev_path='test_data/dev.txt',
max_steps=300,
batch_size=2
)
module.export(params_path=result['last_save_path'], module_name="ernie_gen_test", author="test")
```
## 使用方式
模型转换完毕之后,通过`hub install $module_name`安装该模型,即可通过以下2种方式调用自制module:
1. 命令行预测
```shell
$ hub run $module_name --input_text="输入文本" --use_gpu True --beam_width 5
```
2. API预测
```python
import paddlehub as hub
module = hub.Module(name="$module_name")
test_texts = ["输入文本1", "输入文本2"]
# generate包含3个参数,texts为输入文本列表,use_gpu指定是否使用gpu,beam_width指定beam search宽度。
results = module.generate(texts=test_texts, use_gpu=True, beam_width=5)
for result in results:
print(result)
```
**NOTE**: 上述`$module_name`为export指定的module_name。
您也可以将$module_name文件夹打包为tar.gz压缩包并联系PaddleHub工作人员上传至PaddleHub模型仓库,这样更多的用户可以通过一键安装的方式使用您的模型。PaddleHub非常欢迎您的贡献,共同推动开源社区成长。
## 查看代码
https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/
### 依赖
paddlepaddle >= 1.8.2
paddlehub >= 1.7.0
## 更新历史
* 1.0.0
初始发布
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections import namedtuple
import paddle.fluid as F
import paddle.fluid.layers as L
import paddle.fluid.dygraph as D
import numpy as np
from paddlehub.common.logger import logger
def gen_bias(encoder_inputs, decoder_inputs, step):
decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2]
attn_bias = L.reshape(
L.range(0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1])
decoder_bias = L.cast(
(L.matmul(attn_bias, 1. / attn_bias, transpose_y=True) >= 1.),
'float32') # [1, 1, decoderlen, decoderlen]
encoder_bias = L.unsqueeze(
L.cast(L.ones_like(encoder_inputs), 'float32'),
[1]) # [bsz, 1, encoderlen]
encoder_bias = L.expand(
encoder_bias, [1, decoder_seqlen, 1]) # [bsz,decoderlen, encoderlen]
decoder_bias = L.expand(
decoder_bias, [decoder_bsz, 1, 1]) # [bsz, decoderlen, decoderlen]
if step > 0:
bias = L.concat([
encoder_bias,
L.ones([decoder_bsz, decoder_seqlen, step], 'float32'), decoder_bias
], -1)
else:
bias = L.concat([encoder_bias, decoder_bias], -1)
return bias
@D.no_grad
def greedy_search_infilling(model,
q_ids,
q_sids,
sos_id,
eos_id,
attn_id,
max_encode_len=640,
max_decode_len=100,
tgt_type_id=3):
model.eval()
_, logits, info = model(q_ids, q_sids)
gen_ids = L.argmax(logits, -1)
d_batch, d_seqlen = q_ids.shape
seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True)
logger.debug(seqlen.numpy())
logger.debug(d_seqlen)
has_stopped = np.zeros([d_batch], dtype=np.bool)
gen_seq_len = np.zeros([d_batch], dtype=np.int64)
output_ids = []
past_cache = info['caches']
cls_ids = L.ones([d_batch], dtype='int64') * sos_id
attn_ids = L.ones([d_batch], dtype='int64') * attn_id
ids = L.stack([cls_ids, attn_ids], -1)
for step in range(max_decode_len):
logger.debug('decode step %d' % step)
bias = gen_bias(q_ids, ids, step)
pos_ids = D.to_variable(
np.tile(np.array([[step, step + 1]], dtype=np.int64), [d_batch, 1]))
pos_ids += seqlen
_, logits, info = model(
ids,
L.ones_like(ids) * tgt_type_id,
pos_ids=pos_ids,
attn_bias=bias,
past_cache=past_cache)
gen_ids = L.argmax(logits, -1)
past_cached_k, past_cached_v = past_cache
cached_k, cached_v = info['caches']
cached_k = [
L.concat([pk, k[:, :1, :]], 1)
for pk, k in zip(past_cached_k, cached_k)
] # concat cached
cached_v = [
L.concat([pv, v[:, :1, :]], 1)
for pv, v in zip(past_cached_v, cached_v)
]
past_cache = (cached_k, cached_v)
gen_ids = gen_ids[:, 1]
ids = L.stack([gen_ids, attn_ids], 1)
gen_ids = gen_ids.numpy()
has_stopped |= (gen_ids == eos_id).astype(np.bool)
gen_seq_len += (1 - has_stopped.astype(np.int64))
output_ids.append(gen_ids.tolist())
if has_stopped.all():
break
output_ids = np.array(output_ids).transpose([1, 0])
return output_ids
BeamSearchState = namedtuple('BeamSearchState',
['log_probs', 'lengths', 'finished'])
BeamSearchOutput = namedtuple('BeamSearchOutput',
['scores', 'predicted_ids', 'beam_parent_ids'])
def log_softmax(x):
e_x = np.exp(x - np.max(x))
return np.log(e_x / e_x.sum())
def mask_prob(p, onehot_eos, finished):
is_finished = L.cast(L.reshape(finished, [-1, 1]) != 0, 'float32')
p = is_finished * (1. - L.cast(onehot_eos, 'float32')) * -9999. + (
1. - is_finished) * p
return p
def hyp_score(log_probs, length, length_penalty):
lp = L.pow((5. + L.cast(length, 'float32')) / 6., length_penalty)
return log_probs / lp
def beam_search_step(state, logits, eos_id, beam_width, is_first_step,
length_penalty):
"""logits.shape == [B*W, V]"""
_, vocab_size = logits.shape
bsz, beam_width = state.log_probs.shape
onehot_eos = L.cast(
F.one_hot(L.ones([1], 'int64') * eos_id, vocab_size), 'int64') # [1, V]
probs = L.log(L.softmax(logits)) # [B*W, V]
probs = mask_prob(probs, onehot_eos, state.finished) # [B*W, V]
allprobs = L.reshape(state.log_probs, [-1, 1]) + probs # [B*W, V]
not_finished = 1 - L.reshape(state.finished, [-1, 1]) # [B*W,1]
not_eos = 1 - onehot_eos
length_to_add = not_finished * not_eos # [B*W,V]
alllen = L.reshape(state.lengths, [-1, 1]) + length_to_add
allprobs = L.reshape(allprobs, [-1, beam_width * vocab_size])
alllen = L.reshape(alllen, [-1, beam_width * vocab_size])
allscore = hyp_score(allprobs, alllen, length_penalty)
if is_first_step:
allscore = L.reshape(
allscore,
[bsz, beam_width, -1])[:, 0, :] # first step only consiter beam 0
scores, idx = L.topk(allscore, k=beam_width) # [B, W]
next_beam_id = idx // vocab_size # [B, W]
next_word_id = idx % vocab_size
gather_idx = L.concat([L.where(idx != -1)[:, :1],
L.reshape(idx, [-1, 1])], 1)
next_probs = L.reshape(L.gather_nd(allprobs, gather_idx), idx.shape)
next_len = L.reshape(L.gather_nd(alllen, gather_idx), idx.shape)
gather_idx = L.concat(
[L.where(next_beam_id != -1)[:, :1],
L.reshape(next_beam_id, [-1, 1])], 1)
next_finished = L.reshape(
L.gather_nd(state.finished, gather_idx), state.finished.shape
) # [gather new beam state according to new beam id]
next_finished += L.cast(next_word_id == eos_id, 'int64')
next_finished = L.cast(next_finished > 0, 'int64')
next_state = BeamSearchState(
log_probs=next_probs, lengths=next_len, finished=next_finished)
output = BeamSearchOutput(
scores=scores, predicted_ids=next_word_id, beam_parent_ids=next_beam_id)
return output, next_state
@D.no_grad
def beam_search_infilling(model,
q_ids,
q_sids,
sos_id,
eos_id,
attn_id,
max_encode_len=640,
max_decode_len=100,
beam_width=5,
tgt_type_id=3,
length_penalty=1.0):
model.eval()
_, __, info = model(q_ids, q_sids)
d_batch, d_seqlen = q_ids.shape
state = BeamSearchState(
log_probs=L.zeros([d_batch, beam_width], 'float32'),
lengths=L.zeros([d_batch, beam_width], 'int64'),
finished=L.zeros([d_batch, beam_width], 'int64'))
outputs = []
def reorder_(t, parent_id):
"""reorder cache according to parent beam id"""
gather_idx = L.where(parent_id != -1)[:, 0] * beam_width + L.reshape(
parent_id, [-1])
t = L.gather(t, gather_idx)
return t
def tile_(t, times):
_shapes = list(t.shape[1:])
ret = L.reshape(
L.expand(L.unsqueeze(t, [1]), [
1,
times,
] + [
1,
] * len(_shapes)), [
-1,
] + _shapes)
return ret
cached_k, cached_v = info['caches']
cached_k = [tile_(k, beam_width) for k in cached_k]
cached_v = [tile_(v, beam_width) for v in cached_v]
past_cache = (cached_k, cached_v)
q_ids = tile_(q_ids, beam_width)
seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True)
cls_ids = L.ones([d_batch * beam_width], dtype='int64') * sos_id
attn_ids = L.ones([d_batch * beam_width], dtype='int64') * attn_id # SOS
ids = L.stack([cls_ids, attn_ids], -1)
for step in range(max_decode_len):
bias = gen_bias(q_ids, ids, step)
pos_ids = D.to_variable(
np.tile(
np.array([[step, step + 1]], dtype=np.int64),
[d_batch * beam_width, 1]))
pos_ids += seqlen
_, logits, info = model(
ids,
L.ones_like(ids) * tgt_type_id,
pos_ids=pos_ids,
attn_bias=bias,
past_cache=past_cache)
output, state = beam_search_step(
state,
logits[:, 1],
eos_id=eos_id,
beam_width=beam_width,
is_first_step=(step == 0),
length_penalty=length_penalty)
outputs.append(output)
past_cached_k, past_cached_v = past_cache
cached_k, cached_v = info['caches']
cached_k = [
reorder_(L.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids)
for pk, k in zip(past_cached_k, cached_k)
] # concat cached
cached_v = [
reorder_(L.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids)
for pv, v in zip(past_cached_v, cached_v)
]
past_cache = (cached_k, cached_v)
pred_ids_flatten = L.reshape(output.predicted_ids,
[d_batch * beam_width])
ids = L.stack([pred_ids_flatten, attn_ids], 1)
if state.finished.numpy().all():
break
final_ids = L.stack([o.predicted_ids for o in outputs], 0)
final_parent_ids = L.stack([o.beam_parent_ids for o in outputs], 0)
final_ids = L.gather_tree(final_ids, final_parent_ids)[:, :,
0] # pick best beam
final_ids = L.transpose(L.reshape(final_ids, [-1, d_batch * 1]), [1, 0])
return final_ids
en_patten = re.compile(r'^[a-zA-Z0-9]*$')
def post_process(token):
if token.startswith('##'):
ret = token[2:]
else:
if en_patten.match(token):
ret = ' ' + token
else:
ret = token
return ret
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Propeller"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import logging
import six
from time import time
__version__ = '0.2'
log = logging.getLogger(__name__)
stream_hdl = logging.StreamHandler(stream=sys.stderr)
formatter = logging.Formatter(
fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s'
)
try:
from colorlog import ColoredFormatter
fancy_formatter = ColoredFormatter(
fmt=
'%(log_color)s[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s'
)
stream_hdl.setFormatter(fancy_formatter)
except ImportError:
stream_hdl.setFormatter(formatter)
log.setLevel(logging.INFO)
log.addHandler(stream_hdl)
log.propagate = False
from ernie_gen.propeller.types import *
from ernie_gen.propeller.util import ArgumentParser, parse_hparam, parse_runconfig, parse_file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
doc
"""
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic Dataset API"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import logging
import os
import itertools
import random
import inspect
import multiprocessing
from contextlib import contextmanager
import gzip
import struct
import functools
import six
from six.moves import zip, map, filter
import numpy as np
from ernie_gen.propeller.util import map_structure
log = logging.getLogger(__name__)
__all__ = ['Dataset']
@contextmanager
def _open_file(filename, format=None):
if format is None:
fd = open(filename, 'rb')
elif format == 'GZIP':
fd = gzip.open(filename, 'rb')
else:
raise ValueError('unkwon file format %s' % format)
yield fd
fd.close()
def _open_record(filename):
def _gen():
with _open_file(filename, format='GZIP') as f:
while True:
data = f.read(struct.calcsize('i'))
if not len(data):
raise StopIteration
l, = struct.unpack('i', data)
data = f.read(l)
yield data
return _gen
def _shuffle_func(dataset, buffer_size):
def _gen():
buf = []
iterable = dataset()
try:
while len(buf) < buffer_size:
buf.append(next(iterable))
while 1:
i = random.randint(0, buffer_size - 1)
n = next(iterable)
yield buf[i]
buf[i] = n
except StopIteration:
if len(buf):
random.shuffle(buf)
for i in buf:
yield i
return _gen
def _interleave_func(iterable, map_fn, cycle_length, block_length):
def _gen():
ls = itertools.tee(iterable(), cycle_length)
buf = []
for i, j in enumerate(ls):
j = itertools.islice(j, i, None, cycle_length)
j = map(map_fn, j)
j = (jjj for jj in j for jjj in jj) #flatten
buf.append(j)
for tup in six.moves.zip_longest(*buf):
for ii in (i for i in tup if i is not None):
yield ii
return _gen
def _repeat_func(dataset, n):
def _gen():
iterable = dataset()
if n >= 0:
ret = itertools.chain(*itertools.tee(iterable, n))
else:
ret = itertools.cycle(iterable)
for i in ret:
yield i
return _gen
def _filter_func(dataset, fn):
def _gen():
for i in dataset():
if isinstance(i, tuple) or isinstance(i, list):
if fn(*i) is True:
yield i
else:
if fn(i) is True:
yield i
return _gen
def _map_func(dataset, fn):
def _gen():
for i in dataset():
if isinstance(i, tuple) or isinstance(i, list):
yield fn(*i)
else:
yield fn(i)
return _gen
def _shard_func(dataset, num_shards, index):
def _gen():
iterable = dataset()
ret = itertools.islice(iterable, index, None, num_shards)
for i in ret:
yield i
return _gen
def _take_func(dataset, count):
def _gen():
iterable = dataset()
ret = itertools.islice(iterable, count)
for i in ret:
yield i
return _gen
def _chain_func(dataset, dataset2):
def _gen():
iterable = dataset()
iterable2 = dataset2()
ret = itertools.chain(iterable, iterable2)
for i in ret:
yield i
return _gen
def _buffered_func(dataset, size):
"""
Creates a buffered data reader.
The buffered data reader will read and save data entries into a
buffer. Reading from the buffered data reader will proceed as long
as the buffer is not empty.
:param reader: the data reader to read from.
:type reader: callable
:param size: max buffer size.
:type size: int
:returns: the buffered data reader.
"""
class _EndSignal(object):
pass
end = _EndSignal()
def _read_worker(r, q):
for d in r:
q.put(d)
q.put(end)
def _data_reader():
r = dataset()
q = multiprocessing.Queue(maxsize=size)
t = multiprocessing.Process(
target=_read_worker, args=(
r,
q,
))
t.daemon = True
t.start()
e = q.get()
while e != end:
yield e
e = q.get()
return _data_reader
def _batch_func(dataset, batch_size):
def _gen():
iterable = dataset()
while True:
buf = list(itertools.islice(iterable, batch_size))
if not len(buf):
raise StopIteration
buf = list(zip(*buf)) # transpose
buf = [np.stack(b) for b in buf]
yield buf
return _gen
def _padded_batch_func(dataset, batch_size, pad_value=0, max_seqlen=None):
if not isinstance(batch_size, int):
raise ValueError('unknown batch_size: %s' % repr(batch_size))
def _gen():
iterable = dataset()
pad_value_t = pad_value
while True:
buf = list(itertools.islice(iterable, batch_size))
if not len(buf):
raise StopIteration
buf = list(zip(*buf)) # transpose
if type(pad_value_t) not in [list, tuple]:
pad_value_t = [pad_value_t] * len(buf)
padded = []
assert len(buf) == len(
pad_value_t), 'pad_value [%d] != element size[%d]' % (
len(pad_value_t), len(buf))
for e, pv in zip(buf, pad_value_t):
elem = e[0]
if (not np.isscalar(elem)) and elem.shape != ():
max_len = max(map(len,
e)) if max_seqlen is None else max_seqlen
def _fn(i):
if max_len >= len(i):
return np.pad(
i, [0, max_len - len(i)],
'constant',
constant_values=pv)
else:
return i[:max_len]
e = map(_fn, e)
padded.append(np.stack(list(e)))
yield padded
return _gen
class Dataset(object):
"""Python Wrapper for PyReader"""
@classmethod
def from_generator_func(cls, _gen, data_shapes=None, data_types=None):
"""doc"""
if not inspect.isgeneratorfunction(_gen):
raise ValueError('expect generator function, got %s' % repr(_gen))
def _wrapper(): #compat to py3.7
try:
for item in _gen():
yield item
except RuntimeError as e:
if str(e) != 'generator raised StopIteration':
raise e
ret = cls()
ret.generator = _wrapper
ret.data_shapes = data_shapes
ret.data_types = data_types
return ret
@classmethod
def from_file(cls, filename, format=None):
"""doc"""
if os.path.getsize(filename) == 0:
raise RuntimeError('%s is empty' % filename)
def _gen():
with _open_file(filename, format) as f:
for line in f:
yield line
ret = cls()
ret.generator = _gen
ret.data_shapes = []
ret.data_types = str
return ret
@classmethod
def from_record_file(cls, filename):
"""doc"""
if os.path.getsize(filename) == 0:
raise RuntimeError('%s is empty' % filename)
_gen = _open_record(filename)
ret = cls()
ret.generator = _gen
ret.data_shapes = []
ret.data_types = str
return ret
@classmethod
def from_list(cls, ls):
"""doc"""
if not isinstance(ls, list):
raise ValueError('expect list, got %s' % repr(ls))
def _gen():
for i in ls:
yield i
ret = cls()
ret.generator = _gen
ret.data_shapes = []
ret.data_types = str
return ret
def __init__(self):
self.name = None
self._data_shapes = None
self._data_types = None
self.generator = None
self.pyreader = None
def __repr__(self):
return 'Dataset: name: %s, data_shapes %s, data_types %s' % (
self.name, self._data_shapes, self._data_types)
def __eq__(self, other):
return self.name == other.name and \
self._data_shapes == other._data_shapes and \
self._data_types == other._data_types
def __iter__(self):
return self.generator()
#def __call__(self):
# return self.generator()
def _infer_shapes_and_types(self):
if self.generator is not None and self.name is not None:
log.info('Try to infer data shapes & types from generator')
first_value = next(self.generator())
shapes, types = [], []
for v in first_value:
if not isinstance(v, np.ndarray):
raise ValueError(
'dataset generator should use numpy elements, got %s' %
first_value)
shapes.append(v.shape)
types.append(v.dtype.name)
self._data_shapes = shapes
self._data_types = types
log.info('Dataset `%s` has data_shapes: %s data_types: %s' %
(self.name, repr(shapes), repr(types)))
else:
raise ValueError(
'Try to infer data shapes or types from incomplete Dataset')
@property
def data_shapes(self):
"""doc"""
if self._data_shapes is None:
self._infer_shapes_and_types()
return self._data_shapes
else:
return self._data_shapes
@data_shapes.setter
def data_shapes(self, val):
"""doc"""
self._data_shapes = val
@property
def data_types(self):
"""doc"""
if self._data_types is None:
self._infer_shapes_and_types()
return self._data_types
else:
return self._data_types
@data_types.setter
def data_types(self, val):
"""doc"""
self._data_types = val
def apply(self, transform_func):
"""apply transform func to datasets"""
#input_shapes = transform_func.input_shapes
#input_types = transform_func.input_types
#data_shapes = transform_func.data_shapes
#data_types = transform_func.data_types
#assert input_shapes == self._data_shapes
#assert input_types = self._data_types
ret_gen = transform_func(self.generator)
ret = type(self).from_generator_func(ret_gen)
if self.name is not None:
ret.name = self.name
#ret.data_shapes = data_shapes
#ret.data_types = data_types
return ret
def shuffle(self, buffer_size):
"""doc"""
func = functools.partial(_shuffle_func, buffer_size=buffer_size)
return self.apply(func)
def repeat(self, n=-1):
"""doc"""
func = functools.partial(_repeat_func, n=n)
return self.apply(func)
def map(self, fn):
"""doc"""
func = functools.partial(_map_func, fn=fn)
return self.apply(func)
def filter(self, fn):
"""doc"""
func = functools.partial(_filter_func, fn=fn)
return self.apply(func)
def shard(self, num_shards, index):
"""doc"""
func = functools.partial(
_shard_func, num_shards=num_shards, index=index)
return self.apply(func)
def interleave(self, map_fn, cycle_length, block_length):
"""doc"""
func = functools.partial(
_interleave_func,
map_fn=map_fn,
cycle_length=cycle_length,
block_length=block_length)
return self.apply(func)
def batch(self, batch_size):
func = functools.partial(_batch_func, batch_size=batch_size)
return self.apply(func)
def padded_batch(self, batch_size, pad_value=0, max_seqlen=None):
"""doc"""
func = functools.partial(
_padded_batch_func,
batch_size=batch_size,
pad_value=pad_value,
max_seqlen=max_seqlen)
return self.apply(func)
def take(self, count=1):
"""doc"""
func = functools.partial(_take_func, count=count)
return self.apply(func)
def buffered(self, size=10):
"""doc"""
func = functools.partial(_buffered_func, size=size)
return self.apply(func)
def chain(self, other):
func = functools.partial(_chain_func, dataset2=other.generator)
return self.apply(func)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
doc
"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import six
import logging
log = logging.getLogger(__name__)
def enable_textone():
try:
import textone
except ImportError:
log.fatal('enable textone failed: textone not found!')
raise
global textone_enabled
log.info('textone enabled')
from ernie_gen.propeller.paddle.train.monitored_executor import MonitoredExecutor, TextoneTrainer
if TextoneTrainer is None:
raise RuntimeError('enable textone failed: textone not found!')
MonitoredExecutor.saver_class = TextoneTrainer
from ernie_gen.propeller.types import *
from ernie_gen.propeller.util import ArgumentParser, parse_hparam, parse_runconfig, parse_file
from ernie_gen.propeller.paddle import data
from ernie_gen.propeller.paddle import train
from ernie_gen.propeller.paddle.train import *
import paddle
paddle_version = [int(i) for i in paddle.__version__.split('.')]
if paddle_version[1] < 7:
raise RuntimeError(
'propeller 0.2 requires paddle 1.7+, got %s' % paddle.__version__)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""global collections"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
_global_collection = None
class Key(object):
"""predefine collection keys"""
SUMMARY_SCALAR = 1
SUMMARY_HISTOGRAM = 2
SKIP_OPTIMIZE = 3
class Collections(object):
"""global collections to record everything"""
def __init__(self):
self.col = {}
def __enter__(self):
global _global_collection
_global_collection = self
return self
def __exit__(self, err_type, err_value, trace):
global _global_collection
_global_collection = None
def add(self, key, val):
"""doc"""
self.col.setdefault(key, []).append(val)
def get(self, key):
"""doc"""
return self.col.get(key, None)
def default_collection():
"""return global collection"""
global _global_collection
if _global_collection is None:
_global_collection = Collections()
return _global_collection
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
doc
"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
from ernie_gen.propeller.paddle.data.functional import *
from ernie_gen.propeller.paddle.data.feature_column import *
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Protocol messages for describing input data Examples for machine learning
// model training or inference.
syntax = "proto3";
import "propeller/paddle/data/feature.proto";
package propeller;
message Example {
Features features = 1;
};
message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
};
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: propeller/paddle/data/example.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from ernie_gen.propeller.paddle.data import feature_pb2 as propeller_dot_paddle_dot_data_dot_feature__pb2
DESCRIPTOR = _descriptor.FileDescriptor(
name='propeller/paddle/data/example.proto',
package='propeller',
syntax='proto3',
serialized_options=None,
serialized_pb=_b(
'\n#propeller/paddle/data/example.proto\x12\tpropeller\x1a#propeller/paddle/data/feature.proto\"0\n\x07\x45xample\x12%\n\x08\x66\x65\x61tures\x18\x01 \x01(\x0b\x32\x13.propeller.Features\"g\n\x0fSequenceExample\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.propeller.Features\x12.\n\rfeature_lists\x18\x02 \x01(\x0b\x32\x17.propeller.FeatureListsb\x06proto3'
),
dependencies=[
propeller_dot_paddle_dot_data_dot_feature__pb2.DESCRIPTOR,
])
_EXAMPLE = _descriptor.Descriptor(
name='Example',
full_name='propeller.Example',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='features',
full_name='propeller.Example.features',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=87,
serialized_end=135,
)
_SEQUENCEEXAMPLE = _descriptor.Descriptor(
name='SequenceExample',
full_name='propeller.SequenceExample',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='context',
full_name='propeller.SequenceExample.context',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='feature_lists',
full_name='propeller.SequenceExample.feature_lists',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=137,
serialized_end=240,
)
_EXAMPLE.fields_by_name[
'features'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURES
_SEQUENCEEXAMPLE.fields_by_name[
'context'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURES
_SEQUENCEEXAMPLE.fields_by_name[
'feature_lists'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURELISTS
DESCRIPTOR.message_types_by_name['Example'] = _EXAMPLE
DESCRIPTOR.message_types_by_name['SequenceExample'] = _SEQUENCEEXAMPLE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Example = _reflection.GeneratedProtocolMessageType(
'Example',
(_message.Message, ),
dict(
DESCRIPTOR=_EXAMPLE,
__module__='propeller.paddle.data.example_pb2'
# @@protoc_insertion_point(class_scope:propeller.Example)
))
_sym_db.RegisterMessage(Example)
SequenceExample = _reflection.GeneratedProtocolMessageType(
'SequenceExample',
(_message.Message, ),
dict(
DESCRIPTOR=_SEQUENCEEXAMPLE,
__module__='propeller.paddle.data.example_pb2'
# @@protoc_insertion_point(class_scope:propeller.SequenceExample)
))
_sym_db.RegisterMessage(SequenceExample)
# @@protoc_insertion_point(module_scope)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package propeller;
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
message Features {
map<string, Feature> feature = 1;
};
message FeatureList {
repeated Feature feature = 1;
};
message FeatureLists {
map<string, FeatureList> feature_list = 1;
};
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FeatureColumns and many Column"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import struct
from six.moves import zip, map
import itertools
import gzip
from functools import partial
import six
import logging
import numpy as np
from glob import glob
from ernie_gen.propeller.paddle.train import distribution
from ernie_gen.propeller.data.functional import _interleave_func
from ernie_gen.propeller.paddle.data.functional import Dataset
from ernie_gen.propeller.paddle.data import example_pb2, feature_pb2
import multiprocessing
log = logging.getLogger(__name__)
__all__ = [
'FeatureColumns', 'TextColumn', 'TextIDColumn', 'LabelColumn',
'RawBytesColumn', 'basic_tokenizer', 'Column'
]
def basic_tokenizer(sen):
"""doc"""
seg = sen.split(b' ')
seg = filter(lambda i: i != b' ', seg)
return seg
class Column(object):
"""doc"""
def __init__(self, name):
"""doc"""
pass
def raw_to_proto(self, raw):
"""doc"""
return feature_pb2.Feature()
@property
def output_shapes(self):
"""doc"""
pass
@property
def output_types(self):
"""doc"""
pass
def proto_to_instance(self, proto):
"""doc"""
raise NotImplementedError()
def raw_to_instance(self, raw):
"""doc"""
raise NotImplementedError()
class LabelColumn(Column):
"""doc"""
def __init__(self, name, vocab_dict=None, vocab_file=None):
"""doc"""
self.name = name
self.vocab = None
if vocab_file:
self.vocab = {
j.strip(): i
for i, j in enumerate(open(vocab_file, 'rb').readlines())
}
if vocab_dict:
self.vocab = vocab_dict
@property
def output_shapes(self):
"""doc"""
return [1]
@property
def output_types(self):
"""doc"""
return 'int64'
def raw_to_proto(self, raw):
"""doc"""
if self.vocab is None:
ids = [int(raw)]
else:
ids = [self.vocab[raw]]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
"""doc"""
ret = np.array(feature.int64_list.value[0], dtype=np.int64)
return ret
def raw_to_instance(self, raw):
"""doc"""
if self.vocab is None:
ids = int(raw)
else:
ids = self.vocab[raw]
return ids
class TextColumn(Column):
"""doc"""
def __init__(self,
name,
unk_id,
vocab_file=None,
vocab_dict=None,
tokenizer=basic_tokenizer):
self.name = name
self.tokenizer = tokenizer
self.unk_id = unk_id
if not (vocab_file or vocab_dict):
raise ValueError('at least specify vocab_file or vocab_dict')
if vocab_file:
self.vocab = {
j.strip(): i
for i, j in enumerate(open(vocab_file, 'rb').readlines())
}
if vocab_dict:
self.vocab = vocab_dict
@property
def output_shapes(self):
"""doc"""
return [-1]
@property
def output_types(self):
"""doc"""
return 'int64'
def raw_to_proto(self, raw):
"""doc"""
ids = [
s if isinstance(s, int) else self.vocab.get(s, self.unk_id)
for s in self.tokenizer(raw)
]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
"""doc"""
ret = np.array(feature.int64_list.value, dtype=np.int64)
return ret
def raw_to_instance(self, raw):
"""doc"""
ids = [
s if isinstance(s, int) else self.vocab.get(s, self.unk_id)
for s in self.tokenizer(raw)
]
return np.array(ids, dtype=np.int64)
class RawBytesColumn(Column):
def __init__(self, name):
self.name = name
@property
def output_shapes(self):
"""doc"""
return [-1]
@property
def output_types(self):
"""doc"""
return 'bytes'
# def raw_to_proto(self, raw):
# """doc"""
# fe = feature_pb2.Feature(bytes_list=BytesList(value=[raw]))
# return fe
def proto_to_instance(self, feature):
"""doc"""
ret = feature.bytes_list.value[
0] #np.array(feature.int64_list.value, dtype=np.int64)
return ret
def raw_to_instance(self, raw):
"""doc"""
return raw
class TextIDColumn(Column):
"""doc"""
def __init__(self, name):
"""doc"""
self.name = name
@property
def output_shapes(self):
"""doc"""
return [-1]
@property
def output_types(self):
"""doc"""
return 'int64'
def raw_to_proto(self, raw):
"""doc"""
ids = [int(s) for s in raw.split(b' ')]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
"""doc"""
ret = np.array(feature.int64_list.value, dtype=np.int64)
return ret
def raw_to_instance(self, raw):
"""doc"""
ret = np.array([int(i) for i in raw.split(b' ')], dtype=np.int64)
return ret
def _list_files(raw_dir):
return [os.path.join(raw_dir, p) for p in os.listdir(raw_dir)]
_columns = None
def _init_worker(col):
global _columns
_columns = col
def _worker_entrence(args):
args = (_columns, ) + args
return _make_gz(args)
class FeatureColumns(object):
"""A Dataset Factory object"""
def __init__(self, columns):
"""doc"""
self._columns = columns
def _make_gz_dataset(self, raw_dir, gz_dir):
assert raw_dir or gz_dir, 'data_dir not specified when using gz mode'
if raw_dir is not None:
assert os.path.exists(raw_dir), 'raw_dir not exists: %s' % raw_dir
raw_file = os.listdir(raw_dir)
if gz_dir is None:
gz_dir = '%s_gz' % raw_dir.rstrip('/')
if not os.path.exists(gz_dir):
os.mkdir(gz_dir)
if raw_dir is not None:
if len(raw_file) != 0:
log.debug('try making gz')
pool = multiprocessing.Pool(
initializer=_init_worker, initargs=(self._columns, ))
args = [(os.path.join(raw_dir, f), os.path.join(gz_dir, f),
b'\t') for f in raw_file]
pool.map(_worker_entrence, args)
pool.close()
pool.join()
else:
assert len(
os.listdir(gz_dir)
) != 0, 'cant find gz file or raw-txt file at [%s] and [%s]' % (
raw_dir, gz_dir)
return gz_dir
def _read_gz_dataset(self,
gz_files,
shuffle=False,
repeat=True,
shard=False,
**kwargs):
if len(gz_files) == 0:
raise ValueError('reading gz from empty file list: %s' % gz_files)
log.info('reading gz from %s' % '\n'.join(gz_files))
dataset = Dataset.from_list(gz_files)
if repeat:
dataset = dataset.repeat()
# if shard and distribution.status.mode == distribution.DistributionMode.NCCL:
# log.info('Apply dataset sharding in distribution env')
# train_ds = train_ds.shard(distribution.status.num_replica,
# distribution.status.replica_id)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(gz_files))
fn = partial(
_interleave_func,
map_fn=lambda filename: Dataset.from_record_file(filename),
cycle_length=len(gz_files),
block_length=1)
dataset = dataset.apply(fn)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_gz(record_str): # function that takes python_str as input
ex = example_pb2.Example()
ex.ParseFromString(record_str)
ret = []
fea_dict = ex.features.feature
for c in self._columns:
ins = c.proto_to_instance(fea_dict[c.name])
ret.append(ins)
return ret
dataset = dataset.map(_parse_gz)
return dataset
def _read_txt_dataset(self,
data_files,
shuffle=False,
repeat=True,
**kwargs):
log.info('reading raw files from %s' % '\n'.join(data_files))
dataset = Dataset.from_list(data_files)
if repeat:
dataset = dataset.repeat()
if shuffle:
dataset = dataset.shuffle(buffer_size=len(data_files))
fn = partial(
_interleave_func,
map_fn=lambda filename: Dataset.from_file(filename),
cycle_length=len(data_files),
block_length=1)
dataset = dataset.apply(fn)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_txt_file(
record_str): # function that takes python_str as input
features = record_str.strip(b'\n').split(b'\t')
ret = [
column.raw_to_instance(feature)
for feature, column in zip(features, self._columns)
]
return ret
dataset = dataset.map(_parse_txt_file)
return dataset
def _read_stdin_dataset(self, encoding='utf8', shuffle=False, **kwargs):
log.info('reading raw files stdin')
def _gen():
if six.PY3:
source = sys.stdin.buffer
else:
source = sys.stdin
while True:
line = source.readline()
if len(line) == 0:
break
yield line,
dataset = Dataset.from_generator_func(_gen)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_stdin(record_str):
"""function that takes python_str as input"""
features = record_str.strip(b'\n').split(b'\t')
ret = [
column.raw_to_instance(feature)
for feature, column in zip(features, self._columns)
]
return ret
dataset = dataset.map(_parse_stdin)
return dataset
def _prepare_dataset(self,
dataset,
map_func_before_batch=None,
map_func_after_batch=None,
shuffle_buffer_size=None,
batch_size=1,
pad_id=0,
prefetch=None,
**kwargs):
if map_func_before_batch is not None:
dataset = dataset.map(map_func_before_batch)
if batch_size:
dataset = dataset.padded_batch(batch_size, pad_id)
if map_func_after_batch is not None:
dataset = dataset.map(map_func_after_batch)
return dataset
def build_dataset(self,
name,
use_gz=True,
data_dir=None,
gz_dir=None,
data_file=None,
**kwargs):
"""
build `Dataset` from `data_dir` or `data_file`
if `use_gz`, will try to convert data_files to gz format and save to `gz_dir`, if `gz_dir` not given, will create one.
"""
if use_gz:
gz_dir = self._make_gz_dataset(data_dir, gz_dir)
gz_files = _list_files(gz_dir) if gz_dir is not None else gz_dir
ds = self._read_gz_dataset(gz_files, **kwargs)
else:
if data_dir is not None:
data_files = _list_files(data_dir)
elif data_file is not None:
data_files = [data_file]
else:
raise ValueError('data_dir or data_files not specified')
ds = self._read_txt_dataset(data_files, **kwargs)
ds.name = name
return ds
def build_dataset_from_stdin(self, name, **kwargs):
"""doc"""
ds = self._read_stdin_dataset(**kwargs)
ds.name = name
return ds
def _make_gz(args):
try:
columns, from_file, to_file, sep = args
if os.path.exists(to_file):
return
with open(from_file, 'rb') as fin, gzip.open(to_file, 'wb') as fout:
log.debug('making gz %s => %s' % (from_file, to_file))
for i, line in enumerate(fin):
line = line.strip(b'\n').split(sep)
#if i % 10000 == 0:
# log.debug('making gz %s => %s [%d]' % (from_file, to_file, i))
if len(line) != len(columns):
log.error('columns not match at %s, got %d, expect %d' %
(from_file, len(line), len(columns)))
continue
features = {}
for l, c in zip(line, columns):
features[c.name] = c.raw_to_proto(l)
example = example_pb2.Example(
features=feature_pb2.Features(feature=features))
serialized = example.SerializeToString()
l = len(serialized)
data = struct.pack('i%ds' % l, l, serialized)
fout.write(data)
log.debug('done making gz %s => %s' % (from_file, to_file))
except Exception as e:
log.exception(e)
raise e
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: propeller/paddle/data/feature.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='propeller/paddle/data/feature.proto',
package='propeller',
syntax='proto3',
serialized_options=None,
serialized_pb=_b(
'\n#propeller/paddle/data/feature.proto\x12\tpropeller\"\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\"\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\"\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\"\x95\x01\n\x07\x46\x65\x61ture\x12*\n\nbytes_list\x18\x01 \x01(\x0b\x32\x14.propeller.BytesListH\x00\x12*\n\nfloat_list\x18\x02 \x01(\x0b\x32\x14.propeller.FloatListH\x00\x12*\n\nint64_list\x18\x03 \x01(\x0b\x32\x14.propeller.Int64ListH\x00\x42\x06\n\x04kind\"\x81\x01\n\x08\x46\x65\x61tures\x12\x31\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32 .propeller.Features.FeatureEntry\x1a\x42\n\x0c\x46\x65\x61tureEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.propeller.Feature:\x02\x38\x01\"2\n\x0b\x46\x65\x61tureList\x12#\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x12.propeller.Feature\"\x9a\x01\n\x0c\x46\x65\x61tureLists\x12>\n\x0c\x66\x65\x61ture_list\x18\x01 \x03(\x0b\x32(.propeller.FeatureLists.FeatureListEntry\x1aJ\n\x10\x46\x65\x61tureListEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.propeller.FeatureList:\x02\x38\x01\x62\x06proto3'
))
_BYTESLIST = _descriptor.Descriptor(
name='BytesList',
full_name='propeller.BytesList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.BytesList.value',
index=0,
number=1,
type=12,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=50,
serialized_end=76,
)
_FLOATLIST = _descriptor.Descriptor(
name='FloatList',
full_name='propeller.FloatList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.FloatList.value',
index=0,
number=1,
type=2,
cpp_type=6,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=_b('\020\001'),
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=78,
serialized_end=108,
)
_INT64LIST = _descriptor.Descriptor(
name='Int64List',
full_name='propeller.Int64List',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.Int64List.value',
index=0,
number=1,
type=3,
cpp_type=2,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=_b('\020\001'),
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=110,
serialized_end=140,
)
_FEATURE = _descriptor.Descriptor(
name='Feature',
full_name='propeller.Feature',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='bytes_list',
full_name='propeller.Feature.bytes_list',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='float_list',
full_name='propeller.Feature.float_list',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='int64_list',
full_name='propeller.Feature.int64_list',
index=2,
number=3,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='kind',
full_name='propeller.Feature.kind',
index=0,
containing_type=None,
fields=[]),
],
serialized_start=143,
serialized_end=292,
)
_FEATURES_FEATUREENTRY = _descriptor.Descriptor(
name='FeatureEntry',
full_name='propeller.Features.FeatureEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='propeller.Features.FeatureEntry.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.Features.FeatureEntry.value',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=_b('8\001'),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=358,
serialized_end=424,
)
_FEATURES = _descriptor.Descriptor(
name='Features',
full_name='propeller.Features',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature',
full_name='propeller.Features.feature',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[
_FEATURES_FEATUREENTRY,
],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=295,
serialized_end=424,
)
_FEATURELIST = _descriptor.Descriptor(
name='FeatureList',
full_name='propeller.FeatureList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature',
full_name='propeller.FeatureList.feature',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=426,
serialized_end=476,
)
_FEATURELISTS_FEATURELISTENTRY = _descriptor.Descriptor(
name='FeatureListEntry',
full_name='propeller.FeatureLists.FeatureListEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='propeller.FeatureLists.FeatureListEntry.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.FeatureLists.FeatureListEntry.value',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=_b('8\001'),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=559,
serialized_end=633,
)
_FEATURELISTS = _descriptor.Descriptor(
name='FeatureLists',
full_name='propeller.FeatureLists',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature_list',
full_name='propeller.FeatureLists.feature_list',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[
_FEATURELISTS_FEATURELISTENTRY,
],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=479,
serialized_end=633,
)
_FEATURE.fields_by_name['bytes_list'].message_type = _BYTESLIST
_FEATURE.fields_by_name['float_list'].message_type = _FLOATLIST
_FEATURE.fields_by_name['int64_list'].message_type = _INT64LIST
_FEATURE.oneofs_by_name['kind'].fields.append(
_FEATURE.fields_by_name['bytes_list'])
_FEATURE.fields_by_name[
'bytes_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURE.oneofs_by_name['kind'].fields.append(
_FEATURE.fields_by_name['float_list'])
_FEATURE.fields_by_name[
'float_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURE.oneofs_by_name['kind'].fields.append(
_FEATURE.fields_by_name['int64_list'])
_FEATURE.fields_by_name[
'int64_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURES_FEATUREENTRY.fields_by_name['value'].message_type = _FEATURE
_FEATURES_FEATUREENTRY.containing_type = _FEATURES
_FEATURES.fields_by_name['feature'].message_type = _FEATURES_FEATUREENTRY
_FEATURELIST.fields_by_name['feature'].message_type = _FEATURE
_FEATURELISTS_FEATURELISTENTRY.fields_by_name[
'value'].message_type = _FEATURELIST
_FEATURELISTS_FEATURELISTENTRY.containing_type = _FEATURELISTS
_FEATURELISTS.fields_by_name[
'feature_list'].message_type = _FEATURELISTS_FEATURELISTENTRY
DESCRIPTOR.message_types_by_name['BytesList'] = _BYTESLIST
DESCRIPTOR.message_types_by_name['FloatList'] = _FLOATLIST
DESCRIPTOR.message_types_by_name['Int64List'] = _INT64LIST
DESCRIPTOR.message_types_by_name['Feature'] = _FEATURE
DESCRIPTOR.message_types_by_name['Features'] = _FEATURES
DESCRIPTOR.message_types_by_name['FeatureList'] = _FEATURELIST
DESCRIPTOR.message_types_by_name['FeatureLists'] = _FEATURELISTS
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
BytesList = _reflection.GeneratedProtocolMessageType(
'BytesList',
(_message.Message, ),
dict(
DESCRIPTOR=_BYTESLIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.BytesList)
))
_sym_db.RegisterMessage(BytesList)
FloatList = _reflection.GeneratedProtocolMessageType(
'FloatList',
(_message.Message, ),
dict(
DESCRIPTOR=_FLOATLIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FloatList)
))
_sym_db.RegisterMessage(FloatList)
Int64List = _reflection.GeneratedProtocolMessageType(
'Int64List',
(_message.Message, ),
dict(
DESCRIPTOR=_INT64LIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Int64List)
))
_sym_db.RegisterMessage(Int64List)
Feature = _reflection.GeneratedProtocolMessageType(
'Feature',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURE,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Feature)
))
_sym_db.RegisterMessage(Feature)
Features = _reflection.GeneratedProtocolMessageType(
'Features',
(_message.Message, ),
dict(
FeatureEntry=_reflection.GeneratedProtocolMessageType(
'FeatureEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURES_FEATUREENTRY,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Features.FeatureEntry)
)),
DESCRIPTOR=_FEATURES,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Features)
))
_sym_db.RegisterMessage(Features)
_sym_db.RegisterMessage(Features.FeatureEntry)
FeatureList = _reflection.GeneratedProtocolMessageType(
'FeatureList',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURELIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureList)
))
_sym_db.RegisterMessage(FeatureList)
FeatureLists = _reflection.GeneratedProtocolMessageType(
'FeatureLists',
(_message.Message, ),
dict(
FeatureListEntry=_reflection.GeneratedProtocolMessageType(
'FeatureListEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURELISTS_FEATURELISTENTRY,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureLists.FeatureListEntry)
)),
DESCRIPTOR=_FEATURELISTS,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureLists)
))
_sym_db.RegisterMessage(FeatureLists)
_sym_db.RegisterMessage(FeatureLists.FeatureListEntry)
_FLOATLIST.fields_by_name['value']._options = None
_INT64LIST.fields_by_name['value']._options = None
_FEATURES_FEATUREENTRY._options = None
_FEATURELISTS_FEATURELISTENTRY._options = None
# @@protoc_insertion_point(module_scope)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pyreader based Dataset"""
import sys
import numpy as np
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
from ernie_gen.propeller.data.functional import Dataset as DatasetBase
log = logging.getLogger(__name__)
class Dataset(DatasetBase):
"""Pyreader based Dataset"""
def placeholders(self):
"""doc"""
if self.name is None:
raise ValueError('can not get feature from unnamed Dataset')
ret = []
for i, (shape, types) in enumerate(
zip(self.data_shapes, self.data_types)):
ret.append(
L.data(
'%s_placeholder_%d' % (self.name, i),
shape=shape,
append_batch_size=False,
dtype=types))
return ret
def features(self):
"""start point of net building. call this in a program scope"""
if self.name is None:
raise ValueError('can not get feature from unnamed Dataset')
if len(self.data_shapes) != len(self.data_types):
raise ValueError(
'Dataset shapes and types not match: shape:%s types%s' % (repr(
self._data_shapes), repr(self._data_types)))
return self.placeholders()
def start(self, places=None):
"""start Pyreader"""
if places is None:
places = F.cuda_places() if F.core.is_compiled_with_cuda(
) else F.cpu_places()
#assert self.pyreader is not None, 'use Dataset.features to build net first, then start dataset'
def _gen():
try:
for idx, i in enumerate(self.generator()):
yield i
except Exception as e:
log.exception(e)
raise e
r = F.io.PyReader(
feed_list=self.placeholders(),
capacity=50,
iterable=True,
return_list=F.in_dygraph_mode())
r.decorate_batch_generator(_gen, places=places)
return r()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""record summary tensor in a collection scope"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import paddle.fluid as F
from ernie_gen.propeller.paddle.collection import default_collection, Key
def scalar(name, tensor):
"""scalar summary"""
if not isinstance(tensor, F.framework.Variable):
raise ValueError('expect paddle Variable, got %s' % repr(tensor))
default_collection().add(Key.SUMMARY_SCALAR, (name, tensor))
def histogram(name, tensor):
"""histogram summary"""
if not isinstance(tensor, F.framework.Variable):
raise ValueError('expect paddle Variable, got %s' % repr(tensor))
default_collection().add(Key.SUMMARY_HISTOGRAM, (name, tensor))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Propeller training"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import logging
from time import time
log = logging.getLogger(__name__)
from ernie_gen.propeller.paddle.train.monitored_executor import *
from ernie_gen.propeller.paddle.train.trainer import *
from ernie_gen.propeller.paddle.train.hooks import *
from ernie_gen.propeller.train.model import Model
from ernie_gen.propeller.paddle.train import exporter
from ernie_gen.propeller.paddle.train import distribution
from ernie_gen.propeller.paddle.train import metrics
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import functools
import six
import os
import logging
from time import sleep
import paddle.fluid as F
import paddle.fluid.layers as L
log = logging.getLogger(__name__)
import ernie_gen.propeller.util
__all__ = ['init_distribuition_env', 'status']
status = None
class DistributionMode(object):
LOCAL = 0
NCCL = 1
class DistributionStatus(object):
def __init__(self, config):
if config is None:
self._mode = DistributionMode.LOCAL
self._env = None
self._this = None
else:
try:
self._mode = DistributionMode.NCCL
cluster = config['cluster']
task = config['task']['type']
idx = int(config['task']['index'])
self._this = cluster[task][idx]
self._env = cluster['chief'] + cluster.get('worker', [])
if len(set(self._env)) != len(self._env):
raise ValueError('duplicate host in dis_config %s' % config)
except KeyError as e:
raise ValueError('PROPELLER_DISCONFIG wrong: %s not found in %s'
% (e, repr(config)))
@property
def mode(self):
return self._mode
@property
def num_replica(self):
if self._mode == DistributionMode.LOCAL:
return 1
elif self._mode == DistributionMode.NCCL:
return len(self._env)
else:
raise ValueError(
'Got unknow distribution mode %s' % repr(self._mode))
@property
def replica_id(self):
if self._mode == DistributionMode.LOCAL:
return 0
elif self._mode == DistributionMode.NCCL:
return self._env.index(self._this)
else:
raise ValueError(
'Got unknow distribution mode %s' % repr(self._mode))
@property
def is_master(self):
if self._mode == DistributionMode.LOCAL:
return True
elif self._mode == DistributionMode.NCCL:
return self.replica_id == 0
else:
raise ValueError(
'got unknow distribution mode %s' % repr(self._mode))
def _get_paddlestype_disconfig():
env = os.environ.copy()
if not ('PADDLE_TRAINER_ID' in env and 'PADDLE_CURRENT_ENDPOINT' in env and
'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ENDPOINTS' in env):
return None
else:
ip_port_list = env['PADDLE_TRAINER_ENDPOINTS'].split(',')
assert len(ip_port_list) == int(env['PADDLE_TRAINERS_NUM'])
ip_port_self = env['PADDLE_CURRENT_ENDPOINT']
world = {"chief": [ip_port_list[0]]}
for ip_port in ip_port_list[1:]:
world.setdefault('worker', []).append(ip_port)
self_index = ip_port_list.index(ip_port_self)
self_type = 'chief' if self_index == 0 else 'worker'
if self_type == 'worker':
self_index -= 1
env_dict = {
'cluster': world,
'task': {
'type': self_type,
'index': self_index
}
}
return env_dict
dis_config = ernie_gen.propeller.util._get_dict_from_environ_or_json_or_file(
None, 'PROPELLER_DISCONFIG')
if dis_config is None:
log.debug('no PROPELLER_DISCONFIG found, try paddlestype setting')
dis_config = _get_paddlestype_disconfig()
if dis_config is None:
log.debug('no paddle stype setting found')
status = DistributionStatus(dis_config)
def run_on_master(func):
"""skip function in distribution env"""
@functools.wraps(func)
def f(*arg, **kwargs):
"""f"""
if status is None:
raise ValueError('distribution mode unkown at this point')
if status.mode == DistributionMode.LOCAL:
r = func(*arg, **kwargs)
elif status.mode == DistributionMode.NCCL:
if status.is_master:
r = func(*arg, **kwargs)
else:
r = 0 # skip function
#MPI.COMM_WORLD.Barrier()
return r
return f
def init_distribuition_env(program):
if status.mode == DistributionMode.LOCAL:
log.info('Initializing local training')
elif status.mode == DistributionMode.NCCL:
config = F.DistributeTranspilerConfig()
config.mode = "nccl2"
config.nccl_comm_num = 1
F.DistributeTranspiler(config=config).transpile(
status.replica_id,
trainers=','.join(status._env),
current_endpoint=status._this,
program=program.train_program,
startup_program=program.startup_program)
log.info('Initializing distribution training with config %s' %
(repr(dis_config)))
if status.is_master:
sleep(30)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
exporters
"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import os
import itertools
import six
import inspect
import abc
import logging
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from ernie_gen.propeller.util import map_structure
from ernie_gen.propeller.paddle.train import Saver
from ernie_gen.propeller.types import InferenceSpec
from ernie_gen.propeller.train.model import Model
from ernie_gen.propeller.paddle.train.trainer import _build_net
from ernie_gen.propeller.paddle.train.trainer import _build_model_fn
from ernie_gen.propeller.types import RunMode
from ernie_gen.propeller.types import ProgramPair
log = logging.getLogger(__name__)
@six.add_metaclass(abc.ABCMeta)
class Exporter(object):
"""base exporter"""
@abc.abstractmethod
def export(self, exe, program, eval_result, state):
"""export"""
raise NotImplementedError()
class BestExporter(Exporter):
"""export saved model accordingto `cmp_fn`"""
def __init__(self, export_dir, cmp_fn):
"""doc"""
self._export_dir = export_dir
self._best = None
self.cmp_fn = cmp_fn
def export(self, exe, program, eval_model_spec, eval_result, state):
"""doc"""
log.debug('New evaluate result: %s \nold: %s' % (repr(eval_result),
repr(self._best)))
if self._best is None and state['best_model'] is not None:
self._best = state['best_model']
log.debug('restoring best state %s' % repr(self._best))
if self._best is None or self.cmp_fn(old=self._best, new=eval_result):
log.debug('[Best Exporter]: export to %s' % self._export_dir)
eval_program = program.train_program
# FIXME: all eval datasets has same name/types/shapes now!!! so every eval program are the smae
saver = Saver(
self._export_dir, exe, program=program, max_ckpt_to_keep=1)
saver.save(state)
eval_result = map_structure(float, eval_result)
self._best = eval_result
state['best_model'] = eval_result
else:
log.debug('[Best Exporter]: skip step %s' % state.gstep)
class BestInferenceModelExporter(Exporter):
"""export inference model accordingto `cmp_fn`"""
def __init__(self,
export_dir,
cmp_fn,
model_class_or_model_fn=None,
hparams=None,
dataset=None):
"""doc"""
self._export_dir = export_dir
self._best = None
self.cmp_fn = cmp_fn
self.model_class_or_model_fn = model_class_or_model_fn
self.hparams = hparams
self.dataset = dataset
def export(self, exe, program, eval_model_spec, eval_result, state):
"""doc"""
if self.model_class_or_model_fn is not None and self.hparams is not None \
and self.dataset is not None:
log.info('Building program by user defined model function')
if issubclass(self.model_class_or_model_fn, Model):
_model_fn = _build_model_fn(self.model_class_or_model_fn)
elif inspect.isfunction(self.model_class_or_model_fn):
_model_fn = self.model_class_or_model_fn
else:
raise ValueError(
'unknown model %s' % self.model_class_or_model_fn)
# build net
infer_program = F.Program()
startup_prog = F.Program()
with F.program_guard(infer_program, startup_prog):
#share var with Train net
with F.unique_name.guard():
log.info('Building Infer Graph')
infer_fea = self.dataset.features()
# run_config is None
self.model_spec = _build_net(_model_fn, infer_fea,
RunMode.PREDICT, self.hparams,
None)
log.info('Done')
infer_program = infer_program.clone(for_test=True)
self.program = ProgramPair(
train_program=infer_program, startup_program=startup_prog)
else:
self.program = program
self.model_spec = eval_model_spec
if self._best is None and state['best_inf_model'] is not None:
self._best = state['best_inf_model']
log.debug('restoring best state %s' % repr(self._best))
log.debug('New evaluate result: %s \nold: %s' % (repr(eval_result),
repr(self._best)))
if self._best is None or self.cmp_fn(old=self._best, new=eval_result):
log.debug('[Best Exporter]: export to %s' % self._export_dir)
if self.model_spec.inference_spec is None:
raise ValueError('model_fn didnt return InferenceSpec')
inf_spec_dict = self.model_spec.inference_spec
if not isinstance(inf_spec_dict, dict):
inf_spec_dict = {'inference': inf_spec_dict}
for inf_spec_name, inf_spec in six.iteritems(inf_spec_dict):
if not isinstance(inf_spec, InferenceSpec):
raise ValueError(
'unknow inference spec type: %s' % inf_spec)
save_dir = os.path.join(self._export_dir, inf_spec_name)
log.debug('[Best Exporter]: save inference model: "%s" to %s' %
(inf_spec_name, save_dir))
feed_var = [i.name for i in inf_spec.inputs]
fetch_var = inf_spec.outputs
infer_program = self.program.train_program
startup_prog = F.Program()
F.io.save_inference_model(
save_dir,
feed_var,
fetch_var,
exe,
main_program=infer_program)
eval_result = map_structure(float, eval_result)
state['best_inf_model'] = eval_result
self._best = eval_result
else:
log.debug('[Best Exporter]: skip step %s' % state.gstep)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""train hooks"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import six
import os
import itertools
import numpy as np
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
from ernie_gen.propeller import util
from ernie_gen.propeller.paddle.train import distribution
from ernie_gen.propeller.paddle.train.metrics import Metrics
__all__ = [
'RunHook', 'TqdmProgressBarHook', 'TqdmNotebookProgressBarHook',
'CheckpointSaverHook', 'LoggingHook', 'StopAtStepHook', 'EvalHook'
]
log = logging.getLogger(__name__)
class RunHook(object):
"""RunHook Base class"""
def __init__(self):
"""doc"""
pass
def before_train(self, program):
"""doc"""
pass
def before_run(self, state):
"""doc"""
return []
def after_run(self, res_list, state):
"""doc"""
pass
def should_stop(self, state):
"""doc"""
return False
def after_train(self):
"""doc"""
pass
class TqdmProgressBarHook(RunHook):
"""show a progress bar when training"""
def __init__(self, max_steps, desc=None):
"""doc"""
self.tqdm = None
import tqdm
from propeller import log as main_log
hdl = main_log.handlers[0]
class _TqdmLogginHandler(logging.Handler):
def emit(self, record):
"""doc"""
try:
msg = self.format(record)
tqdm.tqdm.write(msg, file=sys.stderr)
self.flush()
except (KeyboardInterrupt, SystemExit) as e:
raise e
except:
self.handleError(record)
tqdm_hdl = _TqdmLogginHandler()
tqdm_hdl.setFormatter(hdl.formatter)
main_log.removeHandler(hdl)
main_log.addHandler(tqdm_hdl)
self.tqdm = tqdm.tqdm(total=max_steps, desc=None)
def before_run(self, state):
self.tqdm.n = state.gstep
return []
def __del__(self):
if self.tqdm:
self.tqdm.close()
class TqdmNotebookProgressBarHook(RunHook):
"""show a progress bar when training"""
def __init__(self, max_steps, desc=None):
"""doc"""
self.tqdm = None
import tqdm
from propeller import log as main_log
hdl = main_log.handlers[0]
class _TqdmLogginHandler(logging.Handler):
def emit(self, record):
"""doc"""
try:
msg = self.format(record)
tqdm.tqdm.write(msg, file=sys.stderr)
self.flush()
except (KeyboardInterrupt, SystemExit) as e:
raise e
except:
self.handleError(record)
tqdm_hdl = _TqdmLogginHandler()
tqdm_hdl.setFormatter(hdl.formatter)
main_log.removeHandler(hdl)
main_log.addHandler(tqdm_hdl)
self.tqdm = tqdm.tqdm_notebook(total=max_steps, desc=None)
def before_run(self, state):
"""doc"""
self.tqdm.n = state.gstep
self.tqdm.refresh()
return []
def __del__(self):
"""doc"""
if self.tqdm:
self.tqdm.close()
class LoggingHook(RunHook):
"""log tensor in to screan and tensorboard"""
def __init__(self,
loss,
per_step=10,
skip_step=100,
summary_writer=None,
summary_record=None):
"""doc"""
if per_step is None or skip_step is None:
raise ValueError('wrong step argument, per step: %d skip_step %d' %
(per_step, skip_step))
self.loss = loss
self.per_step = per_step
self.skip_step = skip_step
self.summary_record = summary_record
self.writer = summary_writer
self.last_state = None
def before_train(self, program):
"""doc"""
if self.summary_record:
if self.summary_record.scalar:
self.s_name, self.s_tolog = zip(*self.summary_record.scalar)
else:
self.s_name, self.s_tolog = [], []
if self.summary_record.histogram:
self.h_name, self.h_tolog = zip(*self.summary_record.histogram)
else:
self.h_name, self.h_tolog = [], []
def before_run(self, state):
"""doc"""
if state.gstep % self.per_step == 0 and state.step > self.skip_step:
ret = [self.loss]
if self.summary_record:
ret += self.s_tolog
ret += self.h_tolog
return ret
else:
return []
def after_run(self, res_list, state):
"""doc"""
if state.gstep % self.per_step == 0 and state.step > self.skip_step:
if not self.summary_record:
return
loss = float(res_list[0])
s_np = res_list[1:1 + len(self.s_name)]
h_np = res_list[1 + len(self.s_name):1 + len(self.s_name) +
len(self.h_name)]
if self.last_state is not None:
speed = (state.gstep - self.last_state.gstep) / (
state.time - self.last_state.time)
else:
speed = -1.
self.last_state = state
# log to tensorboard
if self.writer is not None:
self.writer.add_scalar('loss', loss, state.gstep)
for name, t in zip(self.s_name, s_np):
if np.isnan(t).any():
log.warning('Nan summary: %s, skip' % name)
else:
self.writer.add_scalar(name, t, state.gstep)
for name, t in zip(self.h_name, h_np):
if np.isnan(t).any():
log.warning('Nan summary: %s, skip' % name)
else:
self.writer.add_histogram(name, t, state.gstep)
if speed > 0.:
self.writer.add_scalar('global_step', speed, state.gstep)
# log to stdout
log.debug('\t'.join([
'step: %d' % state.gstep,
'steps/sec: %.5f' % speed,
'loss: %.5f' % loss,
'' if self.summary_record is None else ' '.join(
map(lambda t: '%s:%s' % t, zip(self.s_name, s_np))),
]))
class StopAtStepHook(RunHook):
"""stop training at some step"""
def __init__(self, stop_global_step, stop_step):
"""doc"""
self._stop_gstep = stop_global_step
self._stop_step = stop_step
def should_stop(self, state):
"""doc"""
if (self._stop_gstep and state.gstep >= self._stop_gstep) or \
(self._stop_step and state.step >= self._stop_step):
log.info('StopAtStepHook called stop')
return True
else:
return False
class EvalHook(RunHook):
"""hook this on a eval Executor"""
def __init__(self, metrics, summary_writer=None):
"""doc"""
self.writer = summary_writer
self._result = None
if not isinstance(metrics, dict):
raise ValueError('metrics should be dict, got %s' % repr(metrics))
for k, m in six.iteritems(metrics):
if not isinstance(m, Metrics):
raise ValueError(
'metrics %s should be instance of propeller.Metrics, got %s'
% (k, repr(m)))
if len(metrics):
self.names = list(metrics.keys())
self.metrics = list(metrics.values())
else:
self.names, self.metrics = [], []
def before_train(self, program):
"""doc"""
for m in self.metrics:
m.reset()
def before_run(self, state):
"""doc"""
ls = [m.tensor for m in self.metrics]
for i in ls:
if not (isinstance(i, list) or isinstance(i, tuple)):
raise ValueError(
'metrics should return tuple or list of tensors, got %s' %
repr(i))
for ii in i:
if not isinstance(ii, F.framework.Variable):
raise ValueError(
'metrics tensor be propeller.train.Metrics, got %s of type %s'
% (repr(ii), type(ii)))
ls_flt, self.schema = util.flatten(ls)
#log.debug(ls_flt)
return ls_flt
def after_run(self, res_list, state):
"""doc"""
res = util.unflatten(res_list, self.schema)
for r, m in zip(res, self.metrics):
m.update(r)
@property
def result(self):
"""doc"""
return self._result
def after_train(self):
"""doc"""
printable = []
self._result = {}
for n, m in zip(self.names, self.metrics):
val = m.eval()
self._result[n] = val
return self.result
class CheckpointSaverHook(RunHook):
"""Save checkpoint every n step"""
def __init__(self, saver, per_step=10, skip_step=100):
"""doc"""
self.saver = saver
self.per_step = per_step
self.skip_step = skip_step
def after_run(self, res_list, state):
"""doc"""
if state.gstep % self.per_step == 0 and \
state.step > self.skip_step:
self.saver.save(state)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""predefined metrics"""
import sys
import os
import six
import numpy as np
import itertools
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
import sklearn.metrics
log = logging.getLogger(__name__)
__all__ = [
'Metrics', 'F1', 'Recall', 'Precision', 'Mrr', 'Mean', 'Acc', 'ChunkF1',
'RecallAtPrecision'
]
class Metrics(object):
"""Metrics base class"""
def __init__(self):
"""doc"""
self.saver = []
@property
def tensor(self):
"""doc"""
pass
def update(self, *args):
"""doc"""
pass
def eval(self):
"""doc"""
pass
class Mean(Metrics):
"""doc"""
def __init__(self, t):
"""doc"""
self.t = t
self.reset()
def reset(self):
"""doc"""
self.saver = np.array([])
@property
def tensor(self):
"""doc"""
return self.t,
def update(self, args):
"""doc"""
t, = args
t = t.reshape([-1])
self.saver = np.concatenate([self.saver, t])
def eval(self):
"""doc"""
return self.saver.mean()
class Ppl(Mean):
"""doc"""
def eval(self):
"""doc"""
return np.exp(self.saver.mean())
class Acc(Mean):
"""doc"""
def __init__(self, label, pred):
"""doc"""
if label.shape != pred.shape:
raise ValueError(
'expect label shape == pred shape, got: label.shape=%s, pred.shape = %s'
% (repr(label), repr(pred)))
self.eq = L.equal(pred, label)
self.reset()
@property
def tensor(self):
"""doc"""
return self.eq,
class MSE(Mean):
"""doc"""
def __init__(self, label, pred):
"""doc"""
if label.shape != pred.shape:
raise ValueError(
'expect label shape == pred shape, got: label.shape=%s, pred.shape = %s'
% (repr(label), repr(pred)))
diff = pred - label
self.mse = diff * diff
self.reset()
@property
def tensor(self):
"""doc"""
return self.mse,
class Cosine(Mean):
"""doc"""
def __init__(self, label, pred):
"""doc"""
if label.shape != pred.shape:
raise ValueError(
'expect label shape == pred shape, got: label.shape=%s, pred.shape = %s'
% (repr(label), repr(pred)))
self.cos = L.cos_sim(label, pred)
self.reset()
@property
def tensor(self):
"""doc"""
return self.cos,
class MacroF1(Metrics):
"""doc"""
def __init__(self, label, pred):
"""doc"""
if label.shape != pred.shape:
raise ValueError(
'expect label shape == pred shape, got: label.shape=%s, pred.shape = %s'
% (repr(label), repr(pred)))
self.label = label
self.pred = pred
self.reset()
def reset(self):
"""doc"""
self.label_saver = np.array([], dtype=np.bool)
self.pred_saver = np.array([], dtype=np.bool)
@property
def tensor(self):
"""doc"""
return self.label, self.pred
def update(self, args):
"""doc"""
label, pred = args
label = label.reshape([-1]).astype(np.bool)
pred = pred.reshape([-1]).astype(np.bool)
if label.shape != pred.shape:
raise ValueError(
'Metrics precesion: input not match: label:%s pred:%s' % (label,
pred))
self.label_saver = np.concatenate([self.label_saver, label])
self.pred_saver = np.concatenate([self.pred_saver, pred])
def eval(self):
"""doc"""
return sklearn.metrics.f1_score(
self.label_saver, self.pred_saver, average='macro')
class Precision(Metrics):
"""doc"""
def __init__(self, label, pred):
"""doc"""
if label.shape != pred.shape:
raise ValueError(
'expect label shape == pred shape, got: label.shape=%s, pred.shape = %s'
% (repr(label), repr(pred)))
self.label = label
self.pred = pred
self.reset()
def reset(self):
"""doc"""
self.label_saver = np.array([], dtype=np.bool)
self.pred_saver = np.array([], dtype=np.bool)
@property
def tensor(self):
"""doc"""
return self.label, self.pred
def update(self, args):
"""doc"""
label, pred = args
label = label.reshape([-1]).astype(np.bool)
pred = pred.reshape([-1]).astype(np.bool)
if label.shape != pred.shape:
raise ValueError(
'Metrics precesion: input not match: label:%s pred:%s' % (label,
pred))
self.label_saver = np.concatenate([self.label_saver, label])
self.pred_saver = np.concatenate([self.pred_saver, pred])
def eval(self):
"""doc"""
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
p = self.pred_saver.astype(np.int64).sum()
return tp / p
class Recall(Precision):
"""doc"""
def eval(self):
"""doc"""
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
t = (self.label_saver).astype(np.int64).sum()
return tp / t
class F1(Precision):
"""doc"""
def eval(self):
"""doc"""
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
t = self.label_saver.astype(np.int64).sum()
p = self.pred_saver.astype(np.int64).sum()
precision = tp / (p + 1.e-6)
recall = tp / (t + 1.e-6)
return 2 * precision * recall / (precision + recall + 1.e-6)
class Auc(Metrics):
"""doc"""
def __init__(self, label, pred):
"""doc"""
if label.shape != pred.shape:
raise ValueError(
'expect label shape == pred shape, got: label.shape=%s, pred.shape = %s'
% (repr(label), repr(pred)))
self.pred = pred
self.label = label
self.reset()
def reset(self):
"""doc"""
self.pred_saver = np.array([], dtype=np.float32)
self.label_saver = np.array([], dtype=np.bool)
@property
def tensor(self):
"""doc"""
return [self.pred, self.label]
def update(self, args):
"""doc"""
pred, label = args
pred = pred.reshape([-1]).astype(np.float32)
label = label.reshape([-1]).astype(np.bool)
self.pred_saver = np.concatenate([self.pred_saver, pred])
self.label_saver = np.concatenate([self.label_saver, label])
def eval(self):
"""doc"""
fpr, tpr, thresholds = sklearn.metrics.roc_curve(
self.label_saver.astype(np.int64), self.pred_saver)
auc = sklearn.metrics.auc(fpr, tpr)
return auc
class RecallAtPrecision(Auc):
"""doc"""
def __init__(self, label, pred, precision=0.9):
"""doc"""
super(RecallAtPrecision, self).__init__(label, pred)
self.precision = precision
def eval(self):
"""doc"""
self.pred_saver = self.pred_saver.reshape([self.label_saver.size,
-1])[:, -1]
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
self.label_saver, self.pred_saver)
for p, r in zip(precision, recall):
if p > self.precision:
return r
class PrecisionAtThreshold(Auc):
"""doc"""
def __init__(self, label, pred, threshold=0.5):
"""doc"""
super().__init__(label, pred)
self.threshold = threshold
def eval(self):
"""doc"""
infered = self.pred_saver > self.threshold
correct_num = np.array(infered & self.label_saver).sum()
infer_num = infered.sum()
return correct_num / (infer_num + 1.e-6)
class Mrr(Metrics):
"""doc"""
def __init__(self, qid, label, pred):
"""doc"""
if label.shape != pred.shape:
raise ValueError(
'expect label shape == pred shape, got: label.shape=%s, pred.shape = %s'
% (repr(label), repr(pred)))
self.qid = qid
self.label = label
self.pred = pred
self.reset()
def reset(self):
"""doc"""
self.qid_saver = np.array([], dtype=np.int64)
self.label_saver = np.array([], dtype=np.int64)
self.pred_saver = np.array([], dtype=np.float32)
@property
def tensor(self):
"""doc"""
return [self.qid, self.label, self.pred]
def update(self, args):
"""doc"""
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError(
'Mrr dimention not match: qid[%s] label[%s], pred[%s]' %
(qid.shape, label.shape, pred.shape))
self.qid_saver = np.concatenate(
[self.qid_saver, qid.reshape([-1]).astype(np.int64)])
self.label_saver = np.concatenate(
[self.label_saver,
label.reshape([-1]).astype(np.int64)])
self.pred_saver = np.concatenate(
[self.pred_saver,
pred.reshape([-1]).astype(np.float32)])
def eval(self):
"""doc"""
def _key_func(tup):
return tup[0]
def _calc_func(tup):
ranks = [
1. / (rank + 1.) for rank, (_, l, p) in enumerate(
sorted(tup, key=lambda t: t[2], reverse=True)) if l != 0
]
if len(ranks):
return ranks[0]
else:
return 0.
mrr_for_qid = [
_calc_func(tup) for _, tup in itertools.groupby(
sorted(
zip(self.qid_saver, self.label_saver, self.pred_saver),
key=_key_func),
key=_key_func)
]
mrr = np.float32(sum(mrr_for_qid) / len(mrr_for_qid))
return mrr
class ChunkF1(Metrics):
"""doc"""
def __init__(self, label, pred, seqlen, num_label):
"""doc"""
self.label = label
self.pred = pred
self.seqlen = seqlen
self.null_index = num_label - 1
self.label_cnt = 0
self.pred_cnt = 0
self.correct_cnt = 0
def _extract_bio_chunk(self, seq):
chunks = []
cur_chunk = None
for index in range(len(seq)):
tag = seq[index]
tag_type = tag // 2
tag_pos = tag % 2
if tag == self.null_index:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = None
continue
if tag_pos == 0:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = {}
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
else:
if cur_chunk is None:
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
continue
if cur_chunk["type"] == tag_type:
cur_chunk["en"] = index + 1
else:
chunks.append(cur_chunk)
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
if cur_chunk is not None:
chunks.append(cur_chunk)
return chunks
def reset(self):
"""doc"""
self.label_cnt = 0
self.pred_cnt = 0
self.correct_cnt = 0
@property
def tensor(self):
"""doc"""
return [self.pred, self.label, self.seqlen]
def update(self, args):
"""doc"""
pred, label, seqlen = args
pred = pred.reshape([-1]).astype(np.int32).tolist()
label = label.reshape([-1]).astype(np.int32).tolist()
seqlen = seqlen.reshape([-1]).astype(np.int32).tolist()
max_len = 0
for l in seqlen:
max_len = max(max_len, l)
for i in range(len(seqlen)):
seq_st = i * max_len + 1
seq_en = seq_st + (seqlen[i] - 2)
pred_chunks = self._extract_bio_chunk(pred[seq_st:seq_en])
label_chunks = self._extract_bio_chunk(label[seq_st:seq_en])
self.pred_cnt += len(pred_chunks)
self.label_cnt += len(label_chunks)
pred_index = 0
label_index = 0
while label_index < len(label_chunks) and pred_index < len(
pred_chunks):
if pred_chunks[pred_index]['st'] < label_chunks[label_index][
'st']:
pred_index += 1
elif pred_chunks[pred_index]['st'] > label_chunks[label_index][
'st']:
label_index += 1
else:
if pred_chunks[pred_index]['en'] == label_chunks[label_index]['en'] \
and pred_chunks[pred_index]['type'] == label_chunks[label_index]['type']:
self.correct_cnt += 1
pred_index += 1
label_index += 1
def eval(self):
"""doc"""
if self.pred_cnt == 0:
precision = 0.0
else:
precision = 1.0 * self.correct_cnt / self.pred_cnt
if self.label_cnt == 0:
recall = 0.0
else:
recall = 1.0 * self.correct_cnt / self.label_cnt
if self.correct_cnt == 0:
f1 = 0.0
else:
f1 = 2 * precision * recall / (precision + recall)
return np.float32(f1)
class PNRatio(Metrics):
"""doc"""
def __init__(self, qid, label, pred):
"""doc"""
if label.shape != pred.shape:
raise ValueError(
'expect label shape == pred shape, got: label.shape=%s, pred.shape = %s'
% (repr(label), repr(pred)))
self.qid = qid
self.label = label
self.pred = pred
self.saver = {}
def reset(self):
"""doc"""
self.saver = {}
@property
def tensor(self):
"""doc"""
return [self.qid, self.label, self.pred]
def update(self, args):
"""doc"""
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]'
% (qid.shape, label.shape, pred.shape))
qid = qid.reshape([-1]).tolist()
label = label.reshape([-1]).tolist()
pred = pred.reshape([-1]).tolist()
assert len(qid) == len(label) == len(pred)
for q, l, p in zip(qid, label, pred):
if q not in self.saver:
self.saver[q] = []
self.saver[q].append((l, p))
def eval(self):
"""doc"""
p = 0
n = 0
for qid, outputs in self.saver.items():
for i in range(0, len(outputs)):
l1, p1 = outputs[i]
for j in range(i + 1, len(outputs)):
l2, p2 = outputs[j]
if l1 > l2:
if p1 > p2:
p += 1
elif p1 < p2:
n += 1
elif l1 < l2:
if p1 < p2:
p += 1
elif p1 > p2:
n += 1
pn = p / n if n > 0 else 0.0
return np.float32(pn)
class BinaryPNRatio(PNRatio):
"""doc"""
def __init__(self, qid, label, pred):
"""doc"""
super(BinaryPNRatio, self).__init__(qid, label, pred)
def eval(self):
"""doc"""
p = 0
n = 0
for qid, outputs in self.saver.items():
pos_set = []
neg_set = []
for label, score in outputs:
if label == 1:
pos_set.append(score)
else:
neg_set.append(score)
for ps in pos_set:
for ns in neg_set:
if ps > ns:
p += 1
elif ps < ns:
n += 1
else:
continue
pn = p / n if n > 0 else 0.0
return np.float32(pn)
class PrecisionAtK(Metrics):
"""doc"""
def __init__(self, qid, label, pred, k=1):
"""doc"""
if label.shape != pred.shape:
raise ValueError(
'expect label shape == pred shape, got: label.shape=%s, pred.shape = %s'
% (repr(label), repr(pred)))
self.qid = qid
self.label = label
self.pred = pred
self.k = k
self.saver = {}
def reset(self):
"""doc"""
self.saver = {}
@property
def tensor(self):
"""doc"""
return [self.qid, self.label, self.pred]
def update(self, args):
"""doc"""
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]'
% (qid.shape, label.shape, pred.shape))
qid = qid.reshape([-1]).tolist()
label = label.reshape([-1]).tolist()
pred = pred.reshape([-1]).tolist()
assert len(qid) == len(label) == len(pred)
for q, l, p in zip(qid, label, pred):
if q not in self.saver:
self.saver[q] = []
self.saver[q].append((l, p))
def eval(self):
"""doc"""
right = 0
total = 0
for v in self.saver.values():
v = sorted(v, key=lambda x: x[1], reverse=True)
k = min(self.k, len(v))
for i in range(k):
if v[i][0] == 1:
right += 1
break
total += 1
return np.float32(1.0 * right / total)
#class SemanticRecallMetrics(Metrics):
# def __init__(self, qid, vec, type_id):
# self.qid = qid
# self.vec = vec
# self.type_id = type_id
# self.reset()
#
# def reset(self):
# self.saver = []
#
# @property
# def tensor(self):
# return [self.qid, self.vec, self.type_id]
#
# def update(self, args):
# qid, vec, type_id = args
# self.saver.append((qid, vec, type_id))
#
# def eval(self):
# dic = {}
# for qid, vec, type_id in self.saver():
# dic.setdefault(i, {}).setdefault(k, []).append(vec)
#
# for qid in dic:
# assert len(dic[qid]) == 3
# qvec = np.arrray(dic[qid][0])
# assert len(qvec) == 1
# ptvec = np.array(dic[qid][1])
# ntvec = np.array(dic[qid][2])
#
# np.matmul(qvec, np.transpose(ptvec))
# np.matmul(qvec, np.transpose(ntvec))
#
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
doc
"""
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import json
from functools import reduce
import six
from time import time
import shutil
import logging
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from ernie_gen.propeller import util
from ernie_gen.propeller.types import StopException, ProgramPair, WarmStartSetting, TextoneWarmStartSetting
from ernie_gen.propeller.paddle.train import hooks
from . import distribution
log = logging.getLogger(__name__)
__all__ = ['MonitoredExecutor', 'Saver']
def _get_one_place():
return F.cuda_places()[0] if F.core.is_compiled_with_cuda(
) else F.cpu_places()[0]
class RunState(object):
"""serializable Run state object"""
@classmethod
def from_dict(cls, d):
d['step'] = 0
r = RunState()
r.__dict__ = d
return r
@classmethod
def from_str(cls, s):
"""doc"""
j = json.loads(s)
return cls.from_dict(j)
def __init__(self):
"""doc"""
self.__dict__ = {'gstep': 0, 'step': 0, 'time': time()}
@property
def gstep(self):
"""doc"""
return self.__dict__.get(
'gstep', self.__dict__.get('global_step')) # backward compatibility
@property
def step(self):
"""doc"""
return self.__dict__['step']
def __setitem__(self, k, v):
self.__dict__[k] = v
def __getitem__(self, k):
return self.__dict__.get(k, None)
@property
def time(self):
"""doc"""
return self.__dict__['time']
def state_dict(self):
return self.__dict__
def __repr__(self):
"""doc"""
return repr(self.state_dict())
def serialize(self):
"""doc"""
return json.dumps(self.state_dict())
def next(self):
"""doc"""
newd = dict(
self.__dict__,
gstep=self.gstep + 1,
step=self.step + 1,
time=time())
ret = RunState()
ret.__dict__ = newd
return ret
class Saver(object):
"""checkpoint saver and manager"""
def __init__(self,
save_dir,
exe,
program,
save_prefix='model',
max_ckpt_to_keep=None):
"""doc"""
assert isinstance(
exe, F.Executor
), 'expect normal executor to save, got executor of type %s' % repr(
type(exe))
self._exe = exe
self._program = program
self._save_dir = save_dir
self._save_prefix = save_prefix
self._max_ckpt_to_keep = 10 if max_ckpt_to_keep is None else max_ckpt_to_keep
self.ckpt_info_path = os.path.join(save_dir, 'ckpt_info')
if os.path.exists(self.ckpt_info_path):
self.ckpt_list = [
p.strip() for p in open(self.ckpt_info_path).readlines()
]
log.debug('ckpt_list in this Saver: %s' % (self.ckpt_list))
else:
self.ckpt_list = []
@property
def last_ckpt(self):
"""doc"""
return self.ckpt_list[-1] if len(self.ckpt_list) else None
def _save_program(self, dir):
F.io.save_persistables(self._exe, dir, self._program.train_program)
def _load_program(self, dir, predicate_fn=None):
if predicate_fn is None:
def _fn(v):
vpath = os.path.join(dir, v.name)
if F.io.is_persistable(v):
if os.path.exists(vpath):
return True
else:
log.warning(
'var %s not found in checkpoint, ignored' % v.name)
return False
predicate_fn = _fn
try:
F.io.load_vars(
self._exe,
dir,
main_program=self._program.train_program,
predicate=predicate_fn)
except F.core.EnforceNotMet as e:
log.exception(e)
raise RuntimeError(
'can not load model from %s, is this a textone checkpoint?' %
dir)
def save(self, state):
"""doc"""
save_name = '%s_%d' % (self._save_prefix, state.gstep)
save_dir = os.path.join(self._save_dir, save_name)
tmp_dir = os.path.join(self._save_dir, 'tmp')
try:
shutil.rmtree(save_dir)
shutil.rmtree(tmp_dir)
except OSError:
pass
log.debug('saving step %d to %s' % (state.gstep, save_dir))
self._save_program(tmp_dir)
shutil.move(tmp_dir, save_dir)
meta = state.serialize()
open(os.path.join(save_dir, 'meta'), 'w').write(meta)
self.ckpt_list.append(save_name)
if len(self.ckpt_list) > self._max_ckpt_to_keep:
ckpt_to_keep = self.ckpt_list[-self._max_ckpt_to_keep:]
ckpt_to_remove = set(self.ckpt_list) - set(ckpt_to_keep)
self.ckpt_list = ckpt_to_keep
for ckpt in ckpt_to_remove:
ckpt_dir = os.path.join(self._save_dir, ckpt)
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir)
log.debug('No. of ckpt exceed %d, clean up: %s' %
(self._max_ckpt_to_keep, ckpt_dir))
open(self.ckpt_info_path, 'w').write('\n'.join(self.ckpt_list))
def restore(self, ckpt=-1):
"""doc"""
if isinstance(ckpt, int):
try:
path = os.path.join(self._save_dir, self.ckpt_list[ckpt])
except IndexError:
raise ValueError('invalid restore ckpt number %d' % ckpt)
elif isinstance(ckpt, six.string_types):
if not os.path.exists(ckpt):
raise ValueError('ckpt: %s not found' % ckpt)
path = ckpt
else:
raise ValueError('ckpt type not understood %s' % repr(ckpt))
meta_file = os.path.join(path, 'meta')
if not os.path.exists(meta_file):
raise RuntimeError('meta not found in restore dir: %s' % path)
state = RunState.from_str(open(meta_file).read())
log.info('restore from ckpt %s, ckpt-status: %s' % (path, repr(state)))
self._load_program(path)
return state
class SaverV2(Saver):
def _save_program(self, dir):
save_path = os.path.join(dir, 'ckpt')
F.save(self._program.train_program, save_path)
def _load_program(self, dir, predicate_fn=None):
try:
save_path = os.path.join(dir, 'ckpt')
F.load(
self._program.train_program,
save_path,
)
except F.core.EnforceNotMet as e:
log.exception(e)
raise RuntimeError(
'can not load model from %s, is this a textone checkpoint?' %
dir)
TextoneTrainer = None
class MonitoredExecutor(object):
"""An Executor wrapper handling the train loop"""
saver_class = SaverV2 # will change if textone enabled
def __init__(
self,
executor,
program,
loss=None, #must set in train
state=None,
run_config=None, #none if not load
run_hooks=[],
warm_start_setting=None):
if not isinstance(executor, F.Executor):
raise ValueError('PE is no longer supported')
if isinstance(executor, F.ParallelExecutor):
raise ValueError('ParallelExecutor is deprecatd, use Executor')
if not isinstance(program, ProgramPair):
raise ValueError('Expect ProgramPair, got %r' % type(program))
self._exe = executor
self._hooks = run_hooks
self._state = RunState() # might be overwrite in freeze
self._program = program
self._loss = loss
self._warm_start_setting = warm_start_setting
self._saver = None # will set in prepare
self.result = None # will set after train
if run_config is not None:
self._model_dir = run_config.model_dir
self._save_dir = run_config.model_dir
self._save_steps = run_config.save_steps
self._skip_steps = run_config.skip_steps if run_config.skip_steps else 100
self._save_prefix = 'model'
self._max_ckpt = run_config.max_ckpt
@property
def state(self):
"""doc"""
return self._state
def init_or_restore_variables(self, ckpt=-1):
"""
init vars or restore vars from model_dir
call before train
"""
# The order of this 2 steps really matters
# 1. init train
F.Executor(_get_one_place()).run(self._program.startup_program)
# 2. restore param
self._saver = self.saver_class(
self._model_dir,
F.Executor(_get_one_place()),
program=self._program,
max_ckpt_to_keep=self._max_ckpt)
if self._warm_start_setting is not None:
if not os.path.exists(self._warm_start_setting.from_dir):
raise ValueError('warm start dir not exists: %s' %
self._warm_start_setting.from_dir)
if isinstance(self._warm_start_setting, WarmStartSetting):
log.info(
"warm start from %s" % self._warm_start_setting.from_dir)
log.info(self._saver)
if (not type(self._saver) is Saver) and (
not type(self._saver) is SaverV2):
raise ValueError(
'try to warm start from standart dir, but textone enabled'
)
if self._warm_start_setting.predicate_fn is not None:
def _fn(v):
ret = self._warm_start_setting.predicate_fn(v)
if ret:
log.info('warm start: %s' % v.name)
return ret
try:
F.io.load_vars(
self._exe,
self._warm_start_setting.from_dir,
main_program=self._program.train_program,
predicate=_fn)
except F.core.EnforceNotMet as e:
log.exception(e)
raise RuntimeError(
'can not load model from %s, is this a textone checkpoint?'
% dir)
else:
raise NotImplementedError()
elif isinstance(self._warm_start_setting, TextoneWarmStartSetting):
if not type(self._saver) is TextoneTrainer:
raise ValueError(
'try to warm start from textone pretrain dir, but textone not enabled'
)
log.info("[texone] warm start from %s" %
self._warm_start_setting.from_dir)
self._saver._load_pretrained(self._warm_start_setting.from_dir)
else:
raise ValueError(
'expect _warm_start_setting to be TextoneWarmStartSetting of WarmStartSetting, got %s'
% repr(self._warm_start_setting))
if self._saver.last_ckpt is not None:
self._state = self._saver.restore(ckpt)
def _freeze(self):
"""
call before enter train loop
convert program to compiled program
will do nothing if loss is None i.e. not in train mode
"""
if self._loss is None:
log.debug('will not freeze a program without loss')
return
if isinstance(self._program.train_program, F.compiler.CompiledProgram):
log.debug('program has already been built')
return
exec_strategy = F.ExecutionStrategy()
exec_strategy.num_threads = 4 #2 for fp32 4 for fp16
exec_strategy.use_experimental_executor = True
exec_strategy.num_iteration_per_drop_scope = 10 #important shit
build_strategy = F.BuildStrategy()
build_strategy.remove_unnecessary_lock = False
#build_strategy.fuse_broadcast_ops = True
build_strategy.num_trainers = distribution.status.num_replica
build_strategy.trainer_id = distribution.status.replica_id
build_strategy.memory_optimize = True
log.info('replica id %d of %d' % (distribution.status.replica_id,
distribution.status.num_replica))
program = F.CompiledProgram(
self._program.train_program).with_data_parallel(
loss_name=self._loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
self._program = ProgramPair(
train_program=program,
startup_program=self._program.startup_program)
def __enter__(self):
"""
prepapre before enter train loop
"""
if F.core.is_compiled_with_cuda():
log.info('propeller runs in CUDA mode')
else:
log.info('propeller runs in CPU mode')
#log.debug('freezing program')
self._freeze()
#log.debug('done freezing')
log.info('********** Start Loop ************')
# TODO init
self.result = None
for h in self._hooks:
log.debug('train loop has hook %s' % h)
h.before_train(self._program)
return self
def run(self, fetch_list=[], *args, **kwargs):
"""
wrapper for Executor.run
"""
#log.debug('Executor running step %d' % self._state.gstep)
if self._hooks:
fetch_list = [fetch_list]
for h in self._hooks:
#log.debug('calling hook.before_run %s' % h)
fetch = h.before_run(self._state)
fetch_list.append(fetch)
fetch_list_len = map(len, fetch_list)
fetch_list, schema = util.flatten(fetch_list)
fetch_list = [
f.name if not isinstance(f, six.string_types) else f
for f in fetch_list
]
#if len(set(fetch_list)) != len(fetch_list):
# log.error('strange shit happend when fetch list has idetity tensors %s' % fetch_list)
#log.debug(fetch_list)
res = self._exe.run(
self._program.train_program,
fetch_list=fetch_list,
*args,
**kwargs)
res = [self._merge_result(r) for r in res]
#log.debug(res)
res = util.unflatten(res, schema)
ret, res = res[0], res[1:]
for r, h in zip(res, self._hooks):
#log.debug('calling hook.after_run')
h.after_run(r, self._state)
if any(map(lambda i: i.should_stop(self._state), self._hooks)):
raise StopException('hook call stop')
else:
ret = self._exe.run(
self._program.train_program,
fetch_list=fetch_list,
*args,
**kwargs)
self._state = self._state.next()
return ret
def __exit__(self, err_type, err_value, trace):
"""
clean up things and report hook result when exit train loop
"""
if (err_type is None) or isinstance(
err_value,
(F.core.EOFException, StopException, KeyboardInterrupt)):
try:
log.info('********** Stop Loop ************')
self.result = []
for h in self._hooks:
self.result.append(h.after_train())
except Exception as e:
log.exception('error occur after loop %s' % repr(e))
else:
log.info('********** Interupt Loop ************')
log.exception(
'error occur during loop %s: %s' % (err_type, err_value))
def _merge_result(self, ls):
"""
merge results from multi gpu cards
"""
dev_count = len(self._program.train_program._places) if isinstance(
self._program.train_program, F.compiler.CompiledProgram) else 1
if dev_count == 1:
return ls
else:
shape = (-1, ls.shape[0] // dev_count) + ls.shape[1:]
ret = np.reshape(ls, shape).mean(axis=0)
return ret
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""server"""
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package interface;
service Inference {
rpc Infer(Slots) returns (Slots){}
}
message Slots {
repeated Slot slots = 1;
}
message Slot {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
// Tensor<size_t> is used in C++.
SIZE_T = 19;
UINT8 = 20;
INT8 = 21;
}
Type type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
bytes data = 3;
}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
doc
"""
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册