Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b4474fb4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b4474fb4
编写于
3年前
作者:
R
Roc
提交者:
GitHub
3年前
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NPU]Adamw skip update for npu (#34897)
上级
1ef21855
develop
Ligoml-patch-1
ZHUI-patch-1
add_some_yaml_config
ascendrelease
cherry_undefined_var
cp_2.4_fix_numpy
delete_disable_iterable_dataset_unittest
delete_fix_retry_ci
delete_fix_undefined_var
delete_revert-34910-spinlocks_for_allocator
delete_revert-35069-revert-34910-spinlocks_for_allocator
delete_revert-36057-dev/read_flags_in_ut
dingjiaweiww-patch-1
disable_iterable_dataset_unittest
dy2static
enable_eager_model_test
final_state_gen_python_c
final_state_intermediate
fix-numpy-issue
fix_concat_slice
fix_dlpack_for
fix_npu_ci
fix_op_flops
fix_retry_ci
fix_rnn_docs
fix_tensor_type
fix_undefined_var
fix_var_stop_gradient_error
incubate/frl_train_eval
incubate/infrt
inplace_addto
layer_norm
make_flag_adding_easier
matmul_double_grad
move_embedding_to_phi
move_histogram_to_pten
move_sgd_to_phi
move_slice_to_pten
move_temporal_shift_to_phi
move_yolo_box_to_phi
npu_fix_alloc
preln_ernie
prv-md-even-more
prv-onednn-2.5
prv-reshape-mkldnn-ut2
pten_tensor_refactor
release/2.2
release/2.3
release/2.3-fc-ernie-fix
release/2.4
revert-34406-add_copy_from_tensor
revert-34910-spinlocks_for_allocator
revert-35069-revert-34910-spinlocks_for_allocator
revert-36057-dev/read_flags_in_ut
revert-36201-refine_fast_threaded_ssa_graph_executor
revert-36985-add_license
revert-37318-refactor_dygraph_to_eager
revert-37926-eager_coreops_500
revert-37956-revert-37727-pylayer_support_tuple
revert-38100-mingdong
revert-38301-allocation_rearrange_pr
revert-38703-numpy_bf16_package_reupload
revert-38732-remove_useless_header_in_elementwise_mul_grad
revert-38959-Reduce_Grad
revert-39143-adjust_empty
revert-39227-move_trace_op_to_pten
revert-39268-dev/remove_concat_fluid_kernel
revert-40170-support_partial_grad
revert-41056-revert-40727-move_some_activaion_to_phi
revert-41065-revert-40993-mv_ele_floordiv_pow
revert-41068-revert-40790-phi_new
revert-41944-smaller_inference_api_test
revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator
revert-43155-fix_ut_tempfile
revert-43882-revert-41944-smaller_inference_api_test
revert-45808-phi/simplify_size_op
revert-46827-deform_comment
revert-47325-remove_cudnn_hardcode
revert-47645-add_npu_storage_dims
revert-48815-set_free_when_no_cache_hit_default_value_true
revert-49654-prim_api_gen
revert-49763-fix_static_composite_gen
support-0D-sort
support_weight_transpose
test_for_Filtetfiles
zhiqiu-patch-1
v2.4.1
v2.4.0
v2.4.0-rc0
v2.3.2
v2.3.1
v2.3.0
v2.3.0-rc0
v2.2.2
v2.2.1
v2.2.0
v2.2.0-rc0
v2.2.0-bak0
无相关合并请求
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
645 addition
and
15 deletion
+645
-15
paddle/fluid/operators/optimizers/adam_op.cc
paddle/fluid/operators/optimizers/adam_op.cc
+20
-0
paddle/fluid/operators/optimizers/adam_op_npu.cc
paddle/fluid/operators/optimizers/adam_op_npu.cc
+76
-0
paddle/fluid/operators/optimizers/adamw_op.cc
paddle/fluid/operators/optimizers/adamw_op.cc
+20
-0
paddle/fluid/operators/optimizers/adamw_op.h
paddle/fluid/operators/optimizers/adamw_op.h
+105
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
...distributed/fleet/meta_optimizers/sharding/fp16_helper.py
+2
-3
python/paddle/fluid/contrib/mixed_precision/decorator.py
python/paddle/fluid/contrib/mixed_precision/decorator.py
+10
-4
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+2
-6
python/paddle/fluid/tests/unittests/npu/test_adamw_op_npu.py
python/paddle/fluid/tests/unittests/npu/test_adamw_op_npu.py
+250
-0
python/paddle/fluid/tests/unittests/test_adam_op.py
python/paddle/fluid/tests/unittests/test_adam_op.py
+39
-0
python/paddle/optimizer/adamw.py
python/paddle/optimizer/adamw.py
+121
-2
未找到文件。
paddle/fluid/operators/optimizers/adam_op.cc
浏览文件 @
b4474fb4
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/optimizers/adam_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/optimizers/adamw_op.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -230,11 +231,30 @@ $$
)DOC"
);
}
};
class
AdamWOpMaker
:
public
AdamOpMaker
{
public:
void
Make
()
{
AdamOpMaker
::
Make
();
AddAttr
<
float
>
(
"coeff"
,
"(float, default 0.01) "
"coeff of the weight decay"
)
.
SetDefault
(
0.01
f
);
AddAttr
<
bool
>
(
"with_decay"
,
"(bool, default false) "
"whether to do weight decay"
)
.
SetDefault
(
false
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
adam
,
ops
::
AdamOp
,
ops
::
AdamOpMaker
);
REGISTER_OP_WITHOUT_GRADIENT
(
adamw
,
ops
::
AdamWOp
,
ops
::
AdamWOpMaker
);
REGISTER_OP_CPU_KERNEL
(
adam
,
ops
::
AdamOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
AdamOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/optimizers/adam_op_npu.cc
浏览文件 @
b4474fb4
...
...
@@ -225,6 +225,79 @@ class AdamNPUKernel : public framework::OpKernel<T> {
}
};
template
<
typename
T
>
class
AdamWNPUKernel
:
public
AdamNPUKernel
<
platform
::
NPUDeviceContext
,
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
VLOG
(
3
)
<<
"NPU AdamW Kernel"
;
bool
skip_update
=
false
;
if
(
ctx
.
HasInput
(
"SkipUpdate"
))
{
VLOG
(
3
)
<<
"Has SkipUpdate"
;
auto
*
skip_update_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"SkipUpdate"
);
PADDLE_ENFORCE_EQ
(
skip_update_tensor
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(SkipUpdate) size must be 1, but get %d"
,
skip_update_tensor
->
numel
()));
std
::
vector
<
bool
>
skip_update_vec
;
TensorToVector
(
*
skip_update_tensor
,
ctx
.
device_context
(),
&
skip_update_vec
);
skip_update
=
skip_update_vec
[
0
];
}
VLOG
(
3
)
<<
"Skip update"
<<
skip_update
;
bool
with_decay
=
ctx
.
Attr
<
bool
>
(
"with_decay"
);
if
(
!
skip_update
&&
with_decay
)
{
float
coeff
=
ctx
.
Attr
<
float
>
(
"coeff"
);
auto
*
lr
=
ctx
.
Input
<
LoDTensor
>
(
"LearningRate"
);
auto
place
=
ctx
.
GetPlace
();
auto
stream
=
ctx
.
template
device_context
<
paddle
::
platform
::
NPUDeviceContext
>()
.
stream
();
Tensor
one
(
framework
::
proto
::
VarType
::
FP32
);
Tensor
decay
(
framework
::
proto
::
VarType
::
FP32
);
Tensor
tmp
(
framework
::
proto
::
VarType
::
FP32
);
tmp
.
mutable_data
<
float
>
({
1
},
place
);
one
.
mutable_data
<
float
>
({
1
},
place
);
decay
.
mutable_data
<
float
>
({
1
},
place
);
FillNpuTensorWithConstant
<
float
>
(
&
one
,
1.0
f
);
framework
::
NPUAttributeMap
attr_input
=
{{
"value"
,
coeff
}};
const
auto
&
runner1
=
NpuOpRunner
(
"Muls"
,
{
*
lr
},
{
tmp
},
attr_input
);
runner1
.
Run
(
stream
);
const
auto
&
runner2
=
NpuOpRunner
(
"Sub"
,
{
one
,
tmp
},
{
decay
},
{});
runner2
.
Run
(
stream
);
if
(
ctx
.
HasInput
(
"MasterParam"
))
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Master Parma is not supported on npu"
));
}
else
{
auto
*
param_out
=
ctx
.
Output
<
LoDTensor
>
(
"ParamOut"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
InputNames
(
"Param"
).
front
(),
framework
::
ToTypeName
(
param_var
->
Type
())));
auto
*
param
=
ctx
.
Input
<
LoDTensor
>
(
"Param"
);
const
auto
&
runner
=
NpuOpRunner
(
"Mul"
,
{
*
param
,
decay
},
{
*
const_cast
<
framework
::
LoDTensor
*>
(
param
)},
{});
runner
.
Run
(
stream
);
}
}
AdamNPUKernel
<
platform
::
NPUDeviceContext
,
T
>::
Compute
(
ctx
);
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -234,3 +307,6 @@ REGISTER_OP_NPU_KERNEL(
adam
,
ops
::
AdamNPUKernel
<
paddle
::
platform
::
NPUDeviceContext
,
float
>
,
ops
::
AdamNPUKernel
<
paddle
::
platform
::
NPUDeviceContext
,
paddle
::
platform
::
float16
>
);
REGISTER_OP_NPU_KERNEL
(
adamw
,
ops
::
AdamWNPUKernel
<
float
>
,
ops
::
AdamWNPUKernel
<
paddle
::
platform
::
float16
>
);
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/optimizers/adamw_op.cc
0 → 100644
浏览文件 @
b4474fb4
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/fluid/operators/optimizers/adamw_op.h>
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CPU_KERNEL
(
adamw
,
ops
::
AdamWOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
AdamWOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/optimizers/adamw_op.h
0 → 100644
浏览文件 @
b4474fb4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/fluid/operators/optimizers/adam_op.h>
namespace
paddle
{
namespace
operators
{
class
AdamWOp
:
public
AdamOp
{
using
AdamOp
::
AdamOp
;
};
struct
CPUAdamW
;
template
<
typename
T
,
typename
Flavour
>
class
AdamWFunctor
;
template
<
typename
T
>
class
AdamWFunctor
<
T
,
CPUAdamW
>
{
private:
const
float
coeff_
;
const
float
learning_rate_
;
T
*
param_
;
public:
AdamWFunctor
(
const
float
&
coeff
,
const
float
&
learning_rate
,
T
*
param
)
:
coeff_
(
coeff
),
learning_rate_
(
learning_rate
),
param_
(
param
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
numel
)
const
{
Eigen
::
Map
<
Eigen
::
Array
<
T
,
1
,
Eigen
::
Dynamic
>>
param
{
param_
,
static_cast
<
Eigen
::
Index
>
(
numel
)};
// Calculation
param
=
param
*
(
1.0
f
-
learning_rate_
*
coeff_
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
AdamWOpKernel
:
public
AdamOpKernel
<
DeviceContext
,
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
InputNames
(
"Param"
).
front
(),
framework
::
ToTypeName
(
param_var
->
Type
())));
using
paddle
::
framework
::
LoDTensor
;
bool
skip_update
=
false
;
// TODO(liupeng):
if
(
ctx
.
HasInput
(
"SkipUpdate"
))
{
VLOG
(
3
)
<<
"Has SkipUpdate"
;
auto
*
skip_update_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"SkipUpdate"
);
PADDLE_ENFORCE_EQ
(
skip_update_tensor
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(SkipUpdate) size must be 1, but get %d"
,
skip_update_tensor
->
numel
()));
std
::
vector
<
bool
>
skip_update_vec
;
TensorToVector
(
*
skip_update_tensor
,
ctx
.
device_context
(),
&
skip_update_vec
);
skip_update
=
skip_update_vec
[
0
];
}
VLOG
(
3
)
<<
"Skip update"
<<
skip_update
;
bool
with_decay
=
ctx
.
Attr
<
bool
>
(
"with_decay"
);
if
(
skip_update
||
!
with_decay
)
{
AdamOpKernel
<
DeviceContext
,
T
>::
Compute
(
ctx
);
return
;
}
float
coeff
=
ctx
.
Attr
<
float
>
(
"coeff"
);
auto
*
lr
=
ctx
.
Input
<
LoDTensor
>
(
"LearningRate"
);
LoDTensor
*
param
;
if
(
ctx
.
HasInput
(
"MasterParam"
))
{
// TODO(liupeng): master
param
=
const_cast
<
LoDTensor
*>
(
ctx
.
Input
<
LoDTensor
>
(
"MasterParam"
));
}
else
{
param
=
const_cast
<
LoDTensor
*>
(
ctx
.
Input
<
LoDTensor
>
(
"Param"
));
}
// AdamWFunctor(float coeff, const float* learning_rate, T* parma)
AdamWFunctor
<
T
,
CPUAdamW
>
functor
(
coeff
,
*
lr
->
data
<
float
>
(),
param
->
data
<
T
>
());
functor
(
param
->
numel
());
AdamOpKernel
<
DeviceContext
,
T
>::
Compute
(
ctx
);
}
};
}
// namespace operators
}
// namespace paddle
This diff is collapsed.
Click to expand it.
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
浏览文件 @
b4474fb4
...
...
@@ -197,7 +197,6 @@ class FP16Utils(object):
if
op
.
type
==
"update_loss_scaling"
:
update_loss_scaling_op_idx
=
idx
inf_var_name
=
op
.
desc
.
input
(
'FoundInfinite'
)[
0
]
op
.
_rename_input
(
inf_var_name
,
inf_var_name
+
"@GLOBAL_WORLD"
)
break
# not use amp
...
...
@@ -246,10 +245,10 @@ class FP16Utils(object):
update_loss_scaling_op_idx
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var
_global
},
outputs
=
{
'Out'
:
inf_var
},
attrs
=
{
"in_dtype"
:
inf_var_int32
.
dtype
,
"out_dtype"
:
inf_var
_global
.
dtype
,
"out_dtype"
:
inf_var
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
update_loss_scaling_op_idx
+=
1
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/contrib/mixed_precision/decorator.py
浏览文件 @
b4474fb4
...
...
@@ -399,12 +399,18 @@ class OptimizerWithMixedPrecision(object):
self
.
_decr_ratio
,
name
=
"update_loss_scaling"
)
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
if
isinstance
(
self
.
_optimizer
,
paddle
.
fluid
.
optimizer
.
Adam
):
# With fleet, optimizers are nested and the real optimizer set by user is the inner most one.
real_optimizer
=
self
.
_optimizer
while
hasattr
(
real_optimizer
,
"inner_opt"
):
real_optimizer
=
real_optimizer
.
inner_opt
if
isinstance
(
real_optimizer
,
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
)):
# NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
# copy it in advance to avoid multiple time copies.
found_inf
=
paddle
.
tensor
.
creation
.
_memcpy
(
found_inf
,
paddle
.
CPUPlace
())
self
.
_optimizer
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
)
with
self
.
_train_program
.
_optimized_guard
([]):
found_inf
=
paddle
.
tensor
.
creation
.
_memcpy
(
found_inf
,
paddle
.
CPUPlace
())
real_optimizer
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
)
optimize_ops
=
self
.
_optimizer
.
apply_gradients
(
params_grads
)
return
optimize_ops
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/optimizer.py
浏览文件 @
b4474fb4
...
...
@@ -4661,12 +4661,8 @@ class PipelineOptimizer(object):
op
.
_set_attr
(
self
.
_op_device_key
,
f
"
{
self
.
_device
}
:all"
)
else
:
other_known_ops
=
[
'update_loss_scaling'
,
'reduce_any'
,
'concat'
,
'sum'
,
'check_finite_and_unscale'
,
'alloc_float_status'
,
'update_loss_scaling'
,
'reduce_any'
,
'concat'
,
'sum'
,
'check_finite_and_unscale'
,
'alloc_float_status'
,
'memcpy'
]
assert
op
.
type
in
other_known_ops
,
"For other ops without "
\
"op_device set, they must be one of {}, but it "
\
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/npu/test_adamw_op_npu.py
0 → 100644
浏览文件 @
b4474fb4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
unittest
import
sys
sys
.
path
.
append
(
".."
)
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
test_adam_op
import
adamw_step
paddle
.
enable_static
()
SEED
=
2021
class
TestAdamW
(
OpTest
):
def
setUp
(
self
):
self
.
set_npu
()
self
.
place
=
paddle
.
NPUPlace
(
0
)
self
.
op_type
=
"adamw"
param
=
np
.
random
.
uniform
(
-
1
,
1
,
(
105
,
102
)).
astype
(
"float32"
)
grad
=
np
.
random
.
uniform
(
-
1
,
1
,
(
105
,
102
)).
astype
(
"float32"
)
moment1
=
np
.
random
.
uniform
(
-
1
,
1
,
(
105
,
102
)).
astype
(
"float32"
)
# The second moment is positive
moment2
=
np
.
random
.
random
((
105
,
102
)).
astype
(
"float32"
)
learning_rate
=
0.5
beta1
=
0.78
beta2
=
0.836
epsilon
=
1e-4
beta1_pow
=
beta1
**
10
beta2_pow
=
beta2
**
10
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Moment1'
:
moment1
,
'Moment2'
:
moment2
,
'LearningRate'
:
np
.
array
([
learning_rate
]).
astype
(
"float32"
),
'Beta1Pow'
:
np
.
array
([
beta1_pow
]).
astype
(
"float32"
),
'Beta2Pow'
:
np
.
array
([
beta2_pow
]).
astype
(
"float32"
)
}
self
.
attrs
=
{
'epsilon'
:
epsilon
,
'beta1'
:
beta1
,
'beta2'
:
beta2
,
"coeff"
:
0.9
,
"with_decay"
:
True
}
param_out
,
moment1_out
,
\
moment2_out
=
adamw_step
(
self
.
inputs
,
self
.
attrs
)
self
.
outputs
=
{
'Moment1Out'
:
moment1_out
,
'Moment2Out'
:
moment2_out
,
'ParamOut'
:
param_out
,
'Beta1PowOut'
:
np
.
array
([
beta1_pow
]).
astype
(
"float32"
)
*
beta1
,
'Beta2PowOut'
:
np
.
array
([
beta2_pow
]).
astype
(
"float32"
)
*
beta2
}
def
set_npu
(
self
):
self
.
__class__
.
use_npu
=
True
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
,
atol
=
1e-5
)
class
TestAdamOpWithSkipUpdate
(
OpTest
):
def
setUp
(
self
):
self
.
set_npu
()
self
.
place
=
paddle
.
NPUPlace
(
0
)
self
.
op_type
=
"adamw"
param
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
"float32"
)
grad
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
"float32"
)
moment1
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
"float32"
)
# The second moment is positive
moment2
=
np
.
random
.
random
((
102
,
105
)).
astype
(
"float32"
)
learning_rate
=
0.004
beta1
=
0.78
beta2
=
0.836
epsilon
=
1e-4
beta1_pow
=
beta1
**
10
beta2_pow
=
beta2
**
10
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Moment1'
:
moment1
,
'Moment2'
:
moment2
,
'LearningRate'
:
np
.
array
([
learning_rate
]).
astype
(
"float32"
),
'Beta1Pow'
:
np
.
array
([
beta1_pow
]).
astype
(
"float32"
),
'Beta2Pow'
:
np
.
array
([
beta2_pow
]).
astype
(
"float32"
),
'Beta1Tensor'
:
np
.
array
([
beta1
]).
astype
(
"float32"
),
'Beta2Tensor'
:
np
.
array
([
beta2
]).
astype
(
"float32"
),
'EpsilonTensor'
:
np
.
array
([
epsilon
]).
astype
(
"float32"
),
"SkipUpdate"
:
np
.
array
([
True
]).
astype
(
"bool"
),
}
self
.
attrs
=
{
'epsilon'
:
epsilon
,
"coeff"
:
0.02
,
"with_decay"
:
True
}
self
.
outputs
=
{
'Moment1Out'
:
moment1
,
'Moment2Out'
:
moment2
,
'ParamOut'
:
param
,
'Beta1PowOut'
:
self
.
inputs
[
'Beta1Pow'
],
'Beta2PowOut'
:
self
.
inputs
[
'Beta2Pow'
],
}
def
set_npu
(
self
):
self
.
__class__
.
use_npu
=
True
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
,
atol
=
1e-5
)
class
TestAdamOpWithoutDecay
(
OpTest
):
def
setUp
(
self
):
self
.
set_npu
()
self
.
place
=
paddle
.
NPUPlace
(
0
)
self
.
op_type
=
"adamw"
param
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
"float32"
)
grad
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
"float32"
)
moment1
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
"float32"
)
# The second moment is positive
moment2
=
np
.
random
.
random
((
102
,
105
)).
astype
(
"float32"
)
learning_rate
=
0.004
beta1
=
0.78
beta2
=
0.836
epsilon
=
1e-4
beta1_pow
=
beta1
**
10
beta2_pow
=
beta2
**
10
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Moment1'
:
moment1
,
'Moment2'
:
moment2
,
'LearningRate'
:
np
.
array
([
learning_rate
]).
astype
(
"float32"
),
'Beta1Pow'
:
np
.
array
([
beta1_pow
]).
astype
(
"float32"
),
'Beta2Pow'
:
np
.
array
([
beta2_pow
]).
astype
(
"float32"
),
'Beta1Tensor'
:
np
.
array
([
beta1
]).
astype
(
"float32"
),
'Beta2Tensor'
:
np
.
array
([
beta2
]).
astype
(
"float32"
),
'EpsilonTensor'
:
np
.
array
([
epsilon
]).
astype
(
"float32"
),
"SkipUpdate"
:
np
.
array
([
True
]).
astype
(
"bool"
),
}
self
.
attrs
=
{
'epsilon'
:
epsilon
,
"coeff"
:
0.02
,
"with_decay"
:
False
}
self
.
outputs
=
{
'Moment1Out'
:
moment1
,
'Moment2Out'
:
moment2
,
'ParamOut'
:
param
,
'Beta1PowOut'
:
self
.
inputs
[
'Beta1Pow'
],
'Beta2PowOut'
:
self
.
inputs
[
'Beta2Pow'
],
}
def
set_npu
(
self
):
self
.
__class__
.
use_npu
=
True
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
,
atol
=
1e-5
)
class
TestNet
(
unittest
.
TestCase
):
def
_test
(
self
,
run_npu
=
True
):
main_prog
=
paddle
.
static
.
Program
()
startup_prog
=
paddle
.
static
.
Program
()
main_prog
.
random_seed
=
SEED
startup_prog
.
random_seed
=
SEED
np
.
random
.
seed
(
SEED
)
a_np
=
np
.
random
.
random
(
size
=
(
32
,
32
)).
astype
(
'float32'
)
b_np
=
np
.
random
.
random
(
size
=
(
32
,
32
)).
astype
(
'float32'
)
label_np
=
np
.
random
.
randint
(
2
,
size
=
(
32
,
1
)).
astype
(
'int64'
)
with
paddle
.
static
.
program_guard
(
main_prog
,
startup_prog
):
a
=
paddle
.
static
.
data
(
name
=
"a"
,
shape
=
[
32
,
32
],
dtype
=
'float32'
)
b
=
paddle
.
static
.
data
(
name
=
"b"
,
shape
=
[
32
,
32
],
dtype
=
'float32'
)
label
=
paddle
.
static
.
data
(
name
=
"label"
,
shape
=
[
32
,
1
],
dtype
=
'int64'
)
sum
=
paddle
.
add
(
a
,
b
)
z
=
paddle
.
pow
(
sum
,
2.0
)
fc_1
=
fluid
.
layers
.
fc
(
input
=
z
,
size
=
128
)
prediction
=
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
2
,
act
=
'softmax'
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
loss
=
fluid
.
layers
.
reduce_mean
(
cost
)
adam
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.01
,
weight_decay
=
0.02
)
adam
.
minimize
(
loss
)
if
run_npu
:
place
=
paddle
.
NPUPlace
(
0
)
else
:
place
=
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
print
(
"Start run on {}"
.
format
(
place
))
for
epoch
in
range
(
100
):
pred_res
,
loss_res
=
exe
.
run
(
main_prog
,
feed
=
{
"a"
:
a_np
,
"b"
:
b_np
,
"label"
:
label_np
},
fetch_list
=
[
prediction
,
loss
])
if
epoch
%
10
==
0
:
print
(
"Epoch {} | Prediction[0]: {}, Loss: {}"
.
format
(
epoch
,
pred_res
[
0
],
loss_res
))
return
pred_res
,
loss_res
def
test_npu
(
self
):
npu_pred
,
npu_loss
=
self
.
_test
(
True
)
cpu_pred
,
cpu_loss
=
self
.
_test
(
False
)
self
.
assertTrue
(
np
.
allclose
(
npu_pred
,
cpu_pred
,
rtol
=
1e-3
))
self
.
assertTrue
(
np
.
allclose
(
npu_loss
,
cpu_loss
,
rtol
=
1e-3
))
if
__name__
==
'__main__'
:
unittest
.
main
()
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_adam_op.py
浏览文件 @
b4474fb4
...
...
@@ -215,6 +215,45 @@ def adam_step(inputs, attributes):
return
param_out
,
moment1_out
,
moment2_out
def
adamw_step
(
inputs
,
attributes
):
'''
Simulate one step of the adam optimizer
:param inputs: dict of inputs
:param attributes: dict of attributes
:return tuple: tuple of output param, moment1, moment2,
beta1 power accumulator and beta2 power accumulator
'''
param
=
inputs
[
'Param'
]
grad
=
inputs
[
'Grad'
]
moment1
=
inputs
[
'Moment1'
]
moment2
=
inputs
[
'Moment2'
]
lr
=
inputs
[
'LearningRate'
]
beta1_pow
=
inputs
[
'Beta1Pow'
]
beta2_pow
=
inputs
[
'Beta2Pow'
]
epsilon
=
attributes
[
'epsilon'
]
coeff
=
attributes
[
"coeff"
]
if
attributes
.
get
(
"with_decay"
,
False
):
decay
=
1.0
-
lr
*
coeff
param2
=
param
*
decay
param
=
param2
.
copy
()
if
'beta1'
in
attributes
:
beta1
=
attributes
[
'beta1'
]
else
:
beta1
=
inputs
[
'Beta1Tensor'
][
0
]
if
'beta2'
in
attributes
:
beta2
=
attributes
[
'beta2'
]
else
:
beta2
=
inputs
[
'Beta2Tensor'
][
0
]
moment1_out
=
beta1
*
moment1
+
(
1
-
beta1
)
*
grad
moment2_out
=
beta2
*
moment2
+
(
1
-
beta2
)
*
np
.
square
(
grad
)
lr_t
=
lr
*
np
.
sqrt
(
1
-
beta2_pow
)
/
(
1
-
beta1_pow
)
param_out
=
param
-
lr_t
*
(
moment1_out
/
(
np
.
sqrt
(
moment2_out
)
+
epsilon
))
return
param_out
,
moment1_out
,
moment2_out
def
adam_step_sparse
(
inputs
,
attributes
,
height
,
rows
,
row_numel
,
np_grad
,
lazy_mode
):
'''
...
...
This diff is collapsed.
Click to expand it.
python/paddle/optimizer/adamw.py
浏览文件 @
b4474fb4
...
...
@@ -16,9 +16,12 @@ from .optimizer import Optimizer
from
.adam
import
Adam
from
..fluid
import
core
from
..fluid
import
framework
from
..fluid.framework
import
Variable
from
..fluid.dygraph
import
base
as
imperative_base
import
paddle
_C_ops
=
core
.
ops
__all__
=
[]
...
...
@@ -173,6 +176,23 @@ class AdamW(Adam):
multi_precision
=
multi_precision
)
self
.
_default_dict
=
{
'coeff'
:
coeff
}
self
.
type
=
"adamw"
# now the adamw op doesn't support cuda
if
core
.
is_compiled_with_cuda
():
self
.
type
=
"adam"
# Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that.
self
.
_auxiliary_vars
=
dict
()
def
_set_auxiliary_var
(
self
,
key
,
val
):
self
.
_auxiliary_vars
[
key
]
=
val
def
_get_auxiliary_var
(
self
,
key
):
if
key
in
self
.
_auxiliary_vars
:
return
self
.
_auxiliary_vars
[
key
]
else
:
return
None
def
_append_decoupled_weight_decay
(
self
,
block
,
param_and_grad
):
"""
Add decoupled weight decay op.
...
...
@@ -228,8 +248,107 @@ class AdamW(Adam):
paddle
.
fluid
.
layers
.
assign
(
input
=
scaled_param
,
output
=
param
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
self
.
_append_decoupled_weight_decay
(
block
,
param_and_grad
)
return
super
(
AdamW
,
self
).
_append_optimize_op
(
block
,
param_and_grad
)
if
not
core
.
is_compiled_with_npu
():
self
.
_append_decoupled_weight_decay
(
block
,
param_and_grad
)
return
super
(
AdamW
,
self
).
_append_optimize_op
(
block
,
param_and_grad
)
assert
isinstance
(
block
,
framework
.
Block
)
if
isinstance
(
param_and_grad
,
dict
):
param_and_grad
=
self
.
_update_param_group
(
param_and_grad
)
param
,
grad
=
param_and_grad
# Whether we should do weight decay for the parameter.
with_decay
=
True
if
self
.
_apply_decay_param_fun
is
not
None
\
and
not
self
.
_apply_decay_param_fun
(
param
.
name
):
with_decay
=
False
moment1
=
self
.
_get_accumulator
(
self
.
_moment1_acc_str
,
param_and_grad
[
0
])
moment2
=
self
.
_get_accumulator
(
self
.
_moment2_acc_str
,
param_and_grad
[
0
])
beta1_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta1_pow_acc_str
,
param_and_grad
[
0
])
beta2_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta2_pow_acc_str
,
param_and_grad
[
0
])
find_master
=
self
.
_multi_precision
and
param_and_grad
[
0
].
dtype
==
core
.
VarDesc
.
VarType
.
FP16
master_weight
=
(
self
.
_master_weights
[
param_and_grad
[
0
].
name
]
if
find_master
else
None
)
lr
=
self
.
_create_param_lr
(
param_and_grad
)
# create the adam optimize op
if
framework
.
in_dygraph_mode
():
_beta1
=
self
.
_beta1
if
not
isinstance
(
self
.
_beta1
,
Variable
)
else
self
.
_beta1
.
numpy
().
item
(
0
)
_beta2
=
self
.
_beta2
if
not
isinstance
(
self
.
_beta2
,
Variable
)
else
self
.
_beta2
.
numpy
().
item
(
0
)
_
,
_
,
_
,
_
,
_
=
_C_ops
.
adam
(
param_and_grad
[
0
],
param_and_grad
[
1
],
lr
,
moment1
,
moment2
,
beta1_pow_acc
,
beta2_pow_acc
,
param_and_grad
[
0
],
moment1
,
moment2
,
beta1_pow_acc
,
beta2_pow_acc
,
'epsilon'
,
self
.
_epsilon
,
'lazy_mode'
,
self
.
_lazy_mode
,
'min_row_size_to_use_multithread'
,
1000
,
'beta1'
,
_beta1
,
'beta2'
,
_beta2
)
return
None
inputs
=
{
"Param"
:
[
param_and_grad
[
0
]],
"Grad"
:
[
param_and_grad
[
1
]],
"LearningRate"
:
[
lr
],
"Moment1"
:
[
moment1
],
"Moment2"
:
[
moment2
],
"Beta1Pow"
:
[
beta1_pow_acc
],
"Beta2Pow"
:
[
beta2_pow_acc
],
}
# Pass found_inf to adamw, to skip update for not only param, but also momentum and beta_pow
found_inf
=
self
.
_get_auxiliary_var
(
'found_inf'
)
if
found_inf
:
inputs
[
'SkipUpdate'
]
=
found_inf
outputs
=
{
"ParamOut"
:
[
param_and_grad
[
0
]],
"Moment1Out"
:
[
moment1
],
"Moment2Out"
:
[
moment2
],
"Beta1PowOut"
:
[
beta1_pow_acc
],
"Beta2PowOut"
:
[
beta2_pow_acc
],
}
attrs
=
{
"lazy_mode"
:
self
.
_lazy_mode
,
"min_row_size_to_use_multithread"
:
1000
,
"multi_precision"
:
find_master
,
"with_decay"
:
with_decay
,
"coeff"
:
self
.
_coeff
,
}
if
isinstance
(
self
.
_beta1
,
Variable
):
inputs
[
'Beta1Tensor'
]
=
self
.
_beta1
else
:
attrs
[
'beta1'
]
=
self
.
_beta1
if
isinstance
(
self
.
_beta2
,
Variable
):
inputs
[
'Beta2Tensor'
]
=
self
.
_beta2
else
:
attrs
[
'beta2'
]
=
self
.
_beta2
if
isinstance
(
self
.
_epsilon
,
Variable
):
inputs
[
'EpsilonTensor'
]
=
self
.
_epsilon
else
:
attrs
[
'epsilon'
]
=
self
.
_epsilon
if
find_master
:
inputs
[
"MasterParam"
]
=
master_weight
outputs
[
"MasterParamOut"
]
=
master_weight
adamw_op
=
block
.
append_op
(
type
=
self
.
type
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
,
stop_gradient
=
True
)
return
adamw_op
def
_create_optimization_pass
(
self
,
parameters_and_grads
):
optimize_ops
=
super
(
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部