提交 569a9d39 编写于 作者: W wuzewu

Update image classification module

上级 e6de8757
......@@ -21,8 +21,8 @@ from efficientnetb0_small_imagenet.efficientnet import EfficientNetB0_small
@moduleinfo(
name="efficientnetb0_small_imagenet",
type="CV/image_classification",
author="baidu-vis",
author_email="",
author="paddlepaddle",
author_email="paddle-dev@baidu.com",
summary=
"ResNet18vd is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0")
......
......@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k):
output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs:
label = label_list[index]
label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index])
output.append(output_i)
return output
......@@ -21,15 +21,15 @@ from fix_resnext101_32x48d_wsl_imagenet.resnext101_wsl import Fix_ResNeXt101_32x
@moduleinfo(
name="fix_resnext101_32x48d_wsl_imagenet",
type="CV/image_classification",
author="baidu-vis",
author_email="",
author="paddlepaddle",
author_email="paddle-dev@baidu.com",
summary=
"fix_resnext101_32x48d_wsl is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0")
class FixResnext10132x48dwslImagenet(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(
self.directory, "fix_resnext101_32x48d_wsl_imagenet_model")
self.directory, "model")
label_file = os.path.join(self.directory, "label_list.txt")
with open(label_file, 'r', encoding='utf-8') as file:
self.label_list = file.read().split("\n")[:-1]
......
......@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k):
output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs:
label = label_list[index]
label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index])
output.append(output_i)
return output
......@@ -21,8 +21,8 @@ from res2net101_vd_26w_4s_imagenet.res2net_vd import Res2Net101_vd_26w_4s
@moduleinfo(
name="res2net101_vd_26w_4s_imagenet",
type="CV/image_classification",
author="baidu-vis",
author_email="",
author="paddlepaddle",
author_email="paddle-dev@baidu.com",
summary=
"res2net101_vd_26w_4s is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0")
......
......@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k):
output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs:
label = label_list[index]
label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index])
output.append(output_i)
return output
......@@ -21,8 +21,8 @@ from resnet18_vd_imagenet.resnet_vd import ResNet18_vd
@moduleinfo(
name="resnet18_vd_imagenet",
type="CV/image_classification",
author="baidu-vis",
author_email="",
author="paddlepaddle",
author_email="paddle-dev@baidu.com",
summary=
"ResNet18vd is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0")
......
......@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k):
output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs:
label = label_list[index]
label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index])
output.append(output_i)
return output
......@@ -21,8 +21,8 @@ from se_resnet18_vd_imagenet.se_resnet import SE_ResNet18_vd
@moduleinfo(
name="se_resnet18_vd_imagenet",
type="CV/image_classification",
author="baidu-vis",
author_email="",
author="paddlepaddle",
author_email="paddle-dev@baidu.com",
summary=
"SE_ResNet18_vd is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0")
......
......@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k):
output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs:
label = label_list[index]
label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index])
output.append(output_i)
return output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册