提交 660e9ea6 编写于 作者: C chenxuyi

+ propeller

上级 cce162f8
[简体中文](./README.md)|English
# Introducing Propeller
This doc introduct Propeller, a high level paddle API for general ML, Propeller encapsulate the following actions::
- training
- evaluation
- prediction
- export serving
Propeller provide the following benefits:
- You can run Propeller-based models on a local host or on a distributed multi-server environment without changing your model. Furthermore, you can run Propeller-based models on CPUs, GPUs without recoding your model.
- Propeller simplify sharing implementations between model developers.
- Propeller do many things for you (logging, hot-start...)
- Propeller buids Program and PyReader or you.
- Propeller provide a safe distributed training loop that controls how and when to:
- build the Program
- initialize variables
- create checkpoint files and recover from failures
- save visualizable results
## install
```script
cd propeller && pip install .
```
## Getting Started
```python
#Define model
class BowModel(propeller.Model):
def __init__(self, config, mode):
self.embedding = Embedding(config['emb_size'], config['vocab_size'])
self.fc1 = FC(config['hidden_size'])
self.fc2 = FC(config['hidden_size'])
def forward(self, features):
q, t = features
q_emb = softsign(self.embedding(q))
t_emb = softsign(self.embedding(t))
q_emb = self.fc1(q_emb)
t_emb = self.fc2(t_emn)
prediction = dot(q_emb, emb)
return prediction
def loss(self, predictions, label):
return sigmoid_cross_entropy_with_logits(predictions, label)
def backward(self, loss):
opt = AdamOptimizer(1.e-3)
opt.mimize(loss)
def metrics(self, predictions, label):
auc = atarshi.metrics.Auc(predictions, label)
return {'auc': auc}
# hyper param comes from files/command line prompt/env vir
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
# Define data
# `FeatureColumns` helps you to organize training/evluation files.
feature_column = propeller.data.FeatureColumns(columns=[
propeller.data.TextColumn('query', vocab='./vocab'),
propeller.data.TextColumn('title', vocab='./vocab'),
propeller.data.LabelColumn('label'),
])
train_ds = feature_column.build_dataset(data_dir='./data', shuffle=True, repeat=True)
eval_ds = feature_column.build_dataset(data_dir='./data', shuffle=False, repeat=False)
# Start training!
propeller.train_and_eval(BowModel, hparams, run_config, train_ds, eval_ds)
```
More detail see example/toy/
## Main Feature
1. train_and_eval
according to user-specified `propeller.Model`class,initialize training model in the following 2 modes: 1. TRAIN mode 2. EVAL mode and
perform train_and_eval
2. FeatureColumns
`FeatureColumns`is used to ogranize train data. With custmizable `Column` property, it can adaps to many ML tasks(NLP/CV...).
`FeatureColumns` also do the preprocessing for you (tokenization, vocab lookup, serialization, batcing etc.)
3. Dataset
`FeatureColumns` generats `Dataset`,or you can call `propeller.Dataset.from_generator_func` to build your own `Dataset`.
4. Summary
To trace tensor histogram in training, simply:
```python
propeller.summary.histogram('loss', tensor)
```
## Contributing
1. This project is in alpha stage, any contribution is welcomed. Fill free to create a PR.
简体中文|[English](./README.en.md)
# Introducing paddle-propeller
本文档介绍propeller,一种可极大地简化机器学习编程的高阶 Paddle API。propeller 会封装下列操作:
- 训练
- 评估
- 预测
- 导出以供使用(上线)
Propeller 具有下列优势:
- 您可以在本地主机上或分布式多服务器环境中运行基于 Propeller 的模型,而无需更改模型。此外,您可以在 CPU、GPU上运行基于 Propeller 的模型,而无需重新编码模型。
- Propeller 简化了在模型开发者之间共享实现的过程。
- 只需关注模型实现以及数据输入,而无需关注其他辅助代码(保存、热启动、打log等)
- Propeller 会为您构建Program以及PyReader。
- Propeller 提供安全的分布式训练循环,可以控制如何以及何时:
- 构建Program
- 初始化变量
- 处理异常
- 创建检查点文件并从故障中恢复
- 保存可视化的摘要结果
## install|安装
cd propeller && pip install .
## Getting Started|快速开始
```python
#定义训练模型
class BowModel(propeller.Model):
def __init__(self, config, mode):
self.embedding = Embedding(config['emb_size'], config['vocab_size'])
self.fc1 = FC(config['hidden_size'])
self.fc2 = FC(config['hidden_size']
def forward(self, features):
q, t = features
q_emb = softsign(self.embedding(q))
t_emb = softsign(self.embedding(t))
q_emb = self.fc1(q_emb)
t_emb = self.fc2(t_emn)
prediction = dot(q_emb, emb)
return prediction
def loss(self, predictions, label):
return sigmoid_cross_entropy_with_logits(predictions, label)
def backward(self, loss):
opt = AdamOptimizer(1.e-3)
opt.mimize(loss)
def metrics(self, predictions, label):
auc = atarshi.metrics.Auc(predictions, label)
return {'auc': auc}
# 超参可以来自于文件/ 环境变量/ 命令行
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
# 定义数据:
# `FeatureColumns` 用于管理训练、预测文件. 会自动进行二进制化.
feature_column = propeller.data.FeatureColumns(columns=[
propeller.data.TextColumn('query', vocab='./vocab'),
propeller.data.TextColumn('title', vocab='./vocab'),
propeller.data.LabelColumn('label'),
])
train_ds = feature_column.build_dataset(data_dir='./data', shuffle=True, repeat=True)
eval_ds = feature_column.build_dataset(data_dir='./data', shuffle=False, repeat=False)
# 开始训练!
propeller.train_and_eval(BowModel, hparams, run_config, train_ds, eval_ds)
```
先洗详细请见example/toy/
## 主要构件
1. train_and_eval
会根据用户提供的`propeller.Model`类,实例化两种模式下的训练模型: 1. TRAIN模式 2. EVAL模式。
然后开始训练,同时执行评估(Evaluation)
2. FeatureColumns
`FeatureColumns`来管理训练数据. 根据自定义`Column`来适配多种ML任务(NLP/CV...).
`FeatureColumns`会自动对提供的训练数据进行批量预处理(tokenization, 查词表, etc.)并二进制化,并且生成训练用的dataset
3. Dataset
`FeatureColumns`生成`Dataset`,或者您可以调用`propeller.Dataset.from_generator_func`来构造自己的`Dataset`,配合shuffle/ interleave/ padded_batch/ repeat 等方法满足定制化需求.
4. Summary
对训练过程中的某些参数进行log追踪,只需要:
```python
propeller.summary.histogram('loss', tensor)
```
## Contributing|贡献
1. 本项目处于初期阶段,欢迎贡献!
2. functional programing is welcomed
## TODO
1. dataset output_types/ output_shapes 自动推断
2. 自动超参数搜索
3. propeller server
4. ...
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import logging
import six
from time import time
__version__ = '0.1'
log = logging.getLogger(__name__)
stream_hdl = logging.StreamHandler(stream=sys.stderr)
formatter = logging.Formatter(
fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s'
)
try:
from colorlog import ColoredFormatter
fancy_formatter = ColoredFormatter(
fmt='%(log_color)s[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s'
)
stream_hdl.setFormatter(fancy_formatter)
except ImportError:
stream_hdl.setFormatter(formatter)
log.setLevel(logging.INFO)
log.addHandler(stream_hdl)
log.propagate = False
from propeller.types import *
from propeller.util import ArgumentParser, parse_hparam, parse_runconfig, parse_file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import logging
import os
import itertools
import random
import inspect
import multiprocessing
from contextlib import contextmanager
import gzip
import struct
import functools
import six
from six.moves import zip, map, filter
import numpy as np
from propeller.util import map_structure
log = logging.getLogger(__name__)
__all__ = ['Dataset']
@contextmanager
def open_file(filename, format=None):
if format is None:
fd = open(filename, 'rb')
elif format == 'GZIP':
fd = gzip.open(filename, 'rb')
else:
raise ValueError('unkwon file format %s' % format)
yield fd
fd.close()
def open_record(filename):
def gen():
with open_file(filename, format='GZIP') as f:
while True:
data = f.read(struct.calcsize('i'))
if not len(data):
raise StopIteration
l, = struct.unpack('i', data)
data = f.read(l)
yield data
return gen
def shuffle_func(dataset, buffer_size):
def gen():
buf = []
iterable = dataset()
try:
while len(buf) < buffer_size:
buf.append(next(iterable))
while 1:
i = random.randint(0, buffer_size - 1)
n = next(iterable)
yield buf[i]
buf[i] = n
except StopIteration:
if len(buf):
random.shuffle(buf)
for i in buf:
yield i
return gen
def interleave_func(iterable, map_fn, cycle_length, block_length):
def gen():
ls = itertools.tee(iterable(), cycle_length)
buf = []
for i, j in enumerate(ls):
j = itertools.islice(j, i, None, cycle_length)
j = map(map_fn, j)
j = (jjj for jj in j for jjj in jj) #flatten
buf.append(j)
for tup in six.moves.zip_longest(*buf):
for ii in (i for i in tup if i is not None):
yield ii
return gen
def repeat_func(dataset, n):
def gen():
iterable = dataset()
if n >= 0:
ret = itertools.chain(*itertools.tee(iterable, n))
else:
ret = itertools.cycle(iterable)
for i in ret:
yield i
return gen
def filter_func(dataset, fn):
def gen():
for i in dataset():
if isinstance(i, tuple) or isinstance(i, list):
if fn(*i) is True:
yield i
else:
if fn(i) is True:
yield i
return gen
def map_func(dataset, fn):
def gen():
for i in dataset():
if isinstance(i, tuple) or isinstance(i, list):
yield fn(*i)
else:
yield fn(i)
return gen
def shard_func(dataset, num_shards, index):
def gen():
iterable = dataset()
ret = itertools.islice(iterable, index, None, num_shards)
for i in ret:
yield i
return gen
def take_func(dataset, count):
def gen():
iterable = dataset()
ret = itertools.islice(iterable, count)
for i in ret:
yield i
return gen
def buffered_func(dataset, size):
"""
Creates a buffered data reader.
The buffered data reader will read and save data entries into a
buffer. Reading from the buffered data reader will proceed as long
as the buffer is not empty.
:param reader: the data reader to read from.
:type reader: callable
:param size: max buffer size.
:type size: int
:returns: the buffered data reader.
"""
class EndSignal():
pass
end = EndSignal()
def read_worker(r, q):
for d in r:
q.put(d)
q.put(end)
def data_reader():
r = dataset()
q = multiprocessing.Queue(maxsize=size)
t = multiprocessing.Process(
target=read_worker, args=(
r,
q, ))
t.daemon = True
t.start()
e = q.get()
while e != end:
yield e
e = q.get()
return data_reader
def padded_batch_func(dataset, batch_size, pad_value=0, max_seqlen=None):
if not isinstance(batch_size, int):
raise ValueError('unknown batch_size: %s' % repr(batch_size))
def gen():
iterable = dataset()
pad_value_t = pad_value
while True:
buf = list(itertools.islice(iterable, batch_size))
if not len(buf):
raise StopIteration
buf = list(zip(*buf)) # transpose
if type(pad_value_t) not in [list, tuple]:
pad_value_t = [pad_value_t] * len(buf)
padded = []
assert len(buf) == len(
pad_value_t), 'pad_value [%d] != element size[%d]' % (
len(pad_value_t), len(buf))
for e, pv in zip(buf, pad_value_t):
elem = e[0]
if (not np.isscalar(elem)) and elem.shape != ():
max_len = max(map(len,
e)) if max_seqlen is None else max_seqlen
e = map(lambda i: np.pad(i, [0, max_len - len(i)], 'constant', constant_values=pv) if max_len >= len(i) else i[: max_len], e)
padded.append(np.stack(list(e)))
yield padded
return gen
class Dataset(object):
@classmethod
def from_generator_func(cls, gen, data_shapes=None, data_types=None):
if not inspect.isgeneratorfunction(gen):
raise ValueError('expect generator function, got %s' % repr(gen))
def wrapper(): #compat to py3.7
try:
for item in gen():
yield item
except RuntimeError as e:
if str(e) != 'generator raised StopIteration':
raise e
ret = cls()
ret.generator = wrapper
ret.data_shapes = data_shapes
ret.data_types = data_types
return ret
@classmethod
def from_file(cls, filename, format=None):
if os.path.getsize(filename) == 0:
raise RuntimeError('%s is empty' % filename)
def gen():
with open_file(filename, format) as f:
for line in f:
yield line
ret = cls()
ret.generator = gen
ret.data_shapes = []
ret.data_types = str
return ret
@classmethod
def from_record_file(cls, filename):
if os.path.getsize(filename) == 0:
raise RuntimeError('%s is empty' % filename)
gen = open_record(filename)
ret = cls()
ret.generator = gen
ret.data_shapes = []
ret.data_types = str
return ret
@classmethod
def from_list(cls, ls):
if not isinstance(ls, list):
raise ValueError('expect list, got %s' % repr(ls))
def gen():
for i in ls:
yield i
ret = cls()
ret.generator = gen
ret.data_shapes = []
ret.data_types = str
return ret
def __init__(self):
self.name = None
self._data_shapes = None
self._data_types = None
self.generator = None
self.pyreader = None
def __repr__(self):
return 'Dataset: name: %s, data_shapes %s, data_types %s' % (
self.name, self._data_shapes, self._data_types)
def __eq__(self, other):
return self.name == other.name and \
self._data_shapes == other._data_shapes and \
self._data_types == other._data_types
def __iter__(self):
return self.generator()
#def __call__(self):
# return self.generator()
def _infer_shapes_and_types(self):
if self.generator is not None and self.name is not None:
log.info('Try to infer data shapes & types from generator')
first_value = next(self.generator())
shapes, types = [], []
for v in first_value:
if not isinstance(v, np.ndarray):
raise ValueError(
'dataset generator should use numpy elements, got %s' %
first_value)
shapes.append(v.shape)
types.append(v.dtype.name)
self._data_shapes = shapes
self._data_types = types
log.info('Dataset `%s` has data_shapes: %s data_types: %s' %
(self.name, repr(shapes), repr(types)))
else:
raise ValueError(
'Try to infer data shapes or types from incomplete Dataset')
@property
def data_shapes(self):
if self._data_shapes is None:
self._infer_shapes_and_types()
return self._data_shapes
else:
return self._data_shapes
@data_shapes.setter
def data_shapes(self, val):
self._data_shapes = val
@property
def data_types(self):
if self._data_types is None:
self._infer_shapes_and_types()
return self._data_types
else:
return self._data_types
@data_types.setter
def data_types(self, val):
self._data_types = val
def apply(self, transform_func):
#input_shapes = transform_func.input_shapes
#input_types = transform_func.input_types
#data_shapes = transform_func.data_shapes
#data_types = transform_func.data_types
#assert input_shapes == self._data_shapes
#assert input_types = self._data_types
ret_gen = transform_func(self.generator)
ret = type(self).from_generator_func(ret_gen)
if self.name is not None:
ret.name = self.name
#ret.data_shapes = data_shapes
#ret.data_types = data_types
return ret
def shuffle(self, buffer_size):
func = functools.partial(shuffle_func, buffer_size=buffer_size)
return self.apply(func)
def repeat(self, n=-1):
func = functools.partial(repeat_func, n=n)
return self.apply(func)
def map(self, fn):
func = functools.partial(map_func, fn=fn)
return self.apply(func)
def filter(self, fn):
func = functools.partial(filter_func, fn=fn)
return self.apply(func)
def shard(self, num_shards, index):
func = functools.partial(
shard_func, num_shards=num_shards, index=index)
return self.apply(func)
def interleave(self, map_fn, cycle_length, block_length):
func = functools.partial(
interleave_func,
map_fn=map_fn,
cycle_length=cycle_length,
block_length=block_length)
return self.apply(func)
def padded_batch(self, batch_size, pad_value=0, max_seqlen=None):
func = functools.partial(
padded_batch_func,
batch_size=batch_size,
pad_value=pad_value,
max_seqlen=max_seqlen)
return self.apply(func)
def take(self, count=1):
func = functools.partial(take_func, count=count)
return self.apply(func)
def buffered(self, size=10):
func = functools.partial(buffered_func, size=size)
return self.apply(func)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import six
from propeller.types import *
from propeller.util import ArgumentParser, parse_hparam, parse_runconfig, parse_file
from propeller.paddle import data
from propeller.paddle import train
from propeller.paddle.train import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
_global_collection = None
class Key(object):
"""predefine collection keys"""
SUMMARY_SCALAR = 1
SUMMARY_HISTOGRAM = 2
SKIP_OPTIMIZE = 3
class Collections(object):
"""global collections to record everything"""
def __init__(self):
self.col = {}
def __enter__(self):
global _global_collection
_global_collection = self
return self
def __exit__(self, err_type, err_value, trace):
global _global_collection
_global_collection = None
def add(self, key, val):
self.col.setdefault(key, []).append(val)
def get(self, key):
return self.col.get(key, None)
def default_collection():
global _global_collection
if _global_collection is None:
_global_collection = Collections()
return _global_collection
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
from propeller.paddle.data.functional import *
from propeller.paddle.data.feature_column import *
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Protocol messages for describing input data Examples for machine learning
// model training or inference.
syntax = "proto3";
import "propeller/paddle/data/feature.proto";
package propeller;
message Example {
Features features = 1;
};
message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
};
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: propeller/paddle/data/example.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (
lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from propeller.paddle.data import feature_pb2 as propeller_dot_paddle_dot_data_dot_feature__pb2
DESCRIPTOR = _descriptor.FileDescriptor(
name='propeller/paddle/data/example.proto',
package='propeller',
syntax='proto3',
serialized_options=None,
serialized_pb=_b(
'\n#propeller/paddle/data/example.proto\x12\tpropeller\x1a#propeller/paddle/data/feature.proto\"0\n\x07\x45xample\x12%\n\x08\x66\x65\x61tures\x18\x01 \x01(\x0b\x32\x13.propeller.Features\"g\n\x0fSequenceExample\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.propeller.Features\x12.\n\rfeature_lists\x18\x02 \x01(\x0b\x32\x17.propeller.FeatureListsb\x06proto3'
),
dependencies=[
propeller_dot_paddle_dot_data_dot_feature__pb2.DESCRIPTOR,
])
_EXAMPLE = _descriptor.Descriptor(
name='Example',
full_name='propeller.Example',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='features',
full_name='propeller.Example.features',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=87,
serialized_end=135, )
_SEQUENCEEXAMPLE = _descriptor.Descriptor(
name='SequenceExample',
full_name='propeller.SequenceExample',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='context',
full_name='propeller.SequenceExample.context',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='feature_lists',
full_name='propeller.SequenceExample.feature_lists',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=137,
serialized_end=240, )
_EXAMPLE.fields_by_name[
'features'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURES
_SEQUENCEEXAMPLE.fields_by_name[
'context'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURES
_SEQUENCEEXAMPLE.fields_by_name[
'feature_lists'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURELISTS
DESCRIPTOR.message_types_by_name['Example'] = _EXAMPLE
DESCRIPTOR.message_types_by_name['SequenceExample'] = _SEQUENCEEXAMPLE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Example = _reflection.GeneratedProtocolMessageType(
'Example',
(_message.Message, ),
dict(
DESCRIPTOR=_EXAMPLE,
__module__='propeller.paddle.data.example_pb2'
# @@protoc_insertion_point(class_scope:propeller.Example)
))
_sym_db.RegisterMessage(Example)
SequenceExample = _reflection.GeneratedProtocolMessageType(
'SequenceExample',
(_message.Message, ),
dict(
DESCRIPTOR=_SEQUENCEEXAMPLE,
__module__='propeller.paddle.data.example_pb2'
# @@protoc_insertion_point(class_scope:propeller.SequenceExample)
))
_sym_db.RegisterMessage(SequenceExample)
# @@protoc_insertion_point(module_scope)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package propeller;
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
message Features {
map<string, Feature> feature = 1;
};
message FeatureList {
repeated Feature feature = 1;
};
message FeatureLists {
map<string, FeatureList> feature_list = 1;
};
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import struct
from six.moves import zip, map
import itertools
import gzip
from functools import partial
import multiprocessing
import six
import logging
import numpy as np
from glob import glob
from propeller.paddle.train import distribution
from propeller.data.functional import interleave_func
from propeller.paddle.data.functional import Dataset
from propeller.paddle.data import example_pb2, feature_pb2
log = logging.getLogger(__name__)
__all__ = [
'FeatureColumns', 'TextColumn', 'TextIDColumn', 'LabelColumn',
'basic_tokenizer', 'Column'
]
def basic_tokenizer(sen):
seg = sen.split(b' ')
seg = filter(lambda i: i != b' ', seg)
return seg
class Column():
def __init__(self, name):
pass
def raw_to_proto(self, raw):
return feature_pb2.Feature()
@property
def output_shapes(self):
pass
@property
def output_types(self):
pass
def proto_to_instance(self, proto):
raise NotImplementedError()
def raw_to_instance(self, raw):
raise NotImplementedError()
class LabelColumn(Column):
def __init__(self, name, vocab_dict=None, vocab_file=None):
self.name = name
self.vocab = None
if vocab_file:
self.vocab = {
j.strip(): i
for i, j in enumerate(open(vocab_file, 'rb').readlines())
}
if vocab_dict:
self.vocab = vocab_dict
@property
def output_shapes(self):
return [1]
@property
def output_types(self):
return 'int64'
def raw_to_proto(self, raw):
if self.vocab is None:
ids = [int(raw)]
else:
ids = [self.vocab[raw]]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
ret = np.array(feature.int64_list.value[0], dtype=np.int64)
return ret
def raw_to_instance(self, raw):
if self.vocab is None:
ids = int(raw)
else:
ids = self.vocab[raw]
return ids
class TextColumn(Column):
def __init__(self,
name,
unk_id,
vocab_file=None,
vocab_dict=None,
tokenizer=basic_tokenizer):
self.name = name
self.tokenizer = tokenizer
self.unk_id = unk_id
if not (vocab_file or vocab_dict):
raise ValueError('at least specify vocab_file or vocab_dict')
if vocab_file:
self.vocab = {
j.strip(): i
for i, j in enumerate(open(vocab_file, 'rb').readlines())
}
if vocab_dict:
self.vocab = vocab_dict
@property
def output_shapes(self):
return [-1]
@property
def output_types(self):
return 'int64'
def raw_to_proto(self, raw):
ids = [self.vocab.get(s, self.unk_id) for s in self.tokenizer(raw)]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
ret = np.array(feature.int64_list.value, dtype=np.int64)
return ret
def raw_to_instance(self, raw):
ids = [self.vocab.get(s, self.unk_id) for s in self.tokenizer(raw)]
return np.array(ids, dtype=np.int64)
class TextIDColumn(Column):
def __init__(self, name):
self.name = name
@property
def output_shapes(self):
return [-1]
@property
def output_types(self):
return 'int64'
def raw_to_proto(self, raw):
ids = [int(s) for s in raw.split(b' ')]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
ret = np.array(feature.int64_list.value, dtype=np.int64)
return ret
def raw_to_instance(self, raw):
ret = np.array([int(i) for i in raw.split(b' ')], dtype=np.int64)
return ret
class FeatureColumns(object):
def __init__(self, columns, pad_id=0):
self._columns = columns
def raw_files(self, raw_dir):
return [os.path.join(raw_dir, p) for p in os.listdir(raw_dir)]
def gz_files(self, gz_dir):
return None if gz_dir is None else [
os.path.join(gz_dir, p) for p in os.listdir(gz_dir)
]
def _make_gz_dataset(self, raw_dir, gz_dir):
assert raw_dir or gz_dir, 'data_dir not specified when using gz mode'
if raw_dir is not None:
assert os.path.exists(raw_dir), 'raw_dir not exists: %s' % raw_dir
raw_file = os.listdir(raw_dir)
if gz_dir is None:
gz_dir = '%s_gz' % raw_dir.rstrip('/')
if not os.path.exists(gz_dir):
os.mkdir(gz_dir)
if raw_dir is not None:
if len(raw_file) != 0:
log.debug('try making gz')
pool = multiprocessing.Pool()
args = [(os.path.join(raw_dir, f), os.path.join(gz_dir, f),
self._columns, b'\t') for f in raw_file]
pool.map(_make_gz, args)
pool.terminate()
else:
assert len(
os.listdir(gz_dir)
) != 0, 'cant find gz file or raw-txt file at [%s] and [%s]' % (
raw_dir, gz_dir)
return gz_dir
def _read_gz_dataset(self,
gz_files,
shuffle=False,
repeat=True,
shard=False,
**kwargs):
if len(gz_files) == 0:
raise ValueError('reading gz from empty file list: %s' % gz_files)
log.info('reading gz from %s' % '\n'.join(gz_files))
dataset = Dataset.from_list(gz_files)
if repeat:
dataset = dataset.repeat()
if shard and distribution.status.mode == distribution.DistributionMode.NCCL:
log.info('Apply dataset sharding in distribution env')
train_ds = train_ds.shard(distribution.status.num_replica,
distribution.status.replica_id)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(gz_files))
fn = partial(
interleave_func,
map_fn=lambda filename: Dataset.from_record_file(filename),
cycle_length=len(gz_files),
block_length=1)
dataset = dataset.apply(fn)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_gz(record_str): # function that takes python_str as input
ex = example_pb2.Example()
ex.ParseFromString(record_str)
ret = []
fea_dict = ex.features.feature
for c in self._columns:
ins = c.proto_to_instance(fea_dict[c.name])
ret.append(ins)
return ret
dataset = dataset.map(_parse_gz)
return dataset
def _read_txt_dataset(self,
data_files,
shuffle=False,
repeat=True,
**kwargs):
log.info('reading raw files from %s' % '\n'.join(data_files))
dataset = Dataset.from_list(data_files)
if repeat:
dataset = dataset.repeat()
if shuffle:
dataset = dataset.shuffle(buffer_size=len(data_files))
fn = partial(
interleave_func,
map_fn=lambda filename: Dataset.from_file(filename),
cycle_length=len(data_files),
block_length=1)
dataset = dataset.apply(fn)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_txt_file(
record_str): # function that takes python_str as input
features = record_str.strip(b'\n').split(b'\t')
ret = [
column.raw_to_instance(feature)
for feature, column in zip(features, self._columns)
]
return ret
dataset = dataset.map(_parse_txt_file)
return dataset
def _read_stdin_dataset(self, encoding='utf8', shuffle=False, **kwargs):
log.info('reading raw files stdin')
def gen():
if six.PY3:
source = sys.stdin.buffer
else:
source = sys.stdin
while True:
line = source.readline()
if len(line) == 0:
break
yield line,
dataset = Dataset.from_generator_func(gen)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_stdin(record_str):
'''function that takes python_str as input'''
features = record_str.strip(b'\n').split(b'\t')
ret = [
column.raw_to_instance(feature)
for feature, column in zip(features, self._columns)
]
return ret
dataset = dataset.map(_parse_stdin)
return dataset
def _prepare_dataset(self,
dataset,
map_func_before_batch=None,
map_func_after_batch=None,
shuffle_buffer_size=None,
batch_size=1,
pad_id=0,
prefetch=None,
**kwargs):
if map_func_before_batch is not None:
dataset = dataset.map(map_func_before_batch)
if batch_size:
dataset = dataset.padded_batch(batch_size, pad_id)
if map_func_after_batch is not None:
dataset = dataset.map(map_func_after_batch)
return dataset
def build_dataset(self,
name,
use_gz=True,
data_dir=None,
gz_dir=None,
data_file=None,
**kwargs):
if use_gz:
gz_dir = self._make_gz_dataset(data_dir, gz_dir)
gz_files = self.gz_files(gz_dir)
ds = self._read_gz_dataset(gz_files, **kwargs)
else:
if data_dir is not None:
data_files = self.raw_files(data_dir)
elif data_file is not None:
data_files = [data_file]
else:
raise ValueError('data_dir or data_files not specified')
ds = self._read_txt_dataset(data_files, **kwargs)
ds.name = name
return ds
def build_dataset_from_stdin(self, name, **kwargs):
ds = self._read_stdin_dataset(**kwargs)
ds.name = name
return ds
def _make_gz(args):
try:
from_file, to_file, columns, sep = args
if os.path.exists(to_file):
return
with open(from_file, 'rb') as fin, gzip.open(to_file, 'wb') as fout:
log.debug('making gz %s => %s' % (from_file, to_file))
for i, line in enumerate(fin):
line = line.strip(b'\n').split(sep)
#if i % 10000 == 0:
# log.debug('making gz %s => %s [%d]' % (from_file, to_file, i))
if len(line) != len(columns):
log.error('columns not match at %s, got %d, expect %d' %
(from_file, len(line), len(columns)))
continue
features = {}
for l, c in zip(line, columns):
features[c.name] = c.raw_to_proto(l)
example = example_pb2.Example(features=feature_pb2.Features(
feature=features))
serialized = example.SerializeToString()
l = len(serialized)
data = struct.pack('i%ds' % l, l, serialized)
fout.write(data)
log.debug('done making gz %s => %s' % (from_file, to_file))
except Exception as e:
log.exception(e)
raise e
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: propeller/paddle/data/feature.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (
lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='propeller/paddle/data/feature.proto',
package='propeller',
syntax='proto3',
serialized_options=None,
serialized_pb=_b(
'\n#propeller/paddle/data/feature.proto\x12\tpropeller\"\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\"\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\"\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\"\x95\x01\n\x07\x46\x65\x61ture\x12*\n\nbytes_list\x18\x01 \x01(\x0b\x32\x14.propeller.BytesListH\x00\x12*\n\nfloat_list\x18\x02 \x01(\x0b\x32\x14.propeller.FloatListH\x00\x12*\n\nint64_list\x18\x03 \x01(\x0b\x32\x14.propeller.Int64ListH\x00\x42\x06\n\x04kind\"\x81\x01\n\x08\x46\x65\x61tures\x12\x31\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32 .propeller.Features.FeatureEntry\x1a\x42\n\x0c\x46\x65\x61tureEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.propeller.Feature:\x02\x38\x01\"2\n\x0b\x46\x65\x61tureList\x12#\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x12.propeller.Feature\"\x9a\x01\n\x0c\x46\x65\x61tureLists\x12>\n\x0c\x66\x65\x61ture_list\x18\x01 \x03(\x0b\x32(.propeller.FeatureLists.FeatureListEntry\x1aJ\n\x10\x46\x65\x61tureListEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.propeller.FeatureList:\x02\x38\x01\x62\x06proto3'
))
_BYTESLIST = _descriptor.Descriptor(
name='BytesList',
full_name='propeller.BytesList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.BytesList.value',
index=0,
number=1,
type=12,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=50,
serialized_end=76, )
_FLOATLIST = _descriptor.Descriptor(
name='FloatList',
full_name='propeller.FloatList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.FloatList.value',
index=0,
number=1,
type=2,
cpp_type=6,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=_b('\020\001'),
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=78,
serialized_end=108, )
_INT64LIST = _descriptor.Descriptor(
name='Int64List',
full_name='propeller.Int64List',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.Int64List.value',
index=0,
number=1,
type=3,
cpp_type=2,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=_b('\020\001'),
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=110,
serialized_end=140, )
_FEATURE = _descriptor.Descriptor(
name='Feature',
full_name='propeller.Feature',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='bytes_list',
full_name='propeller.Feature.bytes_list',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='float_list',
full_name='propeller.Feature.float_list',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='int64_list',
full_name='propeller.Feature.int64_list',
index=2,
number=3,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='kind',
full_name='propeller.Feature.kind',
index=0,
containing_type=None,
fields=[]),
],
serialized_start=143,
serialized_end=292, )
_FEATURES_FEATUREENTRY = _descriptor.Descriptor(
name='FeatureEntry',
full_name='propeller.Features.FeatureEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='propeller.Features.FeatureEntry.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.Features.FeatureEntry.value',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=_b('8\001'),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=358,
serialized_end=424, )
_FEATURES = _descriptor.Descriptor(
name='Features',
full_name='propeller.Features',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature',
full_name='propeller.Features.feature',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[_FEATURES_FEATUREENTRY, ],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=295,
serialized_end=424, )
_FEATURELIST = _descriptor.Descriptor(
name='FeatureList',
full_name='propeller.FeatureList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature',
full_name='propeller.FeatureList.feature',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=426,
serialized_end=476, )
_FEATURELISTS_FEATURELISTENTRY = _descriptor.Descriptor(
name='FeatureListEntry',
full_name='propeller.FeatureLists.FeatureListEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='propeller.FeatureLists.FeatureListEntry.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.FeatureLists.FeatureListEntry.value',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=_b('8\001'),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=559,
serialized_end=633, )
_FEATURELISTS = _descriptor.Descriptor(
name='FeatureLists',
full_name='propeller.FeatureLists',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature_list',
full_name='propeller.FeatureLists.feature_list',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[_FEATURELISTS_FEATURELISTENTRY, ],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=479,
serialized_end=633, )
_FEATURE.fields_by_name['bytes_list'].message_type = _BYTESLIST
_FEATURE.fields_by_name['float_list'].message_type = _FLOATLIST
_FEATURE.fields_by_name['int64_list'].message_type = _INT64LIST
_FEATURE.oneofs_by_name['kind'].fields.append(_FEATURE.fields_by_name[
'bytes_list'])
_FEATURE.fields_by_name[
'bytes_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURE.oneofs_by_name['kind'].fields.append(_FEATURE.fields_by_name[
'float_list'])
_FEATURE.fields_by_name[
'float_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURE.oneofs_by_name['kind'].fields.append(_FEATURE.fields_by_name[
'int64_list'])
_FEATURE.fields_by_name[
'int64_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURES_FEATUREENTRY.fields_by_name['value'].message_type = _FEATURE
_FEATURES_FEATUREENTRY.containing_type = _FEATURES
_FEATURES.fields_by_name['feature'].message_type = _FEATURES_FEATUREENTRY
_FEATURELIST.fields_by_name['feature'].message_type = _FEATURE
_FEATURELISTS_FEATURELISTENTRY.fields_by_name[
'value'].message_type = _FEATURELIST
_FEATURELISTS_FEATURELISTENTRY.containing_type = _FEATURELISTS
_FEATURELISTS.fields_by_name[
'feature_list'].message_type = _FEATURELISTS_FEATURELISTENTRY
DESCRIPTOR.message_types_by_name['BytesList'] = _BYTESLIST
DESCRIPTOR.message_types_by_name['FloatList'] = _FLOATLIST
DESCRIPTOR.message_types_by_name['Int64List'] = _INT64LIST
DESCRIPTOR.message_types_by_name['Feature'] = _FEATURE
DESCRIPTOR.message_types_by_name['Features'] = _FEATURES
DESCRIPTOR.message_types_by_name['FeatureList'] = _FEATURELIST
DESCRIPTOR.message_types_by_name['FeatureLists'] = _FEATURELISTS
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
BytesList = _reflection.GeneratedProtocolMessageType(
'BytesList',
(_message.Message, ),
dict(
DESCRIPTOR=_BYTESLIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.BytesList)
))
_sym_db.RegisterMessage(BytesList)
FloatList = _reflection.GeneratedProtocolMessageType(
'FloatList',
(_message.Message, ),
dict(
DESCRIPTOR=_FLOATLIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FloatList)
))
_sym_db.RegisterMessage(FloatList)
Int64List = _reflection.GeneratedProtocolMessageType(
'Int64List',
(_message.Message, ),
dict(
DESCRIPTOR=_INT64LIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Int64List)
))
_sym_db.RegisterMessage(Int64List)
Feature = _reflection.GeneratedProtocolMessageType(
'Feature',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURE,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Feature)
))
_sym_db.RegisterMessage(Feature)
Features = _reflection.GeneratedProtocolMessageType(
'Features',
(_message.Message, ),
dict(
FeatureEntry=_reflection.GeneratedProtocolMessageType(
'FeatureEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURES_FEATUREENTRY,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Features.FeatureEntry)
)),
DESCRIPTOR=_FEATURES,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Features)
))
_sym_db.RegisterMessage(Features)
_sym_db.RegisterMessage(Features.FeatureEntry)
FeatureList = _reflection.GeneratedProtocolMessageType(
'FeatureList',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURELIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureList)
))
_sym_db.RegisterMessage(FeatureList)
FeatureLists = _reflection.GeneratedProtocolMessageType(
'FeatureLists',
(_message.Message, ),
dict(
FeatureListEntry=_reflection.GeneratedProtocolMessageType(
'FeatureListEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURELISTS_FEATURELISTENTRY,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureLists.FeatureListEntry)
)),
DESCRIPTOR=_FEATURELISTS,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureLists)
))
_sym_db.RegisterMessage(FeatureLists)
_sym_db.RegisterMessage(FeatureLists.FeatureListEntry)
_FLOATLIST.fields_by_name['value']._options = None
_INT64LIST.fields_by_name['value']._options = None
_FEATURES_FEATUREENTRY._options = None
_FEATURELISTS_FEATURELISTENTRY._options = None
# @@protoc_insertion_point(module_scope)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller.data.functional import Dataset as DatasetBase
log = logging.getLogger(__name__)
class Dataset(DatasetBase):
def placeholders(self):
if self.name is None:
raise ValueError('can not get feature from unnamed Dataset')
ret = []
for i, (shape,
types) in enumerate(zip(self.data_shapes, self.data_types)):
ret.append(
L.data(
'%s_placeholder_%d' % (self.name, i),
shape=shape,
append_batch_size=False,
dtype=types))
return ret
def features(self):
'''start point of net building. call this in a program scope'''
if self.name is None:
raise ValueError('can not get feature from unnamed Dataset')
if len(self.data_shapes) != len(self.data_types):
raise ValueError(
'Dataset shapes and types not match: shape:%s types%s' %
(repr(self._data_shapes), repr(self._data_types)))
return self.placeholders()
def start(self, places=F.cuda_places()):
#assert self.pyreader is not None, 'use Dataset.features to build net first, then start dataset'
def gen():
try:
for idx, i in enumerate(self.generator()):
yield i
except Exception as e:
log.exception(e)
raise e
r = F.io.PyReader(
feed_list=self.placeholders(), capacity=50, iterable=True)
r.decorate_batch_generator(gen, places=places)
return r()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import logging
import six
import asyncio
import threading
import grpc
from propeller.service import interface_pb2
from propeller.service import interface_pb2_grpc
import propeller.paddle.service.utils as serv_utils
from concurrent.futures import ThreadPoolExecutor
import paddle.fluid as F
from time import sleep, time
log = logging.getLogger(__name__)
def profile(msg):
def decfn(fn):
def retfn(*args, **kwargs):
start = time()
ret = fn(*args, **kwargs)
end = time()
log.debug('%s timecost: %.5f' % (msg, end - start))
return ret
return retfn
return decfn
def serve(model_dir, host, num_concurrent=None):
if six.PY2:
raise RuntimeError('propeller service work in python3 only')
num_worker = len(F.cuda_places(
)) if num_concurrent is None else num_concurrent
pool = ThreadPoolExecutor(num_worker)
class Predictor(object):
def __init__(self, did):
log.debug('create predictor on card %d' % did)
config = F.core.AnalysisConfig(model_dir)
config.enable_use_gpu(5000, did)
self._predictor = F.core.create_paddle_predictor(config)
@profile('paddle')
def __call__(self, args):
for i, a in enumerate(args):
a.name = 'placeholder_%d' % i
res = self._predictor.run(args)
return res
predictor_context = {}
class InferenceService(interface_pb2_grpc.InferenceServicer):
@profile('service')
def Infer(self, request, context):
try:
slots = request.slots
current_thread = threading.current_thread()
log.debug('%d slots received dispatch to thread %s' %
(len(slots), current_thread))
if current_thread not in predictor_context:
did = list(pool._threads).index(current_thread)
log.debug('spawning worker thread %d' % did)
predictor = Predictor(did)
predictor_context[current_thread] = predictor
else:
predictor = predictor_context[current_thread]
slots = [serv_utils.slot_to_paddlearray(s) for s in slots]
ret = predictor(slots)
response = [serv_utils.paddlearray_to_slot(r) for r in ret]
except Exception as e:
log.exception(e)
raise e
return interface_pb2.Slots(slots=response)
server = grpc.server(pool)
interface_pb2_grpc.add_InferenceServicer_to_server(InferenceService(),
server)
server.add_insecure_port(host)
server.start()
log.info('server started on %s...' % host)
try:
while True:
sleep(100000)
except KeyboardInterrupt as e:
pass
log.info('server stoped...')
if __name__ == '__main__':
from propeller import log
log.setLevel(logging.DEBUG)
serve(
'/home/work/chenxuyi/playground/grpc_play/ernie2.0/',
'10.255.138.19:8334',
num_concurrent=3)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import struct
from propeller.service import interface_pb2
from propeller.service import interface_pb2_grpc
import paddle.fluid.core as core
def slot_to_paddlearray(slot):
if slot.type == interface_pb2.Slot.FP32:
type_str = 'f'
dtype = core.PaddleDType.FLOAT32
elif slot.type == interface_pb2.Slot.INT32:
type_str = 'i'
dtype = core.PaddleDType.INT32
elif slot.type == interface_pb2.Slot.INT64:
type_str = 'q'
dtype = core.PaddleDType.INT64
else:
raise RuntimeError('know type %s' % slot.type)
ret = core.PaddleTensor()
ret.shape = slot.dims
ret.dtype = dtype
num = len(slot.data) // struct.calcsize(type_str)
arr = struct.unpack('%d%s' % (num, type_str), slot.data)
ret.data = core.PaddleBuf(arr)
return ret
def paddlearray_to_slot(arr):
if arr.dtype == core.PaddleDType.FLOAT32:
dtype = interface_pb2.Slot.FP32
type_str = 'f'
arr_data = arr.data.float_data()
elif arr.dtype == core.PaddleDType.INT32:
dtype = interface_pb2.Slot.INT32
type_str = 'i'
arr_data = arr.data.int32_data()
elif arr.dtype == core.PaddleDType.INT64:
dtype = interface_pb2.Slot.INT64
type_str = 'q'
arr_data = arr.data.int64_data()
else:
raise RuntimeError('know type %s' % arr.dtype)
data = struct.pack('%d%s' % (len(arr_data), type_str), *arr_data)
pb = interface_pb2.Slot(type=dtype, dims=list(arr.shape), data=data)
return pb
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import paddle.fluid as F
from propeller.paddle.collection import default_collection, Key
def scalar(name, tensor):
if not isinstance(tensor, F.framework.Variable):
raise ValueError('expect paddle Variable, got %s' % repr(tensor))
tensor.persistable = True
default_collection().add(Key.SUMMARY_SCALAR, (name, tensor))
def histogram(name, tensor):
if not isinstance(tensor, F.framework.Variable):
raise ValueError('expect paddle Variable, got %s' % repr(tensor))
tensor.persistable = True
default_collection().add(Key.SUMMARY_HISTOGRAM, (name, tensor))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import logging
from time import time
log = logging.getLogger(__name__)
from propeller.paddle.train.monitored_executor import *
from propeller.paddle.train.trainer import *
from propeller.paddle.train.hooks import *
from propeller.train.model import Model
from propeller.paddle.train import exporter
from propeller.paddle.train import distribution
from propeller.paddle.train import metrics
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import functools
import six
import logging
from time import sleep
import paddle.fluid as F
import paddle.fluid.layers as L
log = logging.getLogger(__name__)
import propeller.util
__all__ = ['init_distribuition_env', 'status']
status = None
class DistributionMode(object):
LOCAL = 0
NCCL = 1
class DistributionStatus(object):
def __init__(self, config):
if config is None:
self._mode = DistributionMode.LOCAL
self._env = None
self._this = None
else:
try:
self._mode = DistributionMode.NCCL
cluster = config['cluster']
task = config['task']['type']
idx = int(config['task']['index'])
self._this = cluster[task][idx]
self._env = cluster['chief'] + cluster['worker']
if len(set(self._env)) != len(self._env):
raise ValueError('duplicate host in dis_config %s' %
config)
except KeyError as e:
raise ValueError(
'PROPELLER_DISCONFIG wrong: %s not found in %s' %
(e, repr(dis_config)))
@property
def mode(self):
return self._mode
@property
def num_replica(self):
if self._mode == DistributionMode.LOCAL:
return 1
elif self._mode == DistributionMode.NCCL:
return len(self._env)
else:
raise ValueError('Got unknow distribution mode %s' %
repr(self._mode))
@property
def replica_id(self):
if self._mode == DistributionMode.LOCAL:
return 0
elif self._mode == DistributionMode.NCCL:
return self._env.index(self._this)
else:
raise ValueError('Got unknow distribution mode %s' %
repr(self._mode))
@property
def is_master(self):
if self._mode == DistributionMode.LOCAL:
return True
elif self._mode == DistributionMode.NCCL:
return self.replica_id == 0
else:
raise ValueError('got unknow distribution mode %s' %
repr(self._mode))
dis_config = propeller.util._get_dict_from_environ_or_json_or_file(
None, 'PROPELLER_DISCONFIG')
status = DistributionStatus(dis_config)
def run_on_master(func):
"""skip function in distribution env"""
@functools.wraps(func)
def f(*arg, **kwargs):
"""f"""
if status is None:
raise ValueError('distribution mode unkown at this point')
if status.mode == DistributionMode.LOCAL:
r = func(*arg, **kwargs)
elif status.mode == DistributionMode.NCCL:
if status.is_master:
r = func(*arg, **kwargs)
else:
r = 0 # skip function
#MPI.COMM_WORLD.Barrier()
return r
return f
def init_distribuition_env(program):
if status.mode == DistributionMode.LOCAL:
log.info('Initializing local training')
elif status.mode == DistributionMode.NCCL:
config = F.DistributeTranspilerConfig()
config.mode = "nccl2"
F.DistributeTranspiler(config=config).transpile(
status.replica_id,
trainers=','.join(status._env),
current_endpoint=status._this,
program=program.train_program,
startup_program=program.startup_program)
log.info('Initializing distribution training with config %s' %
(repr(dis_config)))
if status.is_master:
sleep(30)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import os
import itertools
import six
import abc
import logging
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller.paddle.train import Saver
from propeller.types import InferenceSpec
log = logging.getLogger(__name__)
@six.add_metaclass(abc.ABCMeta)
class Exporter():
@abc.abstractmethod
def export(self, exe, program, eval_result, state):
raise NotImplementedError()
class BestExporter(Exporter):
def __init__(self, export_dir, cmp_fn):
self._export_dir = export_dir
self._best = None
self.cmp_fn = cmp_fn
def export(self, exe, program, eval_model_spec, eval_result, state):
log.debug('New evaluate result: %s \nold: %s' %
(repr(eval_result), repr(self._best)))
if self._best is None or self.cmp_fn(old=self._best, new=eval_result):
log.debug('[Best Exporter]: export to %s' % self._export_dir)
eval_program = program.train_program
# FIXME: all eval datasets has same name/types/shapes now!!! so every eval program are the smae
saver = Saver(
self._export_dir,
exe,
program=eval_program,
max_ckpt_to_keep=1)
saver.save(state)
self._best = eval_result
else:
log.debug('[Best Exporter]: skip step %s' % state.gstep)
class BestInferenceModelExporter(Exporter):
def __init__(self, export_dir, cmp_fn):
self._export_dir = export_dir
self._best = None
self.cmp_fn = cmp_fn
def export(self, exe, program, eval_model_spec, eval_result, state):
log.debug('New evaluate result: %s \nold: %s' %
(repr(eval_result), repr(self._best)))
if self._best is None or self.cmp_fn(old=self._best, new=eval_result):
log.debug('[Best Exporter]: export to %s' % self._export_dir)
if eval_model_spec.inference_spec is None:
raise ValueError('model_fn didnt return InferenceSpec')
inf_sepc_dict = eval_model_spec.inference_spec
if not isinstance(inf_sepc_dict, dict):
inf_sepc_dict = {'inference': inf_sepc_dict}
for inf_sepc_name, inf_sepc in six.iteritems(inf_sepc_dict):
if not isinstance(inf_sepc, InferenceSpec):
raise ValueError('unkonw inference spec type: %s' % v)
save_dir = os.path.join(self._export_dir, inf_sepc_name)
log.debug('[Best Exporter]: save inference model: "%s" to %s' %
(inf_sepc_name, save_dir))
feed_var = [i.name for i in inf_sepc.inputs]
fetch_var = inf_sepc.outputs
eval_program = program.train_program
startup_prog = F.Program()
F.io.save_inference_model(
save_dir,
feed_var,
fetch_var,
exe,
main_program=eval_program)
self._best = eval_result
else:
log.debug('[Best Exporter]: skip step %s' % state.gstep)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import six
import os
import itertools
import numpy as np
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller import util
from propeller.paddle.train import distribution
from propeller.paddle.train.metrics import Metrics
__all__ = [
'RunHook', 'TqdmProgressBarHook', 'TqdmNotebookProgressBarHook',
'CheckpointSaverHook', 'LoggingHook', 'StopAtStepHook', 'EvalHook'
]
log = logging.getLogger(__name__)
class RunHook(object):
def __init__(self):
pass
def before_train(self):
pass
def before_run(self, state):
return []
def after_run(self, res_list, state):
pass
def should_stop(self, state):
return False
def after_train(self):
pass
class TqdmProgressBarHook(RunHook):
def __init__(self, max_steps, desc=None):
self.tqdm = None
import tqdm
from propeller import log as main_log
hdl = main_log.handlers[0]
class TqdmLogginHandler(logging.Handler):
def emit(self, record):
try:
msg = self.format(record)
tqdm.tqdm.write(msg, file=sys.stderr)
self.flush()
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)
tqdm_hdl = TqdmLogginHandler()
tqdm_hdl.setFormatter(hdl.formatter)
main_log.removeHandler(hdl)
main_log.addHandler(tqdm_hdl)
self.tqdm = tqdm.tqdm(total=max_steps, desc=None)
def before_run(self, state):
self.tqdm.n = state.gstep
return []
def __del__(self):
if self.tqdm:
self.tqdm.close()
class TqdmNotebookProgressBarHook(RunHook):
def __init__(self, max_steps, desc=None):
self.tqdm = None
import tqdm
from propeller import log as main_log
hdl = main_log.handlers[0]
class TqdmLogginHandler(logging.Handler):
def emit(self, record):
try:
msg = self.format(record)
tqdm.tqdm.write(msg, file=sys.stderr)
self.flush()
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)
tqdm_hdl = TqdmLogginHandler()
tqdm_hdl.setFormatter(hdl.formatter)
main_log.removeHandler(hdl)
main_log.addHandler(tqdm_hdl)
self.tqdm = tqdm.tqdm_notebook(total=max_steps, desc=None)
def before_run(self, state):
self.tqdm.n = state.gstep
self.tqdm.refresh()
return []
def __del__(self):
if self.tqdm:
self.tqdm.close()
class LoggingHook(RunHook):
def __init__(self,
loss,
per_step=10,
skip_step=100,
summary_writer=None,
summary_record=None):
if per_step is None or skip_step is None:
raise ValueError('wrong step argument, per step: %d skip_step %d' %
(per_step, skip_step))
self.loss = loss
self.per_step = per_step
self.skip_step = skip_step
self.summary_record = summary_record
self.writer = summary_writer
self.last_state = None
def before_train(self):
if self.summary_record:
if self.summary_record.scalar:
self.s_name, self.s_tolog = zip(*self.summary_record.scalar)
else:
self.s_name, self.s_tolog = [], []
if self.summary_record.histogram:
self.h_name, self.h_tolog = zip(*self.summary_record.histogram)
else:
self.h_name, self.h_tolog = [], []
def before_run(self, state):
if state.gstep % self.per_step == 0 and state.step > self.skip_step:
ret = [self.loss]
if self.summary_record:
ret += self.s_tolog
ret += self.h_tolog
return ret
else:
return []
def after_run(self, res_list, state):
if state.gstep % self.per_step == 0 and state.step > self.skip_step:
if not self.summary_record:
return
loss = float(res_list[0])
s_np = res_list[1:1 + len(self.s_name)]
h_np = res_list[1 + len(self.s_name):1 + len(self.s_name) + len(
self.h_name)]
if self.last_state is not None:
speed = (state.gstep - self.last_state.gstep) / (
state.time - self.last_state.time)
else:
speed = -1.
self.last_state = state
# log to tensorboard
if self.writer is not None:
self.writer.add_scalar('loss', loss, state.gstep)
for name, t in zip(self.s_name, s_np):
if np.isnan(t).any():
log.warning('Nan summary: %s, skip' % name)
else:
self.writer.add_scalar(name, t, state.gstep)
for name, t in zip(self.h_name, h_np):
if np.isnan(t).any():
log.warning('Nan summary: %s, skip' % name)
else:
self.writer.add_histogram(name, t, state.gstep)
if speed > 0.:
self.writer.add_scalar('global_step', speed, state.gstep)
# log to stdout
log.debug('\t'.join([
'step: %d' % state.gstep,
'steps/sec: %.5f' % speed,
'loss: %.5f' % loss,
'' if self.summary_record is None else ' '.join(
map(lambda t: '%s:%s' % t, zip(self.s_name, s_np))),
]))
class StopAtStepHook(RunHook):
def __init__(self, stop_global_step, stop_step):
self._stop_gstep = stop_global_step
self._stop_step = stop_step
def should_stop(self, state):
if (self._stop_gstep and state.gstep >= self._stop_gstep) or \
(self._stop_step and state.step >= self._stop_step):
log.info('StopAtStepHook called stop')
return True
else:
return False
class EvalHook(RunHook):
"""hook this on a eval Executor"""
def __init__(self, metrics, summary_writer=None):
self.writer = summary_writer
self._result = None
if not isinstance(metrics, dict):
raise ValueError('metrics should be dict, got %s' % repr(metrics))
for k, m in six.iteritems(metrics):
if not isinstance(m, Metrics):
raise ValueError(
'metrics %s should be instance of propeller.Metrics, got %s'
% (k, repr(m)))
if len(metrics):
self.names = list(metrics.keys())
self.metrics = list(metrics.values())
else:
self.names, self.metrics = [], []
def before_train(self):
for m in self.metrics:
m.reset()
def before_run(self, state):
ls = [m.tensor for m in self.metrics]
for i in ls:
if not (isinstance(i, list) or isinstance(i, tuple)):
raise ValueError(
'metrics should return tuple or list of tensors, got %s' %
repr(i))
for ii in i:
if not isinstance(ii, F.framework.Variable):
raise ValueError(
'metrics tensor be propeller.train.Metrics, got %s of type %s'
% (repr(ii), type(ii)))
ls_flt, self.schema = util.flatten(ls)
#log.debug(ls_flt)
return ls_flt
def after_run(self, res_list, state):
res = util.unflatten(res_list, self.schema)
for r, m in zip(res, self.metrics):
m.update(r)
@property
def result(self):
return self._result
def after_train(self):
printable = []
self._result = {}
for n, m in zip(self.names, self.metrics):
val = m.eval()
self._result[n] = val
return self.result
class CheckpointSaverHook(RunHook):
def __init__(self, saver, per_step=10, skip_step=100):
self.saver = saver
self.per_step = per_step
self.skip_step = skip_step
def after_run(self, res_list, state):
if state.gstep % self.per_step == 0 and \
state.step > self.skip_step:
self.saver.save(state)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import numpy as np
import itertools
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
import sklearn.metrics
log = logging.getLogger(__name__)
__all__ = [
'Metrics', 'F1', 'Recall', 'Precision', 'Mrr', 'Mean', 'Acc', 'ChunkF1',
'RecallAtPrecision'
]
class Metrics(object):
def __init__(self):
self.saver = []
@property
def tensor(self):
pass
def update(self, *args):
pass
def eval(self):
pass
class Mean(Metrics):
def __init__(self, t):
self.t = t
self.reset()
def reset(self):
self.saver = np.array([])
@property
def tensor(self):
self.t.persistable = True
return self.t,
def update(self, args):
t, = args
t = t.reshape([-1])
self.saver = np.concatenate([self.saver, t])
def eval(self):
return self.saver.mean()
class Ppl(Mean):
def eval(self):
return np.exp(self.saver.mean())
class Acc(Mean):
def __init__(self, label, pred):
self.eq = L.equal(pred, label)
self.reset()
@property
def tensor(self):
self.eq.persistable = True
return self.eq,
class MSE(Mean):
def __init__(self, label, pred):
diff = pred - label
self.mse = diff * diff
self.reset()
@property
def tensor(self):
self.mse.persistable = True
return self.mse,
class Cosine(Mean):
def __init__(self, label, pred):
self.cos = L.cos_sim(label, pred)
self.reset()
@property
def tensor(self):
self.cos.persistable = True
return self.cos,
class Precision(Metrics):
def __init__(self, label, pred):
self.label = label
self.pred = pred
self.reset()
def reset(self):
self.label_saver = np.array([], dtype=np.bool)
self.pred_saver = np.array([], dtype=np.bool)
@property
def tensor(self):
self.label.persistable = True
self.pred.persistable = True
return self.label, self.pred
def update(self, args):
label, pred = args
label = label.reshape([-1]).astype(np.bool)
pred = pred.reshape([-1]).astype(np.bool)
if label.shape != pred.shape:
raise ValueError(
'Metrics precesion: input not match: label:%s pred:%s' %
(label, pred))
self.label_saver = np.concatenate([self.label_saver, label])
self.pred_saver = np.concatenate([self.pred_saver, pred])
def eval(self):
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
t = self.label_saver.astype(np.int64).sum()
return tp / t
class Recall(Precision):
def eval(self):
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
p = (self.label_saver).astype(np.int64).sum()
return tp / p
class F1(Precision):
def eval(self):
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
t = self.label_saver.astype(np.int64).sum()
p = self.pred_saver.astype(np.int64).sum()
precision = tp / (t + 1.e-6)
recall = tp / (p + 1.e-6)
return 2 * precision * recall / (precision + recall + 1.e-6)
class Auc(Metrics):
def __init__(self, label, pred):
self.pred = pred
self.label = label
self.reset()
def reset(self):
self.pred_saver = np.array([], dtype=np.float32)
self.label_saver = np.array([], dtype=np.bool)
@property
def tensor(self):
self.pred.persistable = True
self.label.persistable = True
return [self.pred, self.label]
def update(self, args):
pred, label = args
pred = pred.reshape([-1]).astype(np.float32)
label = label.reshape([-1]).astype(np.bool)
self.pred_saver = np.concatenate([self.pred_saver, pred])
self.label_saver = np.concatenate([self.label_saver, label])
def eval(self):
fpr, tpr, thresholds = sklearn.metrics.roc_curve(
self.label_saver.astype(np.int64), self.pred_saver)
auc = sklearn.metrics.auc(fpr, tpr)
return auc
class RecallAtPrecision(Auc):
def __init__(self, label, pred, precision=0.9):
super(RecallAtPrecision, self).__init__(label, pred)
self.precision = precision
def eval(self):
self.pred_saver = self.pred_saver.reshape(
[self.label_saver.size, -1])[:, -1]
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
self.label_saver, self.pred_saver)
for p, r in zip(precision, recall):
if p > self.precision:
return r
class PrecisionAtThreshold(Auc):
def __init__(self, label, pred, threshold=0.5):
super().__init__(label, pred)
self.threshold = threshold
def eval(self):
infered = self.pred_saver > self.threshold
correct_num = np.array(infered & self.label_saver).sum()
infer_num = infered.sum()
return correct_num / (infer_num + 1.e-6)
class Mrr(Metrics):
def __init__(self, qid, label, pred):
self.qid = qid
self.label = label
self.pred = pred
self.reset()
def reset(self):
self.qid_saver = np.array([], dtype=np.int64)
self.label_saver = np.array([], dtype=np.int64)
self.pred_saver = np.array([], dtype=np.float32)
@property
def tensor(self):
self.qid.persistable = True
self.label.persistable = True
self.pred.persistable = True
return [self.qid, self.label, self.pred]
def update(self, args):
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError(
'Mrr dimention not match: qid[%s] label[%s], pred[%s]' %
(qid.shape, label.shape, pred.shape))
self.qid_saver = np.concatenate(
[self.qid_saver, qid.reshape([-1]).astype(np.int64)])
self.label_saver = np.concatenate(
[self.label_saver, label.reshape([-1]).astype(np.int64)])
self.pred_saver = np.concatenate(
[self.pred_saver, pred.reshape([-1]).astype(np.float32)])
def eval(self):
def key_func(tup):
return tup[0]
def calc_func(tup):
ranks = [
1. / (rank + 1.)
for rank, (_, l, p) in enumerate(
sorted(
tup, key=lambda t: t[2], reverse=True)) if l != 0
]
ranks = ranks[0]
return ranks
mrr_for_qid = [
calc_func(tup)
for _, tup in itertools.groupby(
sorted(
zip(self.qid_saver, self.label_saver, self.pred_saver),
key=key_func),
key=key_func)
]
mrr = np.float32(sum(mrr_for_qid) / len(mrr_for_qid))
return mrr
class ChunkF1(Metrics):
def __init__(self, label, pred, seqlen, num_label):
self.label = label
self.pred = pred
self.seqlen = seqlen
self.null_index = num_label - 1
self.label_cnt = 0
self.pred_cnt = 0
self.correct_cnt = 0
def _extract_bio_chunk(self, seq):
chunks = []
cur_chunk = None
for index in range(len(seq)):
tag = seq[index]
tag_type = tag // 2
tag_pos = tag % 2
if tag == self.null_index:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = None
continue
if tag_pos == 0:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = {}
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
else:
if cur_chunk is None:
cur_chunk = {
"st": index,
"en": index + 1,
"type": tag_type
}
continue
if cur_chunk["type"] == tag_type:
cur_chunk["en"] = index + 1
else:
chunks.append(cur_chunk)
cur_chunk = {
"st": index,
"en": index + 1,
"type": tag_type
}
if cur_chunk is not None:
chunks.append(cur_chunk)
return chunks
def reset(self):
self.label_cnt = 0
self.pred_cnt = 0
self.correct_cnt = 0
@property
def tensor(self):
self.pred.persistable = True
self.label.persistable = True
self.seqlen.persistable = True
return [self.pred, self.label, self.seqlen]
def update(self, args):
pred, label, seqlen = args
pred = pred.reshape([-1]).astype(np.int32).tolist()
label = label.reshape([-1]).astype(np.int32).tolist()
seqlen = seqlen.reshape([-1]).astype(np.int32).tolist()
max_len = 0
for l in seqlen:
max_len = max(max_len, l)
for i in range(len(seqlen)):
seq_st = i * max_len + 1
seq_en = seq_st + (seqlen[i] - 2)
pred_chunks = self._extract_bio_chunk(pred[seq_st:seq_en])
label_chunks = self._extract_bio_chunk(label[seq_st:seq_en])
self.pred_cnt += len(pred_chunks)
self.label_cnt += len(label_chunks)
pred_index = 0
label_index = 0
while label_index < len(label_chunks) and pred_index < len(
pred_chunks):
if pred_chunks[pred_index]['st'] < label_chunks[label_index][
'st']:
pred_index += 1
elif pred_chunks[pred_index]['st'] > label_chunks[label_index][
'st']:
label_index += 1
else:
if pred_chunks[pred_index]['en'] == label_chunks[label_index]['en'] \
and pred_chunks[pred_index]['type'] == label_chunks[label_index]['type']:
self.correct_cnt += 1
pred_index += 1
label_index += 1
def eval(self):
if self.pred_cnt == 0:
precision = 0.0
else:
precision = 1.0 * self.correct_cnt / self.pred_cnt
if self.label_cnt == 0:
recall = 0.0
else:
recall = 1.0 * self.correct_cnt / self.label_cnt
if self.correct_cnt == 0:
f1 = 0.0
else:
f1 = 2 * precision * recall / (precision + recall)
return np.float32(f1)
class PNRatio(Metrics):
def __init__(self, qid, label, pred):
self.qid = qid
self.label = label
self.pred = pred
self.saver = {}
def reset(self):
self.saver = {}
@property
def tensor(self):
self.qid.persistable = True
self.label.persistable = True
self.pred.persistable = True
return [self.qid, self.label, self.pred]
def update(self, args):
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]'
% (qid.shape, label.shape, pred.shape))
qid = qid.reshape([-1]).tolist()
label = label.reshape([-1]).tolist()
pred = pred.reshape([-1]).tolist()
assert len(qid) == len(label) == len(pred)
for q, l, p in zip(qid, label, pred):
if q not in self.saver:
self.saver[q] = []
self.saver[q].append((l, p))
def eval(self):
p = 0
n = 0
for qid, outputs in self.saver.items():
for i in range(0, len(outputs)):
l1, p1 = outputs[i]
for j in range(i + 1, len(outputs)):
l2, p2 = outputs[j]
if l1 > l2:
if p1 > p2:
p += 1
elif p1 < p2:
n += 1
elif l1 < l2:
if p1 < p2:
p += 1
elif p1 > p2:
n += 1
pn = p / n if n > 0 else 0.0
return np.float32(pn)
class BinaryPNRatio(PNRatio):
def __init__(self, qid, label, pred):
super(BinaryPNRatio, self).__init__(qid, label, pred)
def eval(self):
p = 0
n = 0
for qid, outputs in self.saver.items():
pos_set = []
neg_set = []
for label, score in outputs:
if label == 1:
pos_set.append(score)
else:
neg_set.append(score)
for ps in pos_set:
for ns in neg_set:
if ps > ns:
p += 1
elif ps < ns:
n += 1
else:
continue
pn = p / n if n > 0 else 0.0
return np.float32(pn)
class PrecisionAtK(Metrics):
def __init__(self, qid, label, pred, k=1):
self.qid = qid
self.label = label
self.pred = pred
self.k = k
self.saver = {}
def reset(self):
self.saver = {}
@property
def tensor(self):
self.qid.persistable = True
self.label.persistable = True
self.pred.persistable = True
return [self.qid, self.label, self.pred]
def update(self, args):
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]'
% (qid.shape, label.shape, pred.shape))
qid = qid.reshape([-1]).tolist()
label = label.reshape([-1]).tolist()
pred = pred.reshape([-1]).tolist()
assert len(qid) == len(label) == len(pred)
for q, l, p in zip(qid, label, pred):
if q not in self.saver:
self.saver[q] = []
self.saver[q].append((l, p))
def eval(self):
right = 0
total = 0
for v in self.saver.values():
v = sorted(v, key=lambda x: x[1], reverse=True)
k = min(self.k, len(v))
for i in range(k):
if v[i][0] == 1:
right += 1
break
total += 1
return np.float32(1.0 * right / total)
#class SemanticRecallMetrics(Metrics):
# def __init__(self, qid, vec, type_id):
# self.qid = qid
# self.vec = vec
# self.type_id = type_id
# self.reset()
#
# def reset(self):
# self.saver = []
#
# @property
# def tensor(self):
# return [self.qid, self.vec, self.type_id]
#
# def update(self, args):
# qid, vec, type_id = args
# self.saver.append((qid, vec, type_id))
#
# def eval(self):
# dic = {}
# for qid, vec, type_id in self.saver():
# dic.setdefault(i, {}).setdefault(k, []).append(vec)
#
# for qid in dic:
# assert len(dic[qid]) == 3
# qvec = np.arrray(dic[qid][0])
# assert len(qvec) == 1
# ptvec = np.array(dic[qid][1])
# ntvec = np.array(dic[qid][2])
#
# np.matmul(qvec, np.transpose(ptvec))
# np.matmul(qvec, np.transpose(ntvec))
#
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import json
from functools import reduce
import six
from time import time
import shutil
import logging
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller import util
from propeller.types import StopException, ProgramPair
from propeller.paddle.train import hooks
from . import distribution
log = logging.getLogger(__name__)
__all__ = ['MonitoredExecutor', 'Saver']
class RunState(object):
@classmethod
def from_str(cls, s):
j = json.loads(s)
ret = RunState()
ret._gstep = j['global_step']
ret._time = j['time']
ret._step = 0
return ret
def __init__(self):
self._gstep = 0
self._step = 0
self._time = time()
@property
def gstep(self):
return self._gstep
@property
def step(self):
return self._step
@property
def time(self):
return self._time
def __repr__(self):
return repr({'global_step': self._gstep, 'time': self._time})
def serialize(self):
return json.dumps({'global_step': self._gstep, 'time': self._time})
def next(self):
ret = RunState()
ret._gstep = self._gstep + 1
ret._step = self._step + 1
ret._time = time()
return ret
class Saver(object):
def __init__(self,
save_dir,
exe,
program,
save_prefix='model',
max_ckpt_to_keep=None):
if exe is not None:
assert isinstance(
exe, F.Executor
), 'expect normal executor to save, got executor of type %s' % repr(
type(exe))
self._exe = exe
self._program = program
self._save_dir = save_dir
self._save_prefix = save_prefix
self._max_ckpt_to_keep = 10 if max_ckpt_to_keep is None else max_ckpt_to_keep
self.ckpt_info_path = os.path.join(save_dir, 'ckpt_info')
if os.path.exists(self.ckpt_info_path):
self.ckpt_list = [
p.strip() for p in open(self.ckpt_info_path).readlines()
]
log.debug('ckpt_list in this Saver: %s' % (self.ckpt_list))
else:
self.ckpt_list = []
@property
def last_ckpt(self):
return self.ckpt_list[-1] if len(self.ckpt_list) else None
def save(self, state):
save_name = '%s_%d' % (self._save_prefix, state.gstep)
save_dir = os.path.join(self._save_dir, save_name)
tmp_dir = os.path.join(self._save_dir, 'tmp')
try:
shutil.rmtree(save_dir)
shutil.rmtree(tmp_dir)
except OSError:
pass
log.debug('saving step %d to %s' % (state.gstep, save_dir))
F.io.save_persistables(self._exe, tmp_dir, self._program)
shutil.move(tmp_dir, save_dir)
meta = state.serialize()
open(os.path.join(save_dir, 'meta'), 'w').write(meta)
self.ckpt_list.append(save_name)
if len(self.ckpt_list) > self._max_ckpt_to_keep:
ckpt_to_keep = self.ckpt_list[-self._max_ckpt_to_keep:]
ckpt_to_remove = set(self.ckpt_list) - set(ckpt_to_keep)
self.ckpt_list = ckpt_to_keep
for ckpt in ckpt_to_remove:
ckpt_dir = os.path.join(self._save_dir, ckpt)
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir)
log.debug('No. of ckpt exceed %d, clean up: %s' %
(self._max_ckpt_to_keep, ckpt_dir))
open(self.ckpt_info_path, 'w').write('\n'.join(self.ckpt_list))
def restore(self, ckpt=-1):
if not isinstance(ckpt, (int, ) + six.string_types):
raise ValueError('ckpt type not understood %s' % repr(ckpt))
if isinstance(ckpt, int):
try:
ckpt = self.ckpt_list[ckpt]
except IndexError:
raise ValueError('invalid restore ckpt number %d' % ckpt)
if isinstance(ckpt, six.string_types):
try:
ckpt = self.ckpt_list.index(ckpt)
except ValueError:
raise ValueError('ckpt: %s not in ckpt list: %s' %
(ckpt, self.ckpt_list))
path = os.path.join(self._save_dir, self.ckpt_list[ckpt])
meta_file = os.path.join(path, 'meta')
if not os.path.exists(meta_file):
raise RuntimeError('meta not found in restore dir: %s' % path)
state = RunState.from_str(open(meta_file).read())
log.info('restore from ckpt %s, ckpt-status: %s' % (path, repr(state)))
def fn(v):
vpath = os.path.join(path, v.name)
if F.io.is_persistable(v):
if os.path.exists(vpath):
return True
else:
log.warning('var %s not found in checkpoint, ignored' %
v.name)
return False
F.io.load_vars(
self._exe, path, main_program=self._program, predicate=fn)
return state
class MonitoredExecutor(object):
"""A wrapper handling the train loop"""
def __init__(
self,
executor,
program,
loss=None, #must set in train
state=None,
run_config=None, #none if not load
run_hooks=[],
warm_start_setting=None):
if not isinstance(executor, F.Executor):
raise ValueError('PE is no longer supported')
if isinstance(executor, F.ParallelExecutor):
raise ValueError('ParallelExecutor is deprecatd, use Executor')
self._exe = executor
self._hooks = run_hooks
self._state = RunState() # might be overwrite in freeze
self._program = program
self._loss = loss
self._warm_start_setting = warm_start_setting
self._saver = None # will set in prepare
self.result = None # will set after train
if run_config is not None:
self._model_dir = run_config.model_dir
self._save_dir = run_config.model_dir
self._save_steps = run_config.save_steps
self._skip_steps = run_config.skip_steps if run_config.skip_steps else 100
self._save_prefix = 'model'
self._max_ckpt = run_config.max_ckpt
@property
def state(self):
return self._state
def init_or_restore_variables(self):
# The order of this 2 steps really matters
# 1. init train
F.Executor(F.cuda_places()[0]).run(self._program.startup_program)
# 2. restore param
if self._warm_start_setting is not None:
if not os.path.exists(self._warm_start_setting.from_dir):
raise ValueError('warm start dir not exists: %s' %
self._warm_start_setting.from_dir)
log.info("warm start from %s" % self._warm_start_setting.from_dir)
if self._warm_start_setting.predicate_fn is not None:
def fn(v):
ret = self._warm_start_setting.predicate_fn(v)
if ret:
log.info('warm start: %s' % v.name)
return ret
F.io.load_vars(
F.Executor(F.cuda_places()[0]),
self._warm_start_setting.from_dir,
main_program=self._program.train_program,
predicate=fn)
else:
raise NotImplementedError()
self._saver = Saver(
self._model_dir,
F.Executor(F.cuda_places()[0]),
program=self._program.train_program,
max_ckpt_to_keep=self._max_ckpt)
if self._saver.last_ckpt is not None:
self._state = self._saver.restore()
def freeze(self):
if self._loss is None:
log.debug('will not freeze a program without loss')
return
if isinstance(self._program.train_program, F.compiler.CompiledProgram):
log.debug('program has already been built')
return
exec_strategy = F.ExecutionStrategy()
exec_strategy.num_threads = 4 #2 for fp32 4 for fp16
exec_strategy.use_experimental_executor = True
exec_strategy.num_iteration_per_drop_scope = 10 #important shit
build_strategy = F.BuildStrategy()
build_strategy.remove_unnecessary_lock = False
#build_strategy.fuse_broadcast_ops = True
build_strategy.num_trainers = distribution.status.num_replica
build_strategy.trainer_id = distribution.status.replica_id
build_strategy.memory_optimize = True
log.info('replica id %d of %d' % (distribution.status.replica_id,
distribution.status.num_replica))
program = F.CompiledProgram(
self._program.train_program).with_data_parallel(
loss_name=self._loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
self._program = ProgramPair(
train_program=program,
startup_program=self._program.startup_program)
def __enter__(self):
log.debug('freezing program')
self.freeze()
log.debug('done freezing')
log.info('********** Start Loop ************')
# TODO init
self.result = None
for h in self._hooks:
log.debug('train loop has hook %s' % h)
h.before_train()
return self
def run(self, fetch_list=[], *args, **kwargs):
#log.debug('Executor running step %d' % self._state.gstep)
if self._hooks:
fetch_list = [fetch_list]
for h in self._hooks:
#log.debug('calling hook.before_run %s' % h)
fetch = h.before_run(self._state)
fetch_list.append(fetch)
fetch_list_len = map(len, fetch_list)
fetch_list, schema = util.flatten(fetch_list)
fetch_list = [
f.name if not isinstance(f, six.string_types) else f
for f in fetch_list
]
#if len(set(fetch_list)) != len(fetch_list):
# log.error('strange shit happend when fetch list has idetity tensors %s' % fetch_list)
res = self._exe.run(self._program.train_program,
fetch_list=fetch_list,
*args,
**kwargs)
res = [self.merge_result(r) for r in res]
#log.debug(res)
res = util.unflatten(res, schema)
ret, res = res[0], res[1:]
for r, h in zip(res, self._hooks):
#log.debug('calling hook.after_run')
h.after_run(r, self._state)
if any(map(lambda i: i.should_stop(self._state), self._hooks)):
raise StopException('hook call stop')
else:
ret = self._exe.run(self._program.train_program,
fetch_list=fetch_list,
*args,
**kwargs)
self._state = self._state.next()
return ret
def __exit__(self, err_type, err_value, trace):
if (err_type is None) or isinstance(err_value, (
F.core.EOFException, StopException, KeyboardInterrupt)):
try:
log.info('********** Stop Loop ************')
self.result = []
for h in self._hooks:
self.result.append(h.after_train())
except Exception as e:
log.exception('error occur after loop %s' % repr(e))
else:
log.info('********** Interupt Loop ************')
log.exception('error occur during loop %s: %s' %
(err_type, err_value))
def merge_result(self, ls):
dev_count = len(self._program.train_program._places) if isinstance(
self._program.train_program, F.compiler.CompiledProgram) else 1
if dev_count == 1:
return ls
else:
shape = (-1, ls.shape[0] // dev_count) + ls.shape[1:]
ret = np.reshape(ls, shape).mean(axis=0)
return ret
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import asyncio
import threading
import math
import zmq
import zmq.asyncio
import numpy as np
from propeller import log
import propeller.service.utils as serv_utils
class InferenceBaseClient(object):
def __init__(self, address):
self.context = zmq.Context()
self.address = address
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(address)
log.info("Connecting to server... %s" % address)
def __call__(self, *args):
for arg in args:
if not isinstance(arg, np.ndarray):
raise ValueError('expect ndarray slot data, got %s' %
repr(arg))
request = serv_utils.nparray_list_serialize(args)
self.socket.send(request)
reply = self.socket.recv()
ret = serv_utils.nparray_list_deserialize(reply)
return ret
class InferenceClient(InferenceBaseClient):
def __init__(self, address, batch_size=128, num_coroutine=10, timeout=10.):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
context = zmq.asyncio.Context()
self.socket_pool = [
context.socket(zmq.REQ) for _ in range(num_coroutine)
]
log.info("Connecting to server... %s" % address)
for socket in self.socket_pool:
socket.connect(address)
self.num_coroutine = num_coroutine
self.batch_size = batch_size
self.timeout = int(timeout * 1000)
#yapf: disable
def __call__(self, *args):
for arg in args:
if not isinstance(arg, np.ndarray):
raise ValueError('expect ndarray slot data, got %s' %
repr(arg))
num_tasks = math.ceil(1. * args[0].shape[0] / self.batch_size)
rets = [None] * num_tasks
async def get(coroutine_idx=0, num_coroutine=1):
socket = self.socket_pool[coroutine_idx]
while coroutine_idx < num_tasks:
begin = coroutine_idx * self.batch_size
end = (coroutine_idx + 1) * self.batch_size
arr_list = [arg[begin:end] for arg in args]
request = serv_utils.nparray_list_serialize(arr_list)
try:
await socket.send(request)
await socket.poll(self.timeout, zmq.POLLIN)
reply = await socket.recv(zmq.NOBLOCK)
ret = serv_utils.nparray_list_deserialize(reply)
except Exception as e:
log.exception(e)
ret = None
rets[coroutine_idx] = ret
coroutine_idx += num_coroutine
futures = [
get(i, self.num_coroutine) for i in range(self.num_coroutine)
]
self.loop.run_until_complete(asyncio.wait(futures))
for r in rets:
if r is None:
raise RuntimeError('Client call failed')
return [np.concatenate(col, 0) for col in zip(*rets)]
#yapf: enable
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package interface;
service Inference {
rpc Infer(Slots) returns (Slots){}
}
message Slots {
repeated Slot slots = 1;
}
message Slot {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
// Tensor<size_t> is used in C++.
SIZE_T = 19;
UINT8 = 20;
INT8 = 21;
}
Type type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
bytes data = 3;
}
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: interface.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (
lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='interface.proto',
package='interface',
syntax='proto3',
serialized_options=None,
serialized_pb=_b(
'\n\x0finterface.proto\x12\tinterface\"\'\n\x05Slots\x12\x1e\n\x05slots\x18\x01 \x03(\x0b\x32\x0f.interface.Slot\"\xb8\x01\n\x04Slot\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.interface.Slot.Type\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"p\n\x04Type\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06\x12\n\n\x06SIZE_T\x10\x13\x12\t\n\x05UINT8\x10\x14\x12\x08\n\x04INT8\x10\x15\x32:\n\tInference\x12-\n\x05Infer\x12\x10.interface.Slots\x1a\x10.interface.Slots\"\x00\x62\x06proto3'
))
_SLOT_TYPE = _descriptor.EnumDescriptor(
name='Type',
full_name='interface.Slot.Type',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='BOOL', index=0, number=0, serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT16',
index=1,
number=1,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT32',
index=2,
number=2,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT64',
index=3,
number=3,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP16', index=4, number=4, serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP32', index=5, number=5, serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP64', index=6, number=6, serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='SIZE_T',
index=7,
number=19,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='UINT8',
index=8,
number=20,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT8',
index=9,
number=21,
serialized_options=None,
type=None),
],
containing_type=None,
serialized_options=None,
serialized_start=144,
serialized_end=256, )
_sym_db.RegisterEnumDescriptor(_SLOT_TYPE)
_SLOTS = _descriptor.Descriptor(
name='Slots',
full_name='interface.Slots',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='slots',
full_name='interface.Slots.slots',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=30,
serialized_end=69, )
_SLOT = _descriptor.Descriptor(
name='Slot',
full_name='interface.Slot',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='interface.Slot.type',
index=0,
number=1,
type=14,
cpp_type=8,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='dims',
full_name='interface.Slot.dims',
index=1,
number=2,
type=3,
cpp_type=2,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='data',
full_name='interface.Slot.data',
index=2,
number=3,
type=12,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b(""),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[_SLOT_TYPE, ],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=72,
serialized_end=256, )
_SLOTS.fields_by_name['slots'].message_type = _SLOT
_SLOT.fields_by_name['type'].enum_type = _SLOT_TYPE
_SLOT_TYPE.containing_type = _SLOT
DESCRIPTOR.message_types_by_name['Slots'] = _SLOTS
DESCRIPTOR.message_types_by_name['Slot'] = _SLOT
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Slots = _reflection.GeneratedProtocolMessageType(
'Slots',
(_message.Message, ),
{
'DESCRIPTOR': _SLOTS,
'__module__': 'interface_pb2'
# @@protoc_insertion_point(class_scope:interface.Slots)
})
_sym_db.RegisterMessage(Slots)
Slot = _reflection.GeneratedProtocolMessageType(
'Slot',
(_message.Message, ),
{
'DESCRIPTOR': _SLOT,
'__module__': 'interface_pb2'
# @@protoc_insertion_point(class_scope:interface.Slot)
})
_sym_db.RegisterMessage(Slot)
_INFERENCE = _descriptor.ServiceDescriptor(
name='Inference',
full_name='interface.Inference',
file=DESCRIPTOR,
index=0,
serialized_options=None,
serialized_start=258,
serialized_end=316,
methods=[
_descriptor.MethodDescriptor(
name='Infer',
full_name='interface.Inference.Infer',
index=0,
containing_service=None,
input_type=_SLOTS,
output_type=_SLOTS,
serialized_options=None, ),
])
_sym_db.RegisterServiceDescriptor(_INFERENCE)
DESCRIPTOR.services_by_name['Inference'] = _INFERENCE
# @@protoc_insertion_point(module_scope)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import os
import logging
import six
from time import sleep, time
import multiprocessing
import zmq
""" Never Never Never import paddle.fluid in main process, or any module would import fluid.
"""
log = logging.getLogger(__name__)
def profile(msg):
def decfn(fn):
def retfn(*args, **kwargs):
start = time()
ret = fn(*args, **kwargs)
end = time()
log.debug('%s timecost: %.5f' % (msg, end - start))
return ret
return retfn
return decfn
class Predictor(object):
def __init__(self, model_dir, device_idx=0):
import paddle.fluid as F
log.debug('create predictor on card %d' % device_idx)
config = F.core.AnalysisConfig(model_dir)
config.enable_use_gpu(5000, device_idx)
self._predictor = F.core.create_paddle_predictor(config)
@profile('paddle')
def __call__(self, args):
for i, a in enumerate(args):
a.name = 'placeholder_%d' % i
res = self._predictor.run(args)
return res
def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"):
try:
log.debug("run_worker %s" % device_idx)
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv(
"CUDA_VISIBLE_DEVICES").split(",")[device_idx]
log.debug('cuda_env %s' % os.environ["CUDA_VISIBLE_DEVICES"])
import paddle.fluid as F
from propeller.service import interface_pb2
import propeller.service.utils as serv_utils
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.connect(endpoint)
#socket.bind(endpoint)
log.debug("Predictor building %s" % device_idx)
predictor = Predictor(model_dir, 0)
log.debug("Predictor %s" % device_idx)
except Exception as e:
log.exception(e)
while True:
# Wait for next request from client
try:
message = socket.recv()
log.debug("get message %s" % device_idx)
slots = interface_pb2.Slots()
slots.ParseFromString(message)
pts = [serv_utils.slot_to_paddlearray(s) for s in slots.slots]
ret = predictor(pts)
slots = interface_pb2.Slots(
slots=[serv_utils.paddlearray_to_slot(r) for r in ret])
socket.send(slots.SerializeToString())
except Exception as e:
log.exception(e)
socket.send(e.message)
class InferencePredictor(object):
def __init__(self, backend_addr, model_dir, n_devices=1):
self.backend_addr = backend_addr
self.model_dir = model_dir
self.n_devices = n_devices
self.children = []
def start(self):
for device_idx in range(self.n_devices):
p = multiprocessing.Process(
target=run_worker,
args=(self.model_dir, device_idx, self.backend_addr))
p.start()
self.children.append(p)
return self
def join(self):
for p in self.children:
p.join()
def term(self):
for p in self.children:
log.debug("terminating children %s" % repr(p))
p.terminate()
class InferenceProxy(object):
def __init__(self):
self.backend = None
self.frontend = None
def listen(self, frontend_addr, backend_addr):
log.info("InferenceProxy starting...")
try:
context = zmq.Context(1)
# Socket facing clients
self.frontend = context.socket(zmq.ROUTER)
self.frontend.bind(frontend_addr)
# Socket facing services
self.backend = context.socket(zmq.DEALER)
self.backend.bind(backend_addr)
log.info("Queue init done")
zmq.device(zmq.QUEUE, self.frontend, self.backend)
except Exception as e:
log.exception(e)
log.info("Bringing down zmq device")
finally:
log.debug('terminating proxy')
if self.frontend is not None:
self.frontend.close()
if self.backend is not None:
self.backend.close()
context.term()
class InferenceServer(object):
def __init__(self, model_dir, n_devices):
self.model_dir = model_dir
self.n_devices = n_devices
def listen(self, port):
frontend_addr = "tcp://*:%s" % port
backend_addr = "ipc://backend.ipc"
predictor = InferencePredictor(backend_addr, self.model_dir,
self.n_devices).start()
try:
proxy = InferenceProxy()
proxy.listen(frontend_addr, backend_addr)
predictor.join()
except KeyboardInterrupt:
log.debug('terminating server')
predictor.term()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import struct
from propeller.service import interface_pb2
def slot_to_numpy(slot):
if slot.type == interface_pb2.Slot.FP32:
dtype = np.float32
type_str = 'f'
elif slot.type == interface_pb2.Slot.INT32:
type_str = 'i'
dtype = np.int32
elif slot.type == interface_pb2.Slot.INT64:
dtype = np.int64
type_str = 'q'
else:
raise RuntimeError('know type %s' % slot.type)
num = len(slot.data) // struct.calcsize(type_str)
arr = struct.unpack('%d%s' % (num, type_str), slot.data)
shape = slot.dims
ret = np.array(arr, dtype=dtype).reshape(shape)
return ret
def numpy_to_slot(arr):
if arr.dtype == np.float32:
dtype = interface_pb2.Slot.FP32
elif arr.dtype == np.int32:
dtype = interface_pb2.Slot.INT32
elif arr.dtype == np.int64:
dtype = interface_pb2.Slot.INT64
else:
raise RuntimeError('know type %s' % arr.dtype)
pb = interface_pb2.Slot(
type=dtype, dims=list(arr.shape), data=arr.tobytes())
return pb
def slot_to_paddlearray(slot):
import paddle.fluid.core as core
if slot.type == interface_pb2.Slot.FP32:
type_str = 'f'
dtype = core.PaddleDType.FLOAT32
elif slot.type == interface_pb2.Slot.INT32:
type_str = 'i'
dtype = core.PaddleDType.INT32
elif slot.type == interface_pb2.Slot.INT64:
type_str = 'q'
dtype = core.PaddleDType.INT64
else:
raise RuntimeError('know type %s' % slot.type)
ret = core.PaddleTensor()
ret.shape = slot.dims
ret.dtype = dtype
num = len(slot.data) // struct.calcsize(type_str)
arr = struct.unpack('%d%s' % (num, type_str), slot.data)
ret.data = core.PaddleBuf(arr)
return ret
def paddlearray_to_slot(arr):
import paddle.fluid.core as core
if arr.dtype == core.PaddleDType.FLOAT32:
dtype = interface_pb2.Slot.FP32
type_str = 'f'
arr_data = arr.data.float_data()
elif arr.dtype == core.PaddleDType.INT32:
dtype = interface_pb2.Slot.INT32
type_str = 'i'
arr_data = arr.data.int32_data()
elif arr.dtype == core.PaddleDType.INT64:
dtype = interface_pb2.Slot.INT64
type_str = 'q'
arr_data = arr.data.int64_data()
else:
raise RuntimeError('know type %s' % arr.dtype)
data = struct.pack('%d%s' % (len(arr_data), type_str), *arr_data)
pb = interface_pb2.Slot(type=dtype, dims=list(arr.shape), data=data)
return pb
def nparray_list_serialize(arr_list):
slot_list = [numpy_to_slot(arr) for arr in arr_list]
slots = interface_pb2.Slots(slots=slot_list)
return slots.SerializeToString()
def nparray_list_deserialize(string):
slots = interface_pb2.Slots()
slots.ParseFromString(string)
return [slot_to_numpy(slot) for slot in slots.slots]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import sys
import os
import argparse
import logging
import logging.handlers
from propeller.service.server import InferenceServer
from propeller import log
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_dir', type=str, required=True)
parser.add_argument('-p', '--port', type=int, required=True)
parser.add_argument('-v', '--verbose', action='store_true')
args = parser.parse_args()
if args.verbose:
log.setLevel(logging.DEBUG)
n_devices = len(os.getenv("CUDA_VISIBLE_DEVICES").split(","))
server = InferenceServer(args.model_dir, n_devices)
log.info('propeller server listent on port %d' % args.port)
server.listen(args.port)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册