data_set.h 13.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

#pragma once

17
#include <ThreadPool.h>
18 19 20
#include <fstream>
#include <memory>
#include <mutex>  // NOLINT
21
#include <set>
22 23
#include <string>
#include <thread>  // NOLINT
24
#include <unordered_set>
X
xjqbest 已提交
25
#include <utility>
26
#include <vector>
Y
yaoxuefeng 已提交
27 28 29 30
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
31 32 33 34 35 36

#include "paddle/fluid/framework/data_feed.h"

namespace paddle {
namespace framework {

X
xjqbest 已提交
37 38 39 40 41 42 43 44 45 46
// Dataset is a abstract class, which defines user interfaces
// Example Usage:
//    Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
//    dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
//    dataset->SetThreadNum(1)
//    dataset->CreateReaders();
//    dataset->SetDataFeedDesc(your_data_feed_desc);
//    dataset->LoadIntoMemory();
//    dataset->SetTrainerNum(2);
//    dataset->GlobalShuffle();
47 48
class Dataset {
 public:
49 50
  Dataset() {}
  virtual ~Dataset() {}
51
  // set file list
52
  virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
53
  // set readers' num
54
  virtual void SetThreadNum(int thread_num) = 0;
55
  // set workers' num
56
  virtual void SetTrainerNum(int trainer_num) = 0;
X
xjqbest 已提交
57 58
  // set fleet send batch size
  virtual void SetFleetSendBatchSize(int64_t size) = 0;
59
  // set fs name and ugi
60 61
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi) = 0;
62 63
  // set customized download command, such as using afs api
  virtual void SetDownloadCmd(const std::string& download_cmd) = 0;
64 65
  // set data fedd desc, which contains:
  //   data feed name, batch size, slots
66
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
J
jiaqi 已提交
67 68
  // set channel num
  virtual void SetChannelNum(int channel_num) = 0;
69 70 71
  // set parse ins id
  virtual void SetParseInsId(bool parse_ins_id) = 0;
  virtual void SetParseContent(bool parse_content) = 0;
72 73
  virtual void SetParseLogKey(bool parse_logkey) = 0;
  virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
74
  virtual bool EnablePvMerge() = 0;
75
  virtual void SetMergeBySid(bool is_merge) = 0;
76
  // set merge by ins id
77
  virtual void SetMergeByInsId(int merge_size) = 0;
78
  virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
79 80
  // set fea eval mode
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
81
  // get file list
82
  virtual const std::vector<std::string>& GetFileList() = 0;
83
  // get thread num
84
  virtual int GetThreadNum() = 0;
85
  // get worker num
86
  virtual int GetTrainerNum() = 0;
X
xjqbest 已提交
87 88
  // get fleet send batch size
  virtual int64_t GetFleetSendBatchSize() = 0;
X
xjqbest 已提交
89 90
  // get hdfs config
  virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
91 92
  // get download cmd
  virtual std::string GetDownloadCmd() = 0;
93
  // get data fedd desc
94
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
J
jiaqi 已提交
95 96
  // get channel num
  virtual int GetChannelNum() = 0;
97 98
  // get readers, the reader num depend both on thread num
  // and filelist size
J
jiaqi 已提交
99 100 101
  virtual std::vector<paddle::framework::DataFeed*> GetReaders() = 0;
  // create input channel and output channel
  virtual void CreateChannel() = 0;
102 103 104
  // register message handler between workers
  virtual void RegisterClientToClientMsgHandler() = 0;
  // load all data into memory
105
  virtual void LoadIntoMemory() = 0;
J
jiaqi 已提交
106 107 108 109
  // load all data into memory in async mode
  virtual void PreLoadIntoMemory() = 0;
  // wait async load done
  virtual void WaitPreLoadDone() = 0;
110 111 112
  // release all memory data
  virtual void ReleaseMemory() = 0;
  // local shuffle data
113
  virtual void LocalShuffle() = 0;
114
  // global shuffle data
115
  virtual void GlobalShuffle(int thread_num = -1) = 0;
116
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
117
  // create readers
118
  virtual void CreateReaders() = 0;
119
  // destroy readers
120
  virtual void DestroyReaders() = 0;
121 122
  // get memory data size
  virtual int64_t GetMemoryDataSize() = 0;
123 124
  // get memory data size in input_pv_channel_
  virtual int64_t GetPvDataSize() = 0;
125 126
  // get shuffle data size
  virtual int64_t GetShuffleDataSize() = 0;
127 128
  // merge by ins id
  virtual void MergeByInsId() = 0;
129 130 131 132 133 134
  // merge pv instance
  virtual void PreprocessInstance() = 0;
  // divide pv instance
  virtual void PostprocessInstance() = 0;
  // only for untest
  virtual void SetCurrentPhase(int current_phase) = 0;
135 136 137 138 139
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num,
                                         int shard_num) = 0;
  virtual void ClearLocalTables() = 0;
140 141 142 143 144 145
  // create preload readers
  virtual void CreatePreLoadReaders() = 0;
  // destroy preload readers after prelaod done
  virtual void DestroyPreLoadReaders() = 0;
  // set preload thread num
  virtual void SetPreLoadThreadNum(int thread_num) = 0;
H
hutuxian 已提交
146 147 148
  // seperate train thread and dataset thread
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins = false) = 0;
149 150 151
  virtual void DynamicAdjustReadersNum(int thread_num) = 0;
  // set fleet send sleep seconds
  virtual void SetFleetSendSleepSeconds(int seconds) = 0;
