提交 d2803c0f 编写于 作者: Y Yong Zheng Xin 提交者: François Chollet

PEP8 fixed line length examples/ (#10724)

* PEP8 fixed line length examples/

* revert addition_rnn

* Fix PEP8 in babi_memnn
上级 d18c5645
......@@ -24,7 +24,7 @@ Four digits reversed:
Five digits reversed:
+ One layer LSTM (128 HN), 550k training examples = 99% train/test accuracy in 30 epochs
'''
''' # noqa
from __future__ import print_function
from keras.models import Sequential
......
......@@ -17,7 +17,8 @@ from __future__ import print_function
from keras.models import Sequential, Model
from keras.layers.embeddings import Embedding
from keras.layers import Input, Activation, Dense, Permute, Dropout, add, dot, concatenate
from keras.layers import Input, Activation, Dense, Permute, Dropout
from keras.layers import add, dot, concatenate
from keras.layers import LSTM
from keras.utils.data_utils import get_file
from keras.preprocessing.sequence import pad_sequences
......@@ -79,7 +80,8 @@ def get_stories(f, only_supporting=False, max_length=None):
'''
data = parse_stories(f.readlines(), only_supporting=only_supporting)
flatten = lambda data: reduce(lambda x, y: x + y, data)
data = [(flatten(story), q, answer) for story, q, answer in data if not max_length or len(flatten(story)) < max_length]
data = [(flatten(story), q, answer) for story, q, answer in data
if not max_length or len(flatten(story)) < max_length]
return data
......@@ -94,19 +96,24 @@ def vectorize_stories(data):
np.array(answers))
try:
path = get_file('babi-tasks-v1-2.tar.gz', origin='https://s3.amazonaws.com/text-datasets/babi_tasks_1-20_v1-2.tar.gz')
path = get_file('babi-tasks-v1-2.tar.gz',
origin='https://s3.amazonaws.com/text-datasets/'
'babi_tasks_1-20_v1-2.tar.gz')
except:
print('Error downloading dataset, please download it manually:\n'
'$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz\n'
'$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2'
'.tar.gz\n'
'$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz')
raise
challenges = {
# QA1 with 10,000 samples
'single_supporting_fact_10k': 'tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_{}.txt',
'single_supporting_fact_10k': 'tasks_1-20_v1-2/en-10k/qa1_'
'single-supporting-fact_{}.txt',
# QA2 with 10,000 samples
'two_supporting_facts_10k': 'tasks_1-20_v1-2/en-10k/qa2_two-supporting-facts_{}.txt',
'two_supporting_facts_10k': 'tasks_1-20_v1-2/en-10k/qa2_'
'two-supporting-facts_{}.txt',
}
challenge_type = 'single_supporting_fact_10k'
challenge = challenges[challenge_type]
......
......@@ -123,7 +123,8 @@ def get_stories(f, only_supporting=False, max_length=None):
'''
data = parse_stories(f.readlines(), only_supporting=only_supporting)
flatten = lambda data: reduce(lambda x, y: x + y, data)
data = [(flatten(story), q, answer) for story, q, answer in data if not max_length or len(flatten(story)) < max_length]
data = [(flatten(story), q, answer) for story, q, answer in data
if not max_length or len(flatten(story)) < max_length]
return data
......@@ -140,7 +141,8 @@ def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
xs.append(x)
xqs.append(xq)
ys.append(y)
return pad_sequences(xs, maxlen=story_maxlen), pad_sequences(xqs, maxlen=query_maxlen), np.array(ys)
return (pad_sequences(xs, maxlen=story_maxlen),
pad_sequences(xqs, maxlen=query_maxlen), np.array(ys))
RNN = recurrent.LSTM
EMBED_HIDDEN_SIZE = 50
......@@ -154,10 +156,13 @@ print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN,
QUERY_HIDDEN_SIZE))
try:
path = get_file('babi-tasks-v1-2.tar.gz', origin='https://s3.amazonaws.com/text-datasets/babi_tasks_1-20_v1-2.tar.gz')
path = get_file('babi-tasks-v1-2.tar.gz',
origin='https://s3.amazonaws.com/text-datasets/'
'babi_tasks_1-20_v1-2.tar.gz')
except:
print('Error downloading dataset, please download it manually:\n'
'$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz\n'
'$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2'
'.tar.gz\n'
'$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz')
raise
......
......@@ -85,19 +85,26 @@ else:
zca_whitening=False, # apply ZCA whitening
zca_epsilon=1e-06, # epsilon for ZCA whitening
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
# randomly shift images horizontally (fraction of total width)
width_shift_range=0.1,
# randomly shift images vertically (fraction of total height)
height_shift_range=0.1,
shear_range=0., # set range for random shear
zoom_range=0., # set range for random zoom
channel_shift_range=0., # set range for random channel shifts
fill_mode='nearest', # set mode for filling points outside the input boundaries
# set mode for filling points outside the input boundaries
fill_mode='nearest',
cval=0., # value used for fill_mode = "constant"
horizontal_flip=True, # randomly flip images
vertical_flip=False, # randomly flip images
rescale=None, # set rescaling factor (applied before any other transformation)
preprocessing_function=None, # set function that will be applied on each input
data_format=None, # image data format, either "channels_first" or "channels_last"
validation_split=0.0) # fraction of images reserved for validation (strictly between 0 and 1)
# set rescaling factor (applied before any other transformation)
rescale=None,
# set function that will be applied on each input
preprocessing_function=None,
# image data format, either "channels_first" or "channels_last"
data_format=None,
# fraction of images reserved for validation (strictly between 0 and 1)
validation_split=0.0)
# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
......
......@@ -211,14 +211,19 @@ else:
shear_range=0., # set range for random shear
zoom_range=0., # set range for random zoom
channel_shift_range=0., # set range for random channel shifts
fill_mode='nearest', # set mode for filling points outside the input boundaries
# set mode for filling points outside the input boundaries
fill_mode='nearest',
cval=0., # value used for fill_mode = "constant"
horizontal_flip=True, # randomly flip images
vertical_flip=False, # randomly flip images
rescale=None, # set rescaling factor (applied before any other transformation)
preprocessing_function=None, # set function that will be applied on each input
data_format=None, # image data format, either "channels_first" or "channels_last"
validation_split=0.0) # fraction of images reserved for validation (strictly between 0 and 1)
# set rescaling factor (applied before any other transformation)
rescale=None,
# set function that will be applied on each input
preprocessing_function=None,
# image data format, either "channels_first" or "channels_last"
data_format=None,
# fraction of images reserved for validation (strictly between 0 and 1)
validation_split=0.0)
# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
......
......@@ -14,11 +14,6 @@ norecursedirs= build
pep8ignore=* E402 \
* E731 \
examples/addition_rnn.py E501 \
examples/babi_memnn.py E501 \
examples/babi_rnn.py E501 \
examples/cifar10_cnn.py E501 \
examples/cifar10_cnn_capsule.py E501 \
examples/conv_filter_visualization.py E501 \
examples/deep_dream.py E501 \
examples/image_ocr.py E501 \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册