diff --git a/.gitignore b/.gitignore index 0f4a9a6a4b5edfc84c9bf439486bc3122229d4f4..3ca6777fe130eebac83b6b10c34322687e58f3aa 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,5 @@ /经典网络/data.zip /图像识别/花朵识别/flowers.zip /图像识别/花朵识别/model/ +/图像识别/宝可梦识别/data/ +/图像识别/宝可梦识别/model/ diff --git "a/\345\233\276\345\203\217\350\257\206\345\210\253/\345\256\235\345\217\257\346\242\246\350\257\206\345\210\253/ResNetRS50-test.py" "b/\345\233\276\345\203\217\350\257\206\345\210\253/\345\256\235\345\217\257\346\242\246\350\257\206\345\210\253/ResNetRS50-test.py" new file mode 100644 index 0000000000000000000000000000000000000000..50da5bb3407c2427ed7be1b0a3393b0de13d50e1 --- /dev/null +++ "b/\345\233\276\345\203\217\350\257\206\345\210\253/\345\256\235\345\217\257\346\242\246\350\257\206\345\210\253/ResNetRS50-test.py" @@ -0,0 +1,77 @@ +import tensorflow as tf +import numpy as np +import os +import matplotlib.pyplot as plt +from tensorflow.keras import layers +from tensorflow.keras.models import load_model + +# from PIL import Image + +model = load_model('model/ResNetRS50-pokeman.h5') +model.summary() + +# 类别总数 +dataset_dir = 'data/train' +classes = [] +for filename in os.listdir(dataset_dir): + classes.append(filename) +# print('classes:',classes) + +# 预测单张图片 +def predict_single_image(img_path): + # string类型的tensor + img = tf.io.read_file(img_path) + # 将jpg格式转换为tensor + img = tf.image.decode_jpeg(img, channels=3) + # 数据归一化 + img = tf.image.convert_image_dtype(img, dtype=tf.float32) + # resize + img = tf.image.resize(img, size=[224, 224]) + # 扩充一个维度 + img = np.expand_dims(img, axis=0) + + # 预测:结果是二维的 + test_result = model.predict(img) + # print('test_result:', test_result) + # 转化为一维 + result = np.squeeze(test_result) + # print('转化后result:', result) + + # 找到概率值最大的索引 + predict_class = np.argmax(result) + # print('概率值最大的索引:', predict_class) + + # 返回类别和所属类别的概率 + return classes[int(predict_class)], result[predict_class] + +# 对整个文件夹的图片进行预测 +def predict_directory(file_path): + classes_pred=[] + classes_true=[] + probs=[] + for file in os.listdir(file_path): + # 测试图片完整路径 + file_dir=os.path.join(file_path,file) + # 打印文件路径 + print(file_dir) + # 传入文件路径进行预测 + preds,prob=predict_single_image(file_dir) + + # 取出图片的真实标签(这里直接将文件夹名称作为真实标签值了) + # label_true=file.split('_')[0].title() + label_true = file_dir.split('\\')[0].split('/')[-1] + # 保存真实值和预测值结果 + classes_true.append(label_true) + classes_pred.append(preds) + probs.append(prob) + return classes_pred,classes_true,probs + +# img_path = 'Gemstones/train/Almandine/almandine_0.jpg' +# classes, prob = predict_single_image(img_path) +# print(classes, prob) + +file_path= 'data/test/bulbasaur' +classes_pred,classes_true,probs=predict_directory(file_path) +print(classes_pred) +print(classes_true) +print(probs) \ No newline at end of file