未验证 提交 cdc5896f 编写于 作者: A Ainavo 提交者: GitHub

remove unnecessary generator set and dict (#51845)

上级 4dbf3a8f
......@@ -117,7 +117,7 @@ class TunableSpace:
{"class_name": v.__class__.__name__, "state": v.get_state()}
for v in self._variables.values()
],
"values": dict((k, v) for (k, v) in self.values.items()),
"values": {k: v for (k, v) in self.values.items()},
}
@classmethod
......@@ -126,7 +126,7 @@ class TunableSpace:
for v in state["variables"]:
v = _deserialize_tunable_variable(v)
ts._variables[v.name] = v
ts._values = dict((k, v) for (k, v) in state["values"].items())
ts._values = {k: v for (k, v) in state["values"].items()}
return ts
......
......@@ -88,7 +88,7 @@ class Choice(TunableVariable):
def __init__(self, name, values, default=None):
super().__init__(name=name, default=default)
types = set(type(v) for v in values)
types = {type(v) for v in values}
if len(types) > 1:
raise TypeError(
"Choice can contain only one type of value, but found values: {} with types: {}.".format(
......
......@@ -433,9 +433,9 @@ class PipelineLayer(nn.Layer):
return
layers_desc = self._layers_desc
shared_layer_names = set(
shared_layer_names = {
s.layer_name for s in layers_desc if isinstance(s, SharedLayerDesc)
)
}
for key in shared_layer_names:
shared_layers = []
for idx, layer in enumerate(layers_desc):
......@@ -445,9 +445,9 @@ class PipelineLayer(nn.Layer):
):
shared_layers.append(idx)
shared_stages = set(
shared_stages = {
self.get_stage_from_index(idx) for idx in shared_layers
)
}
self._dp_degree = self._topo.get_dim('data')
self._mp_degree = self._topo.get_dim('model')
self._sharding_degree = self._topo.get_dim('sharding')
......
......@@ -194,7 +194,7 @@ class TrainerRuntimeConfig:
'communicator_send_queue_size'
] = num_threads
return dict((key, str(self.runtime_configs[key])) for key in need_keys)
return {key: str(self.runtime_configs[key]) for key in need_keys}
def get_lr_ops(program):
......
......@@ -128,7 +128,7 @@ class TrainerRuntimeConfig:
'communicator_send_queue_size'
] = num_threads
return dict((key, str(self.runtime_configs[key])) for key in need_keys)
return {key: str(self.runtime_configs[key]) for key in need_keys}
def display(self, configs):
raw0, raw1, length = 45, 5, 50
......
......@@ -357,7 +357,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
assert quant_acc1 - int8_acc1 <= threshold
def _strings_from_csv(self, string):
return set(s.strip() for s in string.split(','))
return {s.strip() for s in string.split(',')}
def _ints_from_csv(self, string):
return set(map(int, string.split(',')))
......
......@@ -294,7 +294,7 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
assert quant_acc - int8_acc <= threshold
def _strings_from_csv(self, string):
return set(s.strip() for s in string.split(','))
return {s.strip() for s in string.split(',')}
def _ints_from_csv(self, string):
return set(map(int, string.split(',')))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册