Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
b6f61faf
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b6f61faf
编写于
9月 18, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix adam
上级
2d898491
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
68 addition
and
25 deletion
+68
-25
paddle/fluid/operators/adam_op.h
paddle/fluid/operators/adam_op.h
+49
-23
paddle/fluid/operators/math/selected_rows_functor.cc
paddle/fluid/operators/math/selected_rows_functor.cc
+8
-1
paddle/fluid/operators/math/selected_rows_functor.cu
paddle/fluid/operators/math/selected_rows_functor.cu
+8
-1
paddle/fluid/operators/math/selected_rows_functor.h
paddle/fluid/operators/math/selected_rows_functor.h
+3
-0
未找到文件。
paddle/fluid/operators/adam_op.h
浏览文件 @
b6f61faf
...
...
@@ -174,12 +174,13 @@ struct SparseAdamFunctor {
const
int64_t
*
rows_
;
int64_t
row_numel_
;
int64_t
row_count_
;
SparseAdamFunctor
(
T
beta1
,
T
beta2
,
T
epsilon
,
const
T
*
beta1_pow
,
const
T
*
beta2_pow
,
const
T
*
mom1
,
T
*
mom1_out
,
const
T
*
mom2
,
T
*
mom2_out
,
const
T
*
lr
,
const
T
*
grad
,
const
T
*
param
,
T
*
param_out
,
const
int64_t
*
rows
,
int64_t
row_numel
)
int64_t
row_numel
,
int64_t
row_count
)
:
beta1_
(
beta1
),
beta2_
(
beta2
),
epsilon_
(
epsilon
),
...
...
@@ -194,28 +195,47 @@ struct SparseAdamFunctor {
param_
(
param
),
param_out_
(
param_out
),
rows_
(
rows
),
row_numel_
(
row_numel
)
{}
row_numel_
(
row_numel
),
row_count_
(
row_count
)
{}
inline
HOSTDEVICE
int64_t
BinarySearchInRows
(
int64_t
row
)
const
{
int64_t
beg
=
0
,
end
=
row_count_
-
1
;
while
(
beg
<=
end
)
{
auto
mid
=
((
beg
+
end
)
>>
1
);
if
(
rows_
[
mid
]
==
row
)
return
mid
;
else
if
(
rows_
[
mid
]
<
row
)
beg
=
mid
+
1
;
else
end
=
mid
-
1
;
}
return
-
1
;
}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
int64_t
row
=
i
/
row_numel_
;
auto
row_idx
=
BinarySearchInRows
(
row
);
T
g
=
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
// The following code is the same as dense
T
mom1
=
moment1_
[
i
];
T
mom2
=
moment2_
[
i
];
T
lr
=
*
lr_
;
T
beta1_pow
=
*
beta1_pow_
;
T
beta2_pow
=
*
beta2_pow_
;
for
(
int64_t
j
=
0
;
j
<
row_numel_
;
++
j
)
{
T
g
=
grad_
[
i
*
row_numel_
+
j
];
T
mom1
=
moment1_
[
rows_
[
i
]
*
row_numel_
+
j
];
T
mom2
=
moment2_
[
rows_
[
i
]
*
row_numel_
+
j
];
T
lr
=
*
lr_
;
T
p
=
param_
[
rows_
[
i
]
*
row_numel_
+
j
];
T
p
=
param_
[
i
];
// Calculation
lr
*=
sqrt
(
1
-
beta2_pow
)
/
(
1
-
beta1_pow
);
mom1
=
beta1_
*
mom1
+
(
1
-
beta1_
)
*
g
;
mom2
=
beta2_
*
mom2
+
(
1
-
beta2_
)
*
g
*
g
;
p
-=
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon_
));
moment1_out_
[
rows_
[
i
]
*
row_numel_
+
j
]
=
mom1
;
moment2_out_
[
rows_
[
i
]
*
row_numel_
+
j
]
=
mom2
;
param_out_
[
rows_
[
i
]
*
row_numel_
+
j
]
=
p
;
}
// for col id
// Write back to global memory
moment1_out_
[
i
]
=
mom1
;
moment2_out_
[
i
]
=
mom2
;
param_out_
[
i
]
=
p
;
}
};
...
...
@@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
return
;
}
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
auto
grad_merge
=
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
);
auto
&
grad_merge
=
*
(
ctx
.
scope
()
.
NewScope
()
.
Var
(
"sparse_adam_grad_merge"
)
->
GetMutable
<
framework
::
SelectedRows
>
());
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
,
&
grad_merge
);
auto
&
grad_tensor
=
grad_merge
.
value
();
const
T
*
grad_data
=
grad_tensor
.
template
data
<
T
>();
int64_t
*
rows
=
nullptr
;
...
...
@@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
mom2
.
template
data
<
T
>(),
mom2_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
lr
.
template
data
<
T
>(),
grad_data
,
param
.
template
data
<
T
>(),
param_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
rows
,
row_numel
);
param_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
rows
,
row_numel
,
grad_merge
.
rows
().
size
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
grad_merge
.
rows
().
size
());
param
.
numel
());
for_range
(
functor
);
}
else
{
PADDLE_THROW
(
"Variable type not supported by adam_op"
);
...
...
paddle/fluid/operators/math/selected_rows_functor.cc
浏览文件 @
b6f61faf
...
...
@@ -199,6 +199,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
framework
::
SelectedRows
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
)
{
framework
::
SelectedRows
out
;
(
*
this
)(
context
,
input
,
&
out
);
return
out
;
}
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
framework
::
SelectedRows
*
output
)
{
framework
::
SelectedRows
&
out
=
*
output
;
auto
input_rows
=
input
.
rows
();
std
::
set
<
int64_t
>
row_set
(
input_rows
.
begin
(),
input_rows
.
end
());
std
::
vector
<
int64_t
>
merge_rows
(
row_set
.
begin
(),
row_set
.
end
());
...
...
@@ -223,7 +231,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
out_data
[
out_i
*
input_width
+
j
]
+=
input_data
[
i
*
input_width
+
j
];
}
}
return
out
;
}
};
...
...
paddle/fluid/operators/math/selected_rows_functor.cu
浏览文件 @
b6f61faf
...
...
@@ -262,6 +262,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
framework
::
SelectedRows
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
)
{
framework
::
SelectedRows
out
;
(
*
this
)(
context
,
input
,
&
out
);
return
out
;
}
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
framework
::
SelectedRows
*
output
)
{
framework
::
SelectedRows
&
out
=
*
output
;
framework
::
Vector
<
int64_t
>
input_rows
(
input
.
rows
());
std
::
set
<
int64_t
>
row_set
(
input_rows
.
begin
(),
input_rows
.
end
());
std
::
vector
<
int64_t
>
merge_rows
(
row_set
.
begin
(),
row_set
.
end
());
...
...
@@ -292,7 +300,6 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
input_data
,
input_rows
.
CUDAData
(
context
.
GetPlace
()),
out_data
,
out
.
mutable_rows
()
->
CUDAMutableData
(
context
.
GetPlace
()),
out
.
rows
().
size
(),
input_width
);
return
out
;
}
};
...
...
paddle/fluid/operators/math/selected_rows_functor.h
浏览文件 @
b6f61faf
...
...
@@ -65,6 +65,9 @@ struct MergeAdd {
// the input SelectedRows object.
framework
::
SelectedRows
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
);
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
framework
::
SelectedRows
*
output
);
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录