From 7c0196d466ffdfa34e7eb644fb9e70e5f9629638 Mon Sep 17 00:00:00 2001 From: malin10 Date: Wed, 9 Sep 2020 15:19:06 +0800 Subject: [PATCH] bug fix --- paddle/fluid/framework/data_set.cc | 11 ++-- paddle/fluid/framework/data_set.h | 12 ++-- paddle/fluid/framework/fleet/tree_wrapper.cc | 25 +++++++-- paddle/fluid/framework/fleet/tree_wrapper.h | 58 ++++++++++++++------ python/paddle/fluid/dataset.py | 9 +++ 5 files changed, 86 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 6882922b3ed..6ded7992f0e 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -24,6 +24,7 @@ #include "google/protobuf/text_format.h" #include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/fleet/tree_wrapper.h" #include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/platform/monitor.h" #include "paddle/fluid/platform/timer.h" @@ -358,7 +359,7 @@ template void DatasetImpl::TDMDump(std::string name, const uint64_t table_id, int fea_value_dim, const std::string tree_path) { auto tree_ptr = TreeWrapper::GetInstance(); - tree_ptr->dump_tree(name, table_id, fea_value_dim, tree_path); + tree_ptr->dump(name, table_id, fea_value_dim, tree_path); } // do sample @@ -378,7 +379,7 @@ void DatasetImpl::TDMSample(const uint16_t sample_slot, if (!multi_output_channel_[i] || multi_output_channel_[i]->Size() == 0) { continue; } - multi_output_channe_[i]->ReadAll(data[i]); + multi_output_channel_[i]->ReadAll(data[i]); } } else { input_channel_->Close(); @@ -388,15 +389,17 @@ void DatasetImpl::TDMSample(const uint16_t sample_slot, } auto tree_ptr = TreeWrapper::GetInstance(); + auto fleet_ptr = FleetWrapper::GetInstance(); for (auto i = 0; i < data.size(); i++) { std::vector tmp_results; - tree_ptr->sample(sample_slot, type_slot, data[i], tmp_results); + tree_ptr->sample(sample_slot, type_slot, data[i], &tmp_results); sample_results.push_back(tmp_results); } + auto output_channel_num = multi_output_channel_.size(); for (auto i = 0; i < sample_results.size(); i++) { auto output_idx = fleet_ptr->LocalRandomEngine()() % output_channel_num; - multi_output_channe_[i]->Write(std::move(sample_results[i])) + multi_output_channel_[output_idx]->Write(std::move(sample_results[i])); } data.clear(); diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index e94a603c089..2b4844b2696 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -45,6 +45,12 @@ class Dataset { public: Dataset() {} virtual ~Dataset() {} + virtual void InitTDMTree( + const std::vector> config) = 0; + virtual void TDMSample(const uint16_t sample_slot, + const uint64_t type_slot) = 0; + virtual void TDMDump(std::string name, const uint64_t table_id, + int fea_value_dim, const std::string tree_path) = 0; // set file list virtual void SetFileList(const std::vector& filelist) = 0; // set readers' num @@ -162,13 +168,11 @@ class DatasetImpl : public Dataset { virtual void InitTDMTree( const std::vector> config); - virtual void TDMSample(std::string name, const uint64_t table_id, - int fea_value_dim, const std::string tree_path); + virtual void TDMSample(const uint16_t sample_slot, const uint64_t type_slot); virtual void TDMDump(std::string name, const uint64_t table_id, int fea_value_dim, const std::string tree_path); - virtual void virtual void SetFileList( - const std::vector& filelist); + virtual void SetFileList(const std::vector& filelist); virtual void SetThreadNum(int thread_num); virtual void SetTrainerNum(int trainer_num); virtual void SetFleetSendBatchSize(int64_t size); diff --git a/paddle/fluid/framework/fleet/tree_wrapper.cc b/paddle/fluid/framework/fleet/tree_wrapper.cc index e02ae46d401..a7efe04e4ac 100644 --- a/paddle/fluid/framework/fleet/tree_wrapper.cc +++ b/paddle/fluid/framework/fleet/tree_wrapper.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include #include @@ -6,6 +20,7 @@ #include #include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/fleet/tree_wrapper.h" namespace paddle { namespace framework { @@ -117,7 +132,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim, std::shared_ptr fp = paddle::framework::fs_open(tree_path, "w", &ret, ""); - std::vector fea_keys, std::vector pull_result_ptr; + std::vector fea_keys, std::vector pull_result_ptr; fea_keys.reserve(_total_node_num); pull_result_ptr.reserve(_total_node_num); @@ -167,7 +182,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim, } bool Tree::trace_back(uint64_t id, - std::vector>& ids) { + std::vector>* ids) { ids.clear(); std::unordered_map::iterator find_it = _leaf_node_map.find(id); @@ -178,11 +193,11 @@ bool Tree::trace_back(uint64_t id, Node* node = find_it->second; while (node != NULL) { height++; - ids.emplace_back(node->id, 0); + ids->emplace_back(node->id, 0); node = node->parent_node; } - for (auto& id : ids) { - id.second = height--; + for (auto& pair_id : *ids) { + pair_id.second = height--; } } return true; diff --git a/paddle/fluid/framework/fleet/tree_wrapper.h b/paddle/fluid/framework/fleet/tree_wrapper.h index c48ee69de6c..4cca988f6f2 100644 --- a/paddle/fluid/framework/fleet/tree_wrapper.h +++ b/paddle/fluid/framework/fleet/tree_wrapper.h @@ -1,8 +1,23 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include #include #include #include +#include #include #include "paddle/fluid/framework/data_feed.h" @@ -11,8 +26,8 @@ namespace paddle { namespace framework { struct Node { - Node::Node() : parent_node(NULL), id(0), height(0) {} - ~Node(){}; + Node() : parent_node(NULL), id(0), height(0) {} + ~Node() {} std::vector sub_nodes; // uint32_t sub_node_num; Node* parent_node; @@ -34,8 +49,7 @@ class Tree { void print_tree(); int dump_tree(const uint64_t table_id, int fea_value_dim, const std::string tree_path); - //采样:从叶节点回溯到根节点 - void trace_back(uint64_t id, std::vector>& ids); + bool trace_back(uint64_t id, std::vector>* ids); int load(std::string path); Node* get_node(); size_t get_total_node_num(); @@ -75,8 +89,8 @@ class TreeWrapper { if (tree_map.find(name) != tree_map.end()) { return; } - TreePtr tree = new Tree(); - tree.load(tree_path); + TreePtr tree = std::make_shared(); + tree->load(tree_path); tree_map.insert(std::pair{name, tree}); } @@ -89,32 +103,39 @@ class TreeWrapper { } void sample(const uint16_t sample_slot, const uint64_t type_slot, - std::vector& src_datas, - std::vector& sample_results) { - sample_results.clear(); + const std::vector& src_datas, + std::vector* sample_results) { + sample_results->clear(); + auto debug_idx = 0; for (auto& data : src_datas) { + if (debug_idx == 0) { + VLOG(0) << "src record"; + data.Print(); + } uint64_t sample_feasign_idx = -1, type_feasign_idx = -1; - for (auto i = 0; i < data.uint64_feasigns_.size(); i++) { + for (uint64_t i = 0; i < data.uint64_feasigns_.size(); i++) { if (data.uint64_feasigns_[i].slot() == sample_slot) { sample_feasign_idx = i; } - if (data.uint64_feasigns_.slot() == type_slot) { + if (data.uint64_feasigns_[i].slot() == type_slot) { type_feasign_idx = i; } } if (sample_feasign_idx > 0) { std::vector> trace_ids; - for (auto name : tree_map) { - bool in_tree = tree_map.at(name)->trace_back( + for (std::unordered_map::iterator ite = + tree_map.begin(); + ite != tree_map.end(); ite++) { + bool in_tree = ite->second->trace_back( data.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_, - trace_ids); + &trace_ids); if (in_tree) { break; } else { PADDLE_ENFORCE_EQ(trace_ids.size(), 0, ""); } } - for (auto i = 0; i < trace_ids.size(); i++) { + for (uint64_t i = 0; i < trace_ids.size(); i++) { Record instance(data); instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ = trace_ids[i].first; @@ -122,9 +143,14 @@ class TreeWrapper { instance.uint64_feasigns_[type_feasign_idx] .sign() .uint64_feasign_ += trace_ids[i].second * 100; - sample_results.push_back(instance); + if (debug_idx == 0) { + VLOG(0) << "sample results:" << i; + instance.Print(); + } + sample_results->push_back(instance); } } + debug_idx += 1; } return; } diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 87b1ce2511e..00e8968aaa3 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -608,6 +608,15 @@ class InMemoryDataset(DatasetBase): self.dataset.generate_local_tables_unlock( table_id, fea_dim, read_thread_num, consume_thread_num, shard_num) + def init_tdm_tree(self, configs): + self.dataset.init_tdm_tree(configs) + + def tdm_sample(self, sample_slot, type_slot): + self.dataset.tdm_sample(sample_slot, type_slot) + + def tdm_dump(self, name, table_id, fea_value_dim, tree_path): + self.dataset.tdm_dump(name, table_id, fea_value_dim, tree_path) + def load_into_memory(self): """ Load data into memory -- GitLab