未验证 提交 362dae55 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix image one sample bug (#156)

上级 9f0872be
import numpy as np
import mxnet as mx
import logging import logging
import mxnet as mx import mxnet as mx
...@@ -22,6 +24,8 @@ with logger.mode("train"): ...@@ -22,6 +24,8 @@ with logger.mode("train"):
# scalar0 is used to record scalar metrics while MXNet is training. We will record accuracy. # scalar0 is used to record scalar metrics while MXNet is training. We will record accuracy.
# In the visualization, we can see the accuracy is increasing as more training steps happen. # In the visualization, we can see the accuracy is increasing as more training steps happen.
scalar0 = logger.scalar("scalars/scalar0") scalar0 = logger.scalar("scalars/scalar0")
image0 = logger.image("images/image0", 1)
histogram0 = logger.histogram("histogram/histogram0", num_buckets=100)
# Record training steps # Record training steps
cnt_step = 0 cnt_step = 0
...@@ -42,6 +46,19 @@ def add_scalar(): ...@@ -42,6 +46,19 @@ def add_scalar():
cnt_step += 1 cnt_step += 1
return _callback return _callback
def add_image_histogram():
def _callback(iter_no, sym, arg, aux):
image0.start_sampling()
weight = arg['fullyconnected1_weight'].asnumpy()
shape = [100, 50]
data = weight.flatten()
image0.add_sample(shape, list(data))
histogram0.add_record(iter_no, list(data))
image0.finish_sampling()
return _callback
# Start to build CNN in MXNet, train MNIST dataset. For more info, check MXNet's official website: # Start to build CNN in MXNet, train MNIST dataset. For more info, check MXNet's official website:
# https://mxnet.incubator.apache.org/tutorials/python/mnist.html # https://mxnet.incubator.apache.org/tutorials/python/mnist.html
...@@ -81,7 +98,8 @@ lenet_model.fit(train_iter, ...@@ -81,7 +98,8 @@ lenet_model.fit(train_iter,
eval_metric='acc', eval_metric='acc',
# integrate our customized callback method # integrate our customized callback method
batch_end_callback=[add_scalar()], batch_end_callback=[add_scalar()],
num_epoch=2) epoch_end_callback=[add_image_histogram()],
num_epoch=5)
test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size) test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size)
prob = lenet_model.predict(test_iter) prob = lenet_model.predict(test_iter)
......
...@@ -140,7 +140,7 @@ class LogWriter(object): ...@@ -140,7 +140,7 @@ class LogWriter(object):
} }
return type2scalar[type](tag) return type2scalar[type](tag)
def image(self, tag, num_samples, step_cycle): def image(self, tag, num_samples, step_cycle=1):
""" """
Create an image writer that used to write image data. Create an image writer that used to write image data.
""" """
......
import pprint import pprint
import re import re
import sys
import time import time
import urllib import urllib
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
...@@ -131,6 +132,7 @@ def get_invididual_image(storage, mode, tag, step_index, max_size=80): ...@@ -131,6 +132,7 @@ def get_invididual_image(storage, mode, tag, step_index, max_size=80):
with storage.mode(mode) as reader: with storage.mode(mode) as reader:
res = re.search(r".*/([0-9]+$)", tag) res = re.search(r".*/([0-9]+$)", tag)
# remove suffix '/x' # remove suffix '/x'
offset = 0
if res: if res:
offset = int(res.groups()[0]) offset = int(res.groups()[0])
tag = tag[:tag.rfind('/')] tag = tag[:tag.rfind('/')]
...@@ -206,4 +208,6 @@ def retry(ntimes, function, time2sleep, *args, **kwargs): ...@@ -206,4 +208,6 @@ def retry(ntimes, function, time2sleep, *args, **kwargs):
try: try:
return function(*args, **kwargs) return function(*args, **kwargs)
except: except:
error_info = '\n'.join(map(str, sys.exc_info()))
logger.error("Unexpected error: %s" % error_info)
time.sleep(time2sleep) time.sleep(time2sleep)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册