data_set.h 15.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

#pragma once

17
#include <ThreadPool.h>
18 19 20
#include <fstream>
#include <memory>
#include <mutex>  // NOLINT
21
#include <set>
22 23
#include <string>
#include <thread>  // NOLINT
24
#include <unordered_set>
X
xjqbest 已提交
25
#include <utility>
26
#include <vector>
Y
yaoxuefeng 已提交
27 28 29 30
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
31 32 33 34 35 36

#include "paddle/fluid/framework/data_feed.h"

namespace paddle {
namespace framework {

X
xjqbest 已提交
37 38 39 40 41 42 43 44 45 46
// Dataset is a abstract class, which defines user interfaces
// Example Usage:
//    Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
//    dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
//    dataset->SetThreadNum(1)
//    dataset->CreateReaders();
//    dataset->SetDataFeedDesc(your_data_feed_desc);
//    dataset->LoadIntoMemory();
//    dataset->SetTrainerNum(2);
//    dataset->GlobalShuffle();
47 48
class Dataset {
 public:
49 50
  Dataset() {}
  virtual ~Dataset() {}
W
wangzhen38 已提交
51 52 53 54 55 56 57
  // do sample
  virtual void TDMSample(const std::string tree_name,
                         const std::string tree_path,
                         const std::vector<uint16_t> tdm_layer_counts,
                         const uint16_t start_sample_layer,
                         const bool with_hierachy, const uint16_t seed_,
                         const uint16_t sample_slot) {}
58
  // set file list
59
  virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
60
  // set readers' num
61
  virtual void SetThreadNum(int thread_num) = 0;
62
  // set workers' num
63
  virtual void SetTrainerNum(int trainer_num) = 0;
X
xjqbest 已提交
64 65
  // set fleet send batch size
  virtual void SetFleetSendBatchSize(int64_t size) = 0;
T
Thunderbrook 已提交
66
  virtual void ReleaseMemoryFun() = 0;
67
  // set fs name and ugi
68 69
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi) = 0;
70 71
  // set customized download command, such as using afs api
  virtual void SetDownloadCmd(const std::string& download_cmd) = 0;
72 73
  // set data fedd desc, which contains:
  //   data feed name, batch size, slots
74
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
J
jiaqi 已提交
75 76
  // set channel num
  virtual void SetChannelNum(int channel_num) = 0;
77 78 79
  // set parse ins id
  virtual void SetParseInsId(bool parse_ins_id) = 0;
  virtual void SetParseContent(bool parse_content) = 0;
80 81
  virtual void SetParseLogKey(bool parse_logkey) = 0;
  virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
82
  virtual bool EnablePvMerge() = 0;
83
  virtual void SetMergeBySid(bool is_merge) = 0;
84
  virtual void SetShuffleByUid(bool enable_shuffle_uid) = 0;
85
  // set merge by ins id
86
  virtual void SetMergeByInsId(int merge_size) = 0;
87
  virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
88 89
  // set fea eval mode
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
90
  // get file list
91
  virtual const std::vector<std::string>& GetFileList() = 0;
92
  // get thread num
93
  virtual int GetThreadNum() = 0;
94
  // get worker num
95
  virtual int GetTrainerNum() = 0;
X
xjqbest 已提交
96 97
  // get fleet send batch size
  virtual int64_t GetFleetSendBatchSize() = 0;
X
xjqbest 已提交
98 99
  // get hdfs config
  virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
100 101
  // get download cmd
  virtual std::string GetDownloadCmd() = 0;
102
  // get data fedd desc
103
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
J
jiaqi 已提交
104 105
  // get channel num
  virtual int GetChannelNum() = 0;
106 107
  // get readers, the reader num depend both on thread num
  // and filelist size
J
jiaqi 已提交
108 109 110
  virtual std::vector<paddle::framework::DataFeed*> GetReaders() = 0;
  // create input channel and output channel
  virtual void CreateChannel() = 0;
111 112 113
  // register message handler between workers
  virtual void RegisterClientToClientMsgHandler() = 0;
  // load all data into memory
114
  virtual void LoadIntoMemory() = 0;
J
jiaqi 已提交
115 116 117 118
  // load all data into memory in async mode
  virtual void PreLoadIntoMemory() = 0;
  // wait async load done
  virtual void WaitPreLoadDone() = 0;
119 120 121
  // release all memory data
  virtual void ReleaseMemory() = 0;
  // local shuffle data
122
  virtual void LocalShuffle() = 0;
123
  // global shuffle data
124
  virtual void GlobalShuffle(int thread_num = -1) = 0;
125
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
126
  // create readers
127
  virtual void CreateReaders() = 0;
128
  // destroy readers
129
  virtual void DestroyReaders() = 0;
130 131
  // get memory data size
  virtual int64_t GetMemoryDataSize() = 0;
132 133
  // get memory data size in input_pv_channel_
  virtual int64_t GetPvDataSize() = 0;
134 135
  // get shuffle data size
  virtual int64_t GetShuffleDataSize() = 0;
136 137
  // merge by ins id
  virtual void MergeByInsId() = 0;
138 139 140 141 142 143
  // merge pv instance
  virtual void PreprocessInstance() = 0;
  // divide pv instance
  virtual void PostprocessInstance() = 0;
  // only for untest
  virtual void SetCurrentPhase(int current_phase) = 0;
144 145 146 147 148
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num,
                                         int shard_num) = 0;
  virtual void ClearLocalTables() = 0;
149 150 151 152 153 154
  // create preload readers
  virtual void CreatePreLoadReaders() = 0;
  // destroy preload readers after prelaod done
  virtual void DestroyPreLoadReaders() = 0;
  // set preload thread num
  virtual void SetPreLoadThreadNum(int thread_num) = 0;
H
hutuxian 已提交
155 156 157
  // seperate train thread and dataset thread
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins = false) = 0;
158 159 160
  virtual void DynamicAdjustReadersNum(int thread_num) = 0;
  // set fleet send sleep seconds
  virtual void SetFleetSendSleepSeconds(int seconds) = 0;
