data_set.h 14.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

#pragma once

17
#include <ThreadPool.h>
18 19 20
#include <fstream>
#include <memory>
#include <mutex>  // NOLINT
21
#include <set>
22 23
#include <string>
#include <thread>  // NOLINT
24
#include <unordered_set>
X
xjqbest 已提交
25
#include <utility>
26
#include <vector>
Y
yaoxuefeng 已提交
27 28 29 30
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
31 32 33 34 35 36

#include "paddle/fluid/framework/data_feed.h"

namespace paddle {
namespace framework {

X
xjqbest 已提交
37 38 39 40 41 42 43 44 45 46
// Dataset is a abstract class, which defines user interfaces
// Example Usage:
//    Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
//    dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
//    dataset->SetThreadNum(1)
//    dataset->CreateReaders();
//    dataset->SetDataFeedDesc(your_data_feed_desc);
//    dataset->LoadIntoMemory();
//    dataset->SetTrainerNum(2);
//    dataset->GlobalShuffle();
47 48
class Dataset {
 public:
49 50
  Dataset() {}
  virtual ~Dataset() {}
W
wangzhen38 已提交
51 52 53 54 55 56 57
  // do sample
  virtual void TDMSample(const std::string tree_name,
                         const std::string tree_path,
                         const std::vector<uint16_t> tdm_layer_counts,
                         const uint16_t start_sample_layer,
                         const bool with_hierachy, const uint16_t seed_,
                         const uint16_t sample_slot) {}
58
  // set file list
59
  virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
60
  // set readers' num
61
  virtual void SetThreadNum(int thread_num) = 0;
62
  // set workers' num
63
  virtual void SetTrainerNum(int trainer_num) = 0;
X
xjqbest 已提交
64 65
  // set fleet send batch size
  virtual void SetFleetSendBatchSize(int64_t size) = 0;
66
  // set fs name and ugi
67 68
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi) = 0;
69 70
  // set customized download command, such as using afs api
  virtual void SetDownloadCmd(const std::string& download_cmd) = 0;
71 72
  // set data fedd desc, which contains:
  //   data feed name, batch size, slots
73
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
J
jiaqi 已提交
74 75
  // set channel num
  virtual void SetChannelNum(int channel_num) = 0;
76 77 78
  // set parse ins id
  virtual void SetParseInsId(bool parse_ins_id) = 0;
  virtual void SetParseContent(bool parse_content) = 0;
79 80
  virtual void SetParseLogKey(bool parse_logkey) = 0;
  virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
81
  virtual bool EnablePvMerge() = 0;
82
  virtual void SetMergeBySid(bool is_merge) = 0;
83
  // set merge by ins id
84
  virtual void SetMergeByInsId(int merge_size) = 0;
85
  virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
86 87
  // set fea eval mode
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
88
  // get file list
89
  virtual const std::vector<std::string>& GetFileList() = 0;
90
  // get thread num
91
  virtual int GetThreadNum() = 0;
92
  // get worker num
93
  virtual int GetTrainerNum() = 0;
X
xjqbest 已提交
94 95
  // get fleet send batch size
  virtual int64_t GetFleetSendBatchSize() = 0;
X
xjqbest 已提交
96 97
  // get hdfs config
  virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
98 99
  // get download cmd
  virtual std::string GetDownloadCmd() = 0;
100
  // get data fedd desc
101
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
J
jiaqi 已提交
102 103
  // get channel num
  virtual int GetChannelNum() = 0;
104 105
  // get readers, the reader num depend both on thread num
  // and filelist size
J
jiaqi 已提交
106 107 108
  virtual std::vector<paddle::framework::DataFeed*> GetReaders() = 0;
  // create input channel and output channel
  virtual void CreateChannel() = 0;
109 110 111
  // register message handler between workers
  virtual void RegisterClientToClientMsgHandler() = 0;
  // load all data into memory
112
  virtual void LoadIntoMemory() = 0;
J
jiaqi 已提交
113 114 115 116
  // load all data into memory in async mode
  virtual void PreLoadIntoMemory() = 0;
  // wait async load done
  virtual void WaitPreLoadDone() = 0;
117 118 119
  // release all memory data
  virtual void ReleaseMemory() = 0;
  // local shuffle data
120
  virtual void LocalShuffle() = 0;
121
  // global shuffle data
122
  virtual void GlobalShuffle(int thread_num = -1) = 0;
123
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
124
  // create readers
125
  virtual void CreateReaders() = 0;
126
  // destroy readers
127
  virtual void DestroyReaders() = 0;
128 129
  // get memory data size
  virtual int64_t GetMemoryDataSize() = 0;
130 131
  // get memory data size in input_pv_channel_
  virtual int64_t GetPvDataSize() = 0;
132 133
  // get shuffle data size
  virtual int64_t GetShuffleDataSize() = 0;
134 135
  // merge by ins id
  virtual void MergeByInsId() = 0;
136 137 138 139 140 141
  // merge pv instance
  virtual void PreprocessInstance() = 0;
  // divide pv instance
  virtual void PostprocessInstance() = 0;
  // only for untest
  virtual void SetCurrentPhase(int current_phase) = 0;
142 143 144 145 146
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num,
                                         int shard_num) = 0;
  virtual void ClearLocalTables() = 0;
147 148 149 150 151 152
  // create preload readers
  virtual void CreatePreLoadReaders() = 0;
  // destroy preload readers after prelaod done
  virtual void DestroyPreLoadReaders() = 0;
  // set preload thread num
  virtual void SetPreLoadThreadNum(int thread_num) = 0;
H
hutuxian 已提交
153 154 155
  // seperate train thread and dataset thread
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins = false) = 0;
156 157 158
  virtual void DynamicAdjustReadersNum(int thread_num) = 0;
  // set fleet send sleep seconds
  virtual void SetFleetSendSleepSeconds(int seconds) = 0;
