提交 7c880522 编写于 作者: Y Yancey1989

Optimize return all opt ops

上级 ca5dc46a
...@@ -740,6 +740,9 @@ class Block(object): ...@@ -740,6 +740,9 @@ class Block(object):
raise e raise e
self.desc.remove_op(start, end + 1) self.desc.remove_op(start, end + 1)
def slice_ops(self, start, end):
return list(self.ops)[start:end]
def prepend_op(self, *args, **kwargs): def prepend_op(self, *args, **kwargs):
op_desc = self.desc.prepend_op() op_desc = self.desc.prepend_op()
op = Operator(self, op_desc, *args, **kwargs) op = Operator(self, op_desc, *args, **kwargs)
......
...@@ -190,6 +190,8 @@ class Optimizer(object): ...@@ -190,6 +190,8 @@ class Optimizer(object):
# Create any accumulators # Create any accumulators
program = loss.block.program program = loss.block.program
with program_guard(program, startup_program): with program_guard(program, startup_program):
global_block = framework.default_main_program().global_block()
start = len(global_block.ops)
self.helper = LayerHelper(self.__class__.__name__) self.helper = LayerHelper(self.__class__.__name__)
self._create_accumulators(loss.block, self._create_accumulators(loss.block,
[p[0] for p in parameters_and_grads]) [p[0] for p in parameters_and_grads])
...@@ -203,19 +205,14 @@ class Optimizer(object): ...@@ -203,19 +205,14 @@ class Optimizer(object):
param_and_grad) param_and_grad)
optimize_ops.append(optimize_op) optimize_ops.append(optimize_op)
# Returned list of ops can include more ops in addition
# to optimization ops
return_ops = optimize_ops
# Get custom finish ops for subclasses # Get custom finish ops for subclasses
# FIXME: Need to fix this once we figure out how to handle dependencies # FIXME: Need to fix this once we figure out how to handle dependencies
finish_ops = self._finish_update(loss.block) self._finish_update(loss.block)
if finish_ops is not None:
return_ops += finish_ops
if self._global_step is not None: if self._global_step is not None:
return_ops.append(self._increment_global_step(loss.block)) self._increment_global_step(loss.block)
return return_ops end = len(global_block.ops)
return global_block.slice_ops(start, end)
def minimize(self, def minimize(self,
loss, loss,
......
...@@ -42,9 +42,9 @@ class TestOptimizer(unittest.TestCase): ...@@ -42,9 +42,9 @@ class TestOptimizer(unittest.TestCase):
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.01) sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.01)
opts, _ = sgd_optimizer.minimize(mean_out, init_program) opts, _ = sgd_optimizer.minimize(mean_out, init_program)
self.assertEqual(len(opts), 1) self.assertEqual(len(opts), 3)
sgd_op = opts[0] self.assertEqual([op.type for op in opts],
self.assertEqual(sgd_op.type, "sgd") ["fill_constant", "elementwise_mul", "sgd"])
def test_sgd_optimizer_with_global_step(self): def test_sgd_optimizer_with_global_step(self):
init_program = framework.Program() init_program = framework.Program()
...@@ -72,11 +72,10 @@ class TestOptimizer(unittest.TestCase): ...@@ -72,11 +72,10 @@ class TestOptimizer(unittest.TestCase):
sgd_optimizer = optimizer.SGDOptimizer( sgd_optimizer = optimizer.SGDOptimizer(
learning_rate=learning_rate, global_step=global_step) learning_rate=learning_rate, global_step=global_step)
opts, _ = sgd_optimizer.minimize(mean_out, init_program) opts, _ = sgd_optimizer.minimize(mean_out, init_program)
self.assertEqual(len(opts), 2) self.assertEqual(len(opts), 4)
sgd_op = opts[0] self.assertEqual(
self.assertEqual(sgd_op.type, "sgd") [op.type for op in opts],
increment_op = opts[1] ["fill_constant", "elementwise_mul", "sgd", "increment"])
self.assertEqual(increment_op.type, "increment")
# Check init_program # Check init_program
init_ops = init_program.global_block().ops init_ops = init_program.global_block().ops
...@@ -121,9 +120,10 @@ class TestMomentumOptimizer(unittest.TestCase): ...@@ -121,9 +120,10 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass( opts = momentum_optimizer.create_optimization_pass(
params_grads, mul_out, init_program) params_grads, mul_out, init_program)
self.assertEqual(len(opts), 1) self.assertEqual(len(opts), 3)
sgd_op = opts[0] sgd_op = opts[-1]
self.assertEqual(sgd_op.type, "momentum") self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "momentum"])
self.assertFalse(sgd_op.attr('use_nesterov')) self.assertFalse(sgd_op.attr('use_nesterov'))
# Check accumulators # Check accumulators
...@@ -170,9 +170,10 @@ class TestMomentumOptimizer(unittest.TestCase): ...@@ -170,9 +170,10 @@ class TestMomentumOptimizer(unittest.TestCase):
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass( opts = momentum_optimizer.create_optimization_pass(
params_grads, mul_out, init_program) params_grads, mul_out, init_program)
self.assertEqual(len(opts), 1) self.assertEqual(len(opts), 3)
sgd_op = opts[0] sgd_op = opts[-1]
self.assertEqual(sgd_op.type, "momentum") self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "momentum"])
self.assertTrue(sgd_op.attr('use_nesterov')) self.assertTrue(sgd_op.attr('use_nesterov'))
# Check accumulators # Check accumulators
...@@ -228,9 +229,9 @@ class TestAdagradOptimizer(unittest.TestCase): ...@@ -228,9 +229,9 @@ class TestAdagradOptimizer(unittest.TestCase):
self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0) self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0)
opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out, opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out,
init_program) init_program)
self.assertEqual(len(opts), 1) self.assertEqual(len(opts), 3)
adagrad_op = opts[0] self.assertEqual([op.type for op in opts],
self.assertEqual(adagrad_op.type, "adagrad") ["fill_constant", "elementwise_mul", "adagrad"])
# Check accumulators # Check accumulators
accumulators = adagrad_optimizer.get_accumulators() accumulators = adagrad_optimizer.get_accumulators()
...@@ -288,9 +289,10 @@ class TestAdamOptimizer(unittest.TestCase): ...@@ -288,9 +289,10 @@ class TestAdamOptimizer(unittest.TestCase):
self.assertEqual(len(adam_optimizer.get_accumulators()), 0) self.assertEqual(len(adam_optimizer.get_accumulators()), 0)
opts = adam_optimizer.create_optimization_pass(params_grads, mul_out, opts = adam_optimizer.create_optimization_pass(params_grads, mul_out,
init_program) init_program)
self.assertEqual(len(opts), 3) self.assertEqual(len(opts), 5)
adam_op = opts[0] self.assertEqual(
self.assertEqual(adam_op.type, "adam") [op.type for op in opts],
["fill_constant", "elementwise_mul", "adam", "scale", "scale"])
# Check accumulators # Check accumulators
accumulators = adam_optimizer.get_accumulators() accumulators = adam_optimizer.get_accumulators()
...@@ -350,9 +352,10 @@ class TestAdamaxOptimizer(unittest.TestCase): ...@@ -350,9 +352,10 @@ class TestAdamaxOptimizer(unittest.TestCase):
self.assertEqual(len(adamax_optimizer.get_accumulators()), 0) self.assertEqual(len(adamax_optimizer.get_accumulators()), 0)
opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out, opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out,
init_program) init_program)
self.assertEqual(len(opts), 2) self.assertEqual(len(opts), 4)
adam_op = opts[0] self.assertEqual(
self.assertEqual(adam_op.type, "adamax") [op.type for op in opts],
["fill_constant", "elementwise_mul", "adamax", "scale"])
# Check accumulators # Check accumulators
accumulators = adamax_optimizer.get_accumulators() accumulators = adamax_optimizer.get_accumulators()
...@@ -409,9 +412,10 @@ class TestDecayedAdagradOptimizer(unittest.TestCase): ...@@ -409,9 +412,10 @@ class TestDecayedAdagradOptimizer(unittest.TestCase):
self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0) self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0)
opts = decayed_adagrad_optimizer.create_optimization_pass( opts = decayed_adagrad_optimizer.create_optimization_pass(
params_grads, mul_out, init_program) params_grads, mul_out, init_program)
self.assertEqual(len(opts), 1) self.assertEqual(len(opts), 3)
decayed_adagrad_op = opts[0] self.assertEqual(
self.assertEqual(decayed_adagrad_op.type, "decayed_adagrad") [op.type for op in opts],
["fill_constant", "elementwise_mul", "decayed_adagrad"])
# Check accumulators # Check accumulators
accumulators = decayed_adagrad_optimizer.get_accumulators() accumulators = decayed_adagrad_optimizer.get_accumulators()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册