/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include "DataFormat.pb.h" #include "DataProvider.h" #include namespace paddle { class PyDataProvider : public DataProvider { public: PyDataProvider(const DataConfig& config, bool useGpu, bool loadDataAll = true); virtual void reset(); // Note this size includes the sequences which are skipped because they // are longer than the batch size virtual int64_t getSize() { LOG(FATAL) << "Not implement yet"; return -1; } virtual void shuffle(); virtual int64_t getNextBatchInternal(int64_t size, DataBatch* batch); protected: struct ProtoSlot; // return false if each each sample is one sequence, i.e., independent // of other samples. inline bool iidData() const { return isIID_; } void parseHeaderData(const std::string& headerData); void fillDenseSlot(ProtoSlot& slot, char*& data, const char* dataEnd); void fillSparseNonValueSlot(ProtoSlot& slot, char*& data, const char* dataEnd); void fillSparseValueSlot(ProtoSlot& slot, char*& data, const char* dataEnd); void fillIndexSlot(ProtoSlot& slot, char*& data, const char* dataEnd); void fillStringSlot(ProtoSlot& slot, char*& data, const char* dataEnd); void fillSlotsByStr(const std::string& samples); void handleDenseSlot(ProtoSlot& slot, size_t slotIndex, std::vector& cpuArguments); void handleSparseNonValueSlot(ProtoSlot& slot, size_t slotIndex, std::vector& cpuArguments); void handleSparseValueSlot(ProtoSlot& slot, size_t slotIndex, std::vector& cpuArguments); void handleIndexSlot(ProtoSlot& slot, size_t slotIndex, std::vector& cpuArguments); void handleStringSlot(ProtoSlot& slot, size_t slotIndex, std::vector& cpuArguments); void resetSlots(); void loadData(const std::vector& fileList); protected: struct ProtoSlot { SlotDef::SlotType type; int dim; unsigned int sampleNum; unsigned int sequenceNum; unsigned int subSequenceNum; // Store the data of index type slot std::vector indexData; // Store the data of dense type slot std::vector denseData; // Store the data of sparseNonValue type slot std::vector sparseNonValueData; // Store the data of sparseValue type slot std::vector sparseFloatValueData; // Used to store the index of each sample in slot values std::vector indices; // The starting position of each sequence in samples // The last element should be the number of samples // If empty, each sample is one sequence. std::vector sequenceStartPositions; // The index id of sequences in slot std::vector sampleSequenceIdVec; // The starting position of each subsequence in samples // The last element should be the number of subsequence // If empty, each sequence of sample has no subsequence. std::vector subSequenceStartPositions; // Store the data of string type slot std::vector strData; }; std::vector slots_; PyObjectPtr classInstance_; unsigned int batchSize_; unsigned int slotNum_; // if use sequence, isIID_ equals false, otherwise it is true. bool isIID_; // The name of python module name std::string pyModuleName_; // The name of python class name std::string pyClassName_; // User args set in config std::map pyUserArgs_; ThreadLocalD cpuBatch_; ThreadLocalD gpuBatch_; }; } // namespace paddle