提交 815925c5 编写于 作者: C chenxuyi 提交者: Meiyim

upgrade propeller

上级 5e23ec05
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Propeller"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
......
...@@ -11,3 +11,6 @@ ...@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
doc
"""
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Basic Dataset API"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
...@@ -39,7 +40,7 @@ __all__ = ['Dataset'] ...@@ -39,7 +40,7 @@ __all__ = ['Dataset']
@contextmanager @contextmanager
def open_file(filename, format=None): def _open_file(filename, format=None):
if format is None: if format is None:
fd = open(filename, 'rb') fd = open(filename, 'rb')
elif format == 'GZIP': elif format == 'GZIP':
...@@ -50,9 +51,9 @@ def open_file(filename, format=None): ...@@ -50,9 +51,9 @@ def open_file(filename, format=None):
fd.close() fd.close()
def open_record(filename): def _open_record(filename):
def gen(): def _gen():
with open_file(filename, format='GZIP') as f: with _open_file(filename, format='GZIP') as f:
while True: while True:
data = f.read(struct.calcsize('i')) data = f.read(struct.calcsize('i'))
if not len(data): if not len(data):
...@@ -61,11 +62,11 @@ def open_record(filename): ...@@ -61,11 +62,11 @@ def open_record(filename):
data = f.read(l) data = f.read(l)
yield data yield data
return gen return _gen
def shuffle_func(dataset, buffer_size): def _shuffle_func(dataset, buffer_size):
def gen(): def _gen():
buf = [] buf = []
iterable = dataset() iterable = dataset()
try: try:
...@@ -82,11 +83,11 @@ def shuffle_func(dataset, buffer_size): ...@@ -82,11 +83,11 @@ def shuffle_func(dataset, buffer_size):
for i in buf: for i in buf:
yield i yield i
return gen return _gen
def interleave_func(iterable, map_fn, cycle_length, block_length): def _interleave_func(iterable, map_fn, cycle_length, block_length):
def gen(): def _gen():
ls = itertools.tee(iterable(), cycle_length) ls = itertools.tee(iterable(), cycle_length)
buf = [] buf = []
for i, j in enumerate(ls): for i, j in enumerate(ls):
...@@ -99,11 +100,11 @@ def interleave_func(iterable, map_fn, cycle_length, block_length): ...@@ -99,11 +100,11 @@ def interleave_func(iterable, map_fn, cycle_length, block_length):
for ii in (i for i in tup if i is not None): for ii in (i for i in tup if i is not None):
yield ii yield ii
return gen return _gen
def repeat_func(dataset, n): def _repeat_func(dataset, n):
def gen(): def _gen():
iterable = dataset() iterable = dataset()
if n >= 0: if n >= 0:
ret = itertools.chain(*itertools.tee(iterable, n)) ret = itertools.chain(*itertools.tee(iterable, n))
...@@ -113,11 +114,11 @@ def repeat_func(dataset, n): ...@@ -113,11 +114,11 @@ def repeat_func(dataset, n):
for i in ret: for i in ret:
yield i yield i
return gen return _gen
def filter_func(dataset, fn): def _filter_func(dataset, fn):
def gen(): def _gen():
for i in dataset(): for i in dataset():
if isinstance(i, tuple) or isinstance(i, list): if isinstance(i, tuple) or isinstance(i, list):
if fn(*i) is True: if fn(*i) is True:
...@@ -126,41 +127,41 @@ def filter_func(dataset, fn): ...@@ -126,41 +127,41 @@ def filter_func(dataset, fn):
if fn(i) is True: if fn(i) is True:
yield i yield i
return gen return _gen
def map_func(dataset, fn): def _map_func(dataset, fn):
def gen(): def _gen():
for i in dataset(): for i in dataset():
if isinstance(i, tuple) or isinstance(i, list): if isinstance(i, tuple) or isinstance(i, list):
yield fn(*i) yield fn(*i)
else: else:
yield fn(i) yield fn(i)
return gen return _gen
def shard_func(dataset, num_shards, index): def _shard_func(dataset, num_shards, index):
def gen(): def _gen():
iterable = dataset() iterable = dataset()
ret = itertools.islice(iterable, index, None, num_shards) ret = itertools.islice(iterable, index, None, num_shards)
for i in ret: for i in ret:
yield i yield i
return gen return _gen
def take_func(dataset, count): def _take_func(dataset, count):
def gen(): def _gen():
iterable = dataset() iterable = dataset()
ret = itertools.islice(iterable, count) ret = itertools.islice(iterable, count)
for i in ret: for i in ret:
yield i yield i
return gen return _gen
def buffered_func(dataset, size): def _buffered_func(dataset, size):
""" """
Creates a buffered data reader. Creates a buffered data reader.
...@@ -176,21 +177,21 @@ def buffered_func(dataset, size): ...@@ -176,21 +177,21 @@ def buffered_func(dataset, size):
:returns: the buffered data reader. :returns: the buffered data reader.
""" """
class EndSignal(): class _EndSignal(object):
pass pass
end = EndSignal() end = _EndSignal()
def read_worker(r, q): def _read_worker(r, q):
for d in r: for d in r:
q.put(d) q.put(d)
q.put(end) q.put(end)
def data_reader(): def _data_reader():
r = dataset() r = dataset()
q = multiprocessing.Queue(maxsize=size) q = multiprocessing.Queue(maxsize=size)
t = multiprocessing.Process( t = multiprocessing.Process(
target=read_worker, args=( target=_read_worker, args=(
r, r,
q, )) q, ))
t.daemon = True t.daemon = True
...@@ -200,14 +201,14 @@ def buffered_func(dataset, size): ...@@ -200,14 +201,14 @@ def buffered_func(dataset, size):
yield e yield e
e = q.get() e = q.get()
return data_reader return _data_reader
def padded_batch_func(dataset, batch_size, pad_value=0, max_seqlen=None): def _padded_batch_func(dataset, batch_size, pad_value=0, max_seqlen=None):
if not isinstance(batch_size, int): if not isinstance(batch_size, int):
raise ValueError('unknown batch_size: %s' % repr(batch_size)) raise ValueError('unknown batch_size: %s' % repr(batch_size))
def gen(): def _gen():
iterable = dataset() iterable = dataset()
pad_value_t = pad_value pad_value_t = pad_value
while True: while True:
...@@ -226,71 +227,86 @@ def padded_batch_func(dataset, batch_size, pad_value=0, max_seqlen=None): ...@@ -226,71 +227,86 @@ def padded_batch_func(dataset, batch_size, pad_value=0, max_seqlen=None):
if (not np.isscalar(elem)) and elem.shape != (): if (not np.isscalar(elem)) and elem.shape != ():
max_len = max(map(len, max_len = max(map(len,
e)) if max_seqlen is None else max_seqlen e)) if max_seqlen is None else max_seqlen
e = map(lambda i: np.pad(i, [0, max_len - len(i)], 'constant', constant_values=pv) if max_len >= len(i) else i[: max_len], e)
def _fn(i):
if max_len >= len(i):
return np.pad(i, [0, max_len - len(i)],
'constant',
constant_values=pv)
else:
return i[:max_len]
e = map(_fn, e)
padded.append(np.stack(list(e))) padded.append(np.stack(list(e)))
yield padded yield padded
return gen return _gen
class Dataset(object): class Dataset(object):
"""Python Wrapper for PyReader"""
@classmethod @classmethod
def from_generator_func(cls, gen, data_shapes=None, data_types=None): def from_generator_func(cls, _gen, data_shapes=None, data_types=None):
if not inspect.isgeneratorfunction(gen): """doc"""
raise ValueError('expect generator function, got %s' % repr(gen)) if not inspect.isgeneratorfunction(_gen):
raise ValueError('expect generator function, got %s' % repr(_gen))
def wrapper(): #compat to py3.7 def _wrapper(): #compat to py3.7
try: try:
for item in gen(): for item in _gen():
yield item yield item
except RuntimeError as e: except RuntimeError as e:
if str(e) != 'generator raised StopIteration': if str(e) != 'generator raised StopIteration':
raise e raise e
ret = cls() ret = cls()
ret.generator = wrapper ret.generator = _wrapper
ret.data_shapes = data_shapes ret.data_shapes = data_shapes
ret.data_types = data_types ret.data_types = data_types
return ret return ret
@classmethod @classmethod
def from_file(cls, filename, format=None): def from_file(cls, filename, format=None):
"""doc"""
if os.path.getsize(filename) == 0: if os.path.getsize(filename) == 0:
raise RuntimeError('%s is empty' % filename) raise RuntimeError('%s is empty' % filename)
def gen(): def _gen():
with open_file(filename, format) as f: with _open_file(filename, format) as f:
for line in f: for line in f:
yield line yield line
ret = cls() ret = cls()
ret.generator = gen ret.generator = _gen
ret.data_shapes = [] ret.data_shapes = []
ret.data_types = str ret.data_types = str
return ret return ret
@classmethod @classmethod
def from_record_file(cls, filename): def from_record_file(cls, filename):
"""doc"""
if os.path.getsize(filename) == 0: if os.path.getsize(filename) == 0:
raise RuntimeError('%s is empty' % filename) raise RuntimeError('%s is empty' % filename)
gen = open_record(filename) _gen = _open_record(filename)
ret = cls() ret = cls()
ret.generator = gen ret.generator = _gen
ret.data_shapes = [] ret.data_shapes = []
ret.data_types = str ret.data_types = str
return ret return ret
@classmethod @classmethod
def from_list(cls, ls): def from_list(cls, ls):
"""doc"""
if not isinstance(ls, list): if not isinstance(ls, list):
raise ValueError('expect list, got %s' % repr(ls)) raise ValueError('expect list, got %s' % repr(ls))
def gen(): def _gen():
for i in ls: for i in ls:
yield i yield i
ret = cls() ret = cls()
ret.generator = gen ret.generator = _gen
ret.data_shapes = [] ret.data_shapes = []
ret.data_types = str ret.data_types = str
return ret return ret
...@@ -339,6 +355,7 @@ class Dataset(object): ...@@ -339,6 +355,7 @@ class Dataset(object):
@property @property
def data_shapes(self): def data_shapes(self):
"""doc"""
if self._data_shapes is None: if self._data_shapes is None:
self._infer_shapes_and_types() self._infer_shapes_and_types()
return self._data_shapes return self._data_shapes
...@@ -347,10 +364,12 @@ class Dataset(object): ...@@ -347,10 +364,12 @@ class Dataset(object):
@data_shapes.setter @data_shapes.setter
def data_shapes(self, val): def data_shapes(self, val):
"""doc"""
self._data_shapes = val self._data_shapes = val
@property @property
def data_types(self): def data_types(self):
"""doc"""
if self._data_types is None: if self._data_types is None:
self._infer_shapes_and_types() self._infer_shapes_and_types()
return self._data_types return self._data_types
...@@ -359,9 +378,11 @@ class Dataset(object): ...@@ -359,9 +378,11 @@ class Dataset(object):
@data_types.setter @data_types.setter
def data_types(self, val): def data_types(self, val):
"""doc"""
self._data_types = val self._data_types = val
def apply(self, transform_func): def apply(self, transform_func):
"""apply transform func to datasets"""
#input_shapes = transform_func.input_shapes #input_shapes = transform_func.input_shapes
#input_types = transform_func.input_types #input_types = transform_func.input_types
#data_shapes = transform_func.data_shapes #data_shapes = transform_func.data_shapes
...@@ -377,46 +398,55 @@ class Dataset(object): ...@@ -377,46 +398,55 @@ class Dataset(object):
return ret return ret
def shuffle(self, buffer_size): def shuffle(self, buffer_size):
func = functools.partial(shuffle_func, buffer_size=buffer_size) """doc"""
func = functools.partial(_shuffle_func, buffer_size=buffer_size)
return self.apply(func) return self.apply(func)
def repeat(self, n=-1): def repeat(self, n=-1):
func = functools.partial(repeat_func, n=n) """doc"""
func = functools.partial(_repeat_func, n=n)
return self.apply(func) return self.apply(func)
def map(self, fn): def map(self, fn):
func = functools.partial(map_func, fn=fn) """doc"""
func = functools.partial(_map_func, fn=fn)
return self.apply(func) return self.apply(func)
def filter(self, fn): def filter(self, fn):
func = functools.partial(filter_func, fn=fn) """doc"""
func = functools.partial(_filter_func, fn=fn)
return self.apply(func) return self.apply(func)
def shard(self, num_shards, index): def shard(self, num_shards, index):
"""doc"""
func = functools.partial( func = functools.partial(
shard_func, num_shards=num_shards, index=index) _shard_func, num_shards=num_shards, index=index)
return self.apply(func) return self.apply(func)
def interleave(self, map_fn, cycle_length, block_length): def interleave(self, map_fn, cycle_length, block_length):
"""doc"""
func = functools.partial( func = functools.partial(
interleave_func, _interleave_func,
map_fn=map_fn, map_fn=map_fn,
cycle_length=cycle_length, cycle_length=cycle_length,
block_length=block_length) block_length=block_length)
return self.apply(func) return self.apply(func)
def padded_batch(self, batch_size, pad_value=0, max_seqlen=None): def padded_batch(self, batch_size, pad_value=0, max_seqlen=None):
"""doc"""
func = functools.partial( func = functools.partial(
padded_batch_func, _padded_batch_func,
batch_size=batch_size, batch_size=batch_size,
pad_value=pad_value, pad_value=pad_value,
max_seqlen=max_seqlen) max_seqlen=max_seqlen)
return self.apply(func) return self.apply(func)
def take(self, count=1): def take(self, count=1):
func = functools.partial(take_func, count=count) """doc"""
func = functools.partial(_take_func, count=count)
return self.apply(func) return self.apply(func)
def buffered(self, size=10): def buffered(self, size=10):
func = functools.partial(buffered_func, size=size) """doc"""
func = functools.partial(_buffered_func, size=size)
return self.apply(func) return self.apply(func)
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
doc
"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""global collections"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
...@@ -43,13 +45,16 @@ class Collections(object): ...@@ -43,13 +45,16 @@ class Collections(object):
_global_collection = None _global_collection = None
def add(self, key, val): def add(self, key, val):
"""doc"""
self.col.setdefault(key, []).append(val) self.col.setdefault(key, []).append(val)
def get(self, key): def get(self, key):
"""doc"""
return self.col.get(key, None) return self.col.get(key, None)
def default_collection(): def default_collection():
"""return global collection"""
global _global_collection global _global_collection
if _global_collection is None: if _global_collection is None:
_global_collection = Collections() _global_collection = Collections()
......
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
doc
"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""FeatureColumns and many Column"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
...@@ -30,7 +31,7 @@ import numpy as np ...@@ -30,7 +31,7 @@ import numpy as np
from glob import glob from glob import glob
from propeller.paddle.train import distribution from propeller.paddle.train import distribution
from propeller.data.functional import interleave_func from propeller.data.functional import _interleave_func
from propeller.paddle.data.functional import Dataset from propeller.paddle.data.functional import Dataset
from propeller.paddle.data import example_pb2, feature_pb2 from propeller.paddle.data import example_pb2, feature_pb2
...@@ -43,35 +44,47 @@ __all__ = [ ...@@ -43,35 +44,47 @@ __all__ = [
def basic_tokenizer(sen): def basic_tokenizer(sen):
"""doc"""
seg = sen.split(b' ') seg = sen.split(b' ')
seg = filter(lambda i: i != b' ', seg) seg = filter(lambda i: i != b' ', seg)
return seg return seg
class Column(): class Column(object):
"""doc"""
def __init__(self, name): def __init__(self, name):
"""doc"""
pass pass
def raw_to_proto(self, raw): def raw_to_proto(self, raw):
"""doc"""
return feature_pb2.Feature() return feature_pb2.Feature()
@property @property
def output_shapes(self): def output_shapes(self):
"""doc"""
pass pass
@property @property
def output_types(self): def output_types(self):
"""doc"""
pass pass
def proto_to_instance(self, proto): def proto_to_instance(self, proto):
"""doc"""
raise NotImplementedError() raise NotImplementedError()
def raw_to_instance(self, raw): def raw_to_instance(self, raw):
"""doc"""
raise NotImplementedError() raise NotImplementedError()
class LabelColumn(Column): class LabelColumn(Column):
"""doc"""
def __init__(self, name, vocab_dict=None, vocab_file=None): def __init__(self, name, vocab_dict=None, vocab_file=None):
"""doc"""
self.name = name self.name = name
self.vocab = None self.vocab = None
if vocab_file: if vocab_file:
...@@ -84,13 +97,16 @@ class LabelColumn(Column): ...@@ -84,13 +97,16 @@ class LabelColumn(Column):
@property @property
def output_shapes(self): def output_shapes(self):
"""doc"""
return [1] return [1]
@property @property
def output_types(self): def output_types(self):
"""doc"""
return 'int64' return 'int64'
def raw_to_proto(self, raw): def raw_to_proto(self, raw):
"""doc"""
if self.vocab is None: if self.vocab is None:
ids = [int(raw)] ids = [int(raw)]
else: else:
...@@ -99,10 +115,12 @@ class LabelColumn(Column): ...@@ -99,10 +115,12 @@ class LabelColumn(Column):
return fe return fe
def proto_to_instance(self, feature): def proto_to_instance(self, feature):
"""doc"""
ret = np.array(feature.int64_list.value[0], dtype=np.int64) ret = np.array(feature.int64_list.value[0], dtype=np.int64)
return ret return ret
def raw_to_instance(self, raw): def raw_to_instance(self, raw):
"""doc"""
if self.vocab is None: if self.vocab is None:
ids = int(raw) ids = int(raw)
else: else:
...@@ -111,6 +129,8 @@ class LabelColumn(Column): ...@@ -111,6 +129,8 @@ class LabelColumn(Column):
class TextColumn(Column): class TextColumn(Column):
"""doc"""
def __init__(self, def __init__(self,
name, name,
unk_id, unk_id,
...@@ -132,63 +152,75 @@ class TextColumn(Column): ...@@ -132,63 +152,75 @@ class TextColumn(Column):
@property @property
def output_shapes(self): def output_shapes(self):
"""doc"""
return [-1] return [-1]
@property @property
def output_types(self): def output_types(self):
"""doc"""
return 'int64' return 'int64'
def raw_to_proto(self, raw): def raw_to_proto(self, raw):
"""doc"""
ids = [self.vocab.get(s, self.unk_id) for s in self.tokenizer(raw)] ids = [self.vocab.get(s, self.unk_id) for s in self.tokenizer(raw)]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids)) fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe return fe
def proto_to_instance(self, feature): def proto_to_instance(self, feature):
"""doc"""
ret = np.array(feature.int64_list.value, dtype=np.int64) ret = np.array(feature.int64_list.value, dtype=np.int64)
return ret return ret
def raw_to_instance(self, raw): def raw_to_instance(self, raw):
"""doc"""
ids = [self.vocab.get(s, self.unk_id) for s in self.tokenizer(raw)] ids = [self.vocab.get(s, self.unk_id) for s in self.tokenizer(raw)]
return np.array(ids, dtype=np.int64) return np.array(ids, dtype=np.int64)
class TextIDColumn(Column): class TextIDColumn(Column):
"""doc"""
def __init__(self, name): def __init__(self, name):
"""doc"""
self.name = name self.name = name
@property @property
def output_shapes(self): def output_shapes(self):
"""doc"""
return [-1] return [-1]
@property @property
def output_types(self): def output_types(self):
"""doc"""
return 'int64' return 'int64'
def raw_to_proto(self, raw): def raw_to_proto(self, raw):
"""doc"""
ids = [int(s) for s in raw.split(b' ')] ids = [int(s) for s in raw.split(b' ')]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids)) fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe return fe
def proto_to_instance(self, feature): def proto_to_instance(self, feature):
"""doc"""
ret = np.array(feature.int64_list.value, dtype=np.int64) ret = np.array(feature.int64_list.value, dtype=np.int64)
return ret return ret
def raw_to_instance(self, raw): def raw_to_instance(self, raw):
"""doc"""
ret = np.array([int(i) for i in raw.split(b' ')], dtype=np.int64) ret = np.array([int(i) for i in raw.split(b' ')], dtype=np.int64)
return ret return ret
class FeatureColumns(object): def _list_files(raw_dir):
def __init__(self, columns, pad_id=0): return [os.path.join(raw_dir, p) for p in os.listdir(raw_dir)]
self._columns = columns
def raw_files(self, raw_dir): class FeatureColumns(object):
return [os.path.join(raw_dir, p) for p in os.listdir(raw_dir)] """A Dataset Factory object"""
def gz_files(self, gz_dir): def __init__(self, columns):
return None if gz_dir is None else [ """doc"""
os.path.join(gz_dir, p) for p in os.listdir(gz_dir) self._columns = columns
]
def _make_gz_dataset(self, raw_dir, gz_dir): def _make_gz_dataset(self, raw_dir, gz_dir):
assert raw_dir or gz_dir, 'data_dir not specified when using gz mode' assert raw_dir or gz_dir, 'data_dir not specified when using gz mode'
...@@ -237,7 +269,7 @@ class FeatureColumns(object): ...@@ -237,7 +269,7 @@ class FeatureColumns(object):
if shuffle: if shuffle:
dataset = dataset.shuffle(buffer_size=len(gz_files)) dataset = dataset.shuffle(buffer_size=len(gz_files))
fn = partial( fn = partial(
interleave_func, _interleave_func,
map_fn=lambda filename: Dataset.from_record_file(filename), map_fn=lambda filename: Dataset.from_record_file(filename),
cycle_length=len(gz_files), cycle_length=len(gz_files),
block_length=1) block_length=1)
...@@ -271,7 +303,7 @@ class FeatureColumns(object): ...@@ -271,7 +303,7 @@ class FeatureColumns(object):
dataset = dataset.shuffle(buffer_size=len(data_files)) dataset = dataset.shuffle(buffer_size=len(data_files))
fn = partial( fn = partial(
interleave_func, _interleave_func,
map_fn=lambda filename: Dataset.from_file(filename), map_fn=lambda filename: Dataset.from_file(filename),
cycle_length=len(data_files), cycle_length=len(data_files),
block_length=1) block_length=1)
...@@ -294,9 +326,9 @@ class FeatureColumns(object): ...@@ -294,9 +326,9 @@ class FeatureColumns(object):
def _read_stdin_dataset(self, encoding='utf8', shuffle=False, **kwargs): def _read_stdin_dataset(self, encoding='utf8', shuffle=False, **kwargs):
log.info('reading raw files stdin') log.info('reading raw files stdin')
def gen(): def _gen():
if six.PY3: if six.PY3:
source = sys.stdin.buffer source = sys.stdin.buffer
else: else:
source = sys.stdin source = sys.stdin
while True: while True:
...@@ -305,12 +337,12 @@ class FeatureColumns(object): ...@@ -305,12 +337,12 @@ class FeatureColumns(object):
break break
yield line, yield line,
dataset = Dataset.from_generator_func(gen) dataset = Dataset.from_generator_func(_gen)
if shuffle: if shuffle:
dataset = dataset.shuffle(buffer_size=1000) dataset = dataset.shuffle(buffer_size=1000)
def _parse_stdin(record_str): def _parse_stdin(record_str):
'''function that takes python_str as input''' """function that takes python_str as input"""
features = record_str.strip(b'\n').split(b'\t') features = record_str.strip(b'\n').split(b'\t')
ret = [ ret = [
column.raw_to_instance(feature) column.raw_to_instance(feature)
...@@ -346,13 +378,17 @@ class FeatureColumns(object): ...@@ -346,13 +378,17 @@ class FeatureColumns(object):
gz_dir=None, gz_dir=None,
data_file=None, data_file=None,
**kwargs): **kwargs):
"""
build `Dataset` from `data_dir` or `data_file`
if `use_gz`, will try to convert data_files to gz format and save to `gz_dir`, if `gz_dir` not given, will create one.
"""
if use_gz: if use_gz:
gz_dir = self._make_gz_dataset(data_dir, gz_dir) gz_dir = self._make_gz_dataset(data_dir, gz_dir)
gz_files = self.gz_files(gz_dir) gz_files = _list_files(gz_dir) if gz_dir is not None else gz_dir
ds = self._read_gz_dataset(gz_files, **kwargs) ds = self._read_gz_dataset(gz_files, **kwargs)
else: else:
if data_dir is not None: if data_dir is not None:
data_files = self.raw_files(data_dir) data_files = _list_files(data_dir)
elif data_file is not None: elif data_file is not None:
data_files = [data_file] data_files = [data_file]
else: else:
...@@ -362,6 +398,7 @@ class FeatureColumns(object): ...@@ -362,6 +398,7 @@ class FeatureColumns(object):
return ds return ds
def build_dataset_from_stdin(self, name, **kwargs): def build_dataset_from_stdin(self, name, **kwargs):
"""doc"""
ds = self._read_stdin_dataset(**kwargs) ds = self._read_stdin_dataset(**kwargs)
ds.name = name ds.name = name
return ds return ds
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Pyreader based Dataset"""
import sys import sys
import numpy as np import numpy as np
...@@ -25,7 +26,10 @@ log = logging.getLogger(__name__) ...@@ -25,7 +26,10 @@ log = logging.getLogger(__name__)
class Dataset(DatasetBase): class Dataset(DatasetBase):
"""Pyreader based Dataset"""
def placeholders(self): def placeholders(self):
"""doc"""
if self.name is None: if self.name is None:
raise ValueError('can not get feature from unnamed Dataset') raise ValueError('can not get feature from unnamed Dataset')
...@@ -41,7 +45,7 @@ class Dataset(DatasetBase): ...@@ -41,7 +45,7 @@ class Dataset(DatasetBase):
return ret return ret
def features(self): def features(self):
'''start point of net building. call this in a program scope''' """start point of net building. call this in a program scope"""
if self.name is None: if self.name is None:
raise ValueError('can not get feature from unnamed Dataset') raise ValueError('can not get feature from unnamed Dataset')
...@@ -51,9 +55,13 @@ class Dataset(DatasetBase): ...@@ -51,9 +55,13 @@ class Dataset(DatasetBase):
(repr(self._data_shapes), repr(self._data_types))) (repr(self._data_shapes), repr(self._data_types)))
return self.placeholders() return self.placeholders()
def start(self, places=F.cuda_places()): def start(self, places=None):
"""start Pyreader"""
if places is None:
places = F.cuda_places() if F.core.is_compiled_with_cuda(
) else F.cpu_places()
#assert self.pyreader is not None, 'use Dataset.features to build net first, then start dataset' #assert self.pyreader is not None, 'use Dataset.features to build net first, then start dataset'
def gen(): def _gen():
try: try:
for idx, i in enumerate(self.generator()): for idx, i in enumerate(self.generator()):
yield i yield i
...@@ -63,5 +71,5 @@ class Dataset(DatasetBase): ...@@ -63,5 +71,5 @@ class Dataset(DatasetBase):
r = F.io.PyReader( r = F.io.PyReader(
feed_list=self.placeholders(), capacity=50, iterable=True) feed_list=self.placeholders(), capacity=50, iterable=True)
r.decorate_batch_generator(gen, places=places) r.decorate_batch_generator(_gen, places=places)
return r() return r()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import logging
import six
import asyncio
import threading
import grpc
from propeller.service import interface_pb2
from propeller.service import interface_pb2_grpc
import propeller.paddle.service.utils as serv_utils
from concurrent.futures import ThreadPoolExecutor
import paddle.fluid as F
from time import sleep, time
log = logging.getLogger(__name__)
def profile(msg):
def decfn(fn):
def retfn(*args, **kwargs):
start = time()
ret = fn(*args, **kwargs)
end = time()
log.debug('%s timecost: %.5f' % (msg, end - start))
return ret
return retfn
return decfn
def serve(model_dir, host, num_concurrent=None):
if six.PY2:
raise RuntimeError('propeller service work in python3 only')
num_worker = len(F.cuda_places(
)) if num_concurrent is None else num_concurrent
pool = ThreadPoolExecutor(num_worker)
class Predictor(object):
def __init__(self, did):
log.debug('create predictor on card %d' % did)
config = F.core.AnalysisConfig(model_dir)
config.enable_use_gpu(5000, did)
self._predictor = F.core.create_paddle_predictor(config)
@profile('paddle')
def __call__(self, args):
for i, a in enumerate(args):
a.name = 'placeholder_%d' % i
res = self._predictor.run(args)
return res
predictor_context = {}
class InferenceService(interface_pb2_grpc.InferenceServicer):
@profile('service')
def Infer(self, request, context):
try:
slots = request.slots
current_thread = threading.current_thread()
log.debug('%d slots received dispatch to thread %s' %
(len(slots), current_thread))
if current_thread not in predictor_context:
did = list(pool._threads).index(current_thread)
log.debug('spawning worker thread %d' % did)
predictor = Predictor(did)
predictor_context[current_thread] = predictor
else:
predictor = predictor_context[current_thread]
slots = [serv_utils.slot_to_paddlearray(s) for s in slots]
ret = predictor(slots)
response = [serv_utils.paddlearray_to_slot(r) for r in ret]
except Exception as e:
log.exception(e)
raise e
return interface_pb2.Slots(slots=response)
server = grpc.server(pool)
interface_pb2_grpc.add_InferenceServicer_to_server(InferenceService(),
server)
server.add_insecure_port(host)
server.start()
log.info('server started on %s...' % host)
try:
while True:
sleep(100000)
except KeyboardInterrupt as e:
pass
log.info('server stoped...')
if __name__ == '__main__':
from propeller import log
log.setLevel(logging.DEBUG)
serve(
'/home/work/chenxuyi/playground/grpc_play/ernie2.0/',
'10.255.138.19:8334',
num_concurrent=3)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import struct
from propeller.service import interface_pb2
from propeller.service import interface_pb2_grpc
import paddle.fluid.core as core
def slot_to_paddlearray(slot):
if slot.type == interface_pb2.Slot.FP32:
type_str = 'f'
dtype = core.PaddleDType.FLOAT32
elif slot.type == interface_pb2.Slot.INT32:
type_str = 'i'
dtype = core.PaddleDType.INT32
elif slot.type == interface_pb2.Slot.INT64:
type_str = 'q'
dtype = core.PaddleDType.INT64
else:
raise RuntimeError('know type %s' % slot.type)
ret = core.PaddleTensor()
ret.shape = slot.dims
ret.dtype = dtype
num = len(slot.data) // struct.calcsize(type_str)
arr = struct.unpack('%d%s' % (num, type_str), slot.data)
ret.data = core.PaddleBuf(arr)
return ret
def paddlearray_to_slot(arr):
if arr.dtype == core.PaddleDType.FLOAT32:
dtype = interface_pb2.Slot.FP32
type_str = 'f'
arr_data = arr.data.float_data()
elif arr.dtype == core.PaddleDType.INT32:
dtype = interface_pb2.Slot.INT32
type_str = 'i'
arr_data = arr.data.int32_data()
elif arr.dtype == core.PaddleDType.INT64:
dtype = interface_pb2.Slot.INT64
type_str = 'q'
arr_data = arr.data.int64_data()
else:
raise RuntimeError('know type %s' % arr.dtype)
data = struct.pack('%d%s' % (len(arr_data), type_str), *arr_data)
pb = interface_pb2.Slot(type=dtype, dims=list(arr.shape), data=data)
return pb
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""record summary tensor in a collection scope"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
...@@ -23,6 +24,7 @@ from propeller.paddle.collection import default_collection, Key ...@@ -23,6 +24,7 @@ from propeller.paddle.collection import default_collection, Key
def scalar(name, tensor): def scalar(name, tensor):
"""scalar summary"""
if not isinstance(tensor, F.framework.Variable): if not isinstance(tensor, F.framework.Variable):
raise ValueError('expect paddle Variable, got %s' % repr(tensor)) raise ValueError('expect paddle Variable, got %s' % repr(tensor))
tensor.persistable = True tensor.persistable = True
...@@ -30,6 +32,7 @@ def scalar(name, tensor): ...@@ -30,6 +32,7 @@ def scalar(name, tensor):
def histogram(name, tensor): def histogram(name, tensor):
"""histogram summary"""
if not isinstance(tensor, F.framework.Variable): if not isinstance(tensor, F.framework.Variable):
raise ValueError('expect paddle Variable, got %s' % repr(tensor)) raise ValueError('expect paddle Variable, got %s' % repr(tensor))
tensor.persistable = True tensor.persistable = True
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Propeller training"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
......
...@@ -128,6 +128,7 @@ def init_distribuition_env(program): ...@@ -128,6 +128,7 @@ def init_distribuition_env(program):
elif status.mode == DistributionMode.NCCL: elif status.mode == DistributionMode.NCCL:
config = F.DistributeTranspilerConfig() config = F.DistributeTranspilerConfig()
config.mode = "nccl2" config.mode = "nccl2"
config.nccl_comm_num = 1
F.DistributeTranspiler(config=config).transpile( F.DistributeTranspiler(config=config).transpile(
status.replica_id, status.replica_id,
trainers=','.join(status._env), trainers=','.join(status._env),
......
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
exporters
"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
...@@ -19,6 +22,7 @@ import sys ...@@ -19,6 +22,7 @@ import sys
import os import os
import itertools import itertools
import six import six
import inspect
import abc import abc
import logging import logging
...@@ -28,24 +32,36 @@ import paddle.fluid.layers as L ...@@ -28,24 +32,36 @@ import paddle.fluid.layers as L
from propeller.paddle.train import Saver from propeller.paddle.train import Saver
from propeller.types import InferenceSpec from propeller.types import InferenceSpec
from propeller.train.model import Model
from propeller.paddle.train.trainer import _build_net
from propeller.paddle.train.trainer import _build_model_fn
from propeller.types import RunMode
from propeller.types import ProgramPair
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class Exporter(): class Exporter(object):
"""base exporter"""
@abc.abstractmethod @abc.abstractmethod
def export(self, exe, program, eval_result, state): def export(self, exe, program, eval_result, state):
"""export"""
raise NotImplementedError() raise NotImplementedError()
class BestExporter(Exporter): class BestExporter(Exporter):
"""export saved model accordingto `cmp_fn`"""
def __init__(self, export_dir, cmp_fn): def __init__(self, export_dir, cmp_fn):
"""doc"""
self._export_dir = export_dir self._export_dir = export_dir
self._best = None self._best = None
self.cmp_fn = cmp_fn self.cmp_fn = cmp_fn
def export(self, exe, program, eval_model_spec, eval_result, state): def export(self, exe, program, eval_model_spec, eval_result, state):
"""doc"""
log.debug('New evaluate result: %s \nold: %s' % log.debug('New evaluate result: %s \nold: %s' %
(repr(eval_result), repr(self._best))) (repr(eval_result), repr(self._best)))
if self._best is None or self.cmp_fn(old=self._best, new=eval_result): if self._best is None or self.cmp_fn(old=self._best, new=eval_result):
...@@ -65,40 +81,85 @@ class BestExporter(Exporter): ...@@ -65,40 +81,85 @@ class BestExporter(Exporter):
class BestInferenceModelExporter(Exporter): class BestInferenceModelExporter(Exporter):
def __init__(self, export_dir, cmp_fn): """export inference model accordingto `cmp_fn`"""
def __init__(self,
export_dir,
cmp_fn,
model_class_or_model_fn=None,
hparams=None,
dataset=None):
"""doc"""
self._export_dir = export_dir self._export_dir = export_dir
self._best = None self._best = None
self.cmp_fn = cmp_fn self.cmp_fn = cmp_fn
self.model_class_or_model_fn = model_class_or_model_fn
self.hparams = hparams
self.dataset = dataset
def export(self, exe, program, eval_model_spec, eval_result, state): def export(self, exe, program, eval_model_spec, eval_result, state):
"""doc"""
if self.model_class_or_model_fn is not None and self.hparams is not None \
and self.dataset is not None:
log.info('Building program by user defined model function')
if issubclass(self.model_class_or_model_fn, Model):
_model_fn = _build_model_fn(self.model_class_or_model_fn)
elif inspect.isfunction(self.model_class_or_model_fn):
_model_fn = self.model_class_or_model_fn
else:
raise ValueError('unknown model %s' %
self.model_class_or_model_fn)
# build net
infer_program = F.Program()
startup_prog = F.Program()
with F.program_guard(infer_program, startup_prog):
#share var with Train net
with F.unique_name.guard():
log.info('Building Infer Graph')
infer_fea = self.dataset.features()
# run_config is None
self.model_spec = _build_net(_model_fn, infer_fea,
RunMode.PREDICT, self.hparams,
None)
log.info('Done')
infer_program = infer_program.clone(for_test=True)
self.program = ProgramPair(
train_program=infer_program, startup_program=startup_prog)
else:
self.program = program
self.model_spec = eval_model_spec
log.debug('New evaluate result: %s \nold: %s' % log.debug('New evaluate result: %s \nold: %s' %
(repr(eval_result), repr(self._best))) (repr(eval_result), repr(self._best)))
if self._best is None or self.cmp_fn(old=self._best, new=eval_result): if self._best is None or self.cmp_fn(old=self._best, new=eval_result):
log.debug('[Best Exporter]: export to %s' % self._export_dir) log.debug('[Best Exporter]: export to %s' % self._export_dir)
if eval_model_spec.inference_spec is None: if self.model_spec.inference_spec is None:
raise ValueError('model_fn didnt return InferenceSpec') raise ValueError('model_fn didnt return InferenceSpec')
inf_sepc_dict = eval_model_spec.inference_spec inf_spec_dict = self.model_spec.inference_spec
if not isinstance(inf_sepc_dict, dict): if not isinstance(inf_spec_dict, dict):
inf_sepc_dict = {'inference': inf_sepc_dict} inf_spec_dict = {'inference': inf_spec_dict}
for inf_sepc_name, inf_sepc in six.iteritems(inf_sepc_dict): for inf_spec_name, inf_spec in six.iteritems(inf_spec_dict):
if not isinstance(inf_sepc, InferenceSpec): if not isinstance(inf_spec, InferenceSpec):
raise ValueError('unknown inference spec type: %s' % inf_sepc) raise ValueError('unknow inference spec type: %s' %
inf_spec)
save_dir = os.path.join(self._export_dir, inf_sepc_name) save_dir = os.path.join(self._export_dir, inf_spec_name)
log.debug('[Best Exporter]: save inference model: "%s" to %s' % log.debug('[Best Exporter]: save inference model: "%s" to %s' %
(inf_sepc_name, save_dir)) (inf_spec_name, save_dir))
feed_var = [i.name for i in inf_sepc.inputs] feed_var = [i.name for i in inf_spec.inputs]
fetch_var = inf_sepc.outputs fetch_var = inf_spec.outputs
eval_program = program.train_program infer_program = self.program.train_program
startup_prog = F.Program() startup_prog = F.Program()
F.io.save_inference_model( F.io.save_inference_model(
save_dir, save_dir,
feed_var, feed_var,
fetch_var, fetch_var,
exe, exe,
main_program=eval_program) main_program=infer_program)
self._best = eval_result self._best = eval_result
else: else:
log.debug('[Best Exporter]: skip step %s' % state.gstep) log.debug('[Best Exporter]: skip step %s' % state.gstep)
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""train hooks"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
...@@ -38,44 +39,56 @@ log = logging.getLogger(__name__) ...@@ -38,44 +39,56 @@ log = logging.getLogger(__name__)
class RunHook(object): class RunHook(object):
"""RunHook Base class"""
def __init__(self): def __init__(self):
"""doc"""
pass pass
def before_train(self): def before_train(self, program):
"""doc"""
pass pass
def before_run(self, state): def before_run(self, state):
"""doc"""
return [] return []
def after_run(self, res_list, state): def after_run(self, res_list, state):
"""doc"""
pass pass
def should_stop(self, state): def should_stop(self, state):
"""doc"""
return False return False
def after_train(self): def after_train(self):
"""doc"""
pass pass
class TqdmProgressBarHook(RunHook): class TqdmProgressBarHook(RunHook):
"""show a progress bar when training"""
def __init__(self, max_steps, desc=None): def __init__(self, max_steps, desc=None):
"""doc"""
self.tqdm = None self.tqdm = None
import tqdm import tqdm
from propeller import log as main_log from propeller import log as main_log
hdl = main_log.handlers[0] hdl = main_log.handlers[0]
class TqdmLogginHandler(logging.Handler): class _TqdmLogginHandler(logging.Handler):
def emit(self, record): def emit(self, record):
"""doc"""
try: try:
msg = self.format(record) msg = self.format(record)
tqdm.tqdm.write(msg, file=sys.stderr) tqdm.tqdm.write(msg, file=sys.stderr)
self.flush() self.flush()
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit) as e:
raise raise e
except: except:
self.handleError(record) self.handleError(record)
tqdm_hdl = TqdmLogginHandler() tqdm_hdl = _TqdmLogginHandler()
tqdm_hdl.setFormatter(hdl.formatter) tqdm_hdl.setFormatter(hdl.formatter)
main_log.removeHandler(hdl) main_log.removeHandler(hdl)
main_log.addHandler(tqdm_hdl) main_log.addHandler(tqdm_hdl)
...@@ -91,46 +104,55 @@ class TqdmProgressBarHook(RunHook): ...@@ -91,46 +104,55 @@ class TqdmProgressBarHook(RunHook):
class TqdmNotebookProgressBarHook(RunHook): class TqdmNotebookProgressBarHook(RunHook):
"""show a progress bar when training"""
def __init__(self, max_steps, desc=None): def __init__(self, max_steps, desc=None):
"""doc"""
self.tqdm = None self.tqdm = None
import tqdm import tqdm
from propeller import log as main_log from propeller import log as main_log
hdl = main_log.handlers[0] hdl = main_log.handlers[0]
class TqdmLogginHandler(logging.Handler): class _TqdmLogginHandler(logging.Handler):
def emit(self, record): def emit(self, record):
"""doc"""
try: try:
msg = self.format(record) msg = self.format(record)
tqdm.tqdm.write(msg, file=sys.stderr) tqdm.tqdm.write(msg, file=sys.stderr)
self.flush() self.flush()
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit) as e:
raise raise e
except: except:
self.handleError(record) self.handleError(record)
tqdm_hdl = TqdmLogginHandler() tqdm_hdl = _TqdmLogginHandler()
tqdm_hdl.setFormatter(hdl.formatter) tqdm_hdl.setFormatter(hdl.formatter)
main_log.removeHandler(hdl) main_log.removeHandler(hdl)
main_log.addHandler(tqdm_hdl) main_log.addHandler(tqdm_hdl)
self.tqdm = tqdm.tqdm_notebook(total=max_steps, desc=None) self.tqdm = tqdm.tqdm_notebook(total=max_steps, desc=None)
def before_run(self, state): def before_run(self, state):
"""doc"""
self.tqdm.n = state.gstep self.tqdm.n = state.gstep
self.tqdm.refresh() self.tqdm.refresh()
return [] return []
def __del__(self): def __del__(self):
"""doc"""
if self.tqdm: if self.tqdm:
self.tqdm.close() self.tqdm.close()
class LoggingHook(RunHook): class LoggingHook(RunHook):
"""log tensor in to screan and tensorboard"""
def __init__(self, def __init__(self,
loss, loss,
per_step=10, per_step=10,
skip_step=100, skip_step=100,
summary_writer=None, summary_writer=None,
summary_record=None): summary_record=None):
"""doc"""
if per_step is None or skip_step is None: if per_step is None or skip_step is None:
raise ValueError('wrong step argument, per step: %d skip_step %d' % raise ValueError('wrong step argument, per step: %d skip_step %d' %
(per_step, skip_step)) (per_step, skip_step))
...@@ -141,7 +163,8 @@ class LoggingHook(RunHook): ...@@ -141,7 +163,8 @@ class LoggingHook(RunHook):
self.writer = summary_writer self.writer = summary_writer
self.last_state = None self.last_state = None
def before_train(self): def before_train(self, program):
"""doc"""
if self.summary_record: if self.summary_record:
if self.summary_record.scalar: if self.summary_record.scalar:
self.s_name, self.s_tolog = zip(*self.summary_record.scalar) self.s_name, self.s_tolog = zip(*self.summary_record.scalar)
...@@ -154,6 +177,7 @@ class LoggingHook(RunHook): ...@@ -154,6 +177,7 @@ class LoggingHook(RunHook):
self.h_name, self.h_tolog = [], [] self.h_name, self.h_tolog = [], []
def before_run(self, state): def before_run(self, state):
"""doc"""
if state.gstep % self.per_step == 0 and state.step > self.skip_step: if state.gstep % self.per_step == 0 and state.step > self.skip_step:
ret = [self.loss] ret = [self.loss]
if self.summary_record: if self.summary_record:
...@@ -164,6 +188,7 @@ class LoggingHook(RunHook): ...@@ -164,6 +188,7 @@ class LoggingHook(RunHook):
return [] return []
def after_run(self, res_list, state): def after_run(self, res_list, state):
"""doc"""
if state.gstep % self.per_step == 0 and state.step > self.skip_step: if state.gstep % self.per_step == 0 and state.step > self.skip_step:
if not self.summary_record: if not self.summary_record:
return return
...@@ -209,11 +234,15 @@ class LoggingHook(RunHook): ...@@ -209,11 +234,15 @@ class LoggingHook(RunHook):
class StopAtStepHook(RunHook): class StopAtStepHook(RunHook):
"""stop training at some step"""
def __init__(self, stop_global_step, stop_step): def __init__(self, stop_global_step, stop_step):
"""doc"""
self._stop_gstep = stop_global_step self._stop_gstep = stop_global_step
self._stop_step = stop_step self._stop_step = stop_step
def should_stop(self, state): def should_stop(self, state):
"""doc"""
if (self._stop_gstep and state.gstep >= self._stop_gstep) or \ if (self._stop_gstep and state.gstep >= self._stop_gstep) or \
(self._stop_step and state.step >= self._stop_step): (self._stop_step and state.step >= self._stop_step):
log.info('StopAtStepHook called stop') log.info('StopAtStepHook called stop')
...@@ -226,6 +255,7 @@ class EvalHook(RunHook): ...@@ -226,6 +255,7 @@ class EvalHook(RunHook):
"""hook this on a eval Executor""" """hook this on a eval Executor"""
def __init__(self, metrics, summary_writer=None): def __init__(self, metrics, summary_writer=None):
"""doc"""
self.writer = summary_writer self.writer = summary_writer
self._result = None self._result = None
...@@ -244,11 +274,13 @@ class EvalHook(RunHook): ...@@ -244,11 +274,13 @@ class EvalHook(RunHook):
else: else:
self.names, self.metrics = [], [] self.names, self.metrics = [], []
def before_train(self): def before_train(self, program):
"""doc"""
for m in self.metrics: for m in self.metrics:
m.reset() m.reset()
def before_run(self, state): def before_run(self, state):
"""doc"""
ls = [m.tensor for m in self.metrics] ls = [m.tensor for m in self.metrics]
for i in ls: for i in ls:
if not (isinstance(i, list) or isinstance(i, tuple)): if not (isinstance(i, list) or isinstance(i, tuple)):
...@@ -265,15 +297,18 @@ class EvalHook(RunHook): ...@@ -265,15 +297,18 @@ class EvalHook(RunHook):
return ls_flt return ls_flt
def after_run(self, res_list, state): def after_run(self, res_list, state):
"""doc"""
res = util.unflatten(res_list, self.schema) res = util.unflatten(res_list, self.schema)
for r, m in zip(res, self.metrics): for r, m in zip(res, self.metrics):
m.update(r) m.update(r)
@property @property
def result(self): def result(self):
"""doc"""
return self._result return self._result
def after_train(self): def after_train(self):
"""doc"""
printable = [] printable = []
self._result = {} self._result = {}
for n, m in zip(self.names, self.metrics): for n, m in zip(self.names, self.metrics):
...@@ -284,12 +319,16 @@ class EvalHook(RunHook): ...@@ -284,12 +319,16 @@ class EvalHook(RunHook):
class CheckpointSaverHook(RunHook): class CheckpointSaverHook(RunHook):
"""Save checkpoint every n step"""
def __init__(self, saver, per_step=10, skip_step=100): def __init__(self, saver, per_step=10, skip_step=100):
"""doc"""
self.saver = saver self.saver = saver
self.per_step = per_step self.per_step = per_step
self.skip_step = skip_step self.skip_step = skip_step
def after_run(self, res_list, state): def after_run(self, res_list, state):
"""doc"""
if state.gstep % self.per_step == 0 and \ if state.gstep % self.per_step == 0 and \
state.step > self.skip_step: state.step > self.skip_step:
self.saver.save(state) self.saver.save(state)
...@@ -11,9 +11,12 @@ ...@@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""predefined metrics"""
import sys import sys
import os import os
import six
import numpy as np import numpy as np
import itertools import itertools
import logging import logging
...@@ -31,98 +34,132 @@ __all__ = [ ...@@ -31,98 +34,132 @@ __all__ = [
class Metrics(object): class Metrics(object):
"""Metrics base class"""
def __init__(self): def __init__(self):
"""doc"""
self.saver = [] self.saver = []
@property @property
def tensor(self): def tensor(self):
"""doc"""
pass pass
def update(self, *args): def update(self, *args):
"""doc"""
pass pass
def eval(self): def eval(self):
"""doc"""
pass pass
class Mean(Metrics): class Mean(Metrics):
"""doc"""
def __init__(self, t): def __init__(self, t):
"""doc"""
self.t = t self.t = t
self.reset() self.reset()
def reset(self): def reset(self):
"""doc"""
self.saver = np.array([]) self.saver = np.array([])
@property @property
def tensor(self): def tensor(self):
"""doc"""
self.t.persistable = True self.t.persistable = True
return self.t, return self.t,
def update(self, args): def update(self, args):
"""doc"""
t, = args t, = args
t = t.reshape([-1]) t = t.reshape([-1])
self.saver = np.concatenate([self.saver, t]) self.saver = np.concatenate([self.saver, t])
def eval(self): def eval(self):
"""doc"""
return self.saver.mean() return self.saver.mean()
class Ppl(Mean): class Ppl(Mean):
"""doc"""
def eval(self): def eval(self):
"""doc"""
return np.exp(self.saver.mean()) return np.exp(self.saver.mean())
class Acc(Mean): class Acc(Mean):
"""doc"""
def __init__(self, label, pred): def __init__(self, label, pred):
"""doc"""
self.eq = L.equal(pred, label) self.eq = L.equal(pred, label)
self.reset() self.reset()
@property @property
def tensor(self): def tensor(self):
"""doc"""
self.eq.persistable = True self.eq.persistable = True
return self.eq, return self.eq,
class MSE(Mean): class MSE(Mean):
"""doc"""
def __init__(self, label, pred): def __init__(self, label, pred):
"""doc"""
diff = pred - label diff = pred - label
self.mse = diff * diff self.mse = diff * diff
self.reset() self.reset()
@property @property
def tensor(self): def tensor(self):
"""doc"""
self.mse.persistable = True self.mse.persistable = True
return self.mse, return self.mse,
class Cosine(Mean): class Cosine(Mean):
"""doc"""
def __init__(self, label, pred): def __init__(self, label, pred):
"""doc"""
self.cos = L.cos_sim(label, pred) self.cos = L.cos_sim(label, pred)
self.reset() self.reset()
@property @property
def tensor(self): def tensor(self):
"""doc"""
self.cos.persistable = True self.cos.persistable = True
return self.cos, return self.cos,
class Precision(Metrics): class Precision(Metrics):
"""doc"""
def __init__(self, label, pred): def __init__(self, label, pred):
"""doc"""
self.label = label self.label = label
self.pred = pred self.pred = pred
self.reset() self.reset()
def reset(self): def reset(self):
"""doc"""
self.label_saver = np.array([], dtype=np.bool) self.label_saver = np.array([], dtype=np.bool)
self.pred_saver = np.array([], dtype=np.bool) self.pred_saver = np.array([], dtype=np.bool)
@property @property
def tensor(self): def tensor(self):
"""doc"""
self.label.persistable = True self.label.persistable = True
self.pred.persistable = True self.pred.persistable = True
return self.label, self.pred return self.label, self.pred
def update(self, args): def update(self, args):
"""doc"""
label, pred = args label, pred = args
label = label.reshape([-1]).astype(np.bool) label = label.reshape([-1]).astype(np.bool)
pred = pred.reshape([-1]).astype(np.bool) pred = pred.reshape([-1]).astype(np.bool)
...@@ -134,20 +171,27 @@ class Precision(Metrics): ...@@ -134,20 +171,27 @@ class Precision(Metrics):
self.pred_saver = np.concatenate([self.pred_saver, pred]) self.pred_saver = np.concatenate([self.pred_saver, pred])
def eval(self): def eval(self):
"""doc"""
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum() tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
t = self.label_saver.astype(np.int64).sum() t = self.label_saver.astype(np.int64).sum()
return tp / t return tp / t
class Recall(Precision): class Recall(Precision):
"""doc"""
def eval(self): def eval(self):
"""doc"""
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum() tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
p = (self.label_saver).astype(np.int64).sum() p = (self.label_saver).astype(np.int64).sum()
return tp / p return tp / p
class F1(Precision): class F1(Precision):
"""doc"""
def eval(self): def eval(self):
"""doc"""
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum() tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
t = self.label_saver.astype(np.int64).sum() t = self.label_saver.astype(np.int64).sum()
p = self.pred_saver.astype(np.int64).sum() p = self.pred_saver.astype(np.int64).sum()
...@@ -157,22 +201,28 @@ class F1(Precision): ...@@ -157,22 +201,28 @@ class F1(Precision):
class Auc(Metrics): class Auc(Metrics):
"""doc"""
def __init__(self, label, pred): def __init__(self, label, pred):
"""doc"""
self.pred = pred self.pred = pred
self.label = label self.label = label
self.reset() self.reset()
def reset(self): def reset(self):
"""doc"""
self.pred_saver = np.array([], dtype=np.float32) self.pred_saver = np.array([], dtype=np.float32)
self.label_saver = np.array([], dtype=np.bool) self.label_saver = np.array([], dtype=np.bool)
@property @property
def tensor(self): def tensor(self):
"""doc"""
self.pred.persistable = True self.pred.persistable = True
self.label.persistable = True self.label.persistable = True
return [self.pred, self.label] return [self.pred, self.label]
def update(self, args): def update(self, args):
"""doc"""
pred, label = args pred, label = args
pred = pred.reshape([-1]).astype(np.float32) pred = pred.reshape([-1]).astype(np.float32)
label = label.reshape([-1]).astype(np.bool) label = label.reshape([-1]).astype(np.bool)
...@@ -180,6 +230,7 @@ class Auc(Metrics): ...@@ -180,6 +230,7 @@ class Auc(Metrics):
self.label_saver = np.concatenate([self.label_saver, label]) self.label_saver = np.concatenate([self.label_saver, label])
def eval(self): def eval(self):
"""doc"""
fpr, tpr, thresholds = sklearn.metrics.roc_curve( fpr, tpr, thresholds = sklearn.metrics.roc_curve(
self.label_saver.astype(np.int64), self.pred_saver) self.label_saver.astype(np.int64), self.pred_saver)
auc = sklearn.metrics.auc(fpr, tpr) auc = sklearn.metrics.auc(fpr, tpr)
...@@ -187,11 +238,15 @@ class Auc(Metrics): ...@@ -187,11 +238,15 @@ class Auc(Metrics):
class RecallAtPrecision(Auc): class RecallAtPrecision(Auc):
"""doc"""
def __init__(self, label, pred, precision=0.9): def __init__(self, label, pred, precision=0.9):
"""doc"""
super(RecallAtPrecision, self).__init__(label, pred) super(RecallAtPrecision, self).__init__(label, pred)
self.precision = precision self.precision = precision
def eval(self): def eval(self):
"""doc"""
self.pred_saver = self.pred_saver.reshape( self.pred_saver = self.pred_saver.reshape(
[self.label_saver.size, -1])[:, -1] [self.label_saver.size, -1])[:, -1]
precision, recall, thresholds = sklearn.metrics.precision_recall_curve( precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
...@@ -202,11 +257,15 @@ class RecallAtPrecision(Auc): ...@@ -202,11 +257,15 @@ class RecallAtPrecision(Auc):
class PrecisionAtThreshold(Auc): class PrecisionAtThreshold(Auc):
"""doc"""
def __init__(self, label, pred, threshold=0.5): def __init__(self, label, pred, threshold=0.5):
"""doc"""
super().__init__(label, pred) super().__init__(label, pred)
self.threshold = threshold self.threshold = threshold
def eval(self): def eval(self):
"""doc"""
infered = self.pred_saver > self.threshold infered = self.pred_saver > self.threshold
correct_num = np.array(infered & self.label_saver).sum() correct_num = np.array(infered & self.label_saver).sum()
infer_num = infered.sum() infer_num = infered.sum()
...@@ -214,25 +273,31 @@ class PrecisionAtThreshold(Auc): ...@@ -214,25 +273,31 @@ class PrecisionAtThreshold(Auc):
class Mrr(Metrics): class Mrr(Metrics):
"""doc"""
def __init__(self, qid, label, pred): def __init__(self, qid, label, pred):
"""doc"""
self.qid = qid self.qid = qid
self.label = label self.label = label
self.pred = pred self.pred = pred
self.reset() self.reset()
def reset(self): def reset(self):
"""doc"""
self.qid_saver = np.array([], dtype=np.int64) self.qid_saver = np.array([], dtype=np.int64)
self.label_saver = np.array([], dtype=np.int64) self.label_saver = np.array([], dtype=np.int64)
self.pred_saver = np.array([], dtype=np.float32) self.pred_saver = np.array([], dtype=np.float32)
@property @property
def tensor(self): def tensor(self):
"""doc"""
self.qid.persistable = True self.qid.persistable = True
self.label.persistable = True self.label.persistable = True
self.pred.persistable = True self.pred.persistable = True
return [self.qid, self.label, self.pred] return [self.qid, self.label, self.pred]
def update(self, args): def update(self, args):
"""doc"""
qid, label, pred = args qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]): if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError( raise ValueError(
...@@ -246,10 +311,12 @@ class Mrr(Metrics): ...@@ -246,10 +311,12 @@ class Mrr(Metrics):
[self.pred_saver, pred.reshape([-1]).astype(np.float32)]) [self.pred_saver, pred.reshape([-1]).astype(np.float32)])
def eval(self): def eval(self):
def key_func(tup): """doc"""
def _key_func(tup):
return tup[0] return tup[0]
def calc_func(tup): def _calc_func(tup):
ranks = [ ranks = [
1. / (rank + 1.) 1. / (rank + 1.)
for rank, (_, l, p) in enumerate( for rank, (_, l, p) in enumerate(
...@@ -262,19 +329,22 @@ class Mrr(Metrics): ...@@ -262,19 +329,22 @@ class Mrr(Metrics):
return 0. return 0.
mrr_for_qid = [ mrr_for_qid = [
calc_func(tup) _calc_func(tup)
for _, tup in itertools.groupby( for _, tup in itertools.groupby(
sorted( sorted(
zip(self.qid_saver, self.label_saver, self.pred_saver), zip(self.qid_saver, self.label_saver, self.pred_saver),
key=key_func), key=_key_func),
key=key_func) key=_key_func)
] ]
mrr = np.float32(sum(mrr_for_qid) / len(mrr_for_qid)) mrr = np.float32(sum(mrr_for_qid) / len(mrr_for_qid))
return mrr return mrr
class ChunkF1(Metrics): class ChunkF1(Metrics):
"""doc"""
def __init__(self, label, pred, seqlen, num_label): def __init__(self, label, pred, seqlen, num_label):
"""doc"""
self.label = label self.label = label
self.pred = pred self.pred = pred
self.seqlen = seqlen self.seqlen = seqlen
...@@ -327,18 +397,21 @@ class ChunkF1(Metrics): ...@@ -327,18 +397,21 @@ class ChunkF1(Metrics):
return chunks return chunks
def reset(self): def reset(self):
"""doc"""
self.label_cnt = 0 self.label_cnt = 0
self.pred_cnt = 0 self.pred_cnt = 0
self.correct_cnt = 0 self.correct_cnt = 0
@property @property
def tensor(self): def tensor(self):
"""doc"""
self.pred.persistable = True self.pred.persistable = True
self.label.persistable = True self.label.persistable = True
self.seqlen.persistable = True self.seqlen.persistable = True
return [self.pred, self.label, self.seqlen] return [self.pred, self.label, self.seqlen]
def update(self, args): def update(self, args):
"""doc"""
pred, label, seqlen = args pred, label, seqlen = args
pred = pred.reshape([-1]).astype(np.int32).tolist() pred = pred.reshape([-1]).astype(np.int32).tolist()
label = label.reshape([-1]).astype(np.int32).tolist() label = label.reshape([-1]).astype(np.int32).tolist()
...@@ -374,6 +447,7 @@ class ChunkF1(Metrics): ...@@ -374,6 +447,7 @@ class ChunkF1(Metrics):
label_index += 1 label_index += 1
def eval(self): def eval(self):
"""doc"""
if self.pred_cnt == 0: if self.pred_cnt == 0:
precision = 0.0 precision = 0.0
else: else:
...@@ -393,23 +467,29 @@ class ChunkF1(Metrics): ...@@ -393,23 +467,29 @@ class ChunkF1(Metrics):
class PNRatio(Metrics): class PNRatio(Metrics):
"""doc"""
def __init__(self, qid, label, pred): def __init__(self, qid, label, pred):
"""doc"""
self.qid = qid self.qid = qid
self.label = label self.label = label
self.pred = pred self.pred = pred
self.saver = {} self.saver = {}
def reset(self): def reset(self):
"""doc"""
self.saver = {} self.saver = {}
@property @property
def tensor(self): def tensor(self):
"""doc"""
self.qid.persistable = True self.qid.persistable = True
self.label.persistable = True self.label.persistable = True
self.pred.persistable = True self.pred.persistable = True
return [self.qid, self.label, self.pred] return [self.qid, self.label, self.pred]
def update(self, args): def update(self, args):
"""doc"""
qid, label, pred = args qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]): if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]' raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]'
...@@ -424,6 +504,7 @@ class PNRatio(Metrics): ...@@ -424,6 +504,7 @@ class PNRatio(Metrics):
self.saver[q].append((l, p)) self.saver[q].append((l, p))
def eval(self): def eval(self):
"""doc"""
p = 0 p = 0
n = 0 n = 0
for qid, outputs in self.saver.items(): for qid, outputs in self.saver.items():
...@@ -446,10 +527,14 @@ class PNRatio(Metrics): ...@@ -446,10 +527,14 @@ class PNRatio(Metrics):
class BinaryPNRatio(PNRatio): class BinaryPNRatio(PNRatio):
"""doc"""
def __init__(self, qid, label, pred): def __init__(self, qid, label, pred):
"""doc"""
super(BinaryPNRatio, self).__init__(qid, label, pred) super(BinaryPNRatio, self).__init__(qid, label, pred)
def eval(self): def eval(self):
"""doc"""
p = 0 p = 0
n = 0 n = 0
for qid, outputs in self.saver.items(): for qid, outputs in self.saver.items():
...@@ -474,7 +559,10 @@ class BinaryPNRatio(PNRatio): ...@@ -474,7 +559,10 @@ class BinaryPNRatio(PNRatio):
class PrecisionAtK(Metrics): class PrecisionAtK(Metrics):
"""doc"""
def __init__(self, qid, label, pred, k=1): def __init__(self, qid, label, pred, k=1):
"""doc"""
self.qid = qid self.qid = qid
self.label = label self.label = label
self.pred = pred self.pred = pred
...@@ -482,16 +570,19 @@ class PrecisionAtK(Metrics): ...@@ -482,16 +570,19 @@ class PrecisionAtK(Metrics):
self.saver = {} self.saver = {}
def reset(self): def reset(self):
"""doc"""
self.saver = {} self.saver = {}
@property @property
def tensor(self): def tensor(self):
"""doc"""
self.qid.persistable = True self.qid.persistable = True
self.label.persistable = True self.label.persistable = True
self.pred.persistable = True self.pred.persistable = True
return [self.qid, self.label, self.pred] return [self.qid, self.label, self.pred]
def update(self, args): def update(self, args):
"""doc"""
qid, label, pred = args qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]): if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]' raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]'
...@@ -507,6 +598,7 @@ class PrecisionAtK(Metrics): ...@@ -507,6 +598,7 @@ class PrecisionAtK(Metrics):
self.saver[q].append((l, p)) self.saver[q].append((l, p))
def eval(self): def eval(self):
"""doc"""
right = 0 right = 0
total = 0 total = 0
for v in self.saver.values(): for v in self.saver.values():
......
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
doc
"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
...@@ -37,9 +41,17 @@ log = logging.getLogger(__name__) ...@@ -37,9 +41,17 @@ log = logging.getLogger(__name__)
__all__ = ['MonitoredExecutor', 'Saver'] __all__ = ['MonitoredExecutor', 'Saver']
def _get_one_place():
return F.cuda_places()[0] if F.core.is_compiled_with_cuda(
) else F.cpu_places()[0]
class RunState(object): class RunState(object):
"""serializable Run state object"""
@classmethod @classmethod
def from_str(cls, s): def from_str(cls, s):
"""doc"""
j = json.loads(s) j = json.loads(s)
ret = RunState() ret = RunState()
ret._gstep = j['global_step'] ret._gstep = j['global_step']
...@@ -48,29 +60,36 @@ class RunState(object): ...@@ -48,29 +60,36 @@ class RunState(object):
return ret return ret
def __init__(self): def __init__(self):
"""doc"""
self._gstep = 0 self._gstep = 0
self._step = 0 self._step = 0
self._time = time() self._time = time()
@property @property
def gstep(self): def gstep(self):
"""doc"""
return self._gstep return self._gstep
@property @property
def step(self): def step(self):
"""doc"""
return self._step return self._step
@property @property
def time(self): def time(self):
"""doc"""
return self._time return self._time
def __repr__(self): def __repr__(self):
"""doc"""
return repr({'global_step': self._gstep, 'time': self._time}) return repr({'global_step': self._gstep, 'time': self._time})
def serialize(self): def serialize(self):
"""doc"""
return json.dumps({'global_step': self._gstep, 'time': self._time}) return json.dumps({'global_step': self._gstep, 'time': self._time})
def next(self): def next(self):
"""doc"""
ret = RunState() ret = RunState()
ret._gstep = self._gstep + 1 ret._gstep = self._gstep + 1
ret._step = self._step + 1 ret._step = self._step + 1
...@@ -79,12 +98,15 @@ class RunState(object): ...@@ -79,12 +98,15 @@ class RunState(object):
class Saver(object): class Saver(object):
"""checkpoint saver and manager"""
def __init__(self, def __init__(self,
save_dir, save_dir,
exe, exe,
program, program,
save_prefix='model', save_prefix='model',
max_ckpt_to_keep=None): max_ckpt_to_keep=None):
"""doc"""
if exe is not None: if exe is not None:
assert isinstance( assert isinstance(
exe, F.Executor exe, F.Executor
...@@ -108,9 +130,11 @@ class Saver(object): ...@@ -108,9 +130,11 @@ class Saver(object):
@property @property
def last_ckpt(self): def last_ckpt(self):
"""doc"""
return self.ckpt_list[-1] if len(self.ckpt_list) else None return self.ckpt_list[-1] if len(self.ckpt_list) else None
def save(self, state): def save(self, state):
"""doc"""
save_name = '%s_%d' % (self._save_prefix, state.gstep) save_name = '%s_%d' % (self._save_prefix, state.gstep)
save_dir = os.path.join(self._save_dir, save_name) save_dir = os.path.join(self._save_dir, save_name)
tmp_dir = os.path.join(self._save_dir, 'tmp') tmp_dir = os.path.join(self._save_dir, 'tmp')
...@@ -139,28 +163,26 @@ class Saver(object): ...@@ -139,28 +163,26 @@ class Saver(object):
open(self.ckpt_info_path, 'w').write('\n'.join(self.ckpt_list)) open(self.ckpt_info_path, 'w').write('\n'.join(self.ckpt_list))
def restore(self, ckpt=-1): def restore(self, ckpt=-1):
if not isinstance(ckpt, (int, ) + six.string_types): """doc"""
raise ValueError('ckpt type not understood %s' % repr(ckpt))
if isinstance(ckpt, int): if isinstance(ckpt, int):
try: try:
ckpt = self.ckpt_list[ckpt] path = os.path.join(self._save_dir, self.ckpt_list[ckpt])
except IndexError: except IndexError:
raise ValueError('invalid restore ckpt number %d' % ckpt) raise ValueError('invalid restore ckpt number %d' % ckpt)
if isinstance(ckpt, six.string_types): elif isinstance(ckpt, six.string_types):
try: if not os.path.exists(ckpt):
ckpt = self.ckpt_list.index(ckpt) raise ValueError('ckpt: %s not found' % ckpt)
except ValueError: path = ckpt
raise ValueError('ckpt: %s not in ckpt list: %s' % else:
(ckpt, self.ckpt_list)) raise ValueError('ckpt type not understood %s' % repr(ckpt))
path = os.path.join(self._save_dir, self.ckpt_list[ckpt])
meta_file = os.path.join(path, 'meta') meta_file = os.path.join(path, 'meta')
if not os.path.exists(meta_file): if not os.path.exists(meta_file):
raise RuntimeError('meta not found in restore dir: %s' % path) raise RuntimeError('meta not found in restore dir: %s' % path)
state = RunState.from_str(open(meta_file).read()) state = RunState.from_str(open(meta_file).read())
log.info('restore from ckpt %s, ckpt-status: %s' % (path, repr(state))) log.info('restore from ckpt %s, ckpt-status: %s' % (path, repr(state)))
def fn(v): def _fn(v):
vpath = os.path.join(path, v.name) vpath = os.path.join(path, v.name)
if F.io.is_persistable(v): if F.io.is_persistable(v):
if os.path.exists(vpath): if os.path.exists(vpath):
...@@ -171,12 +193,12 @@ class Saver(object): ...@@ -171,12 +193,12 @@ class Saver(object):
return False return False
F.io.load_vars( F.io.load_vars(
self._exe, path, main_program=self._program, predicate=fn) self._exe, path, main_program=self._program, predicate=_fn)
return state return state
class MonitoredExecutor(object): class MonitoredExecutor(object):
"""A wrapper handling the train loop""" """An Executor wrapper handling the train loop"""
def __init__( def __init__(
self, self,
...@@ -209,13 +231,18 @@ class MonitoredExecutor(object): ...@@ -209,13 +231,18 @@ class MonitoredExecutor(object):
@property @property
def state(self): def state(self):
"""doc"""
return self._state return self._state
def init_or_restore_variables(self): def init_or_restore_variables(self, ckpt=-1):
"""
init vars or restore vars from model_dir
call before train
"""
# The order of this 2 steps really matters # The order of this 2 steps really matters
# 1. init train # 1. init train
F.Executor(F.cuda_places()[0]).run(self._program.startup_program) F.Executor(_get_one_place()).run(self._program.startup_program)
# 2. restore param # 2. restore param
if self._warm_start_setting is not None: if self._warm_start_setting is not None:
if not os.path.exists(self._warm_start_setting.from_dir): if not os.path.exists(self._warm_start_setting.from_dir):
...@@ -224,29 +251,34 @@ class MonitoredExecutor(object): ...@@ -224,29 +251,34 @@ class MonitoredExecutor(object):
log.info("warm start from %s" % self._warm_start_setting.from_dir) log.info("warm start from %s" % self._warm_start_setting.from_dir)
if self._warm_start_setting.predicate_fn is not None: if self._warm_start_setting.predicate_fn is not None:
def fn(v): def _fn(v):
ret = self._warm_start_setting.predicate_fn(v) ret = self._warm_start_setting.predicate_fn(v)
if ret: if ret:
log.info('warm start: %s' % v.name) log.info('warm start: %s' % v.name)
return ret return ret
F.io.load_vars( F.io.load_vars(
F.Executor(F.cuda_places()[0]), F.Executor(_get_one_place()),
self._warm_start_setting.from_dir, self._warm_start_setting.from_dir,
main_program=self._program.train_program, main_program=self._program.train_program,
predicate=fn) predicate=_fn)
else: else:
raise NotImplementedError() raise NotImplementedError()
self._saver = Saver( self._saver = Saver(
self._model_dir, self._model_dir,
F.Executor(F.cuda_places()[0]), F.Executor(_get_one_place()),
program=self._program.train_program, program=self._program.train_program,
max_ckpt_to_keep=self._max_ckpt) max_ckpt_to_keep=self._max_ckpt)
if self._saver.last_ckpt is not None: if self._saver.last_ckpt is not None:
self._state = self._saver.restore() self._state = self._saver.restore(ckpt)
def freeze(self): def _freeze(self):
"""
call before enter train loop
convert program to compiled program
will do nothing if loss is None i.e. not in train mode
"""
if self._loss is None: if self._loss is None:
log.debug('will not freeze a program without loss') log.debug('will not freeze a program without loss')
return return
...@@ -278,8 +310,16 @@ class MonitoredExecutor(object): ...@@ -278,8 +310,16 @@ class MonitoredExecutor(object):
startup_program=self._program.startup_program) startup_program=self._program.startup_program)
def __enter__(self): def __enter__(self):
"""
prepapre before enter train loop
"""
if F.core.is_compiled_with_cuda():
log.info('propeller runs in CUDA mode')
else:
log.info('propeller runs in CPU mode')
log.debug('freezing program') log.debug('freezing program')
self.freeze() self._freeze()
log.debug('done freezing') log.debug('done freezing')
log.info('********** Start Loop ************') log.info('********** Start Loop ************')
# TODO init # TODO init
...@@ -287,10 +327,13 @@ class MonitoredExecutor(object): ...@@ -287,10 +327,13 @@ class MonitoredExecutor(object):
self.result = None self.result = None
for h in self._hooks: for h in self._hooks:
log.debug('train loop has hook %s' % h) log.debug('train loop has hook %s' % h)
h.before_train() h.before_train(self._program)
return self return self
def run(self, fetch_list=[], *args, **kwargs): def run(self, fetch_list=[], *args, **kwargs):
"""
wrapper for Executor.run
"""
#log.debug('Executor running step %d' % self._state.gstep) #log.debug('Executor running step %d' % self._state.gstep)
if self._hooks: if self._hooks:
fetch_list = [fetch_list] fetch_list = [fetch_list]
...@@ -306,11 +349,12 @@ class MonitoredExecutor(object): ...@@ -306,11 +349,12 @@ class MonitoredExecutor(object):
] ]
#if len(set(fetch_list)) != len(fetch_list): #if len(set(fetch_list)) != len(fetch_list):
# log.error('strange shit happend when fetch list has idetity tensors %s' % fetch_list) # log.error('strange shit happend when fetch list has idetity tensors %s' % fetch_list)
#log.debug(fetch_list)
res = self._exe.run(self._program.train_program, res = self._exe.run(self._program.train_program,
fetch_list=fetch_list, fetch_list=fetch_list,
*args, *args,
**kwargs) **kwargs)
res = [self.merge_result(r) for r in res] res = [self._merge_result(r) for r in res]
#log.debug(res) #log.debug(res)
res = util.unflatten(res, schema) res = util.unflatten(res, schema)
...@@ -330,6 +374,9 @@ class MonitoredExecutor(object): ...@@ -330,6 +374,9 @@ class MonitoredExecutor(object):
return ret return ret
def __exit__(self, err_type, err_value, trace): def __exit__(self, err_type, err_value, trace):
"""
clean up things and report hook result when exit train loop
"""
if (err_type is None) or isinstance(err_value, ( if (err_type is None) or isinstance(err_value, (
F.core.EOFException, StopException, KeyboardInterrupt)): F.core.EOFException, StopException, KeyboardInterrupt)):
try: try:
...@@ -344,7 +391,10 @@ class MonitoredExecutor(object): ...@@ -344,7 +391,10 @@ class MonitoredExecutor(object):
log.exception('error occur during loop %s: %s' % log.exception('error occur during loop %s: %s' %
(err_type, err_value)) (err_type, err_value))
def merge_result(self, ls): def _merge_result(self, ls):
"""
merge results from multi gpu cards
"""
dev_count = len(self._program.train_program._places) if isinstance( dev_count = len(self._program.train_program._places) if isinstance(
self._program.train_program, F.compiler.CompiledProgram) else 1 self._program.train_program, F.compiler.CompiledProgram) else 1
if dev_count == 1: if dev_count == 1:
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""common ML train and eval procedure"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
...@@ -28,7 +29,8 @@ from time import time ...@@ -28,7 +29,8 @@ from time import time
import paddle.fluid as F import paddle.fluid as F
import paddle.fluid.layers as L import paddle.fluid.layers as L
from propeller.types import RunMode, StopException, SummaryRecord, StopException, ModelSpec, InferenceSpec, ProgramPair, RunConfig from propeller.types import RunMode, StopException, SummaryRecord, StopException
from propeller.types import ModelSpec, InferenceSpec, ProgramPair, RunConfig
from propeller.paddle import summary, collection from propeller.paddle import summary, collection
from propeller.paddle.data.functional import Dataset from propeller.paddle.data.functional import Dataset
from propeller.paddle.train import distribution from propeller.paddle.train import distribution
...@@ -43,7 +45,7 @@ log = logging.getLogger(__name__) ...@@ -43,7 +45,7 @@ log = logging.getLogger(__name__)
__all__ = ['train_and_eval', 'Learner'] __all__ = ['train_and_eval', 'Learner']
def get_summary_writer(path): def _get_summary_writer(path):
summary_writer = None summary_writer = None
try: try:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
...@@ -54,7 +56,12 @@ def get_summary_writer(path): ...@@ -54,7 +56,12 @@ def get_summary_writer(path):
return summary_writer return summary_writer
def log_eval_result(name, eval_result, swriter, state): def _get_one_place():
return F.cuda_places()[0] if F.core.is_compiled_with_cuda(
) else F.cpu_places()[0]
def _log_eval_result(name, eval_result, swriter, state):
log.debug(eval_result) log.debug(eval_result)
printable = [] printable = []
for n, val in six.iteritems(eval_result): for n, val in six.iteritems(eval_result):
...@@ -71,7 +78,7 @@ def log_eval_result(name, eval_result, swriter, state): ...@@ -71,7 +78,7 @@ def log_eval_result(name, eval_result, swriter, state):
log.info('******************************') log.info('******************************')
def build_net(model_fn, features, mode, params, run_config): def _build_net(model_fn, features, mode, params, run_config):
model_spec = model_fn( model_spec = model_fn(
features=features, mode=mode, params=params, run_config=run_config) features=features, mode=mode, params=params, run_config=run_config)
...@@ -97,12 +104,14 @@ def build_net(model_fn, features, mode, params, run_config): ...@@ -97,12 +104,14 @@ def build_net(model_fn, features, mode, params, run_config):
class Learner(object): class Learner(object):
"""A Learner can train / eval / predict on a Dataset"""
def __init__(self, def __init__(self,
model_class_or_model_fn, model_class_or_model_fn,
run_config, run_config,
params=None, params=None,
warm_start_setting=None): warm_start_setting=None):
''' """
model_class_or_model_fn(callable|propeller.train.Model): `model_class_or_model_fn` be specified in 2 ways: model_class_or_model_fn(callable|propeller.train.Model): `model_class_or_model_fn` be specified in 2 ways:
1. subclass of propeller.train.Model which implements: 1. subclass of propeller.train.Model which implements:
1. \_\_init\_\_ (hyper_param, mode, run_config) 1. \_\_init\_\_ (hyper_param, mode, run_config)
...@@ -121,58 +130,23 @@ class Learner(object): ...@@ -121,58 +130,23 @@ class Learner(object):
params: any python object, will pass to your `model_fn` or `propeller.train.Model` params: any python object, will pass to your `model_fn` or `propeller.train.Model`
run_config (propeller.RunConfig): run_config.max_steps should not be None. run_config (propeller.RunConfig): run_config.max_steps should not be None.
warm_start_setting (propeller.WarmStartSetting): Optional. warm start variable will overwrite model variable. warm_start_setting (propeller.WarmStartSetting): Optional. warm start variable will overwrite model variable.
''' """
if run_config.model_dir is None: if run_config.model_dir is None:
raise ValueError('model_dir should specified in run_config') raise ValueError('model_dir should specified in run_config')
if issubclass(model_class_or_model_fn, Model): if issubclass(model_class_or_model_fn, Model):
_model_fn = _build_model_fn(model_class_or_model_fn)
def model_fn(features, mode, params, run_config):
if mode != RunMode.PREDICT:
fea, label = features[:-1], features[-1]
else:
fea = features
model = model_class_or_model_fn(
params, mode, run_config=run_config)
pred = model.forward(fea)
if isinstance(pred, F.framework.Variable):
prediction = [pred]
else:
prediction = pred
if mode == RunMode.TRAIN:
loss = model.loss(pred, label)
model.backward(loss)
return ModelSpec(
loss=loss, predictions=prediction, mode=mode)
elif mode == RunMode.EVAL:
loss = model.loss(pred, label)
me = model.metrics(pred, label)
inf_spec = InferenceSpec(inputs=fea, outputs=prediction)
if 'loss' not in me:
me['loss'] = metrics.Mean(loss)
return ModelSpec(
loss=loss,
predictions=prediction,
metrics=me,
mode=mode,
inference_spec=inf_spec)
elif mode == RunMode.PREDICT:
return ModelSpec(predictions=prediction, mode=mode)
else:
raise RuntimeError('unknown run mode %s' % mode)
elif inspect.isfunction(model_class_or_model_fn): elif inspect.isfunction(model_class_or_model_fn):
model_fn = model_class_or_model_fn _model_fn = model_class_or_model_fn
else: else:
raise ValueError('unknown model %s' % model_class_or_model_fn) raise ValueError('unknown model %s' % model_class_or_model_fn)
self.model_fn = model_fn self.model_fn = _model_fn
self.params = params self.params = params
self.run_config = run_config self.run_config = run_config
self.warm_start_setting = warm_start_setting self.warm_start_setting = warm_start_setting
def build_for_train(self, train_dataset): def _build_for_train(self, train_dataset):
train_dataset.name = 'train' train_dataset.name = 'train'
train_program = F.Program() train_program = F.Program()
startup_prog = F.Program() startup_prog = F.Program()
...@@ -181,8 +155,8 @@ class Learner(object): ...@@ -181,8 +155,8 @@ class Learner(object):
with collection.Collections() as collections: with collection.Collections() as collections:
log.info('Building Train Graph...') log.info('Building Train Graph...')
fea = train_dataset.features() fea = train_dataset.features()
model_spec = build_net(self.model_fn, fea, RunMode.TRAIN, model_spec = _build_net(self.model_fn, fea, RunMode.TRAIN,
self.params, self.run_config) self.params, self.run_config)
log.info('Building Train Graph: Done') log.info('Building Train Graph: Done')
scalars = collections.get(collection.Key.SUMMARY_SCALAR) scalars = collections.get(collection.Key.SUMMARY_SCALAR)
...@@ -208,7 +182,7 @@ class Learner(object): ...@@ -208,7 +182,7 @@ class Learner(object):
train_program=train_program, train_program=train_program,
startup_program=startup_prog), model_spec, summary_record startup_program=startup_prog), model_spec, summary_record
def build_for_eval(self, ds): def _build_for_eval(self, ds):
ds.name = 'eval' ds.name = 'eval'
program = F.Program() program = F.Program()
startup_prog = F.Program() startup_prog = F.Program()
...@@ -217,8 +191,8 @@ class Learner(object): ...@@ -217,8 +191,8 @@ class Learner(object):
with F.unique_name.guard(): with F.unique_name.guard():
log.info('Building Eval Graph') log.info('Building Eval Graph')
fea = ds.features() fea = ds.features()
model_spec = build_net(self.model_fn, fea, RunMode.EVAL, model_spec = _build_net(self.model_fn, fea, RunMode.EVAL,
self.params, self.run_config) self.params, self.run_config)
log.info('Done') log.info('Done')
program = program.clone(for_test=True) program = program.clone(for_test=True)
log.info( log.info(
...@@ -227,7 +201,7 @@ class Learner(object): ...@@ -227,7 +201,7 @@ class Learner(object):
return ProgramPair( return ProgramPair(
train_program=program, startup_program=startup_prog), model_spec train_program=program, startup_program=startup_prog), model_spec
def build_for_predict(self, ds): def _build_for_predict(self, ds):
ds.name = 'predict' ds.name = 'predict'
program = F.Program() program = F.Program()
startup_prog = F.Program() startup_prog = F.Program()
...@@ -236,8 +210,8 @@ class Learner(object): ...@@ -236,8 +210,8 @@ class Learner(object):
with F.unique_name.guard(): with F.unique_name.guard():
log.info('Building Predict Graph') log.info('Building Predict Graph')
fea = ds.features() fea = ds.features()
model_spec = build_net(self.model_fn, fea, RunMode.PREDICT, model_spec = _build_net(self.model_fn, fea, RunMode.PREDICT,
self.params, self.run_config) self.params, self.run_config)
log.info('Done') log.info('Done')
program = program.clone(for_test=True) program = program.clone(for_test=True)
...@@ -249,11 +223,12 @@ class Learner(object): ...@@ -249,11 +223,12 @@ class Learner(object):
train_program=program, startup_program=startup_prog), model_spec train_program=program, startup_program=startup_prog), model_spec
def train(self, train_ds, train_hooks=[]): def train(self, train_ds, train_hooks=[]):
"""train on a `Dataset`"""
if not isinstance(train_ds, Dataset): if not isinstance(train_ds, Dataset):
raise ValueError('expect dataset to be instance of Dataset, got %s' raise ValueError('expect dataset to be instance of Dataset, got %s'
% repr(train_ds)) % repr(train_ds))
train_program, model_spec, summary_record = self.build_for_train( train_program, model_spec, summary_record = self._build_for_train(
train_ds) train_ds)
train_run_hooks = [ train_run_hooks = [
hooks.StopAtStepHook(self.run_config.max_steps, hooks.StopAtStepHook(self.run_config.max_steps,
...@@ -261,13 +236,16 @@ class Learner(object): ...@@ -261,13 +236,16 @@ class Learner(object):
hooks.LoggingHook( hooks.LoggingHook(
model_spec.loss, model_spec.loss,
summary_record=summary_record, summary_record=summary_record,
summary_writer=get_summary_writer( summary_writer=_get_summary_writer(
os.path.join(self.run_config.model_dir, 'train_history')), os.path.join(self.run_config.model_dir, 'train_history')),
per_step=self.run_config.log_steps, per_step=self.run_config.log_steps,
skip_step=self.run_config.skip_steps), skip_step=self.run_config.skip_steps),
] ]
if model_spec.train_hooks is not None:
train_run_hooks.extend(model_spec.train_hooks)
train_run_hooks.extend(train_hooks) train_run_hooks.extend(train_hooks)
train_executor = F.Executor(F.cuda_places()[0])
train_executor = F.Executor(_get_one_place())
mon_exe = MonitoredExecutor( mon_exe = MonitoredExecutor(
train_executor, train_executor,
...@@ -297,24 +275,29 @@ class Learner(object): ...@@ -297,24 +275,29 @@ class Learner(object):
return mon_exe.result return mon_exe.result
def evaluate(self, eval_dataset, eval_hooks=[]): def evaluate(self, eval_dataset, eval_hooks=[]):
"""eval on a `Dataset`"""
if not isinstance(eval_dataset, Dataset): if not isinstance(eval_dataset, Dataset):
raise ValueError('expect dataset to be instance of Dataset, got %s' raise ValueError('expect dataset to be instance of Dataset, got %s'
% repr(eval_dataset)) % repr(eval_dataset))
program, model_spec = self.build_for_eval(eval_dataset) program, model_spec = self._build_for_eval(eval_dataset)
single_card_place = F.cuda_places()[0] single_card_place = _get_one_place()
eval_executor = F.Executor(single_card_place) eval_executor = F.Executor(single_card_place)
eval_hooks = [ eval_run_hooks = [
hooks.StopAtStepHook(self.run_config.eval_max_steps, hooks.StopAtStepHook(self.run_config.eval_max_steps,
self.run_config.eval_max_steps), self.run_config.eval_max_steps),
hooks.EvalHook(model_spec.metrics, ) hooks.EvalHook(model_spec.metrics, )
] ]
if model_spec.eval_hooks is not None:
eval_run_hooks.extend(model_spec.eval_hooks)
eval_run_hooks.extend(eval_hooks)
mon_exe = MonitoredExecutor( mon_exe = MonitoredExecutor(
eval_executor, eval_executor,
program, program,
run_config=self.run_config, run_config=self.run_config,
run_hooks=eval_hooks) run_hooks=eval_run_hooks)
mon_exe.init_or_restore_variables() mon_exe.init_or_restore_variables()
try: try:
...@@ -326,32 +309,43 @@ class Learner(object): ...@@ -326,32 +309,43 @@ class Learner(object):
_, eval_result = mon_exe.result _, eval_result = mon_exe.result
summary_writer = get_summary_writer( summary_writer = _get_summary_writer(
os.path.join(self.run_config.model_dir, 'eval_history')) os.path.join(self.run_config.model_dir, 'eval_history'))
log_eval_result('eval', eval_result, summary_writer, mon_exe.state) _log_eval_result('eval', eval_result, summary_writer, mon_exe.state)
return mon_exe.result return mon_exe.result
def predict(self, predict_dataset, ckpt=None, steps=-1, split_batch=True): def predict(self,
''' predict_dataset,
ckpt=-1,
ckpt_path=None,
steps=-1,
split_batch=True):
"""
Perform predictoin Perform predictoin
will call `model_fn` and initiate user-specifed model in `propeller.RunMode.PREDICT` mode will call `model_fn` and initiate user-specifed model in `propeller.RunMode.PREDICT` mode
Args: Args:
infer_dataset (propeller.data.Dataset): should not `shuffle` or `repeat` infer_dataset (propeller.data.Dataset): should not `shuffle` or `repeat`
steps (int): steps to predict, if -1 is specifed, will stop when `StopException` is raised in `infer_dataset` steps (int): steps to predict, if None is specifed,
will stop when `StopException` is raised in `infer_dataset`
ckpt_path (None|str): Path of a specific checkpoint to predict.
If None, the latest checkpoint in model_dir is used.
If there are no checkpoints in model_dir,
prediction is run with newly initialized Variables instead of ones restored from checkpoint.
ckpt (int): deprecated args
split_batch (bool): if True, prediction of each example in a batch is returned. split_batch (bool): if True, prediction of each example in a batch is returned.
Yields: Yields:
Evaluated values of predictions tensors. Evaluated values of predictions tensors.
''' """
if not isinstance(predict_dataset, Dataset): if not isinstance(predict_dataset, Dataset):
raise ValueError('expect dataset to be instance of Dataset, got %s' raise ValueError('expect dataset to be instance of Dataset, got %s'
% repr(predict_dataset)) % repr(predict_dataset))
program, model_spec = self.build_for_predict(predict_dataset) program, model_spec = self._build_for_predict(predict_dataset)
single_card_place = F.cuda_places()[0] single_card_place = _get_one_place()
executor = F.Executor(single_card_place) executor = F.Executor(single_card_place)
pred_run_config = RunConfig( pred_run_config = RunConfig(
run_steps=steps if steps == -1 else None, run_steps=steps if steps == -1 else None,
...@@ -360,11 +354,12 @@ class Learner(object): ...@@ -360,11 +354,12 @@ class Learner(object):
executor, executor,
program, program,
run_config=pred_run_config, ) run_config=pred_run_config, )
mon_exe.init_or_restore_variables() mon_exe.init_or_restore_variables(ckpt
if ckpt_path is None else ckpt_path)
try: try:
with mon_exe: with mon_exe:
log.info('Runining predict from dir: %s' % repr(mon_exe.state)) log.info('Runining predict from dir: %s' % repr(mon_exe.state))
single_card_place = F.cuda_places()[0] single_card_place = _get_one_place()
for data in predict_dataset.start(places=[single_card_place]): for data in predict_dataset.start(places=[single_card_place]):
res = mon_exe.run(fetch_list=model_spec.predictions, res = mon_exe.run(fetch_list=model_spec.predictions,
feed=data) feed=data)
...@@ -379,7 +374,7 @@ class Learner(object): ...@@ -379,7 +374,7 @@ class Learner(object):
pass pass
def train_and_eval(_shit=None, def train_and_eval(_placeholder=None,
model_class_or_model_fn=None, model_class_or_model_fn=None,
params=None, params=None,
run_config=None, run_config=None,
...@@ -389,36 +384,27 @@ def train_and_eval(_shit=None, ...@@ -389,36 +384,27 @@ def train_and_eval(_shit=None,
train_hooks=[], train_hooks=[],
eval_hooks=[], eval_hooks=[],
exporters=[]): exporters=[]):
''' """
Perform train and evaluate procesure. Perform train and evaluate procesure.
will call `model_fn` and initiate user-specifed model in `propeller.RunMode.PREDICT` mode will call `model_fn` and initiate user-specifed model in `propeller.RunMode.PREDICT` mode
Args: Args:
model_class_or_model_fn(callable|propeller.train.Model): `model_class_or_model_fn` be specified in 2 ways: model_class_or_model_fn(callable|propeller.train.Model): `model_class_or_model_fn` be specified in 2 ways:
1. subclass of propeller.train.Model which implements: 1. subclass of propeller.train.Model
1. \_\_init\_\_ (hyper_param, mode, run_config) 2. a model_fn takes following args: 1. features; 2. param; 3. mode; 4. run_config(optional)
2. forward (features) => (prediction)
3. backword (loss) => None
4. loss (predictoin) => (loss)
5. metrics (optional) (prediction) => (dict of propeller.Metrics)
2. a model_fn takes following args:
1. features
2. param
3. mode
4. run_config(optional)
and returns a `propeller.ModelSpec` and returns a `propeller.ModelSpec`
params: any python object, will pass to your `model_fn` or `propeller.train.Model` params: any python object, will pass to your `model_fn` or `propeller.train.Model`
run_config (propeller.RunConfig): run_config.max_steps should not be None. run_config (propeller.RunConfig): run_config.max_steps should not be None.
train_dataset (propeller.paddle.data.Dataset): training will stop if global_step > run_config.max_steps. train_dataset (propeller.paddle.data.Dataset): training will stop if global_step > run_config.max_steps.
eval_dataset (propeller.paddle.data.Dataset|dict): Optional, if Dict of propeller.data.Dataset were specified, will perform evluatation on every evaluation sets and report results. eval_dataset (propeller.paddle.data.Dataset|dict): Optional, if Dict of propeller.data.Dataset were specified,
will perform evluatation on every evaluation sets and report results.
warm_start_setting (propeller.WarmStartSetting): Optional. warm start variable will overwrite model variable. warm_start_setting (propeller.WarmStartSetting): Optional. warm start variable will overwrite model variable.
train_hooks (list of propeller.paddle.train.RunHook): Optional. train_hooks (list of propeller.paddle.train.RunHook): Optional.
eval_hooks (list of propeller.paddle.train.RunHook): Optional. eval_hooks (list of propeller.paddle.train.RunHook): Optional.
exporters (list of propeller.paddle.train.Exporter): Optional. exporters (list of propeller.paddle.train.Exporter): Optional.
''' """
if _shit is not None: if _placeholder is not None:
raise ValueError('specify keyword args to this function') raise ValueError('specify keyword args to this function')
if model_class_or_model_fn is None or params is None or run_config is None or train_dataset is None: if model_class_or_model_fn is None or params is None or run_config is None or train_dataset is None:
raise ValueError( raise ValueError(
...@@ -454,13 +440,13 @@ def train_and_eval(_shit=None, ...@@ -454,13 +440,13 @@ def train_and_eval(_shit=None,
params, params,
warm_start_setting=warm_start_setting) warm_start_setting=warm_start_setting)
class EvalHookOnTrainLoop(hooks.RunHook): class _EvalHookOnTrainLoop(hooks.RunHook):
def __init__(self): def __init__(self):
self.program, self.model_spec = est.build_for_eval( self.program, self.model_spec = est._build_for_eval(
list(eval_dataset.values())[ list(eval_dataset.values())[
0]) #eval_datasets must have same output shapes 0]) #eval_datasets must have same output shapes
self.summary_writers = { self.summary_writers = {
ds_name: get_summary_writer( ds_name: _get_summary_writer(
os.path.join( os.path.join(
os.path.join(run_config.model_dir, 'eval_history'), os.path.join(run_config.model_dir, 'eval_history'),
ds_name)) ds_name))
...@@ -468,6 +454,7 @@ def train_and_eval(_shit=None, ...@@ -468,6 +454,7 @@ def train_and_eval(_shit=None,
} }
def after_run(self, _, state): def after_run(self, _, state):
"""doc"""
if state.step > run_config.skip_steps and state.gstep % run_config.eval_steps == 0: if state.step > run_config.skip_steps and state.gstep % run_config.eval_steps == 0:
eval_results = {} eval_results = {}
for name, ds in six.iteritems(eval_dataset): for name, ds in six.iteritems(eval_dataset):
...@@ -478,7 +465,7 @@ def train_and_eval(_shit=None, ...@@ -478,7 +465,7 @@ def train_and_eval(_shit=None,
self.model_spec.metrics, self.model_spec.metrics,
summary_writer=self.summary_writers[name], ) summary_writer=self.summary_writers[name], )
] ]
single_card_place = F.cuda_places()[0] single_card_place = _get_one_place()
eval_executor = F.Executor(single_card_place) eval_executor = F.Executor(single_card_place)
mon_exe = MonitoredExecutor( mon_exe = MonitoredExecutor(
eval_executor, eval_executor,
...@@ -495,8 +482,8 @@ def train_and_eval(_shit=None, ...@@ -495,8 +482,8 @@ def train_and_eval(_shit=None,
eval_res = hook_results[ eval_res = hook_results[
1] # hook_results: [StopAtStepHook, EvalHook, ...] 1] # hook_results: [StopAtStepHook, EvalHook, ...]
eval_results[name] = eval_res eval_results[name] = eval_res
log_eval_result(name, eval_res, self.summary_writers[name], _log_eval_result(name, eval_res,
state) self.summary_writers[name], state)
for exporter in exporters: for exporter in exporters:
exporter.export(eval_executor, self.program, exporter.export(eval_executor, self.program,
self.model_spec, eval_results, state) self.model_spec, eval_results, state)
...@@ -505,6 +492,46 @@ def train_and_eval(_shit=None, ...@@ -505,6 +492,46 @@ def train_and_eval(_shit=None,
return eval_results return eval_results
if distribution.status.is_master: if distribution.status.is_master:
train_hooks.append(EvalHookOnTrainLoop()) train_hooks.append(_EvalHookOnTrainLoop())
res = est.train(train_dataset, train_hooks=train_hooks) res = est.train(train_dataset, train_hooks=train_hooks)
return res return res
def _build_model_fn(model_class):
def _model_fn(features, mode, params, run_config):
if mode != RunMode.PREDICT:
fea, label = features[:-1], features[-1]
else:
fea = features
model = model_class(params, mode, run_config=run_config)
pred = model.forward(fea)
if isinstance(pred, F.framework.Variable):
prediction = [pred]
else:
prediction = pred
if mode == RunMode.TRAIN:
loss = model.loss(pred, label)
model.backward(loss)
return ModelSpec(loss=loss, predictions=prediction, mode=mode)
elif mode == RunMode.EVAL:
loss = model.loss(pred, label)
me = model.metrics(pred, label)
inf_spec = InferenceSpec(inputs=fea, outputs=prediction)
if 'loss' not in me:
me['loss'] = metrics.Mean(loss)
return ModelSpec(
loss=loss,
predictions=prediction,
metrics=me,
mode=mode,
inference_spec=inf_spec)
elif mode == RunMode.PREDICT:
inf_spec = InferenceSpec(inputs=fea, outputs=prediction)
return ModelSpec(
predictions=prediction, mode=mode, inference_spec=inf_spec)
else:
raise RuntimeError('unknown run mode %s' % mode)
return _model_fn
...@@ -11,3 +11,4 @@ ...@@ -11,3 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""server"""
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Never Never Never import paddle.fluid in main process, or any module would import fluid.
"""
from __future__ import division from __future__ import division
from __future__ import absolute_import from __future__ import absolute_import
...@@ -24,27 +27,27 @@ from time import sleep, time ...@@ -24,27 +27,27 @@ from time import sleep, time
import multiprocessing import multiprocessing
import zmq import zmq
""" Never Never Never import paddle.fluid in main process, or any module would import fluid.
"""
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def profile(msg): def _profile(msg):
def decfn(fn): def _decfn(fn):
def retfn(*args, **kwargs): def _retfn(*args, **kwargs):
start = time() start = time()
ret = fn(*args, **kwargs) ret = fn(*args, **kwargs)
end = time() end = time()
log.debug('%s timecost: %.5f' % (msg, end - start)) log.debug('%s timecost: %.5f' % (msg, end - start))
return ret return ret
return retfn return _retfn
return decfn return _decfn
class Predictor(object): class Predictor(object):
"""paddle predictor wrapper"""
def __init__(self, model_dir, device_idx=0): def __init__(self, model_dir, device_idx=0):
import paddle.fluid as F import paddle.fluid as F
log.debug('create predictor on card %d' % device_idx) log.debug('create predictor on card %d' % device_idx)
...@@ -52,7 +55,7 @@ class Predictor(object): ...@@ -52,7 +55,7 @@ class Predictor(object):
config.enable_use_gpu(5000, device_idx) config.enable_use_gpu(5000, device_idx)
self._predictor = F.core.create_paddle_predictor(config) self._predictor = F.core.create_paddle_predictor(config)
@profile('paddle') @_profile('paddle')
def __call__(self, args): def __call__(self, args):
for i, a in enumerate(args): for i, a in enumerate(args):
a.name = 'placeholder_%d' % i a.name = 'placeholder_%d' % i
...@@ -61,6 +64,7 @@ class Predictor(object): ...@@ -61,6 +64,7 @@ class Predictor(object):
def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"): def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"):
"""worker process entrence"""
try: try:
log.debug("run_worker %s" % device_idx) log.debug("run_worker %s" % device_idx)
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv( os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv(
...@@ -97,6 +101,8 @@ def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"): ...@@ -97,6 +101,8 @@ def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"):
class InferencePredictor(object): class InferencePredictor(object):
"""control Predictor for multi gpu card"""
def __init__(self, backend_addr, model_dir, n_devices=1): def __init__(self, backend_addr, model_dir, n_devices=1):
self.backend_addr = backend_addr self.backend_addr = backend_addr
self.model_dir = model_dir self.model_dir = model_dir
...@@ -104,6 +110,7 @@ class InferencePredictor(object): ...@@ -104,6 +110,7 @@ class InferencePredictor(object):
self.children = [] self.children = []
def start(self): def start(self):
"""doc"""
for device_idx in range(self.n_devices): for device_idx in range(self.n_devices):
p = multiprocessing.Process( p = multiprocessing.Process(
target=run_worker, target=run_worker,
...@@ -113,21 +120,27 @@ class InferencePredictor(object): ...@@ -113,21 +120,27 @@ class InferencePredictor(object):
return self return self
def join(self): def join(self):
"""doc"""
for p in self.children: for p in self.children:
p.join() p.join()
def term(self): def term(self):
"""doc"""
for p in self.children: for p in self.children:
log.debug("terminating children %s" % repr(p)) log.debug("terminating children %s" % repr(p))
p.terminate() p.terminate()
class InferenceProxy(object): class InferenceProxy(object):
"""zmq proxy"""
def __init__(self): def __init__(self):
"""doc"""
self.backend = None self.backend = None
self.frontend = None self.frontend = None
def listen(self, frontend_addr, backend_addr): def listen(self, frontend_addr, backend_addr):
"""doc"""
log.info("InferenceProxy starting...") log.info("InferenceProxy starting...")
try: try:
context = zmq.Context(1) context = zmq.Context(1)
...@@ -152,11 +165,15 @@ class InferenceProxy(object): ...@@ -152,11 +165,15 @@ class InferenceProxy(object):
class InferenceServer(object): class InferenceServer(object):
"""start InferencePredictor and InferenceProxy"""
def __init__(self, model_dir, n_devices): def __init__(self, model_dir, n_devices):
"""doc"""
self.model_dir = model_dir self.model_dir = model_dir
self.n_devices = n_devices self.n_devices = n_devices
def listen(self, port): def listen(self, port):
"""doc"""
frontend_addr = "tcp://*:%s" % port frontend_addr = "tcp://*:%s" % port
backend_addr = "ipc://backend.ipc" backend_addr = "ipc://backend.ipc"
predictor = InferencePredictor(backend_addr, self.model_dir, predictor = InferencePredictor(backend_addr, self.model_dir,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""utils for server"""
from __future__ import division from __future__ import division
from __future__ import absolute_import from __future__ import absolute_import
...@@ -26,6 +27,7 @@ from propeller.service import interface_pb2 ...@@ -26,6 +27,7 @@ from propeller.service import interface_pb2
def slot_to_numpy(slot): def slot_to_numpy(slot):
"""doc"""
if slot.type == interface_pb2.Slot.FP32: if slot.type == interface_pb2.Slot.FP32:
dtype = np.float32 dtype = np.float32
type_str = 'f' type_str = 'f'
...@@ -45,6 +47,7 @@ def slot_to_numpy(slot): ...@@ -45,6 +47,7 @@ def slot_to_numpy(slot):
def numpy_to_slot(arr): def numpy_to_slot(arr):
"""doc"""
if arr.dtype == np.float32: if arr.dtype == np.float32:
dtype = interface_pb2.Slot.FP32 dtype = interface_pb2.Slot.FP32
elif arr.dtype == np.int32: elif arr.dtype == np.int32:
...@@ -59,25 +62,30 @@ def numpy_to_slot(arr): ...@@ -59,25 +62,30 @@ def numpy_to_slot(arr):
def slot_to_paddlearray(slot): def slot_to_paddlearray(slot):
"""doc"""
import paddle.fluid.core as core import paddle.fluid.core as core
if slot.type == interface_pb2.Slot.FP32: if slot.type == interface_pb2.Slot.FP32:
dtype = np.float32
type_str = 'f' type_str = 'f'
dtype = core.PaddleDType.FLOAT32
elif slot.type == interface_pb2.Slot.INT32: elif slot.type == interface_pb2.Slot.INT32:
dtype = np.int32
type_str = 'i' type_str = 'i'
dtype = core.PaddleDType.INT32
elif slot.type == interface_pb2.Slot.INT64: elif slot.type == interface_pb2.Slot.INT64:
dtype = np.int64
type_str = 'q' type_str = 'q'
dtype = core.PaddleDType.INT64
else: else:
raise RuntimeError('know type %s' % slot.type) raise RuntimeError('know type %s' % slot.type)
ret = core.PaddleTensor()
ret.shape = slot.dims
ret.dtype = dtype
num = len(slot.data) // struct.calcsize(type_str) num = len(slot.data) // struct.calcsize(type_str)
arr = struct.unpack('%d%s' % (num, type_str), slot.data) arr = struct.unpack('%d%s' % (num, type_str), slot.data)
ret = core.PaddleTensor(np.array(arr, dtype=dtype).reshape(slot.dims)) ret.data = core.PaddleBuf(arr)
return ret return ret
def paddlearray_to_slot(arr): def paddlearray_to_slot(arr):
"""doc"""
import paddle.fluid.core as core import paddle.fluid.core as core
if arr.dtype == core.PaddleDType.FLOAT32: if arr.dtype == core.PaddleDType.FLOAT32:
dtype = interface_pb2.Slot.FP32 dtype = interface_pb2.Slot.FP32
...@@ -99,12 +107,14 @@ def paddlearray_to_slot(arr): ...@@ -99,12 +107,14 @@ def paddlearray_to_slot(arr):
def nparray_list_serialize(arr_list): def nparray_list_serialize(arr_list):
"""doc"""
slot_list = [numpy_to_slot(arr) for arr in arr_list] slot_list = [numpy_to_slot(arr) for arr in arr_list]
slots = interface_pb2.Slots(slots=slot_list) slots = interface_pb2.Slots(slots=slot_list)
return slots.SerializeToString() return slots.SerializeToString()
def nparray_list_deserialize(string): def nparray_list_deserialize(string):
"""doc"""
slots = interface_pb2.Slots() slots = interface_pb2.Slots()
slots.ParseFromString(string) slots.ParseFromString(string)
return [slot_to_numpy(slot) for slot in slots.slots] return [slot_to_numpy(slot) for slot in slots.slots]
...@@ -72,6 +72,10 @@ def parse(filename): ...@@ -72,6 +72,10 @@ def parse(filename):
elif proto.data_type == framework_pb2.VarType.INT8: elif proto.data_type == framework_pb2.VarType.INT8:
arr = np.array( arr = np.array(
gen_arr(f.read(), 'B'), dtype=np.int8).reshape(proto.dims) gen_arr(f.read(), 'B'), dtype=np.int8).reshape(proto.dims)
elif proto.data_type == framework_pb2.VarType.FP16:
arr = np.array(
gen_arr(f.read(), 'H'),
dtype=np.uint16).view(np.float16).reshape(proto.dims)
else: else:
raise RuntimeError('Unknown dtype %s' % proto.data_type) raise RuntimeError('Unknown dtype %s' % proto.data_type)
......
...@@ -11,3 +11,6 @@ ...@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
doc
"""
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Model template
"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
...@@ -26,7 +30,11 @@ import numpy as np ...@@ -26,7 +30,11 @@ import numpy as np
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class Model(): class Model(object):
"""
template
"""
def __init__(self, config, mode): def __init__(self, config, mode):
""" """
Args: Args:
...@@ -39,9 +47,9 @@ class Model(): ...@@ -39,9 +47,9 @@ class Model():
def forward(self, features): def forward(self, features):
""" """
Args: Args:
features (list of Tensor): depends on your Dataset.output_shapes features (list of Tensor): inputs features that depends on your Dataset.output_shapes
Returns: Returns:
return (Tensor): return (Tensor): prediction
""" """
pass pass
...@@ -53,8 +61,6 @@ class Model(): ...@@ -53,8 +61,6 @@ class Model():
label (Tensor): depends on your Dataset.output_shapes label (Tensor): depends on your Dataset.output_shapes
Returns: Returns:
return (paddle scalar): loss return (paddle scalar): loss
""" """
pass pass
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Basic types"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
...@@ -21,12 +23,15 @@ from collections import namedtuple ...@@ -21,12 +23,15 @@ from collections import namedtuple
class RunMode(object): class RunMode(object):
"""model_fn will be called in 3 modes"""
TRAIN = 1 TRAIN = 1
PREDICT = 2 PREDICT = 2
EVAL = 3 EVAL = 3
class HParams(object): class HParams(object):
"""Hyper paramerter"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
self.__dict__[k] = v self.__dict__[k] = v
...@@ -45,30 +50,36 @@ class HParams(object): ...@@ -45,30 +50,36 @@ class HParams(object):
def __setitem__(self, key, val): def __setitem__(self, key, val):
self.__dict__[key] = val self.__dict__[key] = val
@staticmethod @classmethod
def from_json(self, json_str): def from_json(cls, json_str):
"""doc"""
d = json.loads(json_str) d = json.loads(json_str)
if type(d) != dict: if type(d) != dict:
raise ValueError('json object must be dict.') raise ValueError('json object must be dict.')
return HParams.from_dict(d) return HParams.from_dict(d)
def get(self, key, default=None): def get(self, key, default=None):
"""doc"""
return self.__dict__.get(key, default) return self.__dict__.get(key, default)
@staticmethod @classmethod
def from_dict(self, d): def from_dict(cls, d):
"""doc"""
if type(d) != dict: if type(d) != dict:
raise ValueError('input must be dict.') raise ValueError('input must be dict.')
hp = HParams(**d) hp = HParams(**d)
return hp return hp
def to_json(self): def to_json(self):
"""doc"""
return json.dumps(self.__dict__) return json.dumps(self.__dict__)
def to_dict(self): def to_dict(self):
"""doc"""
return self.__dict__ return self.__dict__
def join(self, other): def join(self, other):
"""doc"""
if not isinstance(other, HParams): if not isinstance(other, HParams):
raise ValueError('input must be HParams instance.') raise ValueError('input must be HParams instance.')
self.__dict__.update(**other.__dict__) self.__dict__.update(**other.__dict__)
...@@ -95,9 +106,12 @@ ModelSpec = namedtuple('ModelSpec', [ ...@@ -95,9 +106,12 @@ ModelSpec = namedtuple('ModelSpec', [
'metrics', 'metrics',
'mode', 'mode',
'inference_spec', 'inference_spec',
'train_hooks',
'eval_hooks',
]) ])
ModelSpec.__new__.__defaults__ = (None, ) * len(ModelSpec._fields) ModelSpec.__new__.__defaults__ = (None, ) * len(ModelSpec._fields)
class StopException(Exception): class StopException(Exception):
"""doc"""
pass pass
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""global utils"""
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
...@@ -31,6 +32,7 @@ log = logging.getLogger(__name__) ...@@ -31,6 +32,7 @@ log = logging.getLogger(__name__)
def ArgumentParser(name): def ArgumentParser(name):
"""predefined argparser"""
parser = argparse.ArgumentParser('propeller model') parser = argparse.ArgumentParser('propeller model')
parser.add_argument('--run_config', type=str, default='') parser.add_argument('--run_config', type=str, default='')
parser.add_argument( parser.add_argument(
...@@ -59,6 +61,7 @@ def _get_dict_from_environ_or_json_or_file(args, env_name): ...@@ -59,6 +61,7 @@ def _get_dict_from_environ_or_json_or_file(args, env_name):
def parse_file(filename): def parse_file(filename):
"""useless api"""
d = _get_dict_from_environ_or_json_or_file(filename, None) d = _get_dict_from_environ_or_json_or_file(filename, None)
if d is None: if d is None:
raise ValueError('file(%s) not found' % filename) raise ValueError('file(%s) not found' % filename)
...@@ -66,6 +69,7 @@ def parse_file(filename): ...@@ -66,6 +69,7 @@ def parse_file(filename):
def parse_runconfig(args=None): def parse_runconfig(args=None):
"""get run_config from env or file"""
d = _get_dict_from_environ_or_json_or_file(args.run_config, d = _get_dict_from_environ_or_json_or_file(args.run_config,
'PROPELLER_RUNCONFIG') 'PROPELLER_RUNCONFIG')
if d is None: if d is None:
...@@ -74,6 +78,7 @@ def parse_runconfig(args=None): ...@@ -74,6 +78,7 @@ def parse_runconfig(args=None):
def parse_hparam(args=None): def parse_hparam(args=None):
"""get hparam from env or file"""
if args is not None: if args is not None:
hparam_strs = reduce(list.__add__, args.hparam) hparam_strs = reduce(list.__add__, args.hparam)
else: else:
...@@ -91,6 +96,7 @@ def parse_hparam(args=None): ...@@ -91,6 +96,7 @@ def parse_hparam(args=None):
def flatten(s): def flatten(s):
"""doc"""
assert is_struture(s) assert is_struture(s)
schema = [len(ss) for ss in s] schema = [len(ss) for ss in s]
flt = list(itertools.chain(*s)) flt = list(itertools.chain(*s))
...@@ -98,6 +104,7 @@ def flatten(s): ...@@ -98,6 +104,7 @@ def flatten(s):
def unflatten(structure, schema): def unflatten(structure, schema):
"""doc"""
start = 0 start = 0
res = [] res = []
for _range in schema: for _range in schema:
...@@ -107,10 +114,12 @@ def unflatten(structure, schema): ...@@ -107,10 +114,12 @@ def unflatten(structure, schema):
def is_struture(s): def is_struture(s):
"""doc"""
return isinstance(s, list) or isinstance(s, tuple) return isinstance(s, list) or isinstance(s, tuple)
def map_structure(func, s): def map_structure(func, s):
"""same sa tf.map_structure"""
if isinstance(s, list) or isinstance(s, tuple): if isinstance(s, list) or isinstance(s, tuple):
return [map_structure(func, ss) for ss in s] return [map_structure(func, ss) for ss in s]
elif isinstance(s, dict): elif isinstance(s, dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册