提交 1795a00f 编写于 作者: C chenguowei01

update unet.py

上级 c79749d9
......@@ -65,7 +65,7 @@ class UNet(object):
if self.env_info['place'] == 'cpu':
self.places = fluid.CPUPlace()
else:
self.places = fluid.CUDAPlaces()
self.places = fluid.CUDAPlace(0)
def build_model(self):
self.model = nets.UNet(self.num_classes, self.upsample_mode)
......@@ -219,6 +219,7 @@ class UNet(object):
tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
包含关键字:'confusion_matrix',表示评估的混淆矩阵。
"""
self.model.eval()
self.arrange_transform(transforms=eval_dataset.transforms, mode='train')
total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
conf_mat = ConfusionMatrix(self.num_classes, streaming=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册