data_set.h 15.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

#pragma once

17
#include <ThreadPool.h>
18 19 20
#include <fstream>
#include <memory>
#include <mutex>  // NOLINT
21
#include <set>
22 23
#include <string>
#include <thread>  // NOLINT
24
#include <unordered_set>
X
xjqbest 已提交
25
#include <utility>
26
#include <vector>
Y
yaoxuefeng 已提交
27 28 29 30
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
31 32 33 34 35 36

#include "paddle/fluid/framework/data_feed.h"

namespace paddle {
namespace framework {

X
xjqbest 已提交
37 38 39 40 41 42 43 44 45 46
// Dataset is a abstract class, which defines user interfaces
// Example Usage:
//    Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
//    dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
//    dataset->SetThreadNum(1)
//    dataset->CreateReaders();
//    dataset->SetDataFeedDesc(your_data_feed_desc);
//    dataset->LoadIntoMemory();
//    dataset->SetTrainerNum(2);
//    dataset->GlobalShuffle();
47 48
class Dataset {
 public:
49 50
  Dataset() {}
  virtual ~Dataset() {}
W
wangzhen38 已提交
51 52 53 54 55 56 57
  // do sample
  virtual void TDMSample(const std::string tree_name,
                         const std::string tree_path,
                         const std::vector<uint16_t> tdm_layer_counts,
                         const uint16_t start_sample_layer,
                         const bool with_hierachy, const uint16_t seed_,
                         const uint16_t sample_slot) {}
58
  // set file list
59
  virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
60
  // set readers' num
61
  virtual void SetThreadNum(int thread_num) = 0;
62
  // set workers' num
63
  virtual void SetTrainerNum(int trainer_num) = 0;
X
xjqbest 已提交
64 65
  // set fleet send batch size
  virtual void SetFleetSendBatchSize(int64_t size) = 0;
T
Thunderbrook 已提交
66
  virtual void ReleaseMemoryFun() = 0;
67
  // set fs name and ugi
68 69
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi) = 0;
70 71
  // set customized download command, such as using afs api
  virtual void SetDownloadCmd(const std::string& download_cmd) = 0;
72 73
  // set data fedd desc, which contains:
  //   data feed name, batch size, slots
74
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
J
jiaqi 已提交
75 76
  // set channel num
  virtual void SetChannelNum(int channel_num) = 0;
77 78 79
  // set parse ins id
  virtual void SetParseInsId(bool parse_ins_id) = 0;
  virtual void SetParseContent(bool parse_content) = 0;
80 81
  virtual void SetParseLogKey(bool parse_logkey) = 0;
  virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
82
  virtual bool EnablePvMerge() = 0;
83
  virtual void SetMergeBySid(bool is_merge) = 0;
84
  virtual void SetShuffleByUid(bool enable_shuffle_uid) = 0;
85
  // set merge by ins id
86
  virtual void SetMergeByInsId(int merge_size) = 0;
87
  virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
88 89
  // set fea eval mode
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
90
  // get file list
91
  virtual const std::vector<std::string>& GetFileList() = 0;
92
  // get thread num
93
  virtual int GetThreadNum() = 0;
94
  // get worker num
95
  virtual int GetTrainerNum() = 0;
X
xjqbest 已提交
96 97
  // get fleet send batch size
  virtual int64_t GetFleetSendBatchSize() = 0;
X
xjqbest 已提交
98 99
  // get hdfs config
  virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
100 101
  // get download cmd
  virtual std::string GetDownloadCmd() = 0;
102
  // get data fedd desc
103
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
J
jiaqi 已提交
104 105
  // get channel num
  virtual int GetChannelNum() = 0;
106 107
  // get readers, the reader num depend both on thread num
  // and filelist size
J
jiaqi 已提交
108 109 110
  virtual std::vector<paddle::framework::DataFeed*> GetReaders() = 0;
  // create input channel and output channel
  virtual void CreateChannel() = 0;
111 112 113
  // register message handler between workers
  virtual void RegisterClientToClientMsgHandler() = 0;
  // load all data into memory
114
  virtual void LoadIntoMemory() = 0;
J
jiaqi 已提交
115 116 117 118
  // load all data into memory in async mode
  virtual void PreLoadIntoMemory() = 0;
  // wait async load done
  virtual void WaitPreLoadDone() = 0;
119 120 121
  // release all memory data
  virtual void ReleaseMemory() = 0;
  // local shuffle data
122
  virtual void LocalShuffle() = 0;
123
  // global shuffle data
124
  virtual void GlobalShuffle(int thread_num = -1) = 0;
125
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
126
  // create readers
127
  virtual void CreateReaders() = 0;
128
  // destroy readers
129
  virtual void DestroyReaders() = 0;
130 131
  // get memory data size
  virtual int64_t GetMemoryDataSize() = 0;
132 133
  // get memory data size in input_pv_channel_
  virtual int64_t GetPvDataSize() = 0;
134 135
  // get shuffle data size
  virtual int64_t GetShuffleDataSize() = 0;
136 137
  // merge by ins id
  virtual void MergeByInsId() = 0;
138 139 140 141 142 143
  // merge pv instance
  virtual void PreprocessInstance() = 0;
  // divide pv instance
  virtual void PostprocessInstance() = 0;
  // only for untest
  virtual void SetCurrentPhase(int current_phase) = 0;
144 145 146 147 148
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num,
                                         int shard_num) = 0;
  virtual void ClearLocalTables() = 0;
149 150 151 152 153 154
  // create preload readers
  virtual void CreatePreLoadReaders() = 0;
  // destroy preload readers after prelaod done
  virtual void DestroyPreLoadReaders() = 0;
  // set preload thread num
  virtual void SetPreLoadThreadNum(int thread_num) = 0;
Y
yaoxuefeng 已提交
155
  // seperate train thread and dataset thread
H
hutuxian 已提交
156 157
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins = false) = 0;
158 159 160
  virtual void DynamicAdjustReadersNum(int thread_num) = 0;
  // set fleet send sleep seconds
  virtual void SetFleetSendSleepSeconds(int seconds) = 0;
