data_set.h 9.5 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

#pragma once

#include <fstream>
#include <memory>
#include <mutex>  // NOLINT
20
#include <set>
X
xiexionghang 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
#include <string>
#include <thread>  // NOLINT
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_feed.h"

namespace paddle {
namespace framework {

// Dataset is a abstract class, which defines user interfaces
// Example Usage:
//    Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
//    dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
//    dataset->SetThreadNum(1)
//    dataset->CreateReaders();
//    dataset->SetDataFeedDesc(your_data_feed_desc);
//    dataset->LoadIntoMemory();
//    dataset->SetTrainerNum(2);
//    dataset->GlobalShuffle();
class Dataset {
 public:
  Dataset() {}
  virtual ~Dataset() {}
  // set file list
  virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
  // set readers' num
  virtual void SetThreadNum(int thread_num) = 0;
  // set workers' num
  virtual void SetTrainerNum(int trainer_num) = 0;
  // set fleet send batch size
  virtual void SetFleetSendBatchSize(int64_t size) = 0;
  // set fs name and ugi
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi) = 0;
  // set data fedd desc, which contains:
  //   data feed name, batch size, slots
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
  // set channel num
  virtual void SetChannelNum(int channel_num) = 0;
61 62 63
  // set parse ins id
  virtual void SetParseInsId(bool parse_ins_id) = 0;
  virtual void SetParseContent(bool parse_content) = 0;
X
xiexionghang 已提交
64 65 66 67
  // set merge by ins id
  virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
                               bool erase_duplicate_feas, int min_merge_size,
                               bool keep_unmerged_ins) = 0;
68 69
  // set fea eval mode
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
X
xiexionghang 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
  // get file list
  virtual const std::vector<std::string>& GetFileList() = 0;
  // get thread num
  virtual int GetThreadNum() = 0;
  // get worker num
  virtual int GetTrainerNum() = 0;
  // get fleet send batch size
  virtual int64_t GetFleetSendBatchSize() = 0;
  // get hdfs config
  virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
  // get data fedd desc
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
  // get channel num
  virtual int GetChannelNum() = 0;
  // get readers, the reader num depend both on thread num
  // and filelist size
  virtual std::vector<paddle::framework::DataFeed*> GetReaders() = 0;
  // create input channel and output channel
  virtual void CreateChannel() = 0;
  // register message handler between workers
  virtual void RegisterClientToClientMsgHandler() = 0;
  // load all data into memory
  virtual void LoadIntoMemory() = 0;
  // load all data into memory in async mode
  virtual void PreLoadIntoMemory() = 0;
  // wait async load done
  virtual void WaitPreLoadDone() = 0;
  // release all memory data
  virtual void ReleaseMemory() = 0;
  // local shuffle data
  virtual void LocalShuffle() = 0;
  // global shuffle data
102 103 104 105 106
  virtual void GlobalShuffle(int thread_num = -1) = 0;
  // for slots shuffle
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
  virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
                             std::vector<Record>* result) = 0;
X
xiexionghang 已提交
107 108 109 110 111 112 113 114 115 116
  // create readers
  virtual void CreateReaders() = 0;
  // destroy readers
  virtual void DestroyReaders() = 0;
  // get memory data size
  virtual int64_t GetMemoryDataSize() = 0;
  // get shuffle data size
  virtual int64_t GetShuffleDataSize() = 0;
  // merge by ins id
  virtual void MergeByInsId() = 0;
117 118 119 120 121 122 123 124 125 126 127
  // create preload readers
  virtual void CreatePreLoadReaders() = 0;
  // destroy preload readers after prelaod done
  virtual void DestroyPreLoadReaders() = 0;
  // set preload thread num
  virtual void SetPreLoadThreadNum(int thread_num) = 0;
  // seperate train thread and dataset thread
  virtual void DynamicAdjustChannelNum(int channel_num) = 0;
  virtual void DynamicAdjustReadersNum(int thread_num) = 0;
  // set fleet send sleep seconds
  virtual void SetFleetSendSleepSeconds(int seconds) = 0;
X
xiexionghang 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg) = 0;
};

// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
template <typename T>
class DatasetImpl : public Dataset {
 public:
  DatasetImpl();
  virtual ~DatasetImpl() {}