161

162 163 164 165 166
 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg) = 0;
};

X
xjqbest 已提交
167 168
// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
169
template <typename T>
170 171 172
class DatasetImpl : public Dataset {
 public:
  DatasetImpl();
T
Thunderbrook 已提交
173 174 175 176 177
  virtual ~DatasetImpl() {
    if (release_thread_ != nullptr) {
      release_thread_->join();
    }
  }
178
  virtual void SetFileList(const std::vector<std::string>& filelist);
T
Thunderbrook 已提交
179
  virtual void ReleaseMemoryFun();
180 181
  virtual void SetThreadNum(int thread_num);
  virtual void SetTrainerNum(int trainer_num);
X
xjqbest 已提交
182
  virtual void SetFleetSendBatchSize(int64_t size);
183 184
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi);
185
  virtual void SetDownloadCmd(const std::string& download_cmd);
186
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
J
jiaqi 已提交
187
  virtual void SetChannelNum(int channel_num);
188 189
  virtual void SetParseInsId(bool parse_ins_id);
  virtual void SetParseContent(bool parse_content);
190 191 192
  virtual void SetParseLogKey(bool parse_logkey);
  virtual void SetEnablePvMerge(bool enable_pv_merge);
  virtual void SetMergeBySid(bool is_merge);
193
  virtual void SetShuffleByUid(bool enable_shuffle_uid);
194

195
  virtual void SetMergeByInsId(int merge_size);
196
  virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
197
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
D
dongdaxiang 已提交
198 199 200
  virtual const std::vector<std::string>& GetFileList() { return filelist_; }
  virtual int GetThreadNum() { return thread_num_; }
  virtual int GetTrainerNum() { return trainer_num_; }
H
hutuxian 已提交
201
  virtual Channel<T> GetInputChannel() { return input_channel_; }
