提交 0c70f34c 编写于 作者: D dangqingqing

Fix bug for flowers dataset and row_conv.

上级 c5dc0b73
......@@ -2082,10 +2082,10 @@ class MaxOutLayer(LayerBase):
class RowConvLayer(LayerBase):
def __init__(self, name, inputs, context_length, **xargs):
super(RowConvLayer, self).__init__(
name, 'maxout', 0, inputs=inputs, **xargs)
name, 'row_conv', 0, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'TransLayer must have one and only one input')
'row convolution layer must have one and only one input.')
input_layer = self.get_input_layer(0)
row_conv_conf = self.config.inputs[0].row_conv_conf
row_conv_conf.context_length = context_length
......
......@@ -7,7 +7,7 @@ layers {
}
layers {
name: "__row_conv_layer_0__"
type: "maxout"
type: "row_conv"
size: 2560
active_type: "relu"
inputs {
......
......@@ -30,6 +30,7 @@ http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
"""
import cPickle
import itertools
import functools
from common import download
import tarfile
import scipy.io as scio
......@@ -54,21 +55,25 @@ TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
def default_mapper(sample):
def default_mapper(is_train, sample):
'''
map image bytes data to type needed by model input layer
'''
img, label = sample
img = load_image_bytes(img)
img = simple_transform(img, 256, 224, True)
img = simple_transform(img, 256, 224, is_train)
return img.flatten().astype('float32'), label
train_mapper = functools.partial(default_mapper, True)
test_mapper = functools.partial(default_mapper, False)
def reader_creator(data_file,
label_file,
setid_file,
dataset_name,
mapper=default_mapper,
mapper,
buffered_size=1024,
use_xmap=True):
'''
......@@ -118,7 +123,7 @@ def reader_creator(data_file,
return map_readers(mapper, reader)
def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
......@@ -141,7 +146,7 @@ def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
buffered_size, use_xmap)
def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
......@@ -164,7 +169,7 @@ def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
buffered_size, use_xmap)
def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册