未验证 提交 198fbdfb 编写于 作者: 1 123malin 提交者: GitHub

Add Lookahead and ModelAverage Optimizer (#30004)

* test=develop, add model_average and lookahead
上级 6a19e41f
......@@ -104,6 +104,9 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"sgd", {"ParamOut"}},
{"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"average_accumulates",
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
"out_old_num_accumulates", "out_num_updates"}},
{"momentum", {"ParamOut", "VelocityOut"}},
{"batch_norm", {"MeanOut", "VarianceOut"}},
{"sync_batch_norm", {"MeanOut", "VarianceOut"}},
......
......@@ -43,6 +43,7 @@ import paddle.optimizer
import paddle.metric
import paddle.device
import paddle.regularizer
import paddle.incubate
# TODO: define alias in tensor and framework directory
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
import paddle
import paddle.nn as nn
LOOKAHEAD_K = 5
LOOKAHEAD_ALPHA = 0.2
SGD_LR = 1.0
class TestLookAhead(unittest.TestCase):
def test_lookahead_static(self):
paddle.enable_static()
place = fluid.CPUPlace()
shape = [2, 3, 8, 8]
exe = fluid.Executor(place)
train_program = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(train_program, startup):
with fluid.unique_name.guard():
data = fluid.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
optimizer = paddle.optimizer.SGD(learning_rate=SGD_LR)
lookahead = paddle.incubate.optimizer.LookAhead(
optimizer, alpha=LOOKAHEAD_ALPHA, k=LOOKAHEAD_K)
lookahead.minimize(loss)
exe.run(startup)
slow_param = None
fast_param = None
for i in range(10):
if (i + 1) % LOOKAHEAD_K == 0:
slow_param = slow_param + LOOKAHEAD_ALPHA * (fast_param -
slow_param)
x = np.random.random(size=(10, 1)).astype('float32')
latest_b, b_grad = exe.run(program=train_program,
feed={'X': x},
fetch_list=[
'fc_0.b_0',
'fc_0.b_0@GRAD',
])
if i == 0:
slow_param = latest_b
if (i + 1) % LOOKAHEAD_K == 0:
self.assertAlmostEqual(
slow_param.all(), latest_b.all(), delta=5e-3)
fast_param = latest_b - SGD_LR * b_grad
def test_look_ahead_dygraph(self):
BATCH_SIZE = 16
BATCH_NUM = 4
EPOCH_NUM = 4
IMAGE_SIZE = 784
CLASS_NUM = 10
# define a random dataset
class RandomDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, CLASS_NUM - 1,
(1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
self.bias = self._linear.bias
@paddle.jit.to_static
def forward(self, x):
return self._linear(x)
def train(layer, loader, loss_fn, opt):
idx = 0
slow_param = None
fast_param = None
for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(loader()):
idx += 1
out = layer(image)
loss = loss_fn(out, label)
loss.backward()
fast_param = layer.bias.numpy() - SGD_LR * layer.bias.grad
opt.step()
if idx == 1:
slow_param = fast_param
if idx % LOOKAHEAD_K == 0:
slow_param = slow_param + LOOKAHEAD_ALPHA * (
fast_param - slow_param)
self.assertAlmostEqual(
np.mean(slow_param),
np.mean(layer.bias.numpy()),
delta=5e-3)
opt.clear_grad()
layer = LinearNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = paddle.optimizer.SGD(learning_rate=SGD_LR,
parameters=layer.parameters())
lookahead = paddle.incubate.optimizer.LookAhead(
optimizer, alpha=LOOKAHEAD_ALPHA, k=LOOKAHEAD_K)
# create data loader
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
loader = paddle.io.DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
train(layer, loader, loss_fn, lookahead)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
import paddle
import paddle.nn as nn
class TestModelAverage(unittest.TestCase):
def test_model_average_static(self):
paddle.enable_static()
place = fluid.CPUPlace()
shape = [2, 3, 8, 8]
exe = fluid.Executor(place)
train_program = fluid.Program()
startup = fluid.Program()
test_program = fluid.Program()
with fluid.program_guard(train_program, startup):
with fluid.unique_name.guard():
data = fluid.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
test_program = train_program.clone()
optimizer = paddle.optimizer.Momentum(
learning_rate=0.2, momentum=0.1)
optimizer.minimize(loss)
# build ModelAverage optimizer
model_average = paddle.incubate.optimizer.ModelAverage(
0.15, min_average_window=2, max_average_window=10)
exe.run(startup)
for i in range(10):
x = np.random.random(size=(10, 1)).astype('float32')
latest_b, sum_1, sum_2, sum_3, num_accumulates, old_num_accumulates, num_updates = exe.run(
program=train_program,
feed={'X': x},
fetch_list=[
'fc_0.b_0', 'fc_0.b_0_sum_1_0', 'fc_0.b_0_sum_2_0',
'fc_0.b_0_sum_3_0', 'fc_0.b_0_num_accumulates_0',
'fc_0.b_0_old_num_accumulates_0', 'fc_0.b_0_num_updates_0'
])
self.assertTrue(
np.equal(
sum_1, np.zeros(
shape=[10], dtype='float32')).all())
self.assertTrue(
np.equal(
sum_2, np.zeros(
shape=[10], dtype='float32')).all())
self.assertTrue(
np.equal(
num_accumulates, np.array(
[0], dtype='int64')).all())
self.assertTrue(
np.equal(
old_num_accumulates, np.array(
[2], dtype='int64')).all())
self.assertTrue(
np.equal(
num_updates, np.array(
[10], dtype='int64')).all())
average_b = (sum_1 + sum_2 + sum_3) / (
num_accumulates + old_num_accumulates)
# apply ModelAverage
with model_average.apply(exe):
x = np.random.random(size=(10, 1)).astype('float32')
outs, b = exe.run(program=test_program,
feed={'X': x},
fetch_list=[loss.name, 'fc_0.b_0'])
self.assertAlmostEqual(np.mean(average_b), np.mean(b))
x = np.random.random(size=(10, 1)).astype('float32')
outs, b = exe.run(program=test_program,
feed={'X': x},
fetch_list=[loss.name, 'fc_0.b_0'])
self.assertAlmostEqual(np.mean(latest_b), np.mean(b))
def test_model_average_dygraph(self):
BATCH_SIZE = 16
BATCH_NUM = 4
EPOCH_NUM = 4
IMAGE_SIZE = 784
CLASS_NUM = 10
# define a random dataset
class RandomDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, CLASS_NUM - 1,
(1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
self.bias = self._linear.bias
@paddle.jit.to_static
def forward(self, x):
return self._linear(x)
def train(layer, loader, loss_fn, opt, model_average):
for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(loader()):
out = layer(image)
loss = loss_fn(out, label)
loss.backward()
opt.step()
model_average.step()
opt.clear_grad()
model_average.clear_grad()
# print("Train Epoch {} batch {}: loss = {}, bias = {}".format(
# epoch_id, batch_id, np.mean(loss.numpy()), layer.bias.numpy()))
sum_1 = model_average._get_accumulator('sum_1', layer.bias)
sum_2 = model_average._get_accumulator('sum_2', layer.bias)
sum_3 = model_average._get_accumulator('sum_3', layer.bias)
num_accumulates = model_average._get_accumulator('num_accumulates',
layer.bias)
old_num_accumulates = model_average._get_accumulator(
'old_num_accumulates', layer.bias)
num_updates = model_average._get_accumulator('num_updates',
layer.bias)
return ((sum_1 + sum_2 + sum_3) /
(num_accumulates + old_num_accumulates)).numpy()
def evaluate(layer, loader, loss_fn, check_param):
for batch_id, (image, label) in enumerate(loader()):
out = layer(image)
loss = loss_fn(out, label)
loss.backward()
self.assertAlmostEqual(
np.mean(layer.bias.numpy()),
np.mean(check_param),
delta=5e-3)
# print("Evaluate batch {}: loss = {}, bias = {}".format(
# batch_id, np.mean(loss.numpy()), layer.bias.numpy()))
# create network
layer = LinearNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Momentum(
learning_rate=0.2, momentum=0.1, parameters=layer.parameters())
# build ModelAverage optimizer
model_average = paddle.incubate.optimizer.ModelAverage(
0.15,
parameters=layer.parameters(),
min_average_window=2,
max_average_window=10)
# create data loader
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
loader = paddle.io.DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
eval_loader = paddle.io.DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=1)
# train
check_param = train(layer, loader, loss_fn, optimizer, model_average)
# print(check_param)
with model_average.apply(need_restore=False):
evaluate(layer, eval_loader, loss_fn, check_param)
check_param = (model_average._get_accumulator('restore',
layer.bias)).numpy()
# print(check_param)
# print("\nEvaluate With Restored Paramters")
model_average.restore()
evaluate(layer, eval_loader, loss_fn, check_param)
if __name__ == "__main__":
unittest.main()
......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import optimizer
from ..fluid.contrib import reader
__all__ = []
__all__ += ["reader"]
from ..fluid.contrib import reader
__all__ += optimizer.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .lookahead import LookAhead
from .modelaverage import ModelAverage
__all__ = ['LookAhead', 'ModelAverage']
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.optimizer import Optimizer
from paddle.fluid import core, framework, layers, unique_name
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.fluid.layer_helper import LayerHelper
import paddle
import numpy as np
from paddle.fluid.dygraph import base as imperative_base
__all__ = ["LookAhead"]
class LookAhead(Optimizer):
r"""
This implements the Lookahead optimizer of the
paper : https://arxiv.org/abs/1907.08610.
Lookahead keeps two sets of params: the fast_params and
the slow_params. inner_optimizer update fast_params every
training step. Lookahead updates the slow_params and fast_params
every k training steps as follows:
.. math::
slow\_param_t &= slow\_param_{t-1} + \\alpha * (fast\_param_{t-1} - slow\_param_{t-1})
fast\_param_t &= slow\_param_t
Args:
inner_optimizer (Optimizer): The optimizer that update fast params step by step.
alpha (float, optinal): The learning rate of Lookahead. The default value is 0.5.
k (int, optinal): The slow params is updated every k steps. The default value is 5.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn as nn
BATCH_SIZE = 16
BATCH_NUM = 4
EPOCH_NUM = 4
IMAGE_SIZE = 784
CLASS_NUM = 10
# define a random dataset
class RandomDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, CLASS_NUM - 1,
(1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
self.bias = self._linear.bias
@paddle.jit.to_static
def forward(self, x):
return self._linear(x)
def train(layer, loader, loss_fn, opt):
for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(loader()):
out = layer(image)
loss = loss_fn(out, label)
loss.backward()
opt.step()
opt.clear_grad()
print("Train Epoch {} batch {}: loss = {}".format(
epoch_id, batch_id, np.mean(loss.numpy())))
layer = LinearNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = paddle.optimizer.SGD(learning_rate=0.1, parameters=layer.parameters())
lookahead = paddle.incubate.optimizer.LookAhead(optimizer, alpha=0.2, k=5)
# create data loader
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
loader = paddle.io.DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
train(layer, loader, loss_fn, lookahead)
"""
_slow_str = "slow"
def __init__(self, inner_optimizer, alpha=0.5, k=5, name=None):
assert (inner_optimizer is not None), "inner optimizer can not be None"
assert (
0.0 <= alpha <= 1.0
), "alpha should be larger or equal to 0.0, and less or equal than 1.0"
assert (isinstance(k, int) and k > 0), "k should be a positive integer"
self.inner_optimizer = inner_optimizer
if self.inner_optimizer._parameter_list is None:
parameters = framework.default_main_program().global_block(
).all_parameters()
else:
parameters = self.inner_optimizer._parameter_list
super(LookAhead, self).__init__(
learning_rate=alpha,
parameters=parameters,
weight_decay=None,
grad_clip=None,
name=name)
self.alpha = alpha
self.k = k
self.type = "lookahead"
self.helper = LayerHelper(self.__class__.__name__)
self._global_step_var = None
self._k_var = None
@framework.dygraph_only
@imperative_base.no_grad
def step(self):
"""
Execute the optimizer and update parameters once.
Returns:
None
Examples:
.. code-block:: python
import paddle
import numpy as np
inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32'))
linear = paddle.nn.Linear(10, 1)
out = linear(inp)
loss = paddle.mean(out)
sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters())
lookahead = paddle.incubate.optimizer.LookAhead(sgd, alpha=0.2, k=5)
loss.backward()
lookahead.step()
lookahead.clear_grad()
"""
self.inner_optimizer.step()
params_grads = []
for param in self._parameter_list:
if not param.trainable:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))
self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads)
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
for p in parameters:
self._add_accumulator(self._slow_str, p)
def _append_optimize_op(self, block, param_and_grad):
if self._global_step_var is None:
self._global_step_var = layers.create_global_var(
name=unique_name.generate("lookahead_step"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
self.helper.append_op(
type='increment',
inputs={'X': [self._global_step_var]},
outputs={'Out': [self._global_step_var]},
attrs={'step': 1.0})
one_var = paddle.ones(shape=[1], dtype='int32', name='lookahead_ones')
zero_var = paddle.zeros(
shape=[1], dtype='int32', name='lookahead_zeros')
k_var = layers.create_global_var(
name=unique_name.generate("lookahead_k"),
shape=[1],
value=self.k,
dtype='int32',
persistable=True)
mod = paddle.remainder(self._global_step_var, k_var)
cond_1 = paddle.equal(self._global_step_var, one_var)
cond_1 = paddle.cast(cond_1, dtype='float32')
cond_2 = paddle.equal(mod, zero_var)
cond_2 = paddle.cast(cond_2, dtype='float32')
slow_var = self._get_accumulator(self._slow_str, param_and_grad[0])
tmp_var = cond_1 * param_and_grad[0] + (1 - cond_1) * slow_var
paddle.assign(tmp_var, slow_var)
tmp_var = self.alpha * param_and_grad[0] + (1.0 - self.alpha) * slow_var
tmp_var_1 = cond_2 * tmp_var + (1 - cond_2) * param_and_grad[0]
paddle.assign(tmp_var_1, param_and_grad[0])
tmp_var_1 = cond_2 * tmp_var + (1 - cond_2) * slow_var
paddle.assign(tmp_var_1, slow_var)
@imperative_base.no_grad
def minimize(self,
loss,
startup_program=None,
parameters=None,
no_grad_set=None):
"""
Add operations to minimize ``loss`` by updating ``parameters``.
Args:
loss (Tensor): A ``Tensor`` containing the value to minimize.
startup_program (Program, optional): :ref:`api_fluid_Program` for
initializing parameters in ``parameters``. The default value
is None, at this time :ref:`api_fluid_default_startup_program` will be used.
parameters (list, optional): List of ``Tensor`` or ``Tensor.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters
will be updated.
no_grad_set (set, optional): Set of ``Tensor`` or ``Tensor.name`` that don't need
to be updated. The default value is None.
Returns:
tuple: tuple (optimize_ops, params_grads), A list of operators appended
by minimize and a list of (param, grad) tensor pairs, param is
``Parameter``, grad is the gradient value corresponding to the parameter.
In static graph mode, the returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
``fetch_list`` before run, see details in ``Executor``.
Examples:
.. code-block:: python
import paddle
import numpy as np
inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32'))
linear = paddle.nn.Linear(10, 1)
out = linear(inp)
loss = paddle.mean(out)
sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters())
lookahead = paddle.incubate.optimizer.LookAhead(sgd, alpha=0.2, k=5)
loss.backward()
lookahead.minimize(loss)
lookahead.clear_grad()
"""
assert isinstance(loss, Variable), "The loss should be an Tensor."
parameter_list = parameters if parameters \
else self._parameter_list
# Apply inner optimizer to the main_program
optimize_ops, params_grads = self.inner_optimizer.minimize(
loss,
startup_program=startup_program,
parameters=parameters,
no_grad_set=no_grad_set)
_ = self._apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads)
return optimize_ops, params_grads
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.optimizer import Optimizer
from paddle.fluid import core, framework, layers
from paddle.fluid.framework import Program, Variable
from paddle.fluid.layer_helper import LayerHelper
import paddle
import numpy as np
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
__all__ = ["ModelAverage"]
class ModelAverage(Optimizer):
r"""
The ModelAverage optimizer accumulates specific continuous historical
parameters during training. The accumulated historical range can be controlled
by the passed ``average_window_rate`` argument. The averaged ``Parameter`` are
used in the prediction, which usually can improve the accuracy of the prediction.
Accumulate the average of the ``Parameter`` in the sliding window, the result will be saved
in a temporary variable, can be applied to the current model's ``Parameter`` by calling
the ``apply()`` method, and the current model ``Parameter`` can be restored by calling
the ``restore()`` method.
The window size for calculating the average is determined by ``average_window_rate``,
``min_average_window``, ``max_average_window`` and the current ``Parameter`` update times (num_updates).
When the cumulative times (num_accumulates) is greater than the specific window
threshold (average_window), the accumulated ``Parameter`` temporary variable is set to 0.0.
The following example will help to understand the role of these arguments:
::
if num_accumulates >= min_average_window and num_accumulates >= min(max_average_window, num_updates * average_window_rate):
num_accumulates = 0
In the above conditional judgment statement, ``num_accumulates`` indicates the current
accumulated number, which can be abstractly understood as the length of the cumulative window.
The length of the window must be at least the length set by the ``min_average_window`` argument,
and cannot exceed the length specified by the ``max_average_window`` argument or
``num_updates * average_window_rate``, where ``num_updates`` indicates the current ``Parameter``
update times, ``average_window_rate`` is a coefficient that calculates the length of the window.
Args:
average_window_rate (float): The calculate ratio of the window length relative to ``Parameter`` update times.
parameters (list, optional): List of ``Tensor`` names to update to minimize ``loss``. \
This parameter is required in dygraph mode. \
The default value is None in static mode, at this time all parameters will be updated.
min_average_window (int, optional): the minimum size of average window length. The default value is 10000.
max_average_window (int, optional): The maximum size of average window length. The default value is 10000.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
BATCH_SIZE = 16
BATCH_NUM = 4
EPOCH_NUM = 4
IMAGE_SIZE = 784
CLASS_NUM = 10
# define a random dataset
class RandomDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
self.bias = self._linear.bias
@paddle.jit.to_static
def forward(self, x):
return self._linear(x)
def train(layer, loader, loss_fn, opt, model_average):
for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(loader()):
out = layer(image)
loss = loss_fn(out, label)
loss.backward()
opt.step()
model_average.step()
opt.clear_grad()
model_average.clear_grad()
print("Train Epoch {} batch {}: loss = {}, bias = {}".format(
epoch_id, batch_id, np.mean(loss.numpy()), layer.bias.numpy()))
def evaluate(layer, loader, loss_fn):
for batch_id, (image, label) in enumerate(loader()):
out = layer(image)
loss = loss_fn(out, label)
loss.backward()
print("Evaluate batch {}: loss = {}, bias = {}".format(
batch_id, np.mean(loss.numpy()), layer.bias.numpy()))
# create network
layer = LinearNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = opt.Momentum(learning_rate=0.2, momentum=0.1, parameters=layer.parameters())
model_average = paddle.incubate.optimizer.ModelAverage(0.15,
parameters=layer.parameters(),
min_average_window=2,
max_average_window=10)
# create data loader
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
loader = paddle.io.DataLoader(dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
# create data loader
eval_loader = paddle.io.DataLoader(dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=1)
# train
train(layer, loader, loss_fn, optimizer, model_average)
print("\nEvaluate With ModelAverage")
with model_average.apply(need_restore=False):
evaluate(layer, eval_loader, loss_fn)
print("\nEvaluate With Restored Paramters")
model_average.restore()
evaluate(layer, eval_loader, loss_fn)
"""
def __init__(self,
average_window_rate,
parameters=None,
min_average_window=10000,
max_average_window=10000,
name=None):
super(ModelAverage, self).__init__(
learning_rate=0.0,
parameters=parameters,
weight_decay=None,
grad_clip=None,
name=name)
self.helper = LayerHelper(self.__class__.__name__)
self.average_window = average_window_rate
self.min_average_window = min_average_window
self.max_average_window = max_average_window
self.type = "average_accumulates"
if not framework.in_dygraph_mode():
global_block = framework.default_main_program().global_block()
all_parameters = parameters if parameters else global_block.all_parameters(
)
self._create_accumulators(global_block, all_parameters)
for param in all_parameters:
self._append_optimize_op(global_block, [param, None])
self.apply_program = Program()
block = self.apply_program.global_block()
with framework.program_guard(main_program=self.apply_program):
for param in all_parameters:
self._add_average_apply_op(block, param)
self.restore_program = Program()
block = self.restore_program.global_block()
with framework.program_guard(main_program=self.restore_program):
for param in all_parameters:
self._add_average_restore_op(block, param)
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
for param in parameters:
self._add_accumulator('sum_1', param)
self._add_accumulator('sum_2', param)
self._add_accumulator('sum_3', param)
self._add_accumulator('restore', param)
self._add_accumulator(
'num_accumulates', param, dtype='int64', shape=[1])
self._add_accumulator(
'old_num_accumulates', param, dtype='int64', shape=[1])
self._add_accumulator(
'num_updates', param, dtype='int64', shape=[1])
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
sum_1 = self._get_accumulator('sum_1', param_and_grad[0])
sum_2 = self._get_accumulator('sum_2', param_and_grad[0])
sum_3 = self._get_accumulator('sum_3', param_and_grad[0])
num_accumulates = self._get_accumulator('num_accumulates',
param_and_grad[0])
old_num_accumulates = self._get_accumulator('old_num_accumulates',
param_and_grad[0])
num_updates = self._get_accumulator('num_updates', param_and_grad[0])
if framework.in_dygraph_mode():
_, _, _, _, _, _ = core.ops.average_accumulates(
param_and_grad[0], sum_1, sum_2, sum_3, num_accumulates,
old_num_accumulates, num_updates, sum_1, sum_2, sum_3,
num_accumulates, old_num_accumulates, num_updates,
'average_window', self.average_window, 'min_average_window',
self.min_average_window, 'max_average_window',
self.max_average_window)
return None
block = framework.default_main_program().global_block()
attrs = {
"average_window": self.average_window,
"min_average_window": self.min_average_window,
"max_average_window": self.max_average_window,
}
inputs = {
"param": param_and_grad[0],
"in_sum_1": sum_1,
"in_sum_2": sum_2,
"in_sum_3": sum_3,
"in_num_accumulates": num_accumulates,
"in_old_num_accumulates": old_num_accumulates,
"in_num_updates": num_updates
}
outputs = {
"out_sum_1": sum_1,
"out_sum_2": sum_2,
"out_sum_3": sum_3,
"out_num_accumulates": num_accumulates,
"out_old_num_accumulates": old_num_accumulates,
"out_num_updates": num_updates,
}
average_accumulates_op = block.append_op(
type=self.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True)
return average_accumulates_op
@imperative_base.no_grad
def minimize(self,
loss,
startup_program=None,
parameters=None,
no_grad_set=None):
"""
Add operations to minimize ``loss`` by updating ``parameters``.
Args:
loss (Tensor): A ``Tensor`` containing the value to minimize.
startup_program (Program, optional): :ref:`api_fluid_Program` for
initializing parameters in ``parameters``. The default value
is None, at this time :ref:`api_fluid_default_startup_program` will be used.
parameters (list, optional): List of ``Tensor`` or ``Tensor.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters
will be updated.
no_grad_set (set, optional): Set of ``Tensor`` or ``Tensor.name`` that don't need
to be updated. The default value is None.
Returns:
tuple: tuple (optimize_ops, params_grads), A list of operators appended
by minimize and a list of (param, grad) tensor pairs, param is
``Parameter``, grad is the gradient value corresponding to the parameter.
In static graph mode, the returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
``fetch_list`` before run, see details in ``Executor``.
Examples:
.. code-block:: python
import paddle
import numpy as np
inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32'))
linear = paddle.nn.Linear(10, 1)
out = linear(inp)
loss = paddle.mean(out)
loss.backward()
sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters())
sgd.minimize(loss)
modelaverage = paddle.incubate.optimizer.ModelAverage(0.15,
parameters=linear.parameters(),
min_average_window=2,
max_average_window=4)
modelaverage.minimize(loss)
sgd.clear_grad()
modelaverage.clear_grad()
"""
if framework.in_dygraph_mode():
self.step()
@framework.dygraph_only
@imperative_base.no_grad
def step(self):
"""
Execute the optimizer and update parameters once.
Returns:
None
Examples:
.. code-block:: python
import paddle
import numpy as np
inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32'))
linear = paddle.nn.Linear(10, 1)
out = linear(inp)
loss = paddle.mean(out)
sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters())
modelaverage = paddle.incubate.optimizer.ModelAverage(0.15,
parameters=linear.parameters(),
min_average_window=2,
max_average_window=4)
loss.backward()
sgd.step()
modelaverage.step()
sgd.clear_grad()
modelaverage.clear_grad()
"""
params_grads = []
for param in self._parameter_list:
if not param.trainable:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))
block = framework.default_main_program().global_block()
self._create_accumulators(block, self._parameter_list)
for param_and_grad in params_grads:
self._append_optimize_op(block, param_and_grad)
@signature_safe_contextmanager
@imperative_base.no_grad
def apply(self, executor=None, need_restore=True):
"""
Apply the average of the cumulative ``Parameter`` to the parameters of the current model.
Args:
executor(Executor): The network executor in static-graph mode. The default value is None in dygraph mode.
need_restore(bool): Restore flag variable, if set to True, the network will restore
the parameters of the network to the default value, if set to False,
it will not be restored. The default value is True.
Examples:
.. code-block:: python
import paddle
import numpy as np
inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32'))
linear = paddle.nn.Linear(10, 1)
out = linear(inp)
loss = paddle.mean(out)
loss.backward()
sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters())
modelaverage = paddle.incubate.optimizer.ModelAverage(0.15,
parameters=linear.parameters(),
min_average_window=2,
max_average_window=4)
sgd.step()
modelaverage.step()
with modelaverage.apply():
for param in linear.parameters():
print(param)
for param in linear.parameters():
print(param)
"""
if framework.in_dygraph_mode():
for param in self._parameter_list:
num_accumulates = self._get_accumulator('num_accumulates',
param)
old_num_accumulates = self._get_accumulator(
'old_num_accumulates', param)
num_updates = self._get_accumulator('num_updates', param)
sum_1 = self._get_accumulator('sum_1', param)
sum_2 = self._get_accumulator('sum_2', param)
sum_3 = self._get_accumulator('sum_3', param)
param_restore = self._get_accumulator('restore', param)
paddle.assign(param, param_restore)
total_param = sum_1 + sum_2 + sum_3
total_accumulates = num_accumulates + old_num_accumulates
total_param = paddle.cast(total_param, dtype='float32')
total_accumulates = paddle.cast(
total_accumulates, dtype='float32')
average_param = total_param / total_accumulates
paddle.assign(average_param, param)
try:
yield
finally:
if need_restore:
self.restore()
return
if executor is None:
raise RuntimeError(
"Executor should not be None in static graph mode.")
executor.run(self.apply_program)
try:
yield
finally:
if need_restore:
self.restore(executor)
@imperative_base.no_grad
def restore(self, executor=None):
"""
Restore ``Parameter`` values of current model.
Args:
executor(Executor): The network executor in static-graph mode. The default value is None in dygraph mode
Examples:
.. code-block:: python
import paddle
import numpy as np
inp = paddle.to_tensor(np.random.random([1, 10]).astype('float32'))
linear = paddle.nn.Linear(10, 1)
out = linear(inp)
loss = paddle.mean(out)
loss.backward()
sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters())
modelaverage = paddle.incubate.optimizer.ModelAverage(0.15,
parameters=linear.parameters(),
min_average_window=2,
max_average_window=4)
sgd.step()
modelaverage.step()
with modelaverage.apply(need_restore=False):
for param in linear.parameters():
print(param)
for param in linear.parameters():
print(param)
modelaverage.restore()
for param in linear.parameters():
print(param)
"""
if framework.in_dygraph_mode():
for param in self._parameter_list:
param_restore = self._get_accumulator('restore', param)
paddle.assign(param_restore, param)
return
if executor is None:
raise RuntimeError(
"Executor should not be None in static graph mode.")
executor.run(self.restore_program)
def _add_average_apply_op(self, block, param):
param = block._clone_variable(param)
grad = block._clone_variable(self._get_accumulator('restore', param))
sum_1 = block._clone_variable(self._get_accumulator('sum_1', param))
sum_2 = block._clone_variable(self._get_accumulator('sum_2', param))
sum_3 = block._clone_variable(self._get_accumulator('sum_3', param))
num_accumulates = block._clone_variable(
self._get_accumulator('num_accumulates', param))
old_num_accumulates = block._clone_variable(
self._get_accumulator('old_num_accumulates', param))
num_updates = block._clone_variable(
self._get_accumulator('num_updates', param))
# backup param value to grad
layers.assign(input=param, output=grad)
# param = (sum_1 + sum_2 + sum_3) / (num_accumulates + old_num_accumulates)
tmp = layers.sum(x=[num_accumulates, old_num_accumulates])
sum = layers.sum(x=[sum_1, sum_2, sum_3])
tmp = layers.cast(
x=tmp, dtype='float32' if self._dtype == None else self._dtype)
sum = layers.cast(
x=sum, dtype='float32' if self._dtype == None else self._dtype)
layers.ops._elementwise_div(x=sum, y=tmp, out=param)
def _add_average_restore_op(self, block, param):
param = block._clone_variable(param)
grad = block._clone_variable(self._get_accumulator('restore', param))
layers.assign(input=grad, output=param)
......@@ -143,6 +143,7 @@ packages=['paddle',
'paddle.reader',
'paddle.distributed',
'paddle.incubate',
'paddle.incubate.optimizer',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.meta_optimizers',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册