未验证 提交 f43a57b2 编写于 作者: W whs 提交者: GitHub

Fix unitest (#809)

上级 c785d0ea
...@@ -131,7 +131,8 @@ class PruningPlan(): ...@@ -131,7 +131,8 @@ class PruningPlan():
backup_name = var_tmp.name.replace(".", "_") + "_backup" backup_name = var_tmp.name.replace(".", "_") + "_backup"
if backup_name not in sub_layer._buffers: if backup_name not in sub_layer._buffers:
sub_layer.register_buffer( sub_layer.register_buffer(
backup_name, paddle.to_tensor(var_tmp.value().get_tensor())) backup_name,
paddle.to_tensor(np.array(var_tmp.value().get_tensor())))
_logger.debug("Backup values of {} into buffers.".format( _logger.debug("Backup values of {} into buffers.".format(
var_tmp.name)) var_tmp.name))
......
...@@ -55,11 +55,8 @@ class TestSoftLabelLoss(StaticCase): ...@@ -55,11 +55,8 @@ class TestSoftLabelLoss(StaticCase):
for op in block.ops: for op in block.ops:
loss_ops.append(op.type) loss_ops.append(op.type)
self.assertTrue(set(merged_ops).difference(set(loss_ops)) == set()) self.assertTrue(set(merged_ops).difference(set(loss_ops)) == set())
self.assertTrue( self.assertTrue({'cross_entropy', 'softmax', 'reduce_mean'}.issubset(
set(loss_ops).difference(set(merged_ops)) == { set(loss_ops).difference(set(merged_ops))))
'cross_entropy', 'softmax', 'reduce_mean', 'fill_constant',
'elementwise_div'
})
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册