data_set.h 14.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

#pragma once

17
#include <ThreadPool.h>
18 19 20
#include <fstream>
#include <memory>
#include <mutex>  // NOLINT
21
#include <set>
22 23
#include <string>
#include <thread>  // NOLINT
24
#include <unordered_set>
X
xjqbest 已提交
25
#include <utility>
26
#include <vector>
Y
yaoxuefeng 已提交
27 28 29 30
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
31 32 33 34 35 36

#include "paddle/fluid/framework/data_feed.h"

namespace paddle {
namespace framework {

X
xjqbest 已提交
37 38 39 40 41 42 43 44 45 46
// Dataset is a abstract class, which defines user interfaces
// Example Usage:
//    Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
//    dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
//    dataset->SetThreadNum(1)
//    dataset->CreateReaders();
//    dataset->SetDataFeedDesc(your_data_feed_desc);
//    dataset->LoadIntoMemory();
//    dataset->SetTrainerNum(2);
//    dataset->GlobalShuffle();
47 48
class Dataset {
 public:
49 50
  Dataset() {}
  virtual ~Dataset() {}
W
wangzhen38 已提交
51 52 53 54 55 56 57
  // do sample
  virtual void TDMSample(const std::string tree_name,
                         const std::string tree_path,
                         const std::vector<uint16_t> tdm_layer_counts,
                         const uint16_t start_sample_layer,
                         const bool with_hierachy, const uint16_t seed_,
                         const uint16_t sample_slot) {}
58
  // set file list
59
  virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
60
  // set readers' num
61
  virtual void SetThreadNum(int thread_num) = 0;
62
  // set workers' num
63
  virtual void SetTrainerNum(int trainer_num) = 0;
X
xjqbest 已提交
64 65
  // set fleet send batch size
  virtual void SetFleetSendBatchSize(int64_t size) = 0;
T
Thunderbrook 已提交
66
  virtual void ReleaseMemoryFun() = 0;
67
  // set fs name and ugi
68 69
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi) = 0;
70 71
  // set customized download command, such as using afs api
  virtual void SetDownloadCmd(const std::string& download_cmd) = 0;
72 73
  // set data fedd desc, which contains:
  //   data feed name, batch size, slots
74
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
J
jiaqi 已提交
75 76
  // set channel num
  virtual void SetChannelNum(int channel_num) = 0;
77 78 79
  // set parse ins id
  virtual void SetParseInsId(bool parse_ins_id) = 0;
  virtual void SetParseContent(bool parse_content) = 0;
80 81
  virtual void SetParseLogKey(bool parse_logkey) = 0;
  virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
82
  virtual bool EnablePvMerge() = 0;
83
  virtual void SetMergeBySid(bool is_merge) = 0;
84
  // set merge by ins id
85
  virtual void SetMergeByInsId(int merge_size) = 0;
86
  virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
87 88
  // set fea eval mode
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
89
  // get file list
90
  virtual const std::vector<std::string>& GetFileList() = 0;
91
  // get thread num
92
  virtual int GetThreadNum() = 0;
93
  // get worker num
94
  virtual int GetTrainerNum() = 0;
X
xjqbest 已提交
95 96
  // get fleet send batch size
  virtual int64_t GetFleetSendBatchSize() = 0;
X
xjqbest 已提交
97 98
  // get hdfs config
  virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
99 100
  // get download cmd
  virtual std::string GetDownloadCmd() = 0;
101
  // get data fedd desc
102
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
J
jiaqi 已提交
103 104
  // get channel num
  virtual int GetChannelNum() = 0;
105 106
  // get readers, the reader num depend both on thread num
  // and filelist size
J
jiaqi 已提交
107 108 109
  virtual std::vector<paddle::framework::DataFeed*> GetReaders() = 0;
  // create input channel and output channel
  virtual void CreateChannel() = 0;
110 111 112
  // register message handler between workers
  virtual void RegisterClientToClientMsgHandler() = 0;
  // load all data into memory
113
  virtual void LoadIntoMemory() = 0;
J
jiaqi 已提交
114 115 116 117
  // load all data into memory in async mode
  virtual void PreLoadIntoMemory() = 0;
  // wait async load done
  virtual void WaitPreLoadDone() = 0;
118 119 120
  // release all memory data
  virtual void ReleaseMemory() = 0;
  // local shuffle data
121
  virtual void LocalShuffle() = 0;
122
  // global shuffle data
123
  virtual void GlobalShuffle(int thread_num = -1) = 0;
124
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
125
  // create readers
126
  virtual void CreateReaders() = 0;
127
  // destroy readers
128
  virtual void DestroyReaders() = 0;
129 130
  // get memory data size
  virtual int64_t GetMemoryDataSize() = 0;
131 132
  // get memory data size in input_pv_channel_
  virtual int64_t GetPvDataSize() = 0;
133 134
  // get shuffle data size
  virtual int64_t GetShuffleDataSize() = 0;
135 136
  // merge by ins id
  virtual void MergeByInsId() = 0;
137 138 139 140 141 142
  // merge pv instance
  virtual void PreprocessInstance() = 0;
  // divide pv instance
  virtual void PostprocessInstance() = 0;
  // only for untest
  virtual void SetCurrentPhase(int current_phase) = 0;
143 144 145 146 147
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num,
                                         int shard_num) = 0;
  virtual void ClearLocalTables() = 0;
148 149 150 151 152 153
  // create preload readers
  virtual void CreatePreLoadReaders() = 0;
  // destroy preload readers after prelaod done
  virtual void DestroyPreLoadReaders() = 0;
  // set preload thread num
  virtual void SetPreLoadThreadNum(int thread_num) = 0;
H
hutuxian 已提交
154 155 156
  // seperate train thread and dataset thread
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins = false) = 0;
157 158 159
  virtual void DynamicAdjustReadersNum(int thread_num) = 0;
  // set fleet send sleep seconds
  virtual void SetFleetSendSleepSeconds(int seconds) = 0;