161

Y
yaoxuefeng 已提交
162 163
  virtual std::vector<std::string> GetSlots() = 0;

164 165 166 167 168
 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg) = 0;
};

X
xjqbest 已提交
169 170
// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
171
template <typename T>
172 173 174
class DatasetImpl : public Dataset {
 public:
  DatasetImpl();
T
Thunderbrook 已提交
175 176 177 178 179
  virtual ~DatasetImpl() {
    if (release_thread_ != nullptr) {
      release_thread_->join();
    }
  }
180
  virtual void SetFileList(const std::vector<std::string>& filelist);
T
Thunderbrook 已提交
181
  virtual void ReleaseMemoryFun();
182 183
  virtual void SetThreadNum(int thread_num);
  virtual void SetTrainerNum(int trainer_num);
X
xjqbest 已提交
184
  virtual void SetFleetSendBatchSize(int64_t size);
185 186
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi);
187
  virtual void SetDownloadCmd(const std::string& download_cmd);
188
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
J
jiaqi 已提交
189
  virtual void SetChannelNum(int channel_num);
190 191
  virtual void SetParseInsId(bool parse_ins_id);
  virtual void SetParseContent(bool parse_content);
192 193 194
  virtual void SetParseLogKey(bool parse_logkey);
  virtual void SetEnablePvMerge(bool enable_pv_merge);
  virtual void SetMergeBySid(bool is_merge);
195
  virtual void SetShuffleByUid(bool enable_shuffle_uid);
196

197
  virtual void SetMergeByInsId(int merge_size);
198
  virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
199
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
D
dongdaxiang 已提交
200 201 202
  virtual const std::vector<std::string>& GetFileList() { return filelist_; }
  virtual int GetThreadNum() { return thread_num_; }
  virtual int GetTrainerNum() { return trainer_num_; }
H
hutuxian 已提交
203
  virtual Channel<T> GetInputChannel() { return input_channel_; }
204 205 206
  virtual void SetInputChannel(const Channel<T>& input_channel) {
    input_channel_ = input_channel;
  }
