diff --git a/README.md b/README.md index 24202f9381d83f76cb3fb711b470502ae31521ca..9b5b78985ea9a1084336d081330fd8861d7f8edb 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,12 @@ * PyTorch >= 1.5.1 ## 数据集 -采用"Stanford Dogs Dataset"数据集官方地址:http://vision.stanford.edu/aditya86/ImageNetDogs/,且分为训练和测试两部分。 -本文将该数据集的标注文件改为xml格式,可以通过运行read_datasests.py,可以对数据的标注信息进行解析可视化。 +* 采用"Stanford Dogs Dataset"数据集官方地址:http://vision.stanford.edu/aditya86/ImageNetDogs/,且分为训练和测试两部分。 +* 本文将该数据集的标注文件更改为xml格式,[数据集下载地址(百度网盘 Password: ks87 )](https://pan.baidu.com/s/1tT0wF4N2I9p5JDfCwtM1CQ) +* 通过运行read_datasests.py,可以对数据的标注信息进行解析可视化。 ## 预训练模型 -* [预训练模型下载地址: ~ +* [预训练模型下载地址(百度网盘 Password: ks87 )](https://pan.baidu.com/s/1tT0wF4N2I9p5JDfCwtM1CQ) ## 项目使用方法 diff --git a/inference.py b/inference.py index a86ed150ffd1256f1775e74f543111ce3cf8b747..8bb2f82f1e39961a2f4b453b0d15d44b719837f8 100644 --- a/inference.py +++ b/inference.py @@ -41,7 +41,7 @@ def get_xml_msg(path): if __name__ == "__main__": parser = argparse.ArgumentParser(description=' Project Classification Test') - parser.add_argument('--test_model', type=str, default = './model_exp/2021-02-09_06-32-32/model_epoch-627.pth', + parser.add_argument('--test_model', type=str, default = './model_exp/2021-02-09_06-32-32/resnet50_epoch-627.pth', help = 'test_model') # 模型路径 parser.add_argument('--model', type=str, default = 'resnet_50', help = 'model : resnet_18,resnet_34,resnet_50,resnet_101,resnet_152') # 模型类型 @@ -177,7 +177,7 @@ if __name__ == "__main__": dict_r[doc] += 1 cv2.destroyAllWindows() - # Top1 的 每类预测精确度。 + # Top1 的每类预测精确度。 print('\n-----------------------------------------------\n') acc_list = [] for idx,doc in enumerate(sorted(os.listdir(ops.test_path), key=lambda x:int(x.split('-')[0]), reverse=False)): diff --git a/read_datasets.py b/read_datasets.py index ba6fd637628845e335ec8d2ccd2aedc1b77d0768..30958f686f18bd0fe92a5b1604640f82904a5e95 100644 --- a/read_datasets.py +++ b/read_datasets.py @@ -1,7 +1,7 @@ #-*-coding:utf-8-*- # date:2020-02-08 -# Author: Eric.Lee -## function: read datasets label files +# author: Eric.Lee +# function: read datasets label files import os import cv2 @@ -56,7 +56,7 @@ if __name__ == "__main__": cv2.putText(img_, ('index : ' + str(idx)), (5,img_.shape[0]-5),cv2.FONT_HERSHEY_PLAIN, 1.8, (255, 255, 0), 6) cv2.putText(img_, ('index : ' + str(idx)), (5,img_.shape[0]-5),cv2.FONT_HERSHEY_PLAIN, 1.8, (255, 60, 255), 2) - xml_ = path + doc_+"/"+f_.strip('.jpg').strip('.png')+'.xml' + xml_ = path + doc_+"/"+f_.replace(".jpg",".xml").replace(".png",".xml") list_x = get_xml_msg(xml_)# 获取 xml 文件 的 object