152

153 154 155 156 157
 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg) = 0;
};

X
xjqbest 已提交
158 159
// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
160
template <typename T>
161 162 163 164
class DatasetImpl : public Dataset {
 public:
  DatasetImpl();
  virtual ~DatasetImpl() {}
165 166 167 168

  virtual void SetFileList(const std::vector<std::string>& filelist);
  virtual void SetThreadNum(int thread_num);
  virtual void SetTrainerNum(int trainer_num);
X
xjqbest 已提交
169
  virtual void SetFleetSendBatchSize(int64_t size);
170 171
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi);
172
  virtual void SetDownloadCmd(const std::string& download_cmd);
173
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
J
jiaqi 已提交
174
  virtual void SetChannelNum(int channel_num);
175 176
  virtual void SetParseInsId(bool parse_ins_id);
  virtual void SetParseContent(bool parse_content);
177 178 179 180
  virtual void SetParseLogKey(bool parse_logkey);
  virtual void SetEnablePvMerge(bool enable_pv_merge);
  virtual void SetMergeBySid(bool is_merge);

181
  virtual void SetMergeByInsId(int merge_size);
182
  virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
183
  virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
D
dongdaxiang 已提交
184 185 186
  virtual const std::vector<std::string>& GetFileList() { return filelist_; }
  virtual int GetThreadNum() { return thread_num_; }
  virtual int GetTrainerNum() { return trainer_num_; }
H
hutuxian 已提交
187
  virtual Channel<T> GetInputChannel() { return input_channel_; }
188 189 190
  virtual void SetInputChannel(const Channel<T>& input_channel) {
    input_channel_ = input_channel;
  }
X
xjqbest 已提交
191
  virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
X
xjqbest 已提交
192 193 194
  virtual std::pair<std::string, std::string> GetHdfsConfig() {
    return std::make_pair(fs_name_, fs_ugi_);
  }
195
  virtual std::string GetDownloadCmd();
196 197 198
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
    return data_feed_desc_;
  }
J
jiaqi 已提交
199
  virtual int GetChannelNum() { return channel_num_; }
200
  virtual bool EnablePvMerge() { return enable_pv_merge_; }
J
jiaqi 已提交
201 202
  virtual std::vector<paddle::framework::DataFeed*> GetReaders();
  virtual void CreateChannel();
203
  virtual void RegisterClientToClientMsgHandler();
204
  virtual void LoadIntoMemory();
J
jiaqi 已提交
205 206
  virtual void PreLoadIntoMemory();
  virtual void WaitPreLoadDone();
207
  virtual void ReleaseMemory();
208
  virtual void LocalShuffle();
Y
yaoxuefeng 已提交
209
  virtual void GlobalShuffle(int thread_num = -1) {}
210
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
211 212 213
  virtual const std::vector<T>& GetSlotsOriginalData() {
    return slots_shuffle_original_data_;
  }
214
  virtual void CreateReaders();
215
  virtual void DestroyReaders();
216
  virtual int64_t GetMemoryDataSize();
217
  virtual int64_t GetPvDataSize();
218
  virtual int64_t GetShuffleDataSize();
219
  virtual void MergeByInsId() {}
220 221 222
  virtual void PreprocessInstance() {}
  virtual void PostprocessInstance() {}
  virtual void SetCurrentPhase(int current_phase) {}
223 224 225 226 227
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num,
                                         int shard_num) {}
  virtual void ClearLocalTables() {}
228 229 230
  virtual void CreatePreLoadReaders();
  virtual void DestroyPreLoadReaders();
  virtual void SetPreLoadThreadNum(int thread_num);
H
hutuxian 已提交
231 232
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins = false);
233 234
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void SetFleetSendSleepSeconds(int seconds);
Y
yaoxuefeng 已提交
235 236 237 238 239
  /* for enable_heterps_
  virtual void EnableHeterps(bool enable_heterps) {
    enable_heterps_ = enable_heterps;
  }
  */
D
dongdaxiang 已提交
240

T
Thunderbrook 已提交
241 242 243 244 245 246 247 248 249 250 251 252 253 254
  std::vector<paddle::framework::Channel<T>>& GetMultiOutputChannel() {
    return multi_output_channel_;
  }

  std::vector<paddle::framework::Channel<T>>& GetCurOutputChannel() {
    if (cur_channel_ == 0) {
      return multi_output_channel_;
    } else {
      return multi_consume_channel_;
    }
  }

  Channel<T>& GetInputChannelRef() { return input_channel_; }

