diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index ec8bed45dc4833f47e67895a5e316f13bcae5261..e0ee9955b8cadbd329758f7f21e216859ddb8176 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -18,7 +18,7 @@ from collections import defaultdict from contextlib import contextmanager from paddle.fluid.framework import Program, Variable, name_scope, default_main_program -import paddle.fluid.transpiler.details.distribute_lookuptable_utils as distribute_lookuptable_utils +from paddle.fluid.transpiler.details.distribute_lookuptable_utils import find_distributed_lookup_table from . import framework from . import layers @@ -40,6 +40,30 @@ __all__ = [ ] +def _process_distribute_lookuptable(program, param_grads, learning_rate): + table_name = find_distributed_lookup_table(program) + table_param = None + table_grad = None + new_param_grads = [] + for p, g in param_grads: + if p.name == table_name: + if table_param is not None: + raise RuntimeError( + "multi dist table var found, only support one now!") + table_param = p + table_grad = g + else: + new_param_grads.append((p, g)) + sgd_op = None + if table_param is not None: + with table_param.block.program._optimized_guard( + [table_param, table_grad]), framework.name_scope("optimizer"): + sgd_optimizer = SGD(learning_rate) + sgd_op = sgd_optimizer._append_optimize_op(table_param.block, ( + table_param, table_grad)) + return new_param_grads, (table_param, table_grad), sgd_op + + class Optimizer(object): """Optimizer Base class. @@ -263,7 +287,7 @@ class Optimizer(object): params_grads = sorted(params_grads, key=lambda x: x[0].name) params_grads, table_param_and_grad, table_optimize_op = \ - distribute_lookuptable_utils.process_distribute_lookuptable(loss.block.program, params_grads, self._learning_rate) + _process_distribute_lookuptable(loss.block.program, params_grads, self._learning_rate) params_grads = append_gradient_clip_ops(params_grads) @@ -273,8 +297,9 @@ class Optimizer(object): optimize_ops = self._create_optimization_pass(params_grads, loss, startup_program) - optimize_ops.append(table_optimize_op) - params_grads.append(table_param_and_grad) + if table_optimize_op is not None: + optimize_ops.append(table_optimize_op) + params_grads.append(table_param_and_grad) return optimize_ops, params_grads diff --git a/python/paddle/fluid/transpiler/details/distribute_lookuptable_utils.py b/python/paddle/fluid/transpiler/details/distribute_lookuptable_utils.py index ab1b551a2eeb4af6a09748816476e679b55052a9..bc4a9e7a4e9df9c778e560d58aa5c6ff70165710 100644 --- a/python/paddle/fluid/transpiler/details/distribute_lookuptable_utils.py +++ b/python/paddle/fluid/transpiler/details/distribute_lookuptable_utils.py @@ -40,27 +40,3 @@ def find_distributed_lookup_table(program): assert op.input("W")[0] != table_name return table_name - - -def process_distribute_lookuptable(program, param_grads, learning_rate): - table_name = find_distributed_lookup_table(program) - table_param = None - table_grad = None - new_param_grads = [] - for p, g in param_grads: - if p.name == table_name: - if table_param is not None: - raise RuntimeError( - "multi dist table var found, only support one now!") - table_param = p - table_grad = g - else: - new_param_grads.append((p, g)) - sgd_op = None - if table_param is not None: - with table_param.block.program._optimized_guard( - [table_param, table_grad]), framework.name_scope("optimizer"): - sgd_optimizer = optimizer.SGD(learning_rate) - sgd_op = sgd_optimizer._append_optimize_op(table_param.block, ( - table_param, table_grad)) - return new_param_grads, (table_param, table_grad), sgd_op