Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
1fdf8853
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1fdf8853
编写于
12月 25, 2017
作者:
Y
Yang Yu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize adam_op
上级
39ef5736
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
101 addition
and
16 deletion
+101
-16
paddle/operators/adam_op.h
paddle/operators/adam_op.h
+16
-16
paddle/platform/for_range.h
paddle/platform/for_range.h
+85
-0
未找到文件。
paddle/operators/adam_op.h
浏览文件 @
1fdf8853
...
...
@@ -16,7 +16,7 @@ limitations under the License. */
#include <math.h> // for sqrt in CPU and CUDA
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/
transform
.h"
#include "paddle/platform/
for_range
.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -36,10 +36,12 @@ struct AdamFunctor {
const
T
*
lr_
;
const
T
*
grad_
;
const
T
*
param_
;
T
*
param_out_
;
AdamFunctor
(
T
beta1
,
T
beta2
,
T
epsilon
,
const
T
*
beta1_pow
,
const
T
*
beta2_pow
,
const
T
*
mom1
,
T
*
mom1_out
,
const
T
*
mom2
,
T
*
mom2_out
,
const
T
*
lr
,
const
T
*
grad
,
const
T
*
param
)
T
*
mom2_out
,
const
T
*
lr
,
const
T
*
grad
,
const
T
*
param
,
T
*
param_out
)
:
beta1_
(
beta1
),
beta2_
(
beta2
),
epsilon_
(
epsilon
),
...
...
@@ -51,11 +53,10 @@ struct AdamFunctor {
moment2_out_
(
mom2_out
),
lr_
(
lr
),
grad_
(
grad
),
param_
(
param
)
{}
param_
(
param
),
param_out_
(
param_out
)
{}
// From param[i] --> param_out[i];
inline
HOSTDEVICE
T
operator
()(
const
T
&
p
)
const
{
size_t
i
=
&
p
-
param_
;
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// Merge all memory access together.
T
g
=
grad_
[
i
];
T
mom1
=
moment1_
[
i
];
...
...
@@ -63,17 +64,18 @@ struct AdamFunctor {
T
lr
=
*
lr_
;
T
beta1_pow
=
*
beta1_pow_
;
T
beta2_pow
=
*
beta2_pow_
;
T
p
=
param_
[
i
];
// Calculation
lr
=
lr
*
sqrt
(
1
-
beta2_pow
)
/
(
1
-
beta1_pow
);
lr
*=
sqrt
(
1
-
beta2_pow
)
/
(
1
-
beta1_pow
);
mom1
=
beta1_
*
mom1
+
(
1
-
beta1_
)
*
g
;
mom2
=
beta2_
*
mom2
+
(
1
-
beta2_
)
*
g
*
g
;
T
new_p
=
p
-
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon_
));
p
-=
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon_
));
// Write back to global memory
moment1_out_
[
i
]
=
mom1
;
moment2_out_
[
i
]
=
mom2
;
return
new_
p
;
param_out_
[
i
]
=
p
;
}
};
...
...
@@ -113,13 +115,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
mom2
.
template
data
<
T
>(),
mom2_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
lr
.
template
data
<
T
>(),
grad
.
template
data
<
T
>(),
param
.
template
data
<
T
>());
const
T
*
in_ptr
=
param
.
template
data
<
T
>();
T
*
out_ptr
=
param_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
platform
::
Transform
<
DeviceContext
>
trans
;
trans
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
in_ptr
,
in_ptr
+
param_out
.
numel
(),
out_ptr
,
functor
);
param
.
template
data
<
T
>(),
param_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()));
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
param
.
numel
());
for_range
(
functor
);
}
};
...
...
paddle/platform/for_range.h
0 → 100644
浏览文件 @
1fdf8853
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/device_context.h"
namespace
paddle
{
namespace
platform
{
template
<
typename
DeviceContext
>
struct
ForRange
{
ForRange
(
const
DeviceContext
&
dev_ctx
,
size_t
limit
);
template
<
typename
Function
>
void
operator
()(
Function
func
)
const
;
};
template
<
>
struct
ForRange
<
CPUDeviceContext
>
{
ForRange
(
const
CPUDeviceContext
&
dev_ctx
,
size_t
limit
)
:
limit_
(
limit
)
{}
template
<
typename
Function
>
void
operator
()(
Function
func
)
const
{
for
(
size_t
i
=
0
;
i
<
limit_
;
++
i
)
{
func
(
i
);
}
}
size_t
limit_
;
};
#ifdef __NVCC__
template
<
typename
Function
>
__global__
static
void
ForRangeElemwiseOpGridIsOne
(
Function
func
)
{
size_t
idx
=
static_cast
<
size_t
>
(
threadIdx
.
x
);
func
(
idx
);
}
template
<
typename
Function
>
__global__
static
void
ForRangeElemwiseOp
(
Function
func
,
int
limit
)
{
size_t
idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
if
(
idx
<
limit
)
{
func
(
idx
);
}
}
template
<
>
struct
ForRange
<
CUDADeviceContext
>
{
ForRange
(
const
CUDADeviceContext
&
dev_ctx
,
size_t
limit
)
:
dev_ctx_
(
dev_ctx
),
limit_
(
static_cast
<
int
>
(
limit
))
{}
template
<
typename
Function
>
inline
void
operator
()(
Function
func
)
const
{
constexpr
size_t
num_threads
=
1024
;
int
block_size
=
limit_
<=
num_threads
?
limit_
:
num_threads
;
int
grid_size
=
(
limit_
+
num_threads
-
1
)
/
num_threads
;
if
(
grid_size
==
1
)
{
ForRangeElemwiseOpGridIsOne
<<<
1
,
block_size
,
0
,
dev_ctx_
.
stream
()
>>>
(
func
);
}
else
{
ForRangeElemwiseOp
<<<
grid_size
,
block_size
,
0
,
dev_ctx_
.
stream
()
>>>
(
func
,
limit_
);
}
}
const
CUDADeviceContext
&
dev_ctx_
;
int
limit_
;
};
#endif
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录