提交 c797aded 编写于 作者: Q qiaolongfei

append table grad to grad_to_block_id

上级 9f4b66f6
...@@ -207,6 +207,7 @@ static void AsyncUpdateThread( ...@@ -207,6 +207,7 @@ static void AsyncUpdateThread(
while (!exit_flag) { while (!exit_flag) {
const detail::ReceivedMessage v = queue->Pop(); const detail::ReceivedMessage v = queue->Pop();
auto recv_var_name = v.first; auto recv_var_name = v.first;
VLOG(4) << "async update " << recv_var_name;
auto var = v.second->GetVar(); auto var = v.second->GetVar();
if (var == nullptr) { if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name; LOG(ERROR) << "Can not find server side var: " << recv_var_name;
......
...@@ -476,7 +476,7 @@ class DistributeTranspiler: ...@@ -476,7 +476,7 @@ class DistributeTranspiler:
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint) pserver_index = self.pserver_endpoints.index(endpoint)
table_opt_block = self._create_table_optimize_block( table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx) pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_block = self._create_prefetch_block( prefetch_block = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block) pserver_index, pserver_program, table_opt_block)
...@@ -688,7 +688,7 @@ class DistributeTranspiler: ...@@ -688,7 +688,7 @@ class DistributeTranspiler:
return prefetch_block return prefetch_block
def _create_table_optimize_block(self, pserver_index, pserver_program, def _create_table_optimize_block(self, pserver_index, pserver_program,
pre_block_idx): pre_block_idx, grad_to_block_id):
def _clone_var(block, var, persistable=True): def _clone_var(block, var, persistable=True):
assert isinstance(var, Variable) assert isinstance(var, Variable)
return block.create_var( return block.create_var(
...@@ -743,10 +743,13 @@ class DistributeTranspiler: ...@@ -743,10 +743,13 @@ class DistributeTranspiler:
outputs={"Out": [grad_var]}) outputs={"Out": [grad_var]})
else: else:
# in async_mode, for table gradient, it also need to be splited to each parameter server # in async_mode, for table gradient, it also need to be splited to each parameter server
old_name = grad_var.name origin_grad_name = grad_var.name
new_name = old_name + ".pserver_" + str(pserver_index) splited_grad_name = self.table_grad_list[pserver_index].name
grad_var = pserver_program.global_block().rename_var(old_name, if not splited_grad_name.startswith(origin_grad_name):
new_name) raise ValueError("origin_grad_var: " + splited_grad_name +
" grad_var:" + grad_var.name)
grad_var = pserver_program.global_block().rename_var(
origin_grad_name, splited_grad_name)
lr_var = pserver_program.global_block().vars[table_opt_op.input( lr_var = pserver_program.global_block().vars[table_opt_op.input(
"LearningRate")[0]] "LearningRate")[0]]
...@@ -762,6 +765,9 @@ class DistributeTranspiler: ...@@ -762,6 +765,9 @@ class DistributeTranspiler:
outputs=outputs, outputs=outputs,
attrs=table_opt_op.attrs) attrs=table_opt_op.attrs)
# add table parameter gradient and it's block id to grad_to_block_id
grad_to_block_id.append(grad_var.name + ":" + str(table_opt_block.idx))
return table_opt_block return table_opt_block
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册