202 203 204
  virtual void SetInputChannel(const Channel<T>& input_channel) {
    input_channel_ = input_channel;
  }
X
xjqbest 已提交
205
  virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
X
xjqbest 已提交
206 207 208
  virtual std::pair<std::string, std::string> GetHdfsConfig() {
    return std::make_pair(fs_name_, fs_ugi_);
  }
209
  virtual std::string GetDownloadCmd();
210 211 212
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
    return data_feed_desc_;
  }
J
jiaqi 已提交
213
  virtual int GetChannelNum() { return channel_num_; }
214
  virtual bool EnablePvMerge() { return enable_pv_merge_; }
J
jiaqi 已提交
215 216
  virtual std::vector<paddle::framework::DataFeed*> GetReaders();
  virtual void CreateChannel();
217
  virtual void RegisterClientToClientMsgHandler();
218
  virtual void LoadIntoMemory();
J
jiaqi 已提交
219 220
  virtual void PreLoadIntoMemory();
  virtual void WaitPreLoadDone();
221
  virtual void ReleaseMemory();
222
  virtual void LocalShuffle();
Y
yaoxuefeng 已提交
223
  virtual void GlobalShuffle(int thread_num = -1) {}
224
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
225 226 227
  virtual const std::vector<T>& GetSlotsOriginalData() {
    return slots_shuffle_original_data_;
  }
228
  virtual void CreateReaders();
229
  virtual void DestroyReaders();
230
  virtual int64_t GetMemoryDataSize();
231
  virtual int64_t GetPvDataSize();
232
  virtual int64_t GetShuffleDataSize();
233
  virtual void MergeByInsId() {}
234 235 236
  virtual void PreprocessInstance() {}
  virtual void PostprocessInstance() {}
  virtual void SetCurrentPhase(int current_phase) {}
237 238 239 240 241
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num,
                                         int shard_num) {}
  virtual void ClearLocalTables() {}
242 243 244
  virtual void CreatePreLoadReaders();
  virtual void DestroyPreLoadReaders();
  virtual void SetPreLoadThreadNum(int thread_num);
H
hutuxian 已提交
245 246
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins = false);
247 248
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void SetFleetSendSleepSeconds(int seconds);
Y
yaoxuefeng 已提交
249 250 251 252 253
  /* for enable_heterps_
  virtual void EnableHeterps(bool enable_heterps) {
    enable_heterps_ = enable_heterps;
  }
  */
D
dongdaxiang 已提交
254

T
Thunderbrook 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268
  std::vector<paddle::framework::Channel<T>>& GetMultiOutputChannel() {
    return multi_output_channel_;
  }

  std::vector<paddle::framework::Channel<T>>& GetCurOutputChannel() {
    if (cur_channel_ == 0) {
      return multi_output_channel_;
    } else {
      return multi_consume_channel_;
    }
  }

  Channel<T>& GetInputChannelRef() { return input_channel_; }

269
 protected:
D
dongdaxiang 已提交
270
  virtual int ReceiveFromClient(int msg_type, int client_id,
Y
yaoxuefeng 已提交
271 272 273 274
                                const std::string& msg) {
    // TODO(yaoxuefeng) for SlotRecordDataset
    return -1;
  }
275
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
276
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
J
jiaqi 已提交
277
  paddle::framework::Channel<T> input_channel_;
278 279 280 281
  paddle::framework::Channel<PvInstance> input_pv_channel_;
  std::vector<paddle::framework::Channel<PvInstance>> multi_pv_output_;
  std::vector<paddle::framework::Channel<PvInstance>> multi_pv_consume_;

J
jiaqi 已提交
282 283 284
  int channel_num_;
  std::vector<paddle::framework::Channel<T>> multi_output_channel_;
  std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
285
  std::vector<std::unordered_set<uint64_t>> local_tables_;
J
jiaqi 已提交
286 287 288 289
  // when read ins, we put ins from one channel to the other,
  // and when finish reading, we set cur_channel = 1 - cur_channel,
  // so if cur_channel=0, all data are in output_channel, else consume_channel
  int cur_channel_;
