提交 7293c821 编写于 作者: Y Yu Yang

Merge branch 'feature/clean_mnist_v2' into feature/tester

...@@ -57,7 +57,7 @@ before_install: ...@@ -57,7 +57,7 @@ before_install:
- if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi - if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version. # protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker 'scikit-learn>=0.18.0' 'scipy>=0.18.0' - pip install numpy wheel 'protobuf==3.1' sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker
script: script:
- paddle/scripts/travis/main.sh - paddle/scripts/travis/main.sh
notifications: notifications:
......
import numpy
import paddle.v2 as paddle import paddle.v2 as paddle
...@@ -41,7 +40,7 @@ def main(): ...@@ -41,7 +40,7 @@ def main():
trainer.train( trainer.train(
reader=paddle.reader.batched( reader=paddle.reader.batched(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train_creator(), buf_size=8192), paddle.dataset.mnist.train(), buf_size=8192),
batch_size=32), batch_size=32),
event_handler=event_handler) event_handler=event_handler)
......
...@@ -18,6 +18,7 @@ import parameters ...@@ -18,6 +18,7 @@ import parameters
import trainer import trainer
import event import event
import data_type import data_type
import topology
import data_feeder import data_feeder
from . import dataset from . import dataset
from . import reader from . import reader
...@@ -27,7 +28,8 @@ import py_paddle.swig_paddle as api ...@@ -27,7 +28,8 @@ import py_paddle.swig_paddle as api
__all__ = [ __all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader' 'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader',
'topology'
] ]
......
...@@ -23,7 +23,7 @@ class DataFeeder(DataProviderConverter): ...@@ -23,7 +23,7 @@ class DataFeeder(DataProviderConverter):
""" """
DataFeeder converts the data returned by paddle.reader into a data structure DataFeeder converts the data returned by paddle.reader into a data structure
of Arguments which is defined in the API. The paddle.reader usually returns of Arguments which is defined in the API. The paddle.reader usually returns
a list of mini-batch data entries. Each data entry in the list is one sampe. a list of mini-batch data entries. Each data entry in the list is one sample.
Each sample is a list or a tuple with one feature or multiple features. Each sample is a list or a tuple with one feature or multiple features.
DataFeeder converts this mini-batch data entries into Arguments in order DataFeeder converts this mini-batch data entries into Arguments in order
to feed it to C++ interface. to feed it to C++ interface.
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
from paddle.trainer.PyDataProvider2 import \ from paddle.trainer.PyDataProvider2 import \
InputType, dense_vector, sparse_binary_vector,\ InputType, DataType, dense_vector, sparse_binary_vector,\
sparse_vector, integer_value, integer_value_sequence sparse_vector, integer_value, integer_value_sequence
__all__ = [ __all__ = [
'InputType', 'dense_vector', 'sparse_binary_vector', 'sparse_vector', 'InputType', 'DataType', 'dense_vector', 'sparse_binary_vector',
'integer_value', 'integer_value_sequence' 'sparse_vector', 'integer_value', 'integer_value_sequence'
] ]
""" """
CIFAR Dataset. CIFAR dataset: https://www.cs.toronto.edu/~kriz/cifar.html
URL: https://www.cs.toronto.edu/~kriz/cifar.html
the default train_creator, test_creator used for CIFAR-10 dataset.
""" """
import cPickle import cPickle
import itertools import itertools
import tarfile
import numpy import numpy
import paddle.v2.dataset.common
import tarfile
from common import download __all__ = ['train100', 'test100', 'train10', 'test10']
__all__ = [
'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator',
'test_creator'
]
CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def __read_batch__(filename, sub_name): def reader_creator(filename, sub_name):
def reader(): def read_batch(batch):
def __read_one_batch_impl__(batch):
data = batch['data'] data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None)) labels = batch.get('labels', batch.get('fine_labels', None))
assert labels is not None assert labels is not None
for sample, label in itertools.izip(data, labels): for sample, label in itertools.izip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label) yield (sample / 255.0).astype(numpy.float32), int(label)
def reader():
with tarfile.open(filename, mode='r') as f: with tarfile.open(filename, mode='r') as f:
names = (each_item.name for each_item in f names = (each_item.name for each_item in f
if sub_name in each_item.name) if sub_name in each_item.name)
for name in names: for name in names:
batch = cPickle.load(f.extractfile(name)) batch = cPickle.load(f.extractfile(name))
for item in __read_one_batch_impl__(batch): for item in read_batch(batch):
yield item yield item
return reader return reader
def cifar_100_train_creator(): def train100():
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) return reader_creator(
return __read_batch__(fn, 'train') paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
'train')
def cifar_100_test_creator():
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
return __read_batch__(fn, 'test')
def train_creator():
"""
Default train reader creator. Use CIFAR-10 dataset.
"""
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
return __read_batch__(fn, 'data_batch')
def test_creator(): def test100():
""" return reader_creator(
Default test reader creator. Use CIFAR-10 dataset. paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
""" 'test')
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
return __read_batch__(fn, 'test_batch')
def unittest(): def train10():
for _ in train_creator()(): return reader_creator(
pass paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
for _ in test_creator()(): 'data_batch')
pass
if __name__ == '__main__': def test10():
unittest() return reader_creator(
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch')
...@@ -27,7 +27,6 @@ def download(url, module_name, md5sum): ...@@ -27,7 +27,6 @@ def download(url, module_name, md5sum):
filename = os.path.join(dirname, url.split('/')[-1]) filename = os.path.join(dirname, url.split('/')[-1])
if not (os.path.exists(filename) and md5file(filename) == md5sum): if not (os.path.exists(filename) and md5file(filename) == md5sum):
# If file doesn't exist or MD5 doesn't match, then download.
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
with open(filename, 'w') as f: with open(filename, 'w') as f:
shutil.copyfileobj(r.raw, f) shutil.copyfileobj(r.raw, f)
......
"""
MNIST dataset.
"""
import paddle.v2.dataset.common import paddle.v2.dataset.common
import subprocess import subprocess
import numpy import numpy
import platform
__all__ = ['train', 'test'] __all__ = ['train', 'test']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz' TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6' TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz' TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
...@@ -18,12 +20,19 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432' ...@@ -18,12 +20,19 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
def reader_creator(image_filename, label_filename, buffer_size): def reader_creator(image_filename, label_filename, buffer_size):
def reader(): def reader():
if platform.system() == 'Darwin':
zcat_cmd = 'gzcat'
elif platform.system() == 'Linux':
zcat_cmd = 'zcat'
else:
raise NotImplementedError()
# According to http://stackoverflow.com/a/38061619/724872, we # According to http://stackoverflow.com/a/38061619/724872, we
# cannot use standard package gzip here. # cannot use standard package gzip here.
m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE) m = subprocess.Popen([zcat_cmd, image_filename], stdout=subprocess.PIPE)
m.stdout.read(16) # skip some magic bytes m.stdout.read(16) # skip some magic bytes
l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE) l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE)
l.stdout.read(8) # skip some magic bytes l.stdout.read(8) # skip some magic bytes
while True: while True:
...@@ -40,12 +49,12 @@ def reader_creator(image_filename, label_filename, buffer_size): ...@@ -40,12 +49,12 @@ def reader_creator(image_filename, label_filename, buffer_size):
images = images / 255.0 * 2.0 - 1.0 images = images / 255.0 * 2.0 - 1.0
for i in xrange(buffer_size): for i in xrange(buffer_size):
yield images[i, :], labels[i] yield images[i, :], int(labels[i])
m.terminate() m.terminate()
l.terminate() l.terminate()
return reader() return reader
def train(): def train():
......
import paddle.v2.dataset.cifar
import unittest
class TestCIFAR(unittest.TestCase):
def check_reader(self, reader):
sum = 0
label = 0
for l in reader():
self.assertEqual(l[0].size, 3072)
if l[1] > label:
label = l[1]
sum += 1
return sum, label
def test_test10(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.test10())
self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 9)
def test_train10(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.train10())
self.assertEqual(instances, 50000)
self.assertEqual(max_label_value, 9)
def test_test100(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.test100())
self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 99)
def test_train100(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.train100())
self.assertEqual(instances, 50000)
self.assertEqual(max_label_value, 99)
if __name__ == '__main__':
unittest.main()
...@@ -5,21 +5,25 @@ import unittest ...@@ -5,21 +5,25 @@ import unittest
class TestMNIST(unittest.TestCase): class TestMNIST(unittest.TestCase):
def check_reader(self, reader): def check_reader(self, reader):
sum = 0 sum = 0
for l in reader: label = 0
for l in reader():
self.assertEqual(l[0].size, 784) self.assertEqual(l[0].size, 784)
self.assertEqual(l[1].size, 1) if l[1] > label:
self.assertLess(l[1], 10) label = l[1]
self.assertGreaterEqual(l[1], 0)
sum += 1 sum += 1
return sum return sum, label
def test_train(self): def test_train(self):
self.assertEqual( instances, max_label_value = self.check_reader(
self.check_reader(paddle.v2.dataset.mnist.train()), 60000) paddle.v2.dataset.mnist.train())
self.assertEqual(instances, 60000)
self.assertEqual(max_label_value, 9)
def test_test(self): def test_test(self):
self.assertEqual( instances, max_label_value = self.check_reader(
self.check_reader(paddle.v2.dataset.mnist.test()), 10000) paddle.v2.dataset.mnist.test())
self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 9)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -284,6 +284,7 @@ def mixed(size=0, ...@@ -284,6 +284,7 @@ def mixed(size=0,
return MixedLayerV2(size, input, name, act, bias_attr, layer_attr) return MixedLayerV2(size, input, name, act, bias_attr, layer_attr)
LayerV2 = Layer
data = DataLayerV2 data = DataLayerV2
AggregateLevel = conf_helps.layers.AggregateLevel AggregateLevel = conf_helps.layers.AggregateLevel
ExpandLevel = conf_helps.layers.ExpandLevel ExpandLevel = conf_helps.layers.ExpandLevel
......
import numpy as np import numpy as np
from . import layer as v2_layer
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
from paddle.proto.ParameterConfig_pb2 import ParameterConfig from paddle.proto.ParameterConfig_pb2 import ParameterConfig
from topology import Topology
__all__ = ['Parameters', 'create'] __all__ = ['Parameters', 'create']
def create(*layers): def create(layers):
""" """
Create parameter pool by layers. In paddle, layer can be represent a Create parameter pool by topology.
model config.
:param layers: :param layers:
:return: :return:
""" """
for layer in layers: topology = Topology(layers)
if not isinstance(layer, v2_layer.Layer):
raise ValueError(
'create must pass a topologies which type is paddle.layer.Layer')
model_config = v2_layer.parse_network(*layers)
pool = Parameters() pool = Parameters()
for param in model_config.parameters: for param in topology.proto().parameters:
pool.__append_config__(param) pool.__append_config__(param)
return pool return pool
...@@ -224,6 +219,7 @@ class Parameters(object): ...@@ -224,6 +219,7 @@ class Parameters(object):
except ValueError: except ValueError:
# If no such parameter in gradient machine, then don't copy # If no such parameter in gradient machine, then don't copy
pass pass
self.__gradient_machines__.append(gradient_machine) self.__gradient_machines__.append(gradient_machine)
......
...@@ -2,5 +2,11 @@ add_test(NAME test_v2_layer ...@@ -2,5 +2,11 @@ add_test(NAME test_v2_layer
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/ COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/v2/tests/test_layer.py ${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/v2/tests/test_layer.py
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle) WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
add_test(NAME test_v2_api add_test(NAME test_v2_api
COMMAND bash ${PROJ_ROOT}/python/paddle/v2/tests/run_tests.sh ${PYTHON_EXECUTABLE}) COMMAND bash ${PROJ_ROOT}/python/paddle/v2/tests/run_tests.sh ${PYTHON_EXECUTABLE})
add_test(NAME topology_test
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/v2/tests/test_topology.py
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
# Copyright PaddlePaddle contributors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.v2.layer as layer
import paddle.v2.topology as topology
import paddle.v2.data_type as data_type
import paddle.trainer_config_helpers as conf_helps
class TestTopology(unittest.TestCase):
def test_data_type(self):
pixel = layer.data(name='pixel', type=data_type.dense_vector(784))
label = layer.data(name='label', type=data_type.integer_value(10))
hidden = layer.fc(input=pixel,
size=100,
act=conf_helps.SigmoidActivation())
inference = layer.fc(input=hidden,
size=10,
act=conf_helps.SoftmaxActivation())
cost = layer.classification_cost(input=inference, label=label)
topo = topology.Topology(cost)
data_types = topo.data_type()
self.assertEqual(len(data_types), 2)
pixel_data_type = filter(lambda type: type[0] == "pixel", data_types)
self.assertEqual(len(pixel_data_type), 1)
pixel_data_type = pixel_data_type[0]
self.assertEqual(pixel_data_type[1].type, data_type.DataType.Dense)
self.assertEqual(pixel_data_type[1].dim, 784)
label_data_type = filter(lambda type: type[0] == "label", data_types)
self.assertEqual(len(label_data_type), 1)
label_data_type = label_data_type[0]
self.assertEqual(label_data_type[1].type, data_type.DataType.Index)
self.assertEqual(label_data_type[1].dim, 10)
def test_get_layer(self):
pixel = layer.data(name='pixel', type=data_type.dense_vector(784))
label = layer.data(name='label', type=data_type.integer_value(10))
hidden = layer.fc(input=pixel,
size=100,
act=conf_helps.SigmoidActivation())
inference = layer.fc(input=hidden,
size=10,
act=conf_helps.SoftmaxActivation())
cost = layer.classification_cost(input=inference, label=label)
topo = topology.Topology(cost)
pixel_layer = topo.get_layer("pixel")
label_layer = topo.get_layer("label")
self.assertEqual(pixel_layer, pixel)
self.assertEqual(label_layer, label)
def test_parse(self):
pixel = layer.data(name='pixel', type=data_type.dense_vector(784))
label = layer.data(name='label', type=data_type.integer_value(10))
hidden = layer.fc(input=pixel,
size=100,
act=conf_helps.SigmoidActivation())
inference = layer.fc(input=hidden,
size=10,
act=conf_helps.SoftmaxActivation())
maxid = layer.max_id(input=inference)
cost1 = layer.classification_cost(input=inference, label=label)
cost2 = layer.cross_entropy_cost(input=inference, label=label)
topology.Topology(cost2).proto()
topology.Topology([cost1]).proto()
topology.Topology([cost1, cost2]).proto()
topology.Topology([inference, maxid]).proto()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from paddle.proto.ModelConfig_pb2 import ModelConfig
import layer as v2_layer
__all__ = ['Topology']
def __bfs_travel__(callback, *layers):
for each_layer in layers:
__break__ = callback(each_layer)
if __break__:
return
__bfs_travel__(callback, *each_layer.__parent_layers__.values())
class Topology(object):
"""
Topology is used to store the information about all layers
and network configs.
"""
def __init__(self, layers):
if not isinstance(layers, collections.Sequence):
__check_layer_type__(layers)
layers = [layers]
for layer in layers:
__check_layer_type__(layer)
self.layers = layers
self.__model_config__ = v2_layer.parse_network(*layers)
assert isinstance(self.__model_config__, ModelConfig)
def proto(self):
return self.__model_config__
def get_layer(self, name):
"""
get v2.Layer Class instance by layer name
:param name:
:return:
"""
result_layer = [None]
def __impl__(l):
if l.name == name:
result_layer[0] = l
return True # break
return False
__bfs_travel__(__impl__, *self.layers)
if result_layer[0] is None:
raise ValueError("No such layer %s" % name)
return result_layer[0]
def data_layers(self):
"""
get all data layer
:return:
"""
data_layers = dict()
def __impl__(l):
if isinstance(l, v2_layer.DataLayerV2):
data_layers[l.name] = l
__bfs_travel__(__impl__, *self.layers)
return data_layers
def data_type(self):
"""
get data_type from proto, such as:
[('image', dense_vector(768)), ('label', integer_value(10))]
"""
data_layers = self.data_layers()
return [(nm, data_layers[nm].type)
for nm in self.proto().input_layer_names]
def __check_layer_type__(layer):
if not isinstance(layer, v2_layer.LayerV2):
raise ValueError('layer should have type paddle.layer.Layer')
import collections import collections
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
from paddle.proto.ModelConfig_pb2 import ModelConfig
from data_feeder import DataFeeder
from data_feeder import DataFeeder
from topology import Topology
from . import event as v2_event from . import event as v2_event
from . import layer as v2_layer
from . import optimizer as v2_optimizer from . import optimizer as v2_optimizer
from . import parameters as v2_parameters from . import parameters as v2_parameters
...@@ -23,13 +22,6 @@ def default_event_handler(event): ...@@ -23,13 +22,6 @@ def default_event_handler(event):
pass pass
def __bfs_travel_topology__(callback, *topologies):
for each_layer in topologies:
callback(each_layer)
__bfs_travel_topology__(callback,
*each_layer.__parent_layers__.values())
class ITrainer(object): class ITrainer(object):
""" """
The interface of Trainer. The only exposed method is `train`. The interface of Trainer. The only exposed method is `train`.
...@@ -50,40 +42,26 @@ class ITrainer(object): ...@@ -50,40 +42,26 @@ class ITrainer(object):
class SGD(ITrainer): class SGD(ITrainer):
def __init__(self, topology, parameters, update_equation): def __init__(self, cost, parameters, update_equation):
""" """
Simple SGD Trainer. Simple SGD Trainer.
:param update_equation: The optimizer object. :param update_equation: The optimizer object.
:type update_equation: v2_optimizer.Optimizer :type update_equation: v2_optimizer.Optimizer
""" """
if not isinstance(parameters, v2_parameters.Parameters): if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters') raise TypeError('parameters should be parameters')
if not isinstance(update_equation, v2_optimizer.Optimizer): if not isinstance(update_equation, v2_optimizer.Optimizer):
raise TypeError("update equation parameter must be " raise TypeError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer") "paddle.v2.optimizer.Optimizer")
topology = Topology(cost)
self.__optimizer__ = update_equation self.__optimizer__ = update_equation
self.__topology__ = topology self.__topology__ = topology
self.__parameters__ = parameters self.__parameters__ = parameters
self.__topology_in_proto__ = v2_layer.parse_network(topology) self.__topology_in_proto__ = topology.proto()
data_types = dict() self.__data_types__ = topology.data_layers()
def __travel__(l):
if hasattr(l, 'type'):
data_types[l.name] = l.type
if not isinstance(topology, collections.Sequence):
topology = [topology]
__bfs_travel_topology__(__travel__, *topology)
self.__data_types__ = [
(iname, data_types[iname])
for iname in self.__topology_in_proto__.input_layer_names
]
if not isinstance(self.__topology_in_proto__, ModelConfig):
raise TypeError('topology should be a model config')
gm = api.GradientMachine.createFromConfigProto( gm = api.GradientMachine.createFromConfigProto(
self.__topology_in_proto__, api.CREATE_MODE_NORMAL, self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types()) self.__optimizer__.enable_types())
...@@ -103,7 +81,6 @@ class SGD(ITrainer): ...@@ -103,7 +81,6 @@ class SGD(ITrainer):
:param event_handler: Event handler. A method will be invoked when event :param event_handler: Event handler. A method will be invoked when event
occurred. occurred.
:type event_handler: (BaseEvent) => None :type event_handler: (BaseEvent) => None
:param data_types: Not important, will be removed after data refactor.
:return: :return:
""" """
if event_handler is None: if event_handler is None:
...@@ -113,6 +90,7 @@ class SGD(ITrainer): ...@@ -113,6 +90,7 @@ class SGD(ITrainer):
reader_dict = self.default_reader_dict() reader_dict = self.default_reader_dict()
__check_train_args__(**locals()) __check_train_args__(**locals())
updater = self.__optimizer__.create_local_updater() updater = self.__optimizer__.create_local_updater()
updater.init(self.__gradient_machine__) updater.init(self.__gradient_machine__)
...@@ -192,6 +170,5 @@ def __check_train_args__(reader, event_handler, **kwargs): ...@@ -192,6 +170,5 @@ def __check_train_args__(reader, event_handler, **kwargs):
if not callable(reader) or not isinstance(reader(), collections.Iterator): if not callable(reader) or not isinstance(reader(), collections.Iterator):
raise TypeError('train_data_reader should be a function, ' raise TypeError('train_data_reader should be a function, '
'which can return a iterator') 'which can return a iterator')
if not callable(event_handler): if not callable(event_handler):
raise TypeError('event handler should be a function') raise TypeError('event handler should be a function')
...@@ -15,9 +15,5 @@ setup(name='paddle', ...@@ -15,9 +15,5 @@ setup(name='paddle',
packages=packages, packages=packages,
package_dir={ package_dir={
'': '${CMAKE_CURRENT_SOURCE_DIR}' '': '${CMAKE_CURRENT_SOURCE_DIR}'
}, }
install_requires = [
'scikit-learn>=0.18.0',
'scipy>=0.18.0',
]
) )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册