未验证 提交 3e2a3488 编写于 作者: M mapingshuo 提交者: GitHub

add string variable support for RecomputeOptimizer (#25728)

上级 0f623ad7
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import six
import logging import logging
from collections import defaultdict from collections import defaultdict
...@@ -4554,6 +4555,17 @@ class RecomputeOptimizer(Optimizer): ...@@ -4554,6 +4555,17 @@ class RecomputeOptimizer(Optimizer):
self._learning_rate_map = self._optimizer._learning_rate_map self._learning_rate_map = self._optimizer._learning_rate_map
def _set_checkpoints(self, checkpoints): def _set_checkpoints(self, checkpoints):
"""
Args:
checkpoints (list): List of Variable or string
"""
assert isinstance(
checkpoints, list
), "_checkpoints should be a list of Variable or a list of String"
for ckpt in checkpoints:
assert (
isinstance(ckpt, six.string_types) or isinstance(ckpt, Variable)
), "_checkpoints should be a list of Variable or a list of String"
self._checkpoints = checkpoints self._checkpoints = checkpoints
def load(self, stat_dict): def load(self, stat_dict):
...@@ -4690,6 +4702,8 @@ class RecomputeOptimizer(Optimizer): ...@@ -4690,6 +4702,8 @@ class RecomputeOptimizer(Optimizer):
no_grad_set=None) no_grad_set=None)
print("Finished backward") print("Finished backward")
""" """
assert (self._checkpoints is not None
), "You should call _set_checkpoints first"
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
raise NotImplementedError( raise NotImplementedError(
...@@ -4698,11 +4712,15 @@ class RecomputeOptimizer(Optimizer): ...@@ -4698,11 +4712,15 @@ class RecomputeOptimizer(Optimizer):
self._dtype = loss.dtype self._dtype = loss.dtype
program = loss.block.program program = loss.block.program
with program_guard(program, startup_program): with program_guard(program, startup_program):
checkpoint_vars = []
for ckpt in self._checkpoints:
if isinstance(ckpt, Variable):
checkpoint_vars.append(ckpt)
else:
checkpoint_vars.append(loss.block.var(ckpt))
params_grads = append_backward( params_grads = append_backward(
loss, loss, parameter_list, no_grad_set, checkpoints=checkpoint_vars)
parameter_list,
no_grad_set,
checkpoints=self._checkpoints)
# Note: since we can't use all_reduce_op now, # Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad. # dgc_op should be the last op of one grad.
if hasattr(self._optimizer, "_append_dgc_ops"): if hasattr(self._optimizer, "_append_dgc_ops"):
......
...@@ -77,6 +77,7 @@ class TestDGCMomentumOptimizer(unittest.TestCase): ...@@ -77,6 +77,7 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
if use_recompute: if use_recompute:
dgc_momentum_optimizer = optimizer.RecomputeOptimizer( dgc_momentum_optimizer = optimizer.RecomputeOptimizer(
dgc_momentum_optimizer) dgc_momentum_optimizer)
dgc_momentum_optimizer._set_checkpoints([])
dgc_momentum_optimizer.get_accumulators = dgc_momentum_optimizer._optimizer.get_accumulators dgc_momentum_optimizer.get_accumulators = dgc_momentum_optimizer._optimizer.get_accumulators
dgc_momentum_optimizer.get_velocity_str = dgc_momentum_optimizer._optimizer.get_velocity_str dgc_momentum_optimizer.get_velocity_str = dgc_momentum_optimizer._optimizer.get_velocity_str
......
...@@ -714,6 +714,23 @@ class TestRecomputeOptimizer(unittest.TestCase): ...@@ -714,6 +714,23 @@ class TestRecomputeOptimizer(unittest.TestCase):
"elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd" "elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd"
]) ])
def test_str_checkpoints(self):
mul_out, b1_out, b2_out, mean_out = self.net()
self.assertEqual(len(mean_out.block.ops), 4)
self.assertEqual([op.type for op in mean_out.block.ops],
["mul", "elementwise_add", "elementwise_add", "mean"])
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([b1_out.name])
opts, params_grads = recompute_optimizer.minimize(mean_out)
self.assertEqual(len(mean_out.block.ops), 13)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "elementwise_add", "elementwise_add", "mean",
"fill_constant", "mean_grad", "elementwise_add_grad", "mul",
"elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd"
])
def test_multi_checkpoint(self): def test_multi_checkpoint(self):
mul_out, b1_out, b2_out, mean_out = self.net() mul_out, b1_out, b2_out, mean_out = self.net()
self.assertEqual(len(mean_out.block.ops), 4) self.assertEqual(len(mean_out.block.ops), 4)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册