Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ae3bb16d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ae3bb16d
编写于
3月 26, 2020
作者:
D
danleifeng
提交者:
GitHub
3月 26, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add MaskAucCalculator in paddlebox (#23157)
* add maskauc in paddlebox; test=develop
上级
6af480ca
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
40 addition
and
3 deletion
+40
-3
paddle/fluid/framework/fleet/box_wrapper.h
paddle/fluid/framework/fleet/box_wrapper.h
+40
-3
未找到文件。
paddle/fluid/framework/fleet/box_wrapper.h
浏览文件 @
ae3bb16d
...
...
@@ -413,6 +413,38 @@ class BoxWrapper {
std
::
vector
<
std
::
pair
<
int
,
int
>>
cmatch_rank_v
;
std
::
string
cmatch_rank_varname_
;
};
class
MaskMetricMsg
:
public
MetricMsg
{
public:
MaskMetricMsg
(
const
std
::
string
&
label_varname
,
const
std
::
string
&
pred_varname
,
int
is_join
,
const
std
::
string
&
mask_varname
,
int
bucket_size
=
1000000
)
{
label_varname_
=
label_varname
;
pred_varname_
=
pred_varname
;
mask_varname_
=
mask_varname
;
is_join_
=
is_join
;
calculator
=
new
BasicAucCalculator
();
calculator
->
init
(
bucket_size
);
}
virtual
~
MaskMetricMsg
()
{}
void
add_data
(
const
Scope
*
exe_scope
)
override
{
std
::
vector
<
int64_t
>
label_data
;
get_data
<
int64_t
>
(
exe_scope
,
label_varname_
,
&
label_data
);
std
::
vector
<
float
>
pred_data
;
get_data
<
float
>
(
exe_scope
,
pred_varname_
,
&
pred_data
);
std
::
vector
<
int64_t
>
mask_data
;
get_data
<
int64_t
>
(
exe_scope
,
mask_varname_
,
&
mask_data
);
auto
cal
=
GetCalculator
();
auto
batch_size
=
label_data
.
size
();
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
mask_data
[
i
]
==
1
)
{
cal
->
add_data
(
pred_data
[
i
],
label_data
[
i
]);
}
}
}
protected:
std
::
string
mask_varname_
;
};
const
std
::
vector
<
std
::
string
>&
GetMetricNameList
()
const
{
return
metric_name_list_
;
}
...
...
@@ -423,7 +455,8 @@ class BoxWrapper {
void
InitMetric
(
const
std
::
string
&
method
,
const
std
::
string
&
name
,
const
std
::
string
&
label_varname
,
const
std
::
string
&
pred_varname
,
const
std
::
string
&
cmatch_rank_varname
,
bool
is_join
,
const
std
::
string
&
cmatch_rank_varname
,
const
std
::
string
&
mask_varname
,
bool
is_join
,
const
std
::
string
&
cmatch_rank_group
,
int
bucket_size
=
1000000
)
{
if
(
method
==
"AucCalculator"
)
{
...
...
@@ -439,10 +472,14 @@ class BoxWrapper {
name
,
new
CmatchRankMetricMsg
(
label_varname
,
pred_varname
,
is_join
?
1
:
0
,
cmatch_rank_group
,
cmatch_rank_varname
,
bucket_size
));
}
else
if
(
method
==
"MaskAucCalculator"
)
{
metric_lists_
.
emplace
(
name
,
new
MaskMetricMsg
(
label_varname
,
pred_varname
,
is_join
?
1
:
0
,
mask_varname
,
bucket_size
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"PaddleBox only support AucCalculator, MultiTaskAucCalculator
and
"
"CmatchRankAucCalculator"
));
"PaddleBox only support AucCalculator, MultiTaskAucCalculator "
"CmatchRankAucCalculator
and MaskAucCalculator
"
));
}
metric_name_list_
.
emplace_back
(
name
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录