提交 7c0196d4 编写于 作者: M malin10

bug fix

上级 ef24bd78
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/tree_wrapper.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/monitor.h" #include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
...@@ -358,7 +359,7 @@ template <typename T> ...@@ -358,7 +359,7 @@ template <typename T>
void DatasetImpl<T>::TDMDump(std::string name, const uint64_t table_id, void DatasetImpl<T>::TDMDump(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path) { int fea_value_dim, const std::string tree_path) {
auto tree_ptr = TreeWrapper::GetInstance(); auto tree_ptr = TreeWrapper::GetInstance();
tree_ptr->dump_tree(name, table_id, fea_value_dim, tree_path); tree_ptr->dump(name, table_id, fea_value_dim, tree_path);
} }
// do sample // do sample
...@@ -378,7 +379,7 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot, ...@@ -378,7 +379,7 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot,
if (!multi_output_channel_[i] || multi_output_channel_[i]->Size() == 0) { if (!multi_output_channel_[i] || multi_output_channel_[i]->Size() == 0) {
continue; continue;
} }
multi_output_channe_[i]->ReadAll(data[i]); multi_output_channel_[i]->ReadAll(data[i]);
} }
} else { } else {
input_channel_->Close(); input_channel_->Close();
...@@ -388,15 +389,17 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot, ...@@ -388,15 +389,17 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot,
} }
auto tree_ptr = TreeWrapper::GetInstance(); auto tree_ptr = TreeWrapper::GetInstance();
auto fleet_ptr = FleetWrapper::GetInstance();
for (auto i = 0; i < data.size(); i++) { for (auto i = 0; i < data.size(); i++) {
std::vector<T> tmp_results; std::vector<T> tmp_results;
tree_ptr->sample(sample_slot, type_slot, data[i], tmp_results); tree_ptr->sample(sample_slot, type_slot, data[i], &tmp_results);
sample_results.push_back(tmp_results); sample_results.push_back(tmp_results);
} }
auto output_channel_num = multi_output_channel_.size();
for (auto i = 0; i < sample_results.size(); i++) { for (auto i = 0; i < sample_results.size(); i++) {
auto output_idx = fleet_ptr->LocalRandomEngine()() % output_channel_num; auto output_idx = fleet_ptr->LocalRandomEngine()() % output_channel_num;
multi_output_channe_[i]->Write(std::move(sample_results[i])) multi_output_channel_[output_idx]->Write(std::move(sample_results[i]));
} }
data.clear(); data.clear();
......
...@@ -45,6 +45,12 @@ class Dataset { ...@@ -45,6 +45,12 @@ class Dataset {
public: public:
Dataset() {} Dataset() {}
virtual ~Dataset() {} virtual ~Dataset() {}
virtual void InitTDMTree(
const std::vector<std::pair<std::string, std::string>> config) = 0;
virtual void TDMSample(const uint16_t sample_slot,
const uint64_t type_slot) = 0;
virtual void TDMDump(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path) = 0;
// set file list // set file list
virtual void SetFileList(const std::vector<std::string>& filelist) = 0; virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
// set readers' num // set readers' num
...@@ -162,13 +168,11 @@ class DatasetImpl : public Dataset { ...@@ -162,13 +168,11 @@ class DatasetImpl : public Dataset {
virtual void InitTDMTree( virtual void InitTDMTree(
const std::vector<std::pair<std::string, std::string>> config); const std::vector<std::pair<std::string, std::string>> config);
virtual void TDMSample(std::string name, const uint64_t table_id, virtual void TDMSample(const uint16_t sample_slot, const uint64_t type_slot);
int fea_value_dim, const std::string tree_path);
virtual void TDMDump(std::string name, const uint64_t table_id, virtual void TDMDump(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path); int fea_value_dim, const std::string tree_path);
virtual void virtual void SetFileList( virtual void SetFileList(const std::vector<std::string>& filelist);
const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num); virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size); virtual void SetFleetSendBatchSize(int64_t size);
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once #pragma once
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -6,6 +20,7 @@ ...@@ -6,6 +20,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/fleet/tree_wrapper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -117,7 +132,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim, ...@@ -117,7 +132,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim,
std::shared_ptr<FILE> fp = std::shared_ptr<FILE> fp =
paddle::framework::fs_open(tree_path, "w", &ret, ""); paddle::framework::fs_open(tree_path, "w", &ret, "");
std::vector<uint64_t> fea_keys, std::vector<float*> pull_result_ptr; std::vector<uint64_t> fea_keys, std::vector<float *> pull_result_ptr;
fea_keys.reserve(_total_node_num); fea_keys.reserve(_total_node_num);
pull_result_ptr.reserve(_total_node_num); pull_result_ptr.reserve(_total_node_num);
...@@ -167,7 +182,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim, ...@@ -167,7 +182,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim,
} }
bool Tree::trace_back(uint64_t id, bool Tree::trace_back(uint64_t id,
std::vector<std::pair<uint64_t, uint32_t>>& ids) { std::vector<std::pair<uint64_t, uint32_t>>* ids) {
ids.clear(); ids.clear();
std::unordered_map<uint64_t, Node*>::iterator find_it = std::unordered_map<uint64_t, Node*>::iterator find_it =
_leaf_node_map.find(id); _leaf_node_map.find(id);
...@@ -178,11 +193,11 @@ bool Tree::trace_back(uint64_t id, ...@@ -178,11 +193,11 @@ bool Tree::trace_back(uint64_t id,
Node* node = find_it->second; Node* node = find_it->second;
while (node != NULL) { while (node != NULL) {
height++; height++;
ids.emplace_back(node->id, 0); ids->emplace_back(node->id, 0);
node = node->parent_node; node = node->parent_node;
} }
for (auto& id : ids) { for (auto& pair_id : *ids) {
id.second = height--; pair_id.second = height--;
} }
} }
return true; return true;
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once #pragma once
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
...@@ -11,8 +26,8 @@ namespace paddle { ...@@ -11,8 +26,8 @@ namespace paddle {
namespace framework { namespace framework {
struct Node { struct Node {
Node::Node() : parent_node(NULL), id(0), height(0) {} Node() : parent_node(NULL), id(0), height(0) {}
~Node(){}; ~Node() {}
std::vector<Node*> sub_nodes; std::vector<Node*> sub_nodes;
// uint32_t sub_node_num; // uint32_t sub_node_num;
Node* parent_node; Node* parent_node;
...@@ -34,8 +49,7 @@ class Tree { ...@@ -34,8 +49,7 @@ class Tree {
void print_tree(); void print_tree();
int dump_tree(const uint64_t table_id, int fea_value_dim, int dump_tree(const uint64_t table_id, int fea_value_dim,
const std::string tree_path); const std::string tree_path);
//采样:从叶节点回溯到根节点 bool trace_back(uint64_t id, std::vector<std::pair<uint64_t, uint32_t>>* ids);
void trace_back(uint64_t id, std::vector<std::pair<uint64_t, uint32_t>>& ids);
int load(std::string path); int load(std::string path);
Node* get_node(); Node* get_node();
size_t get_total_node_num(); size_t get_total_node_num();
...@@ -75,8 +89,8 @@ class TreeWrapper { ...@@ -75,8 +89,8 @@ class TreeWrapper {
if (tree_map.find(name) != tree_map.end()) { if (tree_map.find(name) != tree_map.end()) {
return; return;
} }
TreePtr tree = new Tree(); TreePtr tree = std::make_shared<Tree>();
tree.load(tree_path); tree->load(tree_path);
tree_map.insert(std::pair<std::string, TreePtr>{name, tree}); tree_map.insert(std::pair<std::string, TreePtr>{name, tree});
} }
...@@ -89,32 +103,39 @@ class TreeWrapper { ...@@ -89,32 +103,39 @@ class TreeWrapper {
} }
void sample(const uint16_t sample_slot, const uint64_t type_slot, void sample(const uint16_t sample_slot, const uint64_t type_slot,
std::vector<Record>& src_datas, const std::vector<Record>& src_datas,
std::vector<Record>& sample_results) { std::vector<Record>* sample_results) {
sample_results.clear(); sample_results->clear();
auto debug_idx = 0;
for (auto& data : src_datas) { for (auto& data : src_datas) {
if (debug_idx == 0) {
VLOG(0) << "src record";
data.Print();
}
uint64_t sample_feasign_idx = -1, type_feasign_idx = -1; uint64_t sample_feasign_idx = -1, type_feasign_idx = -1;
for (auto i = 0; i < data.uint64_feasigns_.size(); i++) { for (uint64_t i = 0; i < data.uint64_feasigns_.size(); i++) {
if (data.uint64_feasigns_[i].slot() == sample_slot) { if (data.uint64_feasigns_[i].slot() == sample_slot) {
sample_feasign_idx = i; sample_feasign_idx = i;
} }
if (data.uint64_feasigns_.slot() == type_slot) { if (data.uint64_feasigns_[i].slot() == type_slot) {
type_feasign_idx = i; type_feasign_idx = i;
} }
} }
if (sample_feasign_idx > 0) { if (sample_feasign_idx > 0) {
std::vector<std::pair<uint64_t, uint32_t>> trace_ids; std::vector<std::pair<uint64_t, uint32_t>> trace_ids;
for (auto name : tree_map) { for (std::unordered_map<std::string, TreePtr>::iterator ite =
bool in_tree = tree_map.at(name)->trace_back( tree_map.begin();
ite != tree_map.end(); ite++) {
bool in_tree = ite->second->trace_back(
data.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_, data.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_,
trace_ids); &trace_ids);
if (in_tree) { if (in_tree) {
break; break;
} else { } else {
PADDLE_ENFORCE_EQ(trace_ids.size(), 0, ""); PADDLE_ENFORCE_EQ(trace_ids.size(), 0, "");
} }
} }
for (auto i = 0; i < trace_ids.size(); i++) { for (uint64_t i = 0; i < trace_ids.size(); i++) {
Record instance(data); Record instance(data);
instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ = instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ =
trace_ids[i].first; trace_ids[i].first;
...@@ -122,9 +143,14 @@ class TreeWrapper { ...@@ -122,9 +143,14 @@ class TreeWrapper {
instance.uint64_feasigns_[type_feasign_idx] instance.uint64_feasigns_[type_feasign_idx]
.sign() .sign()
.uint64_feasign_ += trace_ids[i].second * 100; .uint64_feasign_ += trace_ids[i].second * 100;
sample_results.push_back(instance); if (debug_idx == 0) {
VLOG(0) << "sample results:" << i;
instance.Print();
}
sample_results->push_back(instance);
} }
} }
debug_idx += 1;
} }
return; return;
} }
......
...@@ -608,6 +608,15 @@ class InMemoryDataset(DatasetBase): ...@@ -608,6 +608,15 @@ class InMemoryDataset(DatasetBase):
self.dataset.generate_local_tables_unlock( self.dataset.generate_local_tables_unlock(
table_id, fea_dim, read_thread_num, consume_thread_num, shard_num) table_id, fea_dim, read_thread_num, consume_thread_num, shard_num)
def init_tdm_tree(self, configs):
self.dataset.init_tdm_tree(configs)
def tdm_sample(self, sample_slot, type_slot):
self.dataset.tdm_sample(sample_slot, type_slot)
def tdm_dump(self, name, table_id, fea_value_dim, tree_path):
self.dataset.tdm_dump(name, table_id, fea_value_dim, tree_path)
def load_into_memory(self): def load_into_memory(self):
""" """
Load data into memory Load data into memory
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册