未验证 提交 362dae55 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix image one sample bug (#156)

上级 9f0872be
import numpy as np
import mxnet as mx
import logging
import mxnet as mx
......@@ -22,6 +24,8 @@ with logger.mode("train"):
# scalar0 is used to record scalar metrics while MXNet is training. We will record accuracy.
# In the visualization, we can see the accuracy is increasing as more training steps happen.
scalar0 = logger.scalar("scalars/scalar0")
image0 = logger.image("images/image0", 1)
histogram0 = logger.histogram("histogram/histogram0", num_buckets=100)
# Record training steps
cnt_step = 0
......@@ -42,6 +46,19 @@ def add_scalar():
cnt_step += 1
return _callback
def add_image_histogram():
def _callback(iter_no, sym, arg, aux):
image0.start_sampling()
weight = arg['fullyconnected1_weight'].asnumpy()
shape = [100, 50]
data = weight.flatten()
image0.add_sample(shape, list(data))
histogram0.add_record(iter_no, list(data))
image0.finish_sampling()
return _callback
# Start to build CNN in MXNet, train MNIST dataset. For more info, check MXNet's official website:
# https://mxnet.incubator.apache.org/tutorials/python/mnist.html
......@@ -81,7 +98,8 @@ lenet_model.fit(train_iter,
eval_metric='acc',
# integrate our customized callback method
batch_end_callback=[add_scalar()],
num_epoch=2)
epoch_end_callback=[add_image_histogram()],
num_epoch=5)
test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size)
prob = lenet_model.predict(test_iter)
......
......@@ -140,7 +140,7 @@ class LogWriter(object):
}
return type2scalar[type](tag)
def image(self, tag, num_samples, step_cycle):
def image(self, tag, num_samples, step_cycle=1):
"""
Create an image writer that used to write image data.
"""
......
import pprint
import re
import sys
import time
import urllib
from tempfile import NamedTemporaryFile
......@@ -131,6 +132,7 @@ def get_invididual_image(storage, mode, tag, step_index, max_size=80):
with storage.mode(mode) as reader:
res = re.search(r".*/([0-9]+$)", tag)
# remove suffix '/x'
offset = 0
if res:
offset = int(res.groups()[0])
tag = tag[:tag.rfind('/')]
......@@ -206,4 +208,6 @@ def retry(ntimes, function, time2sleep, *args, **kwargs):
try:
return function(*args, **kwargs)
except:
error_info = '\n'.join(map(str, sys.exc_info()))
logger.error("Unexpected error: %s" % error_info)
time.sleep(time2sleep)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册