未验证 提交 434f7b9c 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix the global_step & continuous applying error in EMA (#22090)

* Fix the global_step & continuous applying error in EMA

test=develop

* Fix for step 0 & add unit test, test=develop
上级 5de6a191
......@@ -3086,7 +3086,7 @@ class ExponentialMovingAverage(object):
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(cost)
global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter()
global_steps = fluid.layers.autoincreased_step_counter()
ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps)
ema.update()
......@@ -3124,6 +3124,7 @@ class ExponentialMovingAverage(object):
self._name = name if name is not None else ''
self._decay_var = self._get_ema_decay()
self._step_counter_name = "@EMA_STEP_COUNTER@"
self._params_tmps = []
for param in default_main_program().global_block().all_parameters():
if param.do_model_average != False:
......@@ -3144,14 +3145,16 @@ class ExponentialMovingAverage(object):
self.apply_program = Program()
block = self.apply_program.global_block()
with program_guard(main_program=self.apply_program):
decay_pow = self._get_decay_pow(block)
decay_pow, global_step = self._get_decay_pow(block)
for param, tmp in self._params_tmps:
param = block._clone_variable(param)
tmp = block._clone_variable(tmp)
ema = block._clone_variable(self._ema_vars[param.name])
layers.assign(input=param, output=tmp)
# bias correction
ema = ema / (1.0 - decay_pow)
with layers.control_flow.Switch() as switch:
with switch.case(global_step > 0):
layers.assign(output=ema, input=ema / (1.0 - decay_pow))
layers.assign(input=ema, output=param)
self.restore_program = Program()
......@@ -3184,10 +3187,16 @@ class ExponentialMovingAverage(object):
return decay_var
def _get_decay_pow(self, block):
global_steps = layers.learning_rate_scheduler._decay_step_counter()
global_step = layers.create_global_var(
name=self._step_counter_name,
shape=[1],
value=0,
dtype='int64',
persistable=True)
global_step = layers.cast(global_step, "float32")
decay_var = block._clone_variable(self._decay_var)
decay_pow_acc = layers.elementwise_pow(decay_var, global_steps + 1)
return decay_pow_acc
decay_pow_acc = layers.elementwise_pow(decay_var, global_step)
return decay_pow_acc, global_step
def _create_ema_vars(self, param):
param_ema = layers.create_global_var(
......@@ -3204,6 +3213,8 @@ class ExponentialMovingAverage(object):
Update Exponential Moving Average. Should only call this method in
train program.
"""
global_step = layers.autoincreased_step_counter(
counter_name=self._step_counter_name)
param_master_emas = []
for param, tmp in self._params_tmps:
with param.block.program._optimized_guard(
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
class TestExponentialMovingAverage(unittest.TestCase):
def setUp(self):
self._places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0))
self._ema_decay = 0.999
self._param_name = "fc.weight"
self._train_program = fluid.Program()
self._startup_prog = fluid.Program()
with fluid.program_guard(self._train_program, self._startup_prog):
with fluid.unique_name.guard():
data = fluid.data(name='x', shape=[-1, 5], dtype='float32')
hidden = fluid.layers.fc(input=data,
size=10,
param_attr=self._param_name)
cost = fluid.layers.mean(hidden)
self._test_program = fluid.default_main_program().clone(
for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(cost)
self._ema = fluid.optimizer.ExponentialMovingAverage(
self._ema_decay)
self._ema.update()
def train(self, place):
exe = fluid.Executor(place)
exe.run(self._startup_prog)
params = []
for pass_id in range(2):
for batch_id in range(3):
data = np.random.random(size=(10, 5)).astype('float32')
tmp_param = np.array(fluid.global_scope().find_var(
self._param_name).get_tensor())
exe.run(program=self._train_program, feed={'x': data})
tmp_param = np.array(fluid.global_scope().find_var(
self._param_name).get_tensor())
params.append(tmp_param)
with self._ema.apply(exe):
final_ema = np.array(fluid.global_scope().find_var(self._param_name)
.get_tensor())
data = np.random.random(size=(10, 5)).astype('float32')
exe.run(program=self._test_program, feed={'x': data})
return params, final_ema
def test_check_ema(self):
for place in self._places:
params, final_ema = self.train(place)
manu_ema = np.zeros_like(final_ema)
if len(params) > 0:
for param in params:
manu_ema = self._ema_decay * manu_ema + (1 - self._ema_decay
) * param
manu_ema = manu_ema / (1.0 - self._ema_decay**len(params))
self.assertTrue(np.allclose(manu_ema, final_ema))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册