255
 protected:
D
dongdaxiang 已提交
256
  virtual int ReceiveFromClient(int msg_type, int client_id,
Y
yaoxuefeng 已提交
257 258 259 260
                                const std::string& msg) {
    // TODO(yaoxuefeng) for SlotRecordDataset
    return -1;
  }
261
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
262
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
J
jiaqi 已提交
263
  paddle::framework::Channel<T> input_channel_;
264 265 266 267
  paddle::framework::Channel<PvInstance> input_pv_channel_;
  std::vector<paddle::framework::Channel<PvInstance>> multi_pv_output_;
  std::vector<paddle::framework::Channel<PvInstance>> multi_pv_consume_;

J
jiaqi 已提交
268 269 270
  int channel_num_;
  std::vector<paddle::framework::Channel<T>> multi_output_channel_;
  std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
271
  std::vector<std::unordered_set<uint64_t>> local_tables_;
J
jiaqi 已提交
272 273 274 275
  // when read ins, we put ins from one channel to the other,
  // and when finish reading, we set cur_channel = 1 - cur_channel,
  // so if cur_channel=0, all data are in output_channel, else consume_channel
  int cur_channel_;
276 277
  std::vector<T> slots_shuffle_original_data_;
  RecordCandidateList slots_shuffle_rclist_;
278
  int thread_num_;
279
  int pull_sparse_to_local_thread_num_;
280 281
  paddle::framework::DataFeedDesc data_feed_desc_;
  int trainer_num_;
282 283
  std::vector<std::string> filelist_;
  size_t file_idx_;
H
hutuxian 已提交
284
  uint64_t total_fea_num_;
285
  std::mutex mutex_for_pick_file_;
H
hutuxian 已提交
286
  std::mutex mutex_for_fea_num_;
X
xjqbest 已提交
287 288
  std::string fs_name_;
  std::string fs_ugi_;
X
xjqbest 已提交
289
  int64_t fleet_send_batch_size_;
J
jiaqi 已提交
290 291
  int64_t fleet_send_sleep_seconds_;
  std::vector<std::thread> preload_threads_;
292
  bool merge_by_insid_;
293 294
  bool parse_ins_id_;
  bool parse_content_;
295 296 297 298
  bool parse_logkey_;
  bool merge_by_sid_;
  bool enable_pv_merge_;  // True means to merge pv
  int current_phase_;     // 1 join, 0 update
299
  size_t merge_size_;
300
  bool slots_shuffle_fea_eval_ = false;
301
  bool gen_uni_feasigns_ = false;
302
  int preload_thread_num_;
303 304
  std::mutex global_index_mutex_;
  int64_t global_index_ = 0;
305
  std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
306
  std::vector<T> input_records_;  // only for paddleboxdatafeed
Y
yaoxuefeng 已提交
307
  bool enable_heterps_ = false;
308 309
};

310
// use std::vector<MultiSlotType> or Record as data type
J
jiaqi 已提交
311
class MultiSlotDataset : public DatasetImpl<Record> {
312 313
 public:
  MultiSlotDataset() {}
314
  virtual void MergeByInsId();
315 316 317
  virtual void PreprocessInstance();
  virtual void PostprocessInstance();
  virtual void SetCurrentPhase(int current_phase);
318 319 320 321 322 323 324 325 326 327
  virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
                                         int read_thread_num,
                                         int consume_thread_num, int shard_num);
  virtual void ClearLocalTables() {
    for (auto& t : local_tables_) {
      t.clear();
      std::unordered_set<uint64_t>().swap(t);
    }
    std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
  }
328 329 330
  virtual void PreprocessChannel(
      const std::set<std::string>& slots_to_replace,
      std::unordered_set<uint16_t>& index_slot);  // NOLINT
331
  virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
332 333 334
  virtual void GetRandomData(
      const std::unordered_set<uint16_t>& slots_to_replace,
      std::vector<Record>* result);
335
  virtual ~MultiSlotDataset() {}
Y
yaoxuefeng 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void DynamicAdjustReadersNum(int thread_num);
  virtual void PrepareTrain();

 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg);
};
class SlotRecordDataset : public DatasetImpl<SlotRecord> {
 public:
  SlotRecordDataset() { SlotRecordPool(); }
  virtual ~SlotRecordDataset() {}
  // create input channel
  virtual void CreateChannel();
  // create readers
  virtual void CreateReaders();
  // release memory
  virtual void ReleaseMemory();
  virtual void GlobalShuffle(int thread_num = -1);
  virtual void DynamicAdjustChannelNum(int channel_num,
                                       bool discard_remaining_ins);
  virtual void PrepareTrain();
  virtual void DynamicAdjustReadersNum(int thread_num);

 protected:
  bool enable_heterps_ = true;
362 363
};

D
dongdaxiang 已提交
364 365
}  // end namespace framework
}  // end namespace paddle