160

161 162 163 164 165
 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg) = 0;
};

X
xjqbest 已提交
166 167
// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
168
template <typename T>
169 170 171
class DatasetImpl : public Dataset {
 public:
  DatasetImpl();
T
Thunderbrook 已提交
172 173 174 175 176
  virtual ~DatasetImpl() {
    if (release_thread_ != nullptr) {
      release_thread_->join();
    }
  }
177
  virtual void SetFileList(const std::vector<std::string>& filelist);
T
Thunderbrook 已提交
178
  virtual void ReleaseMemoryFun();
179 180
  virtual void SetThreadNum(int thread_num);
  virtual void SetTrainerNum(int trainer_num);
X
xjqbest 已提交
181
  virtual void SetFleetSendBatchSize(int64_t size);
182 183
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi);
184
  virtual void SetDownloadCmd(const std::string& download_cmd);
185
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
J
jiaqi 已提交
186
  virtual void SetChannelNum(int channel_num);
187 188
  virtual void SetParseInsId(bool parse_ins_id);
  virtual void SetParseContent(bool parse_content);
189 190 191 192
  virtual void SetParseLogKey(bool parse_logkey);
  virtual void SetEnablePvMerge(bool enable_pv_merge);
  virtual void SetMergeBySid(bool is_merge);

193
  virtual void SetMergeByInsId(int merge_size);
194
  virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
195
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
D
dongdaxiang 已提交
196 197 198
  virtual const std::vector<std::string>& GetFileList() { return filelist_; }
  virtual int GetThreadNum() { return thread_num_; }
  virtual int GetTrainerNum() { return trainer_num_; }
H
hutuxian 已提交
199
  virtual Channel<T> GetInputChannel() { return input_channel_; }
200 201 202
  virtual void SetInputChannel(const Channel<T>& input_channel) {
    input_channel_ = input_channel;
  }
