未验证 提交 3a36200a 编写于 作者: O OuYang Yu 提交者: GitHub

refactor ParallelDescSymbol (#3774)

上级 a7ab7ec1
......@@ -143,7 +143,7 @@ def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None):
parallel_conf.device_name.append(device_name)
if builder is None:
return symbol_util.ParallelDescSymbol(
parallel_desc_symbol.symbol_id, parallel_conf, device_tag
parallel_desc_symbol.symbol_id, parallel_conf
)
else:
return builder.GetParallelDescSymbol(parallel_conf)
......@@ -160,7 +160,7 @@ def RandomParallelIdPerMachine(parallel_desc_symbol, device_tag=None, builder=No
parallel_conf.device_name.append("%s:%s" % (machine_id, dev_id))
if builder is None:
return symbol_util.ParallelDescSymbol(
parallel_desc_symbol.symbol_id, parallel_conf, device_tag
parallel_desc_symbol.symbol_id, parallel_conf
)
else:
return builder.GetParallelDescSymbol(parallel_conf)
......@@ -34,9 +34,9 @@ class Symbol(object):
class ParallelDescSymbol(Symbol):
def __init__(self, symbol_id, parallel_conf, device_tag):
def __init__(self, symbol_id, parallel_conf):
Symbol.__init__(self, symbol_id, parallel_conf)
self.device_tag_ = device_tag
self.device_tag_ = parallel_conf.device_tag
self.machine_id2device_id_list_ = MakeMachineId2DeviceIdList(parallel_conf)
sub_parallel_nums = [len(v) for k, v in self.machine_id2device_id_list_.items()]
self.parallel_num_ = functools.reduce(lambda a, b: a + b, sub_parallel_nums, 0)
......
......@@ -421,14 +421,13 @@ class InstructionsBuilder(object):
return symbol
def GetParallelDescSymbol(self, parallel_conf):
device_tag = parallel_conf.device_tag
serialized_parallel_conf = parallel_conf.SerializeToString()
if symbol_storage.HasSymbol4SerializedParallelConf(serialized_parallel_conf):
return symbol_storage.GetSymbol4SerializedParallelConf(
serialized_parallel_conf
)
symbol_id = self._NewSymbolId4ParallelConf(parallel_conf)
symbol = symbol_util.ParallelDescSymbol(symbol_id, parallel_conf, device_tag)
symbol = symbol_util.ParallelDescSymbol(symbol_id, parallel_conf)
symbol_storage.SetSymbol4Id(symbol_id, symbol)
symbol_storage.SetSymbol4SerializedParallelConf(
serialized_parallel_conf, symbol
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册