From ecc59dcdd3c0eb1466fb5fdc51d8a87abeb69f40 Mon Sep 17 00:00:00 2001 From: malin10 Date: Tue, 8 Sep 2020 18:23:12 +0800 Subject: [PATCH] add dataset interface for tdm_tree --- paddle/fluid/framework/data_set.cc | 69 ++++++++++++++++++++++++++++++ paddle/fluid/framework/data_set.h | 11 ++++- paddle/fluid/pybind/data_set_py.cc | 7 +++ 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 94934629e28..6882922b3ed 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -13,10 +13,12 @@ * limitations under the License. */ #include "paddle/fluid/framework/data_set.h" + #include #include #include #include + #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" @@ -341,6 +343,73 @@ void DatasetImpl::ReleaseMemory() { STAT_SUB(STAT_total_feasign_num_in_mem, total_fea_num_); } +template +void DatasetImpl::InitTDMTree( + const std::vector> config) { + auto tree_ptr = TreeWrapper::GetInstance(); + for (auto& iter : config) { + tree_ptr->insert(iter.first, iter.second); + } + return; +} + +// do dump +template +void DatasetImpl::TDMDump(std::string name, const uint64_t table_id, + int fea_value_dim, const std::string tree_path) { + auto tree_ptr = TreeWrapper::GetInstance(); + tree_ptr->dump_tree(name, table_id, fea_value_dim, tree_path); +} + +// do sample +template +void DatasetImpl::TDMSample(const uint16_t sample_slot, + const uint64_t type_slot) { + VLOG(0) << "DatasetImpl::Sample() begin"; + platform::Timer timeline; + timeline.Start(); + + std::vector> data; + std::vector> sample_results; + if (!input_channel_ || input_channel_->Size() == 0) { + for (size_t i = 0; i < multi_output_channel_.size(); ++i) { + std::vector tmp_data; + data.push_back(tmp_data); + if (!multi_output_channel_[i] || multi_output_channel_[i]->Size() == 0) { + continue; + } + multi_output_channe_[i]->ReadAll(data[i]); + } + } else { + input_channel_->Close(); + std::vector tmp_data; + data.push_back(tmp_data); + input_channel_->ReadAll(data[data.size() - 1]); + } + + auto tree_ptr = TreeWrapper::GetInstance(); + for (auto i = 0; i < data.size(); i++) { + std::vector tmp_results; + tree_ptr->sample(sample_slot, type_slot, data[i], tmp_results); + sample_results.push_back(tmp_results); + } + + for (auto i = 0; i < sample_results.size(); i++) { + auto output_idx = fleet_ptr->LocalRandomEngine()() % output_channel_num; + multi_output_channe_[i]->Write(std::move(sample_results[i])) + } + + data.clear(); + sample_results.clear(); + data.shrink_to_fit(); + sample_results.shrink_to_fit(); + + timeline.Pause(); + VLOG(0) << "DatasetImpl::Sample() end, cost time=" << timeline.ElapsedSec() + << " seconds"; + return; +} + // do local shuffle template void DatasetImpl::LocalShuffle() { diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 462f6771a01..e94a603c089 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -15,6 +15,7 @@ #pragma once #include + #include #include #include // NOLINT @@ -159,7 +160,15 @@ class DatasetImpl : public Dataset { DatasetImpl(); virtual ~DatasetImpl() {} - virtual void SetFileList(const std::vector& filelist); + virtual void InitTDMTree( + const std::vector> config); + virtual void TDMSample(std::string name, const uint64_t table_id, + int fea_value_dim, const std::string tree_path); + virtual void TDMDump(std::string name, const uint64_t table_id, + int fea_value_dim, const std::string tree_path); + + virtual void virtual void SetFileList( + const std::vector& filelist); virtual void SetThreadNum(int thread_num); virtual void SetTrainerNum(int trainer_num); virtual void SetFleetSendBatchSize(int64_t size); diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 7a32d8729fc..fb9c39d849a 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include #include #include + #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/text_format.h" #include "paddle/fluid/framework/async_executor.h" @@ -202,6 +203,12 @@ void BindDataset(py::module *m) { .def(py::init([](const std::string &name = "MultiSlotDataset") { return framework::DatasetFactory::CreateDataset(name); })) + .def("init_tdm_tree", &framework::Dataset::InitTDMTree, + py::call_guard()) + .def("tdm_sample", &framework::Dataset::TDMSample, + py::call_guard()) + .def("tdm_dump", &framework::Dataset::TDMDump, + py::call_guard()) .def("set_filelist", &framework::Dataset::SetFileList, py::call_guard()) .def("set_thread_num", &framework::Dataset::SetThreadNum, -- GitLab