提交 284adcc7 编写于 作者: X xjqbest 提交者: dongdaxiang

fix bug

上级 a34fe624
...@@ -68,16 +68,7 @@ class DownpourSGD(DeviceWorker): ...@@ -68,16 +68,7 @@ class DownpourSGD(DeviceWorker):
# TODO(guru4elephant): hard code here, need to improve # TODO(guru4elephant): hard code here, need to improve
sparse_table.label_var_name = "click" sparse_table.label_var_name = "click"
dense_table = downpour.dense_table.add() dense_table_set = set()
dense_table.table_id = \
self.fleet_desc_.trainer_param.dense_table[0].table_id
dense_table.dense_value_name.extend(
self.fleet_desc_.trainer_param.dense_table[0].dense_variable_name)
dense_table.dense_grad_name.extend(
self.fleet_desc_.trainer_param.dense_table[
0].dense_gradient_variable_name)
downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op)
program_id = str(id(self.program_)) program_id = str(id(self.program_))
if self.program_ == None: if self.program_ == None:
print("program of current device worker is not configured") print("program of current device worker is not configured")
...@@ -95,10 +86,22 @@ class DownpourSGD(DeviceWorker): ...@@ -95,10 +86,22 @@ class DownpourSGD(DeviceWorker):
pc.push_dense_table_id.extend([i]) pc.push_dense_table_id.extend([i])
for i in program_configs[program_id]["pull_sparse"]: for i in program_configs[program_id]["pull_sparse"]:
pc.pull_sparse_table_id.extend([i]) pc.pull_sparse_table_id.extend([i])
dense_table_set.add(i)
for i in program_configs[program_id]["pull_dense"]: for i in program_configs[program_id]["pull_dense"]:
pc.pull_dense_table_id.extend([i]) pc.pull_dense_table_id.extend([i])
dense_table_set.add(i)
break break
for i in self.fleet_desc_.trainer_param.dense_table:
if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add()
dense_table.table_id = i.table_id
dense_table.dense_value_name.extend(
i.dense_variable_name)
dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name)
downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op)
class DeviceWorkerFactory(object): class DeviceWorkerFactory(object):
def create_device_worker(self, worker_type): def create_device_worker(self, worker_type):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册