提交 03fd5f6b 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #2686 from qingqing01/row_conv_fix

Fix bug for flowers dataset and row_conv.
...@@ -2082,10 +2082,10 @@ class MaxOutLayer(LayerBase): ...@@ -2082,10 +2082,10 @@ class MaxOutLayer(LayerBase):
class RowConvLayer(LayerBase): class RowConvLayer(LayerBase):
def __init__(self, name, inputs, context_length, **xargs): def __init__(self, name, inputs, context_length, **xargs):
super(RowConvLayer, self).__init__( super(RowConvLayer, self).__init__(
name, 'maxout', 0, inputs=inputs, **xargs) name, 'row_conv', 0, inputs=inputs, **xargs)
config_assert( config_assert(
len(self.inputs) == 1, len(self.inputs) == 1,
'TransLayer must have one and only one input') 'row convolution layer must have one and only one input.')
input_layer = self.get_input_layer(0) input_layer = self.get_input_layer(0)
row_conv_conf = self.config.inputs[0].row_conv_conf row_conv_conf = self.config.inputs[0].row_conv_conf
row_conv_conf.context_length = context_length row_conv_conf.context_length = context_length
......
...@@ -7,7 +7,7 @@ layers { ...@@ -7,7 +7,7 @@ layers {
} }
layers { layers {
name: "__row_conv_layer_0__" name: "__row_conv_layer_0__"
type: "maxout" type: "row_conv"
size: 2560 size: 2560
active_type: "relu" active_type: "relu"
inputs { inputs {
......
...@@ -30,6 +30,7 @@ http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}. ...@@ -30,6 +30,7 @@ http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
""" """
import cPickle import cPickle
import itertools import itertools
import functools
from common import download from common import download
import tarfile import tarfile
import scipy.io as scio import scipy.io as scio
...@@ -54,21 +55,26 @@ TEST_FLAG = 'trnid' ...@@ -54,21 +55,26 @@ TEST_FLAG = 'trnid'
VALID_FLAG = 'valid' VALID_FLAG = 'valid'
def default_mapper(sample): def default_mapper(is_train, sample):
''' '''
map image bytes data to type needed by model input layer map image bytes data to type needed by model input layer
''' '''
img, label = sample img, label = sample
img = load_image_bytes(img) img = load_image_bytes(img)
img = simple_transform(img, 256, 224, True) img = simple_transform(
img, 256, 224, is_train, mean=[103.94, 116.78, 123.68])
return img.flatten().astype('float32'), label return img.flatten().astype('float32'), label
train_mapper = functools.partial(default_mapper, True)
test_mapper = functools.partial(default_mapper, False)
def reader_creator(data_file, def reader_creator(data_file,
label_file, label_file,
setid_file, setid_file,
dataset_name, dataset_name,
mapper=default_mapper, mapper,
buffered_size=1024, buffered_size=1024,
use_xmap=True): use_xmap=True):
''' '''
...@@ -118,7 +124,7 @@ def reader_creator(data_file, ...@@ -118,7 +124,7 @@ def reader_creator(data_file,
return map_readers(mapper, reader) return map_readers(mapper, reader)
def train(mapper=default_mapper, buffered_size=1024, use_xmap=True): def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers training set reader. Create flowers training set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -141,7 +147,7 @@ def train(mapper=default_mapper, buffered_size=1024, use_xmap=True): ...@@ -141,7 +147,7 @@ def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
buffered_size, use_xmap) buffered_size, use_xmap)
def test(mapper=default_mapper, buffered_size=1024, use_xmap=True): def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers test set reader. Create flowers test set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -164,7 +170,7 @@ def test(mapper=default_mapper, buffered_size=1024, use_xmap=True): ...@@ -164,7 +170,7 @@ def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
buffered_size, use_xmap) buffered_size, use_xmap)
def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True): def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers validation set reader. Create flowers validation set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
......
...@@ -262,7 +262,12 @@ def left_right_flip(im): ...@@ -262,7 +262,12 @@ def left_right_flip(im):
return im[:, ::-1, :] return im[:, ::-1, :]
def simple_transform(im, resize_size, crop_size, is_train, is_color=True): def simple_transform(im,
resize_size,
crop_size,
is_train,
is_color=True,
mean=None):
""" """
Simply data argumentation for training. These operations include Simply data argumentation for training. These operations include
resizing, croping and flipping. resizing, croping and flipping.
...@@ -288,7 +293,19 @@ def simple_transform(im, resize_size, crop_size, is_train, is_color=True): ...@@ -288,7 +293,19 @@ def simple_transform(im, resize_size, crop_size, is_train, is_color=True):
im = left_right_flip(im) im = left_right_flip(im)
else: else:
im = center_crop(im, crop_size) im = center_crop(im, crop_size)
im = to_chw(im) if len(im.shape) == 3:
im = to_chw(im)
im = im.astype('float32')
if mean is not None:
mean = np.array(mean, dtype=np.float32)
# mean value, may be one value per channel
if mean.ndim == 1:
mean = mean[:, np.newaxis, np.newaxis]
else:
# elementwise mean
assert len(mean.shape) == len(im)
im -= mean
return im return im
...@@ -297,7 +314,8 @@ def load_and_transform(filename, ...@@ -297,7 +314,8 @@ def load_and_transform(filename,
resize_size, resize_size,
crop_size, crop_size,
is_train, is_train,
is_color=True): is_color=True,
mean=None):
""" """
Load image from the input file `filename` and transform image for Load image from the input file `filename` and transform image for
data argumentation. Please refer to the `simple_transform` interface data argumentation. Please refer to the `simple_transform` interface
...@@ -318,5 +336,5 @@ def load_and_transform(filename, ...@@ -318,5 +336,5 @@ def load_and_transform(filename,
:type is_train: bool :type is_train: bool
""" """
im = load_image(filename) im = load_image(filename)
im = simple_transform(im, resize_size, crop_size, is_train, is_color) im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean)
return im return im
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册