From 4724c44b19d13882651c823af48f5303209a57cb Mon Sep 17 00:00:00 2001 From: interface_xiongtete <1144722582@qq.com> Date: Sat, 17 Sep 2022 10:35:13 +0800 Subject: [PATCH] =?UTF-8?q?ResNetRS(=E8=BF=81=E7=A7=BB=E5=AD=A6=E4=B9=A0)?= =?UTF-8?q?=E5=AE=9D=E5=8F=AF=E6=A2=A6=E5=9B=BE=E5=83=8F=E8=AF=86=E5=88=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + .../ResNetRS50-test.py" | 77 +++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 "\345\233\276\345\203\217\350\257\206\345\210\253/\345\256\235\345\217\257\346\242\246\350\257\206\345\210\253/ResNetRS50-test.py" diff --git a/.gitignore b/.gitignore index 0f4a9a6..3ca6777 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,5 @@ /经典网络/data.zip /图像识别/花朵识别/flowers.zip /图像识别/花朵识别/model/ +/图像识别/宝可梦识别/data/ +/图像识别/宝可梦识别/model/ diff --git "a/\345\233\276\345\203\217\350\257\206\345\210\253/\345\256\235\345\217\257\346\242\246\350\257\206\345\210\253/ResNetRS50-test.py" "b/\345\233\276\345\203\217\350\257\206\345\210\253/\345\256\235\345\217\257\346\242\246\350\257\206\345\210\253/ResNetRS50-test.py" new file mode 100644 index 0000000..50da5bb --- /dev/null +++ "b/\345\233\276\345\203\217\350\257\206\345\210\253/\345\256\235\345\217\257\346\242\246\350\257\206\345\210\253/ResNetRS50-test.py" @@ -0,0 +1,77 @@ +import tensorflow as tf +import numpy as np +import os +import matplotlib.pyplot as plt +from tensorflow.keras import layers +from tensorflow.keras.models import load_model + +# from PIL import Image + +model = load_model('model/ResNetRS50-pokeman.h5') +model.summary() + +# 类别总数 +dataset_dir = 'data/train' +classes = [] +for filename in os.listdir(dataset_dir): + classes.append(filename) +# print('classes:',classes) + +# 预测单张图片 +def predict_single_image(img_path): + # string类型的tensor + img = tf.io.read_file(img_path) + # 将jpg格式转换为tensor + img = tf.image.decode_jpeg(img, channels=3) + # 数据归一化 + img = tf.image.convert_image_dtype(img, dtype=tf.float32) + # resize + img = tf.image.resize(img, size=[224, 224]) + # 扩充一个维度 + img = np.expand_dims(img, axis=0) + + # 预测:结果是二维的 + test_result = model.predict(img) + # print('test_result:', test_result) + # 转化为一维 + result = np.squeeze(test_result) + # print('转化后result:', result) + + # 找到概率值最大的索引 + predict_class = np.argmax(result) + # print('概率值最大的索引:', predict_class) + + # 返回类别和所属类别的概率 + return classes[int(predict_class)], result[predict_class] + +# 对整个文件夹的图片进行预测 +def predict_directory(file_path): + classes_pred=[] + classes_true=[] + probs=[] + for file in os.listdir(file_path): + # 测试图片完整路径 + file_dir=os.path.join(file_path,file) + # 打印文件路径 + print(file_dir) + # 传入文件路径进行预测 + preds,prob=predict_single_image(file_dir) + + # 取出图片的真实标签(这里直接将文件夹名称作为真实标签值了) + # label_true=file.split('_')[0].title() + label_true = file_dir.split('\\')[0].split('/')[-1] + # 保存真实值和预测值结果 + classes_true.append(label_true) + classes_pred.append(preds) + probs.append(prob) + return classes_pred,classes_true,probs + +# img_path = 'Gemstones/train/Almandine/almandine_0.jpg' +# classes, prob = predict_single_image(img_path) +# print(classes, prob) + +file_path= 'data/test/bulbasaur' +classes_pred,classes_true,probs=predict_directory(file_path) +print(classes_pred) +print(classes_true) +print(probs) \ No newline at end of file -- GitLab