提交 fd3adf58 编写于 作者: D dongdaxiang

add distributed optimizer factory

上级 f6128777
......@@ -14,18 +14,21 @@
__all__ = ["DistributedAdam"]
import ps_pb2 as pslib
import paddle.fluid as fluid
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_inputs
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs
from google.protobuf import text_format
from .node import DownpourWorker, DownpourServer
class DistributedOptimizerImplBase(object):
def __init__(self):
pass
def __init__(self, optimizer):
self.optimizer_ = optimizer
self.learning_rate_ = optimizer._learning_rate
self.regularization_ = optimizer.regularization
def minimize(self,
optimizer,
losses,
startup_program=None,
parameter_list=None,
......@@ -34,11 +37,11 @@ class DistributedOptimizerImplBase(object):
class DistributedAdam(DistributedOptimizerImplBase):
def __init__(self):
def __init__(self, optimizer):
# todo(guru4elephant): add more optimizers here as argument
# todo(guru4elephant): make learning_rate as a variable
self.learning_rate_ = learning_rate
self.window_ = window
super(DistributedAdam, self).__init__(optimizer)
self.window_ = 1
self.type = "downpour"
self.data_norm_name = [
".batch_size", ".batch_square_sum", ".batch_sum",
......@@ -46,8 +49,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
]
def minimize(self,
optimizer,
loss,
losses,
startup_program=None,
parameter_list=None,
no_grad_set=None):
......@@ -64,8 +66,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
Returns:
[optimize_ops, grads_and_weights]
"""
if not isinstance(loss, list):
loss = [loss]
if not isinstance(losses, list):
losses = [losses]
table_name = find_distributed_lookup_table(losses[0].block.program)
prefetch_slots = find_distributed_lookup_table_inputs(
......@@ -92,8 +94,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
program_config.pull_sparse_table_id.extend([sparse_table_index])
program_config.push_sparse_table_id.extend([sparse_table_index])
params_grads = sorted(
append_backward(losses[loss_index], parameter_list,
no_grad_set),
fluid.backward.append_backward(losses[loss_index],
parameter_list, no_grad_set),
key=lambda x: x[0].name)
param_grads_list.append(params_grads)
params = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册