# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module provide a wrapper(decorator) to wrap a data process method into a PyDataProvider. Some examples are shown `here `_. """ import struct import array import random import gc import logging import pstats import sys import numpy import functools __all__ = [ 'DenseSlot', 'SlotType', 'SparseNonValueSlot', 'StringSlot', 'SparseValueSlot', 'IndexSlot', 'PoolSize', 'GeneralPyDataProvider', 'provider', 'init_hook_wrapper' ] try: # Just for profile mode, will try to import cProfile first. # Most python will contains cProfile, cProfile/profile are basically same. # ref: https://docs.python.org/2/library/profile.html#introduction-to-the-profilers import cProfile as profile except ImportError: import profile try: import cPickle as pickle except ImportError: import pickle import io class SlotType(object): # Just a hint for user. pass class DenseSlot(SlotType): """ Dense Slot Type: Each item is the value of a Dense Vector. Its yield format for :code:`provider` is: - **NonSeq**: [float, float, ... ] - **Seq**: [[float, float, ...], [float, float ....], ... ] - **SubSeq**: [[[float, float, ...], [float ....], ...] , \ [[float, float, ...], [float ....], ...] , ...] """ def __init__(self, dim): """ :param dim: slot dimension :type dim: int """ self.dim = dim self.type = 0 class SparseNonValueSlot(SlotType): """ Sparse NonValue Slot Type: Each item is the id of a Sparse Vector. Its yield format for :code:`provider` is: - **NonSeq**: [int, int, ...] - **Seq**: [[int, int, ...], [int, int, ...], ... ] - **SubSeq**: [[[int, int, ...], [int, ....], ...] , \ [[int, int, ...], [int, ....], ...] , ...] """ def __init__(self, dim): """ :param dim: slot dimension :type dim: int """ self.dim = dim self.type = 1 class SparseValueSlot(SlotType): """ Sparse Value Slot Type: Each item is the id and value of a Sparse Vector. Its yield format for :code:`provider` is: - **NonSeq**: [(int, float), (int, float), ... ] - **Seq**: [[(int,float), (int, float), ... ], \ [(int, float), (int, float), ...], ... ] - **SubSeq**: [[[(int,float), ...], [(int, float), ....], ...] , \ [[(int,float), ...], [(int, float), ....], ...] , ...] """ def __init__(self, dim): """ :param dim: slot dimension. :type dim: int """ self.dim = dim self.type = 2 class IndexSlot(SlotType): """ Index Value Slot Type: Each item is the id of Label. Its yield format for :code:`provider` is: - **NonSeq**: int - **Seq**: [int, int, ....] - **SubSeq**: [[int, int, ...], [int, int, ...], ... ] """ def __init__(self, dim): """ :param dim: slot dimension :type dim: int """ self.dim = dim self.type = 3 class StringSlot(SlotType): """ String Value Slot Type: Each item is a string for printout, \ can be used in DataLayer too. Its yield format for :code:`provider` is: - **NonSeq**: string - **Seq**: [string, string, ....] - **SubSeq**: [[string, string, ...], [string, string, ...], ... ] """ def __init__(self, dim): """ :param dim: slot dimension :type dim: string """ self.dim = dim self.type = 6 class SparseNonValueHandler(object): """ Private Class, Use for converting python object to paddle string. """ def __init__(self): self.offsets = [] self.value = [] self.offset_count = 0 def __call__(self, ele): """ It will be invoked when scan each sparse data. :param ele: list of sparse data, maybe non-value [ idx, ... ] or value. [ (idx, val), ... ] :type ele: list """ self.offsets.append(self.offset_count) self.offset_count += len(ele) self.processElement(ele) def processElement(self, ele): """ Process for element list. See __call__ for more document. """ self.value += ele def done(self, data_stream, int_packer): """ Dump data to stream. :param data_stream: Output Stream. :param int_packer: A struct.Struct("i") object """ data_stream.write(array.array("i", self.offsets).tostring()) data_stream.write(int_packer.pack(self.offset_count)) data_stream.write(array.array("i", self.value).tostring()) class SparseValueHandler(SparseNonValueHandler): """ Private class, use for converting python obj to paddle string. """ def __init__(self): SparseNonValueHandler.__init__(self) self.weight = [] def processElement(self, ele): for idx, w in ele: self.value.append(idx) self.weight.append(w) def done(self, data_stream, int_packer): SparseNonValueHandler.done(self, data_stream, int_packer) data_stream.write(int_packer.pack(self.offset_count)) data_stream.write(array.array("f", self.weight).tostring()) class StringHandler(object): """ Private Class, Use for converting python object to paddle string. """ def __init__(self, data_stream, int_packer): self.data_stream = data_stream self.int_packer = int_packer def __call__(self, ele): """ It will be invoked when scan each string data. :param ele: string data :type ele: str """ self.data_stream.write(self.int_packer.pack(len(ele))) self.data_stream.write(array.array("c", ele).tostring()) class GeneralPyDataProvider: def __init__(self, *file_list, **kwargs): """ :param file_list: input file_list """ del kwargs # unused gc.disable() assert isinstance(self.logger, logging.Logger) self.use_seq_flag = hasattr(self, "use_seq_flag") and self.use_seq_flag self.slots_num = len(self.getSlots()) self.file_list = list(file_list) self.generators = map(self.generateData, self.file_list) self.int_packer = struct.Struct("i") self.head_packer = struct.Struct("ii") self.float_packer = struct.Struct("f") self.shuffler = lambda *args, **kwargs: None self.data_pool = [] self.has_subseq = [] self.has_checked = False self.debug = hasattr(self, "debug") and self.debug if hasattr(self, "profile_filename") and isinstance( self.profile_filename, str): self.profile_count = 0 self.is_profile = True else: self.is_profile = False if not hasattr(self, "file_count") or not isinstance(self.file_count, int): self.file_count = sys.maxint if not hasattr(self, "can_over_batch_size"): self.can_over_batch_size = True elif not self.can_over_batch_size: self.logger.warn( "User should ensure every data size is not larger than batch" " size when can_over_batch_size = False") self.data_pool_idx = 0 def reset(self): """Reset all data in provider.""" self.logger.debug("reset dataprovider.") self.generators = map(self.generateData, self.file_list) self.shuffler = lambda *args, **kwargs: None self.data_pool = [] self.data_pool_idx = 0 if self.file_count != 0: self.max_pool_size = 0 # When use Profile, each pass will print a profile result. if self.is_profile: if hasattr(self, "profiler") and isinstance(self.profiler, profile.Profile): self.profiler.disable() fn = "%s_%d" % (self.profile_filename, self.profile_count) sortby = "cumulative" with open(fn, "w") as f: pstats.Stats(self.profiler, stream=f).sort_stats( sortby).print_stats() self.logger.info("saving profile to file %s" % fn) self.profile_count += 1 self.logger.info("resetting profile") self.profiler = profile.Profile() self.profiler.enable() def shuffle(self): """ shuffle data""" if not self.should_shuffle: return else: self.logger.debug("shuffling data.") random.shuffle(self.generators) self.shuffler = random.shuffle def getSlots(self): """ :return : return a list of SlotType :rtype: list """ return [] def generateData(self, fn): """ :param fn: file name :return: a generator to yield data one by one. """ raise NotImplementedError def calculateDataBatchSize(self, data): """ :param data: One sample which yield by generateData :type data: list :return: The batch size that the data contribute. :rtype: int """ return 1 def getHeader(self): """return paddle header format""" ret = self.head_packer.pack(self.slots_num, self.use_seq_flag) for obj in self.getSlots(): ret += self.head_packer.pack(obj.type, obj.dim) return ret def getHeaderNative(self): return self.use_seq_flag, self.getSlots() def getNextBatchNative(self, batch_size): ret_list = [] self.__prepareData(batch_size, ret_list) return ret_list def getNextBatch(self, batch_size): """ :param batch_size: the batch_size approximately return. :return: return paddle pyDataProvider format, just see documents. :rtype: str NOTE: If can_over_batch_size is True, the return batch_size >= input batch_size. Otherwise, the return batch_size < input batch_size, BUT USER MUST ENSURE THAT each data's batch size is less than input batch_size. """ ret_list = [] current_batch_size = self.__prepareData(batch_size, ret_list) # create unified format for ret_list with differnt slots_num if self.slots_num == 1: ret_list = [ret_list] if current_batch_size == 0: return self.int_packer.pack(current_batch_size) data_bytes = io.BytesIO() seq_bytes = io.BytesIO() subseq_bytes = io.BytesIO() data_stream = io.BufferedWriter(data_bytes) seq_stream = io.BufferedWriter(seq_bytes) subseq_stream = io.BufferedWriter(subseq_bytes) def convertDataImpl(idx, data_callback): """ This method will handle sequence in return data. invoke data_callback one by one. :param idx: the slot index. :param data_callback: a callback, which type is (each sample) => None. """ indices = 0 slot_sample_num = len(ret_list) if self.use_seq_flag: slot_sample_num = 0 if self.has_subseq[idx]: # has sub-sequence slot_subseq_num = 0 for dat in ret_list: dat = dat[idx] slot_subseq_num += len(dat) for sub_dat in dat: slot_sample_num += len(sub_dat) subseq_stream.write(self.int_packer.pack(slot_subseq_num)) else: for dat in ret_list: dat = dat[idx] slot_sample_num += len(dat) seq_stream.write(self.int_packer.pack(len(ret_list))) data_stream.write(self.int_packer.pack(slot_sample_num)) for dat in ret_list: dat = dat[idx] if self.use_seq_flag: seq_stream.write(self.int_packer.pack(indices)) if self.has_subseq[idx]: # has sub-sequence for sub_dat in dat: writeDataStream(sub_dat, data_callback) subseq_stream.write(self.int_packer.pack(indices)) indices += len(sub_dat) else: writeDataStream(dat, data_callback) indices += len(dat) else: writeDataStream(dat, data_callback) def writeDataStream(dat, data_callback): if self.use_seq_flag > 0: if data_callback is None: # Special for index slot data_stream.write(array.array("i", dat).tostring()) else: for ele in dat: data_callback(ele) else: if data_callback is None: # Special for index slot data_stream.write(self.int_packer.pack(dat)) else: data_callback(dat) try: for i in range(self.slots_num): slot = self.getSlots()[i] # According to the data_type, each slot data will be converted to binary if isinstance(slot, DenseSlot): convertDataImpl(i, lambda e: data_stream.write( array.array("f", e).tostring())) elif isinstance(slot, SparseNonValueSlot): handler = SparseNonValueHandler() convertDataImpl(i, handler) handler.done(data_stream, self.int_packer) elif isinstance(slot, SparseValueSlot): handler = SparseValueHandler() convertDataImpl(i, handler) handler.done(data_stream, self.int_packer) elif isinstance(slot, IndexSlot): convertDataImpl(i, None) elif isinstance(slot, StringSlot): handler = StringHandler(data_stream, self.int_packer) convertDataImpl(i, handler) else: raise RuntimeError("The data_type must be 0/1/2/3/6") data_stream.flush() seq_stream.flush() subseq_stream.flush() return "".join([self.int_packer.pack(current_batch_size), data_bytes.getvalue(), seq_bytes.getvalue(), subseq_bytes.getvalue()]) finally: data_stream.close() seq_stream.close() subseq_stream.close() data_bytes.close() seq_bytes.close() subseq_bytes.close() def hasSubseq(self, ret_list): # create unified format for ret_list with differnt slots_num if self.slots_num == 1: ret_list = [ret_list] # decide whether slot has sub-sequence using its first sample for i in range(self.slots_num): slot = self.getSlots()[i] dat = ret_list[0][i][0] if isinstance(slot, IndexSlot) or isinstance(slot, StringSlot): if isinstance(dat, list) or isinstance(dat, numpy.ndarray): self.has_subseq.append(1) # has_subseq = True continue elif isinstance(dat[0], list) or isinstance(dat[0], numpy.ndarray): self.has_subseq.append(1) # has_subseq = True continue self.has_subseq.append(0) # has_subseq = False def checkOrder(self): first_noSubseq_slot = self.slots_num last_subseq_slot = -1 for i in range(self.slots_num): if not self.has_subseq[i]: first_noSubseq_slot = i break for i in range(self.slots_num): if self.has_subseq[i]: last_subseq_slot = i if first_noSubseq_slot < last_subseq_slot: raise RuntimeError( "slot hasSubseq must put before than slot without subseq") self.has_checked = True def __prepareData(self, batch_size, ret_list): current_batch_size = 0 could_exit = False while not could_exit: if len(self.data_pool) == 0: self.data_pool_idx = 0 self.fillPool() if len(self.data_pool) != 0: for idx in xrange(self.data_pool_idx, len(self.data_pool)): current_batch_size += self.calculateDataBatchSize( self.data_pool[idx]) if current_batch_size >= batch_size: could_exit = True break if current_batch_size > batch_size and not self.can_over_batch_size: # if cannot over batch size current_batch_size -= self.calculateDataBatchSize( self.data_pool[idx]) idx -= 1 ret_list += self.data_pool[self.data_pool_idx: idx + 1] # for speed reason, just shift left index, not delete data actually. self.data_pool_idx = idx + 1 if self.data_pool_idx == len(self.data_pool): self.data_pool = [] else: break if self.use_seq_flag and not self.has_checked: # compute self.has_subseq and checkOrder only at first time self.hasSubseq(ret_list) self.checkOrder() return current_batch_size def fillPool(self): """ Fill the pool to max_pool_size. If max_pool_size is None, then read file_count to pool. """ if self.max_pool_size == 0: for i in xrange(min(self.file_count, len(self.generators))): self.data_pool += list(self.generators[i]) self.generators = self.generators[ min(self.file_count, len(self.generators)):] self.max_pool_size = len(self.data_pool) else: while len(self.data_pool) < self.max_pool_size and len( self.generators) != 0: try: self.data_pool.append(self.generators[0].next()) except StopIteration: self.generators.pop(0) self.shuffler(self.data_pool) class PoolSize(object): """Max number of sample which contains in provider.""" def __init__(self, pool_size): self.size = pool_size def default_init_hook(cls, *args, **kwargs): """ default hook, do nothing """ del cls, args, kwargs def provider(slots=None, use_seq=False, should_shuffle=True, pool_size=1, can_over_batch_size=True, calc_batch_size=lambda data: 1, debug=False, init_hook=default_init_hook, profile_filename=None): """ The decorator for PyDataProvider. User should use this to create Provider class. User should only concern how to read sample from file. So the basic usage is: .. code-block:: python @provider(some data provider config here...) def process(obj, file_name): while not at end of file_name: sample = readOneSampleFromFile(file_name) yield sample. The configuration of data provider should be setup by: :param init_hook: A callback will be invoked when PyDataProvider instance \ created. The parameter is (obj, \*args, \*\*kwargs). - **obj**: actually data provider instance, which \ contains some global objects in obj.xxxxx, \ and is used by process function. 1. **obj.slots**: a list of SlotType Object. Can be \ set in init. For example, obj.slots = \ [DenseSlot(9), IndexSlot(2)]. 2. **obj.logger**: a logger object. User can invoke \ obj.logger.info(), obj.logger.fatal(), etc. - **args** and **kwargs**: the data provider __init__ \ parameters. For example, load_data_args \ will be found in \*\*kwargs, \ and if you want to recieve \ it from trainer_config, \ recommand to use init_hook_wrapper :type init_hook: callable :param pool_size: - **int**: it will read at most pool_size files to memory. - **PoolSize**: it will read at most PoolSize.size samples to memory. - If not set, it will read all the files to memory. :type pool_size: int | PoolSize :param slots: Specify the SlotTypes, can also be set in init_hook. It has two formats: - A list of SlotType objects. For example, slots = \ [DenseSlot(9), IndexSlot(2)]. - A method return a list of SlotTypes, and the parameter of \ method is (obj, \*file_list, \*\*kwargs). :type slots: list | callable :param use_seq: False if use no sequence (Default). True if use sequence: - If sequence has **no sub-sequence**: Each slot will \ return a list of data. This list is one sequence. \ So the return format likes \ [[a0, a1, a2], [b1, b2, b3, b4], [c1]]. - If sequence has **sub-sequence**: Each slot will return \ a nested-list of data. This list contains several \ sub-lists, each sub-list is one sub-sequence. \ So the return format likes \ [[[a0, a1, a2], [a4, a5]], [[b1, b2, b3, b4], [b5, b6]], [[c1], [c2]]]. :type use_seq: bool :param should_shuffle: True if data should shuffle. :type should_shuffle: bool :param calc_batch_size: The method calculate each data's batch size. - Default is the batch size of one sample. - User can customize by **lamda** funtion. For example, \ :code:`calc_batch_size = lambda data : len(data)` \ means calculating the token number of a sequence data. :type calc_batch_size: callable :param can_over_batch_size: Whether :code:`actual batch size >= input batch size` - **True** (>=): getNextBatch method can return more data (Default). - **False** (<): user must ensure that each data's batch size < input batch size. :type can_over_batch_size: bool :param debug: True if enable debug logger and some debug check. Default is False. :type debug: bool :param profile_filename: None if disable profile (Default). Otherwise, \ the data provider will dump profile result when \ reset. And the dump filename is \ **_**. :type profile_filename: None | Str """ def _wrapper(handler): class Cls(GeneralPyDataProvider): """ Real PyDataProvider Class. """ def __init__(self, *file_list, **kwargs): logging.basicConfig( format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]" " %(message)s") self.logger = logging.getLogger("") if debug: self.logger.setLevel(logging.DEBUG) self.logger.debug("Running pydataprovider in debug mode.") else: self.logger.setLevel(logging.INFO) init_hook(self, *file_list, **kwargs) if callable(slots): self.slots = slots(self, *file_list, **kwargs) elif slots is not None: self.slots = slots if isinstance(pool_size, int): self.max_pool_size = 0 self.file_count = pool_size elif isinstance(pool_size, PoolSize): self.max_pool_size = pool_size.size self.file_count = 0 else: raise RuntimeError self.can_over_batch_size = can_over_batch_size self.debug = debug self.profile_filename = profile_filename self.use_seq_flag = use_seq self.should_shuffle = should_shuffle GeneralPyDataProvider.__init__(self, *file_list, **kwargs) def getSlots(self): return self.slots def generateData(self, f): return handler(self, f) def calculateDataBatchSize(self, data): return calc_batch_size(data) return Cls return _wrapper def init_hook_wrapper(func): """ Wrap a method for PyDataProviderWrapper's init_hook. This method can receive parameter from trainer_config's load_data_args. The load_data_args must pass a pickle.dumps() value, and dump a map as keyword args. The wrapped method :code:`func` will receive them as keyword args. So an example usage is: .. code-block:: python @init_hook_wrapper def hook(obj, dictionary, file_list, **kwargs): obj.dictionary = dictionary obj.slots = [IndexSlot(len(obj.dictionary)), IndexSlot(len(open(file_list[0], "r").readlines()))] :param func: init_hook function :type func: callable :return: wrapped method, can be passed into @provider. """ @functools.wraps(func) def wrapper(obj, *file_list, **kwargs): args = kwargs.get("load_data_args", dict()) if isinstance(args, basestring): args = pickle.loads(args) args['file_list'] = file_list func(obj=obj, **args) return wrapper