diff --git a/paddle/fluid/distributed/index_dataset/CMakeLists.txt b/paddle/fluid/distributed/index_dataset/CMakeLists.txt index a30488494a52bcfea61476caeb1ab08e3e6781a1..6edb9834d49eb752a88bacf3cf7d0cd72f081100 100644 --- a/paddle/fluid/distributed/index_dataset/CMakeLists.txt +++ b/paddle/fluid/distributed/index_dataset/CMakeLists.txt @@ -1,7 +1,10 @@ proto_library(index_dataset_proto SRCS index_dataset.proto) cc_library(index_wrapper SRCS index_wrapper.cc DEPS index_dataset_proto fs) -cc_library(index_sampler SRCS index_sampler.cc DEPS index_wrapper) - +if(WITH_MKLDNN) + cc_library(index_sampler SRCS index_sampler.cc DEPS xxhash index_wrapper mkldnn) +else() + cc_library(index_sampler SRCS index_sampler.cc DEPS xxhash index_wrapper) +endif() if(WITH_PYTHON) py_proto_compile(index_dataset_py_proto SRCS index_dataset.proto) endif() diff --git a/paddle/fluid/distributed/index_dataset/index_dataset.proto b/paddle/fluid/distributed/index_dataset/index_dataset.proto index 1b4ee313671ad503b9e46dbe9e34d4a69d0cfc4d..1f0df0df5c2cf690305f368cb12a3ad82c977810 100644 --- a/paddle/fluid/distributed/index_dataset/index_dataset.proto +++ b/paddle/fluid/distributed/index_dataset/index_dataset.proto @@ -19,6 +19,7 @@ message IndexNode { required uint64 id = 1; required bool is_leaf = 2; required float probability = 3; + optional string item_name = 4; } message TreeMeta { @@ -29,4 +30,4 @@ message TreeMeta { message KVItem { required bytes key = 1; required bytes value = 2; -} \ No newline at end of file +} diff --git a/paddle/fluid/distributed/index_dataset/index_sampler.cc b/paddle/fluid/distributed/index_dataset/index_sampler.cc index 3e573bbdd2de97130a109ddb583a724cf363c6be..306d11d333dae680cb1623bff64093a1a8e35493 100644 --- a/paddle/fluid/distributed/index_dataset/index_sampler.cc +++ b/paddle/fluid/distributed/index_dataset/index_sampler.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/index_dataset/index_sampler.h" +#include "paddle/fluid/framework/data_feed.h" namespace paddle { namespace distributed { @@ -69,6 +70,67 @@ std::vector> LayerWiseSampler::sample( } return outputs; } +void LayerWiseSampler::sample_from_dataset( + const uint16_t sample_slot, + std::vector* src_datas, + std::vector* sample_results) { + sample_results->clear(); + for (auto& data : *src_datas) { + VLOG(1) << "src data size = " << src_datas->size(); + VLOG(1) << "float data size = " << data.float_feasigns_.size(); + // data.Print(); + uint64_t start_idx = sample_results->size(); + VLOG(1) << "before sample, sample_results.size = " << start_idx; + uint64_t sample_feasign_idx = -1; + bool sample_sign = false; + for (unsigned int i = 0; i < data.uint64_feasigns_.size(); i++) { + VLOG(1) << "slot" << i << " = " << data.uint64_feasigns_[i].slot(); + if (data.uint64_feasigns_[i].slot() == sample_slot) { + sample_sign = true; + sample_feasign_idx = i; + } + if (sample_sign) break; + } + + VLOG(1) << "sample_feasign_idx: " << sample_feasign_idx; + if (sample_sign) { + auto target_id = + data.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_; + auto travel_codes = tree_->GetTravelCodes(target_id, start_sample_layer_); + auto travel_path = tree_->GetNodes(travel_codes); + for (unsigned int j = 0; j < travel_path.size(); j++) { + paddle::framework::Record instance(data); + instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ = + travel_path[j].id(); + sample_results->push_back(instance); + for (int idx_offset = 0; idx_offset < layer_counts_[j]; idx_offset++) { + int sample_res = 0; + do { + sample_res = sampler_vec_[j]->Sample(); + } while (layer_ids_[j][sample_res].id() == travel_path[j].id()); + paddle::framework::Record instance(data); + instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ = + layer_ids_[j][sample_res].id(); + VLOG(1) << "layer id :" << layer_ids_[j][sample_res].id(); + // sample_feasign_idx + 1 == label's id + instance.uint64_feasigns_[sample_feasign_idx + 1] + .sign() + .uint64_feasign_ = 0; + sample_results->push_back(instance); + } + VLOG(1) << "layer end!!!!!!!!!!!!!!!!!!"; + } + } + } + VLOG(1) << "after sample, sample_results.size = " << sample_results->size(); + return; +} + +std::vector float2int(std::vector tmp) { + std::vector tmp_int; + for (auto i : tmp) tmp_int.push_back(uint64_t(i)); + return tmp_int; +} } // end namespace distributed } // end namespace paddle diff --git a/paddle/fluid/distributed/index_dataset/index_sampler.h b/paddle/fluid/distributed/index_dataset/index_sampler.h index 8813421446a21c1379ca872952fe8b367d0724ca..02806b814c20097306cf00cebbcd04e21718462e 100644 --- a/paddle/fluid/distributed/index_dataset/index_sampler.h +++ b/paddle/fluid/distributed/index_dataset/index_sampler.h @@ -15,6 +15,7 @@ #pragma once #include #include "paddle/fluid/distributed/index_dataset/index_wrapper.h" +#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/platform/enforce.h" @@ -34,13 +35,19 @@ class IndexSampler { return instance; } - virtual void init_layerwise_conf(const std::vector& layer_sample_counts, - int start_sample_layer = 1, int seed = 0) {} + virtual void init_layerwise_conf( + const std::vector& layer_sample_counts, + uint16_t start_sample_layer = 1, uint16_t seed = 0) {} virtual void init_beamsearch_conf(const int64_t k) {} virtual std::vector> sample( const std::vector>& user_inputs, const std::vector& input_targets, bool with_hierarchy = false) = 0; + + virtual void sample_from_dataset( + const uint16_t sample_slot, + std::vector* src_datas, + std::vector* sample_results) = 0; }; class LayerWiseSampler : public IndexSampler { @@ -50,8 +57,9 @@ class LayerWiseSampler : public IndexSampler { tree_ = IndexWrapper::GetInstance()->get_tree_index(name); } - void init_layerwise_conf(const std::vector& layer_sample_counts, - int start_sample_layer, int seed) override { + void init_layerwise_conf(const std::vector& layer_sample_counts, + uint16_t start_sample_layer, + uint16_t seed) override { seed_ = seed; start_sample_layer_ = start_sample_layer; @@ -106,6 +114,11 @@ class LayerWiseSampler : public IndexSampler { const std::vector>& user_inputs, const std::vector& target_ids, bool with_hierarchy) override; + void sample_from_dataset( + const uint16_t sample_slot, + std::vector* src_datas, + std::vector* sample_results) override; + private: std::vector layer_counts_; int64_t layer_counts_sum_{0}; diff --git a/paddle/fluid/distributed/index_dataset/index_wrapper.cc b/paddle/fluid/distributed/index_dataset/index_wrapper.cc index 7a9691f3602e2622c6adc6ddbeb1a1507a174f70..27aa890f7600fba20ef7a9b535d368fa28714972 100644 --- a/paddle/fluid/distributed/index_dataset/index_wrapper.cc +++ b/paddle/fluid/distributed/index_dataset/index_wrapper.cc @@ -66,9 +66,10 @@ int TreeIndex::Load(const std::string filename) { auto code = std::stoull(item.key()); IndexNode node; node.ParseFromString(item.value()); - PADDLE_ENFORCE_NE(node.id(), 0, - platform::errors::InvalidArgument( - "Node'id should not be equel to zero.")); + + // PADDLE_ENFORCE_NE(node.id(), 0, + // platform::errors::InvalidArgument( + // "Node'id should not be equel to zero.")); if (node.is_leaf()) { id_codes_map_[node.id()] = code; } diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 3b9c99b1a5a1e737c741858e4a148b43ebf289bd..639b7bad1bf35293648fd4091d1e4e502145c99f 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -285,7 +285,7 @@ if(WITH_DISTRIBUTE) data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry - device_context scope framework_proto trainer_desc_proto glog fs shell + device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper data_feed_proto timer monitor @@ -307,6 +307,7 @@ if(WITH_DISTRIBUTE) downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog + index_sampler index_wrapper sampler index_dataset_proto lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index ca5e27dac3acb3e90744d811123e70cc612016c0..c511526c3159d6ddc3a7e78c22ded064fb2223f6 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -14,7 +14,11 @@ #include "paddle/fluid/framework/data_set.h" #include "google/protobuf/text_format.h" +#if (defined PADDLE_WITH_DISTRIBUTE) && (defined PADDLE_WITH_PSCORE) +#include "paddle/fluid/distributed/index_dataset/index_sampler.h" +#endif #include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/platform/monitor.h" #include "paddle/fluid/platform/timer.h" @@ -547,6 +551,83 @@ void DatasetImpl::LocalShuffle() { << timeline.ElapsedSec() << " seconds"; } +// do tdm sample +void MultiSlotDataset::TDMSample(const std::string tree_name, + const std::string tree_path, + const std::vector tdm_layer_counts, + const uint16_t start_sample_layer, + const bool with_hierachy, const uint16_t seed_, + const uint16_t sample_slot) { +#if (defined PADDLE_WITH_DISTRIBUTE) && (defined PADDLE_WITH_PSCORE) + // init tdm tree + auto wrapper_ptr = paddle::distributed::IndexWrapper::GetInstance(); + wrapper_ptr->insert_tree_index(tree_name, tree_path); + auto tree_ptr = wrapper_ptr->get_tree_index(tree_name); + auto _layer_wise_sample = paddle::distributed::LayerWiseSampler(tree_name); + _layer_wise_sample.init_layerwise_conf(tdm_layer_counts, start_sample_layer, + seed_); + + VLOG(0) << "DatasetImpl::Sample() begin"; + platform::Timer timeline; + timeline.Start(); + + std::vector> data; + std::vector> sample_results; + if (!input_channel_ || input_channel_->Size() == 0) { + for (size_t i = 0; i < multi_output_channel_.size(); ++i) { + std::vector tmp_data; + data.push_back(tmp_data); + if (!multi_output_channel_[i] || multi_output_channel_[i]->Size() == 0) { + continue; + } + multi_output_channel_[i]->Close(); + multi_output_channel_[i]->ReadAll(data[i]); + } + } else { + input_channel_->Close(); + std::vector tmp_data; + data.push_back(tmp_data); + input_channel_->ReadAll(data[data.size() - 1]); + } + + VLOG(1) << "finish read src data, data.size = " << data.size() + << "; details: "; + auto fleet_ptr = FleetWrapper::GetInstance(); + for (unsigned int i = 0; i < data.size(); i++) { + VLOG(1) << "data[" << i << "]: size = " << data[i].size(); + std::vector tmp_results; + _layer_wise_sample.sample_from_dataset(sample_slot, &data[i], &tmp_results); + VLOG(1) << "sample_results(" << sample_slot << ") = " << tmp_results.size(); + VLOG(0) << "start to put sample in vector!"; + // sample_results.push_back(tmp_results); + for (unsigned int j = 0; j < tmp_results.size(); j++) { + std::vector tmp_vec; + tmp_vec.emplace_back(tmp_results[j]); + sample_results.emplace_back(tmp_vec); + } + VLOG(0) << "finish to put sample in vector!"; + } + + auto output_channel_num = multi_output_channel_.size(); + for (unsigned int i = 0; i < sample_results.size(); i++) { + auto output_idx = fleet_ptr->LocalRandomEngine()() % output_channel_num; + multi_output_channel_[output_idx]->Open(); + // vector? + multi_output_channel_[output_idx]->Write(std::move(sample_results[i])); + } + + data.clear(); + sample_results.clear(); + data.shrink_to_fit(); + sample_results.shrink_to_fit(); + + timeline.Pause(); + VLOG(0) << "DatasetImpl::Sample() end, cost time=" << timeline.ElapsedSec() + << " seconds"; +#endif + return; +} + void MultiSlotDataset::GlobalShuffle(int thread_num) { VLOG(3) << "MultiSlotDataset::GlobalShuffle() begin"; platform::Timer timeline; diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 981fb694e0fec902ac07c9edbb1b4456d723c918..b41f701548f3f17d80a8a3f5e3b3ac05c80786e7 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -48,6 +48,13 @@ class Dataset { public: Dataset() {} virtual ~Dataset() {} + // do sample + virtual void TDMSample(const std::string tree_name, + const std::string tree_path, + const std::vector tdm_layer_counts, + const uint16_t start_sample_layer, + const bool with_hierachy, const uint16_t seed_, + const uint16_t sample_slot) {} // set file list virtual void SetFileList(const std::vector& filelist) = 0; // set readers' num @@ -162,7 +169,6 @@ class DatasetImpl : public Dataset { public: DatasetImpl(); virtual ~DatasetImpl() {} - virtual void SetFileList(const std::vector& filelist); virtual void SetThreadNum(int thread_num); virtual void SetTrainerNum(int trainer_num); @@ -311,6 +317,12 @@ class DatasetImpl : public Dataset { class MultiSlotDataset : public DatasetImpl { public: MultiSlotDataset() {} + virtual void TDMSample(const std::string tree_name, + const std::string tree_path, + const std::vector tdm_layer_counts, + const uint16_t start_sample_layer, + const bool with_hierachy, const uint16_t seed_, + const uint16_t sample_slot); virtual void MergeByInsId(); virtual void PreprocessInstance(); virtual void PostprocessInstance(); diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 7a32d8729fc6ca48a00f319d5c2fd278d3736288..562047a0c0c431e2a5ca605466231c958c76c6e5 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -202,6 +202,8 @@ void BindDataset(py::module *m) { .def(py::init([](const std::string &name = "MultiSlotDataset") { return framework::DatasetFactory::CreateDataset(name); })) + .def("tdm_sample", &framework::Dataset::TDMSample, + py::call_guard()) .def("set_filelist", &framework::Dataset::SetFileList, py::call_guard()) .def("set_thread_num", &framework::Dataset::SetThreadNum, diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 3f3eec345cb616c37f84cdc0ddf628d9350e5b87..f81bbd69a015f7f3b26aaa445e9f3da7006bad60 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -315,6 +315,5 @@ void BindIndexSampler(py::module* m) { .def("init_beamsearch_conf", &IndexSampler::init_beamsearch_conf) .def("sample", &IndexSampler::sample); } - } // end namespace pybind } // namespace paddle diff --git a/python/paddle/distributed/fleet/dataset/dataset.py b/python/paddle/distributed/fleet/dataset/dataset.py index e231ac55e679a27718d56ca2f1a14adb360a85c1..84d1f21d1542f6234f9bcc5a3bd2db10597d36d3 100644 --- a/python/paddle/distributed/fleet/dataset/dataset.py +++ b/python/paddle/distributed/fleet/dataset/dataset.py @@ -784,6 +784,12 @@ class InMemoryDataset(DatasetBase): if self.use_ps_gpu and core._is_compiled_with_heterps(): self.psgpu.set_date(year, month, day) + def tdm_sample(self, tree_name, tree_path, tdm_layer_counts, + start_sample_layer, with_hierachy, seed, id_slot): + self.dataset.tdm_sample(tree_name, tree_path, tdm_layer_counts, + start_sample_layer, with_hierachy, seed, + id_slot) + def load_into_memory(self, is_shuffle=False): """ :api_attr: Static Graph diff --git a/python/paddle/fluid/tests/unittests/test_dist_tree_index.py b/python/paddle/fluid/tests/unittests/test_dist_tree_index.py index feb52b18dad3d81edbb49fa045aff0ea88ec8be6..6ea15319204f2e0fd3ff5949b0af999559a5d6a1 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_tree_index.py +++ b/python/paddle/fluid/tests/unittests/test_dist_tree_index.py @@ -15,21 +15,45 @@ import unittest from paddle.dataset.common import download, DATA_HOME from paddle.distributed.fleet.dataset import TreeIndex +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle +paddle.enable_static() + + +def create_feeds(): + user_input = fluid.layers.data( + name="item_id", shape=[1], dtype="int64", lod_level=1) + + item = fluid.layers.data( + name="unit_id", shape=[1], dtype="int64", lod_level=1) + + label = fluid.layers.data( + name="label", shape=[1], dtype="int64", lod_level=1) + labels = fluid.layers.data( + name="labels", shape=[1], dtype="int64", lod_level=1) + + feed_list = [user_input, item, label, labels] + return feed_list class TestTreeIndex(unittest.TestCase): def test_tree_index(self): path = download( - "https://paddlerec.bj.bcebos.com/tree-based/data/demo_tree.pb", + "https://paddlerec.bj.bcebos.com/tree-based/data/mini_tree.pb", + "tree_index_unittest", "e2ba4561c2e9432b532df40546390efa") + ''' + path = download( + "https://paddlerec.bj.bcebos.com/tree-based/data/mini_tree.pb", "tree_index_unittest", "cadec20089f5a8a44d320e117d9f9f1a") - + ''' tree = TreeIndex("demo", path) height = tree.height() branch = tree.branch() - self.assertTrue(height == 14) + self.assertTrue(height == 5) self.assertTrue(branch == 2) - self.assertEqual(tree.total_node_nums(), 15581) - self.assertEqual(tree.emb_size(), 5171136) + self.assertEqual(tree.total_node_nums(), 25) + self.assertEqual(tree.emb_size(), 30) # get_layer_codes layer_node_ids = [] @@ -80,118 +104,48 @@ class TestTreeIndex(unittest.TestCase): class TestIndexSampler(unittest.TestCase): def test_layerwise_sampler(self): path = download( - "https://paddlerec.bj.bcebos.com/tree-based/data/demo_tree.pb", - "tree_index_unittest", "cadec20089f5a8a44d320e117d9f9f1a") - - tree = TreeIndex("demo", path) - - layer_nodes = [] - for i in range(tree.height()): - layer_codes = tree.get_layer_codes(i) - layer_nodes.append( - [node.id() for node in tree.get_nodes(layer_codes)]) - - sample_num = range(1, 10000) - start_sample_layer = 1 - seed = 0 - sample_layers = tree.height() - start_sample_layer - sample_num = sample_num[:sample_layers] - layer_sample_counts = list(sample_num) + [1] * (sample_layers - - len(sample_num)) - total_sample_num = sum(layer_sample_counts) + len(layer_sample_counts) - tree.init_layerwise_sampler(sample_num, start_sample_layer, seed) - - ids = [315757, 838060, 1251533, 403522, 2473624, 3321007] - parent_path = {} - for i in range(len(ids)): - tmp = tree.get_travel_codes(ids[i], start_sample_layer) - parent_path[ids[i]] = [node.id() for node in tree.get_nodes(tmp)] - - # check sample res with_hierarchy = False - sample_res = tree.layerwise_sample( - [[315757, 838060], [1251533, 403522]], [2473624, 3321007], False) - idx = 0 - layer = tree.height() - 1 - for i in range(len(layer_sample_counts)): - for j in range(layer_sample_counts[0 - (i + 1)] + 1): - self.assertTrue(sample_res[idx + j][0] == 315757) - self.assertTrue(sample_res[idx + j][1] == 838060) - self.assertTrue(sample_res[idx + j][2] in layer_nodes[layer]) - if j == 0: - self.assertTrue(sample_res[idx + j][3] == 1) - self.assertTrue( - sample_res[idx + j][2] == parent_path[2473624][i]) - else: - self.assertTrue(sample_res[idx + j][3] == 0) - self.assertTrue( - sample_res[idx + j][2] != parent_path[2473624][i]) - idx += layer_sample_counts[0 - (i + 1)] + 1 - layer -= 1 - self.assertTrue(idx == total_sample_num) - layer = tree.height() - 1 - for i in range(len(layer_sample_counts)): - for j in range(layer_sample_counts[0 - (i + 1)] + 1): - self.assertTrue(sample_res[idx + j][0] == 1251533) - self.assertTrue(sample_res[idx + j][1] == 403522) - self.assertTrue(sample_res[idx + j][2] in layer_nodes[layer]) - if j == 0: - self.assertTrue(sample_res[idx + j][3] == 1) - self.assertTrue( - sample_res[idx + j][2] == parent_path[3321007][i]) - else: - self.assertTrue(sample_res[idx + j][3] == 0) - self.assertTrue( - sample_res[idx + j][2] != parent_path[3321007][i]) - idx += layer_sample_counts[0 - (i + 1)] + 1 - layer -= 1 - self.assertTrue(idx == total_sample_num * 2) - - # check sample res with_hierarchy = True - sample_res_with_hierarchy = tree.layerwise_sample( - [[315757, 838060], [1251533, 403522]], [2473624, 3321007], True) - idx = 0 - layer = tree.height() - 1 - for i in range(len(layer_sample_counts)): - for j in range(layer_sample_counts[0 - (i + 1)] + 1): - self.assertTrue(sample_res_with_hierarchy[idx + j][0] == - parent_path[315757][i]) - self.assertTrue(sample_res_with_hierarchy[idx + j][1] == - parent_path[838060][i]) - self.assertTrue( - sample_res_with_hierarchy[idx + j][2] in layer_nodes[layer]) - if j == 0: - self.assertTrue(sample_res_with_hierarchy[idx + j][3] == 1) - self.assertTrue(sample_res_with_hierarchy[idx + j][2] == - parent_path[2473624][i]) - else: - self.assertTrue(sample_res_with_hierarchy[idx + j][3] == 0) - self.assertTrue(sample_res_with_hierarchy[idx + j][2] != - parent_path[2473624][i]) - - idx += layer_sample_counts[0 - (i + 1)] + 1 - layer -= 1 - self.assertTrue(idx == total_sample_num) - layer = tree.height() - 1 - for i in range(len(layer_sample_counts)): - for j in range(layer_sample_counts[0 - (i + 1)] + 1): - self.assertTrue(sample_res_with_hierarchy[idx + j][0] == - parent_path[1251533][i]) - self.assertTrue(sample_res_with_hierarchy[idx + j][1] == - parent_path[403522][i]) - self.assertTrue( - sample_res_with_hierarchy[idx + j][2] in layer_nodes[layer]) - if j == 0: - self.assertTrue(sample_res_with_hierarchy[idx + j][3] == 1) - self.assertTrue(sample_res_with_hierarchy[idx + j][2] == - parent_path[3321007][i]) - else: - self.assertTrue(sample_res_with_hierarchy[idx + j][3] == 0) - self.assertTrue(sample_res_with_hierarchy[idx + j][2] != - parent_path[3321007][i]) - - idx += layer_sample_counts[0 - (i + 1)] + 1 - layer -= 1 - self.assertTrue(idx == 2 * total_sample_num) + "https://paddlerec.bj.bcebos.com/tree-based/data/mini_tree.pb", + "tree_index_unittest", "e2ba4561c2e9432b532df40546390efa") + + tdm_layer_counts = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + #tree = TreeIndex("demo", path) + file_name = "test_in_memory_dataset_tdm_sample_run.txt" + with open(file_name, "w") as f: + #data = "29 d 29 d 29 29 29 29 29 29 29 29 29 29 29 29\n" + data = "1 1 1 15 15 15\n" + data += "1 1 1 15 15 15\n" + f.write(data) + + slots = ["slot1", "slot2", "slot3"] + slots_vars = [] + for slot in slots: + var = fluid.layers.data(name=slot, shape=[1], dtype="int64") + slots_vars.append(var) + + dataset = paddle.distributed.InMemoryDataset() + dataset.init( + batch_size=1, + pipe_command="cat", + download_cmd="cat", + use_var=slots_vars) + dataset.set_filelist([file_name]) + #dataset.update_settings(pipe_command="cat") + #dataset._init_distributed_settings( + # parse_ins_id=True, + # parse_content=True, + # fea_eval=True, + # candidate_size=10000) + + dataset.load_into_memory() + dataset.tdm_sample( + 'demo', + tree_path=path, + tdm_layer_counts=tdm_layer_counts, + start_sample_layer=1, + with_hierachy=False, + seed=0, + id_slot=2) + self.assertTrue(dataset.get_shuffle_data_size() == 8) if __name__ == '__main__':