X
xjqbest 已提交
203
  virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
X
xjqbest 已提交
204 205 206
  virtual std::pair<std::string, std::string> GetHdfsConfig() {
    return std::make_pair(fs_name_, fs_ugi_);
  }
207
  virtual std::string GetDownloadCmd();
208 209 210
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
    return data_feed_desc_;
  }
J
jiaqi 已提交
211
  virtual int GetChannelNum() { return channel_num_; }
212
  virtual bool EnablePvMerge() { return enable_pv_merge_; }
J
jiaqi 已提交
213 214
  virtual std::vector<paddle::framework::DataFeed*> GetReaders();
  virtual void CreateChannel();
215
  virtual void RegisterClientToClientMsgHandler();
216
  virtual void LoadIntoMemory();
J
jiaqi 已提交
217 218
  virtual void PreLoadIntoMemory();
  virtual void WaitPreLoadDone();
219
  virtual void ReleaseMemory();
220
  virtual void LocalShuffle();
Y
yaoxuefeng 已提交
221
  virtual void GlobalShuffle(int thread_num = -1) {}
222
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
223 224 225
  virtual const std::vector<T>& GetSlotsOriginalData() {
    return slots_shuffle_original_data_;
  }
226
  virtual void CreateReaders();
227
  virtual void DestroyReaders();
228
  virtual int64_t GetMemoryDataSize();
229
  virtual int64_t GetPvDataSize();
230
  virtual int64_t GetShuffleDataSize();
231
  virtual void MergeByInsId() {}
232 233 234
  virtual void PreprocessInstance() {}
  virtual void PostprocessInstance() {}
  virtual void SetCurrentPhase(int current_phase) {}
235 236 237 238 239
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num,
                                         int shard_num) {}
  virtual void ClearLocalTables() {}
240 241 242
  virtual void CreatePreLoadReaders();
  virtual void DestroyPreLoadReaders();
  virtual void SetPreLoadThreadNum(int thread_num);
H
hutuxian 已提交
243 244
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins = false);
245 246
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void SetFleetSendSleepSeconds(int seconds);
Y
yaoxuefeng 已提交
247 248 249 250 251
  /* for enable_heterps_
  virtual void EnableHeterps(bool enable_heterps) {
    enable_heterps_ = enable_heterps;
  }
  */
D
dongdaxiang 已提交
252

T
Thunderbrook 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266
  std::vector<paddle::framework::Channel<T>>& GetMultiOutputChannel() {
    return multi_output_channel_;
  }

  std::vector<paddle::framework::Channel<T>>& GetCurOutputChannel() {
    if (cur_channel_ == 0) {
      return multi_output_channel_;
    } else {
      return multi_consume_channel_;
    }
  }

  Channel<T>& GetInputChannelRef() { return input_channel_; }

267
 protected:
D
dongdaxiang 已提交
268
  virtual int ReceiveFromClient(int msg_type, int client_id,
Y
yaoxuefeng 已提交
269 270 271 272
                                const std::string& msg) {
    // TODO(yaoxuefeng) for SlotRecordDataset
    return -1;
  }
273
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
274
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
J
jiaqi 已提交
275
  paddle::framework::Channel<T> input_channel_;
276 277 278 279
  paddle::framework::Channel<PvInstance> input_pv_channel_;
  std::vector<paddle::framework::Channel<PvInstance>> multi_pv_output_;
  std::vector<paddle::framework::Channel<PvInstance>> multi_pv_consume_;

J
jiaqi 已提交
280 281 282
  int channel_num_;
  std::vector<paddle::framework::Channel<T>> multi_output_channel_;
  std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
283
  std::vector<std::unordered_set<uint64_t>> local_tables_;
J
jiaqi 已提交
284 285 286 287
  // when read ins, we put ins from one channel to the other,
  // and when finish reading, we set cur_channel = 1 - cur_channel,
  // so if cur_channel=0, all data are in output_channel, else consume_channel
  int cur_channel_;
