Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ef24bd78
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ef24bd78
编写于
9月 08, 2020
作者:
M
malin10
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add tdm_tree
上级
ecc59dcd
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
335 addition
and
0 deletion
+335
-0
paddle/fluid/framework/fleet/tree_wrapper.cc
paddle/fluid/framework/fleet/tree_wrapper.cc
+195
-0
paddle/fluid/framework/fleet/tree_wrapper.h
paddle/fluid/framework/fleet/tree_wrapper.h
+140
-0
未找到文件。
paddle/fluid/framework/fleet/tree_wrapper.cc
0 → 100644
浏览文件 @
ef24bd78
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
framework
{
int
Tree
::
load
(
std
::
string
path
,
std
::
string
tree_pipe_command_
)
{
uint64_t
linenum
=
0
;
size_t
idx
=
0
;
std
::
vector
<
std
::
string
>
lines
;
std
::
vector
<
std
::
string
>
strs
;
std
::
vector
<
std
::
string
>
items
;
int
err_no
;
std
::
shared_ptr
<
FILE
>
fp_
=
fs_open_read
(
path
,
&
err_no
,
tree_pipe_command_
);
string
::
LineFileReader
reader
;
while
(
reader
.
getline
(
&*
(
fp_
.
get
())))
{
line
=
std
::
string
(
reader
.
get
());
strs
.
clear
();
boost
::
split
(
strs
,
line
,
boost
::
is_any_of
(
"
\t
"
));
if
(
0
==
linenum
)
{
_total_node_num
=
boost
::
lexical_cast
<
size_t
>
(
strs
[
0
]);
_nodes
=
new
Node
[
_total_node_num
];
if
(
strs
.
size
()
>
1
)
{
_tree_height
=
boost
::
lexical_cast
<
int16_t
>
(
strs
[
1
]);
}
++
linenum
;
continue
;
}
if
(
strs
.
size
()
<
4
)
{
LOG
(
WARNING
)
<<
"each line must has more than field"
;
return
-
1
;
}
Node
&
node
=
_nodes
[
idx
];
// id
node
.
id
=
boost
::
lexical_cast
<
uint64_t
>
(
strs
[
0
]);
// embedding
items
.
clear
();
if
(
!
strs
[
1
].
empty
())
{
boost
::
split
(
items
,
strs
[
1
],
boost
::
is_any_of
(
" "
));
for
(
size_t
i
=
0
;
i
!=
items
.
size
();
++
i
)
{
node
.
embedding
.
emplace_back
(
boost
::
lexical_cast
<
float
>
(
items
[
i
]));
}
}
// parent
items
.
clear
();
if
(
!
strs
[
2
].
empty
())
{
node
.
parent_node
=
_nodes
+
boost
::
lexical_cast
<
int
>
(
strs
[
2
]);
}
// child
items
.
clear
();
if
(
!
strs
[
3
].
empty
())
{
boost
::
split
(
items
,
strs
[
3
],
boost
::
is_any_of
(
" "
));
// node.sub_nodes = new Node*[items.size()];
for
(
size_t
i
=
0
;
i
!=
items
.
size
();
++
i
)
{
node
.
sub_nodes
.
push_back
(
_nodes
+
boost
::
lexical_cast
<
int
>
(
items
[
i
]));
// node.sub_nodes[i] = _nodes + boost::lexical_cast<int>(items[i]);
}
// node.sub_node_num = items.size();
}
else
{
//没有孩子节点,当前节点是叶节点
_leaf_node_map
[
node
.
id
]
=
&
node
;
// node.sub_node_num = 0;
}
if
(
strs
.
size
()
>
4
)
{
node
.
height
=
boost
::
lexical_cast
<
int16_t
>
(
strs
[
4
]);
}
++
idx
;
++
linenum
;
}
_head
=
_nodes
+
_total_node_num
-
1
;
LOG
(
INFO
)
<<
"all lines:"
<<
linenum
<<
", all tree nodes:"
<<
idx
;
return
0
;
}
void
Tree
::
print_tree
()
{
/*
std::queue<Node*> q;
if (_head) {
q.push(_head);
}
while (!q.empty()) {
const Node* node = q.front();
q.pop();
std::cout << "node_id: " << node->id << std::endl;
std::cout << "node_embedding: ";
for (int i = 0; i != node->embedding.size(); ++i) {
std::cout << node->embedding[i] << " ";
}
std::cout << std::endl;
if (node->parent_node) {
std::cout << "parent_idx: " << node->parent_node - _nodes <<
std::endl;
}
if (node->sub_node_num > 0) {
for (int i = 0; i != node->sub_node_num; ++i) {
std::cout << "child_idx" << i << ": " << node->sub_nodes[i] - _nodes
<< std::endl;
}
}
std::cout << "-------------------------------------" << std::endl;
for (int i = 0; i != node->sub_node_num; ++i) {
Node* tmp_node = node->sub_nodes[i];
q.push(tmp_node);
}
}
*/
}
int
Tree
::
dump_tree
(
const
uint64_t
table_id
,
int
fea_value_dim
,
const
std
::
string
tree_path
)
{
int
ret
;
std
::
shared_ptr
<
FILE
>
fp
=
paddle
::
framework
::
fs_open
(
tree_path
,
"w"
,
&
ret
,
""
);
std
::
vector
<
uint64_t
>
fea_keys
,
std
::
vector
<
float
*>
pull_result_ptr
;
fea_keys
.
reserve
(
_total_node_num
);
pull_result_ptr
.
reserve
(
_total_node_num
);
for
(
size_t
i
=
0
;
i
!=
_total_node_num
;
++
i
)
{
_nodes
[
i
].
embedding
.
resize
(
fea_value_dim
);
fea_key
.
push_back
(
_nodes
[
i
].
id
);
pull_result_ptr
.
push_back
(
_nodes
[
i
].
embedding
.
data
());
}
std
::
string
first_line
=
boost
::
lexical_cast
<
std
::
string
>
(
_total_node_num
)
+
"
\t
"
+
boost
::
lexical_cast
<
std
::
string
>
(
_tree_height
);
fwrite
(
first_line
.
c_str
(),
first_line
.
length
(),
1
,
&*
fp
);
std
::
string
line_break_str
(
"
\n
"
);
std
::
string
line
(
""
);
for
(
size_t
i
=
0
;
i
!=
_total_node_num
;
++
i
)
{
line
=
line_break_str
;
const
Node
&
node
=
_nodes
[
i
];
line
+=
boost
::
lexical_cast
<
std
::
string
>
(
node
.
id
)
+
"
\t
"
;
if
(
!
node
.
embedding
.
empty
())
{
for
(
size_t
j
=
0
;
j
!=
node
.
embedding
.
size
()
-
1
;
++
j
)
{
line
+=
boost
::
lexical_cast
<
std
::
string
>
(
node
.
embedding
[
j
])
+
" "
;
}
line
+=
boost
::
lexical_cast
<
std
::
string
>
(
node
.
embedding
[
node
.
embedding
.
size
()
-
1
]);
}
else
{
LOG
(
WARNING
)
<<
"node_idx["
<<
i
<<
"], id["
<<
node
.
id
<<
"] "
<<
"has no embeddings"
;
}
line
+=
"
\t
"
;
if
(
node
.
parent_node
)
{
line
+=
boost
::
lexical_cast
<
std
::
string
>
(
node
.
parent_node
-
_nodes
);
}
line
+=
"
\t
"
;
if
(
node
.
sub_nodes
.
size
()
>
0
)
{
for
(
uint32_t
j
=
0
;
j
<
node
.
sub_nodes
.
size
()
-
1
;
++
j
)
{
line
+=
boost
::
lexical_cast
<
std
::
string
>
(
node
.
sub_nodes
[
j
]
-
_nodes
)
+
" "
;
}
line
+=
boost
::
lexical_cast
<
std
::
string
>
(
node
.
sub_nodes
[
node
.
sub_nodes
.
size
()
-
1
]
-
_nodes
);
}
line
+=
"
\t
"
+
boost
::
lexical_cast
<
std
::
string
>
(
node
.
height
);
fwrite
(
line
.
c_str
(),
line
.
length
(),
1
,
&*
fp
);
}
return
0
;
}
bool
Tree
::
trace_back
(
uint64_t
id
,
std
::
vector
<
std
::
pair
<
uint64_t
,
uint32_t
>>&
ids
)
{
ids
.
clear
();
std
::
unordered_map
<
uint64_t
,
Node
*>::
iterator
find_it
=
_leaf_node_map
.
find
(
id
);
if
(
find_it
==
_leaf_node_map
.
end
())
{
return
false
;
}
else
{
uint32_t
height
=
0
;
Node
*
node
=
find_it
->
second
;
while
(
node
!=
NULL
)
{
height
++
;
ids
.
emplace_back
(
node
->
id
,
0
);
node
=
node
->
parent_node
;
}
for
(
auto
&
id
:
ids
)
{
id
.
second
=
height
--
;
}
}
return
true
;
}
Node
*
Tree
::
get_node
()
{
return
_nodes
;
}
size_t
Tree
::
get_total_node_num
()
{
return
_total_node_num
;
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/fleet/tree_wrapper.h
0 → 100644
浏览文件 @
ef24bd78
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
framework
{
struct
Node
{
Node
::
Node
()
:
parent_node
(
NULL
),
id
(
0
),
height
(
0
)
{}
~
Node
(){};
std
::
vector
<
Node
*>
sub_nodes
;
// uint32_t sub_node_num;
Node
*
parent_node
;
uint64_t
id
;
std
::
vector
<
float
>
embedding
;
int16_t
height
;
//层级
};
class
Tree
{
public:
Tree
()
:
_nodes
(
NULL
),
_head
(
NULL
)
{}
~
Tree
()
{
if
(
_nodes
)
{
delete
[]
_nodes
;
_nodes
=
NULL
;
}
}
void
print_tree
();
int
dump_tree
(
const
uint64_t
table_id
,
int
fea_value_dim
,
const
std
::
string
tree_path
);
//采样:从叶节点回溯到根节点
void
trace_back
(
uint64_t
id
,
std
::
vector
<
std
::
pair
<
uint64_t
,
uint32_t
>>&
ids
);
int
load
(
std
::
string
path
);
Node
*
get_node
();
size_t
get_total_node_num
();
private:
// tree data info
Node
*
_nodes
{
nullptr
};
// head pointer
Node
*
_head
{
nullptr
};
// total number of nodes
size_t
_total_node_num
{
0
};
// leaf node map
std
::
unordered_map
<
uint64_t
,
Node
*>
_leaf_node_map
;
// version
std
::
string
_version
{
""
};
//树的高度
int16_t
_tree_height
{
0
};
};
using
TreePtr
=
std
::
shared_ptr
<
Tree
>
;
class
TreeWrapper
{
public:
virtual
~
TreeWrapper
()
{}
TreeWrapper
()
{}
// TreeWrapper singleton
static
std
::
shared_ptr
<
TreeWrapper
>
GetInstance
()
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
framework
::
TreeWrapper
());
}
return
s_instance_
;
}
void
clear
()
{
tree_map
.
clear
();
}
void
insert
(
std
::
string
name
,
std
::
string
tree_path
)
{
if
(
tree_map
.
find
(
name
)
!=
tree_map
.
end
())
{
return
;
}
TreePtr
tree
=
new
Tree
();
tree
.
load
(
tree_path
);
tree_map
.
insert
(
std
::
pair
<
std
::
string
,
TreePtr
>
{
name
,
tree
});
}
void
dump
(
std
::
string
name
,
const
uint64_t
table_id
,
int
fea_value_dim
,
const
std
::
string
tree_path
)
{
if
(
tree_map
.
find
(
name
)
==
tree_map
.
end
())
{
return
;
}
tree_map
.
at
(
name
)
->
dump_tree
(
table_id
,
fea_value_dim
,
tree_path
);
}
void
sample
(
const
uint16_t
sample_slot
,
const
uint64_t
type_slot
,
std
::
vector
<
Record
>&
src_datas
,
std
::
vector
<
Record
>&
sample_results
)
{
sample_results
.
clear
();
for
(
auto
&
data
:
src_datas
)
{
uint64_t
sample_feasign_idx
=
-
1
,
type_feasign_idx
=
-
1
;
for
(
auto
i
=
0
;
i
<
data
.
uint64_feasigns_
.
size
();
i
++
)
{
if
(
data
.
uint64_feasigns_
[
i
].
slot
()
==
sample_slot
)
{
sample_feasign_idx
=
i
;
}
if
(
data
.
uint64_feasigns_
.
slot
()
==
type_slot
)
{
type_feasign_idx
=
i
;
}
}
if
(
sample_feasign_idx
>
0
)
{
std
::
vector
<
std
::
pair
<
uint64_t
,
uint32_t
>>
trace_ids
;
for
(
auto
name
:
tree_map
)
{
bool
in_tree
=
tree_map
.
at
(
name
)
->
trace_back
(
data
.
uint64_feasigns_
[
sample_feasign_idx
].
sign
().
uint64_feasign_
,
trace_ids
);
if
(
in_tree
)
{
break
;
}
else
{
PADDLE_ENFORCE_EQ
(
trace_ids
.
size
(),
0
,
""
);
}
}
for
(
auto
i
=
0
;
i
<
trace_ids
.
size
();
i
++
)
{
Record
instance
(
data
);
instance
.
uint64_feasigns_
[
sample_feasign_idx
].
sign
().
uint64_feasign_
=
trace_ids
[
i
].
first
;
if
(
type_feasign_idx
>
0
)
instance
.
uint64_feasigns_
[
type_feasign_idx
]
.
sign
()
.
uint64_feasign_
+=
trace_ids
[
i
].
second
*
100
;
sample_results
.
push_back
(
instance
);
}
}
}
return
;
}
public:
std
::
unordered_map
<
std
::
string
,
TreePtr
>
tree_map
;
private:
static
std
::
shared_ptr
<
TreeWrapper
>
s_instance_
;
};
}
// end namespace framework
}
// end namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录