未验证 提交 b9e7b276 编写于 作者: N Nicky Chan 提交者: GitHub

Simplify Paddle demo by showing vgg only and add comments (#425)

* Simplify Paddle demo by showing vgg only and add comments

* update python format

* fix format issue
上级 2642aab1
......@@ -6,7 +6,7 @@ there are several demos for different platforms.
## PaddlePaddle
Locates in `./paddle`.
This is a visualization for `resnet` on `cifar10` dataset, we visualize the CONV parameters,
This is a visualization for `vgg` on `cifar10` dataset, we visualize the CONV parameters,
and there are some interesting patterns.
## PyTorch GAN
......
......@@ -15,8 +15,6 @@
from __future__ import print_function
import sys
import numpy as np
from visualdl import LogWriter
......@@ -26,10 +24,13 @@ import paddle.v2.fluid.framework as framework
from paddle.v2.fluid.initializer import NormalInitializer
from paddle.v2.fluid.param_attr import ParamAttr
# create VisualDL logger and directory
logdir = "./tmp"
logwriter = LogWriter(logdir, sync_cycle=10)
# create 'train' run
with logwriter.mode("train") as writer:
# create 'loss' scalar tag to keep track of loss function
loss_scalar = writer.scalar("loss")
with logwriter.mode("train") as writer:
......@@ -37,53 +38,13 @@ with logwriter.mode("train") as writer:
num_samples = 4
with logwriter.mode("train") as writer:
conv_image = writer.image("conv_image", num_samples, 1)
conv_image = writer.image("conv_image", num_samples,
1) # show 4 samples for every 1 step
input_image = writer.image("input_image", num_samples, 1)
with logwriter.mode("train") as writer:
param1_histgram = writer.histogram("param1", 100)
def resnet_cifar10(input, depth=32):
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=tmp, act=act)
def shortcut(input, ch_in, ch_out, stride):
if ch_in != ch_out:
return conv_bn_layer(input, ch_out, 1, stride, 0, None)
else:
return input
def basicblock(input, ch_in, ch_out, stride):
tmp = conv_bn_layer(input, ch_out, 3, stride, 1)
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None)
short = shortcut(input, ch_in, ch_out, stride)
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu')
def layer_warp(block_func, input, ch_in, ch_out, count, stride):
tmp = block_func(input, ch_in, ch_out, stride)
for i in range(1, count):
tmp = block_func(tmp, ch_out, ch_out, 1)
return tmp
assert (depth - 2) % 6 == 0
n = (depth - 2) / 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
return pool
param1_histgram = writer.histogram(
"param1", 100) # 100 buckets, e.g 100 data sets in a histograms
def vgg16_bn_drop(input):
......@@ -119,18 +80,7 @@ data_shape = [3, 32, 32]
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
net_type = "vgg"
if len(sys.argv) >= 2:
net_type = sys.argv[1]
if net_type == "vgg":
print("train vgg net")
net, conv1 = vgg16_bn_drop(images)
elif net_type == "resnet":
print("train resnet")
net = resnet_cifar10(images, 32)
else:
raise ValueError("%s network is not supported" % net_type)
net, conv1 = vgg16_bn_drop(images)
predict = fluid.layers.fc(
input=net,
......@@ -173,6 +123,9 @@ for pass_id in range(PASS_NUM):
fetch_list=[avg_cost, conv1, param1_var] + accuracy.metrics)
pass_acc = accuracy.eval(exe)
# all code below is for VisualDL
# start picking sample from beginning
if sample_num == 0:
input_image.start_sampling()
conv_image.start_sampling()
......@@ -183,21 +136,26 @@ for pass_id in range(PASS_NUM):
idx = idx1
if idx != -1:
image_data = data[0][0]
# reshape the image to 32x32 and 3 channels
input_image_data = np.transpose(
image_data.reshape(data_shape), axes=[1, 2, 0])
# add sample to VisualDL Image Writer to view input image
input_image.set_sample(idx, input_image_data.shape,
input_image_data.flatten())
conv_image_data = conv1_out[0][0]
# add sample to view conv image
conv_image.set_sample(idx, conv_image_data.shape,
conv_image_data.flatten())
sample_num += 1
# when we have enough samples, call finish sampling()
if sample_num % num_samples == 0:
input_image.finish_sampling()
conv_image.finish_sampling()
sample_num = 0
# add record for loss and accuracy to scalar
loss_scalar.add_record(step, loss)
acc_scalar.add_record(step, acc)
param1_histgram.add_record(step, param1.flatten())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册