Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
fd152289
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fd152289
编写于
12月 17, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean for range in test=develop
上级
1141db81
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
6 addition
and
60 deletion
+6
-60
paddle/fluid/operators/optimizers/adam_op.h
paddle/fluid/operators/optimizers/adam_op.h
+6
-8
paddle/fluid/platform/for_range.h
paddle/fluid/platform/for_range.h
+0
-52
未找到文件。
paddle/fluid/operators/optimizers/adam_op.h
浏览文件 @
fd152289
...
...
@@ -227,8 +227,10 @@ struct SparseAdamFunctor {
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
auto
row_idx
=
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_count_
,
i
/
row_numel_
);
T
g
=
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
adam_update
(
i
,
g
);
if
(
!
(
lazy_mode_
&&
row_idx
<
0
))
{
T
g
=
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
0
;
adam_update
(
i
,
g
);
}
}
};
...
...
@@ -359,19 +361,15 @@ class AdamOpKernel : public framework::OpKernel<T> {
param_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
rows
,
row_numel
,
grad_merge
.
rows
().
size
(),
lazy_mode
);
VLOG
(
3
)
<<
"lazy_mode :"
<<
lazy_mode
;
if
(
lazy_mode
)
{
std
::
vector
<
int64_t
>
id_vector
;
if
(
lazy_mode
&&
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
size_t
row_count
=
grad_merge
.
rows
().
size
();
std
::
vector
<
int64_t
>
cpu_rows
(
grad_merge
.
rows
());
for
(
size_t
row_index
=
0
;
row_index
<
row_count
;
++
row_index
)
{
for
(
size_t
offset
=
0
;
offset
<
row_numel
;
++
offset
)
{
size_t
i
=
cpu_rows
[
row_index
]
*
row_numel
+
offset
;
id_vector
.
push_back
(
i
);
functor
.
adam_update
(
i
,
grad_data
[
row_index
*
row_numel
+
offset
]
);
}
}
platform
::
ForRangeIn
<
DeviceContext
>
for_range_in
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
id_vector
);
for_range_in
(
functor
);
}
else
{
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
...
...
paddle/fluid/platform/for_range.h
浏览文件 @
fd152289
...
...
@@ -22,29 +22,6 @@ limitations under the License. */
namespace
paddle
{
namespace
platform
{
template
<
typename
DeviceContext
>
struct
ForRangeIn
{
ForRangeIn
(
const
DeviceContext
&
dev_ctx
,
std
::
vector
<
int64_t
>
range
);
template
<
typename
Function
>
void
operator
()(
Function
func
)
const
;
};
template
<
>
struct
ForRangeIn
<
CPUDeviceContext
>
{
ForRangeIn
(
const
CPUDeviceContext
&
dev_ctx
,
std
::
vector
<
int64_t
>
range
)
:
range_
(
range
)
{}
template
<
typename
Function
>
void
operator
()(
Function
func
)
const
{
for
(
auto
i
:
range_
)
{
func
(
i
);
}
}
std
::
vector
<
int64_t
>
range_
;
};
template
<
typename
DeviceContext
>
struct
ForRange
{
ForRange
(
const
DeviceContext
&
dev_ctx
,
size_t
limit
);
...
...
@@ -106,35 +83,6 @@ struct ForRange<CUDADeviceContext> {
int
limit_
;
};
template
<
typename
T
,
typename
Function
>
__global__
static
void
ForRangeInElemwiseOp
(
Function
func
,
T
*
vector
,
int
vector_size
)
{
size_t
idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
if
(
idx
<
vector_size
)
{
func
(
vector
[
idx
]);
}
}
template
<
>
struct
ForRangeIn
<
CUDADeviceContext
>
{
ForRangeIn
(
const
CUDADeviceContext
&
dev_ctx
,
std
::
vector
<
int64_t
>
range
)
:
dev_ctx_
(
dev_ctx
),
range_
(
range
)
{}
template
<
typename
Function
>
inline
void
operator
()(
Function
func
)
const
{
constexpr
int
num_threads
=
1024
;
int
range_size
=
range_
.
size
();
int
block_size
=
range_size
<=
num_threads
?
range_size
:
num_threads
;
int
grid_size
=
(
range_
.
size
()
+
num_threads
-
1
)
/
num_threads
;
ForRangeInElemwiseOp
<<<
grid_size
,
block_size
,
0
,
dev_ctx_
.
stream
()
>>>
(
func
,
range_
.
CUDAData
(
dev_ctx_
.
GetPlace
()),
range_size
);
}
const
CUDADeviceContext
&
dev_ctx_
;
framework
::
Vector
<
int64_t
>
range_
;
};
#endif
}
// namespace platform
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录