  virtual void SetFileList(const std::vector<std::string>& filelist);
  virtual void SetThreadNum(int thread_num);
  virtual void SetTrainerNum(int trainer_num);
  virtual void SetFleetSendBatchSize(int64_t size);
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi);
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
  virtual void SetChannelNum(int channel_num);
150 151
  virtual void SetParseInsId(bool parse_ins_id);
  virtual void SetParseContent(bool parse_content);
X
xiexionghang 已提交
152 153 154 155
  virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
                               bool erase_duplicate_feas, int min_merge_size,
                               bool keep_unmerged_ins);

156
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
X
xiexionghang 已提交
157 158 159
  virtual const std::vector<std::string>& GetFileList() { return filelist_; }
  virtual int GetThreadNum() { return thread_num_; }
  virtual int GetTrainerNum() { return trainer_num_; }
160
  virtual Channel<T> GetInputChannel() { return input_channel_; }
X
xiexionghang 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
  virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
  virtual std::pair<std::string, std::string> GetHdfsConfig() {
    return std::make_pair(fs_name_, fs_ugi_);
  }
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
    return data_feed_desc_;
  }
  virtual int GetChannelNum() { return channel_num_; }
  virtual std::vector<paddle::framework::DataFeed*> GetReaders();
  virtual void CreateChannel();
  virtual void RegisterClientToClientMsgHandler();
  virtual void LoadIntoMemory();
  virtual void PreLoadIntoMemory();
  virtual void WaitPreLoadDone();
  virtual void ReleaseMemory();
  virtual void LocalShuffle();
177 178 179 180
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
  virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
                             std::vector<Record>* result) {}
X
xiexionghang 已提交
181 182 183 184 185
  virtual void CreateReaders();
  virtual void DestroyReaders();
  virtual int64_t GetMemoryDataSize();
  virtual int64_t GetShuffleDataSize();
  virtual void MergeByInsId() {}
186 187 188 189 190 191
  virtual void CreatePreLoadReaders();
  virtual void DestroyPreLoadReaders();
  virtual void SetPreLoadThreadNum(int thread_num);
  virtual void DynamicAdjustChannelNum(int channel_num);
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void SetFleetSendSleepSeconds(int seconds);
X
xiexionghang 已提交
192 193 194 195 196

 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg);
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
197
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
X
xiexionghang 已提交
198 199 200 201 202 203 204 205
  paddle::framework::Channel<T> input_channel_;
  int channel_num_;
  std::vector<paddle::framework::Channel<T>> multi_output_channel_;
  std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
  // when read ins, we put ins from one channel to the other,
  // and when finish reading, we set cur_channel = 1 - cur_channel,
  // so if cur_channel=0, all data are in output_channel, else consume_channel
  int cur_channel_;
206 207
  std::vector<T> slots_shuffle_original_data_;
  RecordCandidateList slots_shuffle_rclist_;
X
xiexionghang 已提交
208 209 210 211 212 213 214 215 216 217 218 219
  int thread_num_;
  paddle::framework::DataFeedDesc data_feed_desc_;
  int trainer_num_;
  std::vector<std::string> filelist_;
  size_t file_idx_;
  std::mutex mutex_for_pick_file_;
  std::string fs_name_;
  std::string fs_ugi_;
  int64_t fleet_send_batch_size_;
  int64_t fleet_send_sleep_seconds_;
  std::vector<std::thread> preload_threads_;
  bool merge_by_insid_;
220 221
  bool parse_ins_id_;
  bool parse_content_;
X
xiexionghang 已提交
222 223 224 225
  bool erase_duplicate_feas_;
  bool keep_unmerged_ins_;
  int min_merge_size_;
  std::vector<std::string> merge_slots_list_;
226 227 228 229
  bool slots_shuffle_fea_eval_ = false;
  int preload_thread_num_;
  std::mutex global_index_mutex_;
  int64_t global_index_ = 0;
X
xiexionghang 已提交
230 231 232 233 234 235 236
};

// use std::vector<MultiSlotType> or Record as data type
class MultiSlotDataset : public DatasetImpl<Record> {
 public:
  MultiSlotDataset() {}
  virtual void MergeByInsId();
237 238 239
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
  virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
                             std::vector<Record>* result);
X
xiexionghang 已提交
240 241 242 243 244
  virtual ~MultiSlotDataset() {}
};

}  // end namespace framework
}  // end namespace paddle