未验证 提交 4cebdec8 编写于 作者: D dzhwinter 提交者: GitHub

"migration from benchmark repo to paddle main repo" (#9762)

上级 31464f34
此差异已折叠。
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import time
import numpy as np
import tensorflow as tf
import paddle.v2 as paddle
DTYPE = tf.float32
def parse_args():
parser = argparse.ArgumentParser("mnist model benchmark.")
parser.add_argument(
'--batch_size', type=int, default=128, help='The minibatch size.')
parser.add_argument(
'--iterations', type=int, default=35, help='The number of minibatches.')
parser.add_argument(
'--pass_num', type=int, default=5, help='The number of passes.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
args = parser.parse_args()
return args
def run_benchmark(args):
def weight_variable(dtype, shape):
initial = tf.truncated_normal(shape, stddev=0.1, dtype=dtype)
return tf.Variable(initial)
def bias_variable(dtype, shape):
initial = tf.constant(0.1, shape=shape, dtype=dtype)
return tf.Variable(initial)
device = '/cpu:0' if args.device == 'CPU' else '/device:GPU:0'
with tf.device(device):
images = tf.placeholder(DTYPE, shape=(None, 28, 28, 1))
labels = tf.placeholder(tf.int64, shape=(None, ))
# conv1, relu, pool1
conv1_weights = weight_variable(DTYPE, [5, 5, 1, 20])
conv1_bias = bias_variable(DTYPE, [20])
conv1 = tf.nn.conv2d(
images, conv1_weights, strides=[1, 1, 1, 1], padding="VALID")
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_bias))
pool1 = tf.nn.max_pool(
relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")
# conv2, relu, pool2
conv2_weights = weight_variable(DTYPE, [5, 5, 20, 50])
conv2_bias = bias_variable(DTYPE, [50])
conv2 = tf.nn.conv2d(
pool1, conv2_weights, strides=[1, 1, 1, 1], padding="VALID")
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_bias))
pool2 = tf.nn.max_pool(
relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")
# FC
pool_shape = pool2.get_shape().as_list()
hidden_dim = reduce(lambda a, b: a * b, pool_shape[1:], 1)
reshape = tf.reshape(pool2, shape=(tf.shape(pool2)[0], hidden_dim))
fc_weights = weight_variable(DTYPE, [hidden_dim, 10])
fc_bias = bias_variable(DTYPE, [10])
logits = tf.matmul(reshape, fc_weights) + fc_bias
# Get prediction
prediction = tf.nn.softmax(logits)
# Loss
one_hot_labels = tf.one_hot(labels, depth=10)
cost = -tf.reduce_sum(tf.log(prediction) * one_hot_labels, [1])
avg_cost = tf.reduce_mean(cost)
# Get accuracy
correct = tf.equal(tf.argmax(prediction, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
# metrics, g_accuracy
with tf.variable_scope("reset_metrics_accuracy_scope") as scope:
g_accuracy = tf.metrics.accuracy(
labels, tf.argmax(
prediction, axis=1))
vars = tf.contrib.framework.get_variables(
scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
g_accuracy_reset_op = tf.variables_initializer(vars)
# Optimizer
opt = tf.train.AdamOptimizer(
learning_rate=0.001, beta1=0.9, beta2=0.999)
train_op = opt.minimize(avg_cost)
# train_op = tf.train.AdamOptimizer(1e-4).minimize(avg_cost)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=args.batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=args.batch_size)
def eval_test():
sess.run(g_accuracy_reset_op)
for batch_id, data in enumerate(test_reader()):
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape([1, 28, 28]), axes=[1,2,0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype("int64")
loss, acc, g_acc = sess.run(
[avg_cost, accuracy, g_accuracy],
feed_dict={images: images_data,
labels: labels_data})
return g_acc[1]
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_g)
sess.run(init_l)
for pass_id in range(args.pass_num):
sess.run(g_accuracy_reset_op)
pass_start = time.time()
for batch_id, data in enumerate(train_reader()):
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape([1, 28, 28]), axes=[1,2,0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype(
"int64")
start = time.time()
_, loss, acc, g_acc = sess.run(
[train_op, avg_cost, accuracy, g_accuracy],
feed_dict={images: images_data,
labels: labels_data})
end = time.time()
print("pass=%d, batch=%d, loss=%f, error=%f, elapse=%f" %
(pass_id, batch_id, loss, 1 - acc, (end - start) / 1000))
pass_end = time.time()
test_avg_acc = eval_test()
print(
"pass=%d, training_avg_accuracy=%f, test_avg_acc=%f, elapse=%f"
% (pass_id, g_acc[1], test_avg_acc,
(pass_end - pass_start) / 1000))
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
run_benchmark(args)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
based on https://github.com/tensorflow/models/blob/master/official/resnet/resnet_model.py
Get help: python resnet.py --help
See performance on flowers: python resnet.py
Train on cifar10: python resnet.py --data=cifar10 --with_test
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import time
import numpy as np
import paddle.v2 as paddle
import tensorflow as tf
DTYPE = tf.float32
def parse_args():
parser = argparse.ArgumentParser('Convolution model benchmark.')
parser.add_argument(
'--model',
type=str,
choices=['resnet'],
default='resnet',
help='The model architecture.')
parser.add_argument(
'--batch_size', type=int, default=32, help='The minibatch size.')
parser.add_argument(
'--use_fake_data',
action='store_true',
help='use real data or fake data')
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations',
type=int,
default=105,
help='The number of minibatches.')
parser.add_argument(
'--pass_num', type=int, default=300, help='The number of passes.')
parser.add_argument(
'--order',
type=str,
default='NHWC',
choices=['NCHW', 'NHWC'],
help='The data order, now only support NCHW.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
parser.add_argument(
'--data',
type=str,
default='flowers102',
choices=['flowers102', 'cifar10'],
help='The kinds of data.')
parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.')
parser.add_argument(
'--use_cprof', action='store_true', help='If set, use cProfile.')
parser.add_argument(
'--with_test',
action='store_true',
help='If set, test the testset during training.')
parser.add_argument(
'--use_nvprof',
action='store_true',
help='If set, use nvprof for CUDA.')
args = parser.parse_args()
return args
def print_arguments(args):
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
vars(args)['device'] == 'GPU')
vars(args)['iterations'] = vars(args)['pass_num'] * 1000 if vars(args)[
'with_test'] else vars(args)['iterations']
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def fixed_padding(inputs, kernel_size, data_format):
"""Pads the input along the spatial dimensions independently of input size.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
Should be a positive integer.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
A tensor with the same format as the input with the data either intact
(if kernel_size == 1) or padded (if kernel_size > 1).
"""
pad_total = kernel_size - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
if data_format == 'channels_first':
padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], [pad_beg, pad_end],
[pad_beg, pad_end]])
else:
padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
[pad_beg, pad_end], [0, 0]])
return padded_inputs
def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
"""Strided 2-D convolution with explicit padding."""
# The padding is consistent and is based only on `kernel_size`, not on the
# dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
# This is consistent with PaddlePaddle.
# In addition, the calculation for output size in TensorFlow can refer:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/common_shape_fns.cc
if strides > 1:
inputs = fixed_padding(inputs, kernel_size, data_format)
return tf.layers.conv2d(
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=('SAME' if strides == 1 else 'VALID'),
use_bias=False,
kernel_initializer=tf.variance_scaling_initializer(),
data_format=data_format)
def conv_bn(inputs,
filters,
kernel_size,
strides,
is_training,
data_format,
act=True):
# def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
# set fused=True for a significant performance boost. See
# https://www.tensorflow.org/performance/performance_guide#common_fused_ops
inputs = conv2d_fixed_padding(
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
strides=strides,
data_format=data_format)
inputs = tf.layers.batch_normalization(
inputs=inputs,
axis=1 if data_format == 'channels_first' else 3,
momentum=0.9,
epsilon=1e-05,
center=True,
scale=True,
training=is_training,
fused=True)
if act:
inputs = tf.nn.relu(inputs)
return inputs
def basicblock(inputs, filters, is_training, projection_shortcut, strides,
data_format):
shortcut = inputs
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
inputs = conv_bn(inputs, filters, 3, strides, is_training, data_format)
inputs = conv_bn(inputs, filters, 3, 1, is_training, data_format, act=False)
inputs = inputs + shortcut
inputs = tf.nn.relu(inputs)
return inputs
def bottleneck(inputs, filters, is_training, projection_shortcut, strides,
data_format):
shortcut = inputs
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
inputs = conv_bn(inputs, filters, 1, strides, is_training, data_format)
inputs = conv_bn(inputs, filters, 3, 1, is_training, data_format, act=False)
inputs = conv_bn(
inputs, filters * 4, 1, 1, is_training, data_format, act=False)
inputs = inputs + shortcut
inputs = tf.nn.relu(inputs)
return inputs
def block_layer(inputs, filters, block_fn, blocks, strides, is_training, name,
data_format):
# Bottleneck blocks end with 4x the number of filters as they start with
filters_out = 4 * filters if block_fn is bottleneck else filters
def projection_shortcut(inputs):
return conv2d_fixed_padding(
inputs=inputs,
filters=filters_out,
kernel_size=1,
strides=strides,
data_format=data_format)
# Only the first block per block_layer uses projection_shortcut and strides
inputs = block_fn(inputs, filters, is_training, projection_shortcut,
strides, data_format)
for _ in range(1, blocks):
inputs = block_fn(inputs, filters, is_training, None, 1, data_format)
return tf.identity(inputs, name)
def resnet_imagenet(depth, class_dim, data_format):
"""Returns the ResNet model for a given size and number of output classes."""
def resnet_generator(block_fn,
layers,
num_classes,
data_format='channels_last'):
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
def model(inputs, is_training):
"""Constructs the ResNet model given the inputs."""
if data_format == 'channels_first':
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# This provides a large performance boost on GPU. See
# https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2])
inputs = conv_bn(inputs, 64, 7, 2, is_training, data_format)
inputs = tf.identity(inputs, 'initial_conv')
inputs = tf.layers.max_pooling2d(
inputs=inputs,
pool_size=3,
strides=2,
padding='SAME',
data_format=data_format)
inputs = tf.identity(inputs, 'initial_max_pool')
inputs = block_layer(inputs, 64, block_fn, layers[0], 1,
is_training, 'block_layer1', data_format)
inputs = block_layer(inputs, 128, block_fn, layers[1], 2,
is_training, 'block_layer2', data_format)
inputs = block_layer(inputs, 256, block_fn, layers[2], 2,
is_training, 'block_layer3', data_format)
inputs = block_layer(inputs, 512, block_fn, layers[3], 2,
is_training, 'block_layer4', data_format)
inputs = tf.layers.average_pooling2d(
inputs=inputs,
pool_size=7,
strides=1,
padding='VALID',
data_format=data_format)
inputs = tf.identity(inputs, 'final_avg_pool')
inputs = tf.reshape(inputs,
[-1, 512 if block_fn is basicblock else 2048])
inputs = tf.layers.dense(inputs=inputs, units=num_classes)
inputs = tf.identity(inputs, 'final_dense')
return inputs
return model
model_params = {
18: {
'block': basicblock,
'layers': [2, 2, 2, 2]
},
34: {
'block': basicblock,
'layers': [3, 4, 6, 3]
},
50: {
'block': bottleneck,
'layers': [3, 4, 6, 3]
},
101: {
'block': bottleneck,
'layers': [3, 4, 23, 3]
},
152: {
'block': bottleneck,
'layers': [3, 8, 36, 3]
},
200: {
'block': bottleneck,
'layers': [3, 24, 36, 3]
}
}
if depth not in model_params:
raise ValueError('Not a valid depth:', depth)
params = model_params[depth]
return resnet_generator(params['block'], params['layers'], class_dim,
data_format)
def resnet_cifar10(depth, num_classes, data_format):
if depth % 6 != 2:
raise ValueError('depth must be 6n + 2:', depth)
num_blocks = (depth - 2) // 6
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
def model(inputs, is_training):
inputs = conv_bn(inputs, 16, 3, 1, is_training, data_format)
inputs = tf.identity(inputs, 'initial_conv')
inputs = block_layer(inputs, 16, basicblock, num_blocks, 1, is_training,
'block_layer1', data_format)
inputs = block_layer(inputs, 32, basicblock, num_blocks, 2, is_training,
'block_layer2', data_format)
inputs = block_layer(inputs, 64, basicblock, num_blocks, 2, is_training,
'block_layer3', data_format)
inputs = tf.layers.average_pooling2d(
inputs=inputs,
pool_size=8,
strides=1,
padding='VALID',
data_format=data_format)
inputs = tf.identity(inputs, 'final_avg_pool')
inputs = tf.reshape(inputs, [-1, 64])
inputs = tf.layers.dense(inputs=inputs, units=num_classes)
inputs = tf.identity(inputs, 'final_dense')
return inputs
return model
def run_benchmark(args, data_format='channels_last', device='/cpu:0'):
"""Our model_fn for ResNet to be used with our Estimator."""
class_dim = 1000
dshape = (None, 224, 224, 3)
pdshape = (3, 224, 224)
if args.data == 'flowers102':
class_dim = 102
dshape = (None, 224, 224, 3)
pdshape = (3, 224, 224)
elif args.data == 'cifar10':
class_dim = 10
dshape = (None, 32, 32, 3)
pdshape = (3, 32, 32)
with tf.device(device):
images = tf.placeholder(DTYPE, shape=dshape)
labels = tf.placeholder(tf.int64, shape=(None, ))
is_training = tf.placeholder('bool')
onehot_labels = tf.one_hot(labels, depth=class_dim)
network = resnet_cifar10(
32, class_dim,
data_format) if args.data == 'cifar10' else resnet_imagenet(
50, class_dim, data_format)
logits = network(inputs=images, is_training=is_training)
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=onehot_labels)
avg_cost = tf.reduce_mean(cross_entropy)
correct = tf.equal(tf.argmax(logits, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
lr = 0.1 if args.data == 'cifar10' else 0.01
optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)
# Batch norm requires update_ops to be added as a train_op dependency.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(avg_cost)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10()
if args.data == 'cifar10' else paddle.dataset.flowers.train(),
buf_size=5120),
batch_size=args.batch_size)
test_reader = paddle.batch(
paddle.dataset.cifar.test10()
if args.data == 'cifar10' else paddle.dataset.flowers.test(),
batch_size=100)
def test():
test_accs = []
for batch_id, data in enumerate(test_reader()):
test_images = np.array(
map(lambda x: np.transpose(x[0].reshape(pdshape),
axes=[1, 2, 0]), data)).astype("float32")
test_labels = np.array(map(lambda x: x[1], data)).astype('int64')
test_accs.append(
accuracy.eval(feed_dict={
images: test_images,
labels: test_labels,
is_training: False
}))
print("Pass = %d, Train performance = %f imgs/s, Test accuracy = %f\n" %
(pass_id, num_samples / train_elapsed, np.mean(test_accs)))
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_g)
sess.run(init_l)
if args.use_fake_data:
data = train_reader().next()
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape(pdshape),
axes=[1, 2, 0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype('int64')
iters, num_samples, start_time = 0, 0, 0.0
for pass_id in range(args.pass_num):
if iters == args.iterations:
break
train_accs = []
train_losses = []
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
if not args.use_fake_data:
images_data = np.array(
map(lambda x: np.transpose(x[0].reshape(pdshape),
axes=[1, 2, 0]), data)).astype("float32")
labels_data = np.array(map(lambda x: x[1], data)).astype(
'int64')
_, loss, acc = sess.run([train_op, avg_cost, accuracy],
feed_dict={
images: images_data,
labels: labels_data,
is_training: True
})
iters += 1
train_accs.append(acc)
train_losses.append(loss)
num_samples += len(data)
print("Pass=%d, Iter=%d, Loss=%f, Accuray=%f\n" %
(pass_id, iters, loss, acc))
train_elapsed = time.time() - start_time
print("Pass=%d, Loss=%f, Accuray=%f\n" %
(pass_id, np.mean(train_losses), np.mean(train_accs)))
# evaluation
if args.with_test:
test()
if not args.with_test:
duration = time.time() - start_time
examples_per_sec = num_samples / duration
sec_per_batch = duration / (iters - args.skip_batch_num)
print('Total examples: %d, total time: %.5f' %
(num_samples, duration))
print('%.5f examples/sec, %.5f sec/batch' %
(examples_per_sec, sec_per_batch))
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
if tf.test.is_built_with_cuda():
device = '/device:GPU:0'
if args.order == 'NHWC':
data_format = 'channels_last'
else:
data_format = 'channels_first'
else:
device = '/cpu:0'
if args.order == 'NHWC':
data_format = 'channels_last'
else:
raise ValueError('Only support NHWC order in CPU mode')
run_benchmark(args, data_format, device)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import argparse
import time
import tensorflow as tf
import paddle.v2 as paddle
def parse_args():
parser = argparse.ArgumentParser("LSTM model benchmark.")
parser.add_argument(
'--batch_size',
type=int,
default=32,
help='The sequence number of a batch data. (default: %(default)d)')
parser.add_argument(
'--stacked_num',
type=int,
default=5,
help='Number of lstm layers to stack. (default: %(default)d)')
parser.add_argument(
'--embedding_dim',
type=int,
default=512,
help='Dimension of embedding table. (default: %(default)d)')
parser.add_argument(
'--hidden_dim',
type=int,
default=512,
help='Hidden size of lstm unit. (default: %(default)d)')
parser.add_argument(
'--pass_num',
type=int,
default=10,
help='Epoch number to train. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
default=0.0002,
help='Learning rate used to train. (default: %(default)f)')
parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.')
args = parser.parse_args()
return args
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def dynamic_lstm_model(dict_size,
embedding_dim,
hidden_dim,
stacked_num,
class_num=2,
is_train=True):
word_idx = tf.placeholder(tf.int64, shape=[None, None])
sequence_length = tf.placeholder(tf.int64, shape=[None, ])
embedding_weights = tf.get_variable('word_embeddings',
[dict_size, embedding_dim])
embedding = tf.nn.embedding_lookup(embedding_weights, word_idx)
lstm_cell = tf.nn.rnn_cell.LSTMCell(
num_units=hidden_dim, use_peepholes=False)
stacked_cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * stacked_num)
# final_state [LSTMTuple(c, h), LSTMTuple(c, h) ...] total stacked_num LSTMTuples
_, final_state = tf.nn.dynamic_rnn(
cell=stacked_cell,
inputs=embedding,
dtype=tf.float32,
sequence_length=sequence_length)
w = tf.Variable(
tf.truncated_normal([hidden_dim, class_num]), dtype=tf.float32)
bias = tf.Variable(
tf.constant(
value=0.0, shape=[class_num], dtype=tf.float32))
prediction = tf.matmul(final_state[-1][1], w) + bias
if not is_train:
return (word_idx, sequence_length), tf.nn.softmax(prediction)
label = tf.placeholder(tf.int64, shape=[None, ])
loss = tf.nn.softmax_cross_entropy_with_logits(
labels=tf.one_hot(label, 2), logits=prediction)
avg_loss = tf.reduce_mean(loss)
correct_count = tf.equal(tf.argmax(prediction, 1), label)
acc = tf.reduce_mean(tf.cast(correct_count, tf.float32))
with tf.variable_scope("reset_metrics_accuracy_scope") as scope:
g_acc = tf.metrics.accuracy(label, tf.argmax(prediction, axis=1))
vars = tf.contrib.framework.get_variables(
scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
reset_op = tf.variables_initializer(vars)
return (word_idx, sequence_length, label), avg_loss, acc, g_acc, reset_op
def padding_data(data, padding_size, value):
data = data + [value] * padding_size
return data[:padding_size]
def train(args):
word_dict = paddle.dataset.imdb.word_dict()
dict_size = len(word_dict)
feeding_list, avg_loss, acc, g_acc, reset_op = dynamic_lstm_model(
dict_size, args.embedding_dim, args.hidden_dim, args.stacked_num)
adam_optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
train_op = adam_optimizer.minimize(avg_loss)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=25000),
batch_size=args.batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.test(word_dict), buf_size=25000),
batch_size=args.batch_size)
def do_validation(sess):
sess.run(reset_op)
for batch_id, data in enumerate(test_reader()):
word_idx = map(lambda x: x[0], data)
sequence_length = np.array(
[len(seq) for seq in word_idx]).astype('int64')
maxlen = np.max(sequence_length)
word_idx = [padding_data(seq, maxlen, 0) for seq in word_idx]
word_idx = np.array(word_idx).astype('int64')
label = np.array(map(lambda x: x[1], data)).astype('int64')
_, loss, fetch_acc, fetch_g_acc = sess.run(
[train_op, avg_loss, acc, g_acc],
feed_dict={
feeding_list[0]: word_idx,
feeding_list[1]: sequence_length,
feeding_list[2]: label
})
return fetch_g_acc[1]
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_l)
sess.run(init_g)
for pass_id in xrange(args.pass_num):
# clear accuracy local variable
sess.run(reset_op)
pass_start_time = time.time()
words_seen = 0
for batch_id, data in enumerate(train_reader()):
word_idx = map(lambda x: x[0], data)
sequence_length = np.array(
[len(seq) for seq in word_idx]).astype('int64')
words_seen += np.sum(sequence_length)
maxlen = np.max(sequence_length)
word_idx = [padding_data(seq, maxlen, 0) for seq in word_idx]
word_idx = np.array(word_idx).astype('int64')
label = np.array(map(lambda x: x[1], data)).astype('int64')
_, loss, fetch_acc, fetch_g_acc = sess.run(
[train_op, avg_loss, acc, g_acc],
feed_dict={
feeding_list[0]: word_idx,
feeding_list[1]: sequence_length,
feeding_list[2]: label
})
print("pass_id=%d, batch_id=%d, loss: %f, acc: %f, avg_acc: %f"
% (pass_id, batch_id, loss, fetch_acc, fetch_g_acc[1]))
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
words_per_sec = words_seen / time_consumed
test_acc = do_validation(sess)
print("pass_id=%d, test_acc: %f, words/s: %f, sec/pass: %f" %
(pass_id, test_acc, words_per_sec, time_consumed))
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
if args.infer_only:
pass
else:
train(args)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VGG16 benchmark in TensorFlow"""
import tensorflow as tf
import paddle.v2 as paddle
import numpy as np
import argparse
import time
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--batch_size', type=int, default=128, help="Batch size for training.")
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test')
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
'--learning_rate',
type=float,
default=1e-3,
help="Learning rate for training.")
parser.add_argument('--num_passes', type=int, default=50, help="No. of passes.")
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
parser.add_argument(
'--data_format',
type=str,
default='NHWC',
choices=['NCHW', 'NHWC'],
help='The data order, NCHW=[batch, channels, height, width].'
'Only support NHWC right now.')
parser.add_argument(
'--data_set',
type=str,
default='cifar10',
choices=['cifar10', 'flowers'],
help='Optional dataset for benchmark.')
args = parser.parse_args()
class VGG16Model(object):
def __init__(self):
self.parameters = []
def batch_norm_relu(self, inputs, is_training):
"""Performs a batch normalization followed by a ReLU."""
# We set fused=True for a significant speed boost. See
# https://www.tensorflow.org/speed/speed_guide#common_fused_ops
inputs = tf.layers.batch_normalization(
inputs=inputs,
axis=1 if args.data_format == 'NCHW' else -1,
momentum=0.9,
epsilon=1e-05,
center=True,
scale=True,
training=is_training,
fused=True)
inputs = tf.nn.relu(inputs)
return inputs
def conv_bn_layer(self,
name,
images,
kernel_shape,
is_training,
drop_rate=0.0):
with tf.name_scope(name) as scope:
kernel = tf.Variable(
tf.truncated_normal(
kernel_shape, dtype=tf.float32, stddev=1e-1),
name='weights')
conv = tf.nn.conv2d(
images,
kernel, [1, 1, 1, 1],
data_format=args.data_format,
padding='SAME')
biases = tf.Variable(
tf.constant(
0.0, shape=[kernel_shape[-1]], dtype=tf.float32),
trainable=True,
name='biases')
out = tf.nn.bias_add(conv, biases)
out = self.batch_norm_relu(out, is_training)
out = tf.layers.dropout(out, rate=drop_rate, training=is_training)
return out
def fc_layer(self, name, inputs, shape):
with tf.name_scope(name) as scope:
fc_w = tf.Variable(
tf.truncated_normal(
shape, dtype=tf.float32, stddev=1e-1),
name='weights')
fc_b = tf.Variable(
tf.constant(
0.0, shape=[shape[-1]], dtype=tf.float32),
trainable=True,
name='biases')
out = tf.nn.bias_add(tf.matmul(inputs, fc_w), fc_b)
return out
def network(self, images, class_dim, is_training):
""" VGG16 model structure.
TODO(kuke): enable this network to support the 'NCHW' data format
"""
# conv1
conv1_1 = self.conv_bn_layer(
'conv1_1', images, [3, 3, 3, 64], is_training, drop_rate=0.3)
conv1_2 = self.conv_bn_layer(
'conv1_2', conv1_1, [3, 3, 64, 64], is_training, drop_rate=0.0)
# pool1
pool1 = tf.nn.max_pool(
conv1_2,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool1')
# conv2
conv2_1 = self.conv_bn_layer(
'conv2_1', pool1, [3, 3, 64, 128], is_training, drop_rate=0.4)
conv2_2 = self.conv_bn_layer(
'conv2_2', conv2_1, [3, 3, 128, 128], is_training, drop_rate=0.0)
# pool2
pool2 = tf.nn.max_pool(
conv2_2,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool2')
# conv3
conv3_1 = self.conv_bn_layer(
'conv3_1', pool2, [3, 3, 128, 256], is_training, drop_rate=0.4)
conv3_2 = self.conv_bn_layer(
'conv3_2', conv3_1, [3, 3, 256, 256], is_training, drop_rate=0.4)
conv3_3 = self.conv_bn_layer(
'conv3_3', conv3_2, [3, 3, 256, 256], is_training, drop_rate=0.0)
# pool3
pool3 = tf.nn.max_pool(
conv3_3,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool3')
# conv4
conv4_1 = self.conv_bn_layer(
'conv4_1', pool3, [3, 3, 256, 512], is_training, drop_rate=0.4)
conv4_2 = self.conv_bn_layer(
'conv4_2', conv4_1, [3, 3, 512, 512], is_training, drop_rate=0.4)
conv4_3 = self.conv_bn_layer(
'conv4_3', conv4_2, [3, 3, 512, 512], is_training, drop_rate=0.0)
# pool4
pool4 = tf.nn.max_pool(
conv4_3,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool4')
# conv5
conv5_1 = self.conv_bn_layer(
'conv5_1', pool4, [3, 3, 512, 512], is_training, drop_rate=0.4)
conv5_2 = self.conv_bn_layer(
'conv5_2', conv5_1, [3, 3, 512, 512], is_training, drop_rate=0.4)
conv5_3 = self.conv_bn_layer(
'conv5_3', conv5_2, [3, 3, 512, 512], is_training, drop_rate=0.0)
# pool5
pool5 = tf.nn.max_pool(
conv5_3,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name='pool4')
# flatten
shape = int(np.prod(pool5.get_shape()[1:]))
pool5_flat = tf.reshape(pool5, [-1, shape])
# fc1
drop = tf.layers.dropout(pool5_flat, rate=0.5, training=is_training)
fc1 = self.fc_layer('fc1', drop, [shape, 512])
# fc2
bn = self.batch_norm_relu(fc1, is_training)
drop = tf.layers.dropout(bn, rate=0.5, training=is_training)
fc2 = self.fc_layer('fc2', drop, [512, 512])
fc3 = self.fc_layer('fc3', fc2, [512, class_dim])
return fc3
def run_benchmark():
"""Run benchmark on cifar10 or flowers."""
if args.data_set == "cifar10":
class_dim = 10
raw_shape = (3, 32, 32)
dat_shape = (None, 32, 32, 3) if args.data_format == 'NHWC' else (
None, 3, 32, 32)
else:
class_dim = 102
raw_shape = (3, 224, 224)
dat_shape = (None, 224, 224, 3) if args.data_format == 'NHWC' else (
None, 3, 224, 224)
device = '/cpu:0' if args.device == 'CPU' else '/device:GPU:0'
with tf.device(device):
images = tf.placeholder(tf.float32, shape=dat_shape)
labels = tf.placeholder(tf.int64, shape=(None, ))
is_training = tf.placeholder('bool')
onehot_labels = tf.one_hot(labels, depth=class_dim)
vgg16 = VGG16Model()
logits = vgg16.network(images, class_dim, is_training)
loss = tf.losses.softmax_cross_entropy(
onehot_labels=onehot_labels, logits=logits)
avg_loss = tf.reduce_mean(loss)
correct = tf.equal(tf.argmax(logits, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(avg_loss)
# data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10()
if args.data_set == 'cifar10' else paddle.dataset.flowers.train(),
buf_size=5120),
batch_size=args.batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.test10()
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(),
buf_size=5120),
batch_size=args.batch_size)
# test
def test():
test_accs = []
for batch_id, data in enumerate(test_reader()):
test_images = np.array(
map(lambda x: np.transpose(x[0].reshape(raw_shape),
axes=[1, 2, 0]) if args.data_format == 'NHWC' else x[0], data)).astype("float32")
test_labels = np.array(map(lambda x: x[1], data)).astype('int64')
test_accs.append(
accuracy.eval(feed_dict={
images: test_images,
labels: test_labels,
is_training: False
}))
return np.mean(test_accs)
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_g)
sess.run(init_l)
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in range(args.num_passes):
# train
num_samples = 0
start_time = time.time()
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
train_images = np.array(
map(lambda x: np.transpose(x[0].reshape(raw_shape),
axes=[1, 2, 0]) if args.data_format == 'NHWC' else x[0], data)).astype("float32")
train_labels = np.array(map(lambda x: x[1], data)).astype(
'int64')
_, loss, acc = sess.run([train_op, avg_loss, accuracy],
feed_dict={
images: train_images,
labels: train_labels,
is_training: True
})
iters += 1
num_samples += len(data)
print("Pass = %d, Iters = %d, Loss = %f, Accuracy = %f" %
(pass_id, iters, loss, acc))
train_elapsed = time.time() - start_time
# test
pass_test_acc = test()
print("Pass = %d, Train speed = %f imgs/s, Test accuracy = %f\n" %
(pass_id, num_samples / train_elapsed, pass_test_acc))
def print_arguments():
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
print_arguments()
run_benchmark()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册