Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ebd14743
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ebd14743
编写于
2月 07, 2022
作者:
J
jakpiase
提交者:
GitHub
2月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added Adam FP32 JIT assembly kernel (#39158)
* Added adam kernel * CI rerun
上级
e15e4ed0
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
422 addition
and
20 deletion
+422
-20
paddle/fluid/operators/jit/gen/CMakeLists.txt
paddle/fluid/operators/jit/gen/CMakeLists.txt
+1
-0
paddle/fluid/operators/jit/gen/adam.cc
paddle/fluid/operators/jit/gen/adam.cc
+153
-0
paddle/fluid/operators/jit/gen/adam.h
paddle/fluid/operators/jit/gen/adam.h
+75
-0
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+1
-0
paddle/fluid/operators/jit/helper.cc
paddle/fluid/operators/jit/helper.cc
+1
-0
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+5
-0
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+18
-2
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+5
-0
paddle/fluid/operators/jit/refer/CMakeLists.txt
paddle/fluid/operators/jit/refer/CMakeLists.txt
+1
-0
paddle/fluid/operators/jit/refer/refer.cc
paddle/fluid/operators/jit/refer/refer.cc
+1
-0
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+14
-0
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+87
-6
paddle/fluid/operators/optimizers/adam_op.h
paddle/fluid/operators/optimizers/adam_op.h
+47
-8
python/paddle/fluid/tests/unittests/test_adam_op.py
python/paddle/fluid/tests/unittests/test_adam_op.py
+13
-4
未找到文件。
paddle/fluid/operators/jit/gen/CMakeLists.txt
浏览文件 @
ebd14743
...
@@ -32,5 +32,6 @@ USE_JITKERNEL_GEN(kSeqPool)
...
@@ -32,5 +32,6 @@ USE_JITKERNEL_GEN(kSeqPool)
USE_JITKERNEL_GEN
(
kHMax
)
USE_JITKERNEL_GEN
(
kHMax
)
USE_JITKERNEL_GEN
(
kHSum
)
USE_JITKERNEL_GEN
(
kHSum
)
USE_JITKERNEL_GEN
(
kEmbSeqPool
)
USE_JITKERNEL_GEN
(
kEmbSeqPool
)
USE_JITKERNEL_GEN
(
kAdam
)
USE_JITKERNEL_GEN
(
kSgd
)
USE_JITKERNEL_GEN
(
kSgd
)
USE_JITKERNEL_GEN
(
kVBroadcast
)
USE_JITKERNEL_GEN
(
kVBroadcast
)
paddle/fluid/operators/jit/gen/adam.cc
0 → 100644
浏览文件 @
ebd14743
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/adam.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
void
AdamJitCode
::
loadArgs
()
{
static
constexpr
int32_t
one_as_float
=
0x3f800000
;
static
constexpr
int32_t
mask_all_ones
=
0xFFFFFFFF
;
static
constexpr
int64_t
mask_8_divisible
=
0xFFFFFFFFFFFFFFF8
;
static
constexpr
int64_t
abi_pushes_offset
=
num_g_abi_regs
*
8
;
mov
(
reg_mom2_out_ptr
,
ptr
[
rsp
+
(
abi_pushes_offset
+
8
)]);
mov
(
reg_param_out_ptr
,
ptr
[
rsp
+
(
abi_pushes_offset
+
16
)]);
mov
(
eax
,
one_as_float
);
movd
(
xmm_one
,
eax
);
vbroadcastss
(
ymm_one
,
xmm_one
);
// 1
vbroadcastss
(
ymm_beta1
,
xmm_beta1
);
// beta1
vbroadcastss
(
ymm_beta2
,
xmm_beta2
);
// beta2
vbroadcastss
(
ymm_lr
,
xmm_lr
);
// -lr
vbroadcastss
(
ymm_eps
,
xmm_eps
);
// eps
vsubps
(
ymm_one_sub_beta1
,
ymm_one
,
ymm_beta1
);
// 1 - beta1
vsubps
(
ymm_one_sub_beta2
,
ymm_one
,
ymm_beta2
);
// 1 - beta2
mov
(
reg_numel_without_tail
,
reg_numel
);
and_
(
reg_numel_without_tail
,
mask_8_divisible
);
// make it 8-divisible
shl
(
reg_numel_without_tail
,
2
);
// * 4 to treat it as float offset
shl
(
reg_numel
,
2
);
mov
(
eax
,
mask_all_ones
);
kmovw
(
k1
,
eax
);
xor_
(
reg_offset
,
reg_offset
);
}
void
AdamJitCode
::
setTailOpmask
()
{
mov
(
r13
,
rcx
);
mov
(
rcx
,
reg_numel
);
sub
(
rcx
,
reg_offset
);
// get tail numel as float size
shr
(
rcx
,
2
);
// as elements
mov
(
r14
,
1
);
shl
(
r14
,
cl
);
// 2 ^ elements
dec
(
r14
);
// 2 ^ elements - 1, so numel first bits are set to 1
kmovw
(
k1
,
r14d
);
mov
(
rcx
,
r13
);
}
void
AdamJitCode
::
mainCode
()
{
// load grad
vmovups
(
ymm7
|
k1
,
ptr
[
reg_grad_ptr
+
reg_offset
]);
// beta1 * mom1 + (1 - beta1) * g
vmulps
(
ymm8
|
k1
,
ymm_one_sub_beta1
,
ymm7
);
vfmadd231ps
(
ymm8
|
k1
,
ymm_beta1
,
ptr
[
reg_mom1_ptr
+
reg_offset
]);
// beta2 * mom2 + (1 - beta2) * g * g
vmulps
(
ymm7
|
k1
,
ymm7
,
ymm7
);
vmulps
(
ymm7
|
k1
,
ymm_one_sub_beta2
,
ymm7
);
vfmadd231ps
(
ymm7
|
k1
,
ymm1
,
ptr
[
reg_mom2_ptr
+
reg_offset
]);
// store mom1 and mom2
vmovups
(
ptr
[
reg_mom1_out_ptr
+
reg_offset
]
|
k1
,
ymm8
);
vmovups
(
ptr
[
reg_mom2_out_ptr
+
reg_offset
]
|
k1
,
ymm7
);
// sqrt(mom2) + eps
vsqrtps
(
ymm7
|
k1
,
ymm7
);
vaddps
(
ymm7
|
k1
,
ymm7
,
ymm3
);
// p + (-lr) * (mom1 / sqrt(mom2) + eps)
vdivps
(
ymm7
|
k1
,
ymm8
,
ymm7
);
vfmadd213ps
(
ymm7
|
k1
,
ymm2
,
ptr
[
reg_param_ptr
+
reg_offset
]);
// store p
vmovups
(
ptr
[
reg_param_out_ptr
+
reg_offset
]
|
k1
,
ymm7
);
}
void
AdamJitCode
::
genCode
()
{
static
constexpr
int64_t
main_loop_elems_size
=
8
*
sizeof
(
float
);
// 8 floats in YMM
static
constexpr
int64_t
offset_increment
=
main_loop_elems_size
;
preCode
();
loadArgs
();
cmp
(
reg_numel
,
main_loop_elems_size
);
jl
(
"process_tail"
);
L
(
"main_loop"
);
{
mainCode
();
add
(
reg_offset
,
offset_increment
);
cmp
(
reg_numel_without_tail
,
reg_offset
);
jg
(
"main_loop"
);
}
cmp
(
reg_numel
,
reg_offset
);
je
(
"end"
);
L
(
"process_tail"
);
{
setTailOpmask
();
mainCode
();
}
L
(
"end"
);
postCode
();
}
class
AdamCreator
:
public
JitCodeCreator
<
adam_attr_t
>
{
public:
bool
CanBeUsed
(
const
adam_attr_t
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx512f
);
}
size_t
CodeSize
(
const
adam_attr_t
&
attr
)
const
override
{
return
96
+
32
*
8
;
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
adam_attr_t
&
attr
)
const
override
{
return
make_unique
<
AdamJitCode
>
(
attr
,
CodeSize
(
attr
));
}
};
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
kAdam
,
gen
::
AdamCreator
);
paddle/fluid/operators/jit/gen/adam.h
0 → 100644
浏览文件 @
ebd14743
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
class
AdamJitCode
:
public
JitCode
{
public:
explicit
AdamJitCode
(
const
adam_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
)
{
this
->
genCode
();
}
DECLARE_JIT_CODE
(
AdamJitCode
);
void
genCode
()
override
;
void
loadArgs
();
void
setTailOpmask
();
void
mainCode
();
private:
reg64_t
reg_numel
{
abi_param1
};
reg64_t
reg_grad_ptr
{
abi_param2
};
reg64_t
reg_mom1_ptr
{
abi_param3
};
reg64_t
reg_mom2_ptr
{
abi_param4
};
reg64_t
reg_param_ptr
{
abi_param5
};
reg64_t
reg_mom1_out_ptr
{
abi_param6
};
xmm_t
xmm_beta1
=
xmm_t
(
0
);
xmm_t
xmm_beta2
=
xmm_t
(
1
);
xmm_t
xmm_lr
=
xmm_t
(
2
);
xmm_t
xmm_eps
=
xmm_t
(
3
);
xmm_t
xmm_one_sub_beta1
=
xmm_t
(
4
);
xmm_t
xmm_one_sub_beta2
=
xmm_t
(
5
);
xmm_t
xmm_one
=
xmm_t
(
6
);
ymm_t
ymm_beta1
=
ymm_t
(
0
);
ymm_t
ymm_beta2
=
ymm_t
(
1
);
ymm_t
ymm_lr
=
ymm_t
(
2
);
ymm_t
ymm_eps
=
ymm_t
(
3
);
ymm_t
ymm_one_sub_beta1
=
ymm_t
(
4
);
ymm_t
ymm_one_sub_beta2
=
ymm_t
(
5
);
ymm_t
ymm_one
=
ymm_t
(
6
);
reg64_t
reg_mom2_out_ptr
{
r10
};
reg64_t
reg_param_out_ptr
{
r11
};
reg64_t
reg_numel_without_tail
{
r12
};
reg64_t
reg_offset
{
rax
};
};
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/gen/jitcode.h
浏览文件 @
ebd14743
...
@@ -45,6 +45,7 @@ using reg32_t = const Xbyak::Reg32;
...
@@ -45,6 +45,7 @@ using reg32_t = const Xbyak::Reg32;
using
xmm_t
=
const
Xbyak
::
Xmm
;
using
xmm_t
=
const
Xbyak
::
Xmm
;
using
ymm_t
=
const
Xbyak
::
Ymm
;
using
ymm_t
=
const
Xbyak
::
Ymm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
opmask_t
=
const
Xbyak
::
Opmask
;
using
Label
=
Xbyak
::
Label
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
typedef
enum
{
...
...
paddle/fluid/operators/jit/helper.cc
浏览文件 @
ebd14743
...
@@ -58,6 +58,7 @@ const char* to_string(KernelType kt) {
...
@@ -58,6 +58,7 @@ const char* to_string(KernelType kt) {
ONE_CASE
(
kSeqPool
);
ONE_CASE
(
kSeqPool
);
ONE_CASE
(
kMatMul
);
ONE_CASE
(
kMatMul
);
ONE_CASE
(
kHMax
);
ONE_CASE
(
kHMax
);
ONE_CASE
(
kAdam
);
ONE_CASE
(
kHSum
);
ONE_CASE
(
kHSum
);
ONE_CASE
(
kStrideASum
);
ONE_CASE
(
kStrideASum
);
ONE_CASE
(
kSoftmax
);
ONE_CASE
(
kSoftmax
);
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
ebd14743
...
@@ -275,6 +275,11 @@ inline std::ostream& operator<<(std::ostream& os,
...
@@ -275,6 +275,11 @@ inline std::ostream& operator<<(std::ostream& os,
return
os
;
return
os
;
}
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
adam_attr_t
&
attr
)
{
os
<<
"beta1["
<<
attr
.
beta1
<<
"],beta2["
<<
attr
.
beta2
<<
"]"
;
return
os
;
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
sgd_attr_t
&
attr
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
sgd_attr_t
&
attr
)
{
os
<<
"param_height["
<<
attr
.
param_height
<<
"],param_width["
os
<<
"param_height["
<<
attr
.
param_height
<<
"],param_width["
<<
attr
.
param_width
<<
"],grad_height["
<<
attr
.
grad_height
<<
attr
.
param_width
<<
"],grad_height["
<<
attr
.
grad_height
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
ebd14743
...
@@ -24,8 +24,9 @@ namespace jit {
...
@@ -24,8 +24,9 @@ namespace jit {
typedef
enum
{
typedef
enum
{
kNone
=
0
,
kNone
=
0
,
// sort by alphabet
// sort by alphabet
kCRFDecoding
=
1
,
kAdam
=
1
,
kEmbSeqPool
=
2
,
kCRFDecoding
,
kEmbSeqPool
,
kGRUH1
,
kGRUH1
,
kGRUHtPart1
,
kGRUHtPart1
,
kGRUHtPart2
,
kGRUHtPart2
,
...
@@ -269,6 +270,21 @@ struct SgdTuple {
...
@@ -269,6 +270,21 @@ struct SgdTuple {
const
sgd_attr_t
*
);
const
sgd_attr_t
*
);
};
};
typedef
struct
adam_attr_s
{
float
beta1
,
beta2
;
adam_attr_s
()
=
default
;
explicit
adam_attr_s
(
float
beta1
,
float
beta2
)
:
beta1
(
beta1
),
beta2
(
beta2
)
{}
}
adam_attr_t
;
template
<
typename
T
>
struct
AdamTuple
{
static
constexpr
KernelType
kernel_type
=
kAdam
;
typedef
T
data_type
;
typedef
adam_attr_t
attr_type
;
typedef
void
(
*
func_type
)(
T
,
T
,
T
,
T
,
int64_t
,
const
T
*
,
const
T
*
,
const
T
*
,
const
T
*
,
T
*
,
T
*
,
T
*
);
};
typedef
struct
matmul_attr_s
{
typedef
struct
matmul_attr_s
{
int
m
,
n
,
k
;
int
m
,
n
,
k
;
void
*
packed_weight
{
nullptr
};
void
*
packed_weight
{
nullptr
};
...
...
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
ebd14743
...
@@ -63,6 +63,11 @@ int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
...
@@ -63,6 +63,11 @@ int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
return
attr
.
grad_width
;
return
attr
.
grad_width
;
}
}
template
<
>
int64_t
JitCodeKey
<
adam_attr_t
>
(
const
adam_attr_t
&
attr
)
{
return
static_cast
<
int64_t
>
(
attr
.
beta1
+
attr
.
beta2
);
}
}
// namespace jit
}
// namespace jit
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jit/refer/CMakeLists.txt
浏览文件 @
ebd14743
...
@@ -36,5 +36,6 @@ USE_JITKERNEL_REFER(kHMax)
...
@@ -36,5 +36,6 @@ USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER
(
kStrideASum
)
USE_JITKERNEL_REFER
(
kStrideASum
)
USE_JITKERNEL_REFER
(
kSoftmax
)
USE_JITKERNEL_REFER
(
kSoftmax
)
USE_JITKERNEL_REFER
(
kEmbSeqPool
)
USE_JITKERNEL_REFER
(
kEmbSeqPool
)
USE_JITKERNEL_REFER
(
kAdam
)
USE_JITKERNEL_REFER
(
kSgd
)
USE_JITKERNEL_REFER
(
kSgd
)
USE_JITKERNEL_REFER
(
kVBroadcast
)
USE_JITKERNEL_REFER
(
kVBroadcast
)
paddle/fluid/operators/jit/refer/refer.cc
浏览文件 @
ebd14743
...
@@ -55,6 +55,7 @@ REGISTER_REFER_KERNEL(HSum);
...
@@ -55,6 +55,7 @@ REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL
(
StrideASum
);
REGISTER_REFER_KERNEL
(
StrideASum
);
REGISTER_REFER_KERNEL
(
Softmax
);
REGISTER_REFER_KERNEL
(
Softmax
);
REGISTER_REFER_KERNEL
(
EmbSeqPool
);
REGISTER_REFER_KERNEL
(
EmbSeqPool
);
REGISTER_REFER_KERNEL
(
Adam
);
REGISTER_REFER_KERNEL
(
Sgd
);
REGISTER_REFER_KERNEL
(
Sgd
);
REGISTER_REFER_KERNEL
(
VBroadcast
);
REGISTER_REFER_KERNEL
(
VBroadcast
);
...
...
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
ebd14743
...
@@ -552,6 +552,19 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
...
@@ -552,6 +552,19 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
}
}
}
}
template
<
typename
T
>
void
Adam
(
T
beta1
,
T
beta2
,
T
lr
,
T
eps
,
int64_t
numel
,
const
T
*
grad_ptr
,
const
T
*
mom1_ptr
,
const
T
*
mom2_ptr
,
const
T
*
param_ptr
,
T
*
mom1_out_ptr
,
T
*
mom2_out_ptr
,
T
*
param_out_ptr
)
{
for
(
int
i
=
0
;
i
<
numel
;
++
i
)
{
mom1_out_ptr
[
i
]
=
beta1
*
mom1_ptr
[
i
]
+
(
1
-
beta1
)
*
grad_ptr
[
i
];
mom2_out_ptr
[
i
]
=
beta2
*
mom2_ptr
[
i
]
+
(
1
-
beta2
)
*
grad_ptr
[
i
]
*
grad_ptr
[
i
];
param_out_ptr
[
i
]
=
param_ptr
[
i
]
+
lr
*
(
mom1_out_ptr
[
i
]
/
(
sqrt
(
mom2_out_ptr
[
i
])
+
eps
));
}
}
#define DECLARE_REFER_KERNEL(name) \
#define DECLARE_REFER_KERNEL(name) \
template <typename T> \
template <typename T> \
class name##Kernel : public ReferKernel<name##Tuple<T>> { \
class name##Kernel : public ReferKernel<name##Tuple<T>> { \
...
@@ -603,6 +616,7 @@ DECLARE_REFER_KERNEL(SeqPool);
...
@@ -603,6 +616,7 @@ DECLARE_REFER_KERNEL(SeqPool);
DECLARE_REFER_KERNEL
(
MatMul
);
DECLARE_REFER_KERNEL
(
MatMul
);
DECLARE_REFER_KERNEL
(
Softmax
);
DECLARE_REFER_KERNEL
(
Softmax
);
DECLARE_REFER_KERNEL
(
EmbSeqPool
);
DECLARE_REFER_KERNEL
(
EmbSeqPool
);
DECLARE_REFER_KERNEL
(
Adam
);
DECLARE_REFER_KERNEL
(
Sgd
);
DECLARE_REFER_KERNEL
(
Sgd
);
DECLARE_REFER_KERNEL
(
VBroadcast
);
DECLARE_REFER_KERNEL
(
VBroadcast
);
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
ebd14743
...
@@ -841,6 +841,72 @@ void TestKernelStrideScal() {
...
@@ -841,6 +841,72 @@ void TestKernelStrideScal() {
}
}
}
}
template
<
typename
KernelTuple
,
typename
PlaceType
>
void
TestKernelAdam
()
{
using
T
=
typename
KernelTuple
::
data_type
;
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
const
T
lr
=
0.1
;
const
T
beta1
=
0.99
;
const
T
beta2
=
0.95
;
const
T
beta1_pow
=
beta1
*
beta1
;
const
T
beta2_pow
=
beta2
*
beta2
;
const
T
epsilon
=
0.000001
;
const
int64_t
numel
=
123
;
T
learning_rate
=
lr
*
(
sqrt
(
1
-
beta2_pow
)
/
(
1
-
beta1_pow
));
T
eps
=
epsilon
*
sqrt
(
1
-
beta2_pow
);
std
::
vector
<
T
>
param
(
numel
);
std
::
vector
<
T
>
grad
(
numel
);
std
::
vector
<
T
>
mom1
(
numel
);
std
::
vector
<
T
>
mom2
(
numel
);
std
::
vector
<
T
>
param_out
(
param
.
size
());
std
::
vector
<
T
>
mom1_out
(
mom1
.
size
());
std
::
vector
<
T
>
mom2_out
(
mom2
.
size
());
RandomVec
<
T
>
(
numel
,
param
.
data
(),
0.5
f
);
RandomVec
<
T
>
(
numel
,
grad
.
data
(),
0.5
f
);
RandomVec
<
T
>
(
numel
,
mom1
.
data
(),
0.5
f
);
RandomVec
<
T
>
(
numel
,
mom2
.
data
(),
0.5
f
);
auto
ref
=
jit
::
GetReferFunc
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
jit
::
adam_attr_t
attr
(
beta1
,
beta2
);
ref
(
beta1
,
beta2
,
-
learning_rate
,
eps
,
numel
,
grad
.
data
(),
mom1
.
data
(),
mom2
.
data
(),
param
.
data
(),
mom1_out
.
data
(),
mom2_out
.
data
(),
param_out
.
data
());
auto
verifier
=
[](
const
typename
KernelTuple
::
func_type
tgt
,
T
beta1
,
T
beta2
,
T
lr
,
T
eps
,
int64_t
numel
,
const
std
::
vector
<
T
>&
grad
,
const
std
::
vector
<
T
>&
mom1
,
const
std
::
vector
<
T
>&
mom2
,
const
std
::
vector
<
T
>&
param
,
const
std
::
vector
<
T
>&
ref_mom1_out
,
const
std
::
vector
<
T
>&
ref_mom2_out
,
const
std
::
vector
<
T
>&
ref_param_out
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
param
.
size
(),
static_cast
<
size_t
>
(
numel
));
EXPECT_EQ
(
grad
.
size
(),
static_cast
<
size_t
>
(
numel
));
EXPECT_EQ
(
mom1
.
size
(),
static_cast
<
size_t
>
(
numel
));
EXPECT_EQ
(
mom2
.
size
(),
static_cast
<
size_t
>
(
numel
));
std
::
vector
<
T
>
jit_mom1_out
(
ref_mom1_out
.
size
());
std
::
vector
<
T
>
jit_mom2_out
(
ref_mom2_out
.
size
());
std
::
vector
<
T
>
jit_param_out
(
ref_param_out
.
size
());
tgt
(
beta1
,
beta2
,
-
lr
,
eps
,
numel
,
grad
.
data
(),
mom1
.
data
(),
mom2
.
data
(),
param
.
data
(),
jit_mom1_out
.
data
(),
jit_mom2_out
.
data
(),
jit_param_out
.
data
());
ExpectEQ
<
T
>
(
ref_mom1_out
.
data
(),
jit_mom1_out
.
data
(),
numel
);
ExpectEQ
<
T
>
(
ref_mom2_out
.
data
(),
jit_mom2_out
.
data
(),
numel
);
ExpectEQ
<
T
>
(
ref_param_out
.
data
(),
jit_param_out
.
data
(),
numel
);
};
TestAllImpls
<
KernelTuple
,
PlaceType
>
(
attr
,
verifier
,
beta1
,
beta2
,
learning_rate
,
eps
,
numel
,
grad
,
mom1
,
mom2
,
param
,
mom1_out
,
mom2_out
,
param_out
);
}
template
<
typename
KernelTuple
,
typename
PlaceType
>
template
<
typename
KernelTuple
,
typename
PlaceType
>
void
TestKernelSgd
()
{
void
TestKernelSgd
()
{
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
...
@@ -980,7 +1046,7 @@ TEST(JITKernel_pool, jitcreator) {
...
@@ -980,7 +1046,7 @@ TEST(JITKernel_pool, jitcreator) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ
(
jitcreators
.
size
(),
0UL
);
EXPECT_EQ
(
jitcreators
.
size
(),
0UL
);
#else
#else
EXPECT_EQ
(
jitcreators
.
size
(),
2
5
UL
);
EXPECT_EQ
(
jitcreators
.
size
(),
2
6
UL
);
#endif
#endif
}
}
...
@@ -1014,7 +1080,7 @@ TEST(JITKernel_pool, more) {
...
@@ -1014,7 +1080,7 @@ TEST(JITKernel_pool, more) {
TEST
(
JITKernel_pool
,
refer
)
{
TEST
(
JITKernel_pool
,
refer
)
{
const
auto
&
kers
=
jit
::
ReferKernelPool
::
Instance
().
AllKernels
();
const
auto
&
kers
=
jit
::
ReferKernelPool
::
Instance
().
AllKernels
();
EXPECT_EQ
(
kers
.
size
(),
3
1
UL
);
EXPECT_EQ
(
kers
.
size
(),
3
2
UL
);
}
}
// test helper
// test helper
...
@@ -1147,9 +1213,10 @@ TEST(JITKernel_helper, attr) {
...
@@ -1147,9 +1213,10 @@ TEST(JITKernel_helper, attr) {
<<
jit
::
to_string
(
jit
::
kVExp
)
<<
jit
::
to_string
(
jit
::
kVIdentity
)
<<
jit
::
to_string
(
jit
::
kVExp
)
<<
jit
::
to_string
(
jit
::
kVIdentity
)
<<
jit
::
to_string
(
jit
::
kVMul
)
<<
jit
::
to_string
(
jit
::
kVRelu
)
<<
jit
::
to_string
(
jit
::
kVMul
)
<<
jit
::
to_string
(
jit
::
kVRelu
)
<<
jit
::
to_string
(
jit
::
kVScal
)
<<
jit
::
to_string
(
jit
::
kSgd
)
<<
jit
::
to_string
(
jit
::
kVScal
)
<<
jit
::
to_string
(
jit
::
kSgd
)
<<
jit
::
to_string
(
jit
::
kVSigmoid
)
<<
jit
::
to_string
(
jit
::
kVSquare
)
<<
jit
::
to_string
(
jit
::
kAdam
)
<<
jit
::
to_string
(
jit
::
kVSigmoid
)
<<
jit
::
to_string
(
jit
::
kVSub
)
<<
jit
::
to_string
(
jit
::
kVTanh
);
<<
jit
::
to_string
(
jit
::
kVSquare
)
<<
jit
::
to_string
(
jit
::
kVSub
)
EXPECT_EQ
(
out
.
str
().
size
(),
234UL
);
<<
jit
::
to_string
(
jit
::
kVTanh
);
EXPECT_EQ
(
out
.
str
().
size
(),
239UL
);
// SeqPoolTypes
// SeqPoolTypes
out
.
str
(
""
);
out
.
str
(
""
);
...
@@ -1296,6 +1363,19 @@ TEST(JITKernel_key, emb_seq_pool) {
...
@@ -1296,6 +1363,19 @@ TEST(JITKernel_key, emb_seq_pool) {
EXPECT_TRUE
(
key4
!=
key5
);
EXPECT_TRUE
(
key4
!=
key5
);
}
}
TEST
(
JITKernel_key
,
adam
)
{
jit
::
adam_attr_t
attr1
(
0.4
f
,
0.9
f
);
jit
::
adam_attr_t
attr2
(
0.4
f
,
0.9
f
);
jit
::
adam_attr_t
attr3
(
0.1
f
,
0.3
f
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
adam_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
adam_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
adam_attr_t
>
(
attr3
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
!=
key3
);
}
TEST
(
JITKernel_key
,
sgd
)
{
TEST
(
JITKernel_key
,
sgd
)
{
jit
::
sgd_attr_t
attr1
(
1
,
2
,
3
,
4
,
5
);
jit
::
sgd_attr_t
attr1
(
1
,
2
,
3
,
4
,
5
);
jit
::
sgd_attr_t
attr2
(
1
,
2
,
3
,
4
,
5
);
jit
::
sgd_attr_t
attr2
(
1
,
2
,
3
,
4
,
5
);
...
@@ -1316,7 +1396,7 @@ TEST(JITKernel_key, sgd) {
...
@@ -1316,7 +1396,7 @@ TEST(JITKernel_key, sgd) {
EXPECT_TRUE
(
key4
!=
key5
);
EXPECT_TRUE
(
key4
!=
key5
);
}
}
// test kerne
r
ls
// test kernels
#define TestKernelVMul TestKernelXYZN
#define TestKernelVMul TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN
#define TestKernelVAddRelu TestKernelXYZN
#define TestKernelVAddRelu TestKernelXYZN
...
@@ -1383,6 +1463,7 @@ TEST_CPU_KERNEL(SeqPool);
...
@@ -1383,6 +1463,7 @@ TEST_CPU_KERNEL(SeqPool);
TEST_CPU_KERNEL
(
EmbSeqPool
);
TEST_CPU_KERNEL
(
EmbSeqPool
);
TEST_CPU_KERNEL
(
MatMul
);
TEST_CPU_KERNEL
(
MatMul
);
TEST_CPU_KERNEL
(
Softmax
);
TEST_CPU_KERNEL
(
Softmax
);
TEST_CPU_KERNEL
(
Adam
);
TEST_CPU_KERNEL
(
Sgd
);
TEST_CPU_KERNEL
(
Sgd
);
TEST_CPU_KERNEL
(
VBroadcast
);
TEST_CPU_KERNEL
(
VBroadcast
);
...
...
paddle/fluid/operators/optimizers/adam_op.h
浏览文件 @
ebd14743
...
@@ -20,9 +20,11 @@ limitations under the License. */
...
@@ -20,9 +20,11 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -506,21 +508,58 @@ class AdamOpKernel : public framework::OpKernel<T> {
...
@@ -506,21 +508,58 @@ class AdamOpKernel : public framework::OpKernel<T> {
beta2_pow_out
->
numel
()));
beta2_pow_out
->
numel
()));
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
*
grad
=
ctx
.
Input
<
LoDTensor
>
(
"Grad"
);
T
beta1_p
=
beta1_pow
->
data
<
T
>
()[
0
];
T
beta2_p
=
beta2_pow
->
data
<
T
>
()[
0
];
AdamFunctor
<
T
,
CPUAdam
>
functor
(
beta1
,
beta2
,
epsilon
,
beta1_pow
->
data
<
T
>
(),
beta2_pow
->
data
<
T
>
(),
mom1
->
data
<
T
>
(),
mom1_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mom2
->
data
<
T
>
(),
mom2_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
lr
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
param
->
data
<
T
>
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
functor
(
param
->
numel
());
if
(
!
use_global_beta_pow
)
{
if
(
!
use_global_beta_pow
)
{
beta1_pow_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())[
0
]
=
beta1_pow_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())[
0
]
=
beta1
*
beta1_pow
->
data
<
T
>
()[
0
];
beta1
*
beta1_pow
->
data
<
T
>
()[
0
];
beta2_pow_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())[
0
]
=
beta2_pow_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())[
0
]
=
beta2
*
beta2_pow
->
data
<
T
>
()[
0
];
beta2
*
beta2_pow
->
data
<
T
>
()[
0
];
}
}
auto
*
grad
=
ctx
.
Input
<
LoDTensor
>
(
"Grad"
);
T
*
param_out_ptr
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
mom1_out_ptr
=
mom1_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
mom2_out_ptr
=
mom2_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
learning_rate
=
lr
->
data
<
T
>
()[
0
]
*
(
sqrt
(
1
-
beta2_p
)
/
(
1
-
beta1_p
));
T
eps
=
epsilon
*
sqrt
(
1
-
beta2_p
);
jit
::
adam_attr_t
attr
(
beta1
,
beta2
);
int64_t
numel
=
param
->
numel
();
const
T
*
param_ptr
=
param
->
data
<
T
>
();
const
T
*
mom1_ptr
=
mom1
->
data
<
T
>
();
const
T
*
mom2_ptr
=
mom2
->
data
<
T
>
();
const
T
*
grad_ptr
=
grad
->
data
<
T
>
();
auto
adam
=
jit
::
KernelFuncs
<
jit
::
AdamTuple
<
T
>
,
platform
::
CPUPlace
>::
Cache
().
At
(
attr
);
static
constexpr
int64_t
chunk_size
=
512
;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for
(
int64_t
i
=
0
;
i
<
numel
/
chunk_size
;
++
i
)
{
const
int64_t
offset
=
i
*
chunk_size
;
adam
(
beta1
,
beta2
,
-
learning_rate
,
eps
,
chunk_size
,
grad_ptr
+
offset
,
mom1_ptr
+
offset
,
mom2_ptr
+
offset
,
param_ptr
+
offset
,
mom1_out_ptr
+
offset
,
mom2_out_ptr
+
offset
,
param_out_ptr
+
offset
);
}
if
(
numel
%
chunk_size
!=
0
)
{
const
int64_t
offset
=
(
numel
/
chunk_size
)
*
chunk_size
;
const
int64_t
tail_numel
=
numel
%
chunk_size
;
adam
(
beta1
,
beta2
,
-
learning_rate
,
eps
,
tail_numel
,
grad_ptr
+
offset
,
mom1_ptr
+
offset
,
mom2_ptr
+
offset
,
param_ptr
+
offset
,
mom1_out_ptr
+
offset
,
mom2_out_ptr
+
offset
,
param_out_ptr
+
offset
);
}
}
else
if
(
grad_var
->
IsType
<
pten
::
SelectedRows
>
())
{
}
else
if
(
grad_var
->
IsType
<
pten
::
SelectedRows
>
())
{
auto
*
grad
=
ctx
.
Input
<
pten
::
SelectedRows
>
(
"Grad"
);
auto
*
grad
=
ctx
.
Input
<
pten
::
SelectedRows
>
(
"Grad"
);
if
(
grad
->
rows
().
size
()
==
0
)
{
if
(
grad
->
rows
().
size
()
==
0
)
{
...
...
python/paddle/fluid/tests/unittests/test_adam_op.py
浏览文件 @
ebd14743
...
@@ -69,15 +69,19 @@ class TestAdamOp1(OpTest):
...
@@ -69,15 +69,19 @@ class TestAdamOp1(OpTest):
class
TestAdamOp2
(
OpTest
):
class
TestAdamOp2
(
OpTest
):
def
set_shape
(
self
):
self
.
shape
=
(
102
,
105
)
def
setUp
(
self
):
def
setUp
(
self
):
'''Test Adam Op with supplied attributes
'''Test Adam Op with supplied attributes
'''
'''
self
.
op_type
=
"adam"
self
.
op_type
=
"adam"
param
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
"float32"
)
self
.
set_shape
()
grad
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
"float32"
)
param
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
shape
).
astype
(
"float32"
)
moment1
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
"float32"
)
grad
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
shape
).
astype
(
"float32"
)
moment1
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
shape
).
astype
(
"float32"
)
# The second moment is positive
# The second moment is positive
moment2
=
np
.
random
.
random
(
(
102
,
105
)
).
astype
(
"float32"
)
moment2
=
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)
learning_rate
=
0.001
learning_rate
=
0.001
beta1
=
0.9
beta1
=
0.9
...
@@ -113,6 +117,11 @@ class TestAdamOp2(OpTest):
...
@@ -113,6 +117,11 @@ class TestAdamOp2(OpTest):
self
.
check_output
()
self
.
check_output
()
class
TestAdamOnlyTailOp
(
TestAdamOp2
):
def
set_shape
(
self
):
self
.
shape
=
(
3
)
class
TestAdamOpMultipleSteps
(
OpTest
):
class
TestAdamOpMultipleSteps
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
'''Test Adam Operator with supplied attributes
'''Test Adam Operator with supplied attributes
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录