未验证 提交 679a4c28 编写于 作者: W whs 提交者: GitHub

Fix lost of learning rate variable in distillatoin when using lr decay. (#16471)

test=develop
上级 57dc3c19
......@@ -13,7 +13,7 @@
# limitations under the License.
from ..core.strategy import Strategy
from ....framework import Program, program_guard
from ....framework import Program, Variable, program_guard
from .... import Executor
import logging
......@@ -74,8 +74,17 @@ class DistillationStrategy(Strategy):
startup_program = Program()
with program_guard(graph.program, startup_program):
context.distiller_optimizer._name = 'distillation_optimizer'
context.distiller_optimizer.minimize(
graph.var(graph.out_nodes['loss'])._var)
# The learning rate variable may be created in other program.
# Update information in optimizer to make
# learning rate variable being accessible in current program.
optimizer = context.distiller_optimizer
if isinstance(optimizer._learning_rate, Variable):
optimizer._learning_rate_map[
graph.program] = optimizer._learning_rate
optimizer.minimize(graph.var(graph.out_nodes['loss'])._var)
exe = Executor(context.place)
exe.run(startup_program, scope=context.scope)
......
......@@ -402,6 +402,12 @@ class GraphWrapper(object):
elif 'cost' in graph.out_nodes:
target_name = graph.out_nodes['cost']
target = graph.var(target_name)._var
# The learning rate variable may be created in other program.
# Update information in optimizer to make
# learning rate variable being accessible in current program.
if isinstance(optimizer._learning_rate, Variable):
optimizer._learning_rate_map[
graph.program] = optimizer._learning_rate
optimizer.minimize(target, no_grad_set=no_grad_var_names)
exe = Executor(place)
......
......@@ -41,9 +41,11 @@ class TestDistillationStrategy(unittest.TestCase):
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
learning_rate=fluid.layers.piecewise_decay(
boundaries=[5, 10], values=[0.01, 0.001, 0.0001]),
regularization=fluid.regularizer.L2Decay(4e-5))
place = fluid.CUDAPlace(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册