提交 fbcdb29d 编写于 作者: Q Qiao Longfei

fix import issue

上级 866d6bfe
...@@ -18,7 +18,7 @@ from collections import defaultdict ...@@ -18,7 +18,7 @@ from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
import paddle.fluid.transpiler.details.distribute_lookuptable_utils as distribute_lookuptable_utils from paddle.fluid.transpiler.details.distribute_lookuptable_utils import find_distributed_lookup_table
from . import framework from . import framework
from . import layers from . import layers
...@@ -40,6 +40,30 @@ __all__ = [ ...@@ -40,6 +40,30 @@ __all__ = [
] ]
def _process_distribute_lookuptable(program, param_grads, learning_rate):
table_name = find_distributed_lookup_table(program)
table_param = None
table_grad = None
new_param_grads = []
for p, g in param_grads:
if p.name == table_name:
if table_param is not None:
raise RuntimeError(
"multi dist table var found, only support one now!")
table_param = p
table_grad = g
else:
new_param_grads.append((p, g))
sgd_op = None
if table_param is not None:
with table_param.block.program._optimized_guard(
[table_param, table_grad]), framework.name_scope("optimizer"):
sgd_optimizer = SGD(learning_rate)
sgd_op = sgd_optimizer._append_optimize_op(table_param.block, (
table_param, table_grad))
return new_param_grads, (table_param, table_grad), sgd_op
class Optimizer(object): class Optimizer(object):
"""Optimizer Base class. """Optimizer Base class.
...@@ -263,7 +287,7 @@ class Optimizer(object): ...@@ -263,7 +287,7 @@ class Optimizer(object):
params_grads = sorted(params_grads, key=lambda x: x[0].name) params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \ params_grads, table_param_and_grad, table_optimize_op = \
distribute_lookuptable_utils.process_distribute_lookuptable(loss.block.program, params_grads, self._learning_rate) _process_distribute_lookuptable(loss.block.program, params_grads, self._learning_rate)
params_grads = append_gradient_clip_ops(params_grads) params_grads = append_gradient_clip_ops(params_grads)
...@@ -273,6 +297,7 @@ class Optimizer(object): ...@@ -273,6 +297,7 @@ class Optimizer(object):
optimize_ops = self._create_optimization_pass(params_grads, loss, optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program) startup_program)
if table_optimize_op is not None:
optimize_ops.append(table_optimize_op) optimize_ops.append(table_optimize_op)
params_grads.append(table_param_and_grad) params_grads.append(table_param_and_grad)
return optimize_ops, params_grads return optimize_ops, params_grads
......
...@@ -40,27 +40,3 @@ def find_distributed_lookup_table(program): ...@@ -40,27 +40,3 @@ def find_distributed_lookup_table(program):
assert op.input("W")[0] != table_name assert op.input("W")[0] != table_name
return table_name return table_name
def process_distribute_lookuptable(program, param_grads, learning_rate):
table_name = find_distributed_lookup_table(program)
table_param = None
table_grad = None
new_param_grads = []
for p, g in param_grads:
if p.name == table_name:
if table_param is not None:
raise RuntimeError(
"multi dist table var found, only support one now!")
table_param = p
table_grad = g
else:
new_param_grads.append((p, g))
sgd_op = None
if table_param is not None:
with table_param.block.program._optimized_guard(
[table_param, table_grad]), framework.name_scope("optimizer"):
sgd_optimizer = optimizer.SGD(learning_rate)
sgd_op = sgd_optimizer._append_optimize_op(table_param.block, (
table_param, table_grad))
return new_param_grads, (table_param, table_grad), sgd_op
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册