未验证 提交 2ce51d13 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #8280 from Yancey1989/return_all_opt_ops

Optimize return all optimize and related ops
......@@ -740,6 +740,9 @@ class Block(object):
raise e
self.desc.remove_op(start, end + 1)
def slice_ops(self, start, end):
return list(self.ops)[start:end]
def prepend_op(self, *args, **kwargs):
op_desc = self.desc.prepend_op()
op = Operator(self, op_desc, *args, **kwargs)
......
......@@ -190,6 +190,8 @@ class Optimizer(object):
# Create any accumulators
program = loss.block.program
with program_guard(program, startup_program):
global_block = framework.default_main_program().global_block()
start = len(global_block.ops)
self.helper = LayerHelper(self.__class__.__name__)
self._create_accumulators(loss.block,
[p[0] for p in parameters_and_grads])
......@@ -203,19 +205,14 @@ class Optimizer(object):
param_and_grad)
optimize_ops.append(optimize_op)
# Returned list of ops can include more ops in addition
# to optimization ops
return_ops = optimize_ops
# Get custom finish ops for subclasses
# FIXME: Need to fix this once we figure out how to handle dependencies
finish_ops = self._finish_update(loss.block)
if finish_ops is not None:
return_ops += finish_ops
self._finish_update(loss.block)
if self._global_step is not None:
return_ops.append(self._increment_global_step(loss.block))
return return_ops
self._increment_global_step(loss.block)
end = len(global_block.ops)
return global_block.slice_ops(start, end)
def minimize(self,
loss,
......
......@@ -42,9 +42,9 @@ class TestOptimizer(unittest.TestCase):
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.01)
opts, _ = sgd_optimizer.minimize(mean_out, init_program)
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "sgd")
self.assertEqual(len(opts), 3)
self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "sgd"])
def test_sgd_optimizer_with_global_step(self):
init_program = framework.Program()
......@@ -72,11 +72,10 @@ class TestOptimizer(unittest.TestCase):
sgd_optimizer = optimizer.SGDOptimizer(
learning_rate=learning_rate, global_step=global_step)
opts, _ = sgd_optimizer.minimize(mean_out, init_program)
self.assertEqual(len(opts), 2)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "sgd")
increment_op = opts[1]
self.assertEqual(increment_op.type, "increment")
self.assertEqual(len(opts), 4)
self.assertEqual(
[op.type for op in opts],
["fill_constant", "elementwise_mul", "sgd", "increment"])
# Check init_program
init_ops = init_program.global_block().ops
......@@ -121,9 +120,10 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(
params_grads, mul_out, init_program)
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum")
self.assertEqual(len(opts), 3)
sgd_op = opts[-1]
self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "momentum"])
self.assertFalse(sgd_op.attr('use_nesterov'))
# Check accumulators
......@@ -170,9 +170,10 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(
params_grads, mul_out, init_program)
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "momentum")
self.assertEqual(len(opts), 3)
sgd_op = opts[-1]
self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "momentum"])
self.assertTrue(sgd_op.attr('use_nesterov'))
# Check accumulators
......@@ -228,9 +229,9 @@ class TestAdagradOptimizer(unittest.TestCase):
self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0)
opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out,
init_program)
self.assertEqual(len(opts), 1)
adagrad_op = opts[0]
self.assertEqual(adagrad_op.type, "adagrad")
self.assertEqual(len(opts), 3)
self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "adagrad"])
# Check accumulators
accumulators = adagrad_optimizer.get_accumulators()
......@@ -288,9 +289,10 @@ class TestAdamOptimizer(unittest.TestCase):
self.assertEqual(len(adam_optimizer.get_accumulators()), 0)
opts = adam_optimizer.create_optimization_pass(params_grads, mul_out,
init_program)
self.assertEqual(len(opts), 3)
adam_op = opts[0]
self.assertEqual(adam_op.type, "adam")
self.assertEqual(len(opts), 5)
self.assertEqual(
[op.type for op in opts],
["fill_constant", "elementwise_mul", "adam", "scale", "scale"])
# Check accumulators
accumulators = adam_optimizer.get_accumulators()
......@@ -350,9 +352,10 @@ class TestAdamaxOptimizer(unittest.TestCase):
self.assertEqual(len(adamax_optimizer.get_accumulators()), 0)
opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out,
init_program)
self.assertEqual(len(opts), 2)
adam_op = opts[0]
self.assertEqual(adam_op.type, "adamax")
self.assertEqual(len(opts), 4)
self.assertEqual(
[op.type for op in opts],
["fill_constant", "elementwise_mul", "adamax", "scale"])
# Check accumulators
accumulators = adamax_optimizer.get_accumulators()
......@@ -409,9 +412,10 @@ class TestDecayedAdagradOptimizer(unittest.TestCase):
self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0)
opts = decayed_adagrad_optimizer.create_optimization_pass(
params_grads, mul_out, init_program)
self.assertEqual(len(opts), 1)
decayed_adagrad_op = opts[0]
self.assertEqual(decayed_adagrad_op.type, "decayed_adagrad")
self.assertEqual(len(opts), 3)
self.assertEqual(
[op.type for op in opts],
["fill_constant", "elementwise_mul", "decayed_adagrad"])
# Check accumulators
accumulators = decayed_adagrad_optimizer.get_accumulators()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册