提交 37596000 编写于 作者: D dongdaxiang

add doc string for downpour.py and distribute_lookup_table.py

上级 854ee964
......@@ -16,31 +16,51 @@ LOOKUP_TABLE_TYPE = "lookup_table"
def find_distributed_lookup_table_inputs(program, table_name):
"""
Find input variable of distribute lookup table in program.
We only support one distribute table now.
Args:
program(Program): given program, locate distributed lookup table
table_name(str): given table name that is found beforehand
Returns:
inputs
"""
local_vars = program.current_block().vars
inputs = []
for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if table_name == op.input("W")[0]:
inputs.extend(
[local_vars[name] for name in op.input("Ids")])
inputs.extend([local_vars[name] for name in op.input("Ids")])
return inputs
def find_distributed_lookup_table_outputs(program, table_name):
"""
Find output variable of distribute lookup table in program.
We only support one distribute table now.
Args:
program(Program): given program, locate distributed lookup table
table_name(str): given table name that is found beforehand
Returns:
outputs
"""
local_vars = program.current_block().vars
outputs = []
for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if table_name == op.input("W")[0]:
outputs.extend(
[local_vars[name] for name in op.output("Out")])
outputs.extend([local_vars[name] for name in op.output("Out")])
return outputs
def find_distributed_lookup_table(program):
"""
Find distribute lookup table in program.
We only support one distribute table now.
:param program:
:return: table_name or None
Args:
program(Program): given program, locate distributed lookup table
Returns:
table_name or None
"""
table_name = None
......
......@@ -20,6 +20,7 @@ from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_i
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs
from google.protobuf import text_format
class DownpourSGD(object):
"""
Distributed optimizer of downpour stochastic gradient descent
......@@ -35,17 +36,38 @@ class DownpourSGD(object):
downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2)
downpour_sgd.minimize(cost)
"""
def __init__(self, learning_rate=0.001, window=1):
# todo(guru4elephant): add more optimizers here as argument
# todo(guru4elephant): make learning_rate as a variable
self.learning_rate_ = learning_rate
self.window_ = window
self.type = "downpour"
def minimize(self, loss, startup_program=None,
parameter_list=None, no_grad_set=None):
params_grads = sorted(append_backward(
loss, parameter_list, no_grad_set), key=lambda x:x[0].name)
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
"""
DownpounSGD is a distributed optimizer so
that user can call minimize to generate backward
operators and optimization operators within minmize function
Args:
loss(Variable): loss variable defined by user
startup_program(Program): startup program that defined by user
parameter_list(str list): parameter names defined by users
no_grad_set(set): a set of variables that is defined by users
so that these variables do not need gradient computation
Returns:
[ps_param, worker_skipped_ops]
ps_param: parameter server protobuf desc
worker_skipped_ops: operator names that need
to be skipped during execution
"""
params_grads = sorted(
append_backward(loss, parameter_list, no_grad_set),
key=lambda x: x[0].name)
table_name = find_distributed_lookup_table(loss.block.program)
prefetch_slots = find_distributed_lookup_table_inputs(
loss.block.program, table_name)
......@@ -67,12 +89,12 @@ class DownpourSGD(object):
grads.append(i[1])
server.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb)
server.add_dense_table(dense_table_index, self.learning_rate_,
params, grads)
server.add_dense_table(dense_table_index, self.learning_rate_, params,
grads)
worker.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb)
worker.add_dense_table(dense_table_index, self.learning_rate_,
params, grads)
worker.add_dense_table(dense_table_index, self.learning_rate_, params,
grads)
ps_param = pslib.PSParameter()
ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc())
......@@ -80,5 +102,4 @@ class DownpourSGD(object):
# currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
ps_param_str = text_format.MessageToString(ps_param)
return [ps_param, worker_skipped_ops]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册