提交 3c59cde2 编写于 作者: D dayhaha

add normalization in predict.py

上级 e28ad8fe
...@@ -46,6 +46,7 @@ class Prediction(): ...@@ -46,6 +46,7 @@ class Prediction():
self.network.loadParameters(model_dir) self.network.loadParameters(model_dir)
self.images, self.labels = read_data(data_dir, "t10k") self.images, self.labels = read_data(data_dir, "t10k")
self.images = self.images / 255.0 * 2.0 - 1.0 # normalized to [-1,1]
slots = [dense_vector(28 * 28)] slots = [dense_vector(28 * 28)]
self.converter = DataProviderConverter(slots) self.converter = DataProviderConverter(slots)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册