Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
fcde2b27
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fcde2b27
编写于
12月 17, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ForRangeIn
上级
cf526462
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
60 addition
and
2 deletion
+60
-2
paddle/fluid/operators/optimizers/adam_op.h
paddle/fluid/operators/optimizers/adam_op.h
+5
-2
paddle/fluid/platform/for_range.h
paddle/fluid/platform/for_range.h
+55
-0
未找到文件。
paddle/fluid/operators/optimizers/adam_op.h
浏览文件 @
fcde2b27
...
...
@@ -359,14 +359,17 @@ class AdamOpKernel : public framework::OpKernel<T> {
param_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
rows
,
row_numel
,
grad_merge
.
rows
().
size
(),
lazy_mode
);
if
(
lazy_mode
)
{
std
::
vector
<
int64_t
>
id_vector
;
size_t
row_count
=
grad_merge
.
rows
().
size
();
for
(
size_t
row_index
=
0
;
row_index
<
row_count
;
++
row_index
)
{
for
(
size_t
offset
=
0
;
offset
<
row_numel
;
++
offset
)
{
size_t
i
=
rows
[
row_index
]
*
row_numel
+
offset
;
T
g
=
grad_data
[
row_index
*
row_numel
+
offset
];
functor
.
adam_update
(
i
,
g
);
id_vector
.
push_back
(
i
);
}
}
platform
::
ForRangeIn
<
DeviceContext
>
for_range_in
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
id_vector
);
for_range_in
(
functor
);
}
else
{
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
...
...
paddle/fluid/platform/for_range.h
浏览文件 @
fcde2b27
...
...
@@ -13,11 +13,38 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
platform
{
template
<
typename
DeviceContext
>
struct
ForRangeIn
{
ForRangeIn
(
const
DeviceContext
&
dev_ctx
,
std
::
vector
<
int64_t
>
range
);
template
<
typename
Function
>
void
operator
()(
Function
func
)
const
;
};
template
<
>
struct
ForRangeIn
<
CPUDeviceContext
>
{
ForRangeIn
(
const
CPUDeviceContext
&
dev_ctx
,
std
::
vector
<
int64_t
>
range
)
:
range_
(
range
)
{}
template
<
typename
Function
>
void
operator
()(
Function
func
)
const
{
for
(
auto
i
:
range_
)
{
func
(
i
);
}
}
std
::
vector
<
int64_t
>
range_
;
};
template
<
typename
DeviceContext
>
struct
ForRange
{
ForRange
(
const
DeviceContext
&
dev_ctx
,
size_t
limit
);
...
...
@@ -79,6 +106,34 @@ struct ForRange<CUDADeviceContext> {
int
limit_
;
};
template
<
typename
T
,
typename
Function
>
__global__
static
void
ForRangeInElemwiseOp
(
Function
func
,
T
*
vector
,
int
vector_size
)
{
size_t
idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
if
(
idx
<
vector_size
)
{
func
(
vector
[
idx
]);
}
}
template
<
>
struct
ForRangeIn
<
CUDADeviceContext
>
{
ForRange
(
const
CUDADeviceContext
&
dev_ctx
,
std
::
vector
<
int64_t
>
range
)
:
dev_ctx_
(
dev_ctx
),
range_
(
range
)
{}
template
<
typename
Function
>
inline
void
operator
()(
Function
func
)
const
{
constexpr
int
num_threads
=
1024
;
int
block_size
=
range_
.
size
()
<=
num_threads
?
limit_
:
num_threads
;
int
grid_size
=
(
range_
.
size
()
+
num_threads
-
1
)
/
num_threads
;
ForRangeInElemwiseOp
<<<
grid_size
,
block_size
,
0
,
dev_ctx_
.
stream
()
>>>
(
func
,
range_
.
data
(),
range_
.
size
());
}
const
CUDADeviceContext
&
dev_ctx_
;
framework
::
Vector
<
int64_t
>
range_
;
};
#endif
}
// namespace platform
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录