Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ea4bdca8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ea4bdca8
编写于
12月 25, 2017
作者:
Y
Yu Yang
提交者:
GitHub
12月 25, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6967 from reyoung/feature/optimize_adam_speed
Use for_range to rewrite adam
上级
ea5d6eae
45372842
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
178 addition
and
39 deletion
+178
-39
paddle/operators/adam_op.h
paddle/operators/adam_op.h
+93
-39
paddle/platform/for_range.h
paddle/platform/for_range.h
+85
-0
未找到文件。
paddle/operators/adam_op.h
浏览文件 @
ea4bdca8
...
...
@@ -13,59 +13,113 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include
"paddle/framework/eigen.h"
#include
<math.h> // for sqrt in CPU and CUDA
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
AdamFunctor
{
T
beta1_
;
T
beta2_
;
T
epsilon_
;
const
T
*
beta1_pow_
;
const
T
*
beta2_pow_
;
const
T
*
moment1_
;
T
*
moment1_out_
;
const
T
*
moment2_
;
T
*
moment2_out_
;
const
T
*
lr_
;
const
T
*
grad_
;
const
T
*
param_
;
T
*
param_out_
;
AdamFunctor
(
T
beta1
,
T
beta2
,
T
epsilon
,
const
T
*
beta1_pow
,
const
T
*
beta2_pow
,
const
T
*
mom1
,
T
*
mom1_out
,
const
T
*
mom2
,
T
*
mom2_out
,
const
T
*
lr
,
const
T
*
grad
,
const
T
*
param
,
T
*
param_out
)
:
beta1_
(
beta1
),
beta2_
(
beta2
),
epsilon_
(
epsilon
),
beta1_pow_
(
beta1_pow
),
beta2_pow_
(
beta2_pow
),
moment1_
(
mom1
),
moment1_out_
(
mom1_out
),
moment2_
(
mom2
),
moment2_out_
(
mom2_out
),
lr_
(
lr
),
grad_
(
grad
),
param_
(
param
),
param_out_
(
param_out
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// Merge all memory access together.
T
g
=
grad_
[
i
];
T
mom1
=
moment1_
[
i
];
T
mom2
=
moment2_
[
i
];
T
lr
=
*
lr_
;
T
beta1_pow
=
*
beta1_pow_
;
T
beta2_pow
=
*
beta2_pow_
;
T
p
=
param_
[
i
];
// Calculation
lr
*=
sqrt
(
1
-
beta2_pow
)
/
(
1
-
beta1_pow
);
mom1
=
beta1_
*
mom1
+
(
1
-
beta1_
)
*
g
;
mom2
=
beta2_
*
mom2
+
(
1
-
beta2_
)
*
g
*
g
;
p
-=
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon_
));
// Write back to global memory
moment1_out_
[
i
]
=
mom1
;
moment2_out_
[
i
]
=
mom2
;
param_out_
[
i
]
=
p
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
AdamOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
moment1_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Moment1Out"
);
auto
moment2_out_tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Moment2Out"
);
param_out_tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
moment1_out_tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
moment2_out_tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
using
paddle
::
framework
::
LoDTensor
;
using
paddle
::
operators
::
detail
::
Ref
;
T
beta1
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"beta1"
));
T
beta2
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"beta2"
));
T
epsilon
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
auto
&
param
=
Ref
(
ctx
.
Input
<
LoDTensor
>
(
"Param"
),
"Must set Param"
);
auto
&
grad
=
Ref
(
ctx
.
Input
<
LoDTensor
>
(
"Grad"
),
"Must set Grad"
);
auto
&
mom1
=
Ref
(
ctx
.
Input
<
LoDTensor
>
(
"Moment1"
),
"Must set Moment1"
);
auto
&
mom2
=
Ref
(
ctx
.
Input
<
LoDTensor
>
(
"Moment2"
),
"Must set Moment2"
);
auto
&
lr
=
Ref
(
ctx
.
Input
<
LoDTensor
>
(
"LearningRate"
),
"Must set LearningRate"
);
auto
&
beta1_pow
=
Ref
(
ctx
.
Input
<
LoDTensor
>
(
"Beta1Pow"
),
"Must set Beta1Pow"
);
auto
&
beta2_pow
=
Ref
(
ctx
.
Input
<
LoDTensor
>
(
"Beta2Pow"
),
"Must set Beta2Pow"
);
auto
&
param_out
=
Ref
(
ctx
.
Output
<
LoDTensor
>
(
"ParamOut"
),
"Must set ParamOut"
);
auto
&
mom1_out
=
Ref
(
ctx
.
Output
<
LoDTensor
>
(
"Moment1Out"
),
"Must set Moment1Out"
);
auto
&
mom2_out
=
Ref
(
ctx
.
Output
<
LoDTensor
>
(
"Moment2Out"
),
"Must set Moment1Out"
);
auto
param
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
));
auto
grad
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
));
auto
moment1
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment1"
));
auto
moment2
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Moment2"
));
auto
lr
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
));
auto
beta1_pow
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Beta1Pow"
));
auto
beta2_pow
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Beta2Pow"
));
auto
param_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out_tensor
);
auto
moment1_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
moment1_out_tensor
);
auto
moment2_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
moment2_out_tensor
);
auto
*
place
=
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
moment1_out
.
device
(
*
place
)
=
beta1
*
moment1
+
(
1
-
beta1
)
*
grad
;
moment2_out
.
device
(
*
place
)
=
beta2
*
moment2
+
(
1
-
beta2
)
*
grad
.
square
();
// All of these are tensors of 1 element
auto
lr_t
=
lr
*
(
1
-
beta2_pow
).
sqrt
()
/
(
1
-
beta1_pow
);
// Eigen does not support automatic broadcast
// Get dimensions of moment vector to broadcast lr_t
Eigen
::
DSizes
<
int
,
1
>
m_dsize
(
moment1_out_tensor
->
numel
());
param_out
.
device
(
*
place
)
=
param
-
lr_t
.
broadcast
(
m_dsize
)
*
(
moment1_out
/
(
moment2_out
.
sqrt
()
+
epsilon
));
AdamFunctor
<
T
>
functor
(
beta1
,
beta2
,
epsilon
,
beta1_pow
.
template
data
<
T
>(),
beta2_pow
.
template
data
<
T
>(),
mom1
.
template
data
<
T
>(),
mom1_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
mom2
.
template
data
<
T
>(),
mom2_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
lr
.
template
data
<
T
>(),
grad
.
template
data
<
T
>(),
param
.
template
data
<
T
>(),
param_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()));
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
param
.
numel
());
for_range
(
functor
);
}
};
...
...
paddle/platform/for_range.h
0 → 100644
浏览文件 @
ea4bdca8
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/device_context.h"
namespace
paddle
{
namespace
platform
{
template
<
typename
DeviceContext
>
struct
ForRange
{
ForRange
(
const
DeviceContext
&
dev_ctx
,
size_t
limit
);
template
<
typename
Function
>
void
operator
()(
Function
func
)
const
;
};
template
<
>
struct
ForRange
<
CPUDeviceContext
>
{
ForRange
(
const
CPUDeviceContext
&
dev_ctx
,
size_t
limit
)
:
limit_
(
limit
)
{}
template
<
typename
Function
>
void
operator
()(
Function
func
)
const
{
for
(
size_t
i
=
0
;
i
<
limit_
;
++
i
)
{
func
(
i
);
}
}
size_t
limit_
;
};
#ifdef __NVCC__
template
<
typename
Function
>
__global__
static
void
ForRangeElemwiseOpGridIsOne
(
Function
func
)
{
size_t
idx
=
static_cast
<
size_t
>
(
threadIdx
.
x
);
func
(
idx
);
}
template
<
typename
Function
>
__global__
static
void
ForRangeElemwiseOp
(
Function
func
,
int
limit
)
{
size_t
idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
if
(
idx
<
limit
)
{
func
(
idx
);
}
}
template
<
>
struct
ForRange
<
CUDADeviceContext
>
{
ForRange
(
const
CUDADeviceContext
&
dev_ctx
,
size_t
limit
)
:
dev_ctx_
(
dev_ctx
),
limit_
(
static_cast
<
int
>
(
limit
))
{}
template
<
typename
Function
>
inline
void
operator
()(
Function
func
)
const
{
constexpr
size_t
num_threads
=
1024
;
int
block_size
=
limit_
<=
num_threads
?
limit_
:
num_threads
;
int
grid_size
=
(
limit_
+
num_threads
-
1
)
/
num_threads
;
if
(
grid_size
==
1
)
{
ForRangeElemwiseOpGridIsOne
<<<
1
,
block_size
,
0
,
dev_ctx_
.
stream
()
>>>
(
func
);
}
else
{
ForRangeElemwiseOp
<<<
grid_size
,
block_size
,
0
,
dev_ctx_
.
stream
()
>>>
(
func
,
limit_
);
}
}
const
CUDADeviceContext
&
dev_ctx_
;
int
limit_
;
};
#endif
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录