X
xjqbest 已提交
207
  virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
X
xjqbest 已提交
208 209 210
  virtual std::pair<std::string, std::string> GetHdfsConfig() {
    return std::make_pair(fs_name_, fs_ugi_);
  }
211
  virtual std::string GetDownloadCmd();
212 213 214
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
    return data_feed_desc_;
  }
J
jiaqi 已提交
215
  virtual int GetChannelNum() { return channel_num_; }
216
  virtual bool EnablePvMerge() { return enable_pv_merge_; }
J
jiaqi 已提交
217 218
  virtual std::vector<paddle::framework::DataFeed*> GetReaders();
  virtual void CreateChannel();
219
  virtual void RegisterClientToClientMsgHandler();
220
  virtual void LoadIntoMemory();
J
jiaqi 已提交
221 222
  virtual void PreLoadIntoMemory();
  virtual void WaitPreLoadDone();
223
  virtual void ReleaseMemory();
224
  virtual void LocalShuffle();
Y
yaoxuefeng 已提交
225
  virtual void GlobalShuffle(int thread_num = -1) {}
226
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
227 228 229
  virtual const std::vector<T>& GetSlotsOriginalData() {
    return slots_shuffle_original_data_;
  }
230
  virtual void CreateReaders();
231
  virtual void DestroyReaders();
232
  virtual int64_t GetMemoryDataSize();
233
  virtual int64_t GetPvDataSize();
234
  virtual int64_t GetShuffleDataSize();
235
  virtual void MergeByInsId() {}
236 237 238
  virtual void PreprocessInstance() {}
  virtual void PostprocessInstance() {}
  virtual void SetCurrentPhase(int current_phase) {}
239 240 241 242 243
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num,
                                         int shard_num) {}
  virtual void ClearLocalTables() {}
244 245 246
  virtual void CreatePreLoadReaders();
  virtual void DestroyPreLoadReaders();
  virtual void SetPreLoadThreadNum(int thread_num);
H
hutuxian 已提交
247 248
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins = false);
249 250
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void SetFleetSendSleepSeconds(int seconds);
Y
yaoxuefeng 已提交
251
  virtual std::vector<std::string> GetSlots();
Y
yaoxuefeng 已提交
252 253 254 255 256
  /* for enable_heterps_
  virtual void EnableHeterps(bool enable_heterps) {
    enable_heterps_ = enable_heterps;
  }
  */
D
dongdaxiang 已提交
257

T
Thunderbrook 已提交
258 259 260 261 262 263 264 265 266 267 268 269 270 271
  std::vector<paddle::framework::Channel<T>>& GetMultiOutputChannel() {
    return multi_output_channel_;
  }

  std::vector<paddle::framework::Channel<T>>& GetCurOutputChannel() {
    if (cur_channel_ == 0) {
      return multi_output_channel_;
    } else {
      return multi_consume_channel_;
    }
  }

  Channel<T>& GetInputChannelRef() { return input_channel_; }

272
 protected:
D
dongdaxiang 已提交
273
  virtual int ReceiveFromClient(int msg_type, int client_id,
Y
yaoxuefeng 已提交
274 275 276 277
                                const std::string& msg) {
    // TODO(yaoxuefeng) for SlotRecordDataset
    return -1;
  }
278
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
279
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
J
jiaqi 已提交
280
  paddle::framework::Channel<T> input_channel_;
281 282 283 284
  paddle::framework::Channel<PvInstance> input_pv_channel_;
  std::vector<paddle::framework::Channel<PvInstance>> multi_pv_output_;
  std::vector<paddle::framework::Channel<PvInstance>> multi_pv_consume_;

J
jiaqi 已提交
285 286 287
  int channel_num_;
  std::vector<paddle::framework::Channel<T>> multi_output_channel_;
  std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
288
  std::vector<std::unordered_set<uint64_t>> local_tables_;
J
jiaqi 已提交
289 290 291 292
  // when read ins, we put ins from one channel to the other,
  // and when finish reading, we set cur_channel = 1 - cur_channel,
  // so if cur_channel=0, all data are in output_channel, else consume_channel
  int cur_channel_;
