提交 cd965f25 编写于 作者: C chenguowei01

update predict function

上级 520736d0
...@@ -400,14 +400,19 @@ class DeepLabv3p(BaseAPI): ...@@ -400,14 +400,19 @@ class DeepLabv3p(BaseAPI):
fetch_list=list(self.test_outputs.values())) fetch_list=list(self.test_outputs.values()))
pred = result[0] pred = result[0]
pred = np.squeeze(pred).astype('uint8') pred = np.squeeze(pred).astype('uint8')
logit = result[1]
logit = np.squeeze(logit)
logit = np.transpose(logit, (1, 2, 0))
for info in im_info[::-1]: for info in im_info[::-1]:
if info[0] == 'resize': if info[0] == 'resize':
w, h = info[1][1], info[1][0] w, h = info[1][1], info[1][0]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST) pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
logit = cv2.resize(logit, (w, h), cv2.INTER_LINEAR)
elif info[0] == 'padding': elif info[0] == 'padding':
w, h = info[1][1], info[1][0] w, h = info[1][1], info[1][0]
pred = pred[0:h, 0:w] pred = pred[0:h, 0:w]
logit = logit[0:h, 0:w, :]
else: else:
raise Exception("Unexpected info '{}' in im_info".format( raise Exception("Unexpected info '{}' in im_info".format(
info[0])) info[0]))
return {'label_map': pred, 'score_map': result[1]} return {'label_map': pred, 'score_map': logit}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册