提交 7c0196d4 编写于 作者: M malin10

bug fix

上级 ef24bd78
......@@ -24,6 +24,7 @@
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/tree_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/timer.h"
......@@ -358,7 +359,7 @@ template <typename T>
void DatasetImpl<T>::TDMDump(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path) {
auto tree_ptr = TreeWrapper::GetInstance();
tree_ptr->dump_tree(name, table_id, fea_value_dim, tree_path);
tree_ptr->dump(name, table_id, fea_value_dim, tree_path);
}
// do sample
......@@ -378,7 +379,7 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot,
if (!multi_output_channel_[i] || multi_output_channel_[i]->Size() == 0) {
continue;
}
multi_output_channe_[i]->ReadAll(data[i]);
multi_output_channel_[i]->ReadAll(data[i]);
}
} else {
input_channel_->Close();
......@@ -388,15 +389,17 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot,
}
auto tree_ptr = TreeWrapper::GetInstance();
auto fleet_ptr = FleetWrapper::GetInstance();
for (auto i = 0; i < data.size(); i++) {
std::vector<T> tmp_results;
tree_ptr->sample(sample_slot, type_slot, data[i], tmp_results);
tree_ptr->sample(sample_slot, type_slot, data[i], &tmp_results);
sample_results.push_back(tmp_results);
}
auto output_channel_num = multi_output_channel_.size();
for (auto i = 0; i < sample_results.size(); i++) {
auto output_idx = fleet_ptr->LocalRandomEngine()() % output_channel_num;
multi_output_channe_[i]->Write(std::move(sample_results[i]))
multi_output_channel_[output_idx]->Write(std::move(sample_results[i]));
}
data.clear();
......
......@@ -45,6 +45,12 @@ class Dataset {
public:
Dataset() {}
virtual ~Dataset() {}
virtual void InitTDMTree(
const std::vector<std::pair<std::string, std::string>> config) = 0;
virtual void TDMSample(const uint16_t sample_slot,
const uint64_t type_slot) = 0;
virtual void TDMDump(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path) = 0;
// set file list
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
// set readers' num
......@@ -162,13 +168,11 @@ class DatasetImpl : public Dataset {
virtual void InitTDMTree(
const std::vector<std::pair<std::string, std::string>> config);
virtual void TDMSample(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path);
virtual void TDMSample(const uint16_t sample_slot, const uint64_t type_slot);
virtual void TDMDump(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path);
virtual void virtual void SetFileList(
const std::vector<std::string>& filelist);
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size);
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
......@@ -6,6 +20,7 @@
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/fleet/tree_wrapper.h"
namespace paddle {
namespace framework {
......@@ -117,7 +132,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim,
std::shared_ptr<FILE> fp =
paddle::framework::fs_open(tree_path, "w", &ret, "");
std::vector<uint64_t> fea_keys, std::vector<float*> pull_result_ptr;
std::vector<uint64_t> fea_keys, std::vector<float *> pull_result_ptr;
fea_keys.reserve(_total_node_num);
pull_result_ptr.reserve(_total_node_num);
......@@ -167,7 +182,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim,
}
bool Tree::trace_back(uint64_t id,
std::vector<std::pair<uint64_t, uint32_t>>& ids) {
std::vector<std::pair<uint64_t, uint32_t>>* ids) {
ids.clear();
std::unordered_map<uint64_t, Node*>::iterator find_it =
_leaf_node_map.find(id);
......@@ -178,11 +193,11 @@ bool Tree::trace_back(uint64_t id,
Node* node = find_it->second;
while (node != NULL) {
height++;
ids.emplace_back(node->id, 0);
ids->emplace_back(node->id, 0);
node = node->parent_node;
}
for (auto& id : ids) {
id.second = height--;
for (auto& pair_id : *ids) {
pair_id.second = height--;
}
}
return true;
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
......@@ -11,8 +26,8 @@ namespace paddle {
namespace framework {
struct Node {
Node::Node() : parent_node(NULL), id(0), height(0) {}
~Node(){};
Node() : parent_node(NULL), id(0), height(0) {}
~Node() {}
std::vector<Node*> sub_nodes;
// uint32_t sub_node_num;
Node* parent_node;
......@@ -34,8 +49,7 @@ class Tree {
void print_tree();
int dump_tree(const uint64_t table_id, int fea_value_dim,
const std::string tree_path);
//采样:从叶节点回溯到根节点
void trace_back(uint64_t id, std::vector<std::pair<uint64_t, uint32_t>>& ids);
bool trace_back(uint64_t id, std::vector<std::pair<uint64_t, uint32_t>>* ids);
int load(std::string path);
Node* get_node();
size_t get_total_node_num();
......@@ -75,8 +89,8 @@ class TreeWrapper {
if (tree_map.find(name) != tree_map.end()) {
return;
}
TreePtr tree = new Tree();
tree.load(tree_path);
TreePtr tree = std::make_shared<Tree>();
tree->load(tree_path);
tree_map.insert(std::pair<std::string, TreePtr>{name, tree});
}
......@@ -89,32 +103,39 @@ class TreeWrapper {
}
void sample(const uint16_t sample_slot, const uint64_t type_slot,
std::vector<Record>& src_datas,
std::vector<Record>& sample_results) {
sample_results.clear();
const std::vector<Record>& src_datas,
std::vector<Record>* sample_results) {
sample_results->clear();
auto debug_idx = 0;
for (auto& data : src_datas) {
if (debug_idx == 0) {
VLOG(0) << "src record";
data.Print();
}
uint64_t sample_feasign_idx = -1, type_feasign_idx = -1;
for (auto i = 0; i < data.uint64_feasigns_.size(); i++) {
for (uint64_t i = 0; i < data.uint64_feasigns_.size(); i++) {
if (data.uint64_feasigns_[i].slot() == sample_slot) {
sample_feasign_idx = i;
}
if (data.uint64_feasigns_.slot() == type_slot) {
if (data.uint64_feasigns_[i].slot() == type_slot) {
type_feasign_idx = i;
}
}
if (sample_feasign_idx > 0) {
std::vector<std::pair<uint64_t, uint32_t>> trace_ids;
for (auto name : tree_map) {
bool in_tree = tree_map.at(name)->trace_back(
for (std::unordered_map<std::string, TreePtr>::iterator ite =
tree_map.begin();
ite != tree_map.end(); ite++) {
bool in_tree = ite->second->trace_back(
data.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_,
trace_ids);
&trace_ids);
if (in_tree) {
break;
} else {
PADDLE_ENFORCE_EQ(trace_ids.size(), 0, "");
}
}
for (auto i = 0; i < trace_ids.size(); i++) {
for (uint64_t i = 0; i < trace_ids.size(); i++) {
Record instance(data);
instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ =
trace_ids[i].first;
......@@ -122,9 +143,14 @@ class TreeWrapper {
instance.uint64_feasigns_[type_feasign_idx]
.sign()
.uint64_feasign_ += trace_ids[i].second * 100;
sample_results.push_back(instance);
if (debug_idx == 0) {
VLOG(0) << "sample results:" << i;
instance.Print();
}
sample_results->push_back(instance);
}
}
debug_idx += 1;
}
return;
}
......
......@@ -608,6 +608,15 @@ class InMemoryDataset(DatasetBase):
self.dataset.generate_local_tables_unlock(
table_id, fea_dim, read_thread_num, consume_thread_num, shard_num)
def init_tdm_tree(self, configs):
self.dataset.init_tdm_tree(configs)
def tdm_sample(self, sample_slot, type_slot):
self.dataset.tdm_sample(sample_slot, type_slot)
def tdm_dump(self, name, table_id, fea_value_dim, tree_path):
self.dataset.tdm_dump(name, table_id, fea_value_dim, tree_path)
def load_into_memory(self):
"""
Load data into memory
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册