293 294
  std::vector<T> slots_shuffle_original_data_;
  RecordCandidateList slots_shuffle_rclist_;
295
  int thread_num_;
296
  int pull_sparse_to_local_thread_num_;
297 298
  paddle::framework::DataFeedDesc data_feed_desc_;
  int trainer_num_;
299 300
  std::vector<std::string> filelist_;
  size_t file_idx_;
H
hutuxian 已提交
301
  uint64_t total_fea_num_;
302
  std::mutex mutex_for_pick_file_;
H
hutuxian 已提交
303
  std::mutex mutex_for_fea_num_;
X
xjqbest 已提交
304 305
  std::string fs_name_;
  std::string fs_ugi_;
X
xjqbest 已提交
306
  int64_t fleet_send_batch_size_;
J
jiaqi 已提交
307 308
  int64_t fleet_send_sleep_seconds_;
  std::vector<std::thread> preload_threads_;
T
Thunderbrook 已提交
309
  std::thread* release_thread_ = nullptr;
310
  bool merge_by_insid_;
311 312
  bool parse_ins_id_;
  bool parse_content_;
313 314
  bool parse_logkey_;
  bool merge_by_sid_;
315 316
  bool shuffle_by_uid_;
  bool parse_uid_;
317 318
  bool enable_pv_merge_;  // True means to merge pv
  int current_phase_;     // 1 join, 0 update
319
  size_t merge_size_;
320
  bool slots_shuffle_fea_eval_ = false;
321
  bool gen_uni_feasigns_ = false;
322
  int preload_thread_num_;
323 324
  std::mutex global_index_mutex_;
  int64_t global_index_ = 0;
325
  std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
326
  std::vector<T> input_records_;  // only for paddleboxdatafeed
Y
yaoxuefeng 已提交
327
  std::vector<std::string> use_slots_;
Y
yaoxuefeng 已提交
328
  bool enable_heterps_ = false;
329 330
};

331
// use std::vector<MultiSlotType> or Record as data type
J
jiaqi 已提交
332
class MultiSlotDataset : public DatasetImpl<Record> {
333 334
 public:
  MultiSlotDataset() {}
W
wangzhen38 已提交
335 336 337 338 339 340
  virtual void TDMSample(const std::string tree_name,
                         const std::string tree_path,
                         const std::vector<uint16_t> tdm_layer_counts,
                         const uint16_t start_sample_layer,
                         const bool with_hierachy, const uint16_t seed_,
                         const uint16_t sample_slot);
341
  virtual void MergeByInsId();
342 343 344
  virtual void PreprocessInstance();
  virtual void PostprocessInstance();
  virtual void SetCurrentPhase(int current_phase);
345 346 347 348 349 350 351 352 353 354
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num, int shard_num);
  virtual void ClearLocalTables() {
    for (auto& t : local_tables_) {
      t.clear();
      std::unordered_set<uint64_t>().swap(t);
    }
    std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
  }
355 356 357
  virtual void PreprocessChannel(
      const std::set<std::string>& slots_to_replace,
      std::unordered_set<uint16_t>& index_slot);  // NOLINT
358
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
359 360 361
  virtual void GetRandomData(
      const std::unordered_set<uint16_t>& slots_to_replace,
      std::vector<Record>* result);
362
  virtual ~MultiSlotDataset() {}
Y
yaoxuefeng 已提交
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void PrepareTrain();

 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg);
};
class SlotRecordDataset : public DatasetImpl<SlotRecord> {
 public:
  SlotRecordDataset() { SlotRecordPool(); }
  virtual ~SlotRecordDataset() {}
  // create input channel
  virtual void CreateChannel();
  // create readers
  virtual void CreateReaders();
  // release memory
  virtual void ReleaseMemory();
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins);
  virtual void PrepareTrain();
  virtual void DynamicAdjustReadersNum(int thread_num);

 protected:
  bool enable_heterps_ = true;
389 390
};

D
dongdaxiang 已提交
391 392
}  // end namespace framework
}  // end namespace paddle