data_set.h 7.2 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

#pragma once

#include <fstream>
#include <memory>
#include <mutex>  // NOLINT
#include <string>
#include <thread>  // NOLINT
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_feed.h"

namespace paddle {
namespace framework {

// Dataset is a abstract class, which defines user interfaces
// Example Usage:
//    Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
//    dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
//    dataset->SetThreadNum(1)
//    dataset->CreateReaders();
//    dataset->SetDataFeedDesc(your_data_feed_desc);
//    dataset->LoadIntoMemory();
//    dataset->SetTrainerNum(2);
//    dataset->GlobalShuffle();
class Dataset {
 public:
  Dataset() {}
  virtual ~Dataset() {}
  // set file list
  virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
  // set readers' num
  virtual void SetThreadNum(int thread_num) = 0;
  // set workers' num
  virtual void SetTrainerNum(int trainer_num) = 0;
  // set fleet send batch size
  virtual void SetFleetSendBatchSize(int64_t size) = 0;
  // set fs name and ugi
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi) = 0;
  // set data fedd desc, which contains:
  //   data feed name, batch size, slots
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
  // set channel num
  virtual void SetChannelNum(int channel_num) = 0;
  // set merge by ins id
  virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
                               bool erase_duplicate_feas, int min_merge_size,
                               bool keep_unmerged_ins) = 0;
  // get file list
  virtual const std::vector<std::string>& GetFileList() = 0;
  // get thread num
  virtual int GetThreadNum() = 0;
  // get worker num
  virtual int GetTrainerNum() = 0;
  // get fleet send batch size
  virtual int64_t GetFleetSendBatchSize() = 0;
  // get hdfs config
  virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
  // get data fedd desc
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
  // get channel num
  virtual int GetChannelNum() = 0;
  // get readers, the reader num depend both on thread num
  // and filelist size
  virtual std::vector<paddle::framework::DataFeed*> GetReaders() = 0;
  // create input channel and output channel
  virtual void CreateChannel() = 0;
  // register message handler between workers
  virtual void RegisterClientToClientMsgHandler() = 0;
  // load all data into memory
  virtual void LoadIntoMemory() = 0;
  // load all data into memory in async mode
  virtual void PreLoadIntoMemory() = 0;
  // wait async load done
  virtual void WaitPreLoadDone() = 0;
  // release all memory data
  virtual void ReleaseMemory() = 0;
  // local shuffle data
  virtual void LocalShuffle() = 0;
  // global shuffle data
  virtual void GlobalShuffle() = 0;
  // create readers
  virtual void CreateReaders() = 0;
  // destroy readers
  virtual void DestroyReaders() = 0;
  // get memory data size
  virtual int64_t GetMemoryDataSize() = 0;
  // get shuffle data size
  virtual int64_t GetShuffleDataSize() = 0;
  // merge by ins id
  virtual void MergeByInsId() = 0;

 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg) = 0;
};

// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
template <typename T>
class DatasetImpl : public Dataset {
 public:
  DatasetImpl();
  virtual ~DatasetImpl() {}

  virtual void SetFileList(const std::vector<std::string>& filelist);
  virtual void SetThreadNum(int thread_num);
  virtual void SetTrainerNum(int trainer_num);
  virtual void SetFleetSendBatchSize(int64_t size);
  virtual void SetHdfsConfig(const std::string& fs_name,
                             const std::string& fs_ugi);
  virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
  virtual void SetChannelNum(int channel_num);
  virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
                               bool erase_duplicate_feas, int min_merge_size,
                               bool keep_unmerged_ins);

  virtual const std::vector<std::string>& GetFileList() { return filelist_; }
  virtual int GetThreadNum() { return thread_num_; }
  virtual int GetTrainerNum() { return trainer_num_; }
  virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
  virtual std::pair<std::string, std::string> GetHdfsConfig() {
    return std::make_pair(fs_name_, fs_ugi_);
  }
  virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
    return data_feed_desc_;
  }
  virtual int GetChannelNum() { return channel_num_; }
  virtual std::vector<paddle::framework::DataFeed*> GetReaders();
  virtual void CreateChannel();
  virtual void RegisterClientToClientMsgHandler();
  virtual void LoadIntoMemory();
  virtual void PreLoadIntoMemory();
  virtual void WaitPreLoadDone();
  virtual void ReleaseMemory();
  virtual void LocalShuffle();
  virtual void GlobalShuffle();
  virtual void CreateReaders();
  virtual void DestroyReaders();
  virtual int64_t GetMemoryDataSize();
  virtual int64_t GetShuffleDataSize();
  virtual void MergeByInsId() {}

 protected:
  virtual int ReceiveFromClient(int msg_type, int client_id,
                                const std::string& msg);
  std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
  paddle::framework::Channel<T> input_channel_;
  int channel_num_;
  std::vector<paddle::framework::Channel<T>> multi_output_channel_;
  std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
  // when read ins, we put ins from one channel to the other,
  // and when finish reading, we set cur_channel = 1 - cur_channel,
  // so if cur_channel=0, all data are in output_channel, else consume_channel
  int cur_channel_;
  int thread_num_;
  paddle::framework::DataFeedDesc data_feed_desc_;
  int trainer_num_;
  std::vector<std::string> filelist_;
  size_t file_idx_;
  std::mutex mutex_for_pick_file_;
  std::string fs_name_;
  std::string fs_ugi_;
  int64_t fleet_send_batch_size_;
  int64_t fleet_send_sleep_seconds_;
  std::vector<std::thread> preload_threads_;
  bool merge_by_insid_;
  bool erase_duplicate_feas_;
  bool keep_unmerged_ins_;
  int min_merge_size_;
  std::vector<std::string> merge_slots_list_;
};

// use std::vector<MultiSlotType> or Record as data type
class MultiSlotDataset : public DatasetImpl<Record> {
 public:
  MultiSlotDataset() {}
  virtual void MergeByInsId();
  virtual ~MultiSlotDataset() {}
};

}  // end namespace framework
}  // end namespace paddle