Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a8c3a902
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a8c3a902
编写于
4月 15, 2021
作者:
1
123malin
提交者:
GitHub
4月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
tree-based-model (#31696)
* add index_dataset and index_sampler for tree-based model
上级
825d4957
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
932 addition
and
6 deletion
+932
-6
paddle/fluid/distributed/CMakeLists.txt
paddle/fluid/distributed/CMakeLists.txt
+1
-0
paddle/fluid/distributed/index_dataset/CMakeLists.txt
paddle/fluid/distributed/index_dataset/CMakeLists.txt
+7
-0
paddle/fluid/distributed/index_dataset/index_dataset.proto
paddle/fluid/distributed/index_dataset/index_dataset.proto
+32
-0
paddle/fluid/distributed/index_dataset/index_sampler.cc
paddle/fluid/distributed/index_dataset/index_sampler.cc
+95
-0
paddle/fluid/distributed/index_dataset/index_sampler.h
paddle/fluid/distributed/index_dataset/index_sampler.h
+100
-0
paddle/fluid/distributed/index_dataset/index_wrapper.cc
paddle/fluid/distributed/index_dataset/index_wrapper.cc
+196
-0
paddle/fluid/distributed/index_dataset/index_wrapper.h
paddle/fluid/distributed/index_dataset/index_wrapper.h
+120
-0
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+11
-5
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+1
-1
paddle/fluid/pybind/fleet_py.cc
paddle/fluid/pybind/fleet_py.cc
+73
-0
paddle/fluid/pybind/fleet_py.h
paddle/fluid/pybind/fleet_py.h
+4
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+5
-0
python/paddle/distributed/fleet/dataset/__init__.py
python/paddle/distributed/fleet/dataset/__init__.py
+1
-0
python/paddle/distributed/fleet/dataset/index_dataset.py
python/paddle/distributed/fleet/dataset/index_dataset.py
+88
-0
python/paddle/fluid/tests/unittests/test_dist_tree_index.py
python/paddle/fluid/tests/unittests/test_dist_tree_index.py
+198
-0
未找到文件。
paddle/fluid/distributed/CMakeLists.txt
浏览文件 @
a8c3a902
...
...
@@ -14,6 +14,7 @@ endif()
add_subdirectory
(
table
)
add_subdirectory
(
service
)
add_subdirectory
(
test
)
add_subdirectory
(
index_dataset
)
get_property
(
RPC_DEPS GLOBAL PROPERTY RPC_DEPS
)
...
...
paddle/fluid/distributed/index_dataset/CMakeLists.txt
0 → 100644
浏览文件 @
a8c3a902
proto_library
(
index_dataset_proto SRCS index_dataset.proto
)
cc_library
(
index_wrapper SRCS index_wrapper.cc DEPS index_dataset_proto
)
cc_library
(
index_sampler SRCS index_sampler.cc DEPS index_wrapper
)
if
(
WITH_PYTHON
)
py_proto_compile
(
index_dataset_py_proto SRCS index_dataset.proto
)
endif
()
paddle/fluid/distributed/index_dataset/index_dataset.proto
0 → 100644
浏览文件 @
a8c3a902
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax
=
"proto2"
;
package
paddle
.
distributed
;
message
IndexNode
{
required
uint64
id
=
1
;
required
bool
is_leaf
=
2
;
required
float
probability
=
3
;
}
message
TreeMeta
{
required
int32
height
=
1
;
required
int32
branch
=
2
;
}
message
KVItem
{
required
bytes
key
=
1
;
required
bytes
value
=
2
;
}
\ No newline at end of file
paddle/fluid/distributed/index_dataset/index_sampler.cc
0 → 100644
浏览文件 @
a8c3a902
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/operators/math/sampler.h"
namespace
paddle
{
namespace
distributed
{
using
Sampler
=
paddle
::
operators
::
math
::
Sampler
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
LayerWiseSampler
::
sample
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>&
user_inputs
,
const
std
::
vector
<
uint64_t
>&
target_ids
,
bool
with_hierarchy
)
{
auto
input_num
=
target_ids
.
size
();
auto
user_feature_num
=
user_inputs
[
0
].
size
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
outputs
(
input_num
*
layer_counts_sum_
,
std
::
vector
<
uint64_t
>
(
user_feature_num
+
2
));
auto
max_layer
=
tree_
->
Height
();
std
::
vector
<
Sampler
*>
sampler_vec
(
max_layer
-
start_sample_layer_
);
std
::
vector
<
std
::
vector
<
IndexNode
>>
layer_ids
(
max_layer
-
start_sample_layer_
);
auto
layer_index
=
max_layer
-
1
;
size_t
idx
=
0
;
while
(
layer_index
>=
start_sample_layer_
)
{
auto
layer_codes
=
tree_
->
GetLayerCodes
(
layer_index
);
layer_ids
[
idx
]
=
tree_
->
GetNodes
(
layer_codes
);
sampler_vec
[
idx
]
=
new
paddle
::
operators
::
math
::
UniformSampler
(
layer_ids
[
idx
].
size
()
-
1
,
seed_
);
layer_index
--
;
idx
++
;
}
idx
=
0
;
for
(
size_t
i
=
0
;
i
<
input_num
;
i
++
)
{
auto
travel_codes
=
tree_
->
GetTravelCodes
(
target_ids
[
i
],
start_sample_layer_
);
auto
travel_path
=
tree_
->
GetNodes
(
travel_codes
);
for
(
size_t
j
=
0
;
j
<
travel_path
.
size
();
j
++
)
{
// user
if
(
j
>
0
&&
with_hierarchy
)
{
auto
ancestor_codes
=
tree_
->
GetAncestorCodes
(
user_inputs
[
i
],
max_layer
-
j
-
1
);
auto
hierarchical_user
=
tree_
->
GetNodes
(
ancestor_codes
);
for
(
int
idx_offset
=
0
;
idx_offset
<=
layer_counts_
[
j
];
idx_offset
++
)
{
for
(
size_t
k
=
0
;
k
<
user_feature_num
;
k
++
)
{
outputs
[
idx
+
idx_offset
][
k
]
=
hierarchical_user
[
k
].
id
();
}
}
}
else
{
for
(
int
idx_offset
=
0
;
idx_offset
<=
layer_counts_
[
j
];
idx_offset
++
)
{
for
(
size_t
k
=
0
;
k
<
user_feature_num
;
k
++
)
{
outputs
[
idx
+
idx_offset
][
k
]
=
user_inputs
[
i
][
k
];
}
}
}
// sampler ++
outputs
[
idx
][
user_feature_num
]
=
travel_path
[
j
].
id
();
outputs
[
idx
][
user_feature_num
+
1
]
=
1.0
;
idx
+=
1
;
for
(
int
idx_offset
=
0
;
idx_offset
<
layer_counts_
[
j
];
idx_offset
++
)
{
int
sample_res
=
0
;
do
{
sample_res
=
sampler_vec
[
j
]
->
Sample
();
}
while
(
layer_ids
[
j
][
sample_res
].
id
()
==
travel_path
[
j
].
id
());
outputs
[
idx
+
idx_offset
][
user_feature_num
]
=
layer_ids
[
j
][
sample_res
].
id
();
outputs
[
idx
+
idx_offset
][
user_feature_num
+
1
]
=
0
;
}
idx
+=
layer_counts_
[
j
];
}
}
for
(
size_t
i
=
0
;
i
<
sampler_vec
.
size
();
i
++
)
{
delete
sampler_vec
[
i
];
}
return
outputs
;
}
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/index_dataset/index_sampler.h
0 → 100644
浏览文件 @
a8c3a902
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
class
IndexSampler
{
public:
virtual
~
IndexSampler
()
{}
IndexSampler
()
{}
template
<
typename
T
>
static
std
::
shared_ptr
<
IndexSampler
>
Init
(
const
std
::
string
&
name
)
{
std
::
shared_ptr
<
IndexSampler
>
instance
=
nullptr
;
instance
.
reset
(
new
T
(
name
));
return
instance
;
}
virtual
void
init_layerwise_conf
(
const
std
::
vector
<
int
>&
layer_sample_counts
,
int
start_sample_layer
=
1
,
int
seed
=
0
)
{}
virtual
void
init_beamsearch_conf
(
const
int64_t
k
)
{}
virtual
std
::
vector
<
std
::
vector
<
uint64_t
>>
sample
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>&
user_inputs
,
const
std
::
vector
<
uint64_t
>&
input_targets
,
bool
with_hierarchy
=
false
)
=
0
;
};
class
LayerWiseSampler
:
public
IndexSampler
{
public:
virtual
~
LayerWiseSampler
()
{}
explicit
LayerWiseSampler
(
const
std
::
string
&
name
)
{
tree_
=
IndexWrapper
::
GetInstance
()
->
get_tree_index
(
name
);
}
void
init_layerwise_conf
(
const
std
::
vector
<
int
>&
layer_sample_counts
,
int
start_sample_layer
,
int
seed
)
override
{
seed_
=
seed
;
start_sample_layer_
=
start_sample_layer
;
PADDLE_ENFORCE_GT
(
start_sample_layer_
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"start sampler layer = [%d], it should greater than 0."
,
start_sample_layer_
));
PADDLE_ENFORCE_LT
(
start_sample_layer_
,
tree_
->
Height
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"start sampler layer = [%d], it should less than "
"max_layer, which is [%d]."
,
start_sample_layer_
,
tree_
->
Height
()));
size_t
i
=
0
;
layer_counts_sum_
=
0
;
layer_counts_
.
clear
();
int
cur_layer
=
start_sample_layer_
;
while
(
cur_layer
<
tree_
->
Height
())
{
int
layer_sample_num
=
1
;
if
(
i
<
layer_sample_counts
.
size
())
{
layer_sample_num
=
layer_sample_counts
[
i
];
}
layer_counts_sum_
+=
layer_sample_num
+
1
;
layer_counts_
.
push_back
(
layer_sample_num
);
VLOG
(
3
)
<<
"[INFO] level "
<<
cur_layer
<<
" sample_layer_counts.push_back: "
<<
layer_sample_num
;
cur_layer
+=
1
;
i
+=
1
;
}
reverse
(
layer_counts_
.
begin
(),
layer_counts_
.
end
());
VLOG
(
3
)
<<
"sample counts sum: "
<<
layer_counts_sum_
;
}
std
::
vector
<
std
::
vector
<
uint64_t
>>
sample
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>&
user_inputs
,
const
std
::
vector
<
uint64_t
>&
target_ids
,
bool
with_hierarchy
)
override
;
private:
std
::
vector
<
int
>
layer_counts_
;
int64_t
layer_counts_sum_
{
0
};
std
::
shared_ptr
<
TreeIndex
>
tree_
{
nullptr
};
int
seed_
{
0
};
int
start_sample_layer_
{
1
};
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/index_dataset/index_wrapper.cc
0 → 100644
浏览文件 @
a8c3a902
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/io/fs.h"
#include <boost/algorithm/string.hpp>
#include <boost/lexical_cast.hpp>
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
namespace
paddle
{
namespace
distributed
{
std
::
shared_ptr
<
IndexWrapper
>
IndexWrapper
::
s_instance_
(
nullptr
);
int
TreeIndex
::
Load
(
const
std
::
string
filename
)
{
int
err_no
;
auto
fp
=
paddle
::
framework
::
fs_open_read
(
filename
,
&
err_no
,
""
);
PADDLE_ENFORCE_NE
(
fp
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Open file %s failed. Please check whether the file exists."
,
filename
));
int
num
=
0
;
max_id_
=
0
;
fake_node_
.
set_id
(
0
);
fake_node_
.
set_is_leaf
(
false
);
fake_node_
.
set_probability
(
0.0
);
max_code_
=
0
;
size_t
ret
=
fread
(
&
num
,
sizeof
(
num
),
1
,
fp
.
get
());
while
(
ret
==
1
&&
num
>
0
)
{
std
::
string
content
(
num
,
'\0'
);
size_t
read_num
=
fread
(
const_cast
<
char
*>
(
content
.
data
()),
1
,
num
,
fp
.
get
());
PADDLE_ENFORCE_EQ
(
read_num
,
static_cast
<
size_t
>
(
num
),
platform
::
errors
::
InvalidArgument
(
"Read from file: %s failed. Valid Format is "
"an integer representing the length of the following string, "
"and the string itself.We got an iteger[% d], "
"but the following string's length is [%d]."
,
filename
,
num
,
read_num
));
KVItem
item
;
PADDLE_ENFORCE_EQ
(
item
.
ParseFromString
(
content
),
true
,
platform
::
errors
::
InvalidArgument
(
"Parse from file: %s failed. It's "
"content can't be parsed by KVItem."
,
filename
));
if
(
item
.
key
()
==
".tree_meta"
)
{
meta_
.
ParseFromString
(
item
.
value
());
}
else
{
auto
code
=
boost
::
lexical_cast
<
uint64_t
>
(
item
.
key
());
IndexNode
node
;
node
.
ParseFromString
(
item
.
value
());
PADDLE_ENFORCE_NE
(
node
.
id
(),
0
,
platform
::
errors
::
InvalidArgument
(
"Node'id should not be equel to zero."
));
if
(
node
.
is_leaf
())
{
id_codes_map_
[
node
.
id
()]
=
code
;
}
data_
[
code
]
=
node
;
if
(
node
.
id
()
>
max_id_
)
{
max_id_
=
node
.
id
();
}
if
(
code
>
max_code_
)
{
max_code_
=
code
;
}
}
ret
=
fread
(
&
num
,
sizeof
(
num
),
1
,
fp
.
get
());
}
total_nodes_num_
=
data_
.
size
();
max_code_
+=
1
;
return
0
;
}
std
::
vector
<
IndexNode
>
TreeIndex
::
GetNodes
(
const
std
::
vector
<
uint64_t
>&
codes
)
{
std
::
vector
<
IndexNode
>
nodes
;
nodes
.
reserve
(
codes
.
size
());
for
(
size_t
i
=
0
;
i
<
codes
.
size
();
i
++
)
{
if
(
CheckIsValid
(
codes
[
i
]))
{
nodes
.
push_back
(
data_
.
at
(
codes
[
i
]));
}
else
{
nodes
.
push_back
(
fake_node_
);
}
}
return
nodes
;
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetLayerCodes
(
int
level
)
{
uint64_t
level_num
=
static_cast
<
uint64_t
>
(
std
::
pow
(
meta_
.
branch
(),
level
));
uint64_t
level_offset
=
level_num
-
1
;
std
::
vector
<
uint64_t
>
res
;
res
.
reserve
(
level_num
);
for
(
uint64_t
i
=
0
;
i
<
level_num
;
i
++
)
{
auto
code
=
level_offset
+
i
;
if
(
CheckIsValid
(
code
))
{
res
.
push_back
(
code
);
}
}
return
res
;
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetAncestorCodes
(
const
std
::
vector
<
uint64_t
>&
ids
,
int
level
)
{
std
::
vector
<
uint64_t
>
res
;
res
.
reserve
(
ids
.
size
());
int
cur_level
;
for
(
size_t
i
=
0
;
i
<
ids
.
size
();
i
++
)
{
if
(
id_codes_map_
.
find
(
ids
[
i
])
==
id_codes_map_
.
end
())
{
res
.
push_back
(
max_code_
);
}
else
{
auto
code
=
id_codes_map_
.
at
(
ids
[
i
]);
cur_level
=
meta_
.
height
()
-
1
;
while
(
level
>=
0
&&
cur_level
>
level
)
{
code
=
(
code
-
1
)
/
meta_
.
branch
();
cur_level
--
;
}
res
.
push_back
(
code
);
}
}
return
res
;
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetChildrenCodes
(
uint64_t
ancestor
,
int
level
)
{
auto
level_code_num
=
static_cast
<
uint64_t
>
(
std
::
pow
(
meta_
.
branch
(),
level
));
auto
code_min
=
level_code_num
-
1
;
auto
code_max
=
meta_
.
branch
()
*
level_code_num
-
1
;
std
::
vector
<
uint64_t
>
parent
;
parent
.
push_back
(
ancestor
);
std
::
vector
<
uint64_t
>
res
;
size_t
p_idx
=
0
;
while
(
true
)
{
size_t
p_size
=
parent
.
size
();
for
(;
p_idx
<
p_size
;
p_idx
++
)
{
for
(
int
i
=
0
;
i
<
meta_
.
branch
();
i
++
)
{
auto
code
=
parent
[
p_idx
]
*
meta_
.
branch
()
+
i
+
1
;
if
(
data_
.
find
(
code
)
!=
data_
.
end
())
parent
.
push_back
(
code
);
}
}
if
((
code_min
<=
parent
[
p_idx
])
&&
(
parent
[
p_idx
]
<
code_max
))
{
break
;
}
}
return
std
::
vector
<
uint64_t
>
(
parent
.
begin
()
+
p_idx
,
parent
.
end
());
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetTravelCodes
(
uint64_t
id
,
int
start_level
)
{
std
::
vector
<
uint64_t
>
res
;
PADDLE_ENFORCE_NE
(
id_codes_map_
.
find
(
id
),
id_codes_map_
.
end
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"id = %d doesn't exist in Tree."
,
id
));
auto
code
=
id_codes_map_
.
at
(
id
);
int
level
=
meta_
.
height
()
-
1
;
while
(
level
>=
start_level
)
{
res
.
push_back
(
code
);
code
=
(
code
-
1
)
/
meta_
.
branch
();
level
--
;
}
return
res
;
}
std
::
vector
<
IndexNode
>
TreeIndex
::
GetAllLeafs
()
{
std
::
vector
<
IndexNode
>
res
;
res
.
reserve
(
id_codes_map_
.
size
());
for
(
auto
&
ite
:
id_codes_map_
)
{
auto
code
=
ite
.
second
;
res
.
push_back
(
data_
.
at
(
code
));
}
return
res
;
}
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/index_dataset/index_wrapper.h
0 → 100644
浏览文件 @
a8c3a902
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_dataset.pb.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
class
Index
{
public:
Index
()
{}
~
Index
()
{}
};
class
TreeIndex
:
public
Index
{
public:
TreeIndex
()
{}
~
TreeIndex
()
{}
int
Height
()
{
return
meta_
.
height
();
}
int
Branch
()
{
return
meta_
.
branch
();
}
uint64_t
TotalNodeNums
()
{
return
total_nodes_num_
;
}
uint64_t
EmbSize
()
{
return
max_id_
+
1
;
}
int
Load
(
const
std
::
string
path
);
inline
bool
CheckIsValid
(
int
code
)
{
if
(
data_
.
find
(
code
)
!=
data_
.
end
())
{
return
true
;
}
else
{
return
false
;
}
}
std
::
vector
<
IndexNode
>
GetNodes
(
const
std
::
vector
<
uint64_t
>&
codes
);
std
::
vector
<
uint64_t
>
GetLayerCodes
(
int
level
);
std
::
vector
<
uint64_t
>
GetAncestorCodes
(
const
std
::
vector
<
uint64_t
>&
ids
,
int
level
);
std
::
vector
<
uint64_t
>
GetChildrenCodes
(
uint64_t
ancestor
,
int
level
);
std
::
vector
<
uint64_t
>
GetTravelCodes
(
uint64_t
id
,
int
start_level
);
std
::
vector
<
IndexNode
>
GetAllLeafs
();
std
::
unordered_map
<
uint64_t
,
IndexNode
>
data_
;
std
::
unordered_map
<
uint64_t
,
uint64_t
>
id_codes_map_
;
uint64_t
total_nodes_num_
;
TreeMeta
meta_
;
uint64_t
max_id_
;
uint64_t
max_code_
;
IndexNode
fake_node_
;
};
using
TreePtr
=
std
::
shared_ptr
<
TreeIndex
>
;
class
IndexWrapper
{
public:
virtual
~
IndexWrapper
()
{}
IndexWrapper
()
{}
void
clear_tree
()
{
tree_map
.
clear
();
}
TreePtr
get_tree_index
(
const
std
::
string
name
)
{
PADDLE_ENFORCE_NE
(
tree_map
.
find
(
name
),
tree_map
.
end
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"tree [%s] doesn't exist. Please insert it firstly "
"by API[
\'
insert_tree_index
\'
]."
,
name
));
return
tree_map
[
name
];
}
void
insert_tree_index
(
const
std
::
string
name
,
const
std
::
string
tree_path
)
{
if
(
tree_map
.
find
(
name
)
!=
tree_map
.
end
())
{
VLOG
(
0
)
<<
"Tree "
<<
name
<<
" has already existed."
;
return
;
}
TreePtr
tree
=
std
::
make_shared
<
TreeIndex
>
();
int
ret
=
tree
->
Load
(
tree_path
);
PADDLE_ENFORCE_EQ
(
ret
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Load tree[%s] from path[%s] failed. Please "
"check whether the file exists."
,
name
,
tree_path
));
tree_map
.
insert
(
std
::
pair
<
std
::
string
,
TreePtr
>
{
name
,
tree
});
}
static
std
::
shared_ptr
<
IndexWrapper
>
GetInstancePtr
()
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
distributed
::
IndexWrapper
());
}
return
s_instance_
;
}
static
IndexWrapper
*
GetInstance
()
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
distributed
::
IndexWrapper
());
}
return
s_instance_
.
get
();
}
private:
static
std
::
shared_ptr
<
IndexWrapper
>
s_instance_
;
std
::
unordered_map
<
std
::
string
,
TreePtr
>
tree_map
;
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
a8c3a902
...
...
@@ -191,13 +191,15 @@ if(WITH_PYTHON)
py_proto_compile
(
distributed_strategy_py_proto SRCS distributed_strategy.proto
)
#Generate an empty \
#__init__.py to make framework_py_proto as a valid python module.
add_custom_target
(
fleet_proto_init ALL
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto
COMMAND
${
CMAKE_COMMAND
}
-E touch
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto/__init__.py
)
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_dependencies
(
framework_py_proto framework_py_proto_init trainer_py_proto distributed_strategy_py_proto
)
add_dependencies
(
framework_py_proto framework_py_proto_init trainer_py_proto distributed_strategy_py_proto
fleet_proto_init
)
if
(
NOT WIN32
)
add_custom_command
(
TARGET framework_py_proto POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto
COMMAND
${
CMAKE_COMMAND
}
-E touch
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto/__init__.py
COMMAND cp *.py
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto/
COMMAND cp distributed_strategy_*.py
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto
COMMENT
"Copy generated python proto into directory paddle/fluid/proto."
...
...
@@ -207,8 +209,6 @@ if(WITH_PYTHON)
string
(
REPLACE
"/"
"
\\
"
fleet_proto_dstpath
"
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto/"
)
add_custom_command
(
TARGET framework_py_proto POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto
COMMAND
${
CMAKE_COMMAND
}
-E touch
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto/__init__.py
COMMAND copy /Y *.py
${
proto_dstpath
}
COMMAND copy /Y distributed_strategy_*.py
${
fleet_proto_dstpath
}
COMMENT
"Copy generated python proto into directory paddle/fluid/proto."
...
...
@@ -217,6 +217,12 @@ if(WITH_PYTHON)
endif
(
NOT WIN32
)
endif
()
if
(
WITH_PSCORE
)
add_custom_target
(
index_dataset_proto_init ALL DEPENDS fleet_proto_init index_dataset_py_proto
COMMAND cp
${
PADDLE_BINARY_DIR
}
/paddle/fluid/distributed/index_dataset/index_dataset_*.py
${
PADDLE_BINARY_DIR
}
/python/paddle/distributed/fleet/proto
COMMENT
"Copy generated python proto into directory paddle/distributed/fleet/proto."
)
endif
(
WITH_PSCORE
)
cc_library
(
lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor
)
cc_library
(
feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog
)
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
a8c3a902
...
...
@@ -76,7 +76,7 @@ endif (WITH_CRYPTO)
if
(
WITH_PSCORE
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result"
)
set_source_files_properties
(
fleet_py.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
list
(
APPEND PYBIND_DEPS fleet communicator
)
list
(
APPEND PYBIND_DEPS fleet communicator
index_wrapper index_sampler
)
list
(
APPEND PYBIND_SRCS fleet_py.cc
)
endif
()
...
...
paddle/fluid/pybind/fleet_py.cc
浏览文件 @
a8c3a902
...
...
@@ -30,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/distributed/communicator_common.h"
#include "paddle/fluid/distributed/fleet.h"
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/graph_brpc_client.h"
...
...
@@ -212,5 +214,76 @@ void BindGraphPyClient(py::module* m) {
.
def
(
"bind_local_server"
,
&
GraphPyClient
::
bind_local_server
);
}
using
paddle
::
distributed
::
TreeIndex
;
using
paddle
::
distributed
::
IndexWrapper
;
using
paddle
::
distributed
::
IndexNode
;
void
BindIndexNode
(
py
::
module
*
m
)
{
py
::
class_
<
IndexNode
>
(
*
m
,
"IndexNode"
)
.
def
(
py
::
init
<>
())
.
def
(
"id"
,
[](
IndexNode
&
self
)
{
return
self
.
id
();
})
.
def
(
"is_leaf"
,
[](
IndexNode
&
self
)
{
return
self
.
is_leaf
();
})
.
def
(
"probability"
,
[](
IndexNode
&
self
)
{
return
self
.
probability
();
});
}
void
BindTreeIndex
(
py
::
module
*
m
)
{
py
::
class_
<
TreeIndex
,
std
::
shared_ptr
<
TreeIndex
>>
(
*
m
,
"TreeIndex"
)
.
def
(
py
::
init
([](
const
std
::
string
name
,
const
std
::
string
path
)
{
auto
index_wrapper
=
IndexWrapper
::
GetInstancePtr
();
index_wrapper
->
insert_tree_index
(
name
,
path
);
return
index_wrapper
->
get_tree_index
(
name
);
}))
.
def
(
"height"
,
[](
TreeIndex
&
self
)
{
return
self
.
Height
();
})
.
def
(
"branch"
,
[](
TreeIndex
&
self
)
{
return
self
.
Branch
();
})
.
def
(
"total_node_nums"
,
[](
TreeIndex
&
self
)
{
return
self
.
TotalNodeNums
();
})
.
def
(
"emb_size"
,
[](
TreeIndex
&
self
)
{
return
self
.
EmbSize
();
})
.
def
(
"get_all_leafs"
,
[](
TreeIndex
&
self
)
{
return
self
.
GetAllLeafs
();
})
.
def
(
"get_nodes"
,
[](
TreeIndex
&
self
,
const
std
::
vector
<
uint64_t
>&
codes
)
{
return
self
.
GetNodes
(
codes
);
})
.
def
(
"get_layer_codes"
,
[](
TreeIndex
&
self
,
int
level
)
{
return
self
.
GetLayerCodes
(
level
);
})
.
def
(
"get_ancestor_codes"
,
[](
TreeIndex
&
self
,
const
std
::
vector
<
uint64_t
>&
ids
,
int
level
)
{
return
self
.
GetAncestorCodes
(
ids
,
level
);
})
.
def
(
"get_children_codes"
,
[](
TreeIndex
&
self
,
uint64_t
ancestor
,
int
level
)
{
return
self
.
GetChildrenCodes
(
ancestor
,
level
);
})
.
def
(
"get_travel_codes"
,
[](
TreeIndex
&
self
,
uint64_t
id
,
int
start_level
)
{
return
self
.
GetTravelCodes
(
id
,
start_level
);
});
}
void
BindIndexWrapper
(
py
::
module
*
m
)
{
py
::
class_
<
IndexWrapper
,
std
::
shared_ptr
<
IndexWrapper
>>
(
*
m
,
"IndexWrapper"
)
.
def
(
py
::
init
([]()
{
return
IndexWrapper
::
GetInstancePtr
();
}))
.
def
(
"insert_tree_index"
,
&
IndexWrapper
::
insert_tree_index
)
.
def
(
"get_tree_index"
,
&
IndexWrapper
::
get_tree_index
)
.
def
(
"clear_tree"
,
&
IndexWrapper
::
clear_tree
);
}
using
paddle
::
distributed
::
IndexSampler
;
using
paddle
::
distributed
::
LayerWiseSampler
;
void
BindIndexSampler
(
py
::
module
*
m
)
{
py
::
class_
<
IndexSampler
,
std
::
shared_ptr
<
IndexSampler
>>
(
*
m
,
"IndexSampler"
)
.
def
(
py
::
init
([](
const
std
::
string
&
mode
,
const
std
::
string
&
name
)
{
if
(
mode
==
"by_layerwise"
)
{
return
IndexSampler
::
Init
<
LayerWiseSampler
>
(
name
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported IndexSampler Type!"
));
}
}))
.
def
(
"init_layerwise_conf"
,
&
IndexSampler
::
init_layerwise_conf
)
.
def
(
"init_beamsearch_conf"
,
&
IndexSampler
::
init_beamsearch_conf
)
.
def
(
"sample"
,
&
IndexSampler
::
sample
);
}
}
// end namespace pybind
}
// namespace paddle
paddle/fluid/pybind/fleet_py.h
浏览文件 @
a8c3a902
...
...
@@ -32,5 +32,9 @@ void BindGraphPyService(py::module* m);
void
BindGraphPyFeatureNode
(
py
::
module
*
m
);
void
BindGraphPyServer
(
py
::
module
*
m
);
void
BindGraphPyClient
(
py
::
module
*
m
);
void
BindIndexNode
(
py
::
module
*
m
);
void
BindTreeIndex
(
py
::
module
*
m
);
void
BindIndexWrapper
(
py
::
module
*
m
);
void
BindIndexSampler
(
py
::
module
*
m
);
}
// namespace pybind
}
// namespace paddle
paddle/fluid/pybind/pybind.cc
浏览文件 @
a8c3a902
...
...
@@ -3092,6 +3092,11 @@ All parameter, weight, gradient are variables in Paddle.
BindGraphPyService
(
&
m
);
BindGraphPyServer
(
&
m
);
BindGraphPyClient
(
&
m
);
BindIndexNode
(
&
m
);
BindTreeIndex
(
&
m
);
BindIndexWrapper
(
&
m
);
BindIndexSampler
(
&
m
);
#endif
}
}
// namespace pybind
...
...
python/paddle/distributed/fleet/dataset/__init__.py
浏览文件 @
a8c3a902
...
...
@@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
from
.dataset
import
*
from
.index_dataset
import
*
python/paddle/distributed/fleet/dataset/index_dataset.py
0 → 100644
浏览文件 @
a8c3a902
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid
import
core
class
Index
(
object
):
def
__init__
(
self
,
name
):
self
.
_name
=
name
class
TreeIndex
(
Index
):
def
__init__
(
self
,
name
,
path
):
super
(
TreeIndex
,
self
).
__init__
(
name
)
self
.
_wrapper
=
core
.
IndexWrapper
()
self
.
_wrapper
.
insert_tree_index
(
name
,
path
)
self
.
_tree
=
self
.
_wrapper
.
get_tree_index
(
name
)
self
.
_height
=
self
.
_tree
.
height
()
self
.
_branch
=
self
.
_tree
.
branch
()
self
.
_total_node_nums
=
self
.
_tree
.
total_node_nums
()
self
.
_emb_size
=
self
.
_tree
.
emb_size
()
self
.
_layerwise_sampler
=
None
def
height
(
self
):
return
self
.
_height
def
branch
(
self
):
return
self
.
_branch
def
total_node_nums
(
self
):
return
self
.
_total_node_nums
def
emb_size
(
self
):
return
self
.
_emb_size
def
get_all_leafs
(
self
):
return
self
.
_tree
.
get_all_leafs
()
def
get_nodes
(
self
,
codes
):
return
self
.
_tree
.
get_nodes
(
codes
)
def
get_layer_codes
(
self
,
level
):
return
self
.
_tree
.
get_layer_codes
(
level
)
def
get_travel_codes
(
self
,
id
,
start_level
=
0
):
return
self
.
_tree
.
get_travel_codes
(
id
,
start_level
)
def
get_ancestor_codes
(
self
,
ids
,
level
):
return
self
.
_tree
.
get_ancestor_codes
(
ids
,
level
)
def
get_children_codes
(
self
,
ancestor
,
level
):
return
self
.
_tree
.
get_children_codes
(
ancestor
,
level
)
def
get_travel_path
(
self
,
child
,
ancestor
):
res
=
[]
while
(
child
>
ancestor
):
res
.
append
(
child
)
child
=
int
((
child
-
1
)
/
self
.
_branch
)
return
res
def
get_pi_relation
(
self
,
ids
,
level
):
codes
=
self
.
get_ancestor_codes
(
ids
,
level
)
return
dict
(
zip
(
ids
,
codes
))
def
init_layerwise_sampler
(
self
,
layer_sample_counts
,
start_sample_layer
=
1
,
seed
=
0
):
assert
self
.
_layerwise_sampler
is
None
self
.
_layerwise_sampler
=
core
.
IndexSampler
(
"by_layerwise"
,
self
.
_name
)
self
.
_layerwise_sampler
.
init_layerwise_conf
(
layer_sample_counts
,
start_sample_layer
,
seed
)
def
layerwise_sample
(
self
,
user_input
,
index_input
,
with_hierarchy
=
False
):
if
self
.
_layerwise_sampler
is
None
:
raise
ValueError
(
"please init layerwise_sampler first."
)
return
self
.
_layerwise_sampler
.
sample
(
user_input
,
index_input
,
with_hierarchy
)
python/paddle/fluid/tests/unittests/test_dist_tree_index.py
0 → 100644
浏览文件 @
a8c3a902
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
paddle.dataset.common
import
download
,
DATA_HOME
from
paddle.distributed.fleet.dataset
import
TreeIndex
class
TestTreeIndex
(
unittest
.
TestCase
):
def
test_tree_index
(
self
):
path
=
download
(
"https://paddlerec.bj.bcebos.com/tree-based/data/demo_tree.pb"
,
"tree_index_unittest"
,
"cadec20089f5a8a44d320e117d9f9f1a"
)
tree
=
TreeIndex
(
"demo"
,
path
)
height
=
tree
.
height
()
branch
=
tree
.
branch
()
self
.
assertTrue
(
height
==
14
)
self
.
assertTrue
(
branch
==
2
)
self
.
assertEqual
(
tree
.
total_node_nums
(),
15581
)
self
.
assertEqual
(
tree
.
emb_size
(),
5171136
)
# get_layer_codes
layer_node_ids
=
[]
layer_node_codes
=
[]
for
i
in
range
(
tree
.
height
()):
layer_node_codes
.
append
(
tree
.
get_layer_codes
(
i
))
layer_node_ids
.
append
(
[
node
.
id
()
for
node
in
tree
.
get_nodes
(
layer_node_codes
[
-
1
])])
all_leaf_ids
=
[
node
.
id
()
for
node
in
tree
.
get_all_leafs
()]
self
.
assertEqual
(
sum
(
all_leaf_ids
),
sum
(
layer_node_ids
[
-
1
]))
# get_travel
travel_codes
=
tree
.
get_travel_codes
(
all_leaf_ids
[
0
])
travel_ids
=
[
node
.
id
()
for
node
in
tree
.
get_nodes
(
travel_codes
)]
for
i
in
range
(
height
):
self
.
assertIn
(
travel_ids
[
i
],
layer_node_ids
[
height
-
1
-
i
])
self
.
assertIn
(
travel_codes
[
i
],
layer_node_codes
[
height
-
1
-
i
])
# get_ancestor
ancestor_codes
=
tree
.
get_ancestor_codes
([
all_leaf_ids
[
0
]],
height
-
2
)
ancestor_ids
=
[
node
.
id
()
for
node
in
tree
.
get_nodes
(
ancestor_codes
)]
self
.
assertEqual
(
ancestor_ids
[
0
],
travel_ids
[
1
])
self
.
assertEqual
(
ancestor_codes
[
0
],
travel_codes
[
1
])
# get_pi_relation
pi_relation
=
tree
.
get_pi_relation
([
all_leaf_ids
[
0
]],
height
-
2
)
self
.
assertEqual
(
pi_relation
[
all_leaf_ids
[
0
]],
ancestor_codes
[
0
])
# get_travel_path
travel_path_codes
=
tree
.
get_travel_path
(
travel_codes
[
0
],
travel_codes
[
-
1
])
travel_path_ids
=
[
node
.
id
()
for
node
in
tree
.
get_nodes
(
travel_path_codes
)
]
self
.
assertEquals
(
travel_path_ids
+
[
travel_ids
[
-
1
]],
travel_ids
)
self
.
assertEquals
(
travel_path_codes
+
[
travel_codes
[
-
1
]],
travel_codes
)
# get_children
children_codes
=
tree
.
get_children_codes
(
travel_codes
[
1
],
height
-
1
)
children_ids
=
[
node
.
id
()
for
node
in
tree
.
get_nodes
(
children_codes
)]
self
.
assertIn
(
all_leaf_ids
[
0
],
children_ids
)
class
TestIndexSampler
(
unittest
.
TestCase
):
def
test_layerwise_sampler
(
self
):
path
=
download
(
"https://paddlerec.bj.bcebos.com/tree-based/data/demo_tree.pb"
,
"tree_index_unittest"
,
"cadec20089f5a8a44d320e117d9f9f1a"
)
tree
=
TreeIndex
(
"demo"
,
path
)
layer_nodes
=
[]
for
i
in
range
(
tree
.
height
()):
layer_codes
=
tree
.
get_layer_codes
(
i
)
layer_nodes
.
append
(
[
node
.
id
()
for
node
in
tree
.
get_nodes
(
layer_codes
)])
sample_num
=
range
(
1
,
10000
)
start_sample_layer
=
1
seed
=
0
sample_layers
=
tree
.
height
()
-
start_sample_layer
sample_num
=
sample_num
[:
sample_layers
]
layer_sample_counts
=
list
(
sample_num
)
+
[
1
]
*
(
sample_layers
-
len
(
sample_num
))
total_sample_num
=
sum
(
layer_sample_counts
)
+
len
(
layer_sample_counts
)
tree
.
init_layerwise_sampler
(
sample_num
,
start_sample_layer
,
seed
)
ids
=
[
315757
,
838060
,
1251533
,
403522
,
2473624
,
3321007
]
parent_path
=
{}
for
i
in
range
(
len
(
ids
)):
tmp
=
tree
.
get_travel_codes
(
ids
[
i
],
start_sample_layer
)
parent_path
[
ids
[
i
]]
=
[
node
.
id
()
for
node
in
tree
.
get_nodes
(
tmp
)]
# check sample res with_hierarchy = False
sample_res
=
tree
.
layerwise_sample
(
[[
315757
,
838060
],
[
1251533
,
403522
]],
[
2473624
,
3321007
],
False
)
idx
=
0
layer
=
tree
.
height
()
-
1
for
i
in
range
(
len
(
layer_sample_counts
)):
for
j
in
range
(
layer_sample_counts
[
0
-
(
i
+
1
)]
+
1
):
self
.
assertTrue
(
sample_res
[
idx
+
j
][
0
]
==
315757
)
self
.
assertTrue
(
sample_res
[
idx
+
j
][
1
]
==
838060
)
self
.
assertTrue
(
sample_res
[
idx
+
j
][
2
]
in
layer_nodes
[
layer
])
if
j
==
0
:
self
.
assertTrue
(
sample_res
[
idx
+
j
][
3
]
==
1
)
self
.
assertTrue
(
sample_res
[
idx
+
j
][
2
]
==
parent_path
[
2473624
][
i
])
else
:
self
.
assertTrue
(
sample_res
[
idx
+
j
][
3
]
==
0
)
self
.
assertTrue
(
sample_res
[
idx
+
j
][
2
]
!=
parent_path
[
2473624
][
i
])
idx
+=
layer_sample_counts
[
0
-
(
i
+
1
)]
+
1
layer
-=
1
self
.
assertTrue
(
idx
==
total_sample_num
)
layer
=
tree
.
height
()
-
1
for
i
in
range
(
len
(
layer_sample_counts
)):
for
j
in
range
(
layer_sample_counts
[
0
-
(
i
+
1
)]
+
1
):
self
.
assertTrue
(
sample_res
[
idx
+
j
][
0
]
==
1251533
)
self
.
assertTrue
(
sample_res
[
idx
+
j
][
1
]
==
403522
)
self
.
assertTrue
(
sample_res
[
idx
+
j
][
2
]
in
layer_nodes
[
layer
])
if
j
==
0
:
self
.
assertTrue
(
sample_res
[
idx
+
j
][
3
]
==
1
)
self
.
assertTrue
(
sample_res
[
idx
+
j
][
2
]
==
parent_path
[
3321007
][
i
])
else
:
self
.
assertTrue
(
sample_res
[
idx
+
j
][
3
]
==
0
)
self
.
assertTrue
(
sample_res
[
idx
+
j
][
2
]
!=
parent_path
[
3321007
][
i
])
idx
+=
layer_sample_counts
[
0
-
(
i
+
1
)]
+
1
layer
-=
1
self
.
assertTrue
(
idx
==
total_sample_num
*
2
)
# check sample res with_hierarchy = True
sample_res_with_hierarchy
=
tree
.
layerwise_sample
(
[[
315757
,
838060
],
[
1251533
,
403522
]],
[
2473624
,
3321007
],
True
)
idx
=
0
layer
=
tree
.
height
()
-
1
for
i
in
range
(
len
(
layer_sample_counts
)):
for
j
in
range
(
layer_sample_counts
[
0
-
(
i
+
1
)]
+
1
):
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
0
]
==
parent_path
[
315757
][
i
])
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
1
]
==
parent_path
[
838060
][
i
])
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
2
]
in
layer_nodes
[
layer
])
if
j
==
0
:
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
3
]
==
1
)
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
2
]
==
parent_path
[
2473624
][
i
])
else
:
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
3
]
==
0
)
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
2
]
!=
parent_path
[
2473624
][
i
])
idx
+=
layer_sample_counts
[
0
-
(
i
+
1
)]
+
1
layer
-=
1
self
.
assertTrue
(
idx
==
total_sample_num
)
layer
=
tree
.
height
()
-
1
for
i
in
range
(
len
(
layer_sample_counts
)):
for
j
in
range
(
layer_sample_counts
[
0
-
(
i
+
1
)]
+
1
):
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
0
]
==
parent_path
[
1251533
][
i
])
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
1
]
==
parent_path
[
403522
][
i
])
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
2
]
in
layer_nodes
[
layer
])
if
j
==
0
:
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
3
]
==
1
)
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
2
]
==
parent_path
[
3321007
][
i
])
else
:
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
3
]
==
0
)
self
.
assertTrue
(
sample_res_with_hierarchy
[
idx
+
j
][
2
]
!=
parent_path
[
3321007
][
i
])
idx
+=
layer_sample_counts
[
0
-
(
i
+
1
)]
+
1
layer
-=
1
self
.
assertTrue
(
idx
==
2
*
total_sample_num
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录