159

160 161 162 163 164
 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg) = 0;
};

X
xjqbest 已提交
165 166
// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
167
template <typename T>
168 169 170 171
class DatasetImpl : public Dataset {
 public:
  DatasetImpl();
  virtual ~DatasetImpl() {}
172 173 174
  virtual void SetFileList(const std::vector<std::string>& filelist);
  virtual void SetThreadNum(int thread_num);
  virtual void SetTrainerNum(int trainer_num);
X
xjqbest 已提交
175
  virtual void SetFleetSendBatchSize(int64_t size);
176 177
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi);
178
  virtual void SetDownloadCmd(const std::string& download_cmd);
179
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
J
jiaqi 已提交
180
  virtual void SetChannelNum(int channel_num);
181 182
  virtual void SetParseInsId(bool parse_ins_id);
  virtual void SetParseContent(bool parse_content);
183 184 185 186
  virtual void SetParseLogKey(bool parse_logkey);
  virtual void SetEnablePvMerge(bool enable_pv_merge);
  virtual void SetMergeBySid(bool is_merge);

187
  virtual void SetMergeByInsId(int merge_size);
188
  virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
189
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
D
dongdaxiang 已提交
190 191 192
  virtual const std::vector<std::string>& GetFileList() { return filelist_; }
  virtual int GetThreadNum() { return thread_num_; }
  virtual int GetTrainerNum() { return trainer_num_; }
H
hutuxian 已提交
193
  virtual Channel<T> GetInputChannel() { return input_channel_; }
194 195 196
  virtual void SetInputChannel(const Channel<T>& input_channel) {
    input_channel_ = input_channel;
  }
X
xjqbest 已提交
197
  virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
X
xjqbest 已提交
198 199 200
  virtual std::pair<std::string, std::string> GetHdfsConfig() {
    return std::make_pair(fs_name_, fs_ugi_);
  }
201
  virtual std::string GetDownloadCmd();
202 203 204
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
    return data_feed_desc_;
  }
J
jiaqi 已提交
205
  virtual int GetChannelNum() { return channel_num_; }
206
  virtual bool EnablePvMerge() { return enable_pv_merge_; }
J
jiaqi 已提交
207 208
  virtual std::vector<paddle::framework::DataFeed*> GetReaders();
  virtual void CreateChannel();
209
  virtual void RegisterClientToClientMsgHandler();
210
  virtual void LoadIntoMemory();
J
jiaqi 已提交
211 212
  virtual void PreLoadIntoMemory();
  virtual void WaitPreLoadDone();
213
  virtual void ReleaseMemory();
214
  virtual void LocalShuffle();
Y
yaoxuefeng 已提交
215
  virtual void GlobalShuffle(int thread_num = -1) {}
216
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
217 218 219
  virtual const std::vector<T>& GetSlotsOriginalData() {
    return slots_shuffle_original_data_;
  }
220
  virtual void CreateReaders();
221
  virtual void DestroyReaders();
222
  virtual int64_t GetMemoryDataSize();
223
  virtual int64_t GetPvDataSize();
224
  virtual int64_t GetShuffleDataSize();
225
  virtual void MergeByInsId() {}
226 227 228
  virtual void PreprocessInstance() {}
  virtual void PostprocessInstance() {}
  virtual void SetCurrentPhase(int current_phase) {}
229 230 231 232 233
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num,
                                         int shard_num) {}
  virtual void ClearLocalTables() {}
234 235 236
  virtual void CreatePreLoadReaders();
  virtual void DestroyPreLoadReaders();
  virtual void SetPreLoadThreadNum(int thread_num);
H
hutuxian 已提交
237 238
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins = false);
239 240
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void SetFleetSendSleepSeconds(int seconds);
Y
yaoxuefeng 已提交
241 242 243 244 245
  /* for enable_heterps_
  virtual void EnableHeterps(bool enable_heterps) {
    enable_heterps_ = enable_heterps;
  }
  */
D
dongdaxiang 已提交
246

T
Thunderbrook 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259 260
  std::vector<paddle::framework::Channel<T>>& GetMultiOutputChannel() {
    return multi_output_channel_;
  }

  std::vector<paddle::framework::Channel<T>>& GetCurOutputChannel() {
    if (cur_channel_ == 0) {
      return multi_output_channel_;
    } else {
      return multi_consume_channel_;
    }
  }

  Channel<T>& GetInputChannelRef() { return input_channel_; }

