Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7f1e3126
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7f1e3126
编写于
9月 20, 2018
作者:
Z
Zeng Jinle
提交者:
GitHub
9月 20, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13456 from sneaxiy/refine_sparse_adam
Fix sparse Adam and Gradient clip of SelectedRows
上级
b7588751
a29b4227
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
103 addition
and
42 deletion
+103
-42
paddle/fluid/operators/adam_op.h
paddle/fluid/operators/adam_op.h
+49
-23
paddle/fluid/operators/clip_op.h
paddle/fluid/operators/clip_op.h
+32
-11
paddle/fluid/operators/math/selected_rows_functor.cc
paddle/fluid/operators/math/selected_rows_functor.cc
+8
-1
paddle/fluid/operators/math/selected_rows_functor.cu
paddle/fluid/operators/math/selected_rows_functor.cu
+11
-7
paddle/fluid/operators/math/selected_rows_functor.h
paddle/fluid/operators/math/selected_rows_functor.h
+3
-0
未找到文件。
paddle/fluid/operators/adam_op.h
浏览文件 @
7f1e3126
...
...
@@ -174,12 +174,13 @@ struct SparseAdamFunctor {
const
int64_t
*
rows_
;
int64_t
row_numel_
;
int64_t
row_count_
;
SparseAdamFunctor
(
T
beta1
,
T
beta2
,
T
epsilon
,
const
T
*
beta1_pow
,
const
T
*
beta2_pow
,
const
T
*
mom1
,
T
*
mom1_out
,
const
T
*
mom2
,
T
*
mom2_out
,
const
T
*
lr
,
const
T
*
grad
,
const
T
*
param
,
T
*
param_out
,
const
int64_t
*
rows
,
int64_t
row_numel
)
int64_t
row_numel
,
int64_t
row_count
)
:
beta1_
(
beta1
),
beta2_
(
beta2
),
epsilon_
(
epsilon
),
...
...
@@ -194,28 +195,47 @@ struct SparseAdamFunctor {
param_
(
param
),
param_out_
(
param_out
),
rows_
(
rows
),
row_numel_
(
row_numel
)
{}
row_numel_
(
row_numel
),
row_count_
(
row_count
)
{}
inline
HOSTDEVICE
int64_t
BinarySearchInRows
(
int64_t
row
)
const
{
int64_t
beg
=
0
,
end
=
row_count_
-
1
;
while
(
beg
<=
end
)
{
auto
mid
=
((
beg
+
end
)
>>
1
);
if
(
rows_
[
mid
]
==
row
)
return
mid
;
else
if
(
rows_
[
mid
]
<
row
)
beg
=
mid
+
1
;
else
end
=
mid
-
1
;
}
return
-
1
;
}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
int64_t
row
=
i
/
row_numel_
;
auto
row_idx
=
BinarySearchInRows
(
row
);
T
g
=
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
// The following code is the same as dense
T
mom1
=
moment1_
[
i
];
T
mom2
=
moment2_
[
i
];
T
lr
=
*
lr_
;
T
beta1_pow
=
*
beta1_pow_
;
T
beta2_pow
=
*
beta2_pow_
;
for
(
int64_t
j
=
0
;
j
<
row_numel_
;
++
j
)
{
T
g
=
grad_
[
i
*
row_numel_
+
j
];
T
mom1
=
moment1_
[
rows_
[
i
]
*
row_numel_
+
j
];
T
mom2
=
moment2_
[
rows_
[
i
]
*
row_numel_
+
j
];
T
lr
=
*
lr_
;
T
p
=
param_
[
rows_
[
i
]
*
row_numel_
+
j
];
T
p
=
param_
[
i
];
// Calculation
lr
*=
sqrt
(
1
-
beta2_pow
)
/
(
1
-
beta1_pow
);
mom1
=
beta1_
*
mom1
+
(
1
-
beta1_
)
*
g
;
mom2
=
beta2_
*
mom2
+
(
1
-
beta2_
)
*
g
*
g
;
p
-=
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon_
));
moment1_out_
[
rows_
[
i
]
*
row_numel_
+
j
]
=
mom1
;
moment2_out_
[
rows_
[
i
]
*
row_numel_
+
j
]
=
mom2
;
param_out_
[
rows_
[
i
]
*
row_numel_
+
j
]
=
p
;
}
// for col id
// Write back to global memory
moment1_out_
[
i
]
=
mom1
;
moment2_out_
[
i
]
=
mom2
;
param_out_
[
i
]
=
p
;
}
};
...
...
@@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
return
;
}
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
auto
grad_merge
=
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
);
auto
&
grad_merge
=
*
(
ctx
.
scope
()
.
NewScope
()
.
Var
(
"sparse_adam_grad_merge"
)
->
GetMutable
<
framework
::
SelectedRows
>
());
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
,
&
grad_merge
);
auto
&
grad_tensor
=
grad_merge
.
value
();
const
T
*
grad_data
=
grad_tensor
.
template
data
<
T
>();
int64_t
*
rows
=
nullptr
;
...
...
@@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
mom2
.
template
data
<
T
>(),
mom2_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
lr
.
template
data
<
T
>(),
grad_data
,
param
.
template
data
<
T
>(),
param_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
rows
,
row_numel
);
param_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
rows
,
row_numel
,
grad_merge
.
rows
().
size
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
grad_merge
.
rows
().
size
());
param
.
numel
());
for_range
(
functor
);
}
else
{
PADDLE_THROW
(
"Variable type not supported by adam_op"
);
...
...
paddle/fluid/operators/clip_op.h
浏览文件 @
7f1e3126
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
...
...
@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
max
=
context
.
Attr
<
T
>
(
"max"
);
auto
min
=
context
.
Attr
<
T
>
(
"min"
);
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
*
x_var
=
context
.
InputVar
(
"X"
);
if
(
x_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
*
x
=
context
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
x_data
=
x
->
data
<
T
>
();
int64_t
numel
=
x
->
numel
();
Transform
<
DeviceContext
>
trans
;
trans
(
context
.
template
device_context
<
DeviceContext
>(),
x_data
,
x_data
+
numel
,
out_data
,
ClipFunctor
<
T
>
(
min
,
max
));
}
else
if
(
x_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
x
=
context
.
Input
<
framework
::
SelectedRows
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
SelectedRows
>
(
"Out"
);
PADDLE_ENFORCE_NE
(
x
,
out
,
"Inplace clip is not allowed when x is SelectedRows"
);
math
::
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
merge_func
(
context
.
template
device_context
<
DeviceContext
>(),
*
x
,
out
);
auto
*
out_tensor
=
out
->
mutable_value
();
auto
*
out_data
=
out_tensor
->
data
<
T
>
();
int64_t
numel
=
out_tensor
->
numel
();
Transform
<
DeviceContext
>
trans
;
trans
(
context
.
template
device_context
<
DeviceContext
>(),
out_data
,
out_data
+
numel
,
out_data
,
ClipFunctor
<
T
>
(
min
,
max
));
}
else
{
PADDLE_THROW
(
"ClipOp only supports LoDTensor and SelectedRows"
);
}
}
};
...
...
@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
max
=
context
.
Attr
<
T
>
(
"max"
);
auto
min
=
context
.
Attr
<
T
>
(
"min"
);
auto
*
d_out
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_out
=
context
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
d_x
!=
nullptr
)
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
x
=
context
.
Input
<
framework
::
LoD
Tensor
>
(
"X"
);
int64_t
numel
=
d_out
->
numel
();
auto
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
d_out_data
=
d_out
->
data
<
T
>
();
...
...
paddle/fluid/operators/math/selected_rows_functor.cc
浏览文件 @
7f1e3126
...
...
@@ -199,6 +199,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
framework
::
SelectedRows
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
)
{
framework
::
SelectedRows
out
;
(
*
this
)(
context
,
input
,
&
out
);
return
out
;
}
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
framework
::
SelectedRows
*
output
)
{
framework
::
SelectedRows
&
out
=
*
output
;
auto
input_rows
=
input
.
rows
();
std
::
set
<
int64_t
>
row_set
(
input_rows
.
begin
(),
input_rows
.
end
());
std
::
vector
<
int64_t
>
merge_rows
(
row_set
.
begin
(),
row_set
.
end
());
...
...
@@ -223,7 +231,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
out_data
[
out_i
*
input_width
+
j
]
+=
input_data
[
i
*
input_width
+
j
];
}
}
return
out
;
}
};
...
...
paddle/fluid/operators/math/selected_rows_functor.cu
浏览文件 @
7f1e3126
...
...
@@ -234,7 +234,7 @@ template <typename T, int block_size>
__global__
void
MergeAddKernel
(
const
T
*
input
,
const
int64_t
*
input_rows
,
T
*
out
,
const
int64_t
*
out_rows
,
size_t
out_rows_size
,
int64_t
row_numel
)
{
const
int
ty
=
blockIdx
.
y
;
const
int
ty
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
__shared__
size_t
out_idx
;
...
...
@@ -260,6 +260,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
framework
::
SelectedRows
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
)
{
framework
::
SelectedRows
out
;
(
*
this
)(
context
,
input
,
&
out
);
return
out
;
}
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
framework
::
SelectedRows
*
output
)
{
framework
::
SelectedRows
&
out
=
*
output
;
framework
::
Vector
<
int64_t
>
input_rows
(
input
.
rows
());
std
::
set
<
int64_t
>
row_set
(
input_rows
.
begin
(),
input_rows
.
end
());
std
::
vector
<
int64_t
>
merge_rows
(
row_set
.
begin
(),
row_set
.
end
());
...
...
@@ -281,16 +289,12 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid1
(
1
,
input_rows
.
size
()
);
dim3
grid1
(
input_rows
.
size
(),
1
);
MergeAddKernel
<
T
,
256
><<<
grid1
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
MergeAddKernel
<
T
,
256
><<<
grid1
,
threads
,
0
,
context
.
stream
()
>>>
(
input_data
,
input_rows
.
CUDAData
(
context
.
GetPlace
()),
out_data
,
out
.
mutable_rows
()
->
CUDAMutableData
(
context
.
GetPlace
()),
out
.
rows
().
size
(),
input_width
);
return
out
;
}
};
...
...
paddle/fluid/operators/math/selected_rows_functor.h
浏览文件 @
7f1e3126
...
...
@@ -65,6 +65,9 @@ struct MergeAdd {
// the input SelectedRows object.
framework
::
SelectedRows
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
);
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
SelectedRows
&
input
,
framework
::
SelectedRows
*
output
);
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录