From c583fd34acc9e02362fd2ddd4bf7adb53d8321e6 Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Mon, 3 Dec 2018 09:53:24 +0800 Subject: [PATCH] add downpour sgd wrapper for pslib --- python/paddle/fluid/distributed/downpour.py | 34 ++++++++++++ python/paddle/fluid/distributed/node.py | 61 +++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 python/paddle/fluid/distributed/downpour.py create mode 100644 python/paddle/fluid/distributed/node.py diff --git a/python/paddle/fluid/distributed/downpour.py b/python/paddle/fluid/distributed/downpour.py new file mode 100644 index 000000000..523f68666 --- /dev/null +++ b/python/paddle/fluid/distributed/downpour.py @@ -0,0 +1,34 @@ +import paddle.fluid as fluid +import pslib_pb2 as pslib +from .node import DownpourServer +from .node import DownpourWorker +from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table + +class DownpourSGD(object): + def __init__(self, optimizer=opt, learning_rate=0.001, window=1): + # todo(guru4elephant): if optimizer is not None, will warning here + self.learning_rate_ = opt.learning_rate + self.window_ = window + + def minimize(self, loss, startup_program=None, + parameter_list=None, no_grad_set=None, + prefetch_slots=None, prefetch_slots_emb=None): + params_grads = sorted(append_backward(loss), key=lambda x:x[0].name) + table_name = fluid_distributed_lookup_table(loss.block.program) + server = DownpourServer() + worker = DownpourWorker() + server.add_sparse_table(0, learning_rate, + prefetch_slots, prefetch_slots_emb) + server.add_dense_table(1, learning_rate, params, grads) + worker.add_sparse_table(0, learning_rate, + prefetch_slots, prefetch_slots_emb) + worker.add_dense_table(1, learning_rate, params, grads) + + ps_param = pslib.PSParameter() + ps_param.server_param.CopyFrom(server.get_desc()) + ps_param.worker_param.CopyFrom(worker.get_desc()) + worker_skipped_ops = ["lookup_table", "lookup_table_grad"] + + return [solver_desc, parallel_desc] + + diff --git a/python/paddle/fluid/distributed/node.py b/python/paddle/fluid/distributed/node.py new file mode 100644 index 000000000..fc62d7220 --- /dev/null +++ b/python/paddle/fluid/distributed/node.py @@ -0,0 +1,61 @@ +import paddle.fluid as fluid +import pslib_pb2 as pslib + +class Server(object): + def __init__(self): + pass + + +class Worker(object): + def __init__(self): + pass + + +class DownpourServer(Server): + def __init__(self): + self.server_ = pslib.ServerParameter().downpour_server_param + + def add_sparse_table(self, table_id, learning_rate, + slot_key, slot_value_var, slot_grad_var): + table = self.server_.downpour_table_param.add() + table.table_id = table_id + table.type = PS_SPARSE_TABLE + table.accessor.accessor_class = "DownpourFeatureValueAccessor" + table.accessor.dense_sgd_param.adam.learning_rate = learning_rate + table.accessor.fea_dim = slot_value_var[0].shape[1] + + def add_dense_table(self, table_id, learning_rate, + param_var, grad_var): + table = self.server_.downpour_table_param.add() + table.table_id = table_id + table.type = PS_DENSE_TABLE + table.accessor.accessor_class = "DownpourDenseValueAccessor" + table.accessor.sparse_sgd_param.learning_rate = learning_rate + table.accessor.fea_dim = reduce(lambda x, y: x.shape, 1 for x in param_var) + + def get_desc(self): + return self.server_ + + +class DownpourWorker(Worker): + def __init__(self, window): + self.window = window + self.worker_ = pslib.WorkerParameter().downpour_worker_param + self.worker_.pull_dense_per_batch = window + self.worker_.push_dense_per_batch = window + + def add_sparse_table(self, table_id, + slot_keys, slot_value_vars, slot_grad_vars): + table = self.worker_.sparse_table.add() + table.table_id = table_id + table.slot.extend(slot_keys) + self.worker_.extend([grad.name for grad in slot_grad_vars]) + + def add_dense_table(self, table_id, param_vars, grad_vars): + table = self.worker_.dense_table.add() + table.table_id = table_id + table.dense_variable_name.extend([p.name for p in param_vars]) + table.dense_gradient_variable_name.extend([g.name for g in grad_vars]) + + def get_desc(self): + return self.worker_ -- GitLab