PyDataProvider.h 4.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <paddle/utils/PythonUtil.h>
#include "DataFormat.pb.h"
#include "DataProvider.h"

#include <vector>

namespace paddle {

class PyDataProvider : public DataProvider {
public:
27 28
  PyDataProvider(const DataConfig& config,
                 bool useGpu,
Z
zhangjinchao01 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
                 bool loadDataAll = true);

  virtual void reset();

  // Note this size includes the sequences which are skipped because they
  // are longer than the batch size
  virtual int64_t getSize() {
    LOG(FATAL) << "Not implement yet";
    return -1;
  }
  virtual void shuffle();

  virtual int64_t getNextBatchInternal(int64_t size, DataBatch* batch);

protected:
  struct ProtoSlot;
  // return false if each each sample is one sequence, i.e., independent
  // of other samples.
  inline bool iidData() const { return isIID_; }

  void parseHeaderData(const std::string& headerData);
  void fillDenseSlot(ProtoSlot& slot, char*& data, const char* dataEnd);
51 52
  void fillSparseNonValueSlot(ProtoSlot& slot,
                              char*& data,
Z
zhangjinchao01 已提交
53 54 55 56 57
                              const char* dataEnd);
  void fillSparseValueSlot(ProtoSlot& slot, char*& data, const char* dataEnd);
  void fillIndexSlot(ProtoSlot& slot, char*& data, const char* dataEnd);
  void fillStringSlot(ProtoSlot& slot, char*& data, const char* dataEnd);
  void fillSlotsByStr(const std::string& samples);
58 59
  void handleDenseSlot(ProtoSlot& slot,
                       size_t slotIndex,
Z
zhangjinchao01 已提交
60
                       std::vector<Argument>& cpuArguments);
61 62
  void handleSparseNonValueSlot(ProtoSlot& slot,
                                size_t slotIndex,
Z
zhangjinchao01 已提交
63
                                std::vector<Argument>& cpuArguments);
64 65
  void handleSparseValueSlot(ProtoSlot& slot,
                             size_t slotIndex,
Z
zhangjinchao01 已提交
66
                             std::vector<Argument>& cpuArguments);
67 68
  void handleIndexSlot(ProtoSlot& slot,
                       size_t slotIndex,
Z
zhangjinchao01 已提交
69
                       std::vector<Argument>& cpuArguments);
70 71
  void handleStringSlot(ProtoSlot& slot,
                        size_t slotIndex,
Z
zhangjinchao01 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
                        std::vector<Argument>& cpuArguments);
  void resetSlots();
  void loadData(const std::vector<std::string>& fileList);

protected:
  struct ProtoSlot {
    SlotDef::SlotType type;
    int dim;
    unsigned int sampleNum;
    unsigned int sequenceNum;
    unsigned int subSequenceNum;
    // Store the data of index type slot
    std::vector<int> indexData;
    // Store the data of dense type slot
    std::vector<real> denseData;
    // Store the data of sparseNonValue type slot
    std::vector<sparse_non_value_t> sparseNonValueData;
    // Store the data of sparseValue type slot
    std::vector<sparse_float_value_t> sparseFloatValueData;
    // Used to store the index of each sample in slot values
    std::vector<int64_t> indices;
    // The starting position of each sequence in samples
    // The last element should be the number of samples
    // If empty, each sample is one sequence.
    std::vector<size_t> sequenceStartPositions;
    // The index id of sequences in slot
    std::vector<int64_t> sampleSequenceIdVec;
    // The starting position of each subsequence in samples
    // The last element should be the number of subsequence
    // If empty, each sequence of sample has no subsequence.
    std::vector<size_t> subSequenceStartPositions;
    // Store the data of string type slot
    std::vector<std::string> strData;
  };
  std::vector<ProtoSlot> slots_;

  PyObjectPtr classInstance_;
  unsigned int batchSize_;
  unsigned int slotNum_;
  // if use sequence, isIID_ equals false, otherwise it is true.
  bool isIID_;
  // The name of python module name
  std::string pyModuleName_;
  // The name of python class name
  std::string pyClassName_;
  // User args set in config
  std::map<std::string, std::string> pyUserArgs_;

  ThreadLocalD<DataBatch> cpuBatch_;
  ThreadLocalD<DataBatch> gpuBatch_;
};

}  // namespace paddle