提交 1847c180 编写于 作者: D Dang Qingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/models into ssd_pl_exe

.DS_Store .DS_Store
*.pyc *.pyc
.*~
...@@ -18,19 +18,19 @@ This tool is used to convert a Caffe model to Fluid model ...@@ -18,19 +18,19 @@ This tool is used to convert a Caffe model to Fluid model
### Tested models ### Tested models
- Lenet on mnist dataset - Lenet
- ResNets:(ResNet-50, ResNet-101, ResNet-152) - ResNets:(ResNet-50, ResNet-101, ResNet-152)
model addr: `https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777`_ [model addr](https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)
- GoogleNet: - GoogleNet:
model addr: `https://gist.github.com/jimmie33/7ea9f8ac0da259866b854460f4526034`_ [model addr](https://gist.github.com/jimmie33/7ea9f8ac0da259866b854460f4526034)
- VGG: - VGG:
model addr: `https://gist.github.com/ksimonyan/211839e770f7b538e2d8`_ [model addr](https://gist.github.com/ksimonyan/211839e770f7b538e2d8)
- AlexNet: - AlexNet:
model addr: `https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet`_ [model addr](https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet)
### Notes ### Notes
Some of this code come from here: https://github.com/ethereon/caffe-tensorflow Some of this code come from here: https://github.com/ethereon/caffe-tensorflow
#!/usr/bin/python
#
#a tool to compare tensors in two files or two directories
#
import sys
import os
def walk_dir(rootdir):
for subdir, dirs, files in os.walk(rootdir):
for file in files:
yield file
def calc_diff(f1, f2):
import numpy as np
d1 = np.load(f1).flatten()
d2 = np.load(f2).flatten()
d1_num = reduce(lambda x, y: x * y, d1.shape)
d2_num = reduce(lambda x, y: x * y, d2.shape)
if d1_num != d2_num:
print d1.shape
print d2.shape
assert (d1_num == d2_num), "their shape is not consistent"
try:
df = np.abs(d1 - d2)
max_df = np.max(df)
sq_df = np.mean(df * df)
return max_df, sq_df
except Exception as e:
return -1.0, -1.0
def compare(path1, path2):
def diff(f1, f2):
max_df, sq_df = calc_diff(f1, f2)
print('compare %s <=> %s with result[max_df:%.4e, sq_df:%.4e]' %
(f1, f2, max_df, sq_df))
assert (max_df < 1e-5), \
'max_df is too large with value[%.6e]' % (max_df)
assert (sq_df < 1e-10), \
'sq_df is too large with value[%.6e]' % (sq_df)
if os.path.exists(path1) is False:
print('not found %s' % (path1))
return 1
elif os.path.exists(path2) is False:
print('not found %s' % (path2))
return 1
if path1.find('.npy') > 0 and path2.find('.npy') > 0:
diff(path1, path2)
return
for f in walk_dir(path2):
if f.find('.npy') < 0:
continue
f1 = os.path.join(path1, f)
f2 = os.path.join(path2, f)
diff(f1, f2)
print('all checking succeed to pass')
return 0
if __name__ == "__main__":
if len(sys.argv) == 1:
path1 = 'lenet.tf/results'
path2 = 'lenet.paddle/results'
elif len(sys.argv) == 3:
path1 = sys.argv[1]
path2 = sys.argv[2]
else:
print('usage:')
print(' %s [path1] [path2]' % (sys.argv[0]))
exit(1)
print('compare inner result in %s %s' % (path1, path2))
exit(compare(path1, path2))
#!/bin/bash
#
#function:
# a tool used to check the difference of models' results generated by caffe model and paddle model
#
#howto:
# bash diff.sh resnet50 #when this has been finished, you can get the difference in precision
#
#notes:
# 0, in order to infer using caffe, we need pycaffe installed
# 1, prepare your caffe model in 'models.caffe/', eg: 'model.caffe/resnet101/resnet101.[prototxt|caffemodel]'
# 2, converted paddle model will be in 'models'
# 3, results of layers will be stored in 'results/${model_name}.[paddle|caffe]'
# 4, only the last layer will be checked by default
model_name="resnet50"
results_root="results/"
if [[ -n $1 ]];then
if [ $1 = "-h" ];then
echo "usage:"
echo " bash $0 [model_name]"
echo " eg:bash $0 resnet50"
exit 0
fi
model_name=$1
fi
mkdir -p $results_root
model_prototxt="models.caffe/$model_name/${model_name}.prototxt"
model_caffemodel="models.caffe/${model_name}/${model_name}.caffemodel"
#1, dump layers' results from paddle
paddle_results="$results_root/${model_name}.paddle"
rm -rf $paddle_results
rm -rf "results.paddle"
bash run.sh $model_name ./models.caffe/$model_name ./models/$model_name
if [[ $? -ne 0 ]] || [[ ! -e "results.paddle" ]];then
echo "not found paddle's results, maybe failed to convert"
exit 1
fi
mv results.paddle $paddle_results
#2, dump layers' results from caffe
caffe_results="$results_root/${model_name}.caffe"
rm -rf $caffe_results
rm -rf "results.caffe"
cfpython ./infer.py caffe $model_prototxt $model_caffemodel $paddle_results/data.npy
if [[ $? -ne 0 ]] || [[ ! -e "results.caffe" ]];then
echo "not found caffe's results, maybe failed to do inference with caffe"
exit 1
fi
mv results.caffe $caffe_results
#3, extract layer names
cat $model_prototxt | grep name | perl -ne 'if(/^\s*name:\s+\"([^\"]+)/){ print $1."\n";}' >.layer_names
#4, compare one by one
for i in $(cat ".layer_names" | tail -n1);do
echo "process $i"
python compare.py $caffe_results/${i}.npy $paddle_results/${i}.npy
done
...@@ -10,8 +10,11 @@ import os ...@@ -10,8 +10,11 @@ import os
import sys import sys
import inspect import inspect
import numpy as np import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
def import_fluid():
import paddle.fluid as fluid
return fluid
def load_data(imgfile, shape): def load_data(imgfile, shape):
...@@ -52,8 +55,10 @@ def build_model(net_file, net_name): ...@@ -52,8 +55,10 @@ def build_model(net_file, net_name):
print(e) print(e)
return None return None
input_name = 'data' fluid = import_fluid()
input_shape = MyNet.input_shapes()[input_name] inputs_dict = MyNet.input_shapes()
input_name = inputs_dict.keys()[0]
input_shape = inputs_dict[input_name]
images = fluid.layers.data(name='image', shape=input_shape, dtype='float32') images = fluid.layers.data(name='image', shape=input_shape, dtype='float32')
#label = fluid.layers.data(name='label', shape=[1], dtype='int64') #label = fluid.layers.data(name='label', shape=[1], dtype='int64')
...@@ -64,7 +69,7 @@ def build_model(net_file, net_name): ...@@ -64,7 +69,7 @@ def build_model(net_file, net_name):
def dump_results(results, names, root): def dump_results(results, names, root):
if os.path.exists(root) is False: if os.path.exists(root) is False:
os.path.mkdir(root) os.mkdir(root)
for i in range(len(names)): for i in range(len(names)):
n = names[i] n = names[i]
...@@ -73,9 +78,12 @@ def dump_results(results, names, root): ...@@ -73,9 +78,12 @@ def dump_results(results, names, root):
np.save(filename + '.npy', res) np.save(filename + '.npy', res)
def infer(net_file, net_name, model_file, imgfile, debug=False): def infer(net_file, net_name, model_file, imgfile, debug=True):
""" do inference using a model which consist 'xxx.py' and 'xxx.npy' """ do inference using a model which consist 'xxx.py' and 'xxx.npy'
""" """
fluid = import_fluid()
#1, build model #1, build model
net, input_shape = build_model(net_file, net_name) net, input_shape = build_model(net_file, net_name)
prediction = net.get_output() prediction = net.get_output()
...@@ -109,34 +117,79 @@ def infer(net_file, net_name, model_file, imgfile, debug=False): ...@@ -109,34 +117,79 @@ def infer(net_file, net_name, model_file, imgfile, debug=False):
fetch_list=fetch_list_var) fetch_list=fetch_list_var)
if debug is True: if debug is True:
dump_path = 'results.layers' dump_path = 'results.paddle'
dump_results(results, fetch_list_name, dump_path) dump_results(results, fetch_list_name, dump_path)
print('all results dumped to [%s]' % (dump_path)) print('all result of layers dumped to [%s]' % (dump_path))
else: else:
result = results[0] result = results[0]
print('predicted class:', np.argmax(result)) print('predicted class:', np.argmax(result))
return 0
def caffe_infer(prototxt, caffemodel, datafile):
""" do inference using pycaffe for debug,
all intermediate results will be dumpped to 'results.caffe'
"""
import caffe
net = caffe.Net(prototxt, caffemodel, caffe.TEST)
input_layer = net.blobs.keys()[0]
print('got name of input layer is:%s' % (input_layer))
input_shape = list(net.blobs[input_layer].data.shape[1:])
if '.npy' in datafile:
np_images = np.load(datafile)
else:
np_images = load_data(datafile, input_shape)
inputs = {input_layer: np_images}
net.forward_all(**inputs)
results = []
names = []
for k, v in net.blobs.items():
k = k.rstrip('_output')
k = k.replace('/', '_')
names.append(k)
results.append(v.data.copy())
dump_path = 'results.caffe'
dump_results(results, names, dump_path)
print('all result of layers dumped to [%s]' % (dump_path))
return 0
if __name__ == "__main__": if __name__ == "__main__":
""" maybe more convenient to use 'run.sh' to call this tool """ maybe more convenient to use 'run.sh' to call this tool
""" """
net_file = 'models/resnet50/resnet50.py' net_file = 'models/resnet50/resnet50.py'
weight_file = 'models/resnet50/resnet50.npy' weight_file = 'models/resnet50/resnet50.npy'
imgfile = 'data/65.jpeg' datafile = 'data/65.jpeg'
net_name = 'ResNet50' net_name = 'ResNet50'
argc = len(sys.argv) argc = len(sys.argv)
if argc == 5: if sys.argv[1] == 'caffe':
if len(sys.argv) != 5:
print('usage:')
print('\tpython %s caffe [prototxt] [caffemodel] [datafile]' %
(sys.argv[0]))
sys.exit(1)
prototxt = sys.argv[2]
caffemodel = sys.argv[3]
datafile = sys.argv[4]
sys.exit(caffe_infer(prototxt, caffemodel, datafile))
elif argc == 5:
net_file = sys.argv[1] net_file = sys.argv[1]
weight_file = sys.argv[2] weight_file = sys.argv[2]
imgfile = sys.argv[3] datafile = sys.argv[3]
net_name = sys.argv[4] net_name = sys.argv[4]
elif argc > 1: elif argc > 1:
print('usage:') print('usage:')
print('\tpython %s [net_file] [weight_file] [imgfile] [net_name]' % print('\tpython %s [net_file] [weight_file] [datafile] [net_name]' %
(sys.argv[0])) (sys.argv[0]))
print('\teg:python %s %s %s %s %s' % (sys.argv[0], net_file, print('\teg:python %s %s %s %s %s' % (sys.argv[0], net_file,
weight_file, imgfile, net_name)) weight_file, datafile, net_name))
sys.exit(1) sys.exit(1)
infer(net_file, net_name, weight_file, imgfile) infer(net_file, net_name, weight_file, datafile)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#function: #function:
# a tool used to: # a tool used to:
# 1, convert a caffe model # 1, convert a caffe model
# 2, do inference using this model # 2, do inference(only in fluid) using this model
# #
#usage: #usage:
# bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50 # bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50
...@@ -65,7 +65,12 @@ if [[ -z $only_convert ]];then ...@@ -65,7 +65,12 @@ if [[ -z $only_convert ]];then
PYTHON=`which python` PYTHON=`which python`
fi fi
imgfile="data/65.jpeg" imgfile="data/65.jpeg"
net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/\"([^\"]+)\"/){ print $1."\n";}'` #FIX ME:
# only look the first line in prototxt file for the name of this network, maybe not correct
net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/^\s*name\s*:\s*\"([^\"]+)\"/){ print $1."\n";}'`
if [[ -z $net_name ]];then
net_name="MyNet"
fi
$PYTHON ./infer.py $net_file $weight_file $imgfile $net_name $PYTHON ./infer.py $net_file $weight_file $imgfile $net_name
ret=$? ret=$?
fi fi
......
...@@ -52,7 +52,10 @@ class Graph(object): ...@@ -52,7 +52,10 @@ class Graph(object):
def __init__(self, nodes=None, name=None): def __init__(self, nodes=None, name=None):
self.nodes = nodes or [] self.nodes = nodes or []
self.node_lut = {node.name: node for node in self.nodes} self.node_lut = {node.name: node for node in self.nodes}
self.name = name if name is None or name == '':
self.name = 'MyNet'
else:
self.name = name
def add_node(self, node): def add_node(self, node):
self.nodes.append(node) self.nodes.append(node)
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
def import_fluid(): def import_fluid():
import paddle.v2.fluid as fluid import paddle.fluid as fluid
return fluid return fluid
...@@ -64,7 +64,7 @@ class Network(object): ...@@ -64,7 +64,7 @@ class Network(object):
if os.path.isdir(data_path): if os.path.isdir(data_path):
assert (exe is not None), \ assert (exe is not None), \
'must provide a executor to load fluid model' 'must provide a executor to load fluid model'
fluid.io.load_persistables_if_exist(executor=exe, dirname=data_path) fluid.io.load_persistables(executor=exe, dirname=data_path)
return True return True
#load model from a npy file #load model from a npy file
...@@ -161,56 +161,28 @@ class Network(object): ...@@ -161,56 +161,28 @@ class Network(object):
output = fluid.layers.relu(x=input) output = fluid.layers.relu(x=input)
return output return output
def _adjust_pad_if_needed(self, i_hw, k_hw, s_hw, p_hw):
#adjust the padding if needed
i_h, i_w = i_hw
k_h, k_w = k_hw
s_h, s_w = s_hw
p_h, p_w = p_hw
def is_consistent(i, k, s, p):
o = i + 2 * p - k
if o % s == 0:
return True
else:
return False
real_p_h = 0
real_p_w = 0
if is_consistent(i_h, k_h, s_h, p_h) is False:
real_p_h = int(k_h / 2)
if is_consistent(i_w, k_w, s_w, p_w) is False:
real_p_w = int(k_w / 2)
return [real_p_h, real_p_w]
def pool(self, pool_type, input, k_h, k_w, s_h, s_w, name, padding): def pool(self, pool_type, input, k_h, k_w, s_h, s_w, name, padding):
# Get the number of channels in the input # Get the number of channels in the input
in_hw = input.shape[2:] in_hw = input.shape[2:]
k_hw = [k_h, k_w] k_hw = [k_h, k_w]
s_hw = [s_h, s_w] s_hw = [s_h, s_w]
if padding is None:
#fix bug about the difference between conv and pool
#more info: https://github.com/BVLC/caffe/issues/1318
padding = self._adjust_pad_if_needed(in_hw, k_hw, s_hw, [0, 0])
fluid = import_fluid() fluid = import_fluid()
output = fluid.layers.pool2d( output = fluid.layers.pool2d(
input=input, input=input,
pool_size=k_hw, pool_size=k_hw,
pool_stride=s_hw, pool_stride=s_hw,
pool_padding=padding, pool_padding=padding,
ceil_mode=True,
pool_type=pool_type) pool_type=pool_type)
return output return output
@layer @layer
def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None): def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=[0, 0]):
return self.pool('max', input, k_h, k_w, s_h, s_w, name, padding) return self.pool('max', input, k_h, k_w, s_h, s_w, name, padding)
@layer @layer
def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None): def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=[0, 0]):
return self.pool('avg', input, k_h, k_w, s_h, s_w, name, padding) return self.pool('avg', input, k_h, k_w, s_h, s_w, name, padding)
@layer @layer
...@@ -258,7 +230,12 @@ class Network(object): ...@@ -258,7 +230,12 @@ class Network(object):
return output return output
@layer @layer
def batch_normalization(self, input, name, scale_offset=True, relu=False): def batch_normalization(self,
input,
name,
scale_offset=True,
eps=1e-5,
relu=False):
# NOTE: Currently, only inference is supported # NOTE: Currently, only inference is supported
fluid = import_fluid() fluid = import_fluid()
prefix = name + '_' prefix = name + '_'
...@@ -276,7 +253,7 @@ class Network(object): ...@@ -276,7 +253,7 @@ class Network(object):
bias_attr=bias_attr, bias_attr=bias_attr,
moving_mean_name=mean_name, moving_mean_name=mean_name,
moving_variance_name=variance_name, moving_variance_name=variance_name,
epsilon=1e-5, epsilon=eps,
act='relu' if relu is True else None) act='relu' if relu is True else None)
return output return output
......
...@@ -142,7 +142,13 @@ class TensorFlowMapper(NodeMapper): ...@@ -142,7 +142,13 @@ class TensorFlowMapper(NodeMapper):
def map_batch_norm(self, node): def map_batch_norm(self, node):
scale_offset = len(node.data) == 4 scale_offset = len(node.data) == 4
kwargs = {} if scale_offset else {'scale_offset': False}
#this default value comes from caffe's param in batch_norm
default_eps = 1e-5
kwargs = {'scale_offset': scale_offset}
if node.parameters.eps != default_eps:
kwargs['eps'] = node.parameters.eps
return MaybeActivated( return MaybeActivated(
node, default=False)('batch_normalization', **kwargs) node, default=False)('batch_normalization', **kwargs)
...@@ -236,7 +242,7 @@ class TensorFlowEmitter(object): ...@@ -236,7 +242,7 @@ class TensorFlowEmitter(object):
func_def = self.statement('@classmethod') func_def = self.statement('@classmethod')
func_def += self.statement('def convert(cls, npy_model, fluid_path):') func_def += self.statement('def convert(cls, npy_model, fluid_path):')
self.indent() self.indent()
func_def += self.statement('import paddle.v2.fluid as fluid') func_def += self.statement('fluid = import_fluid()')
for l in codes: for l in codes:
func_def += self.statement(l) func_def += self.statement(l)
return '\n' + func_def return '\n' + func_def
......
...@@ -43,21 +43,16 @@ class InferTaskConfig(object): ...@@ -43,21 +43,16 @@ class InferTaskConfig(object):
class ModelHyperParams(object): class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses # This model directly uses paddle.dataset.wmt16 in which <bos>, <eos> and
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has # <unk> token has alreay been added. As for the <pad> token, any token
# alreay been added, but the <pad> token is not added. Transformer requires # included in dict can be used to pad, since the paddings' loss will be
# sequences in a mini-batch are padded to have the same length. A <pad> token is # masked out and make no effect on parameter gradients.
# added into the original dictionary in paddle.dateset.wmt16.
# size of source word dictionary. # size of source word dictionary.
src_vocab_size = 10000 src_vocab_size = 10000
# index for <pad> token in source language.
src_pad_idx = src_vocab_size
# size of target word dictionay # size of target word dictionay
trg_vocab_size = 10000 trg_vocab_size = 10000
# index for <pad> token in target language.
trg_pad_idx = trg_vocab_size
# index for <bos> token # index for <bos> token
bos_idx = 0 bos_idx = 0
...@@ -66,11 +61,10 @@ class ModelHyperParams(object): ...@@ -66,11 +61,10 @@ class ModelHyperParams(object):
# index for <unk> token # index for <unk> token
unk_idx = 2 unk_idx = 2
# position value corresponding to the <pad> token. # max length of sequences.
pos_pad_idx = 0 # The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# max length of sequences. It should plus 1 to include position # token for position encoding.
# padding token for position encoding.
max_length = 50 max_length = 50
# the dimension for word embeddings, which is also the last dimension of # the dimension for word embeddings, which is also the last dimension of
......
...@@ -41,7 +41,7 @@ def translate_batch(exe, ...@@ -41,7 +41,7 @@ def translate_batch(exe,
src_pad_idx, src_pad_idx,
n_head, n_head,
is_target=False, is_target=False,
return_pos=True, is_label=False,
return_attn_bias=True, return_attn_bias=True,
return_max_len=False) return_max_len=False)
# Append the data shape input to reshape the output of embedding layer. # Append the data shape input to reshape the output of embedding layer.
...@@ -250,22 +250,20 @@ def main(): ...@@ -250,22 +250,20 @@ def main():
encoder_program = fluid.Program() encoder_program = fluid.Program()
with fluid.program_guard(main_program=encoder_program): with fluid.program_guard(main_program=encoder_program):
enc_output = encoder( enc_output = encoder(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.dropout)
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
decoder_program = fluid.Program() decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program): with fluid.program_guard(main_program=decoder_program):
predict = decoder( predict = decoder(
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.dropout)
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
# Load model parameters of encoder and decoder separately from the saved # Load model parameters of encoder and decoder separately from the saved
# transformer model. # transformer model.
...@@ -301,9 +299,6 @@ def main(): ...@@ -301,9 +299,6 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict( trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
# Append the <pad> token since the dict provided by dataset.wmt16 does
# not include it.
trg_idx2word[ModelHyperParams.trg_pad_idx] = "<pad>"
def post_process_seq(seq, def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx, bos_idx=ModelHyperParams.bos_idx,
...@@ -327,19 +322,22 @@ def main(): ...@@ -327,19 +322,22 @@ def main():
for batch_id, data in enumerate(test_data()): for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch( batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], exe,
[item[0] for item in data],
encoder_program, encoder_program,
encoder_input_data_names, [enc_output.name], encoder_input_data_names,
[enc_output.name],
decoder_program, decoder_program,
decoder_input_data_names, [predict.name], decoder_input_data_names,
[predict.name],
InferTaskConfig.beam_size, InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.max_length,
InferTaskConfig.n_best, InferTaskConfig.n_best,
len(data), len(data),
ModelHyperParams.n_head, ModelHyperParams.n_head,
ModelHyperParams.d_model, ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx, ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.trg_pad_idx, ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.bos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx, ModelHyperParams.unk_idx,
......
...@@ -199,10 +199,8 @@ def prepare_encoder(src_word, ...@@ -199,10 +199,8 @@ def prepare_encoder(src_word,
src_pos, src_pos,
src_vocab_size, src_vocab_size,
src_emb_dim, src_emb_dim,
src_pad_idx,
src_max_len, src_max_len,
dropout_rate=0., dropout_rate=0.,
pos_pad_idx=0,
src_data_shape=None, src_data_shape=None,
pos_enc_param_name=None): pos_enc_param_name=None):
"""Add word embeddings and position encodings. """Add word embeddings and position encodings.
...@@ -214,12 +212,10 @@ def prepare_encoder(src_word, ...@@ -214,12 +212,10 @@ def prepare_encoder(src_word,
src_word_emb = layers.embedding( src_word_emb = layers.embedding(
src_word, src_word,
size=[src_vocab_size, src_emb_dim], size=[src_vocab_size, src_emb_dim],
padding_idx=src_pad_idx,
param_attr=fluid.initializer.Normal(0., 1.)) param_attr=fluid.initializer.Normal(0., 1.))
src_pos_enc = layers.embedding( src_pos_enc = layers.embedding(
src_pos, src_pos,
size=[src_max_len, src_emb_dim], size=[src_max_len, src_emb_dim],
padding_idx=pos_pad_idx,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False)) name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
...@@ -480,12 +476,16 @@ def make_inputs(input_data_names, ...@@ -480,12 +476,16 @@ def make_inputs(input_data_names,
append_batch_size=False) append_batch_size=False)
input_layers += [slf_attn_post_softmax_shape] input_layers += [slf_attn_post_softmax_shape]
if src_attn_shape_flag: if src_attn_shape_flag:
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
src_attn_pre_softmax_shape = layers.data( src_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[2], shape=[2],
dtype="int32", dtype="int32",
append_batch_size=False) append_batch_size=False)
input_layers += [src_attn_pre_softmax_shape] input_layers += [src_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
src_attn_post_softmax_shape = layers.data( src_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[4], shape=[4],
...@@ -516,10 +516,7 @@ def transformer( ...@@ -516,10 +516,7 @@ def transformer(
d_value, d_value,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate, ):
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
enc_inputs = make_inputs( enc_inputs = make_inputs(
encoder_input_data_names, encoder_input_data_names,
n_head, n_head,
...@@ -543,8 +540,6 @@ def transformer( ...@@ -543,8 +540,6 @@ def transformer(
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_inputs, ) enc_inputs, )
dec_inputs = make_inputs( dec_inputs = make_inputs(
...@@ -570,8 +565,6 @@ def transformer( ...@@ -570,8 +565,6 @@ def transformer(
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_inputs, dec_inputs,
enc_output, ) enc_output, )
...@@ -606,8 +599,6 @@ def wrap_encoder(src_vocab_size, ...@@ -606,8 +599,6 @@ def wrap_encoder(src_vocab_size,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_inputs=None): enc_inputs=None):
""" """
The wrapper assembles together all needed layers for the encoder. The wrapper assembles together all needed layers for the encoder.
...@@ -637,10 +628,8 @@ def wrap_encoder(src_vocab_size, ...@@ -637,10 +628,8 @@ def wrap_encoder(src_vocab_size,
src_pos, src_pos,
src_vocab_size, src_vocab_size,
d_model, d_model,
src_pad_idx,
max_length, max_length,
dropout_rate, dropout_rate,
pos_pad_idx,
src_data_shape, ) src_data_shape, )
enc_output = encoder( enc_output = encoder(
enc_input, enc_input,
...@@ -666,8 +655,6 @@ def wrap_decoder(trg_vocab_size, ...@@ -666,8 +655,6 @@ def wrap_decoder(trg_vocab_size,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_inputs=None, dec_inputs=None,
enc_output=None): enc_output=None):
""" """
...@@ -701,10 +688,8 @@ def wrap_decoder(trg_vocab_size, ...@@ -701,10 +688,8 @@ def wrap_decoder(trg_vocab_size,
trg_pos, trg_pos,
trg_vocab_size, trg_vocab_size,
d_model, d_model,
trg_pad_idx,
max_length, max_length,
dropout_rate, dropout_rate,
pos_pad_idx,
trg_data_shape, ) trg_data_shape, )
dec_output = decoder( dec_output = decoder(
dec_input, dec_input,
......
...@@ -15,7 +15,7 @@ def pad_batch_data(insts, ...@@ -15,7 +15,7 @@ def pad_batch_data(insts,
pad_idx, pad_idx,
n_head, n_head,
is_target=False, is_target=False,
return_pos=True, is_label=False,
return_attn_bias=True, return_attn_bias=True,
return_max_len=True): return_max_len=True):
""" """
...@@ -24,14 +24,20 @@ def pad_batch_data(insts, ...@@ -24,14 +24,20 @@ def pad_batch_data(insts,
""" """
return_list = [] return_list = []
max_len = max(len(inst) for inst in insts) max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array( inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])] return_list += [inst_data.astype("int64").reshape([-1, 1])]
if return_pos: if is_label: # label weight
inst_pos = np.array([[ inst_weight = np.array(
pos_i + 1 if w_i != pad_idx else 0 for pos_i, w_i in enumerate(inst) [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
] for inst in inst_data]) return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
range(1, len(inst) + 1) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])] return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias: if return_attn_bias:
if is_target: if is_target:
...@@ -84,9 +90,14 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -84,9 +90,14 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
trg_src_attn_post_softmax_shape = np.array( trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32") trg_src_attn_bias.shape, dtype="int32")
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, lbl_word, lbl_weight = pad_batch_data(
False, False, False, False) [inst[2] for inst in insts],
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False)
input_dict = dict( input_dict = dict(
zip(input_data_names, [ zip(input_data_names, [
...@@ -105,13 +116,11 @@ def main(): ...@@ -105,13 +116,11 @@ def main():
exe = fluid.Executor(place) exe = fluid.Executor(place)
sum_cost, avg_cost, predict, token_num = transformer( sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
ModelHyperParams.dropout, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, place, TrainTaskConfig.warmup_steps, place,
...@@ -145,8 +154,8 @@ def main(): ...@@ -145,8 +154,8 @@ def main():
for batch_id, data in enumerate(val_data()): for batch_id, data in enumerate(val_data()):
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model) ModelHyperParams.d_model)
test_sum_cost, test_token_num = exe.run( test_sum_cost, test_token_num = exe.run(
test_program, test_program,
...@@ -171,10 +180,12 @@ def main(): ...@@ -171,10 +180,12 @@ def main():
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time() pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model) ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input) lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(), outs = exe.run(fluid.framework.default_main_program(),
......
...@@ -30,15 +30,11 @@ class PolicyGradient: ...@@ -30,15 +30,11 @@ class PolicyGradient:
acts = fluid.layers.data(name='acts', shape=[1], dtype='int64') acts = fluid.layers.data(name='acts', shape=[1], dtype='int64')
vt = fluid.layers.data(name='vt', shape=[1], dtype='float32') vt = fluid.layers.data(name='vt', shape=[1], dtype='float32')
# fc1 # fc1
fc1 = fluid.layers.fc( fc1 = fluid.layers.fc(input=obs, size=10, act="tanh") # tanh activation
input=obs,
size=10,
act="tanh" # tanh activation
)
# fc2 # fc2
all_act_prob = fluid.layers.fc(input=fc1, all_act_prob = fluid.layers.fc(input=fc1,
size=self.n_actions, size=self.n_actions,
act="softmax") act="softmax")
self.inferece_program = fluid.defaul_main_program().clone() self.inferece_program = fluid.defaul_main_program().clone()
# to maximize total reward (log_p * R) is to minimize -(log_p * R) # to maximize total reward (log_p * R) is to minimize -(log_p * R)
neg_log_prob = fluid.layers.cross_entropy( neg_log_prob = fluid.layers.cross_entropy(
...@@ -53,10 +49,9 @@ class PolicyGradient: ...@@ -53,10 +49,9 @@ class PolicyGradient:
self.exe.run(fluid.default_startup_program()) self.exe.run(fluid.default_startup_program())
def choose_action(self, observation): def choose_action(self, observation):
prob_weights = self.exe.run( prob_weights = self.exe.run(self.inferece_program,
self.inferece_program, feed={"obs": observation[np.newaxis, :]},
feed={"obs": observation[np.newaxis, :]}, fetch_list=[self.all_act_prob])
fetch_list=[self.all_act_prob])
prob_weights = np.array(prob_weights[0]) prob_weights = np.array(prob_weights[0])
action = np.random.choice( action = np.random.choice(
range(prob_weights.shape[1]), range(prob_weights.shape[1]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册