提交 ecc59dcd 编写于 作者: M malin10

add dataset interface for tdm_tree

上级 8df5b4d6
......@@ -13,10 +13,12 @@
* limitations under the License. */
#include "paddle/fluid/framework/data_set.h"
#include <algorithm>
#include <random>
#include <unordered_map>
#include <unordered_set>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
......@@ -341,6 +343,73 @@ void DatasetImpl<T>::ReleaseMemory() {
STAT_SUB(STAT_total_feasign_num_in_mem, total_fea_num_);
}
template <typename T>
void DatasetImpl<T>::InitTDMTree(
const std::vector<std::pair<std::string, std::string>> config) {
auto tree_ptr = TreeWrapper::GetInstance();
for (auto& iter : config) {
tree_ptr->insert(iter.first, iter.second);
}
return;
}
// do dump
template <typename T>
void DatasetImpl<T>::TDMDump(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path) {
auto tree_ptr = TreeWrapper::GetInstance();
tree_ptr->dump_tree(name, table_id, fea_value_dim, tree_path);
}
// do sample
template <typename T>
void DatasetImpl<T>::TDMSample(const uint16_t sample_slot,
const uint64_t type_slot) {
VLOG(0) << "DatasetImpl<T>::Sample() begin";
platform::Timer timeline;
timeline.Start();
std::vector<std::vector<T>> data;
std::vector<std::vector<T>> sample_results;
if (!input_channel_ || input_channel_->Size() == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
std::vector<T> tmp_data;
data.push_back(tmp_data);
if (!multi_output_channel_[i] || multi_output_channel_[i]->Size() == 0) {
continue;
}
multi_output_channe_[i]->ReadAll(data[i]);
}
} else {
input_channel_->Close();
std::vector<T> tmp_data;
data.push_back(tmp_data);
input_channel_->ReadAll(data[data.size() - 1]);
}
auto tree_ptr = TreeWrapper::GetInstance();
for (auto i = 0; i < data.size(); i++) {
std::vector<T> tmp_results;
tree_ptr->sample(sample_slot, type_slot, data[i], tmp_results);
sample_results.push_back(tmp_results);
}
for (auto i = 0; i < sample_results.size(); i++) {
auto output_idx = fleet_ptr->LocalRandomEngine()() % output_channel_num;
multi_output_channe_[i]->Write(std::move(sample_results[i]))
}
data.clear();
sample_results.clear();
data.shrink_to_fit();
sample_results.shrink_to_fit();
timeline.Pause();
VLOG(0) << "DatasetImpl<T>::Sample() end, cost time=" << timeline.ElapsedSec()
<< " seconds";
return;
}
// do local shuffle
template <typename T>
void DatasetImpl<T>::LocalShuffle() {
......
......@@ -15,6 +15,7 @@
#pragma once
#include <ThreadPool.h>
#include <fstream>
#include <memory>
#include <mutex> // NOLINT
......@@ -159,7 +160,15 @@ class DatasetImpl : public Dataset {
DatasetImpl();
virtual ~DatasetImpl() {}
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void InitTDMTree(
const std::vector<std::pair<std::string, std::string>> config);
virtual void TDMSample(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path);
virtual void TDMDump(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path);
virtual void virtual void SetFileList(
const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size);
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include <unordered_map>
#include <utility>
#include <vector>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/async_executor.h"
......@@ -202,6 +203,12 @@ void BindDataset(py::module *m) {
.def(py::init([](const std::string &name = "MultiSlotDataset") {
return framework::DatasetFactory::CreateDataset(name);
}))
.def("init_tdm_tree", &framework::Dataset::InitTDMTree,
py::call_guard<py::gil_scoped_release>())
.def("tdm_sample", &framework::Dataset::TDMSample,
py::call_guard<py::gil_scoped_release>())
.def("tdm_dump", &framework::Dataset::TDMDump,
py::call_guard<py::gil_scoped_release>())
.def("set_filelist", &framework::Dataset::SetFileList,
py::call_guard<py::gil_scoped_release>())
.def("set_thread_num", &framework::Dataset::SetThreadNum,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册