提交 37596000 编写于 作者: D dongdaxiang

add doc string for downpour.py and distribute_lookup_table.py

上级 854ee964
...@@ -16,31 +16,51 @@ LOOKUP_TABLE_TYPE = "lookup_table" ...@@ -16,31 +16,51 @@ LOOKUP_TABLE_TYPE = "lookup_table"
def find_distributed_lookup_table_inputs(program, table_name): def find_distributed_lookup_table_inputs(program, table_name):
"""
Find input variable of distribute lookup table in program.
We only support one distribute table now.
Args:
program(Program): given program, locate distributed lookup table
table_name(str): given table name that is found beforehand
Returns:
inputs
"""
local_vars = program.current_block().vars local_vars = program.current_block().vars
inputs = [] inputs = []
for op in program.global_block().ops: for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE:
if table_name == op.input("W")[0]: if table_name == op.input("W")[0]:
inputs.extend( inputs.extend([local_vars[name] for name in op.input("Ids")])
[local_vars[name] for name in op.input("Ids")])
return inputs return inputs
def find_distributed_lookup_table_outputs(program, table_name): def find_distributed_lookup_table_outputs(program, table_name):
"""
Find output variable of distribute lookup table in program.
We only support one distribute table now.
Args:
program(Program): given program, locate distributed lookup table
table_name(str): given table name that is found beforehand
Returns:
outputs
"""
local_vars = program.current_block().vars local_vars = program.current_block().vars
outputs = [] outputs = []
for op in program.global_block().ops: for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE:
if table_name == op.input("W")[0]: if table_name == op.input("W")[0]:
outputs.extend( outputs.extend([local_vars[name] for name in op.output("Out")])
[local_vars[name] for name in op.output("Out")])
return outputs return outputs
def find_distributed_lookup_table(program): def find_distributed_lookup_table(program):
""" """
Find distribute lookup table in program. Find distribute lookup table in program.
We only support one distribute table now. We only support one distribute table now.
:param program: Args:
:return: table_name or None program(Program): given program, locate distributed lookup table
Returns:
table_name or None
""" """
table_name = None table_name = None
......
...@@ -20,6 +20,7 @@ from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_i ...@@ -20,6 +20,7 @@ from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_i
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs
from google.protobuf import text_format from google.protobuf import text_format
class DownpourSGD(object): class DownpourSGD(object):
""" """
Distributed optimizer of downpour stochastic gradient descent Distributed optimizer of downpour stochastic gradient descent
...@@ -35,6 +36,7 @@ class DownpourSGD(object): ...@@ -35,6 +36,7 @@ class DownpourSGD(object):
downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2) downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2)
downpour_sgd.minimize(cost) downpour_sgd.minimize(cost)
""" """
def __init__(self, learning_rate=0.001, window=1): def __init__(self, learning_rate=0.001, window=1):
# todo(guru4elephant): add more optimizers here as argument # todo(guru4elephant): add more optimizers here as argument
# todo(guru4elephant): make learning_rate as a variable # todo(guru4elephant): make learning_rate as a variable
...@@ -42,10 +44,30 @@ class DownpourSGD(object): ...@@ -42,10 +44,30 @@ class DownpourSGD(object):
self.window_ = window self.window_ = window
self.type = "downpour" self.type = "downpour"
def minimize(self, loss, startup_program=None, def minimize(self,
parameter_list=None, no_grad_set=None): loss,
params_grads = sorted(append_backward( startup_program=None,
loss, parameter_list, no_grad_set), key=lambda x:x[0].name) parameter_list=None,
no_grad_set=None):
"""
DownpounSGD is a distributed optimizer so
that user can call minimize to generate backward
operators and optimization operators within minmize function
Args:
loss(Variable): loss variable defined by user
startup_program(Program): startup program that defined by user
parameter_list(str list): parameter names defined by users
no_grad_set(set): a set of variables that is defined by users
so that these variables do not need gradient computation
Returns:
[ps_param, worker_skipped_ops]
ps_param: parameter server protobuf desc
worker_skipped_ops: operator names that need
to be skipped during execution
"""
params_grads = sorted(
append_backward(loss, parameter_list, no_grad_set),
key=lambda x: x[0].name)
table_name = find_distributed_lookup_table(loss.block.program) table_name = find_distributed_lookup_table(loss.block.program)
prefetch_slots = find_distributed_lookup_table_inputs( prefetch_slots = find_distributed_lookup_table_inputs(
loss.block.program, table_name) loss.block.program, table_name)
...@@ -67,12 +89,12 @@ class DownpourSGD(object): ...@@ -67,12 +89,12 @@ class DownpourSGD(object):
grads.append(i[1]) grads.append(i[1])
server.add_sparse_table(sparse_table_index, self.learning_rate_, server.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
server.add_dense_table(dense_table_index, self.learning_rate_, server.add_dense_table(dense_table_index, self.learning_rate_, params,
params, grads) grads)
worker.add_sparse_table(sparse_table_index, self.learning_rate_, worker.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
worker.add_dense_table(dense_table_index, self.learning_rate_, worker.add_dense_table(dense_table_index, self.learning_rate_, params,
params, grads) grads)
ps_param = pslib.PSParameter() ps_param = pslib.PSParameter()
ps_param.server_param.CopyFrom(server.get_desc()) ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc()) ps_param.trainer_param.CopyFrom(worker.get_desc())
...@@ -80,5 +102,4 @@ class DownpourSGD(object): ...@@ -80,5 +102,4 @@ class DownpourSGD(object):
# currently only support lookup_table # currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"] worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param.trainer_param.skip_op.extend(worker_skipped_ops) ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
ps_param_str = text_format.MessageToString(ps_param)
return [ps_param, worker_skipped_ops] return [ps_param, worker_skipped_ops]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册