提交 03fd5f6b 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #2686 from qingqing01/row_conv_fix

Fix bug for flowers dataset and row_conv.
......@@ -2082,10 +2082,10 @@ class MaxOutLayer(LayerBase):
class RowConvLayer(LayerBase):
def __init__(self, name, inputs, context_length, **xargs):
super(RowConvLayer, self).__init__(
name, 'maxout', 0, inputs=inputs, **xargs)
name, 'row_conv', 0, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'TransLayer must have one and only one input')
'row convolution layer must have one and only one input.')
input_layer = self.get_input_layer(0)
row_conv_conf = self.config.inputs[0].row_conv_conf
row_conv_conf.context_length = context_length
......
......@@ -7,7 +7,7 @@ layers {
}
layers {
name: "__row_conv_layer_0__"
type: "maxout"
type: "row_conv"
size: 2560
active_type: "relu"
inputs {
......
......@@ -30,6 +30,7 @@ http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
"""
import cPickle
import itertools
import functools
from common import download
import tarfile
import scipy.io as scio
......@@ -54,21 +55,26 @@ TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
def default_mapper(sample):
def default_mapper(is_train, sample):
'''
map image bytes data to type needed by model input layer
'''
img, label = sample
img = load_image_bytes(img)
img = simple_transform(img, 256, 224, True)
img = simple_transform(
img, 256, 224, is_train, mean=[103.94, 116.78, 123.68])
return img.flatten().astype('float32'), label
train_mapper = functools.partial(default_mapper, True)
test_mapper = functools.partial(default_mapper, False)
def reader_creator(data_file,
label_file,
setid_file,
dataset_name,
mapper=default_mapper,
mapper,
buffered_size=1024,
use_xmap=True):
'''
......@@ -118,7 +124,7 @@ def reader_creator(data_file,
return map_readers(mapper, reader)
def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
......@@ -141,7 +147,7 @@ def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
buffered_size, use_xmap)
def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
......@@ -164,7 +170,7 @@ def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
buffered_size, use_xmap)
def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
......
......@@ -262,7 +262,12 @@ def left_right_flip(im):
return im[:, ::-1, :]
def simple_transform(im, resize_size, crop_size, is_train, is_color=True):
def simple_transform(im,
resize_size,
crop_size,
is_train,
is_color=True,
mean=None):
"""
Simply data argumentation for training. These operations include
resizing, croping and flipping.
......@@ -288,7 +293,19 @@ def simple_transform(im, resize_size, crop_size, is_train, is_color=True):
im = left_right_flip(im)
else:
im = center_crop(im, crop_size)
im = to_chw(im)
if len(im.shape) == 3:
im = to_chw(im)
im = im.astype('float32')
if mean is not None:
mean = np.array(mean, dtype=np.float32)
# mean value, may be one value per channel
if mean.ndim == 1:
mean = mean[:, np.newaxis, np.newaxis]
else:
# elementwise mean
assert len(mean.shape) == len(im)
im -= mean
return im
......@@ -297,7 +314,8 @@ def load_and_transform(filename,
resize_size,
crop_size,
is_train,
is_color=True):
is_color=True,
mean=None):
"""
Load image from the input file `filename` and transform image for
data argumentation. Please refer to the `simple_transform` interface
......@@ -318,5 +336,5 @@ def load_and_transform(filename,
:type is_train: bool
"""
im = load_image(filename)
im = simple_transform(im, resize_size, crop_size, is_train, is_color)
im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean)
return im
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册