提交 18f0af05 编写于 作者: “liuxiao”

pylint clean

上级 c8f69f5d
...@@ -167,7 +167,7 @@ class BertAttentionMask(nn.Cell): ...@@ -167,7 +167,7 @@ class BertAttentionMask(nn.Cell):
super(BertAttentionMask, self).__init__() super(BertAttentionMask, self).__init__()
self.has_attention_mask = has_attention_mask self.has_attention_mask = has_attention_mask
self.multiply_data = Tensor([-1000.0, ], dtype=dtype) self.multiply_data = Tensor([-1000.0,], dtype=dtype)
self.multiply = P.Mul() self.multiply = P.Mul()
if self.has_attention_mask: if self.has_attention_mask:
...@@ -198,7 +198,7 @@ class BertAttentionMaskBackward(nn.Cell): ...@@ -198,7 +198,7 @@ class BertAttentionMaskBackward(nn.Cell):
dtype=mstype.float32): dtype=mstype.float32):
super(BertAttentionMaskBackward, self).__init__() super(BertAttentionMaskBackward, self).__init__()
self.has_attention_mask = has_attention_mask self.has_attention_mask = has_attention_mask
self.multiply_data = Tensor([-1000.0, ], dtype=dtype) self.multiply_data = Tensor([-1000.0,], dtype=dtype)
self.multiply = P.Mul() self.multiply = P.Mul()
self.attention_mask = Tensor(np.ones(shape=attention_mask_shape).astype(np.float32)) self.attention_mask = Tensor(np.ones(shape=attention_mask_shape).astype(np.float32))
if self.has_attention_mask: if self.has_attention_mask:
......
...@@ -136,7 +136,7 @@ def test_LSTM(): ...@@ -136,7 +136,7 @@ def test_LSTM():
train_network.set_train() train_network.set_train()
train_features = Tensor(np.ones([64, max_len]).astype(np.int32)) train_features = Tensor(np.ones([64, max_len]).astype(np.int32))
train_labels = Tensor(np.ones([64, ]).astype(np.int32)[0:64]) train_labels = Tensor(np.ones([64,]).astype(np.int32)[0:64])
losses = [] losses = []
for epoch in range(num_epochs): for epoch in range(num_epochs):
loss = train_network(train_features, train_labels) loss = train_network(train_features, train_labels)
......
...@@ -34,7 +34,7 @@ ndarr = np.ones((2, 3)) ...@@ -34,7 +34,7 @@ ndarr = np.ones((2, 3))
def test_tensor_flatten(): def test_tensor_flatten():
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
lst = [1, 2, 3, 4, ] lst = [1, 2, 3, 4,]
tensor_list = ms.Tensor(lst, ms.float32) tensor_list = ms.Tensor(lst, ms.float32)
tensor_list = tensor_list.Flatten() tensor_list = tensor_list.Flatten()
print(tensor_list) print(tensor_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册