Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
别团等shy哥发育
Tensorflow Deep Learning
提交
4724c44b
T
Tensorflow Deep Learning
项目概览
别团等shy哥发育
/
Tensorflow Deep Learning
9 个月 前同步成功
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow Deep Learning
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
4724c44b
编写于
9月 17, 2022
作者:
别团等shy哥发育
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ResNetRS(迁移学习)宝可梦图像识别
上级
a4084395
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
80 addition
and
0 deletion
+80
-0
.gitignore
.gitignore
+2
-0
图像识别/宝可梦识别/ResNetRS50-test.py
图像识别/宝可梦识别/ResNetRS50-test.py
+78
-0
未找到文件。
.gitignore
浏览文件 @
4724c44b
...
...
@@ -45,3 +45,5 @@
/经典网络/data.zip
/图像识别/花朵识别/flowers.zip
/图像识别/花朵识别/model/
/图像识别/宝可梦识别/data/
/图像识别/宝可梦识别/model/
图像识别/宝可梦识别/ResNetRS50-test.py
0 → 100644
浏览文件 @
4724c44b
import
tensorflow
as
tf
import
numpy
as
np
import
os
import
matplotlib.pyplot
as
plt
from
tensorflow.keras
import
layers
from
tensorflow.keras.models
import
load_model
# from PIL import Image
model
=
load_model
(
'model/ResNetRS50-pokeman.h5'
)
model
.
summary
()
# 类别总数
dataset_dir
=
'data/train'
classes
=
[]
for
filename
in
os
.
listdir
(
dataset_dir
):
classes
.
append
(
filename
)
# print('classes:',classes)
# 预测单张图片
def
predict_single_image
(
img_path
):
# string类型的tensor
img
=
tf
.
io
.
read_file
(
img_path
)
# 将jpg格式转换为tensor
img
=
tf
.
image
.
decode_jpeg
(
img
,
channels
=
3
)
# 数据归一化
img
=
tf
.
image
.
convert_image_dtype
(
img
,
dtype
=
tf
.
float32
)
# resize
img
=
tf
.
image
.
resize
(
img
,
size
=
[
224
,
224
])
# 扩充一个维度
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
# 预测:结果是二维的
test_result
=
model
.
predict
(
img
)
# print('test_result:', test_result)
# 转化为一维
result
=
np
.
squeeze
(
test_result
)
# print('转化后result:', result)
# 找到概率值最大的索引
predict_class
=
np
.
argmax
(
result
)
# print('概率值最大的索引:', predict_class)
# 返回类别和所属类别的概率
return
classes
[
int
(
predict_class
)],
result
[
predict_class
]
# 对整个文件夹的图片进行预测
def
predict_directory
(
file_path
):
classes_pred
=
[]
classes_true
=
[]
probs
=
[]
for
file
in
os
.
listdir
(
file_path
):
# 测试图片完整路径
file_dir
=
os
.
path
.
join
(
file_path
,
file
)
# 打印文件路径
print
(
file_dir
)
# 传入文件路径进行预测
preds
,
prob
=
predict_single_image
(
file_dir
)
# 取出图片的真实标签(这里直接将文件夹名称作为真实标签值了)
# label_true=file.split('_')[0].title()
label_true
=
file_dir
.
split
(
'
\\
'
)[
0
].
split
(
'/'
)[
-
1
]
# 保存真实值和预测值结果
classes_true
.
append
(
label_true
)
classes_pred
.
append
(
preds
)
probs
.
append
(
prob
)
return
classes_pred
,
classes_true
,
probs
# img_path = 'Gemstones/train/Almandine/almandine_0.jpg'
# classes, prob = predict_single_image(img_path)
# print(classes, prob)
file_path
=
'data/test/bulbasaur'
classes_pred
,
classes_true
,
probs
=
predict_directory
(
file_path
)
print
(
classes_pred
)
print
(
classes_true
)
print
(
probs
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录