290 291
  std::vector<T> slots_shuffle_original_data_;
  RecordCandidateList slots_shuffle_rclist_;
292
  int thread_num_;
293
  int pull_sparse_to_local_thread_num_;
294 295
  paddle::framework::DataFeedDesc data_feed_desc_;
  int trainer_num_;
296 297
  std::vector<std::string> filelist_;
  size_t file_idx_;
H
hutuxian 已提交
298
  uint64_t total_fea_num_;
299
  std::mutex mutex_for_pick_file_;
H
hutuxian 已提交
300
  std::mutex mutex_for_fea_num_;
X
xjqbest 已提交
301 302
  std::string fs_name_;
  std::string fs_ugi_;
X
xjqbest 已提交
303
  int64_t fleet_send_batch_size_;
J
jiaqi 已提交
304 305
  int64_t fleet_send_sleep_seconds_;
  std::vector<std::thread> preload_threads_;
T
Thunderbrook 已提交
306
  std::thread* release_thread_ = nullptr;
307
  bool merge_by_insid_;
308 309
  bool parse_ins_id_;
  bool parse_content_;
310 311
  bool parse_logkey_;
  bool merge_by_sid_;
312 313
  bool shuffle_by_uid_;
  bool parse_uid_;
314 315
  bool enable_pv_merge_;  // True means to merge pv
  int current_phase_;     // 1 join, 0 update
316
  size_t merge_size_;
317
  bool slots_shuffle_fea_eval_ = false;
318
  bool gen_uni_feasigns_ = false;
319
  int preload_thread_num_;
320 321
  std::mutex global_index_mutex_;
  int64_t global_index_ = 0;
322
  std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
323
  std::vector<T> input_records_;  // only for paddleboxdatafeed
Y
yaoxuefeng 已提交
324
  bool enable_heterps_ = false;
325 326
};

327
// use std::vector<MultiSlotType> or Record as data type
J
jiaqi 已提交
328
class MultiSlotDataset : public DatasetImpl<Record> {
329 330
 public:
  MultiSlotDataset() {}
W
wangzhen38 已提交
331 332 333 334 335 336
  virtual void TDMSample(const std::string tree_name,
                         const std::string tree_path,
                         const std::vector<uint16_t> tdm_layer_counts,
                         const uint16_t start_sample_layer,
                         const bool with_hierachy, const uint16_t seed_,
                         const uint16_t sample_slot);
337
  virtual void MergeByInsId();
338 339 340
  virtual void PreprocessInstance();
  virtual void PostprocessInstance();
  virtual void SetCurrentPhase(int current_phase);
341 342 343 344 345 346 347 348 349 350
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num, int shard_num);
  virtual void ClearLocalTables() {
    for (auto& t : local_tables_) {
      t.clear();
      std::unordered_set<uint64_t>().swap(t);
    }
    std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
  }
351 352 353
  virtual void PreprocessChannel(
      const std::set<std::string>& slots_to_replace,
      std::unordered_set<uint16_t>& index_slot);  // NOLINT
354
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
355 356 357
  virtual void GetRandomData(
      const std::unordered_set<uint16_t>& slots_to_replace,
      std::vector<Record>* result);
358
  virtual ~MultiSlotDataset() {}
Y
yaoxuefeng 已提交
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void PrepareTrain();

 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg);
};
class SlotRecordDataset : public DatasetImpl<SlotRecord> {
 public:
  SlotRecordDataset() { SlotRecordPool(); }
  virtual ~SlotRecordDataset() {}
  // create input channel
  virtual void CreateChannel();
  // create readers
  virtual void CreateReaders();
  // release memory
  virtual void ReleaseMemory();
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins);
  virtual void PrepareTrain();
  virtual void DynamicAdjustReadersNum(int thread_num);

 protected:
  bool enable_heterps_ = true;
385 386
};

D
dongdaxiang 已提交
387 388
}  // end namespace framework
}  // end namespace paddle