提交 fe63dc1d 编写于 作者: G guosheng

Merge branch 'develop' of https://github.com/PaddlePaddle/models into refine-transformer-logit

.DS_Store .DS_Store
*.pyc *.pyc
.*~
...@@ -18,19 +18,19 @@ This tool is used to convert a Caffe model to Fluid model ...@@ -18,19 +18,19 @@ This tool is used to convert a Caffe model to Fluid model
### Tested models ### Tested models
- Lenet on mnist dataset - Lenet
- ResNets:(ResNet-50, ResNet-101, ResNet-152) - ResNets:(ResNet-50, ResNet-101, ResNet-152)
model addr: `https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777`_ [model addr](https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)
- GoogleNet: - GoogleNet:
model addr: `https://gist.github.com/jimmie33/7ea9f8ac0da259866b854460f4526034`_ [model addr](https://gist.github.com/jimmie33/7ea9f8ac0da259866b854460f4526034)
- VGG: - VGG:
model addr: `https://gist.github.com/ksimonyan/211839e770f7b538e2d8`_ [model addr](https://gist.github.com/ksimonyan/211839e770f7b538e2d8)
- AlexNet: - AlexNet:
model addr: `https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet`_ [model addr](https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet)
### Notes ### Notes
Some of this code come from here: https://github.com/ethereon/caffe-tensorflow Some of this code come from here: https://github.com/ethereon/caffe-tensorflow
#!/usr/bin/python
#
#a tool to compare tensors in two files or two directories
#
import sys
import os
def walk_dir(rootdir):
for subdir, dirs, files in os.walk(rootdir):
for file in files:
yield file
def calc_diff(f1, f2):
import numpy as np
d1 = np.load(f1).flatten()
d2 = np.load(f2).flatten()
d1_num = reduce(lambda x, y: x * y, d1.shape)
d2_num = reduce(lambda x, y: x * y, d2.shape)
if d1_num != d2_num:
print d1.shape
print d2.shape
assert (d1_num == d2_num), "their shape is not consistent"
try:
df = np.abs(d1 - d2)
max_df = np.max(df)
sq_df = np.mean(df * df)
return max_df, sq_df
except Exception as e:
return -1.0, -1.0
def compare(path1, path2):
def diff(f1, f2):
max_df, sq_df = calc_diff(f1, f2)
print('compare %s <=> %s with result[max_df:%.4e, sq_df:%.4e]' %
(f1, f2, max_df, sq_df))
assert (max_df < 1e-5), \
'max_df is too large with value[%.6e]' % (max_df)
assert (sq_df < 1e-10), \
'sq_df is too large with value[%.6e]' % (sq_df)
if os.path.exists(path1) is False:
print('not found %s' % (path1))
return 1
elif os.path.exists(path2) is False:
print('not found %s' % (path2))
return 1
if path1.find('.npy') > 0 and path2.find('.npy') > 0:
diff(path1, path2)
return
for f in walk_dir(path2):
if f.find('.npy') < 0:
continue
f1 = os.path.join(path1, f)
f2 = os.path.join(path2, f)
diff(f1, f2)
print('all checking succeed to pass')
return 0
if __name__ == "__main__":
if len(sys.argv) == 1:
path1 = 'lenet.tf/results'
path2 = 'lenet.paddle/results'
elif len(sys.argv) == 3:
path1 = sys.argv[1]
path2 = sys.argv[2]
else:
print('usage:')
print(' %s [path1] [path2]' % (sys.argv[0]))
exit(1)
print('compare inner result in %s %s' % (path1, path2))
exit(compare(path1, path2))
#!/bin/bash
#
#function:
# a tool used to check the difference of models' results generated by caffe model and paddle model
#
#howto:
# bash diff.sh resnet50 #when this has been finished, you can get the difference in precision
#
#notes:
# 0, in order to infer using caffe, we need pycaffe installed
# 1, prepare your caffe model in 'models.caffe/', eg: 'model.caffe/resnet101/resnet101.[prototxt|caffemodel]'
# 2, converted paddle model will be in 'models'
# 3, results of layers will be stored in 'results/${model_name}.[paddle|caffe]'
# 4, only the last layer will be checked by default
model_name="resnet50"
results_root="results/"
if [[ -n $1 ]];then
if [ $1 = "-h" ];then
echo "usage:"
echo " bash $0 [model_name]"
echo " eg:bash $0 resnet50"
exit 0
fi
model_name=$1
fi
mkdir -p $results_root
model_prototxt="models.caffe/$model_name/${model_name}.prototxt"
model_caffemodel="models.caffe/${model_name}/${model_name}.caffemodel"
#1, dump layers' results from paddle
paddle_results="$results_root/${model_name}.paddle"
rm -rf $paddle_results
rm -rf "results.paddle"
bash run.sh $model_name ./models.caffe/$model_name ./models/$model_name
if [[ $? -ne 0 ]] || [[ ! -e "results.paddle" ]];then
echo "not found paddle's results, maybe failed to convert"
exit 1
fi
mv results.paddle $paddle_results
#2, dump layers' results from caffe
caffe_results="$results_root/${model_name}.caffe"
rm -rf $caffe_results
rm -rf "results.caffe"
cfpython ./infer.py caffe $model_prototxt $model_caffemodel $paddle_results/data.npy
if [[ $? -ne 0 ]] || [[ ! -e "results.caffe" ]];then
echo "not found caffe's results, maybe failed to do inference with caffe"
exit 1
fi
mv results.caffe $caffe_results
#3, extract layer names
cat $model_prototxt | grep name | perl -ne 'if(/^\s*name:\s+\"([^\"]+)/){ print $1."\n";}' >.layer_names
#4, compare one by one
for i in $(cat ".layer_names" | tail -n1);do
echo "process $i"
python compare.py $caffe_results/${i}.npy $paddle_results/${i}.npy
done
...@@ -10,8 +10,11 @@ import os ...@@ -10,8 +10,11 @@ import os
import sys import sys
import inspect import inspect
import numpy as np import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
def import_fluid():
import paddle.fluid as fluid
return fluid
def load_data(imgfile, shape): def load_data(imgfile, shape):
...@@ -52,8 +55,10 @@ def build_model(net_file, net_name): ...@@ -52,8 +55,10 @@ def build_model(net_file, net_name):
print(e) print(e)
return None return None
input_name = 'data' fluid = import_fluid()
input_shape = MyNet.input_shapes()[input_name] inputs_dict = MyNet.input_shapes()
input_name = inputs_dict.keys()[0]
input_shape = inputs_dict[input_name]
images = fluid.layers.data(name='image', shape=input_shape, dtype='float32') images = fluid.layers.data(name='image', shape=input_shape, dtype='float32')
#label = fluid.layers.data(name='label', shape=[1], dtype='int64') #label = fluid.layers.data(name='label', shape=[1], dtype='int64')
...@@ -64,7 +69,7 @@ def build_model(net_file, net_name): ...@@ -64,7 +69,7 @@ def build_model(net_file, net_name):
def dump_results(results, names, root): def dump_results(results, names, root):
if os.path.exists(root) is False: if os.path.exists(root) is False:
os.path.mkdir(root) os.mkdir(root)
for i in range(len(names)): for i in range(len(names)):
n = names[i] n = names[i]
...@@ -73,9 +78,12 @@ def dump_results(results, names, root): ...@@ -73,9 +78,12 @@ def dump_results(results, names, root):
np.save(filename + '.npy', res) np.save(filename + '.npy', res)
def infer(net_file, net_name, model_file, imgfile, debug=False): def infer(net_file, net_name, model_file, imgfile, debug=True):
""" do inference using a model which consist 'xxx.py' and 'xxx.npy' """ do inference using a model which consist 'xxx.py' and 'xxx.npy'
""" """
fluid = import_fluid()
#1, build model #1, build model
net, input_shape = build_model(net_file, net_name) net, input_shape = build_model(net_file, net_name)
prediction = net.get_output() prediction = net.get_output()
...@@ -109,34 +117,79 @@ def infer(net_file, net_name, model_file, imgfile, debug=False): ...@@ -109,34 +117,79 @@ def infer(net_file, net_name, model_file, imgfile, debug=False):
fetch_list=fetch_list_var) fetch_list=fetch_list_var)
if debug is True: if debug is True:
dump_path = 'results.layers' dump_path = 'results.paddle'
dump_results(results, fetch_list_name, dump_path) dump_results(results, fetch_list_name, dump_path)
print('all results dumped to [%s]' % (dump_path)) print('all result of layers dumped to [%s]' % (dump_path))
else: else:
result = results[0] result = results[0]
print('predicted class:', np.argmax(result)) print('predicted class:', np.argmax(result))
return 0
def caffe_infer(prototxt, caffemodel, datafile):
""" do inference using pycaffe for debug,
all intermediate results will be dumpped to 'results.caffe'
"""
import caffe
net = caffe.Net(prototxt, caffemodel, caffe.TEST)
input_layer = net.blobs.keys()[0]
print('got name of input layer is:%s' % (input_layer))
input_shape = list(net.blobs[input_layer].data.shape[1:])
if '.npy' in datafile:
np_images = np.load(datafile)
else:
np_images = load_data(datafile, input_shape)
inputs = {input_layer: np_images}
net.forward_all(**inputs)
results = []
names = []
for k, v in net.blobs.items():
k = k.rstrip('_output')
k = k.replace('/', '_')
names.append(k)
results.append(v.data.copy())
dump_path = 'results.caffe'
dump_results(results, names, dump_path)
print('all result of layers dumped to [%s]' % (dump_path))
return 0
if __name__ == "__main__": if __name__ == "__main__":
""" maybe more convenient to use 'run.sh' to call this tool """ maybe more convenient to use 'run.sh' to call this tool
""" """
net_file = 'models/resnet50/resnet50.py' net_file = 'models/resnet50/resnet50.py'
weight_file = 'models/resnet50/resnet50.npy' weight_file = 'models/resnet50/resnet50.npy'
imgfile = 'data/65.jpeg' datafile = 'data/65.jpeg'
net_name = 'ResNet50' net_name = 'ResNet50'
argc = len(sys.argv) argc = len(sys.argv)
if argc == 5: if sys.argv[1] == 'caffe':
if len(sys.argv) != 5:
print('usage:')
print('\tpython %s caffe [prototxt] [caffemodel] [datafile]' %
(sys.argv[0]))
sys.exit(1)
prototxt = sys.argv[2]
caffemodel = sys.argv[3]
datafile = sys.argv[4]
sys.exit(caffe_infer(prototxt, caffemodel, datafile))
elif argc == 5:
net_file = sys.argv[1] net_file = sys.argv[1]
weight_file = sys.argv[2] weight_file = sys.argv[2]
imgfile = sys.argv[3] datafile = sys.argv[3]
net_name = sys.argv[4] net_name = sys.argv[4]
elif argc > 1: elif argc > 1:
print('usage:') print('usage:')
print('\tpython %s [net_file] [weight_file] [imgfile] [net_name]' % print('\tpython %s [net_file] [weight_file] [datafile] [net_name]' %
(sys.argv[0])) (sys.argv[0]))
print('\teg:python %s %s %s %s %s' % (sys.argv[0], net_file, print('\teg:python %s %s %s %s %s' % (sys.argv[0], net_file,
weight_file, imgfile, net_name)) weight_file, datafile, net_name))
sys.exit(1) sys.exit(1)
infer(net_file, net_name, weight_file, imgfile) infer(net_file, net_name, weight_file, datafile)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#function: #function:
# a tool used to: # a tool used to:
# 1, convert a caffe model # 1, convert a caffe model
# 2, do inference using this model # 2, do inference(only in fluid) using this model
# #
#usage: #usage:
# bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50 # bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50
...@@ -65,7 +65,12 @@ if [[ -z $only_convert ]];then ...@@ -65,7 +65,12 @@ if [[ -z $only_convert ]];then
PYTHON=`which python` PYTHON=`which python`
fi fi
imgfile="data/65.jpeg" imgfile="data/65.jpeg"
net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/\"([^\"]+)\"/){ print $1."\n";}'` #FIX ME:
# only look the first line in prototxt file for the name of this network, maybe not correct
net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/^\s*name\s*:\s*\"([^\"]+)\"/){ print $1."\n";}'`
if [[ -z $net_name ]];then
net_name="MyNet"
fi
$PYTHON ./infer.py $net_file $weight_file $imgfile $net_name $PYTHON ./infer.py $net_file $weight_file $imgfile $net_name
ret=$? ret=$?
fi fi
......
...@@ -52,7 +52,10 @@ class Graph(object): ...@@ -52,7 +52,10 @@ class Graph(object):
def __init__(self, nodes=None, name=None): def __init__(self, nodes=None, name=None):
self.nodes = nodes or [] self.nodes = nodes or []
self.node_lut = {node.name: node for node in self.nodes} self.node_lut = {node.name: node for node in self.nodes}
self.name = name if name is None or name == '':
self.name = 'MyNet'
else:
self.name = name
def add_node(self, node): def add_node(self, node):
self.nodes.append(node) self.nodes.append(node)
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
def import_fluid(): def import_fluid():
import paddle.v2.fluid as fluid import paddle.fluid as fluid
return fluid return fluid
...@@ -64,7 +64,7 @@ class Network(object): ...@@ -64,7 +64,7 @@ class Network(object):
if os.path.isdir(data_path): if os.path.isdir(data_path):
assert (exe is not None), \ assert (exe is not None), \
'must provide a executor to load fluid model' 'must provide a executor to load fluid model'
fluid.io.load_persistables_if_exist(executor=exe, dirname=data_path) fluid.io.load_persistables(executor=exe, dirname=data_path)
return True return True
#load model from a npy file #load model from a npy file
...@@ -161,56 +161,28 @@ class Network(object): ...@@ -161,56 +161,28 @@ class Network(object):
output = fluid.layers.relu(x=input) output = fluid.layers.relu(x=input)
return output return output
def _adjust_pad_if_needed(self, i_hw, k_hw, s_hw, p_hw):
#adjust the padding if needed
i_h, i_w = i_hw
k_h, k_w = k_hw
s_h, s_w = s_hw
p_h, p_w = p_hw
def is_consistent(i, k, s, p):
o = i + 2 * p - k
if o % s == 0:
return True
else:
return False
real_p_h = 0
real_p_w = 0
if is_consistent(i_h, k_h, s_h, p_h) is False:
real_p_h = int(k_h / 2)
if is_consistent(i_w, k_w, s_w, p_w) is False:
real_p_w = int(k_w / 2)
return [real_p_h, real_p_w]
def pool(self, pool_type, input, k_h, k_w, s_h, s_w, name, padding): def pool(self, pool_type, input, k_h, k_w, s_h, s_w, name, padding):
# Get the number of channels in the input # Get the number of channels in the input
in_hw = input.shape[2:] in_hw = input.shape[2:]
k_hw = [k_h, k_w] k_hw = [k_h, k_w]
s_hw = [s_h, s_w] s_hw = [s_h, s_w]
if padding is None:
#fix bug about the difference between conv and pool
#more info: https://github.com/BVLC/caffe/issues/1318
padding = self._adjust_pad_if_needed(in_hw, k_hw, s_hw, [0, 0])
fluid = import_fluid() fluid = import_fluid()
output = fluid.layers.pool2d( output = fluid.layers.pool2d(
input=input, input=input,
pool_size=k_hw, pool_size=k_hw,
pool_stride=s_hw, pool_stride=s_hw,
pool_padding=padding, pool_padding=padding,
ceil_mode=True,
pool_type=pool_type) pool_type=pool_type)
return output return output
@layer @layer
def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None): def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=[0, 0]):
return self.pool('max', input, k_h, k_w, s_h, s_w, name, padding) return self.pool('max', input, k_h, k_w, s_h, s_w, name, padding)
@layer @layer
def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None): def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=[0, 0]):
return self.pool('avg', input, k_h, k_w, s_h, s_w, name, padding) return self.pool('avg', input, k_h, k_w, s_h, s_w, name, padding)
@layer @layer
...@@ -258,7 +230,12 @@ class Network(object): ...@@ -258,7 +230,12 @@ class Network(object):
return output return output
@layer @layer
def batch_normalization(self, input, name, scale_offset=True, relu=False): def batch_normalization(self,
input,
name,
scale_offset=True,
eps=1e-5,
relu=False):
# NOTE: Currently, only inference is supported # NOTE: Currently, only inference is supported
fluid = import_fluid() fluid = import_fluid()
prefix = name + '_' prefix = name + '_'
...@@ -276,7 +253,7 @@ class Network(object): ...@@ -276,7 +253,7 @@ class Network(object):
bias_attr=bias_attr, bias_attr=bias_attr,
moving_mean_name=mean_name, moving_mean_name=mean_name,
moving_variance_name=variance_name, moving_variance_name=variance_name,
epsilon=1e-5, epsilon=eps,
act='relu' if relu is True else None) act='relu' if relu is True else None)
return output return output
......
...@@ -142,7 +142,13 @@ class TensorFlowMapper(NodeMapper): ...@@ -142,7 +142,13 @@ class TensorFlowMapper(NodeMapper):
def map_batch_norm(self, node): def map_batch_norm(self, node):
scale_offset = len(node.data) == 4 scale_offset = len(node.data) == 4
kwargs = {} if scale_offset else {'scale_offset': False}
#this default value comes from caffe's param in batch_norm
default_eps = 1e-5
kwargs = {'scale_offset': scale_offset}
if node.parameters.eps != default_eps:
kwargs['eps'] = node.parameters.eps
return MaybeActivated( return MaybeActivated(
node, default=False)('batch_normalization', **kwargs) node, default=False)('batch_normalization', **kwargs)
...@@ -236,7 +242,7 @@ class TensorFlowEmitter(object): ...@@ -236,7 +242,7 @@ class TensorFlowEmitter(object):
func_def = self.statement('@classmethod') func_def = self.statement('@classmethod')
func_def += self.statement('def convert(cls, npy_model, fluid_path):') func_def += self.statement('def convert(cls, npy_model, fluid_path):')
self.indent() self.indent()
func_def += self.statement('import paddle.v2.fluid as fluid') func_def += self.statement('fluid = import_fluid()')
for l in codes: for l in codes:
func_def += self.statement(l) func_def += self.statement(l)
return '\n' + func_def return '\n' + func_def
......
import os
import numpy as np
import time
import sys
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import reader
def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1,
...@@ -124,164 +119,3 @@ def SE_ResNeXt(input, class_dim, infer=False, layers=50): ...@@ -124,164 +119,3 @@ def SE_ResNeXt(input, class_dim, infer=False, layers=50):
drop = pool drop = pool
out = fluid.layers.fc(input=drop, size=class_dim, act='softmax') out = fluid.layers.fc(input=drop, size=class_dim, act='softmax')
return out return out
def train(learning_rate,
batch_size,
num_passes,
init_model=None,
model_save_dir='model',
parallel=True,
use_nccl=True,
lr_strategy=None,
layers=50):
class_dim = 1000
image_shape = [3, 224, 224]
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
if parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
with pd.do():
image_ = pd.read_input(image)
label_ = pd.read_input(label)
out = SE_ResNeXt(input=image_, class_dim=class_dim, layers=layers)
cost = fluid.layers.cross_entropy(input=out, label=label_)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label_, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label_, k=5)
pd.write_output(avg_cost)
pd.write_output(acc_top1)
pd.write_output(acc_top5)
avg_cost, acc_top1, acc_top5 = pd()
avg_cost = fluid.layers.mean(x=avg_cost)
acc_top1 = fluid.layers.mean(x=acc_top1)
acc_top5 = fluid.layers.mean(x=acc_top5)
else:
out = SE_ResNeXt(input=image, class_dim=class_dim, layers=layers)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
if lr_strategy is None:
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
else:
bd = lr_strategy["bd"]
lr = lr_strategy["lr"]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
opts = optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program())
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program(
[avg_cost, acc_top1, acc_top5])
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if init_model is not None:
fluid.io.load_persistables(exe, init_model)
train_reader = paddle.batch(reader.train(), batch_size=batch_size)
test_reader = paddle.batch(reader.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
for pass_id in range(num_passes):
train_info = [[], [], []]
test_info = [[], [], []]
for batch_id, data in enumerate(train_reader()):
t1 = time.time()
loss, acc1, acc5 = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
t2 = time.time()
period = t2 - t1
train_info[0].append(loss[0])
train_info[1].append(acc1[0])
train_info[2].append(acc5[0])
if batch_id % 10 == 0:
print("Pass {0}, trainbatch {1}, loss {2}, \
acc1 {3}, acc5 {4} time {5}"
.format(pass_id, \
batch_id, loss[0], acc1[0], acc5[0], \
"%2.2f sec" % period))
sys.stdout.flush()
train_loss = np.array(train_info[0]).mean()
train_acc1 = np.array(train_info[1]).mean()
train_acc5 = np.array(train_info[2]).mean()
for data in test_reader():
t1 = time.time()
loss, acc1, acc5 = exe.run(
inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
t2 = time.time()
period = t2 - t1
test_info[0].append(loss[0])
test_info[1].append(acc1[0])
test_info[2].append(acc5[0])
if batch_id % 10 == 0:
print("Pass {0},testbatch {1},loss {2}, \
acc1 {3},acc5 {4},time {5}"
.format(pass_id, \
batch_id, loss[0], acc1[0], acc5[0], \
"%2.2f sec" % period))
sys.stdout.flush()
test_loss = np.array(test_info[0]).mean()
test_acc1 = np.array(test_info[1]).mean()
test_acc5 = np.array(test_info[2]).mean()
print("End pass {0}, train_loss {1}, train_acc1 {2}, train_acc5 {3}, \
test_loss {4}, test_acc1 {5}, test_acc5 {6}"
.format(pass_id, \
train_loss, train_acc1, train_acc5, test_loss, test_acc1, \
test_acc5))
sys.stdout.flush()
model_path = os.path.join(model_save_dir, str(pass_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
if __name__ == '__main__':
epoch_points = [30, 60, 90]
total_images = 1281167
batch_size = 256
step = int(total_images / batch_size + 1)
bd = [e * step for e in epoch_points]
lr = [0.1, 0.01, 0.001, 0.0001]
lr_strategy = {"bd": bd, "lr": lr}
use_nccl = True
# layers: 50, 152
layers = 50
train(
learning_rate=0.1,
batch_size=batch_size,
num_passes=120,
init_model=None,
parallel=True,
use_nccl=True,
lr_strategy=lr_strategy,
layers=layers)
import os
import numpy as np
import time
import sys
import paddle.v2 as paddle
import paddle.fluid as fluid
from se_resnext import SE_ResNeXt
import reader
import argparse
import functools
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 256, "Minibatch size.")
add_arg('num_layers', int, 50, "How many layers for SE-ResNeXt model.")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('parallel_exe', bool, True, "Whether to use ParallelExecutor to train or not.")
def train_paralle_do(args,
learning_rate,
batch_size,
num_passes,
init_model=None,
model_save_dir='model',
parallel=True,
use_nccl=True,
lr_strategy=None,
layers=50):
class_dim = 1000
image_shape = [3, 224, 224]
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
if parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
with pd.do():
image_ = pd.read_input(image)
label_ = pd.read_input(label)
out = SE_ResNeXt(input=image_, class_dim=class_dim, layers=layers)
cost = fluid.layers.cross_entropy(input=out, label=label_)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label_, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label_, k=5)
pd.write_output(avg_cost)
pd.write_output(acc_top1)
pd.write_output(acc_top5)
avg_cost, acc_top1, acc_top5 = pd()
avg_cost = fluid.layers.mean(x=avg_cost)
acc_top1 = fluid.layers.mean(x=acc_top1)
acc_top5 = fluid.layers.mean(x=acc_top5)
else:
out = SE_ResNeXt(input=image, class_dim=class_dim, layers=layers)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
if lr_strategy is None:
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
else:
bd = lr_strategy["bd"]
lr = lr_strategy["lr"]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
inference_program = fluid.default_main_program().clone(for_test=True)
opts = optimizer.minimize(avg_cost)
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(inference_program)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if init_model is not None:
fluid.io.load_persistables(exe, init_model)
train_reader = paddle.batch(reader.train(), batch_size=batch_size)
test_reader = paddle.batch(reader.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
for pass_id in range(num_passes):
train_info = [[], [], []]
test_info = [[], [], []]
for batch_id, data in enumerate(train_reader()):
t1 = time.time()
loss, acc1, acc5 = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
t2 = time.time()
period = t2 - t1
train_info[0].append(loss[0])
train_info[1].append(acc1[0])
train_info[2].append(acc5[0])
if batch_id % 10 == 0:
print("Pass {0}, trainbatch {1}, loss {2}, \
acc1 {3}, acc5 {4} time {5}"
.format(pass_id, \
batch_id, loss[0], acc1[0], acc5[0], \
"%2.2f sec" % period))
sys.stdout.flush()
train_loss = np.array(train_info[0]).mean()
train_acc1 = np.array(train_info[1]).mean()
train_acc5 = np.array(train_info[2]).mean()
for data in test_reader():
t1 = time.time()
loss, acc1, acc5 = exe.run(
inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
t2 = time.time()
period = t2 - t1
test_info[0].append(loss[0])
test_info[1].append(acc1[0])
test_info[2].append(acc5[0])
if batch_id % 10 == 0:
print("Pass {0},testbatch {1},loss {2}, \
acc1 {3},acc5 {4},time {5}"
.format(pass_id, \
batch_id, loss[0], acc1[0], acc5[0], \
"%2.2f sec" % period))
sys.stdout.flush()
test_loss = np.array(test_info[0]).mean()
test_acc1 = np.array(test_info[1]).mean()
test_acc5 = np.array(test_info[2]).mean()
print("End pass {0}, train_loss {1}, train_acc1 {2}, train_acc5 {3}, \
test_loss {4}, test_acc1 {5}, test_acc5 {6}"
.format(pass_id, \
train_loss, train_acc1, train_acc5, test_loss, test_acc1, \
test_acc5))
sys.stdout.flush()
model_path = os.path.join(model_save_dir, str(pass_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
def train_parallel_exe(args,
learning_rate,
batch_size,
num_passes,
init_model=None,
model_save_dir='model',
parallel=True,
use_nccl=True,
lr_strategy=None,
layers=50):
class_dim = 1000
image_shape = [3, 224, 224]
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = SE_ResNeXt(input=image, class_dim=class_dim, layers=layers)
cost = fluid.layers.cross_entropy(input=out, label=label)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
avg_cost = fluid.layers.mean(x=cost)
test_program = fluid.default_main_program().clone(for_test=True)
if lr_strategy is None:
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
else:
bd = lr_strategy["bd"]
lr = lr_strategy["lr"]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
opts = optimizer.minimize(avg_cost)
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(test_program)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if init_model is not None:
fluid.io.load_persistables(exe, init_model)
train_reader = paddle.batch(reader.train(), batch_size=batch_size)
test_reader = paddle.batch(reader.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
test_exe = fluid.ParallelExecutor(
use_cuda=True,
main_program=test_program,
share_vars_from=train_exe)
fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
for pass_id in range(num_passes):
train_info = [[], [], []]
test_info = [[], [], []]
for batch_id, data in enumerate(train_reader()):
t1 = time.time()
loss, acc1, acc5 = train_exe.run(
fetch_list,
feed_dict=feeder.feed(data))
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
acc1 = np.mean(np.array(acc1))
acc5 = np.mean(np.array(acc5))
train_info[0].append(loss)
train_info[1].append(acc1)
train_info[2].append(acc5)
if batch_id % 10 == 0:
print("Pass {0}, trainbatch {1}, loss {2}, \
acc1 {3}, acc5 {4} time {5}"
.format(pass_id, \
batch_id, loss, acc1, acc5, \
"%2.2f sec" % period))
sys.stdout.flush()
train_loss = np.array(train_info[0]).mean()
train_acc1 = np.array(train_info[1]).mean()
train_acc5 = np.array(train_info[2]).mean()
for data in test_reader():
t1 = time.time()
loss, acc1, acc5 = test_exe.run(
fetch_list,
feed_dict=feeder.feed(data))
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
acc1 = np.mean(np.array(acc1))
acc5 = np.mean(np.array(acc5))
test_info[0].append(loss)
test_info[1].append(acc1)
test_info[2].append(acc5)
if batch_id % 10 == 0:
print("Pass {0},testbatch {1},loss {2}, \
acc1 {3},acc5 {4},time {5}"
.format(pass_id, \
batch_id, loss, acc1, acc5, \
"%2.2f sec" % period))
sys.stdout.flush()
test_loss = np.array(test_info[0]).mean()
test_acc1 = np.array(test_info[1]).mean()
test_acc5 = np.array(test_info[2]).mean()
print("End pass {0}, train_loss {1}, train_acc1 {2}, train_acc5 {3}, \
test_loss {4}, test_acc1 {5}, test_acc5 {6}"
.format(pass_id, \
train_loss, train_acc1, train_acc5, test_loss, test_acc1, \
test_acc5))
sys.stdout.flush()
model_path = os.path.join(model_save_dir, str(pass_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
epoch_points = [30, 60, 90]
total_images = 1281167
batch_size = args.batch_size
step = int(total_images / batch_size + 1)
bd = [e * step for e in epoch_points]
lr = [0.1, 0.01, 0.001, 0.0001]
lr_strategy = {"bd": bd, "lr": lr}
use_nccl = True
# layers: 50, 152
layers = args.num_layers
method = train_parallel_exe if args.parallel_exe else train_parallel_do
method(args,
learning_rate=0.1,
batch_size=batch_size,
num_passes=120,
init_model=None,
parallel=True,
use_nccl=True,
lr_strategy=lr_strategy,
layers=layers)
"""Contains common utility functions."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import numpy as np
from paddle.fluid import core
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).iteritems()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
...@@ -30,15 +30,12 @@ class PolicyGradient: ...@@ -30,15 +30,12 @@ class PolicyGradient:
acts = fluid.layers.data(name='acts', shape=[1], dtype='int64') acts = fluid.layers.data(name='acts', shape=[1], dtype='int64')
vt = fluid.layers.data(name='vt', shape=[1], dtype='float32') vt = fluid.layers.data(name='vt', shape=[1], dtype='float32')
# fc1 # fc1
fc1 = fluid.layers.fc( fc1 = fluid.layers.fc(input=obs, size=10, act="tanh") # tanh activation
input=obs,
size=10,
act="tanh" # tanh activation
)
# fc2 # fc2
self.all_act_prob = fluid.layers.fc(input=fc1, all_act_prob = fluid.layers.fc(input=fc1,
size=self.n_actions, size=self.n_actions,
act="softmax") act="softmax")
self.inferece_program = fluid.defaul_main_program().clone()
# to maximize total reward (log_p * R) is to minimize -(log_p * R) # to maximize total reward (log_p * R) is to minimize -(log_p * R)
neg_log_prob = fluid.layers.cross_entropy( neg_log_prob = fluid.layers.cross_entropy(
input=self.all_act_prob, input=self.all_act_prob,
...@@ -52,10 +49,9 @@ class PolicyGradient: ...@@ -52,10 +49,9 @@ class PolicyGradient:
self.exe.run(fluid.default_startup_program()) self.exe.run(fluid.default_startup_program())
def choose_action(self, observation): def choose_action(self, observation):
prob_weights = self.exe.run( prob_weights = self.exe.run(self.inferece_program,
fluid.default_main_program().prune(self.all_act_prob), feed={"obs": observation[np.newaxis, :]},
feed={"obs": observation[np.newaxis, :]}, fetch_list=[self.all_act_prob])
fetch_list=[self.all_act_prob])
prob_weights = np.array(prob_weights[0]) prob_weights = np.array(prob_weights[0])
action = np.random.choice( action = np.random.choice(
range(prob_weights.shape[1]), range(prob_weights.shape[1]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册