未验证 提交 a8c3a902 编写于 作者: 1 123malin 提交者: GitHub

tree-based-model (#31696)

* add index_dataset and index_sampler for tree-based model
上级 825d4957
......@@ -14,6 +14,7 @@ endif()
add_subdirectory(table)
add_subdirectory(service)
add_subdirectory(test)
add_subdirectory(index_dataset)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
......
proto_library(index_dataset_proto SRCS index_dataset.proto)
cc_library(index_wrapper SRCS index_wrapper.cc DEPS index_dataset_proto)
cc_library(index_sampler SRCS index_sampler.cc DEPS index_wrapper)
if(WITH_PYTHON)
py_proto_compile(index_dataset_py_proto SRCS index_dataset.proto)
endif()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle.distributed;
message IndexNode {
required uint64 id = 1;
required bool is_leaf = 2;
required float probability = 3;
}
message TreeMeta {
required int32 height = 1;
required int32 branch = 2;
}
message KVItem {
required bytes key = 1;
required bytes value = 2;
}
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/operators/math/sampler.h"
namespace paddle {
namespace distributed {
using Sampler = paddle::operators::math::Sampler;
std::vector<std::vector<uint64_t>> LayerWiseSampler::sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids, bool with_hierarchy) {
auto input_num = target_ids.size();
auto user_feature_num = user_inputs[0].size();
std::vector<std::vector<uint64_t>> outputs(
input_num * layer_counts_sum_,
std::vector<uint64_t>(user_feature_num + 2));
auto max_layer = tree_->Height();
std::vector<Sampler*> sampler_vec(max_layer - start_sample_layer_);
std::vector<std::vector<IndexNode>> layer_ids(max_layer -
start_sample_layer_);
auto layer_index = max_layer - 1;
size_t idx = 0;
while (layer_index >= start_sample_layer_) {
auto layer_codes = tree_->GetLayerCodes(layer_index);
layer_ids[idx] = tree_->GetNodes(layer_codes);
sampler_vec[idx] = new paddle::operators::math::UniformSampler(
layer_ids[idx].size() - 1, seed_);
layer_index--;
idx++;
}
idx = 0;
for (size_t i = 0; i < input_num; i++) {
auto travel_codes =
tree_->GetTravelCodes(target_ids[i], start_sample_layer_);
auto travel_path = tree_->GetNodes(travel_codes);
for (size_t j = 0; j < travel_path.size(); j++) {
// user
if (j > 0 && with_hierarchy) {
auto ancestor_codes =
tree_->GetAncestorCodes(user_inputs[i], max_layer - j - 1);
auto hierarchical_user = tree_->GetNodes(ancestor_codes);
for (int idx_offset = 0; idx_offset <= layer_counts_[j]; idx_offset++) {
for (size_t k = 0; k < user_feature_num; k++) {
outputs[idx + idx_offset][k] = hierarchical_user[k].id();
}
}
} else {
for (int idx_offset = 0; idx_offset <= layer_counts_[j]; idx_offset++) {
for (size_t k = 0; k < user_feature_num; k++) {
outputs[idx + idx_offset][k] = user_inputs[i][k];
}
}
}
// sampler ++
outputs[idx][user_feature_num] = travel_path[j].id();
outputs[idx][user_feature_num + 1] = 1.0;
idx += 1;
for (int idx_offset = 0; idx_offset < layer_counts_[j]; idx_offset++) {
int sample_res = 0;
do {
sample_res = sampler_vec[j]->Sample();
} while (layer_ids[j][sample_res].id() == travel_path[j].id());
outputs[idx + idx_offset][user_feature_num] =
layer_ids[j][sample_res].id();
outputs[idx + idx_offset][user_feature_num + 1] = 0;
}
idx += layer_counts_[j];
}
}
for (size_t i = 0; i < sampler_vec.size(); i++) {
delete sampler_vec[i];
}
return outputs;
}
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
class IndexSampler {
public:
virtual ~IndexSampler() {}
IndexSampler() {}
template <typename T>
static std::shared_ptr<IndexSampler> Init(const std::string& name) {
std::shared_ptr<IndexSampler> instance = nullptr;
instance.reset(new T(name));
return instance;
}
virtual void init_layerwise_conf(const std::vector<int>& layer_sample_counts,
int start_sample_layer = 1, int seed = 0) {}
virtual void init_beamsearch_conf(const int64_t k) {}
virtual std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& input_targets,
bool with_hierarchy = false) = 0;
};
class LayerWiseSampler : public IndexSampler {
public:
virtual ~LayerWiseSampler() {}
explicit LayerWiseSampler(const std::string& name) {
tree_ = IndexWrapper::GetInstance()->get_tree_index(name);
}
void init_layerwise_conf(const std::vector<int>& layer_sample_counts,
int start_sample_layer, int seed) override {
seed_ = seed;
start_sample_layer_ = start_sample_layer;
PADDLE_ENFORCE_GT(
start_sample_layer_, 0,
paddle::platform::errors::InvalidArgument(
"start sampler layer = [%d], it should greater than 0.",
start_sample_layer_));
PADDLE_ENFORCE_LT(start_sample_layer_, tree_->Height(),
paddle::platform::errors::InvalidArgument(
"start sampler layer = [%d], it should less than "
"max_layer, which is [%d].",
start_sample_layer_, tree_->Height()));
size_t i = 0;
layer_counts_sum_ = 0;
layer_counts_.clear();
int cur_layer = start_sample_layer_;
while (cur_layer < tree_->Height()) {
int layer_sample_num = 1;
if (i < layer_sample_counts.size()) {
layer_sample_num = layer_sample_counts[i];
}
layer_counts_sum_ += layer_sample_num + 1;
layer_counts_.push_back(layer_sample_num);
VLOG(3) << "[INFO] level " << cur_layer
<< " sample_layer_counts.push_back: " << layer_sample_num;
cur_layer += 1;
i += 1;
}
reverse(layer_counts_.begin(), layer_counts_.end());
VLOG(3) << "sample counts sum: " << layer_counts_sum_;
}
std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids, bool with_hierarchy) override;
private:
std::vector<int> layer_counts_;
int64_t layer_counts_sum_{0};
std::shared_ptr<TreeIndex> tree_{nullptr};
int seed_{0};
int start_sample_layer_{1};
};
} // end namespace distributed
} // end namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/io/fs.h"
#include <boost/algorithm/string.hpp>
#include <boost/lexical_cast.hpp>
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
namespace paddle {
namespace distributed {
std::shared_ptr<IndexWrapper> IndexWrapper::s_instance_(nullptr);
int TreeIndex::Load(const std::string filename) {
int err_no;
auto fp = paddle::framework::fs_open_read(filename, &err_no, "");
PADDLE_ENFORCE_NE(
fp, nullptr,
platform::errors::InvalidArgument(
"Open file %s failed. Please check whether the file exists.",
filename));
int num = 0;
max_id_ = 0;
fake_node_.set_id(0);
fake_node_.set_is_leaf(false);
fake_node_.set_probability(0.0);
max_code_ = 0;
size_t ret = fread(&num, sizeof(num), 1, fp.get());
while (ret == 1 && num > 0) {
std::string content(num, '\0');
size_t read_num =
fread(const_cast<char*>(content.data()), 1, num, fp.get());
PADDLE_ENFORCE_EQ(
read_num, static_cast<size_t>(num),
platform::errors::InvalidArgument(
"Read from file: %s failed. Valid Format is "
"an integer representing the length of the following string, "
"and the string itself.We got an iteger[% d], "
"but the following string's length is [%d].",
filename, num, read_num));
KVItem item;
PADDLE_ENFORCE_EQ(
item.ParseFromString(content), true,
platform::errors::InvalidArgument("Parse from file: %s failed. It's "
"content can't be parsed by KVItem.",
filename));
if (item.key() == ".tree_meta") {
meta_.ParseFromString(item.value());
} else {
auto code = boost::lexical_cast<uint64_t>(item.key());
IndexNode node;
node.ParseFromString(item.value());
PADDLE_ENFORCE_NE(node.id(), 0,
platform::errors::InvalidArgument(
"Node'id should not be equel to zero."));
if (node.is_leaf()) {
id_codes_map_[node.id()] = code;
}
data_[code] = node;
if (node.id() > max_id_) {
max_id_ = node.id();
}
if (code > max_code_) {
max_code_ = code;
}
}
ret = fread(&num, sizeof(num), 1, fp.get());
}
total_nodes_num_ = data_.size();
max_code_ += 1;
return 0;
}
std::vector<IndexNode> TreeIndex::GetNodes(const std::vector<uint64_t>& codes) {
std::vector<IndexNode> nodes;
nodes.reserve(codes.size());
for (size_t i = 0; i < codes.size(); i++) {
if (CheckIsValid(codes[i])) {
nodes.push_back(data_.at(codes[i]));
} else {
nodes.push_back(fake_node_);
}
}
return nodes;
}
std::vector<uint64_t> TreeIndex::GetLayerCodes(int level) {
uint64_t level_num = static_cast<uint64_t>(std::pow(meta_.branch(), level));
uint64_t level_offset = level_num - 1;
std::vector<uint64_t> res;
res.reserve(level_num);
for (uint64_t i = 0; i < level_num; i++) {
auto code = level_offset + i;
if (CheckIsValid(code)) {
res.push_back(code);
}
}
return res;
}
std::vector<uint64_t> TreeIndex::GetAncestorCodes(
const std::vector<uint64_t>& ids, int level) {
std::vector<uint64_t> res;
res.reserve(ids.size());
int cur_level;
for (size_t i = 0; i < ids.size(); i++) {
if (id_codes_map_.find(ids[i]) == id_codes_map_.end()) {
res.push_back(max_code_);
} else {
auto code = id_codes_map_.at(ids[i]);
cur_level = meta_.height() - 1;
while (level >= 0 && cur_level > level) {
code = (code - 1) / meta_.branch();
cur_level--;
}
res.push_back(code);
}
}
return res;
}
std::vector<uint64_t> TreeIndex::GetChildrenCodes(uint64_t ancestor,
int level) {
auto level_code_num = static_cast<uint64_t>(std::pow(meta_.branch(), level));
auto code_min = level_code_num - 1;
auto code_max = meta_.branch() * level_code_num - 1;
std::vector<uint64_t> parent;
parent.push_back(ancestor);
std::vector<uint64_t> res;
size_t p_idx = 0;
while (true) {
size_t p_size = parent.size();
for (; p_idx < p_size; p_idx++) {
for (int i = 0; i < meta_.branch(); i++) {
auto code = parent[p_idx] * meta_.branch() + i + 1;
if (data_.find(code) != data_.end()) parent.push_back(code);
}
}
if ((code_min <= parent[p_idx]) && (parent[p_idx] < code_max)) {
break;
}
}
return std::vector<uint64_t>(parent.begin() + p_idx, parent.end());
}
std::vector<uint64_t> TreeIndex::GetTravelCodes(uint64_t id, int start_level) {
std::vector<uint64_t> res;
PADDLE_ENFORCE_NE(id_codes_map_.find(id), id_codes_map_.end(),
paddle::platform::errors::InvalidArgument(
"id = %d doesn't exist in Tree.", id));
auto code = id_codes_map_.at(id);
int level = meta_.height() - 1;
while (level >= start_level) {
res.push_back(code);
code = (code - 1) / meta_.branch();
level--;
}
return res;
}
std::vector<IndexNode> TreeIndex::GetAllLeafs() {
std::vector<IndexNode> res;
res.reserve(id_codes_map_.size());
for (auto& ite : id_codes_map_) {
auto code = ite.second;
res.push_back(data_.at(code));
}
return res;
}
} // end namespace distributed
} // end namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_dataset.pb.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
class Index {
public:
Index() {}
~Index() {}
};
class TreeIndex : public Index {
public:
TreeIndex() {}
~TreeIndex() {}
int Height() { return meta_.height(); }
int Branch() { return meta_.branch(); }
uint64_t TotalNodeNums() { return total_nodes_num_; }
uint64_t EmbSize() { return max_id_ + 1; }
int Load(const std::string path);
inline bool CheckIsValid(int code) {
if (data_.find(code) != data_.end()) {
return true;
} else {
return false;
}
}
std::vector<IndexNode> GetNodes(const std::vector<uint64_t>& codes);
std::vector<uint64_t> GetLayerCodes(int level);
std::vector<uint64_t> GetAncestorCodes(const std::vector<uint64_t>& ids,
int level);
std::vector<uint64_t> GetChildrenCodes(uint64_t ancestor, int level);
std::vector<uint64_t> GetTravelCodes(uint64_t id, int start_level);
std::vector<IndexNode> GetAllLeafs();
std::unordered_map<uint64_t, IndexNode> data_;
std::unordered_map<uint64_t, uint64_t> id_codes_map_;
uint64_t total_nodes_num_;
TreeMeta meta_;
uint64_t max_id_;
uint64_t max_code_;
IndexNode fake_node_;
};
using TreePtr = std::shared_ptr<TreeIndex>;
class IndexWrapper {
public:
virtual ~IndexWrapper() {}
IndexWrapper() {}
void clear_tree() { tree_map.clear(); }
TreePtr get_tree_index(const std::string name) {
PADDLE_ENFORCE_NE(tree_map.find(name), tree_map.end(),
paddle::platform::errors::InvalidArgument(
"tree [%s] doesn't exist. Please insert it firstly "
"by API[\' insert_tree_index \'].",
name));
return tree_map[name];
}
void insert_tree_index(const std::string name, const std::string tree_path) {
if (tree_map.find(name) != tree_map.end()) {
VLOG(0) << "Tree " << name << " has already existed.";
return;
}
TreePtr tree = std::make_shared<TreeIndex>();
int ret = tree->Load(tree_path);
PADDLE_ENFORCE_EQ(ret, 0, paddle::platform::errors::InvalidArgument(
"Load tree[%s] from path[%s] failed. Please "
"check whether the file exists.",
name, tree_path));
tree_map.insert(std::pair<std::string, TreePtr>{name, tree});
}
static std::shared_ptr<IndexWrapper> GetInstancePtr() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::distributed::IndexWrapper());
}
return s_instance_;
}
static IndexWrapper* GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::distributed::IndexWrapper());
}
return s_instance_.get();
}
private:
static std::shared_ptr<IndexWrapper> s_instance_;
std::unordered_map<std::string, TreePtr> tree_map;
};
} // end namespace distributed
} // end namespace paddle
......@@ -191,13 +191,15 @@ if(WITH_PYTHON)
py_proto_compile(distributed_strategy_py_proto SRCS distributed_strategy.proto)
#Generate an empty \
#__init__.py to make framework_py_proto as a valid python module.
add_custom_target(fleet_proto_init ALL
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto
COMMAND ${CMAKE_COMMAND} -E touch ${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto/__init__.py
)
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init trainer_py_proto distributed_strategy_py_proto)
add_dependencies(framework_py_proto framework_py_proto_init trainer_py_proto distributed_strategy_py_proto fleet_proto_init)
if (NOT WIN32)
add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto
COMMAND ${CMAKE_COMMAND} -E touch ${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto/__init__.py
COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/
COMMAND cp distributed_strategy_*.py ${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto
COMMENT "Copy generated python proto into directory paddle/fluid/proto."
......@@ -207,8 +209,6 @@ if(WITH_PYTHON)
string(REPLACE "/" "\\" fleet_proto_dstpath "${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto/")
add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto
COMMAND ${CMAKE_COMMAND} -E touch ${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto/__init__.py
COMMAND copy /Y *.py ${proto_dstpath}
COMMAND copy /Y distributed_strategy_*.py ${fleet_proto_dstpath}
COMMENT "Copy generated python proto into directory paddle/fluid/proto."
......@@ -217,6 +217,12 @@ if(WITH_PYTHON)
endif(NOT WIN32)
endif()
if (WITH_PSCORE)
add_custom_target(index_dataset_proto_init ALL DEPENDS fleet_proto_init index_dataset_py_proto
COMMAND cp ${PADDLE_BINARY_DIR}/paddle/fluid/distributed/index_dataset/index_dataset_*.py ${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto
COMMENT "Copy generated python proto into directory paddle/distributed/fleet/proto.")
endif(WITH_PSCORE)
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
......
......@@ -76,7 +76,7 @@ endif (WITH_CRYPTO)
if (WITH_PSCORE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result")
set_source_files_properties(fleet_py.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
list(APPEND PYBIND_DEPS fleet communicator)
list(APPEND PYBIND_DEPS fleet communicator index_wrapper index_sampler)
list(APPEND PYBIND_SRCS fleet_py.cc)
endif()
......
......@@ -30,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/distributed/communicator_common.h"
#include "paddle/fluid/distributed/fleet.h"
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/graph_brpc_client.h"
......@@ -212,5 +214,76 @@ void BindGraphPyClient(py::module* m) {
.def("bind_local_server", &GraphPyClient::bind_local_server);
}
using paddle::distributed::TreeIndex;
using paddle::distributed::IndexWrapper;
using paddle::distributed::IndexNode;
void BindIndexNode(py::module* m) {
py::class_<IndexNode>(*m, "IndexNode")
.def(py::init<>())
.def("id", [](IndexNode& self) { return self.id(); })
.def("is_leaf", [](IndexNode& self) { return self.is_leaf(); })
.def("probability", [](IndexNode& self) { return self.probability(); });
}
void BindTreeIndex(py::module* m) {
py::class_<TreeIndex, std::shared_ptr<TreeIndex>>(*m, "TreeIndex")
.def(py::init([](const std::string name, const std::string path) {
auto index_wrapper = IndexWrapper::GetInstancePtr();
index_wrapper->insert_tree_index(name, path);
return index_wrapper->get_tree_index(name);
}))
.def("height", [](TreeIndex& self) { return self.Height(); })
.def("branch", [](TreeIndex& self) { return self.Branch(); })
.def("total_node_nums",
[](TreeIndex& self) { return self.TotalNodeNums(); })
.def("emb_size", [](TreeIndex& self) { return self.EmbSize(); })
.def("get_all_leafs", [](TreeIndex& self) { return self.GetAllLeafs(); })
.def("get_nodes",
[](TreeIndex& self, const std::vector<uint64_t>& codes) {
return self.GetNodes(codes);
})
.def("get_layer_codes",
[](TreeIndex& self, int level) { return self.GetLayerCodes(level); })
.def("get_ancestor_codes",
[](TreeIndex& self, const std::vector<uint64_t>& ids, int level) {
return self.GetAncestorCodes(ids, level);
})
.def("get_children_codes",
[](TreeIndex& self, uint64_t ancestor, int level) {
return self.GetChildrenCodes(ancestor, level);
})
.def("get_travel_codes",
[](TreeIndex& self, uint64_t id, int start_level) {
return self.GetTravelCodes(id, start_level);
});
}
void BindIndexWrapper(py::module* m) {
py::class_<IndexWrapper, std::shared_ptr<IndexWrapper>>(*m, "IndexWrapper")
.def(py::init([]() { return IndexWrapper::GetInstancePtr(); }))
.def("insert_tree_index", &IndexWrapper::insert_tree_index)
.def("get_tree_index", &IndexWrapper::get_tree_index)
.def("clear_tree", &IndexWrapper::clear_tree);
}
using paddle::distributed::IndexSampler;
using paddle::distributed::LayerWiseSampler;
void BindIndexSampler(py::module* m) {
py::class_<IndexSampler, std::shared_ptr<IndexSampler>>(*m, "IndexSampler")
.def(py::init([](const std::string& mode, const std::string& name) {
if (mode == "by_layerwise") {
return IndexSampler::Init<LayerWiseSampler>(name);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported IndexSampler Type!"));
}
}))
.def("init_layerwise_conf", &IndexSampler::init_layerwise_conf)
.def("init_beamsearch_conf", &IndexSampler::init_beamsearch_conf)
.def("sample", &IndexSampler::sample);
}
} // end namespace pybind
} // namespace paddle
......@@ -32,5 +32,9 @@ void BindGraphPyService(py::module* m);
void BindGraphPyFeatureNode(py::module* m);
void BindGraphPyServer(py::module* m);
void BindGraphPyClient(py::module* m);
void BindIndexNode(py::module* m);
void BindTreeIndex(py::module* m);
void BindIndexWrapper(py::module* m);
void BindIndexSampler(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -3092,6 +3092,11 @@ All parameter, weight, gradient are variables in Paddle.
BindGraphPyService(&m);
BindGraphPyServer(&m);
BindGraphPyClient(&m);
BindIndexNode(&m);
BindTreeIndex(&m);
BindIndexWrapper(&m);
BindIndexSampler(&m);
#endif
}
} // namespace pybind
......
......@@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
from .dataset import *
from .index_dataset import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import core
class Index(object):
def __init__(self, name):
self._name = name
class TreeIndex(Index):
def __init__(self, name, path):
super(TreeIndex, self).__init__(name)
self._wrapper = core.IndexWrapper()
self._wrapper.insert_tree_index(name, path)
self._tree = self._wrapper.get_tree_index(name)
self._height = self._tree.height()
self._branch = self._tree.branch()
self._total_node_nums = self._tree.total_node_nums()
self._emb_size = self._tree.emb_size()
self._layerwise_sampler = None
def height(self):
return self._height
def branch(self):
return self._branch
def total_node_nums(self):
return self._total_node_nums
def emb_size(self):
return self._emb_size
def get_all_leafs(self):
return self._tree.get_all_leafs()
def get_nodes(self, codes):
return self._tree.get_nodes(codes)
def get_layer_codes(self, level):
return self._tree.get_layer_codes(level)
def get_travel_codes(self, id, start_level=0):
return self._tree.get_travel_codes(id, start_level)
def get_ancestor_codes(self, ids, level):
return self._tree.get_ancestor_codes(ids, level)
def get_children_codes(self, ancestor, level):
return self._tree.get_children_codes(ancestor, level)
def get_travel_path(self, child, ancestor):
res = []
while (child > ancestor):
res.append(child)
child = int((child - 1) / self._branch)
return res
def get_pi_relation(self, ids, level):
codes = self.get_ancestor_codes(ids, level)
return dict(zip(ids, codes))
def init_layerwise_sampler(self,
layer_sample_counts,
start_sample_layer=1,
seed=0):
assert self._layerwise_sampler is None
self._layerwise_sampler = core.IndexSampler("by_layerwise", self._name)
self._layerwise_sampler.init_layerwise_conf(layer_sample_counts,
start_sample_layer, seed)
def layerwise_sample(self, user_input, index_input, with_hierarchy=False):
if self._layerwise_sampler is None:
raise ValueError("please init layerwise_sampler first.")
return self._layerwise_sampler.sample(user_input, index_input,
with_hierarchy)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.dataset.common import download, DATA_HOME
from paddle.distributed.fleet.dataset import TreeIndex
class TestTreeIndex(unittest.TestCase):
def test_tree_index(self):
path = download(
"https://paddlerec.bj.bcebos.com/tree-based/data/demo_tree.pb",
"tree_index_unittest", "cadec20089f5a8a44d320e117d9f9f1a")
tree = TreeIndex("demo", path)
height = tree.height()
branch = tree.branch()
self.assertTrue(height == 14)
self.assertTrue(branch == 2)
self.assertEqual(tree.total_node_nums(), 15581)
self.assertEqual(tree.emb_size(), 5171136)
# get_layer_codes
layer_node_ids = []
layer_node_codes = []
for i in range(tree.height()):
layer_node_codes.append(tree.get_layer_codes(i))
layer_node_ids.append(
[node.id() for node in tree.get_nodes(layer_node_codes[-1])])
all_leaf_ids = [node.id() for node in tree.get_all_leafs()]
self.assertEqual(sum(all_leaf_ids), sum(layer_node_ids[-1]))
# get_travel
travel_codes = tree.get_travel_codes(all_leaf_ids[0])
travel_ids = [node.id() for node in tree.get_nodes(travel_codes)]
for i in range(height):
self.assertIn(travel_ids[i], layer_node_ids[height - 1 - i])
self.assertIn(travel_codes[i], layer_node_codes[height - 1 - i])
# get_ancestor
ancestor_codes = tree.get_ancestor_codes([all_leaf_ids[0]], height - 2)
ancestor_ids = [node.id() for node in tree.get_nodes(ancestor_codes)]
self.assertEqual(ancestor_ids[0], travel_ids[1])
self.assertEqual(ancestor_codes[0], travel_codes[1])
# get_pi_relation
pi_relation = tree.get_pi_relation([all_leaf_ids[0]], height - 2)
self.assertEqual(pi_relation[all_leaf_ids[0]], ancestor_codes[0])
# get_travel_path
travel_path_codes = tree.get_travel_path(travel_codes[0],
travel_codes[-1])
travel_path_ids = [
node.id() for node in tree.get_nodes(travel_path_codes)
]
self.assertEquals(travel_path_ids + [travel_ids[-1]], travel_ids)
self.assertEquals(travel_path_codes + [travel_codes[-1]], travel_codes)
# get_children
children_codes = tree.get_children_codes(travel_codes[1], height - 1)
children_ids = [node.id() for node in tree.get_nodes(children_codes)]
self.assertIn(all_leaf_ids[0], children_ids)
class TestIndexSampler(unittest.TestCase):
def test_layerwise_sampler(self):
path = download(
"https://paddlerec.bj.bcebos.com/tree-based/data/demo_tree.pb",
"tree_index_unittest", "cadec20089f5a8a44d320e117d9f9f1a")
tree = TreeIndex("demo", path)
layer_nodes = []
for i in range(tree.height()):
layer_codes = tree.get_layer_codes(i)
layer_nodes.append(
[node.id() for node in tree.get_nodes(layer_codes)])
sample_num = range(1, 10000)
start_sample_layer = 1
seed = 0
sample_layers = tree.height() - start_sample_layer
sample_num = sample_num[:sample_layers]
layer_sample_counts = list(sample_num) + [1] * (sample_layers -
len(sample_num))
total_sample_num = sum(layer_sample_counts) + len(layer_sample_counts)
tree.init_layerwise_sampler(sample_num, start_sample_layer, seed)
ids = [315757, 838060, 1251533, 403522, 2473624, 3321007]
parent_path = {}
for i in range(len(ids)):
tmp = tree.get_travel_codes(ids[i], start_sample_layer)
parent_path[ids[i]] = [node.id() for node in tree.get_nodes(tmp)]
# check sample res with_hierarchy = False
sample_res = tree.layerwise_sample(
[[315757, 838060], [1251533, 403522]], [2473624, 3321007], False)
idx = 0
layer = tree.height() - 1
for i in range(len(layer_sample_counts)):
for j in range(layer_sample_counts[0 - (i + 1)] + 1):
self.assertTrue(sample_res[idx + j][0] == 315757)
self.assertTrue(sample_res[idx + j][1] == 838060)
self.assertTrue(sample_res[idx + j][2] in layer_nodes[layer])
if j == 0:
self.assertTrue(sample_res[idx + j][3] == 1)
self.assertTrue(
sample_res[idx + j][2] == parent_path[2473624][i])
else:
self.assertTrue(sample_res[idx + j][3] == 0)
self.assertTrue(
sample_res[idx + j][2] != parent_path[2473624][i])
idx += layer_sample_counts[0 - (i + 1)] + 1
layer -= 1
self.assertTrue(idx == total_sample_num)
layer = tree.height() - 1
for i in range(len(layer_sample_counts)):
for j in range(layer_sample_counts[0 - (i + 1)] + 1):
self.assertTrue(sample_res[idx + j][0] == 1251533)
self.assertTrue(sample_res[idx + j][1] == 403522)
self.assertTrue(sample_res[idx + j][2] in layer_nodes[layer])
if j == 0:
self.assertTrue(sample_res[idx + j][3] == 1)
self.assertTrue(
sample_res[idx + j][2] == parent_path[3321007][i])
else:
self.assertTrue(sample_res[idx + j][3] == 0)
self.assertTrue(
sample_res[idx + j][2] != parent_path[3321007][i])
idx += layer_sample_counts[0 - (i + 1)] + 1
layer -= 1
self.assertTrue(idx == total_sample_num * 2)
# check sample res with_hierarchy = True
sample_res_with_hierarchy = tree.layerwise_sample(
[[315757, 838060], [1251533, 403522]], [2473624, 3321007], True)
idx = 0
layer = tree.height() - 1
for i in range(len(layer_sample_counts)):
for j in range(layer_sample_counts[0 - (i + 1)] + 1):
self.assertTrue(sample_res_with_hierarchy[idx + j][0] ==
parent_path[315757][i])
self.assertTrue(sample_res_with_hierarchy[idx + j][1] ==
parent_path[838060][i])
self.assertTrue(
sample_res_with_hierarchy[idx + j][2] in layer_nodes[layer])
if j == 0:
self.assertTrue(sample_res_with_hierarchy[idx + j][3] == 1)
self.assertTrue(sample_res_with_hierarchy[idx + j][2] ==
parent_path[2473624][i])
else:
self.assertTrue(sample_res_with_hierarchy[idx + j][3] == 0)
self.assertTrue(sample_res_with_hierarchy[idx + j][2] !=
parent_path[2473624][i])
idx += layer_sample_counts[0 - (i + 1)] + 1
layer -= 1
self.assertTrue(idx == total_sample_num)
layer = tree.height() - 1
for i in range(len(layer_sample_counts)):
for j in range(layer_sample_counts[0 - (i + 1)] + 1):
self.assertTrue(sample_res_with_hierarchy[idx + j][0] ==
parent_path[1251533][i])
self.assertTrue(sample_res_with_hierarchy[idx + j][1] ==
parent_path[403522][i])
self.assertTrue(
sample_res_with_hierarchy[idx + j][2] in layer_nodes[layer])
if j == 0:
self.assertTrue(sample_res_with_hierarchy[idx + j][3] == 1)
self.assertTrue(sample_res_with_hierarchy[idx + j][2] ==
parent_path[3321007][i])
else:
self.assertTrue(sample_res_with_hierarchy[idx + j][3] == 0)
self.assertTrue(sample_res_with_hierarchy[idx + j][2] !=
parent_path[3321007][i])
idx += layer_sample_counts[0 - (i + 1)] + 1
layer -= 1
self.assertTrue(idx == 2 * total_sample_num)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册