提交 206b36cc 编写于 作者: S shijianning

fix pylint warnings

上级 a06694ab
......@@ -41,23 +41,22 @@ from config import ConfigYOLOV3ResNet18
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
"""Set learning rate."""
lr_each_step = []
lr = learning_rate
for i in range(global_step):
if steps:
lr_each_step.append(lr * (decay_rate ** (i // decay_step)))
lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step)))
else:
lr_each_step.append(lr * (decay_rate ** (i / decay_step)))
lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step)))
lr_each_step = np.array(lr_each_step).astype(np.float32)
lr_each_step = lr_each_step[start_step:]
return lr_each_step
def init_net_param(net, init='ones'):
"""Init the parameters in net."""
def init_net_param(net, init_value='ones'):
"""Init:wq the parameters in net."""
params = net.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(init, p.data.shape(), p.data.dtype()))
p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype()))
if __name__ == '__main__':
......
......@@ -15,9 +15,9 @@
"""format transform function"""
import _akg
def refine_reduce_axis(input, axis):
def refine_reduce_axis(input_content, axis):
"""make reduce axis legal."""
shape = get_shape(input)
shape = get_shape(input_content)
if axis is None:
axis = [i for i in range(len(shape))]
elif isinstance(axis, int):
......
......@@ -55,7 +55,7 @@ class ClipGradients(nn.Cell):
grads,
clip_type,
clip_value):
if clip_type != 0 and clip_type != 1:
if clip_type not in (0, 1):
return grads
new_grads = ()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册