提交 c073bb3b 编写于 作者: T tangwei12

code style

上级 298588f8
...@@ -382,8 +382,7 @@ class Operator(object): ...@@ -382,8 +382,7 @@ class Operator(object):
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv',
'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine',
'ncclInit', 'channel_create', 'channel_close', 'channel_send', 'ncclInit', 'channel_create', 'channel_close', 'channel_send',
'channel_recv', 'select', 'checkpoint_notify' 'channel_recv', 'select', 'checkpoint_notify', 'gen_nccl_id'
, 'gen_nccl_id'
} }
def __init__(self, def __init__(self,
...@@ -1022,7 +1021,7 @@ class Block(object): ...@@ -1022,7 +1021,7 @@ class Block(object):
name=var.name, persistable=var.persistable, type=var.type) name=var.name, persistable=var.persistable, type=var.type)
elif var.type == core.VarDesc.VarType.RAW: elif var.type == core.VarDesc.VarType.RAW:
ret_var = self.create_var( ret_var = self.create_var(
name=var.name, persistable=var.persistable, type=var.type) name=var.name, persistable=var.persistable, type=var.type)
elif var.type == core.VarDesc.VarType.SELECTED_ROWS: elif var.type == core.VarDesc.VarType.SELECTED_ROWS:
ret_var = self.create_var( ret_var = self.create_var(
name=var.name, name=var.name,
...@@ -1465,7 +1464,7 @@ def get_var(name, program=None): ...@@ -1465,7 +1464,7 @@ def get_var(name, program=None):
Args: Args:
name(str): name of the variable name(str): name of the variable
program(Program|None): program object. program(Program|None): program object.
If None, default_global_program() will be used. If None, default_global_program() will be used.
Returns: Returns:
Variable Variable
......
...@@ -865,16 +865,17 @@ class DistributeTranspiler: ...@@ -865,16 +865,17 @@ class DistributeTranspiler:
""" """
import os import os
pserver_program.global_block().create_var(name="loopup_table_path", persistable=True, type=core.VarDesc.VarType.RAW) pserver_program.global_block().create_var(
name="loopup_table_path",
persistable=True,
type=core.VarDesc.VarType.RAW)
checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block = pserver_program.create_block(pre_block_idx)
checkpoint_save_block.append_op( checkpoint_save_block.append_op(
type='save', type='save',
inputs={'X': [self.table_name]}, inputs={'X': [self.table_name]},
outputs={}, outputs={},
attrs={ attrs={'file_path': self.table_name})
'file_path': self.table_name
})
return checkpoint_save_block.idx return checkpoint_save_block.idx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册