PyDataProvider.h 4.2 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */


#pragma once

#include <paddle/utils/PythonUtil.h>
#include "DataFormat.pb.h"
#include "DataProvider.h"

#include <vector>

namespace paddle {

class PyDataProvider : public DataProvider {
public:
  PyDataProvider(const DataConfig& config, bool useGpu,
                 bool loadDataAll = true);

  virtual void reset();

  // Note this size includes the sequences which are skipped because they
  // are longer than the batch size
  virtual int64_t getSize() {
    LOG(FATAL) << "Not implement yet";
    return -1;
  }
  virtual void shuffle();

  virtual int64_t getNextBatchInternal(int64_t size, DataBatch* batch);

protected:
  struct ProtoSlot;
  // return false if each each sample is one sequence, i.e., independent
  // of other samples.
  inline bool iidData() const { return isIID_; }

  void parseHeaderData(const std::string& headerData);
  void fillDenseSlot(ProtoSlot& slot, char*& data, const char* dataEnd);
  void fillSparseNonValueSlot(ProtoSlot& slot, char*& data,
                              const char* dataEnd);
  void fillSparseValueSlot(ProtoSlot& slot, char*& data, const char* dataEnd);
  void fillIndexSlot(ProtoSlot& slot, char*& data, const char* dataEnd);
  void fillStringSlot(ProtoSlot& slot, char*& data, const char* dataEnd);
  void fillSlotsByStr(const std::string& samples);
  void handleDenseSlot(ProtoSlot& slot, size_t slotIndex,
                       std::vector<Argument>& cpuArguments);
  void handleSparseNonValueSlot(ProtoSlot& slot, size_t slotIndex,
                                std::vector<Argument>& cpuArguments);
  void handleSparseValueSlot(ProtoSlot& slot, size_t slotIndex,
                             std::vector<Argument>& cpuArguments);
  void handleIndexSlot(ProtoSlot& slot, size_t slotIndex,
                       std::vector<Argument>& cpuArguments);
  void handleStringSlot(ProtoSlot& slot, size_t slotIndex,
                        std::vector<Argument>& cpuArguments);
  void resetSlots();
  void loadData(const std::vector<std::string>& fileList);

protected:
  struct ProtoSlot {
    SlotDef::SlotType type;
    int dim;
    unsigned int sampleNum;
    unsigned int sequenceNum;
    unsigned int subSequenceNum;
    // Store the data of index type slot
    std::vector<int> indexData;
    // Store the data of dense type slot
    std::vector<real> denseData;
    // Store the data of sparseNonValue type slot
    std::vector<sparse_non_value_t> sparseNonValueData;
    // Store the data of sparseValue type slot
    std::vector<sparse_float_value_t> sparseFloatValueData;
    // Used to store the index of each sample in slot values
    std::vector<int64_t> indices;
    // The starting position of each sequence in samples
    // The last element should be the number of samples
    // If empty, each sample is one sequence.
    std::vector<size_t> sequenceStartPositions;
    // The index id of sequences in slot
    std::vector<int64_t> sampleSequenceIdVec;
    // The starting position of each subsequence in samples
    // The last element should be the number of subsequence
    // If empty, each sequence of sample has no subsequence.
    std::vector<size_t> subSequenceStartPositions;
    // Store the data of string type slot
    std::vector<std::string> strData;
  };
  std::vector<ProtoSlot> slots_;

  PyObjectPtr classInstance_;
  unsigned int batchSize_;
  unsigned int slotNum_;
  // if use sequence, isIID_ equals false, otherwise it is true.
  bool isIID_;
  // The name of python module name
  std::string pyModuleName_;
  // The name of python class name
  std::string pyClassName_;
  // User args set in config
  std::map<std::string, std::string> pyUserArgs_;

  ThreadLocalD<DataBatch> cpuBatch_;
  ThreadLocalD<DataBatch> gpuBatch_;
};

}  // namespace paddle