Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
198fbdfb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
198fbdfb
编写于
1月 07, 2021
作者:
1
123malin
提交者:
GitHub
1月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Lookahead and ModelAverage Optimizer (#30004)
* test=develop, add model_average and lookahead
上级
6a19e41f
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
1203 addition
and
2 deletion
+1203
-2
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+3
-0
python/paddle/__init__.py
python/paddle/__init__.py
+1
-0
python/paddle/fluid/tests/unittests/test_lookahead.py
python/paddle/fluid/tests/unittests/test_lookahead.py
+146
-0
python/paddle/fluid/tests/unittests/test_modelaverage.py
python/paddle/fluid/tests/unittests/test_modelaverage.py
+209
-0
python/paddle/incubate/__init__.py
python/paddle/incubate/__init__.py
+4
-2
python/paddle/incubate/optimizer/__init__.py
python/paddle/incubate/optimizer/__init__.py
+18
-0
python/paddle/incubate/optimizer/lookahead.py
python/paddle/incubate/optimizer/lookahead.py
+296
-0
python/paddle/incubate/optimizer/modelaverage.py
python/paddle/incubate/optimizer/modelaverage.py
+525
-0
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
198fbdfb
...
...
@@ -104,6 +104,9 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{
"sgd"
,
{
"ParamOut"
}},
{
"adam"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
}},
{
"average_accumulates"
,
{
"out_sum_1"
,
"out_sum_2"
,
"out_sum_3"
,
"out_num_accumulates"
,
"out_old_num_accumulates"
,
"out_num_updates"
}},
{
"momentum"
,
{
"ParamOut"
,
"VelocityOut"
}},
{
"batch_norm"
,
{
"MeanOut"
,
"VarianceOut"
}},
{
"sync_batch_norm"
,
{
"MeanOut"
,
"VarianceOut"
}},
...
...
python/paddle/__init__.py
浏览文件 @
198fbdfb
...
...
@@ -43,6 +43,7 @@ import paddle.optimizer
import
paddle.metric
import
paddle.device
import
paddle.regularizer
import
paddle.incubate
# TODO: define alias in tensor and framework directory
...
...
python/paddle/fluid/tests/unittests/test_lookahead.py
0 → 100644
浏览文件 @
198fbdfb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
from
paddle.fluid
import
core
from
paddle.fluid.op
import
Operator
import
paddle.fluid
as
fluid
import
paddle
import
paddle.nn
as
nn
LOOKAHEAD_K
=
5
LOOKAHEAD_ALPHA
=
0.2
SGD_LR
=
1.0
class
TestLookAhead
(
unittest
.
TestCase
):
def
test_lookahead_static
(
self
):
paddle
.
enable_static
()
place
=
fluid
.
CPUPlace
()
shape
=
[
2
,
3
,
8
,
8
]
exe
=
fluid
.
Executor
(
place
)
train_program
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
train_program
,
startup
):
with
fluid
.
unique_name
.
guard
():
data
=
fluid
.
data
(
name
=
'X'
,
shape
=
[
None
,
1
],
dtype
=
'float32'
)
hidden
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
10
)
loss
=
fluid
.
layers
.
mean
(
hidden
)
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
SGD_LR
)
lookahead
=
paddle
.
incubate
.
optimizer
.
LookAhead
(
optimizer
,
alpha
=
LOOKAHEAD_ALPHA
,
k
=
LOOKAHEAD_K
)
lookahead
.
minimize
(
loss
)
exe
.
run
(
startup
)
slow_param
=
None
fast_param
=
None
for
i
in
range
(
10
):
if
(
i
+
1
)
%
LOOKAHEAD_K
==
0
:
slow_param
=
slow_param
+
LOOKAHEAD_ALPHA
*
(
fast_param
-
slow_param
)
x
=
np
.
random
.
random
(
size
=
(
10
,
1
)).
astype
(
'float32'
)
latest_b
,
b_grad
=
exe
.
run
(
program
=
train_program
,
feed
=
{
'X'
:
x
},
fetch_list
=
[
'fc_0.b_0'
,
'fc_0.b_0@GRAD'
,
])
if
i
==
0
:
slow_param
=
latest_b
if
(
i
+
1
)
%
LOOKAHEAD_K
==
0
:
self
.
assertAlmostEqual
(
slow_param
.
all
(),
latest_b
.
all
(),
delta
=
5e-3
)
fast_param
=
latest_b
-
SGD_LR
*
b_grad
def
test_look_ahead_dygraph
(
self
):
BATCH_SIZE
=
16
BATCH_NUM
=
4
EPOCH_NUM
=
4
IMAGE_SIZE
=
784
CLASS_NUM
=
10
# define a random dataset
class
RandomDataset
(
paddle
.
io
.
Dataset
):
def
__init__
(
self
,
num_samples
):
self
.
num_samples
=
num_samples
def
__getitem__
(
self
,
idx
):
image
=
np
.
random
.
random
([
IMAGE_SIZE
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
CLASS_NUM
-
1
,
(
1
,
)).
astype
(
'int64'
)
return
image
,
label
def
__len__
(
self
):
return
self
.
num_samples
class
LinearNet
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
LinearNet
,
self
).
__init__
()
self
.
_linear
=
nn
.
Linear
(
IMAGE_SIZE
,
CLASS_NUM
)
self
.
bias
=
self
.
_linear
.
bias
@
paddle
.
jit
.
to_static
def
forward
(
self
,
x
):
return
self
.
_linear
(
x
)
def
train
(
layer
,
loader
,
loss_fn
,
opt
):
idx
=
0
slow_param
=
None
fast_param
=
None
for
epoch_id
in
range
(
EPOCH_NUM
):
for
batch_id
,
(
image
,
label
)
in
enumerate
(
loader
()):
idx
+=
1
out
=
layer
(
image
)
loss
=
loss_fn
(
out
,
label
)
loss
.
backward
()
fast_param
=
layer
.
bias
.
numpy
()
-
SGD_LR
*
layer
.
bias
.
grad
opt
.
step
()
if
idx
==
1
:
slow_param
=
fast_param
if
idx
%
LOOKAHEAD_K
==
0
:
slow_param
=
slow_param
+
LOOKAHEAD_ALPHA
*
(
fast_param
-
slow_param
)
self
.
assertAlmostEqual
(
np
.
mean
(
slow_param
),
np
.
mean
(
layer
.
bias
.
numpy
()),
delta
=
5e-3
)
opt
.
clear_grad
()
layer
=
LinearNet
()
loss_fn
=
nn
.
CrossEntropyLoss
()
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
SGD_LR
,
parameters
=
layer
.
parameters
())
lookahead
=
paddle
.
incubate
.
optimizer
.
LookAhead
(
optimizer
,
alpha
=
LOOKAHEAD_ALPHA
,
k
=
LOOKAHEAD_K
)
# create data loader
dataset
=
RandomDataset
(
BATCH_NUM
*
BATCH_SIZE
)
loader
=
paddle
.
io
.
DataLoader
(
dataset
,
batch_size
=
BATCH_SIZE
,
shuffle
=
True
,
drop_last
=
True
,
num_workers
=
2
)
train
(
layer
,
loader
,
loss_fn
,
lookahead
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_modelaverage.py
0 → 100644
浏览文件 @
198fbdfb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
from
paddle.fluid
import
core
from
paddle.fluid.op
import
Operator
import
paddle.fluid
as
fluid
import
paddle
import
paddle.nn
as
nn
class
TestModelAverage
(
unittest
.
TestCase
):
def
test_model_average_static
(
self
):
paddle
.
enable_static
()
place
=
fluid
.
CPUPlace
()
shape
=
[
2
,
3
,
8
,
8
]
exe
=
fluid
.
Executor
(
place
)
train_program
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
test_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
train_program
,
startup
):
with
fluid
.
unique_name
.
guard
():
data
=
fluid
.
data
(
name
=
'X'
,
shape
=
[
None
,
1
],
dtype
=
'float32'
)
hidden
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
10
)
loss
=
fluid
.
layers
.
mean
(
hidden
)
test_program
=
train_program
.
clone
()
optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.2
,
momentum
=
0.1
)
optimizer
.
minimize
(
loss
)
# build ModelAverage optimizer
model_average
=
paddle
.
incubate
.
optimizer
.
ModelAverage
(
0.15
,
min_average_window
=
2
,
max_average_window
=
10
)
exe
.
run
(
startup
)
for
i
in
range
(
10
):
x
=
np
.
random
.
random
(
size
=
(
10
,
1
)).
astype
(
'float32'
)
latest_b
,
sum_1
,
sum_2
,
sum_3
,
num_accumulates
,
old_num_accumulates
,
num_updates
=
exe
.
run
(
program
=
train_program
,
feed
=
{
'X'
:
x
},
fetch_list
=
[
'fc_0.b_0'
,
'fc_0.b_0_sum_1_0'
,
'fc_0.b_0_sum_2_0'
,
'fc_0.b_0_sum_3_0'
,
'fc_0.b_0_num_accumulates_0'
,
'fc_0.b_0_old_num_accumulates_0'
,
'fc_0.b_0_num_updates_0'
])
self
.
assertTrue
(
np
.
equal
(
sum_1
,
np
.
zeros
(
shape
=
[
10
],
dtype
=
'float32'
)).
all
())
self
.
assertTrue
(
np
.
equal
(
sum_2
,
np
.
zeros
(
shape
=
[
10
],
dtype
=
'float32'
)).
all
())
self
.
assertTrue
(
np
.
equal
(
num_accumulates
,
np
.
array
(
[
0
],
dtype
=
'int64'
)).
all
())
self
.
assertTrue
(
np
.
equal
(
old_num_accumulates
,
np
.
array
(
[
2
],
dtype
=
'int64'
)).
all
())
self
.
assertTrue
(
np
.
equal
(
num_updates
,
np
.
array
(
[
10
],
dtype
=
'int64'
)).
all
())
average_b
=
(
sum_1
+
sum_2
+
sum_3
)
/
(
num_accumulates
+
old_num_accumulates
)
# apply ModelAverage
with
model_average
.
apply
(
exe
):
x
=
np
.
random
.
random
(
size
=
(
10
,
1
)).
astype
(
'float32'
)
outs
,
b
=
exe
.
run
(
program
=
test_program
,
feed
=
{
'X'
:
x
},
fetch_list
=
[
loss
.
name
,
'fc_0.b_0'
])
self
.
assertAlmostEqual
(
np
.
mean
(
average_b
),
np
.
mean
(
b
))
x
=
np
.
random
.
random
(
size
=
(
10
,
1
)).
astype
(
'float32'
)
outs
,
b
=
exe
.
run
(
program
=
test_program
,
feed
=
{
'X'
:
x
},
fetch_list
=
[
loss
.
name
,
'fc_0.b_0'
])
self
.
assertAlmostEqual
(
np
.
mean
(
latest_b
),
np
.
mean
(
b
))
def
test_model_average_dygraph
(
self
):
BATCH_SIZE
=
16
BATCH_NUM
=
4
EPOCH_NUM
=
4
IMAGE_SIZE
=
784
CLASS_NUM
=
10
# define a random dataset
class
RandomDataset
(
paddle
.
io
.
Dataset
):
def
__init__
(
self
,
num_samples
):
self
.
num_samples
=
num_samples
def
__getitem__
(
self
,
idx
):
image
=
np
.
random
.
random
([
IMAGE_SIZE
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
CLASS_NUM
-
1
,
(
1
,
)).
astype
(
'int64'
)
return
image
,
label
def
__len__
(
self
):
return
self
.
num_samples
class
LinearNet
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
LinearNet
,
self
).
__init__
()
self
.
_linear
=
nn
.
Linear
(
IMAGE_SIZE
,
CLASS_NUM
)
self
.
bias
=
self
.
_linear
.
bias
@
paddle
.
jit
.
to_static
def
forward
(
self
,
x
):
return
self
.
_linear
(
x
)
def
train
(
layer
,
loader
,
loss_fn
,
opt
,
model_average
):
for
epoch_id
in
range
(
EPOCH_NUM
):
for
batch_id
,
(
image
,
label
)
in
enumerate
(
loader
()):
out
=
layer
(
image
)
loss
=
loss_fn
(
out
,
label
)
loss
.
backward
()
opt
.
step
()
model_average
.
step
()
opt
.
clear_grad
()
model_average
.
clear_grad
()
# print("Train Epoch {} batch {}: loss = {}, bias = {}".format(
# epoch_id, batch_id, np.mean(loss.numpy()), layer.bias.numpy()))
sum_1
=
model_average
.
_get_accumulator
(
'sum_1'
,
layer
.
bias
)
sum_2
=
model_average
.
_get_accumulator
(
'sum_2'
,
layer
.
bias
)
sum_3
=
model_average
.
_get_accumulator
(
'sum_3'
,
layer
.
bias
)
num_accumulates
=
model_average
.
_get_accumulator
(
'num_accumulates'
,
layer
.
bias
)
old_num_accumulates
=
model_average
.
_get_accumulator
(
'old_num_accumulates'
,
layer
.
bias
)
num_updates
=
model_average
.
_get_accumulator
(
'num_updates'
,
layer
.
bias
)
return
((
sum_1
+
sum_2
+
sum_3
)
/
(
num_accumulates
+
old_num_accumulates
)).
numpy
()
def
evaluate
(
layer
,
loader
,
loss_fn
,
check_param
):
for
batch_id
,
(
image
,
label
)
in
enumerate
(
loader
()):
out
=
layer
(
image
)
loss
=
loss_fn
(
out
,
label
)
loss
.
backward
()
self
.
assertAlmostEqual
(
np
.
mean
(
layer
.
bias
.
numpy
()),
np
.
mean
(
check_param
),
delta
=
5e-3
)
# print("Evaluate batch {}: loss = {}, bias = {}".format(
# batch_id, np.mean(loss.numpy()), layer.bias.numpy()))
# create network
layer
=
LinearNet
()
loss_fn
=
nn
.
CrossEntropyLoss
()
optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.2
,
momentum
=
0.1
,
parameters
=
layer
.
parameters
())
# build ModelAverage optimizer
model_average
=
paddle
.
incubate
.
optimizer
.
ModelAverage
(
0.15
,
parameters
=
layer
.
parameters
(),
min_average_window
=
2
,
max_average_window
=
10
)
# create data loader
dataset
=
RandomDataset
(
BATCH_NUM
*
BATCH_SIZE
)
loader
=
paddle
.
io
.
DataLoader
(
dataset
,
batch_size
=
BATCH_SIZE
,
shuffle
=
True
,
drop_last
=
True
,
num_workers
=
2
)
eval_loader
=
paddle
.
io
.
DataLoader
(
dataset
,
batch_size
=
BATCH_SIZE
,
shuffle
=
True
,
drop_last
=
True
,
num_workers
=
1
)
# train
check_param
=
train
(
layer
,
loader
,
loss_fn
,
optimizer
,
model_average
)
# print(check_param)
with
model_average
.
apply
(
need_restore
=
False
):
evaluate
(
layer
,
eval_loader
,
loss_fn
,
check_param
)
check_param
=
(
model_average
.
_get_accumulator
(
'restore'
,
layer
.
bias
)).
numpy
()
# print(check_param)
# print("\nEvaluate With Restored Paramters")
model_average
.
restore
()
evaluate
(
layer
,
eval_loader
,
loss_fn
,
check_param
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/incubate/__init__.py
浏览文件 @
198fbdfb
...
...
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
optimizer
from
..fluid.contrib
import
reader
__all__
=
[]
__all__
+=
[
"reader"
]
from
..fluid.contrib
import
reader
__all__
+=
optimizer
.
__all__
python/paddle/incubate/optimizer/__init__.py
0 → 100644
浏览文件 @
198fbdfb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.lookahead
import
LookAhead
from
.modelaverage
import
ModelAverage
__all__
=
[
'LookAhead'
,
'ModelAverage'
]
python/paddle/incubate/optimizer/lookahead.py
0 → 100644
浏览文件 @
198fbdfb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.optimizer
import
Optimizer
from
paddle.fluid
import
core
,
framework
,
layers
,
unique_name
from
paddle.fluid.framework
import
Program
,
Variable
,
name_scope
,
default_main_program
,
default_startup_program
,
device_guard
from
paddle.fluid.layer_helper
import
LayerHelper
import
paddle
import
numpy
as
np
from
paddle.fluid.dygraph
import
base
as
imperative_base
__all__
=
[
"LookAhead"
]
class
LookAhead
(
Optimizer
):
r
"""
This implements the Lookahead optimizer of the
paper : https://arxiv.org/abs/1907.08610.
Lookahead keeps two sets of params: the fast_params and
the slow_params. inner_optimizer update fast_params every
training step. Lookahead updates the slow_params and fast_params
every k training steps as follows:
.. math::
slow\_param_t &= slow\_param_{t-1} + \\alpha * (fast\_param_{t-1} - slow\_param_{t-1})
fast\_param_t &= slow\_param_t
Args:
inner_optimizer (Optimizer): The optimizer that update fast params step by step.
alpha (float, optinal): The learning rate of Lookahead. The default value is 0.5.
k (int, optinal): The slow params is updated every k steps. The default value is 5.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn as nn
BATCH_SIZE = 16
BATCH_NUM = 4
EPOCH_NUM = 4
IMAGE_SIZE = 784
CLASS_NUM = 10
# define a random dataset
class RandomDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, CLASS_NUM - 1,
(1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
self.bias = self._linear.bias
@paddle.jit.to_static
def forward(self, x):
return self._linear(x)
def train(layer, loader, loss_fn, opt):
for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(loader()):
out = layer(image)
loss = loss_fn(out, label)
loss.backward()
opt.step()
opt.clear_grad()
print("Train Epoch {} batch {}: loss = {}".format(
epoch_id, batch_id, np.mean(loss.numpy())))
layer = LinearNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = paddle.optimizer.SGD(learning_rate=0.1, parameters=layer.parameters())
lookahead = paddle.incubate.optimizer.LookAhead(optimizer, alpha=0.2, k=5)
# create data loader
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
loader = paddle.io.DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
train(layer, loader, loss_fn, lookahead)
"""
_slow_str
=
"slow"
def
__init__
(
self
,
inner_optimizer
,
alpha
=
0.5
,
k
=
5
,
name
=
None
):
assert
(
inner_optimizer
is
not
None
),
"inner optimizer can not be None"
assert
(
0.0
<=
alpha
<=
1.0
),
"alpha should be larger or equal to 0.0, and less or equal than 1.0"
assert
(
isinstance
(
k
,
int
)
and
k
>
0
),
"k should be a positive integer"
self
.
inner_optimizer
=
inner_optimizer
if
self
.
inner_optimizer
.
_parameter_list
is
None
:
parameters
=
framework
.
default_main_program
().
global_block
(
).
all_parameters
()
else
:
parameters
=
self
.
inner_optimizer
.
_parameter_list
super
(
LookAhead
,
self
).
__init__
(
learning_rate
=
alpha
,
parameters
=
parameters
,
weight_decay
=
None
,
grad_clip
=
None
,
name
=
name
)
self
.
alpha
=
alpha
self
.
k
=
k
self
.
type
=
"lookahead"
self
.
helper
=
LayerHelper
(
self
.
__class__
.
__name__
)
self
.
_global_step_var
=
None
self
.
_k_var
=
None
@
framework
.
dygraph_only
@
imperative_base
.
no_grad
def
step
(
self
):
"""
Execute the optimizer and update parameters once.
Returns:
None
Examples:
.. code-block:: python
import paddle
import numpy as np
inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32'))
linear = paddle.nn.Linear(10, 1)
out = linear(inp)
loss = paddle.mean(out)
sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters())
lookahead = paddle.incubate.optimizer.LookAhead(sgd, alpha=0.2, k=5)
loss.backward()
lookahead.step()
lookahead.clear_grad()
"""
self
.
inner_optimizer
.
step
()
params_grads
=
[]
for
param
in
self
.
_parameter_list
:
if
not
param
.
trainable
:
continue
if
param
.
_grad_ivar
()
is
not
None
:
grad_var
=
param
.
_grad_ivar
()
params_grads
.
append
((
param
,
grad_var
))
self
.
_apply_optimize
(
loss
=
None
,
startup_program
=
None
,
params_grads
=
params_grads
)
def
_create_accumulators
(
self
,
block
,
parameters
):
assert
isinstance
(
block
,
framework
.
Block
)
for
p
in
parameters
:
self
.
_add_accumulator
(
self
.
_slow_str
,
p
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
if
self
.
_global_step_var
is
None
:
self
.
_global_step_var
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"lookahead_step"
),
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
self
.
helper
.
append_op
(
type
=
'increment'
,
inputs
=
{
'X'
:
[
self
.
_global_step_var
]},
outputs
=
{
'Out'
:
[
self
.
_global_step_var
]},
attrs
=
{
'step'
:
1.0
})
one_var
=
paddle
.
ones
(
shape
=
[
1
],
dtype
=
'int32'
,
name
=
'lookahead_ones'
)
zero_var
=
paddle
.
zeros
(
shape
=
[
1
],
dtype
=
'int32'
,
name
=
'lookahead_zeros'
)
k_var
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"lookahead_k"
),
shape
=
[
1
],
value
=
self
.
k
,
dtype
=
'int32'
,
persistable
=
True
)
mod
=
paddle
.
remainder
(
self
.
_global_step_var
,
k_var
)
cond_1
=
paddle
.
equal
(
self
.
_global_step_var
,
one_var
)
cond_1
=
paddle
.
cast
(
cond_1
,
dtype
=
'float32'
)
cond_2
=
paddle
.
equal
(
mod
,
zero_var
)
cond_2
=
paddle
.
cast
(
cond_2
,
dtype
=
'float32'
)
slow_var
=
self
.
_get_accumulator
(
self
.
_slow_str
,
param_and_grad
[
0
])
tmp_var
=
cond_1
*
param_and_grad
[
0
]
+
(
1
-
cond_1
)
*
slow_var
paddle
.
assign
(
tmp_var
,
slow_var
)
tmp_var
=
self
.
alpha
*
param_and_grad
[
0
]
+
(
1.0
-
self
.
alpha
)
*
slow_var
tmp_var_1
=
cond_2
*
tmp_var
+
(
1
-
cond_2
)
*
param_and_grad
[
0
]
paddle
.
assign
(
tmp_var_1
,
param_and_grad
[
0
])
tmp_var_1
=
cond_2
*
tmp_var
+
(
1
-
cond_2
)
*
slow_var
paddle
.
assign
(
tmp_var_1
,
slow_var
)
@
imperative_base
.
no_grad
def
minimize
(
self
,
loss
,
startup_program
=
None
,
parameters
=
None
,
no_grad_set
=
None
):
"""
Add operations to minimize ``loss`` by updating ``parameters``.
Args:
loss (Tensor): A ``Tensor`` containing the value to minimize.
startup_program (Program, optional): :ref:`api_fluid_Program` for
initializing parameters in ``parameters``. The default value
is None, at this time :ref:`api_fluid_default_startup_program` will be used.
parameters (list, optional): List of ``Tensor`` or ``Tensor.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters
will be updated.
no_grad_set (set, optional): Set of ``Tensor`` or ``Tensor.name`` that don't need
to be updated. The default value is None.
Returns:
tuple: tuple (optimize_ops, params_grads), A list of operators appended
by minimize and a list of (param, grad) tensor pairs, param is
``Parameter``, grad is the gradient value corresponding to the parameter.
In static graph mode, the returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
``fetch_list`` before run, see details in ``Executor``.
Examples:
.. code-block:: python
import paddle
import numpy as np
inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32'))
linear = paddle.nn.Linear(10, 1)
out = linear(inp)
loss = paddle.mean(out)
sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters())
lookahead = paddle.incubate.optimizer.LookAhead(sgd, alpha=0.2, k=5)
loss.backward()
lookahead.minimize(loss)
lookahead.clear_grad()
"""
assert
isinstance
(
loss
,
Variable
),
"The loss should be an Tensor."
parameter_list
=
parameters
if
parameters
\
else
self
.
_parameter_list
# Apply inner optimizer to the main_program
optimize_ops
,
params_grads
=
self
.
inner_optimizer
.
minimize
(
loss
,
startup_program
=
startup_program
,
parameters
=
parameters
,
no_grad_set
=
no_grad_set
)
_
=
self
.
_apply_optimize
(
loss
,
startup_program
=
startup_program
,
params_grads
=
params_grads
)
return
optimize_ops
,
params_grads
python/paddle/incubate/optimizer/modelaverage.py
0 → 100644
浏览文件 @
198fbdfb
此差异已折叠。
点击以展开。
python/setup.py.in
浏览文件 @
198fbdfb
...
...
@@ -143,6 +143,7 @@ packages=['paddle',
'paddle.reader',
'paddle.distributed',
'paddle.incubate',
'paddle.incubate.optimizer',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.meta_optimizers',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录