288 289
  std::vector<T> slots_shuffle_original_data_;
  RecordCandidateList slots_shuffle_rclist_;
290
  int thread_num_;
291
  int pull_sparse_to_local_thread_num_;
292 293
  paddle::framework::DataFeedDesc data_feed_desc_;
  int trainer_num_;
294 295
  std::vector<std::string> filelist_;
  size_t file_idx_;
H
hutuxian 已提交
296
  uint64_t total_fea_num_;
297
  std::mutex mutex_for_pick_file_;
H
hutuxian 已提交
298
  std::mutex mutex_for_fea_num_;
X
xjqbest 已提交
299 300
  std::string fs_name_;
  std::string fs_ugi_;
X
xjqbest 已提交
301
  int64_t fleet_send_batch_size_;
J
jiaqi 已提交
302 303
  int64_t fleet_send_sleep_seconds_;
  std::vector<std::thread> preload_threads_;
T
Thunderbrook 已提交
304
  std::thread* release_thread_ = nullptr;
305
  bool merge_by_insid_;
306 307
  bool parse_ins_id_;
  bool parse_content_;
308 309 310 311
  bool parse_logkey_;
  bool merge_by_sid_;
  bool enable_pv_merge_;  // True means to merge pv
  int current_phase_;     // 1 join, 0 update
312
  size_t merge_size_;
313
  bool slots_shuffle_fea_eval_ = false;
314
  bool gen_uni_feasigns_ = false;
315
  int preload_thread_num_;
316 317
  std::mutex global_index_mutex_;
  int64_t global_index_ = 0;
318
  std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
319
  std::vector<T> input_records_;  // only for paddleboxdatafeed
Y
yaoxuefeng 已提交
320
  bool enable_heterps_ = false;
321 322
};

323
// use std::vector<MultiSlotType> or Record as data type
J
jiaqi 已提交
324
class MultiSlotDataset : public DatasetImpl<Record> {
325 326
 public:
  MultiSlotDataset() {}
W
wangzhen38 已提交
327 328 329 330 331 332
  virtual void TDMSample(const std::string tree_name,
                         const std::string tree_path,
                         const std::vector<uint16_t> tdm_layer_counts,
                         const uint16_t start_sample_layer,
                         const bool with_hierachy, const uint16_t seed_,
                         const uint16_t sample_slot);
333
  virtual void MergeByInsId();
334 335 336
  virtual void PreprocessInstance();
  virtual void PostprocessInstance();
  virtual void SetCurrentPhase(int current_phase);
337 338 339 340 341 342 343 344 345 346
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num, int shard_num);
  virtual void ClearLocalTables() {
    for (auto& t : local_tables_) {
      t.clear();
      std::unordered_set<uint64_t>().swap(t);
    }
    std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
  }
347 348 349
  virtual void PreprocessChannel(
      const std::set<std::string>& slots_to_replace,
      std::unordered_set<uint16_t>& index_slot);  // NOLINT
350
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
351 352 353
  virtual void GetRandomData(
      const std::unordered_set<uint16_t>& slots_to_replace,
      std::vector<Record>* result);
354
  virtual ~MultiSlotDataset() {}
Y
yaoxuefeng 已提交
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void PrepareTrain();

 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg);
};
class SlotRecordDataset : public DatasetImpl<SlotRecord> {
 public:
  SlotRecordDataset() { SlotRecordPool(); }
  virtual ~SlotRecordDataset() {}
  // create input channel
  virtual void CreateChannel();
  // create readers
  virtual void CreateReaders();
  // release memory
  virtual void ReleaseMemory();
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins);
  virtual void PrepareTrain();
  virtual void DynamicAdjustReadersNum(int thread_num);

 protected:
  bool enable_heterps_ = true;
381 382
};

D
dongdaxiang 已提交
383 384
}  // end namespace framework
}  // end namespace paddle