未验证 提交 2922aa67 编写于 作者: A Ainavo 提交者: GitHub

[CodeStyple][B011] replace assert false with raise AssertionError (#51935)

* replace assert false with AssertionError

* 修改配置文件多余的部分
上级 40115c7e
...@@ -829,9 +829,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -829,9 +829,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
backward_input_pos, backward_input_pos,
] ]
else: else:
assert ( raise AssertionError(
False f"Cannot find {backward_input_name} in forward position map"
), f"Cannot find {backward_input_name} in forward position map" )
for backward_output in backward_returns_list: for backward_output in backward_returns_list:
backward_output_name = backward_output[0] backward_output_name = backward_output[0]
......
...@@ -58,7 +58,9 @@ atype_to_parsing_function = { ...@@ -58,7 +58,9 @@ atype_to_parsing_function = {
def FindParsingFunctionFromAttributeType(atype): def FindParsingFunctionFromAttributeType(atype):
if atype not in atype_to_parsing_function.keys(): if atype not in atype_to_parsing_function.keys():
assert False, f"Unable to find {atype} in atype_to_parsing_function." raise AssertionError(
f"Unable to find {atype} in atype_to_parsing_function."
)
return atype_to_parsing_function[atype] return atype_to_parsing_function[atype]
......
...@@ -32,6 +32,11 @@ select = [ ...@@ -32,6 +32,11 @@ select = [
# Pyflakes # Pyflakes
"F401", "F401",
# Comprehensions
"C400",
"C401",
"C402",
# Pyupgrade # Pyupgrade
"UP001", "UP001",
"UP003", "UP003",
...@@ -62,6 +67,7 @@ select = [ ...@@ -62,6 +67,7 @@ select = [
# Bugbear # Bugbear
"B009", "B009",
"B010", "B010",
"B011",
] ]
unfixable = [ unfixable = [
"NPY001" "NPY001"
......
...@@ -108,7 +108,7 @@ def reader_creator(filename, word_idx, n, data_type): ...@@ -108,7 +108,7 @@ def reader_creator(filename, word_idx, n, data_type):
continue continue
yield src_seq, trg_seq yield src_seq, trg_seq
else: else:
assert False, 'Unknown data type' raise AssertionError('Unknown data type')
return reader return reader
......
...@@ -436,9 +436,11 @@ def build_dp_costs( ...@@ -436,9 +436,11 @@ def build_dp_costs(
elif var_name in dist_attr.outputs_dist_attrs: elif var_name in dist_attr.outputs_dist_attrs:
dims_mapping = dist_attr.get_output_dims_mapping(var_name) dims_mapping = dist_attr.get_output_dims_mapping(var_name)
else: else:
assert False, "cannot find dims_mapping for {} in {}".format( raise AssertionError(
"cannot find dims_mapping for {} in {}".format(
var_name, dist_attr var_name, dist_attr
) )
)
# dims_mapping = ( # dims_mapping = (
# dist_attr.get_input_dims_mapping(var_name) # dist_attr.get_input_dims_mapping(var_name)
......
...@@ -974,9 +974,9 @@ class DistributedContext: ...@@ -974,9 +974,9 @@ class DistributedContext:
def validate_dist_attr_for_program(self): def validate_dist_attr_for_program(self):
if not self._is_initialized: if not self._is_initialized:
assert ( raise AssertionError(
False "Program must be initialized before validating its distributed attributes"
), "Program must be initialized before validating its distributed attributes" )
for block in self.serial_main_program.blocks: for block in self.serial_main_program.blocks:
for tensor in block.vars.values(): for tensor in block.vars.values():
dist_tensor = self.get_dist_tensor_for_program(tensor) dist_tensor = self.get_dist_tensor_for_program(tensor)
...@@ -988,14 +988,14 @@ class DistributedContext: ...@@ -988,14 +988,14 @@ class DistributedContext:
if (dist_tensor is not None) and ( if (dist_tensor is not None) and (
not dist_tensor.validate_dist_attr() not dist_tensor.validate_dist_attr()
): ):
assert ( raise AssertionError(
False "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
), "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.serial_tensor.name,
dist_tensor.serial_tensor.desc.id(), dist_tensor.serial_tensor.desc.id(),
dist_tensor.serial_tensor.desc.original_id(), dist_tensor.serial_tensor.desc.original_id(),
dist_tensor.dist_attr, dist_tensor.dist_attr,
) )
)
for op in block.ops: for op in block.ops:
dist_op = self.get_dist_op_for_program(op) dist_op = self.get_dist_op_for_program(op)
assert ( assert (
...@@ -1004,14 +1004,14 @@ class DistributedContext: ...@@ -1004,14 +1004,14 @@ class DistributedContext:
dist_op.serial_op.type dist_op.serial_op.type
) )
if (dist_op is not None) and (not dist_op.validate_dist_attr()): if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert ( raise AssertionError(
False "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format(
), "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format(
dist_op.serial_op.type, dist_op.serial_op.type,
dist_op.serial_op.desc.id(), dist_op.serial_op.desc.id(),
dist_op.serial_op.desc.original_id(), dist_op.serial_op.desc.original_id(),
dist_op.dist_attr, dist_op.dist_attr,
) )
)
return True return True
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
......
...@@ -186,7 +186,9 @@ def register_distributed_operator_impl(op_type, dist_impl): ...@@ -186,7 +186,9 @@ def register_distributed_operator_impl(op_type, dist_impl):
dist_impl.type = op_type dist_impl.type = op_type
dist_op_impl_container.register_impl(dist_impl) dist_op_impl_container.register_impl(dist_impl)
else: else:
assert False, "Must register distributed operator registry first." raise AssertionError(
"Must register distributed operator registry first."
)
def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True): def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True):
......
...@@ -115,8 +115,8 @@ class ProcessGroup: ...@@ -115,8 +115,8 @@ class ProcessGroup:
if global_rank in self.ranks: if global_rank in self.ranks:
return self.ranks.index(global_rank) return self.ranks.index(global_rank)
else: else:
assert False, "Rank {} doesn't belong to this group".format( raise AssertionError(
global_rank "Rank {} doesn't belong to this group".format(global_rank)
) )
def is_instantiate(self): def is_instantiate(self):
...@@ -149,7 +149,7 @@ class ProcessGroup: ...@@ -149,7 +149,7 @@ class ProcessGroup:
ring_id ring_id
) )
else: else:
assert False, "No CUDA device found" raise AssertionError('No CUDA device found')
# TODO(shenliang03): This is a temporary solution to solve the problem of # TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group # hang caused by cross-creation of new_group
......
...@@ -1790,7 +1790,9 @@ def set_dist_op_desc_original_id(dist_op_desc, op_desc, dist_context): ...@@ -1790,7 +1790,9 @@ def set_dist_op_desc_original_id(dist_op_desc, op_desc, dist_context):
return return
# Third, print error infomation if we cannot find the original id # Third, print error infomation if we cannot find the original id
else: else:
assert False, "Cannot find the original id in the distributed context" raise AssertionError(
"Cannot find the original id in the distributed context"
)
def to_list(value): def to_list(value):
......
...@@ -304,7 +304,7 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): ...@@ -304,7 +304,7 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
ring_id ring_id
) )
else: else:
assert False, "no cuda device found" raise AssertionError("no cuda device found")
else: else:
return gp return gp
......
...@@ -416,7 +416,9 @@ class Fleet: ...@@ -416,7 +416,9 @@ class Fleet:
if not order: if not order:
order = ['dp', 'pp', 'sharding', 'mp'] order = ['dp', 'pp', 'sharding', 'mp']
if order[:].sort() != list(d_hybrid_degree.keys())[:].sort(): if order[:].sort() != list(d_hybrid_degree.keys())[:].sort():
assert False, "The order of hybrid_config setting is incorrect." raise AssertionError(
'The order of hybrid_config setting is incorrect.'
)
hybrid_group_names = [] hybrid_group_names = []
dims = [] dims = []
......
...@@ -953,11 +953,11 @@ def get_device_proc_info(args): ...@@ -953,11 +953,11 @@ def get_device_proc_info(args):
else: else:
devices_per_proc = [x for x in range(0, args.nproc_per_node)] devices_per_proc = [x for x in range(0, args.nproc_per_node)]
else: else:
assert ( raise AssertionError(
False "Can't support device_mode:{}, support only cpu|gpu|xpu now.".format(
), "Can't support device_mode:{}, support only cpu|gpu|xpu now.".format(
device_mode device_mode
) )
)
return (device_mode, devices_per_proc) return (device_mode, devices_per_proc)
......
...@@ -116,10 +116,9 @@ class AscendIRParser: ...@@ -116,10 +116,9 @@ class AscendIRParser:
) )
op_parser.apply(op) op_parser.apply(op)
else: else:
assert ( raise AssertionError(
False 'Op[%s] has not been registered, so we have to skip it'
), "Op[%s] has not been registered, so we have to skip it" % ( % op.type
op.type
) )
def _parse_program( def _parse_program(
......
...@@ -515,7 +515,7 @@ class SumParser(AscendParserBase): ...@@ -515,7 +515,7 @@ class SumParser(AscendParserBase):
def _apply(self): def _apply(self):
len_list = len(self.op.input_arg_names) len_list = len(self.op.input_arg_names)
if len_list < 2: if len_list < 2:
assert False, "the size of input list must large or equal 2" raise AssertionError("the size of input list must large or equal 2")
x = self._get_ge_input(self.op.input_arg_names[0]) x = self._get_ge_input(self.op.input_arg_names[0])
y = self._get_ge_input(self.op.input_arg_names[1]) y = self._get_ge_input(self.op.input_arg_names[1])
sum = ( sum = (
...@@ -643,7 +643,7 @@ class MatMulParser(AscendParserBase): ...@@ -643,7 +643,7 @@ class MatMulParser(AscendParserBase):
.set_attr_bool("transpose_x2", transpose_y) .set_attr_bool("transpose_x2", transpose_y)
) )
else: else:
assert False, "not support" raise AssertionError("not support")
return [matmul], [[0]] return [matmul], [[0]]
...@@ -681,7 +681,7 @@ class MulParser(AscendParserBase): ...@@ -681,7 +681,7 @@ class MulParser(AscendParserBase):
.set_input("x2", y, 0) .set_input("x2", y, 0)
) )
else: else:
assert False, "not support" raise AssertionError("not support")
else: else:
if len(shape_x1) == 3 and len(shape_x2) == 2: if len(shape_x1) == 3 and len(shape_x2) == 2:
assert x_num_col_dims == 2, "only support 2" assert x_num_col_dims == 2, "only support 2"
...@@ -729,7 +729,7 @@ class MulParser(AscendParserBase): ...@@ -729,7 +729,7 @@ class MulParser(AscendParserBase):
.set_attr_vec_int32("perm", [1, 2, 0]) .set_attr_vec_int32("perm", [1, 2, 0])
) )
else: else:
assert False, "not support" raise AssertionError("not support")
return [matmul], [[0]] return [matmul], [[0]]
......
...@@ -107,7 +107,9 @@ class DGCMomentumOptimizer(Optimizer): ...@@ -107,7 +107,9 @@ class DGCMomentumOptimizer(Optimizer):
elif isinstance(regularization, L2Decay): elif isinstance(regularization, L2Decay):
regular_type = 2 regular_type = 2
else: else:
assert False, 'regularization must be None|L1Decay|L2Deacy' raise AssertionError(
"regularization must be None|L1Decay|L2Deacy"
)
return regular_type, regular_coeff return regular_type, regular_coeff
def _is_use_dgc(self, param_var, grad_var): def _is_use_dgc(self, param_var, grad_var):
......
...@@ -105,7 +105,7 @@ def check(use_cuda): ...@@ -105,7 +105,7 @@ def check(use_cuda):
if __name__ == '__main__': if __name__ == '__main__':
try: try:
check(use_cuda=False) check(use_cuda=False)
assert False raise AssertionError()
except Exception as e: except Exception as e:
print(e) print(e)
print(type(e)) print(type(e))
...@@ -114,7 +114,7 @@ if __name__ == '__main__': ...@@ -114,7 +114,7 @@ if __name__ == '__main__':
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
try: try:
check(use_cuda=True) check(use_cuda=True)
assert False raise AssertionError()
except Exception as e: except Exception as e:
print(e) print(e)
print(type(e)) print(type(e))
......
...@@ -96,7 +96,7 @@ def run_check(): ...@@ -96,7 +96,7 @@ def run_check():
if paddle.is_compiled_with_cuda(): if paddle.is_compiled_with_cuda():
try: try:
check(use_cuda=True) check(use_cuda=True)
assert False raise AssertionError()
except Exception as e: except Exception as e:
print(e) print(e)
print(type(e)) print(type(e))
...@@ -105,7 +105,7 @@ def run_check(): ...@@ -105,7 +105,7 @@ def run_check():
assert type(e) == OSError or type(e) == RuntimeError assert type(e) == OSError or type(e) == RuntimeError
try: try:
check(use_cuda=False) check(use_cuda=False)
assert False raise AssertionError()
except Exception as e: except Exception as e:
print(e) print(e)
print(type(e)) print(type(e))
......
...@@ -462,7 +462,7 @@ class PassAutoScanTest(AutoScanTest): ...@@ -462,7 +462,7 @@ class PassAutoScanTest(AutoScanTest):
min_success_num, successful_ran_programs min_success_num, successful_ran_programs
) )
) )
assert False raise AssertionError()
used_time = time.time() - start_time used_time = time.time() - start_time
if max_duration > 0 and used_time > max_duration: if max_duration > 0 and used_time > max_duration:
logging.error( logging.error(
...@@ -470,7 +470,7 @@ class PassAutoScanTest(AutoScanTest): ...@@ -470,7 +470,7 @@ class PassAutoScanTest(AutoScanTest):
max_duration max_duration
) )
) )
assert False raise AssertionError()
def run_test(self, quant=False, prog_configs=None): def run_test(self, quant=False, prog_configs=None):
status = True status = True
......
...@@ -180,7 +180,7 @@ def matmul_head2(X, Y, head_number=1): ...@@ -180,7 +180,7 @@ def matmul_head2(X, Y, head_number=1):
z.append(np.matmul(x[i], y[i])) z.append(np.matmul(x[i], y[i]))
Z = np.concatenate((z), axis=2) Z = np.concatenate((z), axis=2)
else: else:
assert False, "ERROR: Not supported dimension!" raise AssertionError("ERROR: Not supported dimension!")
return Z return Z
......
...@@ -99,7 +99,7 @@ class TestNanInfCheckResult(unittest.TestCase): ...@@ -99,7 +99,7 @@ class TestNanInfCheckResult(unittest.TestCase):
out = paddle.log(x) out = paddle.log(x)
sys.stdout.flush() sys.stdout.flush()
if add_assert: if add_assert:
assert False raise AssertionError()
except Exception as e: except Exception as e:
# Cannot catch the log in CUDA kernel. # Cannot catch the log in CUDA kernel.
err_str_list = ( err_str_list = (
......
...@@ -301,7 +301,9 @@ class CollectiveOptimizer(DistributedOptimizer): ...@@ -301,7 +301,9 @@ class CollectiveOptimizer(DistributedOptimizer):
def _check_condition(self, name, **kwargs): def _check_condition(self, name, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
if v is True: if v is True:
assert False, "you can't use %s and %s together" % (name, k) raise AssertionError(
"you can't use %s and %s together" % (name, k)
)
def _check_collective_mode(self, main_program, optimizer, strategy): def _check_collective_mode(self, main_program, optimizer, strategy):
""" """
......
...@@ -384,13 +384,11 @@ class MoELayer(nn.Layer): ...@@ -384,13 +384,11 @@ class MoELayer(nn.Layer):
group=self.group, group=self.group,
) )
else: else:
assert ( raise AssertionError(
False "We only support naive gate, gshard gate and switch gate, but you choose {} gate.".format(
), "We only support naive gate, \
gshard gate and switch gate, \
but you choose {} gate.".format(
str(gate) str(gate)
) )
)
elif isinstance(gate, NaiveGate): elif isinstance(gate, NaiveGate):
self.top_k = gate.top_k self.top_k = gate.top_k
elif isinstance(gate, BaseGate): elif isinstance(gate, BaseGate):
......
...@@ -87,7 +87,7 @@ def _unpack_by_structure_paddle(target, structure): ...@@ -87,7 +87,7 @@ def _unpack_by_structure_paddle(target, structure):
if isinstance(ele, list): if isinstance(ele, list):
ret.append(unpack_by_structure(target[idx], ele)) ret.append(unpack_by_structure(target[idx], ele))
continue continue
assert False, "structure element must be 1 or list" raise AssertionError("structure element must be 1 or list")
return ret return ret
......
...@@ -1317,9 +1317,9 @@ class Transformer(Layer): ...@@ -1317,9 +1317,9 @@ class Transformer(Layer):
encoder_bias_attr = [bias_attr[0], bias_attr[-1]] encoder_bias_attr = [bias_attr[0], bias_attr[-1]]
decoder_bias_attr = bias_attr decoder_bias_attr = bias_attr
else: else:
assert ( raise AssertionError(
False "length of bias_attr should be 1 or 2 or 3 when it is a list/tuple"
), "length of bias_attr should be 1 or 2 or 3 when it is a list/tuple" )
else: else:
encoder_bias_attr = bias_attr encoder_bias_attr = bias_attr
decoder_bias_attr = bias_attr decoder_bias_attr = bias_attr
...@@ -1339,9 +1339,9 @@ class Transformer(Layer): ...@@ -1339,9 +1339,9 @@ class Transformer(Layer):
encoder_weight_attr = [weight_attr[0], weight_attr[-1]] encoder_weight_attr = [weight_attr[0], weight_attr[-1]]
decoder_weight_attr = weight_attr decoder_weight_attr = weight_attr
else: else:
assert ( raise AssertionError(
False "length of weight_attr should be 1 or 2 or 3 when it is a list/tuple"
), "length of weight_attr should be 1 or 2 or 3 when it is a list/tuple" )
else: else:
encoder_weight_attr = weight_attr encoder_weight_attr = weight_attr
decoder_weight_attr = weight_attr decoder_weight_attr = weight_attr
......
...@@ -168,7 +168,7 @@ class Imikolov(Dataset): ...@@ -168,7 +168,7 @@ class Imikolov(Dataset):
continue continue
self.data.append((src_seq, trg_seq)) self.data.append((src_seq, trg_seq))
else: else:
assert False, 'Unknow data type' raise AssertionError('Unknow data type')
def __getitem__(self, idx): def __getitem__(self, idx):
return tuple([np.array(d) for d in self.data[idx]]) return tuple([np.array(d) for d in self.data[idx]])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册