提交 569a9d39 编写于 作者: W wuzewu

Update image classification module

上级 e6de8757
...@@ -21,8 +21,8 @@ from efficientnetb0_small_imagenet.efficientnet import EfficientNetB0_small ...@@ -21,8 +21,8 @@ from efficientnetb0_small_imagenet.efficientnet import EfficientNetB0_small
@moduleinfo( @moduleinfo(
name="efficientnetb0_small_imagenet", name="efficientnetb0_small_imagenet",
type="CV/image_classification", type="CV/image_classification",
author="baidu-vis", author="paddlepaddle",
author_email="", author_email="paddle-dev@baidu.com",
summary= summary=
"ResNet18vd is a image classfication model, this module is trained with imagenet datasets.", "ResNet18vd is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0") version="1.0.0")
......
...@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k): ...@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k):
output_i = {} output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k] indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs: for index in indexs:
label = label_list[index] label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index]) output_i[label] = float(result_i[index])
output.append(output_i) output.append(output_i)
return output return output
...@@ -21,15 +21,15 @@ from fix_resnext101_32x48d_wsl_imagenet.resnext101_wsl import Fix_ResNeXt101_32x ...@@ -21,15 +21,15 @@ from fix_resnext101_32x48d_wsl_imagenet.resnext101_wsl import Fix_ResNeXt101_32x
@moduleinfo( @moduleinfo(
name="fix_resnext101_32x48d_wsl_imagenet", name="fix_resnext101_32x48d_wsl_imagenet",
type="CV/image_classification", type="CV/image_classification",
author="baidu-vis", author="paddlepaddle",
author_email="", author_email="paddle-dev@baidu.com",
summary= summary=
"fix_resnext101_32x48d_wsl is a image classfication model, this module is trained with imagenet datasets.", "fix_resnext101_32x48d_wsl is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0") version="1.0.0")
class FixResnext10132x48dwslImagenet(hub.Module): class FixResnext10132x48dwslImagenet(hub.Module):
def _initialize(self): def _initialize(self):
self.default_pretrained_model_path = os.path.join( self.default_pretrained_model_path = os.path.join(
self.directory, "fix_resnext101_32x48d_wsl_imagenet_model") self.directory, "model")
label_file = os.path.join(self.directory, "label_list.txt") label_file = os.path.join(self.directory, "label_list.txt")
with open(label_file, 'r', encoding='utf-8') as file: with open(label_file, 'r', encoding='utf-8') as file:
self.label_list = file.read().split("\n")[:-1] self.label_list = file.read().split("\n")[:-1]
......
...@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k): ...@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k):
output_i = {} output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k] indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs: for index in indexs:
label = label_list[index] label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index]) output_i[label] = float(result_i[index])
output.append(output_i) output.append(output_i)
return output return output
...@@ -21,8 +21,8 @@ from res2net101_vd_26w_4s_imagenet.res2net_vd import Res2Net101_vd_26w_4s ...@@ -21,8 +21,8 @@ from res2net101_vd_26w_4s_imagenet.res2net_vd import Res2Net101_vd_26w_4s
@moduleinfo( @moduleinfo(
name="res2net101_vd_26w_4s_imagenet", name="res2net101_vd_26w_4s_imagenet",
type="CV/image_classification", type="CV/image_classification",
author="baidu-vis", author="paddlepaddle",
author_email="", author_email="paddle-dev@baidu.com",
summary= summary=
"res2net101_vd_26w_4s is a image classfication model, this module is trained with imagenet datasets.", "res2net101_vd_26w_4s is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0") version="1.0.0")
......
...@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k): ...@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k):
output_i = {} output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k] indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs: for index in indexs:
label = label_list[index] label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index]) output_i[label] = float(result_i[index])
output.append(output_i) output.append(output_i)
return output return output
...@@ -21,8 +21,8 @@ from resnet18_vd_imagenet.resnet_vd import ResNet18_vd ...@@ -21,8 +21,8 @@ from resnet18_vd_imagenet.resnet_vd import ResNet18_vd
@moduleinfo( @moduleinfo(
name="resnet18_vd_imagenet", name="resnet18_vd_imagenet",
type="CV/image_classification", type="CV/image_classification",
author="baidu-vis", author="paddlepaddle",
author_email="", author_email="paddle-dev@baidu.com",
summary= summary=
"ResNet18vd is a image classfication model, this module is trained with imagenet datasets.", "ResNet18vd is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0") version="1.0.0")
......
...@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k): ...@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k):
output_i = {} output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k] indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs: for index in indexs:
label = label_list[index] label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index]) output_i[label] = float(result_i[index])
output.append(output_i) output.append(output_i)
return output return output
...@@ -21,8 +21,8 @@ from se_resnet18_vd_imagenet.se_resnet import SE_ResNet18_vd ...@@ -21,8 +21,8 @@ from se_resnet18_vd_imagenet.se_resnet import SE_ResNet18_vd
@moduleinfo( @moduleinfo(
name="se_resnet18_vd_imagenet", name="se_resnet18_vd_imagenet",
type="CV/image_classification", type="CV/image_classification",
author="baidu-vis", author="paddlepaddle",
author_email="", author_email="paddle-dev@baidu.com",
summary= summary=
"SE_ResNet18_vd is a image classfication model, this module is trained with imagenet datasets.", "SE_ResNet18_vd is a image classfication model, this module is trained with imagenet datasets.",
version="1.0.0") version="1.0.0")
......
...@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k): ...@@ -49,7 +49,7 @@ def postprocess(data_out, label_list, top_k):
output_i = {} output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k] indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs: for index in indexs:
label = label_list[index] label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index]) output_i[label] = float(result_i[index])
output.append(output_i) output.append(output_i)
return output return output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册