261
 protected:
D
dongdaxiang 已提交
262
  virtual int ReceiveFromClient(int msg_type, int client_id,
Y
yaoxuefeng 已提交
263 264 265 266
                                const std::string& msg) {
    // TODO(yaoxuefeng) for SlotRecordDataset
    return -1;
  }
267
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
268
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
J
jiaqi 已提交
269
  paddle::framework::Channel<T> input_channel_;
270 271 272 273
  paddle::framework::Channel<PvInstance> input_pv_channel_;
  std::vector<paddle::framework::Channel<PvInstance>> multi_pv_output_;
  std::vector<paddle::framework::Channel<PvInstance>> multi_pv_consume_;

J
jiaqi 已提交
274 275 276
  int channel_num_;
  std::vector<paddle::framework::Channel<T>> multi_output_channel_;
  std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
277
  std::vector<std::unordered_set<uint64_t>> local_tables_;
J
jiaqi 已提交
278 279 280 281
  // when read ins, we put ins from one channel to the other,
  // and when finish reading, we set cur_channel = 1 - cur_channel,
  // so if cur_channel=0, all data are in output_channel, else consume_channel
  int cur_channel_;
282 283
  std::vector<T> slots_shuffle_original_data_;
  RecordCandidateList slots_shuffle_rclist_;
284
  int thread_num_;
285
  int pull_sparse_to_local_thread_num_;
286 287
  paddle::framework::DataFeedDesc data_feed_desc_;
  int trainer_num_;
288 289
  std::vector<std::string> filelist_;
  size_t file_idx_;
H
hutuxian 已提交
290
  uint64_t total_fea_num_;
291
  std::mutex mutex_for_pick_file_;
H
hutuxian 已提交
292
  std::mutex mutex_for_fea_num_;
X
xjqbest 已提交
293 294
  std::string fs_name_;
  std::string fs_ugi_;
X
xjqbest 已提交
295
  int64_t fleet_send_batch_size_;
J
jiaqi 已提交
296 297
  int64_t fleet_send_sleep_seconds_;
  std::vector<std::thread> preload_threads_;
298
  bool merge_by_insid_;
299 300
  bool parse_ins_id_;
  bool parse_content_;
301 302 303 304
  bool parse_logkey_;
  bool merge_by_sid_;
  bool enable_pv_merge_;  // True means to merge pv
  int current_phase_;     // 1 join, 0 update
305
  size_t merge_size_;
306
  bool slots_shuffle_fea_eval_ = false;
307
  bool gen_uni_feasigns_ = false;
308
  int preload_thread_num_;
309 310
  std::mutex global_index_mutex_;
  int64_t global_index_ = 0;
311
  std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
312
  std::vector<T> input_records_;  // only for paddleboxdatafeed
Y
yaoxuefeng 已提交
313
  bool enable_heterps_ = false;
314 315
};

316
// use std::vector<MultiSlotType> or Record as data type
J
jiaqi 已提交
317
class MultiSlotDataset : public DatasetImpl<Record> {
318 319
 public:
  MultiSlotDataset() {}
W
wangzhen38 已提交
320 321 322 323 324 325
  virtual void TDMSample(const std::string tree_name,
                         const std::string tree_path,
                         const std::vector<uint16_t> tdm_layer_counts,
                         const uint16_t start_sample_layer,
                         const bool with_hierachy, const uint16_t seed_,
                         const uint16_t sample_slot);
326
  virtual void MergeByInsId();
327 328 329
  virtual void PreprocessInstance();
  virtual void PostprocessInstance();
  virtual void SetCurrentPhase(int current_phase);
330 331 332 333 334 335 336 337 338 339
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num, int shard_num);
  virtual void ClearLocalTables() {
    for (auto& t : local_tables_) {
      t.clear();
      std::unordered_set<uint64_t>().swap(t);
    }
    std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
  }
340 341 342
  virtual void PreprocessChannel(
      const std::set<std::string>& slots_to_replace,
      std::unordered_set<uint16_t>& index_slot);  // NOLINT
343
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
344 345 346
  virtual void GetRandomData(
      const std::unordered_set<uint16_t>& slots_to_replace,
      std::vector<Record>* result);
347
  virtual ~MultiSlotDataset() {}
Y
yaoxuefeng 已提交
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void PrepareTrain();

 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg);
};
class SlotRecordDataset : public DatasetImpl<SlotRecord> {
 public:
  SlotRecordDataset() { SlotRecordPool(); }
  virtual ~SlotRecordDataset() {}
  // create input channel
  virtual void CreateChannel();
  // create readers
  virtual void CreateReaders();
  // release memory
  virtual void ReleaseMemory();
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins);
  virtual void PrepareTrain();
  virtual void DynamicAdjustReadersNum(int thread_num);

 protected:
  bool enable_heterps_ = true;
374 375
};

D
dongdaxiang 已提交
376 377
}  // end namespace framework
}  // end namespace paddle