提交 7955e3fc 编写于 作者: L liaogang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into macosx

...@@ -4,6 +4,7 @@ cache: ...@@ -4,6 +4,7 @@ cache:
- $HOME/third_party - $HOME/third_party
- $HOME/.ccache - $HOME/.ccache
- $HOME/.cache/pip - $HOME/.cache/pip
- $HOME/Library/Caches/Homebrew
sudo: required sudo: required
dist: trusty dist: trusty
os: os:
...@@ -54,7 +55,9 @@ before_install: ...@@ -54,7 +55,9 @@ before_install:
fi fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then paddle/scripts/travis/before_install.osx.sh; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then paddle/scripts/travis/before_install.osx.sh; fi
- if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi - if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
- pip install numpy wheel protobuf sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx recommonmark sphinx_rtd_theme virtualenv pre-commit requests==2.9.2 LinkChecker
script: script:
- paddle/scripts/travis/main.sh - paddle/scripts/travis/main.sh
notifications: notifications:
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
set(CBLAS_FOUND OFF) set(CBLAS_FOUND OFF)
## Find MKL First. ## Find MKL First.
set(MKL_ROOT $ENV{MKLROOT} CACHE PATH "Folder contains MKL") set(INTEL_ROOT "/opt/intel" CACHE PATH "Folder contains intel libs")
set(MKL_ROOT ${INTEL_ROOT}/mkl CACHE PATH "Folder contains MKL")
find_path(MKL_INCLUDE_DIR mkl.h PATHS find_path(MKL_INCLUDE_DIR mkl.h PATHS
${MKL_ROOT}/include) ${MKL_ROOT}/include)
......
...@@ -6,25 +6,15 @@ passed to C++ side of Paddle. ...@@ -6,25 +6,15 @@ passed to C++ side of Paddle.
The user api could be simpler and carefully designed. The user api could be simpler and carefully designed.
""" """
import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter
import paddle.trainer.PyDataProvider2 as dp
import numpy as np
import random import random
from mnist_util import read_from_mnist
from paddle.trainer_config_helpers import *
import paddle.v2
import numpy as np
import paddle.v2 as paddle_v2
import py_paddle.swig_paddle as api
from paddle.trainer_config_helpers import *
from py_paddle import DataProviderConverter
def network_config(): from mnist_util import read_from_mnist
imgs = data_layer(name='pixel', size=784)
hidden1 = fc_layer(input=imgs, size=200)
hidden2 = fc_layer(input=hidden1, size=200)
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
cost = classification_cost(
input=inference, label=data_layer(
name='label', size=10))
outputs(cost)
def init_parameter(network): def init_parameter(network):
...@@ -67,7 +57,7 @@ def input_order_converter(generator): ...@@ -67,7 +57,7 @@ def input_order_converter(generator):
def main(): def main():
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
optimizer = paddle.v2.optimizer.Adam( optimizer = paddle_v2.optimizer.Adam(
learning_rate=1e-4, learning_rate=1e-4,
batch_size=1000, batch_size=1000,
model_average=ModelAverage(average_window=0.5), model_average=ModelAverage(average_window=0.5),
...@@ -79,8 +69,20 @@ def main(): ...@@ -79,8 +69,20 @@ def main():
updater = optimizer.create_local_updater() updater = optimizer.create_local_updater()
assert isinstance(updater, api.ParameterUpdater) assert isinstance(updater, api.ParameterUpdater)
# define network
images = paddle_v2.layer.data(
name='pixel', type=paddle_v2.data_type.dense_vector(784))
label = paddle_v2.layer.data(
name='label', type=paddle_v2.data_type.integer_value(10))
hidden1 = paddle_v2.layer.fc(input=images, size=200)
hidden2 = paddle_v2.layer.fc(input=hidden1, size=200)
inference = paddle_v2.layer.fc(input=hidden2,
size=10,
act=paddle_v2.activation.Softmax())
cost = paddle_v2.layer.classification_cost(input=inference, label=label)
# Create Simple Gradient Machine. # Create Simple Gradient Machine.
model_config = parse_network_config(network_config) model_config = paddle_v2.layer.parse_network(cost)
m = api.GradientMachine.createFromConfigProto(model_config, m = api.GradientMachine.createFromConfigProto(model_config,
api.CREATE_MODE_NORMAL, api.CREATE_MODE_NORMAL,
optimizer.enable_types()) optimizer.enable_types())
...@@ -97,8 +99,7 @@ def main(): ...@@ -97,8 +99,7 @@ def main():
# DataProvider Converter is a utility convert Python Object to Paddle C++ # DataProvider Converter is a utility convert Python Object to Paddle C++
# Input. The input format is as same as Paddle's DataProvider. # Input. The input format is as same as Paddle's DataProvider.
converter = DataProviderConverter( converter = DataProviderConverter(input_types=[images.type, label.type])
input_types=[dp.dense_vector(784), dp.integer_value(10)])
train_file = './data/raw_data/train' train_file = './data/raw_data/train'
test_file = './data/raw_data/t10k' test_file = './data/raw_data/t10k'
......
import numpy
import paddle.v2 as paddle
import mnist_util
def train_reader():
train_file = './data/raw_data/train'
generator = mnist_util.read_from_mnist(train_file)
for item in generator:
yield item
def main():
paddle.init(use_gpu=False, trainer_count=1)
# define network topology
images = paddle.layer.data(
name='pixel', type=paddle.data_type.dense_vector(784))
label = paddle.layer.data(
name='label', type=paddle.data_type.integer_value(10))
hidden1 = paddle.layer.fc(input=images, size=200)
hidden2 = paddle.layer.fc(input=hidden1, size=200)
inference = paddle.layer.fc(input=hidden2,
size=10,
act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=inference, label=label)
parameters = paddle.parameters.create(cost)
for param_name in parameters.keys():
array = parameters.get(param_name)
array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
parameters.set(parameter_name=param_name, value=array)
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
para = parameters.get('___fc_2__.w0')
print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % (
event.pass_id, event.batch_id, event.cost, para.mean())
else:
pass
trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train(train_data_reader=train_reader,
topology=cost,
parameters=parameters,
event_handler=event_handler,
batch_size=32, # batch size should be refactor in Data reader
data_types={ # data_types will be removed, It should be in
# network topology
'pixel': images.type,
'label': label.type
})
if __name__ == '__main__':
main()
...@@ -32,4 +32,6 @@ def process(settings, file_name): ...@@ -32,4 +32,6 @@ def process(settings, file_name):
word_slot = [ word_slot = [
settings.word_dict[w] for w in words if w in settings.word_dict settings.word_dict[w] for w in words if w in settings.word_dict
] ]
if not word_slot:
continue
yield word_slot, label yield word_slot, label
...@@ -138,7 +138,11 @@ def main(): ...@@ -138,7 +138,11 @@ def main():
batch = [] batch = []
for line in sys.stdin: for line in sys.stdin:
batch.append([predict.get_index(line)]) words = predict.get_index(line)
if words:
batch.append([words])
else:
print('All the words in [%s] are not in the dictionary.' % line)
if len(batch) == batch_size: if len(batch) == batch_size:
predict.batch_predict(batch) predict.batch_predict(batch)
batch = [] batch = []
......
...@@ -279,6 +279,12 @@ concat_layer ...@@ -279,6 +279,12 @@ concat_layer
:members: concat_layer :members: concat_layer
:noindex: :noindex:
seq_concat_layer
----------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: seq_concat_layer
:noindex:
Reshaping Layers Reshaping Layers
================ ================
...@@ -302,6 +308,12 @@ repeat_layer ...@@ -302,6 +308,12 @@ repeat_layer
:members: repeat_layer :members: repeat_layer
:noindex: :noindex:
seq_reshape_layer
-----------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: seq_reshape_layer
:noindex:
Math Layers Math Layers
=========== ===========
......
# PaddlePaddle Design Doc
## Ingredients
As our design principle is starting from the essence: how could we
allow users to express and solve their problems at neural networks.
Some essential concepts that our API have to provide include:
1. A *topology* is an expression of *layers*.
1. A layer could be any kind of computation, including *cost*.
1. Some layers have parameters, some don't. Most costs don't have
parameters.
1. In some topologies, layers share parameters. For
example,
[the network for training a ranking model](https://github.com/PaddlePaddle/Paddle/issues/1311#issuecomment-279121850).
1. At programming time, users specify topologies and possible sharing
of parameters. PaddlePaddle can figure out and create parameters
required (and possibly shared) by one or more topologies.
## Starting from Examples
As a summarization
of
[our disucssion](https://github.com/PaddlePaddle/Paddle/issues/1315),
let us present two examples here:
### Example 1. Sharing Parameters between Layers
We use
the
[3-branch ranking](https://github.com/PaddlePaddle/Paddle/issues/1311#issuecomment-279121850) model
in this example. For your convenience, I copy-a-paste the model's
topology as follows:
```
A -> f -\
Q -> f --> cost
B -> f -/
```
The following program trains the topology including the cost, and then
use the sub-network in the trained topology in inference:
```python
def f(in):
e = paddle.layer.embedding(in, parameter_name="embedding")
o = paddle.layer.softmax(e, parameter_name="semantic")
return o
# Create 3 topologies (subnets), they share parameters because all
# correspoinding layers have the same parameter names.
fA = f(paddle.layer.data(input_name="A"))
fB = f(paddle.layer.data(input_name="B"))
fQ = f(paddle.layer.data(input_name="Q"))
topology = paddle.layer.less_than(
paddle.layer.cross_entropy(fA, fQ),
paddle.layer.corss_entropy(fB, fQ))
# Derive parameters required in topology and create them in model.
parameters = paddle.parameters.create(topology)
# Estimate parameters used in topology from data.
paddle.train(topology, parameters, reader=read_ranking_model_data)
# Inference using fA (or fB or fC, as they share their parameters).
[testA, testB, testQ] = read_ranking_model_data()
print "The sematic-vector of testA: ", paddle.infer(fA, parameters, testA)
```
### Example 2. Sharing Parameters between "Models"
We use [GAN](https://github.com/PaddlePaddle/book/tree/develop/gan) in
this example. In the following example program, `d0` and `d1`
correspond to the two networks in the following figure:
<img src="https://github.com/wangyang59/book/raw/00036f4b0da5225041a6824587c1a01cf20159b1/gan/image/gan_ig.png" width=400 />
```python
def G(in):
# over-simplified example as G has only one layers:
return paddle.layer.fc(in, parameter_name="G")
def D(in);
# again, over-simplified:
return paddle.layer.fc(in, parameter_name="D")
# Construct the first topology, which contains both D and G.
# By learning this topology, we update parameters of G.
d0 = paddle.layer.should_be_false(D(G(paddle.layer.data())))
# Construct a second topology d1, which contains only D. By
# training this topology, we update parameters of D. Note
# that d1 share parameters with d0.
d1 = paddle.layer.should_be_true(D(paddle.layer.data()))
# Create parameters from a list of multiple topologies (models) for
# the chance to share parameters between these topologies.
parameters = paddle.parameters.create([d0, d1])
# Iterative training of GAN.
for ...:
train(d0, parameters, reader=read_from_rng, immutable_parameters={"D"})
train(d1, parameters, reader=read_from_realistic_images)
# Use d1 for inference:
print "D thinks a batch of images are realistic ", infer(d1, parameters, read_mnist_images)
```
### Summarization
Above two programs reveal some important design concerns:
1. Users describe a topology as an expression of layers. Every layer
has a *parameter name*. If the users don't specify it explicitly, it's automatically generated as a unique name. By
specifying the parameter name, users can specify the sharing of
parameters between layers and even between topologies.
1. `paddle.parameters.create` figures out parameters required by one
or more topologies from parameter names of layers. It creates these
parameters and returns a `ParameterSet` object, which is in essence
a map from *parameter names* to *parameters*.
1. At training and inference time, `paddle.train` and `paddle.infer`
requires both a topology and the parameter set that holds the parameters of that topology. There are some reasons:
1. This prevents users from forgetting to call
`paddle.parameters.create`.
1. `paddle.train` needs to know which parameter set to update.
1. Users could load another (pre-trained) parameter set and use it
with a topology in `train.infer`.
1. By specifying the `immutable_parameters` parameter of
`paddle.train`, we can forbid the update of these parameters.
## Reader
Not all programming frameworks allow users to define I/O functions.
An example is Google MapReduce, which can only read from text,
SSTable, and RecordIO files. Hadoop MapReduce allows users to define
readers and writers by deriving from base classes `Reader` and
`Writer`. The former is less flexible but also less error-prone. We
decide to provide the flexibility to users to define their readers.
There are some open questions here:
1. **Should a reader return a Python dictionary?**
1. **How to map multiple outputs from a reader to multiple data layers?**
1. **How to easily compose some existing readers to read more data and
feed a topology with more data layers?**
## Training
The recommended way to training a model is to call `paddle.train`,
which simply calls `paddle.trainer.Default`, a global variable of
type `paddle.trainer.SGD`. Equivalently, we can do
```python
opt = paddle.trainer.SGD(..., paddle.updater.Adam(...))
opt.train(topology, parameters, reader=read, ...)
```
### Updater
Please be aware that a trainer can accept an updater as its data
member, where an updater is a class derived from
`paddle.trainer.Updater`. This is to make it easier to customize
trainers, as discussed
[here](https://github.com/PaddlePaddle/Paddle/issues/1319).
### Event Handler
`paddle.train` and `paddle.trainer.XXX.train` take an optional
parameter `event_handler`, which should be either `None` or a function
that handle some events:
1. BeginTraining
1. EndTraining
1. BeginIteration
1. EndIteration
1. BeginPass
1. EndPass
where EndPass is sent if and only if the reader yields
`end_pass=True`.
An example as follows:
```python
def event_handler(event):
if ininstance(event, paddle.event.EndIteration):
print paddle.test(...)
paddle.train(topology, parameters, reader, event_handler)
```
If we are writing a PaddlePaddle program in and for iPython/Jypyter,
we can use metaplotlib in the event handler to plot a curve of
cost/error versus iterations, as shown
[here](https://blog.dominodatalab.com/interactive-dashboards-in-jupyter/).
### Distributed Training
If users want to do distributed training on a cluster, s/he should
call `paddle.dist_train` and provides access tokens to the cluster as
a parameter.
For example, if the user has a TLS certificate that allows him to
access a Kubernetes cluster, s/he should be able to call
```python
paddle.dist_train(model,
trainer=paddle.trainer.SGD(...,
paddle.updater.Adam(...)),
reader=read,
k8s_user="yi",
k8s_token="kube_cluster_tls.pem",
k8s_job="hello",
num_parameter_servers=15)
```
The pseudo code if `paddle.dist_train` is as follows:
```python
def dist_train(topology, parameters, trainer, reader, ...):
if os.getenv("KUBERNETES_SERVICE_HOST") == None:
image_name = k8s_user + '/' + k8s_job
docker_build(image_name)
docker_push()
kube_ctrl_start_job(image_name, k8s_user, k8s_token)
else:
rank = kube_list_containers_in_job_and_return_current_containers_rank()
if rank == 0:
master()
elif rank < 15:
parameter_server()
else:
trainer.train(model, reader=read)
```
Please be aware that if a process is running on the Kubernetes
cluster, it will have some environment variables pre-defined.
If `dist_train` doesn't see these environment variables, it knows
that it's running on users' personal computer, and it should work as a
*launcher*. Otherwise, it knows that it's running on the cluster and
need to figure out its role as either the master, or a trainer, or a
parameter server.
# Python Data Reader Design Doc
At training and testing time, PaddlePaddle programs need to read data. To ease the users' work to write data reading code, we define that
- A *reader* is a function that reads data (from file, network, random number generator, etc) and yields data items.
- A *reader creator* is a function that returns a reader function.
- A *reader* decorator is a function, which accepts one or more readers, and returns a reader.
and provide frequently used reader creators and reader decorators.
## Data Reader Interface
Indeed, *data reader* doesn't have to be a function that reads and yields data items. It can be any function with no parameter that creates a iterable (anything can be used in `for x in iterable`):
```
iterable = data_reader()
```
Element produced from the iterable should be a **single** entry of data, **not** a mini batch. That entry of data could be a single item, or a tuple of items. Item should be of [supported type](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int)
An example implementation for single item data reader creator:
```python
def reader_creator_random_image(width, height):
def reader():
while True:
yield numpy.random.uniform(-1, 1, size=width*height)
return reader
```
An example implementation for multiple item data reader creator:
```python
def reader_creator_random_imageand_label(widht, height, label):
def reader():
while True:
yield numpy.random.uniform(-1, 1, size=width*height), label
return reader
```
## Usage
data reader, mapping from item(s) read to data layer, batch size and number of total pass will be passed into `paddle.train`:
```python
# two data layer is created:
image_layer = paddle.layer.data("image", ...)
label_layer = paddle.layer.data("label", ...)
# ...
paddle.train(paddle.dataset.mnist, {"image":0, "label":1}, 128, 10, ...)
```
## Data Reader Decorator
*Data reader decorator* takes a single or multiple data reader, returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` syntax.
Since we have a strict interface for data readers (no parameter, return a single data item). Data reader can be used flexiable via data reader decorators. Following are a few examples:
### Prefetch Data
Since reading data may take time and training can not proceed without data. It is generally a good idea to prefetch data.
Use `paddle.reader.buffered` to prefetch data:
```python
buffered_reader = paddle.reader.buffered(paddle.dataset.mnist, 100)
```
`buffered_reader` will try to buffer (prefetch) `100` data entries.
### Compose Multiple Data Readers
For example, we want to use a source of real images (reusing mnist dataset), and a source of random images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661).
We can do:
```python
def reader_creator_random_image(width, height):
def reader():
while True:
yield numpy.random.uniform(-1, 1, size=width*height)
return reader
def reader_creator_bool(t):
def reader:
while True:
yield t
return reader
true_reader = reader_creator_bool(True)
false_reader = reader_creator_bool(False)
reader = paddle.reader.compose(paddle.dataset.mnist, data_reader_creator_random_image(20, 20), true_reader, false_reader)
# Skipped 1 because paddle.dataset.mnist produces two items per data entry.
# And we don't care second item at this time.
paddle.train(reader, {"true_image":0, "fake_image": 2, "true_label": 3, "false_label": 4}, ...)
```
### Shuffle
Given shuffle buffer size `n`, `paddle.reader.shuffle` will return a data reader that buffers `n` data entries and shuffle them before a data entry is read.
Example:
```python
reader = paddle.reader.shuffle(paddle.dataset.mnist, 512)
```
## Q & A
### Why return only a single entry, but not a mini batch?
If a mini batch is returned, data reader need to take care of batch size. But batch size is a concept for training, it makes more sense for user to specify batch size as a parameter for `train`.
Practically, always return a single entry make reusing existing data readers much easier (e.g., if existing reader return not a single entry but 3 entries, training code will be more complex because it need to handle cases like batch size 2).
### Why use a dictionary but not a list to provide mapping?
We decided to use dictionary (`{"image":0, "label":1}`) instead of list (`["image", "label"]`) is because that user can easily resue item (e.g., using `{"image_a":0, "image_b":0, "label":1}`) or skip item (e.g., using `{"image_a":0, "label":2}`).
### How to create custom data reader creator
```python
def image_reader_creator(image_path, label_path, n):
def reader():
f = open(image_path)
l = open(label_path)
images = numpy.fromfile(
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
for i in xrange(n):
yield images[i, :], labels[i] # a single entry of data is created each time
f.close()
l.close()
return reader
# images_reader_creator creates a reader
reader = image_reader_creator("/path/to/image_file", "/path/to/label_file", 1024)
paddle.train(reader, {"image":0, "label":1}, ...)
```
### How is `paddle.train` implemented
An example implementation of paddle.train could be:
```python
def make_minibatch(reader, minibatch_size):
def ret():
r = reader()
buf = [r.next() for x in xrange(minibatch_size)]
while len(buf) > 0:
yield buf
buf = [r.next() for x in xrange(minibatch_size)]
return ret
def train(reader, mapping, batch_size, total_pass):
for pass_idx in range(total_pass):
for mini_batch in make_minibatch(reader): # this loop will never end in online learning.
do_forward_backward(mini_batch, mapping)
```
...@@ -27,3 +27,18 @@ std::string Evaluator::toString() { ...@@ -27,3 +27,18 @@ std::string Evaluator::toString() {
m->rawPtr->printStats(sout); m->rawPtr->printStats(sout);
return sout.str(); return sout.str();
} }
std::vector<std::string> Evaluator::getNames() const {
std::vector<std::string> retv;
m->rawPtr->getNames(&retv);
return retv;
}
double Evaluator::getValue(const std::string name) const {
paddle::Error err;
double v = m->rawPtr->getValue(name, &err);
if (err) {
throw std::runtime_error(err.msg());
}
return v;
}
...@@ -900,6 +900,10 @@ public: ...@@ -900,6 +900,10 @@ public:
*/ */
std::string toString(); std::string toString();
std::vector<std::string> getNames() const;
double getValue(const std::string name) const;
private: private:
EvaluatorPrivate* m; EvaluatorPrivate* m;
......
...@@ -68,7 +68,7 @@ class TestMatrix(unittest.TestCase): ...@@ -68,7 +68,7 @@ class TestMatrix(unittest.TestCase):
def test_numpyCpu(self): def test_numpyCpu(self):
numpy_mat = np.matrix([[1, 2], [3, 4], [5, 6]], dtype="float32") numpy_mat = np.matrix([[1, 2], [3, 4], [5, 6]], dtype="float32")
m = swig_paddle.Matrix.createCpuDenseFromNumpy(numpy_mat, copy=False) m = swig_paddle.Matrix.createCpuDenseFromNumpy(numpy_mat, False)
self.assertEqual((int(m.getHeight()), int(m.getWidth())), self.assertEqual((int(m.getHeight()), int(m.getWidth())),
numpy_mat.shape) numpy_mat.shape)
......
...@@ -89,9 +89,14 @@ def main(): ...@@ -89,9 +89,14 @@ def main():
except Exception as e: except Exception as e:
print e print e
ev = m.makeEvaluator()
ev.start()
m.forwardBackward(inArgs, outArgs, swig_paddle.PASS_TRAIN, m.forwardBackward(inArgs, outArgs, swig_paddle.PASS_TRAIN,
update_callback) update_callback)
m.eval(ev)
ev.finish()
for name in ev.getNames():
print name, ev.getValue(name)
for optimizer in optimizers: for optimizer in optimizers:
optimizer.finishBatch() optimizer.finishBatch()
......
...@@ -43,7 +43,7 @@ class TestIVector(unittest.TestCase): ...@@ -43,7 +43,7 @@ class TestIVector(unittest.TestCase):
def test_cpu_numpy(self): def test_cpu_numpy(self):
vec = np.array([1, 3, 4, 65, 78, 1, 4], dtype="int32") vec = np.array([1, 3, 4, 65, 78, 1, 4], dtype="int32")
iv = swig_paddle.IVector.createCpuVectorFromNumpy(vec, copy=False) iv = swig_paddle.IVector.createCpuVectorFromNumpy(vec, False)
self.assertEqual(vec.shape[0], int(iv.__len__())) self.assertEqual(vec.shape[0], int(iv.__len__()))
vec[4] = 832 vec[4] = 832
for i in xrange(len(iv)): for i in xrange(len(iv)):
...@@ -106,7 +106,7 @@ class TestVector(unittest.TestCase): ...@@ -106,7 +106,7 @@ class TestVector(unittest.TestCase):
def testCpuNumpy(self): def testCpuNumpy(self):
numpy_arr = np.array([1.2, 2.3, 3.4, 4.5], dtype="float32") numpy_arr = np.array([1.2, 2.3, 3.4, 4.5], dtype="float32")
vec = swig_paddle.Vector.createCpuVectorFromNumpy(numpy_arr, copy=False) vec = swig_paddle.Vector.createCpuVectorFromNumpy(numpy_arr, False)
assert isinstance(vec, swig_paddle.Vector) assert isinstance(vec, swig_paddle.Vector)
numpy_arr[0] = 0.1 numpy_arr[0] = 0.1
for n, v in zip(numpy_arr, vec): for n, v in zip(numpy_arr, vec):
......
...@@ -69,19 +69,6 @@ extern void hl_sequence_softmax_forward(real* A_d, ...@@ -69,19 +69,6 @@ extern void hl_sequence_softmax_forward(real* A_d,
const int* index, const int* index,
int numSequence); int numSequence);
/**
* @brief Matrix classification error.
*
* @param[in] A_d input matrix (M x N).
* @param[in] B_d input vector (M x 1).
* @param[out] C_d output vector (M x 1).
* @param[in] dimM matrix height.
* @param[in] dimN matrix width.
*
*/
extern void hl_matrix_classification_error(
real* A_d, int* B_d, real* C_d, int dimM, int dimN);
/** /**
* @brief Matrix cross entropy. * @brief Matrix cross entropy.
* *
...@@ -188,48 +175,6 @@ extern void hl_param_relu_backward_diff(real* grad_o, ...@@ -188,48 +175,6 @@ extern void hl_param_relu_backward_diff(real* grad_o,
int width, int width,
int height, int height,
int partial_sum); int partial_sum);
/**
* @brief cos sim forward
*
* @param[out] output output data
* @param[in] input1 input1 data(matrix)
* @param[in] input2 input2 data(matrix or vector)
* @param[in] width matrix width
* @param[in] input1_height input1_height
* @param[in] input2_height input2_height
* @param[in] scale scale factor
*/
extern void hl_cossim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale);
/**
* @brief cos sim derivate
*
* @param[in] grad output grad
* @param[in] output output data
* @param[in] prevOutX input1 data
* @param[in] prevOutY input2 data
* @param[out] prevGradX input1 grad
* @param[out] prevGradY input2 grad
* @param[in] width matrix width
* @param[in] input1_height input1 height
* @param[in] input2_height input2 height
* @param[in] scale scale factor
*/
extern void hl_cossim_derivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale);
/** /**
* @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel]. * @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel].
......
...@@ -58,4 +58,30 @@ extern void hl_sparse_matrix_top_k(real* topVal, ...@@ -58,4 +58,30 @@ extern void hl_sparse_matrix_top_k(real* topVal,
int beamSize, int beamSize,
int numSamples); int numSamples);
#endif /* HL_TOP_K_H_ */ /**
* @brief Matrix classification error.
*
* @param[out] topVal top k element.
* @param[in] ldv leading dimension of topVal.
* @param[out] topIds top k index.
* @param[in] src input value.
* @param[in] lds leading dimension of src.
* @param[in] dim width of input value.
* @param[in] topkSize size of top k element.
* @param[in] numSamples height of input value.
* @param[in] label ground truth label.
* @param[out] recResult top-k classification error.
*
*/
extern void hl_matrix_classification_error(real* topVal,
int ldv,
int* topIds,
real* src,
int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult);
#endif // HL_TOP_K_H_
...@@ -35,8 +35,16 @@ inline void hl_sequence_softmax_forward(real* A_d, ...@@ -35,8 +35,16 @@ inline void hl_sequence_softmax_forward(real* A_d,
inline void hl_matrix_softmax_derivative( inline void hl_matrix_softmax_derivative(
real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {} real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {}
inline void hl_matrix_classification_error( inline void hl_matrix_classification_error(real* topVal,
real* A_d, int* B_d, real* C_d, int dimM, int dimN) {} int ldv,
int* topIds,
real* src,
int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult) {}
inline void hl_matrix_cross_entropy( inline void hl_matrix_cross_entropy(
real* A_d, real* C_d, int* label_d, int dimM, int dimN) {} real* A_d, real* C_d, int* label_d, int dimM, int dimN) {}
...@@ -74,25 +82,6 @@ inline void hl_param_relu_backward_diff(real* grad_o, ...@@ -74,25 +82,6 @@ inline void hl_param_relu_backward_diff(real* grad_o,
int height, int height,
int partial_sum) {} int partial_sum) {}
inline void hl_cossim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale) {}
inline void hl_cossim_derivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale) {}
inline void hl_matrix_add_shared_bias(real* A_d, inline void hl_matrix_add_shared_bias(real* A_d,
real* B_d, real* B_d,
const int channel, const int channel,
......
...@@ -265,59 +265,6 @@ void hl_matrix_softmax_derivative(real *grad_d, ...@@ -265,59 +265,6 @@ void hl_matrix_softmax_derivative(real *grad_d,
CHECK_SYNC("hl_matrix_softmax_derivative failed"); CHECK_SYNC("hl_matrix_softmax_derivative failed");
} }
template<int blockSize>
__global__ void KeMatrixClassificationError(real* in_A,
int* in_B,
real* out_C,
int dimN) {
__shared__ real max_s[blockSize];
__shared__ int max_l[blockSize];
const int tid = threadIdx.x;
const int rowId = blockIdx.x;
max_s[tid] = -1e30f;
in_A += rowId * dimN;
real tmp;
for (int colId = tid; colId < dimN; colId += blockSize) {
tmp = in_A[colId];
if (max_s[tid] < tmp) {
max_s[tid] = tmp;
max_l[tid] = colId;
}
}
__syncthreads();
for (int stride = blockSize/2; stride > 0; stride = stride/2) {
if (tid < stride) {
if (max_s[tid] < max_s[tid + stride]) {
max_s[tid] = max_s[tid + stride];
max_l[tid] = max_l[tid + stride];
}
}
__syncthreads();
}
__syncthreads();
if (tid == 0) {
out_C[rowId] = (max_l[0] == in_B[rowId] ? 0 : 1.0f);
}
}
void hl_matrix_classification_error(real* A_d,
int* B_d,
real* C_d,
int dimM,
int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
CHECK_NOTNULL(C_d);
// each sample is calculated by one block
KeMatrixClassificationError<1024><<< dimM, 1024, 0, STREAM_DEFAULT >>>
(A_d, B_d, C_d, dimN);
CHECK_SYNC("hl_matrix_classification_error");
}
__global__ void KeMatrixMultiBinaryCrossEntropy(real* output, __global__ void KeMatrixMultiBinaryCrossEntropy(real* output,
real* entropy, real* entropy,
int* row, int* row,
...@@ -584,177 +531,6 @@ void hl_param_relu_backward_diff(real* grad_o, ...@@ -584,177 +531,6 @@ void hl_param_relu_backward_diff(real* grad_o,
CHECK_SYNC("hl_param_relu_backward_diff failed"); CHECK_SYNC("hl_param_relu_backward_diff failed");
} }
template<int blockSize>
__global__ void KeCosSim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[blockSize];
__shared__ real yy[blockSize];
__shared__ real xy[blockSize];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
input1 += ty * width;
if (input2_height > 1) {
input2 += ty * width;
}
for (int index = tid; index < width; index += blockSize) {
real x = input1[index];
real y = input2[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (tid == 0) {
output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0]));
}
}
void hl_cossim(real* output,
real* input1,
real* input2,
int width,
int input1_height,
int input2_height,
real scale) {
CHECK_NOTNULL(output);
CHECK_NOTNULL(input1);
CHECK_NOTNULL(input2);
const int blockSize = 256;
dim3 threads(blockSize, 1);
dim3 grid(1, input1_height);
KeCosSim<blockSize><<<grid, threads, 0, STREAM_DEFAULT>>>
(output, input1, input2, width, input1_height, input2_height, scale);
CHECK_SYNC("hl_cossim failed");
}
template<int blockSize>
__global__ void KeCosSimDerivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[blockSize];
__shared__ real yy[blockSize];
__shared__ real xy[blockSize];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
prevOutX += ty * width;
prevGradX += ty * width;
if (input2_height > 1) {
prevOutY += ty * width;
prevGradY += ty * width;
}
for (int index = tid; index < width; index += blockSize) {
real x = prevOutX[index];
real y = prevOutY[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (xy[0] == 0) {
real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0]));
for (int index = tid; index < width; index += blockSize) {
prevGradX[index] +=
scale * grad[ty] * prevOutY[index] * reciprocal;
if (input2_height > 1) {
prevGradY[index] +=
scale * grad[ty] * prevOutX[index] * reciprocal;
} else {
paddle::paddleAtomicAdd(prevGradY + index,
scale * grad[ty] * prevOutX[index] * reciprocal);
}
}
} else {
real reciprocalXY = 1.0 / xy[0];
real reciprocalSquareSumX = 1.0 / xx[0];
real reciprocalSquareSumY = 1.0 / yy[0];
for (int index = tid; index < width; index += blockSize) {
prevGradX[index] += output[ty] * grad[ty] *
(prevOutY[index] * reciprocalXY -
prevOutX[index] * reciprocalSquareSumX);
if (input2_height > 1) {
prevGradY[index] += output[ty] * grad[ty] *
(prevOutX[index] * reciprocalXY -
prevOutY[index] * reciprocalSquareSumY);
} else {
paddle::paddleAtomicAdd(prevGradY + index, output[ty] * grad[ty] *
(prevOutX[index] * reciprocalXY -
prevOutY[index] * reciprocalSquareSumY));
}
}
}
}
void hl_cossim_derivative(real* grad,
real* output,
real* prevOutX,
real* prevOutY,
real* prevGradX,
real* prevGradY,
int width,
int input1_height,
int input2_height,
real scale) {
CHECK_NOTNULL(grad);
CHECK_NOTNULL(output);
CHECK_NOTNULL(prevOutX);
CHECK_NOTNULL(prevOutY);
CHECK_NOTNULL(prevGradX);
CHECK_NOTNULL(prevGradY);
const int blockSize = 256;
dim3 threads(blockSize, 1);
dim3 grid(1, input1_height);
KeCosSimDerivative<blockSize><<<grid, threads, 0, STREAM_DEFAULT>>>
(grad, output, prevOutX, prevOutY, prevGradX, prevGradY, width,
input1_height, input2_height, scale);
CHECK_SYNC("hl_cossim_derivate failed");
}
__global__ void KeMatrixAddSharedBias(real* A, __global__ void KeMatrixAddSharedBias(real* A,
real* B, real* B,
const int channel, const int channel,
......
...@@ -384,3 +384,81 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv, ...@@ -384,3 +384,81 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv,
CHECK_SYNC("hl_sparse_matrix_top_k failed"); CHECK_SYNC("hl_sparse_matrix_top_k failed");
} }
/**
* Each block compute one sample.
* In a block:
* 1. every thread get top maxLength value;
* 2. merge to shTopK, block reduce and get max value;
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
*/
template<int maxLength, int blockSize>
__global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
int * topIds,
real* src, int lds,
int dim,
int beamSize,
int* label,
real* recResult) {
__shared__ Pair shTopK[blockSize];
__shared__ int maxId[blockSize / 2];
const int tid = threadIdx.x;
const int warp = threadIdx.x / 32;
src += blockIdx.x * lds;
topVal += blockIdx.x * ldv;
topIds += blockIdx.x * beamSize;
Pair topK[maxLength]; // NOLINT
int beam = maxLength;
Pair max;
bool isEmpty = false;
bool firstStep = true;
int topkSize = beamSize;
for (int k = 0; k < maxLength; k++) {
topK[k].set(-HL_FLOAT_MAX, -1);
}
while (beamSize) {
threadGetTopK<maxLength, blockSize>
(topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
shTopK[tid] = topK[0];
blockReduce<maxLength, blockSize>
(shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
}
__syncthreads();
if (tid == 0) {
for (int i = 0; i < topkSize; i++) {
if (*--topIds == label[blockIdx.x]) {
recResult[blockIdx.x] = 0;
break;
}
recResult[blockIdx.x] = 1.0f;
}
}
}
void hl_matrix_classification_error(real* topVal, int ldv,
int* topIds,
real* src, int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult) {
CHECK_NOTNULL(topVal);
CHECK_NOTNULL(topIds);
CHECK_NOTNULL(src);
if (topkSize > dim) topkSize = dim;
dim3 threads(256, 1);
dim3 grid(numSamples, 1);
KeMatrixTopKClassificationError<5, 256>
<<< grid, threads, 0, STREAM_DEFAULT >>>
(topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult);
CHECK_SYNC("hl_matrix_top_k classification error failed");
}
...@@ -54,22 +54,26 @@ DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size) ...@@ -54,22 +54,26 @@ DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size)
#define WARPCTC_GET_VERSION dynload::get_warpctc_version #define WARPCTC_GET_VERSION dynload::get_warpctc_version
#define WARPCTC_GET_STATUS_STRING dynload::ctcGetStatusString #define WARPCTC_GET_STATUS_STRING dynload::ctcGetStatusString
static int g_warpctcVersion = -1;
#ifndef PADDLE_TYPE_DOUBLE #ifndef PADDLE_TYPE_DOUBLE
#define WARPCTC_COMPUTE_LOSS dynload::compute_ctc_loss #define WARPCTC_COMPUTE_LOSS dynload::compute_ctc_loss
#define WARPCTC_GET_WORKSPACE_SIZE dynload::get_workspace_size #define WARPCTC_GET_WORKSPACE_SIZE dynload::get_workspace_size
#else #else
#define WARPCTC_LOG_FATAL \ hl_warpctc_status_t fatal(...) {
LOG(FATAL) << "warp-ctc [version " << g_warpctcVersion \ LOG(FATAL) << "warp-ctc [version " << g_warpctcVersion
<< "] Error: not support double precision." << "] Error: not support double precision.";
#define WARPCTC_COMPUTE_LOSS(...) WARPCTC_LOG_FATAL(__VA_ARGS__) // both of get_warpctc_version() and get_workspace_size() return an ctcStatus
#define WARPCTC_GET_WORKSPACE_SIZE(...) WARPCTC_LOG_FATAL(__VA_ARGS__) // type value
return CTC_STATUS_EXECUTION_FAILED;
}
#define WARPCTC_COMPUTE_LOSS fatal
#define WARPCTC_GET_WORKSPACE_SIZE fatal
#endif #endif
/** /**
* Check build-in warp-ctc function using glog and it also * Check build-in warp-ctc function using glog and it also
* support << operator for more details error info. * support << operator for more details error info.
*/ */
static int g_warpctcVersion = -1;
#define CHECK_WARPCTC(warpctcStat) \ #define CHECK_WARPCTC(warpctcStat) \
CHECK_EQ(CTC_STATUS_SUCCESS, warpctcStat) \ CHECK_EQ(CTC_STATUS_SUCCESS, warpctcStat) \
<< "warp-ctc [version " << g_warpctcVersion \ << "warp-ctc [version " << g_warpctcVersion \
......
...@@ -190,7 +190,7 @@ public: ...@@ -190,7 +190,7 @@ public:
: BufferArg(VALUE_TYPE_INT32, shape, argType) { : BufferArg(VALUE_TYPE_INT32, shape, argType) {
bufferType_ = TENSOR_SEQUENCE_ID; bufferType_ = TENSOR_SEQUENCE_ID;
CHECK_EQ(shape_.ndims(), 1UL); CHECK_EQ(shape_.ndims(), 1UL);
CHECK_GT(shape_[0], 1UL); CHECK_GE(shape_[0], 1UL);
numSeqs_ = shape_[0] - 1; numSeqs_ = shape_[0] - 1;
} }
...@@ -226,7 +226,8 @@ public: ...@@ -226,7 +226,8 @@ public:
SequenceArg(ValueType valueType, SequenceArg(ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) { : BufferArg(valueType, shape, argType),
startPositions_(TensorShape({shape[0]})) {
bufferType_ = TENSOR_SEQUENCE_DATA; bufferType_ = TENSOR_SEQUENCE_DATA;
} }
......
...@@ -27,6 +27,7 @@ if(WITH_TESTING) ...@@ -27,6 +27,7 @@ if(WITH_TESTING)
add_simple_unittest(ContextProjectionOpTest) add_simple_unittest(ContextProjectionOpTest)
add_simple_unittest(PadOpTest) add_simple_unittest(PadOpTest)
add_simple_unittest(MulOpTest) add_simple_unittest(MulOpTest)
add_simple_unittest(CosSimOpTest)
endif() endif()
endif() endif()
......
...@@ -108,26 +108,23 @@ public: ...@@ -108,26 +108,23 @@ public:
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK(1 == inputs.size() || 2 == inputs.size()); CHECK(1UL == inputs.size() || 2UL == inputs.size());
CHECK_EQ((size_t)1, outputs.size()); CHECK_EQ(1UL, outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here"; << "SequenceArg required here";
const auto val_seqs = dynamic_cast<const SequenceArg&>(inputs[0]); const auto val_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]); auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(out_seq.data() && val_seqs.data() && val_seqs.getSequenceId().data()); CHECK(out_seq.data() && val_seqs.data() && val_seqs.getSequenceId().data());
CHECK_EQ(out_seq.shape().ndims(), (size_t)2); CHECK_EQ(out_seq.shape().ndims(), 2UL);
CHECK_EQ(val_seqs.shape().ndims(), (size_t)2); CHECK_EQ(val_seqs.shape().ndims(), 2UL);
CHECK_EQ(val_seqs.getSequenceId().shape().ndims(), (size_t)1);
if (2 == inputs.size()) {
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
}
/// dim of output = dim of input * context_length /// dim of output = dim of input * context_length
CHECK_EQ(out_seq.shape()[1], val_seqs.shape()[1] * context_length_); CHECK_EQ(out_seq.shape()[1], val_seqs.shape()[1] * context_length_);
/// input and output has the same batch_size /// input and output has the same batch_size
CHECK_EQ(val_seqs.shape()[0], out_seq.shape()[0]); CHECK_EQ(val_seqs.shape()[0], out_seq.shape()[0]);
/// dim of input == dim of weight if (2UL == inputs.size()) {
if (2 == inputs.size()) { CHECK_EQ(inputs[1].shape().ndims(), 2UL);
/// dim of input == dim of weight
CHECK_EQ(val_seqs.shape()[1], inputs[1].shape()[1]); CHECK_EQ(val_seqs.shape()[1], inputs[1].shape()[1]);
} }
...@@ -135,10 +132,11 @@ public: ...@@ -135,10 +132,11 @@ public:
auto out_mat = out_seq.matrix<Device>(); auto out_mat = out_seq.matrix<Device>();
const auto in_mat = val_seqs.matrix<Device>(); const auto in_mat = val_seqs.matrix<Device>();
const auto w_mat = const auto w_mat =
(2 == inputs.size()) (2UL == inputs.size() && inputs[1].data())
? inputs[1].matrix<Device>() ? inputs[1].matrix<Device>()
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0); : typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
const auto seq_vec = val_seqs.getSequenceId().vector<int, Device>(); const auto seq_vec = val_seqs.getSequenceId().vector<int, Device>();
ContextProjectionForward<Device>(out_mat, ContextProjectionForward<Device>(out_mat,
in_mat, in_mat,
w_mat, w_mat,
...@@ -235,36 +233,40 @@ public: ...@@ -235,36 +233,40 @@ public:
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)1, inputs.size()); CHECK_EQ(1UL, inputs.size());
CHECK_EQ((size_t)2, outputs.size()); CHECK(1UL == outputs.size() || 2UL == outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here"; << "SequenceArg required here";
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]); const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]); auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(in_seq.data() && in_seq.getSequenceId().data()); CHECK(in_seq.data() && in_seq.getSequenceId().data());
CHECK_EQ(in_seq.shape().ndims(), (size_t)2); CHECK_EQ(in_seq.shape().ndims(), 2UL);
CHECK_EQ(in_seq.getSequenceId().shape().ndims(), (size_t)1); CHECK_EQ(out_seq.shape().ndims(), 2UL);
CHECK_EQ(out_seq.shape().ndims(), (size_t)2); CHECK_EQ(out_seq.getSequenceId().shape().ndims(), 1UL);
CHECK_EQ(out_seq.getSequenceId().shape().ndims(), (size_t)1);
CHECK_EQ(outputs[1].shape().ndims(), (size_t)2);
/// dim of input grad == dim of weight
CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]);
/// input and output grad has the same batch_size /// input and output grad has the same batch_size
CHECK_EQ(out_seq.shape()[0], in_seq.shape()[0]); CHECK_EQ(out_seq.shape()[0], in_seq.shape()[0]);
/// dim of output grad = dim of input grad * context_length /// dim of output grad = dim of input grad * context_length
CHECK_EQ(in_seq.shape()[1], out_seq.shape()[1] * context_length_); CHECK_EQ(in_seq.shape()[1], out_seq.shape()[1] * context_length_);
CHECK_EQ(out_seq.getArgType(), ADD_TO); CHECK_EQ(out_seq.getArgType(), ADD_TO);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
if (2UL == outputs.size()) {
CHECK_EQ(outputs[1].shape().ndims(), 2UL);
/// dim of input grad == dim of weight
CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
}
const auto seq_vec = in_seq.getSequenceId().vector<int, Device>(); const auto seq_vec = in_seq.getSequenceId().vector<int, Device>();
const auto out_grad_mat = in_seq.matrix<Device>(); const auto out_grad_mat = in_seq.matrix<Device>();
auto in_grad_mat = auto in_grad_mat =
!out_seq.data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0) !out_seq.data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: out_seq.matrix<Device>(); : out_seq.matrix<Device>();
auto w_grad_mat = !outputs[1].data() auto w_grad_mat =
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0) (2UL == outputs.size() && outputs[1].data())
: outputs[1].matrix<Device>(); ? outputs[1].matrix<Device>()
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
ContextProjectionBackward<Device>(out_grad_mat, ContextProjectionBackward<Device>(out_grad_mat,
in_grad_mat, in_grad_mat,
w_grad_mat, w_grad_mat,
...@@ -304,17 +306,17 @@ public: ...@@ -304,17 +306,17 @@ public:
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1, static_cast<int>(inputs.size())); CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1, static_cast<int>(outputs.size())); CHECK_EQ(1UL, outputs.size());
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here"; << "SequenceArg required here";
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]); const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
const auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]); const auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(in_seq.data() && out_seq.data() && in_seq.getSequenceId().data()); CHECK(in_seq.data() && out_seq.data() && in_seq.getSequenceId().data());
CHECK_EQ(static_cast<int>(out_seq.shape().ndims()), 2); CHECK_EQ(out_seq.shape().ndims(), 2UL);
CHECK_EQ(static_cast<int>(in_seq.shape().ndims()), 2); CHECK_EQ(in_seq.shape().ndims(), 2UL);
CHECK_EQ(static_cast<int>(in_seq.getSequenceId().shape().ndims()), 1); CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL);
/// output layer grad dim == input layer grad dim * context_length_ /// output layer grad dim == input layer grad dim * context_length_
CHECK_EQ(in_seq.shape().ndims(), out_seq.shape().ndims() * context_length_); CHECK_EQ(in_seq.shape().ndims(), out_seq.shape().ndims() * context_length_);
/// input and output has the same batch_size /// input and output has the same batch_size
...@@ -355,14 +357,14 @@ public: ...@@ -355,14 +357,14 @@ public:
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1, static_cast<int>(inputs.size())); CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1, static_cast<int>(outputs.size())); CHECK_EQ(1UL, outputs.size());
CHECK(inputs[0].isSequenceArg()) << "SequenceArg required here"; CHECK(inputs[0].isSequenceArg()) << "SequenceArg required here";
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]); const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(in_seq.data() && in_seq.getSequenceId().data() && outputs[0].data()); CHECK(in_seq.data() && in_seq.getSequenceId().data() && outputs[0].data());
CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2); CHECK_EQ(outputs[0].shape().ndims(), 2UL);
CHECK_EQ(static_cast<int>(in_seq.shape().ndims()), 2); CHECK_EQ(in_seq.shape().ndims(), 2UL);
CHECK_EQ(static_cast<int>(in_seq.getSequenceId().shape().ndims()), 1); CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL);
CHECK_EQ(in_seq.shape()[0], outputs[0].shape()[0]); CHECK_EQ(in_seq.shape()[0], outputs[0].shape()[0]);
/// output layer grad dim == weight dim * context_length_ /// output layer grad dim == weight dim * context_length_
CHECK_EQ(in_seq.shape()[1], outputs[0].shape()[1] * context_length_); CHECK_EQ(in_seq.shape()[1], outputs[0].shape()[1] * context_length_);
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "Function.h" #include "Function.h"
namespace paddle { namespace paddle {
......
...@@ -28,55 +28,26 @@ void testMatrixProjectionForward(int context_start, ...@@ -28,55 +28,26 @@ void testMatrixProjectionForward(int context_start,
std::max(0, (int)(context_start + context_length - 1)); std::max(0, (int)(context_start + context_length - 1));
if (pad == 0) is_padding = false; if (pad == 0) is_padding = false;
FunctionCompare compare("ContextProjectionForward", FunctionCompare test("ContextProjectionForward",
FuncConfig() FuncConfig()
.set("context_length", context_length) .set("context_length", context_length)
.set("context_start", context_start) .set("context_start", context_start)
.set("begin_pad", std::max(0, -context_start))); .set("begin_pad", std::max(0, -context_start)));
CpuMatrix cpu_in(batch_size, input_dim); // prepare input arguments
cpu_in.randomizeUniform(); test.addSequence(SequenceIdArg(TensorShape{batch_size}));
GpuMatrix gpu_in(batch_size, input_dim); test.addInputs(
gpu_in.copyFrom(cpu_in); SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim}));
auto cpu_weight = if (is_padding) { // weight
is_padding ? std::make_shared<CpuMatrix>(pad, input_dim) : nullptr; test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{pad, input_dim}));
auto gpu_weight =
is_padding ? std::make_shared<GpuMatrix>(pad, input_dim) : nullptr;
if (is_padding) {
cpu_weight->randomizeUniform();
gpu_weight->copyFrom(*cpu_weight);
} }
IVectorPtr cpu_seq; test.addOutputs(
generateSequenceStartPositions(batch_size, cpu_seq); SequenceArg(VALUE_TYPE_FLOAT,
IVectorPtr gpu_seq = IVector::create(cpu_seq->getSize(), true); TensorShape{batch_size, input_dim * context_length}),
gpu_seq->copyFrom(*cpu_seq); ADD_TO);
CpuMatrix cpu_out(batch_size, input_dim * context_length);
GpuMatrix gpu_out(batch_size, input_dim * context_length);
cpu_out.randomizeUniform();
gpu_out.copyFrom(cpu_out);
BufferArgs cpu_inputs;
BufferArgs cpu_outputs;
cpu_inputs.addArg(cpu_in, *cpu_seq);
if (cpu_weight) {
cpu_inputs.addArg(*cpu_weight, *cpu_seq);
}
cpu_outputs.addArg(cpu_out, *cpu_seq, ADD_TO);
compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs);
BufferArgs gpu_inputs; // run Function
BufferArgs gpu_outputs; test.run();
gpu_inputs.addArg(gpu_in, *gpu_seq);
if (gpu_weight) {
gpu_inputs.addArg(*gpu_weight, *gpu_seq);
}
gpu_outputs.addArg(gpu_out, *gpu_seq, ADD_TO);
compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs);
autotest::TensorCheckEqual(cpu_out, gpu_out);
} }
void testMatrixProjectionBackward(int context_start, void testMatrixProjectionBackward(int context_start,
...@@ -88,63 +59,31 @@ void testMatrixProjectionBackward(int context_start, ...@@ -88,63 +59,31 @@ void testMatrixProjectionBackward(int context_start,
std::max(0, (int)(context_start + context_length - 1)); std::max(0, (int)(context_start + context_length - 1));
if (pad == 0) is_padding = false; if (pad == 0) is_padding = false;
FunctionCompare compare("ContextProjectionBackward", FunctionCompare test("ContextProjectionBackward",
FuncConfig() FuncConfig()
.set("context_length", context_length) .set("context_length", context_length)
.set("context_start", context_start) .set("context_start", context_start)
.set("begin_pad", std::max(0, -context_start)) .set("begin_pad", std::max(0, -context_start))
.set("is_padding", is_padding) .set("is_padding", is_padding)
.set("total_pad", pad)); .set("total_pad", pad));
CpuMatrix cpu_in_grad(batch_size, input_dim); // prepare input arguments
cpu_in_grad.randomizeUniform(); test.addSequence(SequenceIdArg(TensorShape{batch_size}));
GpuMatrix gpu_in_grad(batch_size, input_dim); test.addInputs(SequenceArg(
gpu_in_grad.copyFrom(cpu_in_grad); VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim * context_length}));
test.addOutputs(
CpuMatrix cpu_out_grad(batch_size, input_dim * context_length); SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batch_size, input_dim}),
cpu_out_grad.randomizeUniform(); ADD_TO);
GpuMatrix gpu_out_grad(batch_size, input_dim * context_length); if (is_padding) { // weight
gpu_out_grad.copyFrom(cpu_out_grad); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{pad, input_dim}),
ADD_TO);
IVectorPtr cpu_seq;
generateSequenceStartPositions(batch_size, cpu_seq);
IVectorPtr gpu_seq = IVector::create(cpu_seq->getSize(), true);
gpu_seq->copyFrom(*cpu_seq);
auto cpu_w_grad =
is_padding ? std::make_shared<CpuMatrix>(pad, input_dim) : nullptr;
auto gpu_w_grad =
is_padding ? std::make_shared<GpuMatrix>(pad, input_dim) : nullptr;
if (is_padding) {
cpu_w_grad->randomizeUniform();
gpu_w_grad->copyFrom(*cpu_w_grad);
} }
BufferArgs cpu_inputs; // run Function
BufferArgs cpu_outputs; test.run();
cpu_inputs.addArg(cpu_out_grad, *cpu_seq);
cpu_outputs.addArg(cpu_in_grad, *cpu_seq, ADD_TO);
cpu_outputs.addArg(
cpu_w_grad ? *cpu_w_grad : CpuMatrix(nullptr, 0, input_dim), ADD_TO);
compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs);
BufferArgs gpu_inputs;
BufferArgs gpu_outputs;
gpu_inputs.addArg(gpu_out_grad, *gpu_seq);
gpu_outputs.addArg(gpu_in_grad, *gpu_seq, ADD_TO);
gpu_outputs.addArg(
gpu_w_grad ? *gpu_w_grad : GpuMatrix(nullptr, 0, input_dim), ADD_TO);
compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs);
autotest::TensorCheckErr(cpu_in_grad, gpu_in_grad);
if (is_padding) {
autotest::TensorCheckErr(*cpu_w_grad, *gpu_w_grad);
}
} }
TEST(ContextProjection, projection) { TEST(ContextProjection, Projection) {
for (auto context_start : {-5, -3, -1, 0, 3}) { for (auto context_start : {-5, -3, -1, 0, 3}) {
for (auto context_length : {1, 2, 5, 7}) { for (auto context_length : {1, 2, 5, 7}) {
for (auto trainable_padding : {false, true}) { for (auto trainable_padding : {false, true}) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "CosSimOp.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/Vector.h"
namespace paddle {
/**
* Cosine Similarity for CpuMatrix
*
* \param out_mat, output value, size: nSamples * 1.
* \param in1_mat, input value 1, size: nSamples * dim.
* \param in2_mat, input value 2, size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param scale, default 1.0
*
*/
template <>
void CosSimForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
const CpuMatrix& in1_mat,
const CpuMatrix& in2_mat,
real scale) {
CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData());
size_t num_samples = out_mat.getHeight();
size_t dim = in1_mat.getWidth();
/// column vector [nSamples, 1]
real* out = out_mat.getData();
const real* x = in1_mat.getData();
const real* y = in2_mat.getData();
/// in2 might only have one row or full rows
CHECK(in2_mat.getHeight() == 1LU || in2_mat.getHeight() == num_samples);
size_t inc = (in2_mat.getHeight() == 1LU) ? 0 : dim;
for (size_t i = 0; i < num_samples; ++i, x += dim, y += inc) {
real square_sum_x = 0;
real square_sum_y = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
square_sum_x += x[j] * x[j];
square_sum_y += y[j] * y[j];
xy += x[j] * y[j];
}
CHECK(square_sum_x > 0 && square_sum_y > 0);
out[i] = scale * xy / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y));
}
}
/**
* Cosine Similarity
* for each row i,
* out[i] = scale * cos(input1[i], input2[i])
* = scale * <input1[i], input2[i]>/sqrt(|input1[i]|^2 * |input2[i]|^2)
* when input2 only has one row, then for each row i,
* out[i] = cos(input1[i], input2[0])
*
* \param inputs[0] input matrix 1, size: nSamples * dim.
* \param inputs[1] input matrix 2, size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param outputs[0] output matrix, size : nSamples * 1.
*/
template <DeviceType Device>
class CosSimForwardFunc : public FunctionBase {
void init(const FuncConfig& config) override {
scale_ = config.get<real>("scale");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(inputs.size(), 2UL);
CHECK_EQ(outputs.size(), 1UL);
CHECK_EQ(inputs[0].shape().ndims(), 2UL);
CHECK_EQ(inputs[1].shape().ndims(), 2UL);
CHECK_EQ(outputs[0].shape().ndims(), 2UL);
CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]);
CHECK_EQ(outputs[0].shape()[1], 1UL);
CHECK(outputs[0].data() && inputs[0].data() && inputs[1].data());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
auto out_mat = outputs[0].matrix<Device>();
const auto in1_mat = inputs[0].matrix<Device>();
const auto in2_mat = inputs[1].matrix<Device>();
CosSimForward<Device>(out_mat, in1_mat, in2_mat, scale_);
}
private:
real scale_;
};
/**
* Cosine Similarity Derivative for CpuMatrix
*
* \param in1_grad forward input grad 1, size: nSamples * dim.
* \param in2_grad forward input grad 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
*
* \param out_grad backward loss output grad, size : nSamples * 1.
* \param out_val forward output value, size: nSamples * 1.
* \param in1_val forward input value 1, size: nSamples * dim.
* \param in2_val forward input value 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param scale, default 1.0
*/
template <>
void CosSimBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad,
const CpuMatrix& out_val,
const CpuMatrix& in1_val,
const CpuMatrix& in2_val,
CpuMatrix& in1_grad,
CpuMatrix& in2_grad,
real scale) {
CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() &&
in2_val.getData() && in1_grad.getData() && in2_grad.getData());
CHECK_EQ(out_val.useGpu_, false) << "Matrix type are GPU, CPU required";
const real* grad = out_grad.getData();
const real* out = out_val.getData();
const real* prev_out_x = in1_val.getData();
const real* prev_out_y = in2_val.getData();
real* prev_grad_x = in1_grad.getData();
real* prev_grad_y = in2_grad.getData();
size_t num_samples = out_grad.getHeight();
size_t dim = in1_val.getWidth();
CHECK_EQ(in2_val.getHeight(), in2_grad.getHeight());
CHECK(in2_val.getHeight() == 1LU || in2_val.getHeight() == num_samples);
size_t inc = (in2_val.getHeight() == 1LU) ? 0 : dim;
for (size_t i = 0; i < num_samples; ++i,
prev_out_x += dim,
prev_out_y += inc,
prev_grad_x += dim,
prev_grad_y += inc) {
real square_sum_x = 0;
real square_sum_y = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
square_sum_x += prev_out_x[j] * prev_out_x[j];
square_sum_y += prev_out_y[j] * prev_out_y[j];
xy += prev_out_x[j] * prev_out_y[j];
}
CHECK(square_sum_x > 0 && square_sum_y > 0);
if (xy == 0) {
real reciprocal =
1.0f / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y));
for (size_t j = 0; j < dim; ++j) {
prev_grad_x[j] += scale * grad[i] * prev_out_y[j] * reciprocal;
prev_grad_y[j] += scale * grad[i] * prev_out_x[j] * reciprocal;
}
} else {
real reciprocal_xy = 1.0f / xy;
real reciprocal_square_sum_x = 1.0f / square_sum_x;
real reciprocal_square_sum_y = 1.0f / square_sum_y;
for (size_t j = 0; j < dim; ++j) {
prev_grad_x[j] +=
out[i] * grad[i] * (prev_out_y[j] * reciprocal_xy -
prev_out_x[j] * reciprocal_square_sum_x);
prev_grad_y[j] +=
out[i] * grad[i] * (prev_out_x[j] * reciprocal_xy -
prev_out_y[j] * reciprocal_square_sum_y);
}
}
}
}
/**
* Cosine Similarity backward Derivative
*
* \param outputs[0] forward input grad 1, size: nSamples * dim.
* \param outputs[1] forward input grad 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
*
* \param inputs[0] backward loss output grad, size : nSamples * 1.
* \param inputs[1] forward output value, size: nSamples * 1.
* \param inputs[2] forward input value 1, size: nSamples * dim.
* \param inputs[3] forward input value 2,
* size: n2 * dim (n2 == 1 or n2 == nSamples).
*/
template <DeviceType Device>
class CosSimBackwardFunc : public FunctionBase {
void init(const FuncConfig& config) override {
scale_ = config.get<real>("scale");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(inputs.size(), 4UL);
CHECK_EQ(outputs.size(), 2UL);
/// dim of out_grad and out_val == 1, column vector
CHECK_EQ(inputs[0].shape()[1], 1UL);
CHECK_EQ(inputs[1].shape()[1], 1UL);
/// nSamples of out_grad == out_val == in_val1 == in_grad1
CHECK_EQ(inputs[1].shape()[0], inputs[0].shape()[0]);
CHECK_EQ(inputs[0].shape()[0], inputs[0].shape()[0]);
CHECK_EQ(outputs[0].shape()[0], inputs[0].shape()[0]);
/// dim of in1_val1 == in_val2 == in_grad1 == in_grad2
CHECK_EQ(inputs[3].shape()[1], inputs[2].shape()[1]);
CHECK_EQ(outputs[0].shape()[1], inputs[2].shape()[1]);
CHECK_EQ(outputs[1].shape()[1], inputs[2].shape()[1]);
CHECK(inputs[0].data() && inputs[1].data() && inputs[2].data() &&
inputs[3].data() && outputs[0].data() && outputs[1].data());
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
const auto out_grad = inputs[0].matrix<Device>();
const auto out_val = inputs[1].matrix<Device>();
const auto in1_val = inputs[2].matrix<Device>();
const auto in2_val = inputs[3].matrix<Device>();
auto in1_grad = outputs[0].matrix<Device>();
auto in2_grad = outputs[1].matrix<Device>();
CosSimBackward<Device>(
out_grad, out_val, in1_val, in2_val, in1_grad, in2_grad, scale_);
}
private:
real scale_;
};
REGISTER_TYPED_FUNC(CosSimForward, CPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, CPU, CosSimBackwardFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(CosSimForward, GPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, GPU, CosSimBackwardFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief Cosine Similarity Forward.
* for each row i,
* out[i] = scale * cos(in1[i], in2[i])
* = scale * \sum_j (in1[i][j] * in2[i][j]) /
* sqrt(sum_j (in1[i][j]^2) * sum_j (in2[i][j])^2)
*
* \param[out] output output value.
* \param[in] intput1 input value.
* \param[in] intput2 input value.
* \param[in] scale default 1.0.
*
*/
template <DeviceType Device>
void CosSimForward(typename Tensor<real, Device>::Matrix& output,
const typename Tensor<real, Device>::Matrix& input1,
const typename Tensor<real, Device>::Matrix& input2,
real scale);
/**
* \brief Cosine Similarity BackWard for Derivative.
*
* \param[in] output grad backward loss output grad.
* \param[in] output val forward-output value.
* \param[in] input val1 forward input value 1.
* \param[in] input val2 forward input value 2.
* \param[in/out] input grad forward input grad 1.
* \param[in/out] input grad forward input grad 2.
* \param[in] scale default 1.0.
*
*/
template <DeviceType Device>
void CosSimBackward(const typename Tensor<real, Device>::Matrix& out_grad,
const typename Tensor<real, Device>::Matrix& out_value,
const typename Tensor<real, Device>::Matrix& in1_value,
const typename Tensor<real, Device>::Matrix& in2_value,
typename Tensor<real, Device>::Matrix& in1_grad,
typename Tensor<real, Device>::Matrix& in2_grad,
real scale);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "hl_device_functions.cuh"
#include "CosSimOp.h"
namespace paddle {
template<int block_size>
__global__ void KeCosSim(real* output,
const real* input1,
const real* input2,
int width,
int input1_height,
int input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[block_size];
__shared__ real yy[block_size];
__shared__ real xy[block_size];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
input1 += ty * width;
if (input2_height > 1) {
input2 += ty * width;
}
for (int index = tid; index < width; index += block_size) {
real x = input1[index];
real y = input2[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (tid == 0) {
output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0]));
}
}
void hlCossim(real* output,
const real* input1,
const real* input2,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
CHECK_NOTNULL(output);
CHECK_NOTNULL(input1);
CHECK_NOTNULL(input2);
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, input1_height);
KeCosSim<block_size><<<grid, threads, 0, STREAM_DEFAULT>>>
(output, input1, input2, width, input1_height, input2_height, scale);
CHECK_SYNC("hlCossim failed");
}
template <>
void CosSimForward<DEVICE_TYPE_GPU>(GpuMatrix& out_mat,
const GpuMatrix& in1_mat,
const GpuMatrix& in2_mat,
real scale) {
CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData());
CHECK(in1_mat.useGpu_ == true && in2_mat.useGpu_ == true)
<< "Matrix type are not GPU";
size_t num_samples = out_mat.getHeight();
size_t dim = in1_mat.getWidth();
real* out = out_mat.getData();
const real* x = in1_mat.getData();
const real* y = in2_mat.getData();
hlCossim(out, x, y, dim, in1_mat.getHeight(), in2_mat.getHeight(), scale);
}
template<int block_size>
__global__ void KeCosSimDerivative(const real* grad,
const real* output,
const real* prev_out_x,
const real* prev_out_y,
real* prev_grad_x,
real* prev_grad_y,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ real xx[block_size];
__shared__ real yy[block_size];
__shared__ real xy[block_size];
xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();
prev_out_x += ty * width;
prev_grad_x += ty * width;
if (input2_height > 1) {
prev_out_y += ty * width;
prev_grad_y += ty * width;
}
for (int index = tid; index < width; index += block_size) {
real x = prev_out_x[index];
real y = prev_out_y[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();
for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (xy[0] == 0) {
real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0]));
for (int index = tid; index < width; index += block_size) {
prev_grad_x[index] +=
scale * grad[ty] * prev_out_y[index] * reciprocal;
if (input2_height > 1) {
prev_grad_y[index] +=
scale * grad[ty] * prev_out_x[index] * reciprocal;
} else {
paddle::paddleAtomicAdd(prev_grad_y + index,
scale * grad[ty] * prev_out_x[index] * reciprocal);
}
}
} else {
real reciprocalXY = 1.0 / xy[0];
real reciprocalSquareSumX = 1.0 / xx[0];
real reciprocalSquareSumY = 1.0 / yy[0];
for (int index = tid; index < width; index += block_size) {
prev_grad_x[index] += output[ty] * grad[ty] *
(prev_out_y[index] * reciprocalXY -
prev_out_x[index] * reciprocalSquareSumX);
if (input2_height > 1) {
prev_grad_y[index] += output[ty] * grad[ty] *
(prev_out_x[index] * reciprocalXY -
prev_out_y[index] * reciprocalSquareSumY);
} else {
paddle::paddleAtomicAdd(prev_grad_y + index, output[ty] * grad[ty] *
(prev_out_x[index] * reciprocalXY -
prev_out_y[index] * reciprocalSquareSumY));
}
}
}
}
void hlCossimDerivative(const real* grad,
const real* output,
const real* prev_out_x,
const real* prev_out_y,
real* prev_grad_x,
real* prev_grad_y,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
CHECK_NOTNULL(grad);
CHECK_NOTNULL(output);
CHECK_NOTNULL(prev_out_x);
CHECK_NOTNULL(prev_out_y);
CHECK_NOTNULL(prev_grad_x);
CHECK_NOTNULL(prev_grad_y);
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, input1_height);
KeCosSimDerivative<block_size><<<grid, threads, 0, STREAM_DEFAULT>>>
(grad, output, prev_out_x, prev_out_y, prev_grad_x, prev_grad_y, width,
input1_height, input2_height, scale);
CHECK_SYNC("hlCossimDerivate failed");
}
template <>
void CosSimBackward<DEVICE_TYPE_GPU>(const GpuMatrix& out_grad,
const GpuMatrix& out_val,
const GpuMatrix& in1_val,
const GpuMatrix& in2_val,
GpuMatrix& in1_grad,
GpuMatrix& in2_grad,
real scale) {
CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() &&
in2_val.getData() && in1_grad.getData() && in2_grad.getData());
CHECK(out_grad.useGpu_ && out_val.useGpu_ && in1_val.useGpu_
&& in2_val.useGpu_ && in1_grad.useGpu_ && in2_grad.useGpu_)
<< "Matrix types are not equally GPU";
size_t dim = in1_val.getWidth();
const real* grad = out_grad.getData();
const real* out = out_val.getData();
const real* prev_out_x = in1_val.getData();
const real* prev_out_y = in2_val.getData();
real* prev_grad_x = in1_grad.getData();
real* prev_grad_y = in2_grad.getData();
hlCossimDerivative(grad,
out,
prev_out_x,
prev_out_y,
prev_grad_x,
prev_grad_y,
dim,
in1_val.getHeight(),
in2_val.getHeight(),
scale);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
#include "paddle/math/Matrix.h"
using namespace paddle; // NOLINT
void testCosSimForward(size_t height_x,
size_t height_y,
size_t width,
real scale) {
FunctionCompare test("CosSimForward", FuncConfig().set("scale", scale));
// prepare input arguments
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}),
ASSIGN_TO);
// run Function
test.run();
}
void testCosSimBackward(size_t height_x,
size_t height_y,
size_t width,
real scale) {
FunctionCompare test("CosSimBackward", FuncConfig().set("scale", scale));
// prepare input arguments
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}),
ADD_TO);
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}),
ADD_TO);
// run Function
test.run();
}
TEST(Matrix, cosSim) {
for (auto height_x : {10, 100, 1000}) {
for (auto height_y : {1, height_x}) {
for (auto width : {10, 100, 1000}) {
for (auto scale : {1.0, 2.0}) {
testCosSimForward(height_x, height_y, width, scale);
testCosSimBackward(height_x, height_y, width, scale);
}
}
}
}
}
...@@ -69,6 +69,54 @@ public: ...@@ -69,6 +69,54 @@ public:
gpuMemory_.back()->getBuf(), input.valueType(), input.shape())); gpuMemory_.back()->getBuf(), input.valueType(), input.shape()));
} }
// assume one copy of sequence is shared by different SequenceArgs
void addSequence(const SequenceIdArg& input) {
CHECK_EQ(input.shape().ndims(), 1UL);
size_t batchSize = input.shape()[0];
size_t numSeqs = batchSize / 10 + 1;
size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32);
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(sizeId));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(sizeId));
cpuSeq_ = std::make_shared<SequenceIdArg>(cpuMemory_.back()->getBuf(),
TensorShape{numSeqs + 1});
gpuSeq_ = std::make_shared<SequenceIdArg>(gpuMemory_.back()->getBuf(),
TensorShape{numSeqs + 1});
/// init sequence Id
initArg(*cpuSeq_, batchSize);
// todo(tianbing), delete it
CHECK_EQ(cpuSeq_->shape().getElements(), cpuSeq_->numSeqs() + 1);
CpuIVector cpuSeq(cpuSeq_->shape().getElements(), (int*)cpuSeq_->data());
GpuIVector gpuSeq(gpuSeq_->shape().getElements(), (int*)gpuSeq_->data());
gpuSeq.copyFrom(cpuSeq);
}
void addInputs(const SequenceArg& input) {
CHECK_EQ(input.shape().ndims(), 2UL);
size_t batchSize = input.shape()[0];
if (!cpuSeq_ || !gpuSeq_) { // sequence not exist
addSequence(SequenceIdArg(TensorShape{batchSize}));
}
size_t size =
input.shape().getElements() * sizeOfValuType(input.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size));
/// SequenceArg
cpuInputs_.emplace_back(
std::make_shared<SequenceArg>(cpuMemory_.back()->getBuf(),
input.valueType(),
input.shape(),
*cpuSeq_));
gpuInputs_.emplace_back(
std::make_shared<SequenceArg>(gpuMemory_.back()->getBuf(),
input.valueType(),
input.shape(),
*gpuSeq_));
}
// output need only contains shape, do not contains data. // output need only contains shape, do not contains data.
void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) { void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
size_t size = size_t size =
...@@ -116,24 +164,31 @@ public: ...@@ -116,24 +164,31 @@ public:
std::make_shared<SparseMatrixArg>(*gpuSparse_, argType)); std::make_shared<SparseMatrixArg>(*gpuSparse_, argType));
} }
void addInputs(const SequenceArg& input) { void addOutputs(const SequenceArg& output, ArgType argType = ASSIGN_TO) {
size_t batchSize = input.shape()[0]; CHECK_EQ(output.shape().ndims(), 2UL);
size_t numSeqs = batchSize / 10 + 1; size_t batchSize = output.shape()[0];
size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32);
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(sizeId));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(sizeId));
TensorShape seqsId({numSeqs + 1});
// void* cpuBuffer = cpuMemory_.back()->getBuf();
// void* gpuBuffer = gpuMemory_.back()->getBuf();
if (!cpuSeq_ || !gpuSeq_) { // sequence not exist
addSequence(SequenceIdArg(TensorShape{batchSize}));
}
size_t size = size_t size =
input.shape().getElements() * sizeOfValuType(input.valueType()); output.shape().getElements() * sizeOfValuType(output.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size)); cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size)); gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size));
// TODO: need be implemented. /// SequenceArg
cpuOutputs_.emplace_back(
std::make_shared<SequenceArg>(cpuMemory_.back()->getBuf(),
output.valueType(),
output.shape(),
*cpuSeq_,
argType));
gpuOutputs_.emplace_back(
std::make_shared<SequenceArg>(gpuMemory_.back()->getBuf(),
output.valueType(),
output.shape(),
*gpuSeq_,
argType));
} }
void addInputs(const SparseMatrixArg& input) { void addInputs(const SparseMatrixArg& input) {
...@@ -193,14 +248,44 @@ public: ...@@ -193,14 +248,44 @@ public:
std::shared_ptr<FunctionBase> getGpuFunction() const { return gpuFunc_; } std::shared_ptr<FunctionBase> getGpuFunction() const { return gpuFunc_; }
protected: protected:
// only init cpu argument, gpu argument copy from cpu argument.
void initArg(BufferArg& arg) {
CpuVector vector(arg.shape().getElements(), (real*)arg.data());
vector.uniform(0.001, 1);
}
void initArg(SequenceArg& arg) {
/// init only matrix
CpuVector vector(arg.shape().getElements(), (real*)arg.data());
vector.uniform(0.001, 1);
}
void initArg(SequenceIdArg& arg, size_t batchSize) {
size_t numSeqs = arg.numSeqs();
int* buf = reinterpret_cast<int*>(arg.data());
int pos = 0;
size_t maxLen = 2 * batchSize / numSeqs;
for (int i = 0; i < (int)numSeqs; ++i) {
int len = 1 + uniformRandom(std::min<int64_t>(
maxLen, batchSize - pos - numSeqs + i));
buf[i] = pos;
pos += len;
VLOG(1) << " len=" << len;
}
buf[numSeqs] = batchSize;
}
void initInputs() { void initInputs() {
for (size_t i = 0; i < cpuInputs_.size(); i++) { for (size_t i = 0; i < cpuInputs_.size(); i++) {
if (cpuInputs_[i]->isSparseArg()) { if (cpuInputs_[i]->isSparseArg()) {
continue; /// sparse matrix already init continue; /// sparse matrix already init
} }
initArg(*cpuInputs_[i]); if (cpuInputs_[i]->isSequenceArg()) {
initArg(dynamic_cast<SequenceArg&>(*cpuInputs_[i]));
} else {
initArg(*cpuInputs_[i]);
}
// TODO: Need a BufferCopy used to copy from one BufferArg to another. // TODO: Need a BufferCopy used to copy from one BufferArg to another.
CpuVector cpuVector(cpuInputs_[i]->shape().getElements(), CpuVector cpuVector(cpuInputs_[i]->shape().getElements(),
(real*)cpuInputs_[i]->data()); (real*)cpuInputs_[i]->data());
...@@ -217,7 +302,11 @@ protected: ...@@ -217,7 +302,11 @@ protected:
continue; /// sparse matrix already init continue; /// sparse matrix already init
} }
initArg(*cpuOutputs_[i]); if (cpuOutputs_[i]->isSequenceArg()) {
initArg(dynamic_cast<SequenceArg&>(*cpuOutputs_[i]));
} else {
initArg(*cpuOutputs_[i]);
}
// TODO: Need a BufferCopy used to copy from one BufferArg to another. // TODO: Need a BufferCopy used to copy from one BufferArg to another.
CpuVector cpuVector(cpuOutputs_[i]->shape().getElements(), CpuVector cpuVector(cpuOutputs_[i]->shape().getElements(),
...@@ -241,28 +330,6 @@ protected: ...@@ -241,28 +330,6 @@ protected:
} }
} }
// only init cpu argument, gpu argument copy from cpu argument.
void initArg(BufferArg& arg) {
CpuVector vector(arg.shape().getElements(), (real*)arg.data());
vector.uniform(0.001, 1);
}
void initArg(SequenceIdArg& arg, size_t batchSize) {
size_t numSeqs = arg.numSeqs();
int* buf = reinterpret_cast<int*>(arg.data());
int pos = 0;
size_t maxLen = 2 * batchSize / numSeqs;
for (int i = 0; i < (int)numSeqs; ++i) {
int len = uniformRandom(
std::min<int64_t>(maxLen, batchSize - pos - numSeqs + i)) +
1;
buf[i] = pos;
pos += len;
VLOG(1) << " len=" << len;
}
buf[numSeqs] = batchSize;
}
protected: protected:
std::shared_ptr<FunctionBase> cpuFunc_; std::shared_ptr<FunctionBase> cpuFunc_;
std::shared_ptr<FunctionBase> gpuFunc_; std::shared_ptr<FunctionBase> gpuFunc_;
...@@ -274,6 +341,8 @@ protected: ...@@ -274,6 +341,8 @@ protected:
std::vector<BufferArgPtr> gpuOutputs_; std::vector<BufferArgPtr> gpuOutputs_;
std::shared_ptr<CpuSparseMatrix> cpuSparse_; std::shared_ptr<CpuSparseMatrix> cpuSparse_;
std::shared_ptr<GpuSparseMatrix> gpuSparse_; std::shared_ptr<GpuSparseMatrix> gpuSparse_;
std::shared_ptr<SequenceIdArg> cpuSeq_;
std::shared_ptr<SequenceIdArg> gpuSeq_;
}; };
} // namespace paddle } // namespace paddle
...@@ -60,7 +60,7 @@ TEST(MulOp, DDDMatrixMul) { ...@@ -60,7 +60,7 @@ TEST(MulOp, DDDMatrixMul) {
if (transa && transb) { if (transa && transb) {
continue; continue;
} }
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ') VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " transa=" << transa << " transb=" << transb << " transa=" << transa << " transb=" << transb
<< " dimM=" << std::setw(5) << dimM << " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN << " dimN=" << std::setw(5) << dimN
...@@ -104,7 +104,7 @@ TEST(MuLOp, DSparseDMul) { ...@@ -104,7 +104,7 @@ TEST(MuLOp, DSparseDMul) {
for (const auto dimK : {3, 10}) { for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) { for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSR}) { for (const auto FORMAT : {SPARSE_CSR}) {
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ') VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM << " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN << " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK << " dimK=" << std::setw(5) << dimK
...@@ -150,7 +150,7 @@ TEST(MulOp, DDSparseMul) { ...@@ -150,7 +150,7 @@ TEST(MulOp, DDSparseMul) {
for (const auto dimK : {3, 10}) { for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) { for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSR, SPARSE_CSC}) { for (const auto FORMAT : {SPARSE_CSR, SPARSE_CSC}) {
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ') VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM << " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN << " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK << " dimK=" << std::setw(5) << dimK
...@@ -197,7 +197,7 @@ TEST(MulOp, SparseDDMul) { ...@@ -197,7 +197,7 @@ TEST(MulOp, SparseDDMul) {
for (const auto dimK : {3, 10}) { for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) { for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSC, SPARSE_CSR}) { for (const auto FORMAT : {SPARSE_CSC, SPARSE_CSR}) {
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ') VLOG(3) << std::setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM << " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN << " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK << " dimK=" << std::setw(5) << dimK
......
...@@ -647,7 +647,7 @@ public: ...@@ -647,7 +647,7 @@ public:
DataBatch& gpuBatch = *batch; DataBatch& gpuBatch = *batch;
std::vector<Argument>& gpuArguments = gpuBatch.getStreams(); std::vector<Argument>& gpuArguments = gpuBatch.getStreams();
gpuArguments.resize(cpuArguments.size()); gpuArguments.resize(cpuArguments.size());
gpuBatch.setSize(size); gpuBatch.setSize(bsize);
for (size_t i = 0; i < headers_.size(); ++i) { for (size_t i = 0; i < headers_.size(); ++i) {
gpuArguments[i].resizeAndCopyFrom( gpuArguments[i].resizeAndCopyFrom(
cpuArguments[i], useGpu_, HPPL_STREAM_1); cpuArguments[i], useGpu_, HPPL_STREAM_1);
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
/** /**
* calculate sequence-to-sequence edit distance * calculate sequence-to-sequence edit distance
*/ */
class CTCErrorEvaluator : public Evaluator { class CTCErrorEvaluator : public NotGetableEvaluator {
private: private:
MatrixPtr outActivations_; MatrixPtr outActivations_;
int numTimes_, numClasses_, numSequences_, blank_; int numTimes_, numClasses_, numSequences_, blank_;
......
...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/gserver/evaluators/Evaluator.h" #include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/utils/Stat.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h" #include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/StringUtil.h"
DECLARE_int32(trainer_id); DECLARE_int32(trainer_id);
...@@ -39,6 +39,14 @@ void Evaluator::eval(const NeuralNetwork& nn) { ...@@ -39,6 +39,14 @@ void Evaluator::eval(const NeuralNetwork& nn) {
*/ */
class ClassificationErrorEvaluator : public Evaluator { class ClassificationErrorEvaluator : public Evaluator {
public: public:
/*
ClassificationErrorEvaluator() : totalScore2_(0) {}
virtual void start() {
Evaluator::start();
totalScore2_ = 0;
} */
virtual void updateSamplesNum(const std::vector<Argument>& arguments) { virtual void updateSamplesNum(const std::vector<Argument>& arguments) {
if (3 == arguments.size()) { if (3 == arguments.size()) {
numSamples_ += arguments[2].value->getSum(); numSamples_ += arguments[2].value->getSum();
...@@ -76,9 +84,11 @@ public: ...@@ -76,9 +84,11 @@ public:
1, 1,
/* trans= */ false, /* trans= */ false,
useGpu(arguments[0].deviceId)); useGpu(arguments[0].deviceId));
errorMat->zeroMem(); errorMat->zeroMem();
if (label != nullptr) { if (label != nullptr) {
errorMat->classificationError(*output, *label); errorMat->classificationError(*output, *label, config_.top_k());
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) || } else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) { dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
errorMat->classificationErrorMulti( errorMat->classificationErrorMulti(
...@@ -94,6 +104,16 @@ public: ...@@ -94,6 +104,16 @@ public:
return errorMat; return errorMat;
} }
void printStats(std::ostream& os) const {
if (config_.top_k() == 1) {
os << config_.name() << "="
<< (numSamples_ ? totalScore_ / numSamples_ : 0);
} else {
os << " top_" << config_.top_k()
<< "_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0);
}
}
virtual real evalImp(std::vector<Argument>& arguments) { virtual real evalImp(std::vector<Argument>& arguments) {
MatrixPtr errorMat = calcError(arguments); MatrixPtr errorMat = calcError(arguments);
return errorMat->getSum(); return errorMat->getSum();
...@@ -102,6 +122,10 @@ public: ...@@ -102,6 +122,10 @@ public:
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client); mergeResultsOfAllClients(client);
} }
// Evaluator interface
protected:
std::string getTypeImpl() const { return "classification_error"; }
}; };
/** /**
...@@ -140,6 +164,10 @@ public: ...@@ -140,6 +164,10 @@ public:
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client); mergeResultsOfAllClients(client);
} }
// Evaluator interface
protected:
std::string getTypeImpl() const { return "seq_classification_error"; }
}; };
REGISTER_EVALUATOR(seq_classification_error, REGISTER_EVALUATOR(seq_classification_error,
SequenceClassificationErrorEvaluator); SequenceClassificationErrorEvaluator);
...@@ -230,6 +258,10 @@ public: ...@@ -230,6 +258,10 @@ public:
private: private:
IVectorPtr cpuLabel_; IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_; MatrixPtr cpuWeight_;
// Evaluator interface
protected:
std::string getTypeImpl() const { return "sum"; }
}; };
/** /**
* @brief column sum Evaluator * @brief column sum Evaluator
...@@ -337,10 +369,18 @@ public: ...@@ -337,10 +369,18 @@ public:
} }
private: private:
ColumnSumEvaluator() {}
int32_t colIdx_; int32_t colIdx_;
size_t colNum_; size_t colNum_;
MatrixPtr sum_; /* cpu matrix */ MatrixPtr sum_; /* cpu matrix */
// Evaluator interface
protected:
std::string getTypeImpl() const {
if (colIdx_ == -1)
return "last-column-sum";
else
return "column-sum";
}
}; };
void AucEvaluator::start() { void AucEvaluator::start() {
...@@ -449,6 +489,16 @@ double AucEvaluator::calcAuc() const { ...@@ -449,6 +489,16 @@ double AucEvaluator::calcAuc() const {
} }
} }
real AucEvaluator::getValueImpl() const { return calcAuc(); }
std::string AucEvaluator::getTypeImpl() const {
if (colIdx_ == -1) {
return "last-column-auc";
} else {
return "auc";
}
}
// class RankAucEvaluator // class RankAucEvaluator
REGISTER_EVALUATOR(rankauc, RankAucEvaluator); REGISTER_EVALUATOR(rankauc, RankAucEvaluator);
...@@ -528,12 +578,15 @@ double RankAucEvaluator::calcRankAuc(real* outputData, ...@@ -528,12 +578,15 @@ double RankAucEvaluator::calcRankAuc(real* outputData,
: aucTmp / (clickSum * noClickSum); : aucTmp / (clickSum * noClickSum);
} }
std::string RankAucEvaluator::getTypeImpl() const { return "rankauc"; }
// class PrecisionRecallEvaluator // class PrecisionRecallEvaluator
REGISTER_EVALUATOR(precision_recall, PrecisionRecallEvaluator); REGISTER_EVALUATOR(precision_recall, PrecisionRecallEvaluator);
void PrecisionRecallEvaluator::start() { void PrecisionRecallEvaluator::start() {
Evaluator::start(); Evaluator::start();
statsInfo_.clear(); statsInfo_.clear();
values_.clear();
} }
real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) { real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
...@@ -594,52 +647,23 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) { ...@@ -594,52 +647,23 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
} }
void PrecisionRecallEvaluator::printStats(std::ostream& os) const { void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
int label = config_.positive_label(); PrintStatsInfo info;
if (label != -1) { bool containMacroMicroInfo = getStatsInfo(&info);
CHECK(label >= 0 && label < (int)statsInfo_.size()) os << "positive_label=" << config_.positive_label()
<< "positive_label [" << label << "] should be in range [0, " << " precision=" << info.precision << " recall=" << info.recall
<< statsInfo_.size() << ")"; << " F1-score=" << info.f1;
double precision = if (containMacroMicroInfo) {
calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP); os << "macro-average-precision=" << info.macroAvgPrecision
double recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN); << " macro-average-recall=" << info.macroAvgRecall
os << "positive_label=" << label << " precision=" << precision << " macro-average-F1-score=" << info.macroAvgF1Score;
<< " recall=" << recall if (!isMultiBinaryLabel_) {
<< " F1-score=" << calcF1Score(precision, recall); // precision and recall are equal in this case
return; os << " micro-average-precision=" << info.microAvgPrecision;
} } else {
os << " micro-average-precision=" << info.microAvgPrecision
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2) << " micro-average-recall=" << info.microAvgRecall
// macro average method: precision = (precision1+precision2)/2 << " micro-average-F1-score=" << info.microAvgF1Score;
double microTotalTP = 0; }
double microTotalFP = 0;
double microTotalFN = 0;
double macroAvgPrecision = 0;
double macroAvgRecall = 0;
size_t numLabels = statsInfo_.size();
for (size_t i = 0; i < numLabels; ++i) {
microTotalTP += statsInfo_[i].TP;
microTotalFP += statsInfo_[i].FP;
microTotalFN += statsInfo_[i].FN;
macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
}
macroAvgPrecision /= numLabels;
macroAvgRecall /= numLabels;
double macroAvgF1Score = calcF1Score(macroAvgPrecision, macroAvgRecall);
os << "macro-average-precision=" << macroAvgPrecision
<< " macro-average-recall=" << macroAvgRecall
<< " macro-average-F1-score=" << macroAvgF1Score;
double microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
double microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
double microAvgF1Score = calcF1Score(microAvgPrecision, microAvgRecall);
if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case
os << " micro-average-precision=" << microAvgPrecision;
} else {
os << " micro-average-precision=" << microAvgPrecision
<< " micro-average-recall=" << microAvgRecall
<< " micro-average-F1-score=" << microAvgF1Score;
} }
} }
...@@ -721,6 +745,60 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output, ...@@ -721,6 +745,60 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
} }
} }
void PrecisionRecallEvaluator::storeLocalValues() const {
if (this->values_.size() == 0) {
PrintStatsInfo info;
bool containMacroMicroInfo = getStatsInfo(&info);
values_["precision"] = info.precision;
values_["recal"] = info.recall;
values_["F1-score"] = info.f1;
if (containMacroMicroInfo) {
values_["macro-average-precision"] = info.macroAvgPrecision;
values_["macro-average-recall"] = info.macroAvgRecall;
values_["macro-average-F1-score"] = info.macroAvgF1Score;
if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case
values_["micro-average-precision"] = info.microAvgPrecision;
} else {
values_["micro-average-precision"] = info.microAvgPrecision;
values_["micro-average-recall"] = info.microAvgRecall;
values_["micro-average-F1-score"] = info.microAvgF1Score;
}
}
}
}
void PrecisionRecallEvaluator::getNames(std::vector<std::string>* names) {
this->storeLocalValues();
names->reserve(this->values_.size());
for (auto it = this->values_.begin(); it != this->values_.end(); ++it) {
names->push_back(this->config_.name() + "." + it->first);
}
}
real PrecisionRecallEvaluator::getValue(const std::string& name,
Error* err) const {
this->storeLocalValues();
std::vector<std::string> buffers;
paddle::str::split(name, '.', &buffers);
auto it = this->values_.find(buffers[buffers.size() - 1]);
if (it == this->values_.end()) { // not found
*err = Error("No such key %s", name.c_str());
return .0f;
}
return it->second;
}
std::string PrecisionRecallEvaluator::getType(const std::string& name,
Error* err) const {
this->getValue(name, err);
if (!err->isOK()) {
return "";
}
return "precision_recall";
}
void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) { void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
size_t size = 4 * statsInfo_.size(); size_t size = 4 * statsInfo_.size();
double* buf = new double[size]; double* buf = new double[size];
...@@ -740,6 +818,47 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) { ...@@ -740,6 +818,47 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
delete[] buf; delete[] buf;
} }
bool PrecisionRecallEvaluator::getStatsInfo(
PrecisionRecallEvaluator::PrintStatsInfo* info) const {
int label = config_.positive_label();
if (label != -1) {
CHECK(label >= 0 && label < (int)statsInfo_.size())
<< "positive_label [" << label << "] should be in range [0, "
<< statsInfo_.size() << ")";
info->precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
info->recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
info->f1 = calcF1Score(info->precision, info->recall);
return false;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double microTotalTP = 0;
double microTotalFP = 0;
double microTotalFN = 0;
info->macroAvgPrecision = 0;
info->macroAvgRecall = 0;
size_t numLabels = statsInfo_.size();
for (size_t i = 0; i < numLabels; ++i) {
microTotalTP += statsInfo_[i].TP;
microTotalFP += statsInfo_[i].FP;
microTotalFN += statsInfo_[i].FN;
info->macroAvgPrecision +=
calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
info->macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
}
info->macroAvgPrecision /= numLabels;
info->macroAvgRecall /= numLabels;
info->macroAvgF1Score =
calcF1Score(info->macroAvgPrecision, info->macroAvgRecall);
info->microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
info->microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
info->microAvgF1Score =
calcF1Score(info->microAvgPrecision, info->microAvgRecall);
return true;
}
REGISTER_EVALUATOR(pnpair, PnpairEvaluator); REGISTER_EVALUATOR(pnpair, PnpairEvaluator);
void PnpairEvaluator::start() { void PnpairEvaluator::start() {
Evaluator::start(); Evaluator::start();
...@@ -864,56 +983,35 @@ void PnpairEvaluator::calc(std::vector<PredictionResult>& predictArray) { ...@@ -864,56 +983,35 @@ void PnpairEvaluator::calc(std::vector<PredictionResult>& predictArray) {
<< " calc total special pair: " << special; << " calc total special pair: " << special;
} }
std::string PnpairEvaluator::getTypeImpl() const { return "pnpair"; }
ClassRegistrar<Evaluator> Evaluator::registrar_; ClassRegistrar<Evaluator> Evaluator::registrar_;
Evaluator* Evaluator::create(const EvaluatorConfig& config) { Evaluator* Evaluator::create(const EvaluatorConfig& config) {
Evaluator* evaluator = nullptr; Evaluator* evaluator = registrar_.createByType(config.type());
if (config.type() == "classification_error") {
evaluator = new ClassificationErrorEvaluator();
} else if (config.type() == "sum") {
evaluator = new SumEvaluator();
} else if (config.type() == "last-column-sum") {
evaluator = new ColumnSumEvaluator(-1);
} else if (config.type() == "last-column-auc") {
evaluator = new AucEvaluator(-1);
} else {
evaluator = registrar_.createByType(config.type());
}
evaluator->init(config); evaluator->init(config);
return evaluator; return evaluator;
} }
REGISTER_EVALUATOR(classification_error, ClassificationErrorEvaluator);
REGISTER_EVALUATOR(sum, SumEvaluator);
static InitFunction __reg_type_auc_sum__([]() {
Evaluator::registrar_.registerClass(
"last-column-sum", [] { return new ColumnSumEvaluator(-1); });
Evaluator::registrar_.registerClass("last-column-auc",
[] { return new AucEvaluator(-1); });
});
/** /**
* @brief print value of each layer. * @brief print value of each layer.
* *
* The config file api is value_printer_evaluator. * The config file api is value_printer_evaluator.
*/ */
class ValuePrinter : public Evaluator { class ValuePrinter : public NotGetableEvaluator {
public: public:
ValuePrinter() {}
virtual void eval(const NeuralNetwork& nn) { virtual void eval(const NeuralNetwork& nn) {
for (const std::string& name : config_.input_layers()) { for (const std::string& name : config_.input_layers()) {
const Argument& argu = nn.getLayer(name)->getOutput(); nn.getLayer(name)->getOutput().printValueString(LOG(INFO),
if (argu.value) { "layer=" + name + " ");
std::ostringstream os;
argu.value->print(os);
LOG(INFO) << "layer=" << name << " value matrix:\n" << os.str();
}
if (argu.ids) {
std::ostringstream os;
argu.ids->print(os, argu.ids->getSize());
LOG(INFO) << "layer=" << name << " ids vector:\n" << os.str();
}
if (auto startPos = argu.sequenceStartPositions) {
std::ostringstream os;
startPos->getVector(false)->print(os, startPos->getSize());
LOG(INFO) << "layer=" << name << " sequence pos vector:\n" << os.str();
}
if (auto subStartPos = argu.subSequenceStartPositions) {
std::ostringstream os;
subStartPos->getVector(false)->print(os, subStartPos->getSize());
LOG(INFO) << "layer=" << name << " sub-sequence pos vector:\n"
<< os.str();
}
} }
} }
...@@ -922,15 +1020,14 @@ public: ...@@ -922,15 +1020,14 @@ public:
virtual real evalImp(std::vector<Argument>& arguments) { return 0; } virtual real evalImp(std::vector<Argument>& arguments) { return 0; }
}; };
REGISTER_EVALUATOR(value_printer, ValuePrinter); REGISTER_EVALUATOR(value_printer, ValuePrinter);
/** /**
* @brief print gradient of each layer. * @brief print gradient of each layer.
* *
* The config file api is gradient_printer_evaluator. * The config file api is gradient_printer_evaluator.
*/ */
class GradientPrinter : public Evaluator { class GradientPrinter : public NotGetableEvaluator {
public: public:
GradientPrinter() {}
virtual void eval(const NeuralNetwork& nn) { virtual void eval(const NeuralNetwork& nn) {
for (const std::string& name : config_.input_layers()) { for (const std::string& name : config_.input_layers()) {
const Argument& argu = nn.getLayer(name)->getOutput(); const Argument& argu = nn.getLayer(name)->getOutput();
...@@ -939,11 +1036,6 @@ public: ...@@ -939,11 +1036,6 @@ public:
argu.grad->print(os); argu.grad->print(os);
LOG(INFO) << "layer=" << name << " grad matrix:\n" << os.str(); LOG(INFO) << "layer=" << name << " grad matrix:\n" << os.str();
} }
if (auto startPos = argu.sequenceStartPositions) {
std::ostringstream os;
startPos->getVector(false)->print(os, startPos->getSize());
LOG(INFO) << "layer=" << name << " sequence pos vector:\n" << os.str();
}
} }
} }
...@@ -957,7 +1049,7 @@ REGISTER_EVALUATOR(gradient_printer, GradientPrinter); ...@@ -957,7 +1049,7 @@ REGISTER_EVALUATOR(gradient_printer, GradientPrinter);
* *
* The config file api is maxid_printer_evaluator. * The config file api is maxid_printer_evaluator.
*/ */
class MaxIdPrinter : public Evaluator { class MaxIdPrinter : public NotGetableEvaluator {
private: private:
IVectorPtr maxIds_; IVectorPtr maxIds_;
MatrixPtr maxValues_; MatrixPtr maxValues_;
...@@ -999,7 +1091,7 @@ REGISTER_EVALUATOR(max_id_printer, MaxIdPrinter); ...@@ -999,7 +1091,7 @@ REGISTER_EVALUATOR(max_id_printer, MaxIdPrinter);
* *
* The config file api is maxframe_printer_evaluator. * The config file api is maxframe_printer_evaluator.
*/ */
class MaxFramePrinter : public Evaluator { class MaxFramePrinter : public NotGetableEvaluator {
private: private:
IVectorPtr maxIds_; IVectorPtr maxIds_;
MatrixPtr maxValues_; MatrixPtr maxValues_;
...@@ -1086,7 +1178,7 @@ REGISTER_EVALUATOR(max_frame_printer, MaxFramePrinter); ...@@ -1086,7 +1178,7 @@ REGISTER_EVALUATOR(max_frame_printer, MaxFramePrinter);
* The config file api is seqtext_printer_evaluator. * The config file api is seqtext_printer_evaluator.
* *
*/ */
class SequenceTextPrinter : public Evaluator { class SequenceTextPrinter : public NotGetableEvaluator {
private: private:
/// dict_file, which contains a list of tokens /// dict_file, which contains a list of tokens
std::vector<std::string> dict_; std::vector<std::string> dict_;
...@@ -1253,4 +1345,6 @@ public: ...@@ -1253,4 +1345,6 @@ public:
}; };
REGISTER_EVALUATOR(classification_error_printer, ClassificationErrorPrinter); REGISTER_EVALUATOR(classification_error_printer, ClassificationErrorPrinter);
std::string DummyEvaluator::getTypeImpl() const { return "dummy"; }
} // namespace paddle } // namespace paddle
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/parameter/Argument.h" #include "paddle/parameter/Argument.h"
#include "paddle/pserver/ParameterClient2.h" #include "paddle/pserver/ParameterClient2.h"
#include "paddle/utils/ClassRegistrar.h" #include "paddle/utils/ClassRegistrar.h"
#include "paddle/utils/Error.h"
namespace paddle { namespace paddle {
...@@ -117,12 +118,105 @@ public: ...@@ -117,12 +118,105 @@ public:
static ClassRegistrar<Evaluator> registrar_; static ClassRegistrar<Evaluator> registrar_;
/**
* @brief getNames will return all field names of current evaluator.
*
* The format of name is `evaluator_name.evaluator_fields`. If the evaluator
* has multiple field, the name could be `evaluator_name.field1`. For example
* the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get
* names will return `precision_recall_evaluator.precision`,
* `precision_recall_evaluator.recal`, etc.
*
* Also, if current Evaluator is a combined evaluator. getNames will return
* all names of all evaluators inside the combined evaluator.
*
* @param names [out]: the field names of current evaluator.
* @note Never clear the names parameter inside getNames.
*/
virtual void getNames(std::vector<std::string>* names) {
names->push_back(config_.name());
}
/**
* @brief getValue will return the current evaluate value of one field.
*
* @param name: The field name of current evaluator.
* @param err [out]: The error state.
*
* @return The evaluate value(metric).
*/
virtual real getValue(const std::string& name, Error* err) const {
if (name != config_.name()) {
*err = Error("no such name of evaluator %s", name.c_str());
return .0f;
}
return this->getValueImpl();
}
/**
* @brief getType will return the evaluator type by field name.
*
* Evaluate Type is the current type of evaluator in string. Such as 'auc',
* 'precision_recall'. In combined evaluator, different name may get different
* evaluate type because it could be evaluated by different evaluator inside.
*
* @param name: The field name of current Evaluator.
* @param err: The error state. nullptr means don't care.
* @return the evaluator type string.
*/
virtual std::string getType(const std::string& name, Error* err) const {
if (name != config_.name()) {
*err = Error("no such name of evaluator %s", name.c_str());
return std::string();
}
return this->getTypeImpl();
}
protected:
/**
* @brief getValueImpl The simplest way to define getValue result. If this
* evaluator doesn't contain multiple fields, and do not throw any error, just
* implemented this method to get the evaluate result(metric).
* @return Evaluate result(metric).
*/
virtual real getValueImpl() const {
return numSamples_ != .0 ? totalScore_ / numSamples_ : .0;
}
/**
* @brief getTypeImpl The simplest way to define getType result. If this
* evaluator doesn't combine many evaluators, the get type should only return
* itself type.
* @return Evaluator type.
*/
virtual std::string getTypeImpl() const { return "base"; }
protected: protected:
EvaluatorConfig config_; EvaluatorConfig config_;
double numSamples_; double numSamples_;
double totalScore_; double totalScore_;
}; };
/**
* @brief The NotGetableEvaluator class is the base class of evaluator that
* cannot get value in runtime. The most NotGetableEvaluator is Printer
* Evaluator, which is only used to debug network configuration.
*/
class NotGetableEvaluator : public Evaluator {
// Evaluator interface
public:
void getNames(std::vector<std::string>* names) {}
real getValue(const std::string& name, Error* err) const {
*err = Error("Not implemented");
return .0f;
}
std::string getType(const std::string& name, Error* err) const {
*err = Error("Not implemented");
return "";
}
};
class DummyEvaluator : public Evaluator { class DummyEvaluator : public Evaluator {
public: public:
DummyEvaluator() {} DummyEvaluator() {}
...@@ -135,6 +229,10 @@ public: ...@@ -135,6 +229,10 @@ public:
} }
virtual void finish() {} virtual void finish() {}
virtual void printStats(std::ostream&) const {} virtual void printStats(std::ostream&) const {}
// Evaluator interface
protected:
std::string getTypeImpl() const;
}; };
/** /**
* @brief evaluate AUC using colIdx-th column as prediction. * @brief evaluate AUC using colIdx-th column as prediction.
...@@ -191,6 +289,11 @@ private: ...@@ -191,6 +289,11 @@ private:
} }
double calcAuc() const; double calcAuc() const;
// Evaluator interface
protected:
real getValueImpl() const;
std::string getTypeImpl() const;
}; };
/** /**
...@@ -223,6 +326,10 @@ private: ...@@ -223,6 +326,10 @@ private:
real* clickData, real* clickData,
real* pvData, real* pvData,
size_t size); size_t size);
// Evaluator interface
protected:
std::string getTypeImpl() const;
}; };
/** /**
* @brief precision, recall and f1 score Evaluator * @brief precision, recall and f1 score Evaluator
...@@ -272,6 +379,20 @@ private: ...@@ -272,6 +379,20 @@ private:
IVectorPtr cpuLabel_; IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_; MatrixPtr cpuWeight_;
struct PrintStatsInfo {
double precision;
double recall;
double f1;
double macroAvgPrecision;
double macroAvgRecall;
double macroAvgF1Score;
double microAvgPrecision;
double microAvgRecall;
double microAvgF1Score;
};
bool getStatsInfo(PrintStatsInfo* info) const;
void calcStatsInfo(const MatrixPtr& output, void calcStatsInfo(const MatrixPtr& output,
const IVectorPtr& label, const IVectorPtr& label,
const MatrixPtr& weight); const MatrixPtr& weight);
...@@ -303,6 +424,15 @@ private: ...@@ -303,6 +424,15 @@ private:
return 0; return 0;
} }
} }
mutable std::unordered_map<std::string, real> values_;
void storeLocalValues() const;
// Evaluator interface
public:
void getNames(std::vector<std::string>* names);
real getValue(const std::string& name, Error* err) const;
std::string getType(const std::string& name, Error* err) const;
}; };
/* /*
...@@ -349,8 +479,7 @@ public: ...@@ -349,8 +479,7 @@ public:
virtual void finish() { calc(predictArray_); } virtual void finish() { calc(predictArray_); }
virtual void printStats(std::ostream& os) const { virtual void printStats(std::ostream& os) const {
os << " pos/neg" os << " pos/neg=" << this->getValueImpl();
<< "=" << pairArray_[0] / ((pairArray_[1] <= 0) ? 1.0 : pairArray_[1]);
} }
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
...@@ -366,6 +495,13 @@ private: ...@@ -366,6 +495,13 @@ private:
IVectorPtr cpuLabel_; IVectorPtr cpuLabel_;
IVectorPtr cpuInfo_; IVectorPtr cpuInfo_;
MatrixPtr cpuWeight_; MatrixPtr cpuWeight_;
// Evaluator interface
protected:
real getValueImpl() const {
return pairArray_[0] / ((pairArray_[1] <= 0) ? 1.0 : pairArray_[1]);
}
std::string getTypeImpl() const;
}; };
} // namespace paddle } // namespace paddle
...@@ -306,7 +306,6 @@ void NeuralNetwork::onPassEnd() { ...@@ -306,7 +306,6 @@ void NeuralNetwork::onPassEnd() {
class CombinedEvaluator : public Evaluator { class CombinedEvaluator : public Evaluator {
public: public:
CombinedEvaluator() {}
void addEvaluator(std::unique_ptr<Evaluator>&& evaluator) { void addEvaluator(std::unique_ptr<Evaluator>&& evaluator) {
evaluators_.emplace_back(std::move(evaluator)); evaluators_.emplace_back(std::move(evaluator));
} }
...@@ -346,6 +345,55 @@ public: ...@@ -346,6 +345,55 @@ public:
protected: protected:
std::vector<std::unique_ptr<Evaluator>> evaluators_; std::vector<std::unique_ptr<Evaluator>> evaluators_;
// Evaluator interface
public:
/**
* @brief getNames will return all inside evaluators' names.
* @param names [out]: return names.
*/
void getNames(std::vector<std::string>* names) {
for (auto& eval : evaluators_) {
eval->getNames(names);
}
}
/**
* @brief getValue could get all inside evaluators' value.
*/
real getValue(const std::string& name, Error* err) const {
return this->getMethodHelper<real>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getValue(name, err);
});
}
/**
* @brief getType could get all inside evaluators' type.
*/
std::string getType(const std::string& name, Error* err) const {
return this->getMethodHelper<std::string>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getType(name, err);
});
}
private:
template <typename T>
T getMethodHelper(const std::string& name,
Error* err,
const std::function<T(const std::unique_ptr<Evaluator>&)>&
callback) const {
for (auto& eval : evaluators_) {
std::vector<std::string> names;
eval->getNames(&names);
if (std::find(names.begin(), names.end(), name) != names.end()) {
return callback(eval);
}
}
*err = Error("No such key %s", name.c_str());
return T();
}
}; };
Evaluator* NeuralNetwork::makeEvaluator() const { Evaluator* NeuralNetwork::makeEvaluator() const {
......
...@@ -155,7 +155,8 @@ protected: ...@@ -155,7 +155,8 @@ protected:
public: public:
explicit BootBiasLayer(const LayerConfig& config) : Layer(config) {} explicit BootBiasLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
if (!Layer::init(layerMap, parameterMap)) return false; if (!Layer::init(layerMap, parameterMap)) return false;
if (biasParameter_) { if (biasParameter_) {
...@@ -174,7 +175,7 @@ public: ...@@ -174,7 +175,7 @@ public:
} }
} }
virtual void forward(PassType passType) { void forward(PassType passType) override {
if (biases_) { if (biases_) {
MatrixPtr outV = getOutputValue(); MatrixPtr outV = getOutputValue();
outV->addBias(*(biases_->getW()), 1); outV->addBias(*(biases_->getW()), 1);
...@@ -182,7 +183,7 @@ public: ...@@ -182,7 +183,7 @@ public:
} }
} }
virtual void backward(const UpdateCallback& callback) { void backward(const UpdateCallback& callback) override {
if (biases_) { if (biases_) {
backwardActivation(); backwardActivation();
biases_->getWGrad()->collectBias(*getOutputGrad(), 1); biases_->getWGrad()->collectBias(*getOutputGrad(), 1);
......
...@@ -44,19 +44,20 @@ public: ...@@ -44,19 +44,20 @@ public:
/** /**
* Intialization of AddtoLayer. * Intialization of AddtoLayer.
*/ */
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* Forward propagation. * Forward propagation.
* @note There is no weight matrix for each input, * @note There is no weight matrix for each input,
* because it just a simple add operation. * because it just a simple add operation.
*/ */
void forward(PassType passType); void forward(PassType passType) override;
/** /**
* Backward propagation. * Backward propagation.
*/ */
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -35,7 +35,8 @@ public: ...@@ -35,7 +35,8 @@ public:
~AgentLayer() {} ~AgentLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
// if *numSamples* set, // if *numSamples* set,
// real layer output will only use first *numSamples* rows // real layer output will only use first *numSamples* rows
...@@ -44,8 +45,8 @@ public: ...@@ -44,8 +45,8 @@ public:
numSamples_ = numSamples; numSamples_ = numSamples;
} }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) {} void backward(const UpdateCallback& callback = nullptr) override {}
}; };
/** /**
...@@ -56,8 +57,8 @@ public: ...@@ -56,8 +57,8 @@ public:
explicit SequenceAgentLayer(const LayerConfig& config) : AgentLayer(config) {} explicit SequenceAgentLayer(const LayerConfig& config) : AgentLayer(config) {}
~SequenceAgentLayer() {} ~SequenceAgentLayer() {}
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) {} void backward(const UpdateCallback& callback = nullptr) override {}
}; };
/** /**
...@@ -78,7 +79,8 @@ public: ...@@ -78,7 +79,8 @@ public:
virtual ~GatherAgentLayer() {} virtual ~GatherAgentLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
// call before addRealLayer // call before addRealLayer
void copyIdAndSequenceInfo(const Argument& input, void copyIdAndSequenceInfo(const Argument& input,
...@@ -88,8 +90,8 @@ public: ...@@ -88,8 +90,8 @@ public:
// add one real layer, can call many times // add one real layer, can call many times
void addRealLayer(LayerPtr layer) { realLayers_.push_back(layer); } void addRealLayer(LayerPtr layer) { realLayers_.push_back(layer); }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
}; };
/** /**
...@@ -133,7 +135,8 @@ public: ...@@ -133,7 +135,8 @@ public:
virtual ~ScatterAgentLayer() {} virtual ~ScatterAgentLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* @brief set real layer in generation * @brief set real layer in generation
...@@ -182,8 +185,8 @@ public: ...@@ -182,8 +185,8 @@ public:
numSequences_ = numSequences; numSequences_ = numSequences;
} }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
}; };
/** /**
......
...@@ -38,12 +38,11 @@ public: ...@@ -38,12 +38,11 @@ public:
explicit AverageLayer(const LayerConfig& config) explicit AverageLayer(const LayerConfig& config)
: SequencePoolLayer(config) {} : SequencePoolLayer(config) {}
~AverageLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
protected: protected:
MatrixPtr outMtx_; MatrixPtr outMtx_;
......
...@@ -52,7 +52,8 @@ public: ...@@ -52,7 +52,8 @@ public:
*/ */
static Layer* create(const LayerConfig& config); static Layer* create(const LayerConfig& config);
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* @brief Calculate feature map size. Some input uses frameHeight and * @brief Calculate feature map size. Some input uses frameHeight and
......
...@@ -33,9 +33,10 @@ public: ...@@ -33,9 +33,10 @@ public:
~BatchNormalizationLayer() {} ~BatchNormalizationLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType); const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback = nullptr); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
/// Epsilon value used in the batch normalization formula. /// Epsilon value used in the batch normalization formula.
...@@ -58,7 +59,7 @@ protected: ...@@ -58,7 +59,7 @@ protected:
/// to batch, channels* imagePixels. /// to batch, channels* imagePixels.
void shrinkMat(const MatrixPtr& in, MatrixPtr& out); void shrinkMat(const MatrixPtr& in, MatrixPtr& out);
void onPassEnd() { firstTest_ = true; } void onPassEnd() override { firstTest_ = true; }
MatrixPtr tmpMat_, tmpGrad_; MatrixPtr tmpMat_, tmpGrad_;
MatrixPtr expandedIn_, expandedOut_; MatrixPtr expandedIn_, expandedOut_;
......
...@@ -38,9 +38,10 @@ public: ...@@ -38,9 +38,10 @@ public:
virtual ~BilinearInterpLayer() {} virtual ~BilinearInterpLayer() {}
size_t getSize(); size_t getSize();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType); const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback = nullptr); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -58,10 +58,11 @@ public: ...@@ -58,10 +58,11 @@ public:
~BlockExpandLayer() {} ~BlockExpandLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -32,9 +32,10 @@ namespace paddle { ...@@ -32,9 +32,10 @@ namespace paddle {
class CRFDecodingLayer : public CRFLayer { class CRFDecodingLayer : public CRFLayer {
public: public:
explicit CRFDecodingLayer(const LayerConfig& config) : CRFLayer(config) {} explicit CRFDecodingLayer(const LayerConfig& config) : CRFLayer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
virtual void forward(PassType passType); const ParameterMap& parameterMap) override;
virtual void backward(const UpdateCallback& callback); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected: protected:
std::unique_ptr<LinearChainCRF> crf_; std::unique_ptr<LinearChainCRF> crf_;
......
...@@ -29,9 +29,10 @@ namespace paddle { ...@@ -29,9 +29,10 @@ namespace paddle {
class CRFLayer : public Layer { class CRFLayer : public Layer {
public: public:
explicit CRFLayer(const LayerConfig& config) : Layer(config) {} explicit CRFLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
virtual void forward(PassType passType); const ParameterMap& parameterMap) override;
virtual void backward(const UpdateCallback& callback); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected: protected:
size_t numClasses_; size_t numClasses_;
......
...@@ -22,10 +22,11 @@ namespace paddle { ...@@ -22,10 +22,11 @@ namespace paddle {
class CTCLayer : public Layer { class CTCLayer : public Layer {
public: public:
explicit CTCLayer(const LayerConfig& config) : Layer(config) {} explicit CTCLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
virtual void forward(PassType passType); const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void forwardImp(const Argument& softmaxSeqs, const Argument& labelSeqs); void forwardImp(const Argument& softmaxSeqs, const Argument& labelSeqs);
virtual void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
void backwardImp(const UpdateCallback& callback, void backwardImp(const UpdateCallback& callback,
const Argument& softmaxSeqs, const Argument& softmaxSeqs,
const Argument& labelSeqs); const Argument& labelSeqs);
......
...@@ -28,10 +28,11 @@ public: ...@@ -28,10 +28,11 @@ public:
~ConcatenateLayer() {} ~ConcatenateLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(concat, ConcatenateLayer); REGISTER_LAYER(concat, ConcatenateLayer);
...@@ -101,10 +102,11 @@ public: ...@@ -101,10 +102,11 @@ public:
~ConcatenateLayer2() {} ~ConcatenateLayer2() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
std::vector<std::unique_ptr<Projection>> projections_; std::vector<std::unique_ptr<Projection>> projections_;
......
...@@ -80,7 +80,8 @@ protected: ...@@ -80,7 +80,8 @@ protected:
public: public:
explicit ConvBaseLayer(const LayerConfig& config) : Layer(config) {} explicit ConvBaseLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* imgSizeH_ and imgSizeW_ will be set according to the previous input layers * imgSizeH_ and imgSizeW_ will be set according to the previous input layers
......
...@@ -47,10 +47,11 @@ public: ...@@ -47,10 +47,11 @@ public:
~ConvShiftLayer() {} ~ConvShiftLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(conv_shift, ConvShiftLayer); REGISTER_LAYER(conv_shift, ConvShiftLayer);
......
...@@ -49,10 +49,11 @@ public: ...@@ -49,10 +49,11 @@ public:
~ConvexCombinationLayer() {} ~ConvexCombinationLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(convex_comb, ConvexCombinationLayer); REGISTER_LAYER(convex_comb, ConvexCombinationLayer);
......
...@@ -26,15 +26,23 @@ bool CosSimLayer::init(const LayerMap& layerMap, ...@@ -26,15 +26,23 @@ bool CosSimLayer::init(const LayerMap& layerMap,
Layer::init(layerMap, parameterMap); Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 2LU); CHECK_EQ(inputLayers_.size(), 2LU);
createFunction(forward_,
"CosSimForward",
FuncConfig().set("scale", (real)config_.cos_scale()));
createFunction(backward_,
"CosSimBackward",
FuncConfig().set("scale", (real)config_.cos_scale()));
return true; return true;
} }
void CosSimLayer::forward(PassType passType) { void CosSimLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
/* malloc memory for the output_ if necessary */ /* malloc memory for the output_ if necessary */
int batchSize = getInputValue(0)->getHeight(); int batchSize = getInputValue(0)->getHeight();
int size = getSize(); int size = getSize();
CHECK_EQ(forward_.size(), 1) << "Only one forward function needed";
{ {
REGISTER_TIMER_INFO("CosFwResetTimer", getName().c_str()); REGISTER_TIMER_INFO("CosFwResetTimer", getName().c_str());
...@@ -42,26 +50,43 @@ void CosSimLayer::forward(PassType passType) { ...@@ -42,26 +50,43 @@ void CosSimLayer::forward(PassType passType) {
} }
MatrixPtr outV = getOutputValue(); MatrixPtr outV = getOutputValue();
/* activation */ { /* activation */ {
REGISTER_TIMER_INFO("CosFwAtvTimer", getName().c_str()); REGISTER_TIMER_INFO("CosFwAtvTimer", getName().c_str());
MatrixPtr prevOut1 = getInputValue(0); MatrixPtr prevOut1 = getInputValue(0);
MatrixPtr prevOut2 = getInputValue(1); MatrixPtr prevOut2 = getInputValue(1);
outV->cosSim(*prevOut1, *prevOut2, config_.cos_scale());
CHECK(outV && prevOut1 && prevOut2);
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*prevOut1);
inputs.addArg(*prevOut2);
outputs.addArg(*outV, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
} }
} }
void CosSimLayer::backward(const UpdateCallback& callback) { void CosSimLayer::backward(const UpdateCallback& callback) {
/* activation */ { /* activation */ {
REGISTER_TIMER_INFO("CosBpAtvTimer", getName().c_str()); REGISTER_TIMER_INFO("CosBpAtvTimer", getName().c_str());
MatrixPtr outG = this->getOutputGrad(); CHECK_EQ(backward_.size(), 1) << "Only one backward function needed";
outG->cosSimDerivative(*this->getOutputValue(), const auto outG = this->getOutputGrad();
*getInputValue(0), const auto outV = this->getOutputValue();
*getInputValue(1), const auto inV1 = this->getInputValue(0);
*getInputGrad(0), const auto inV2 = this->getInputValue(1);
*getInputGrad(1), auto inG1 = this->getInputGrad(0);
config_.cos_scale()); auto inG2 = this->getInputGrad(1);
CHECK(outG && outV && inV1 && inV2 && inG1 && inG2);
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*outG);
inputs.addArg(*outV);
inputs.addArg(*inV1);
inputs.addArg(*inV2);
outputs.addArg(*inG1, ADD_TO);
outputs.addArg(*inG2, ADD_TO);
backward_[0]->calc(inputs, outputs);
} }
} }
......
...@@ -28,7 +28,7 @@ namespace paddle { ...@@ -28,7 +28,7 @@ namespace paddle {
* *
* - Input1: A vector (batchSize * dataDim) * * - Input1: A vector (batchSize * dataDim) *
* - Input2: A vector (batchSize * dataDim) or (1 * dataDim) * * - Input2: A vector (batchSize * dataDim) or (1 * dataDim) *
* - Output: A vector (dataDim * 1) * - Output: A vector (batchSize * 1)
* *
* The config file api is cos_sim. * The config file api is cos_sim.
*/ */
...@@ -38,10 +38,11 @@ public: ...@@ -38,10 +38,11 @@ public:
~CosSimLayer() {} ~CosSimLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
namespace paddle { namespace paddle {
/** /**
* @brief A layer for computing cosine similarity between a vector * @brief A layer for computing cosine similarity between a vector
* and each row of a matrix * and each row of a matrix
...@@ -46,10 +45,11 @@ public: ...@@ -46,10 +45,11 @@ public:
~CosSimVecMatLayer() {} ~CosSimVecMatLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(cos_vm, CosSimVecMatLayer); REGISTER_LAYER(cos_vm, CosSimVecMatLayer);
...@@ -97,11 +97,22 @@ bool CosSimVecMatLayer::init(const LayerMap& layerMap, ...@@ -97,11 +97,22 @@ bool CosSimVecMatLayer::init(const LayerMap& layerMap,
dataDim, dataDim,
/* trans= */ false, /* trans= */ false,
useGpu_); useGpu_);
CHECK(tmpRow0 && tmpRow1 && tmpRow2 && tmpRow3 && tmpMtx0 && tmpMtx1);
createFunction(forward_,
"CosSimForward",
FuncConfig().set("scale", (real)config_.cos_scale()));
createFunction(backward_,
"CosSimBackward",
FuncConfig().set("scale", (real)config_.cos_scale()));
return true; return true;
} }
void CosSimVecMatLayer::forward(PassType passType) { void CosSimVecMatLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
CHECK_EQ(forward_.size(), 1) << "Only one forward function needed";
MatrixPtr inV0 = getInputValue(0); MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1); MatrixPtr inV1 = getInputValue(1);
...@@ -117,17 +128,25 @@ void CosSimVecMatLayer::forward(PassType passType) { ...@@ -117,17 +128,25 @@ void CosSimVecMatLayer::forward(PassType passType) {
} }
MatrixPtr outV = getOutputValue(); MatrixPtr outV = getOutputValue();
CHECK(outV && inV0 && inV1);
REGISTER_TIMER_INFO("FwCosVMTimer", getName().c_str()); REGISTER_TIMER_INFO("FwCosVMTimer", getName().c_str());
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
tmpRow0->setData(inV0->rowBuf(i)); tmpRow0->setData(inV0->rowBuf(i));
tmpMtx0->setData(inV1->rowBuf(i)); tmpMtx0->setData(inV1->rowBuf(i));
tmpRow2->setData(outV->rowBuf(i)); tmpRow2->setData(outV->rowBuf(i));
tmpRow2->cosSim(*(tmpMtx0), *(tmpRow0), config_.cos_scale());
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*tmpMtx0);
inputs.addArg(*tmpRow0);
outputs.addArg(*tmpRow2, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
} }
} }
void CosSimVecMatLayer::backward(const UpdateCallback& callback) { void CosSimVecMatLayer::backward(const UpdateCallback& callback) {
CHECK_EQ(backward_.size(), 1) << "Only one forward function needed";
MatrixPtr inV0 = getInputValue(0); MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1); MatrixPtr inV1 = getInputValue(1);
MatrixPtr inG0 = getInputGrad(0); MatrixPtr inG0 = getInputGrad(0);
...@@ -136,27 +155,27 @@ void CosSimVecMatLayer::backward(const UpdateCallback& callback) { ...@@ -136,27 +155,27 @@ void CosSimVecMatLayer::backward(const UpdateCallback& callback) {
MatrixPtr outG = getOutputGrad(); MatrixPtr outG = getOutputGrad();
size_t batchSize = inV0->getHeight(); size_t batchSize = inV0->getHeight();
CHECK(inV0 && inV1 && inG0 && inG1 && outV && outG);
REGISTER_TIMER_INFO("BwCosVMTimer", getName().c_str()); REGISTER_TIMER_INFO("BwCosVMTimer", getName().c_str());
if (inG0 && inG1) { for (size_t i = 0; i < batchSize; i++) {
for (size_t i = 0; i < batchSize; i++) { tmpRow0->setData(inV0->rowBuf(i));
tmpRow0->setData(inV0->rowBuf(i)); tmpRow1->setData(inG0->rowBuf(i));
tmpRow1->setData(inG0->rowBuf(i)); tmpMtx0->setData(inV1->rowBuf(i));
tmpMtx0->setData(inV1->rowBuf(i)); tmpMtx1->setData(inG1->rowBuf(i));
tmpMtx1->setData(inG1->rowBuf(i)); tmpRow2->setData(outV->rowBuf(i));
tmpRow2->setData(outV->rowBuf(i)); tmpRow3->setData(outG->rowBuf(i));
tmpRow3->setData(outG->rowBuf(i));
BufferArgs inputs;
tmpRow3->cosSimDerivative(*(tmpRow2), BufferArgs outputs;
*(tmpMtx0), inputs.addArg(*tmpRow3);
*(tmpRow0), inputs.addArg(*tmpRow2);
*(tmpMtx1), inputs.addArg(*tmpMtx0);
*(tmpRow1), inputs.addArg(*tmpRow0);
config_.cos_scale()); outputs.addArg(*tmpMtx1, ADD_TO);
} outputs.addArg(*tmpRow1, ADD_TO);
} else {
CHECK(!inG0 || !inG1) << "Not supported"; backward_[0]->calc(inputs, outputs);
} }
} }
......
...@@ -367,8 +367,6 @@ void LambdaCost::backward(const UpdateCallback& callback) { ...@@ -367,8 +367,6 @@ void LambdaCost::backward(const UpdateCallback& callback) {
getInputGrad(0)->add(*marginGrad_); getInputGrad(0)->add(*marginGrad_);
} }
void LambdaCost::onPassEnd() {}
void LambdaCost::calcGrad(const real* outputScore, void LambdaCost::calcGrad(const real* outputScore,
const real* score, const real* score,
real* gradData, real* gradData,
...@@ -611,14 +609,15 @@ class SumCostLayer : public Layer { ...@@ -611,14 +609,15 @@ class SumCostLayer : public Layer {
public: public:
explicit SumCostLayer(const LayerConfig& config) : Layer(config) {} explicit SumCostLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap); bool ret = Layer::init(layerMap, parameterMap);
if (!ret) return ret; if (!ret) return ret;
CHECK_EQ(inputLayers_.size(), 1UL); CHECK_EQ(inputLayers_.size(), 1UL);
return true; return true;
} }
virtual void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
const MatrixPtr& input = getInputValue(0); const MatrixPtr& input = getInputValue(0);
...@@ -629,7 +628,7 @@ public: ...@@ -629,7 +628,7 @@ public:
output_.value->sumRows(*input, /* scaleSum= */ 1, /* scaleDest= */ 0); output_.value->sumRows(*input, /* scaleSum= */ 1, /* scaleDest= */ 0);
} }
virtual void backward(const UpdateCallback& callback = nullptr) { void backward(const UpdateCallback& callback = nullptr) override {
getInputGrad(0)->add((real)1); getInputGrad(0)->add((real)1);
} }
}; };
......
...@@ -32,15 +32,16 @@ class CostLayer : public Layer { ...@@ -32,15 +32,16 @@ class CostLayer : public Layer {
public: public:
explicit CostLayer(const LayerConfig& config) : Layer(config) {} explicit CostLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer() { return inputLayers_[0]; } LayerPtr getOutputLayer() { return inputLayers_[0]; }
LayerPtr getLabelLayer() { return inputLayers_[1]; } LayerPtr getLabelLayer() { return inputLayers_[1]; }
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
virtual void forwardImp(Matrix& outputValue, virtual void forwardImp(Matrix& outputValue,
Argument& label, Argument& label,
...@@ -68,11 +69,14 @@ public: ...@@ -68,11 +69,14 @@ public:
explicit MultiClassCrossEntropy(const LayerConfig& config) explicit MultiClassCrossEntropy(const LayerConfig& config)
: CostLayer(config) {} : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
}; };
/** /**
...@@ -95,11 +99,14 @@ public: ...@@ -95,11 +99,14 @@ public:
explicit MultiClassCrossEntropyWithSelfNorm(const LayerConfig& config) explicit MultiClassCrossEntropyWithSelfNorm(const LayerConfig& config)
: CostLayer(config) {} : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
protected: protected:
MatrixPtr sftMaxSum_; MatrixPtr sftMaxSum_;
...@@ -117,11 +124,14 @@ public: ...@@ -117,11 +124,14 @@ public:
explicit SoftBinaryClassCrossEntropy(const LayerConfig& config) explicit SoftBinaryClassCrossEntropy(const LayerConfig& config)
: CostLayer(config) {} : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
protected: protected:
MatrixPtr targetPerDim_; MatrixPtr targetPerDim_;
...@@ -139,11 +149,14 @@ public: ...@@ -139,11 +149,14 @@ public:
explicit SumOfSquaresCostLayer(const LayerConfig& config) explicit SumOfSquaresCostLayer(const LayerConfig& config)
: CostLayer(config) {} : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
}; };
/** /**
...@@ -162,17 +175,18 @@ class RankingCost : public Layer { ...@@ -162,17 +175,18 @@ class RankingCost : public Layer {
public: public:
explicit RankingCost(const LayerConfig& config) : Layer(config) {} explicit RankingCost(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer(size_t i) { return inputLayers_[i]; } LayerPtr getOutputLayer(size_t i) { return inputLayers_[i]; }
LayerPtr getLabelLayer() { return inputLayers_[2]; } LayerPtr getLabelLayer() { return inputLayers_[2]; }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
void onPassEnd(); void onPassEnd() override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost) { void forwardImp(Matrix& output, Argument& label, Matrix& cost) {
(void)output; (void)output;
...@@ -214,17 +228,16 @@ class LambdaCost : public Layer { ...@@ -214,17 +228,16 @@ class LambdaCost : public Layer {
public: public:
explicit LambdaCost(const LayerConfig& config) : Layer(config) {} explicit LambdaCost(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer() { return inputLayers_[0]; } LayerPtr getOutputLayer() { return inputLayers_[0]; }
LayerPtr getScoreLayer() { return inputLayers_[1]; } LayerPtr getScoreLayer() { return inputLayers_[1]; }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
void onPassEnd();
real calcNDCG(const real* outputScore, const real* score, int size); real calcNDCG(const real* outputScore, const real* score, int size);
void calcGrad(const real* outputScore, void calcGrad(const real* outputScore,
...@@ -256,11 +269,14 @@ public: ...@@ -256,11 +269,14 @@ public:
explicit MultiBinaryLabelCrossEntropy(const LayerConfig& config) explicit MultiBinaryLabelCrossEntropy(const LayerConfig& config)
: CostLayer(config) {} : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
}; };
/** /**
...@@ -282,13 +298,16 @@ class HuberTwoClass : public CostLayer { ...@@ -282,13 +298,16 @@ class HuberTwoClass : public CostLayer {
public: public:
explicit HuberTwoClass(const LayerConfig& config) : CostLayer(config) {} explicit HuberTwoClass(const LayerConfig& config) : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void forwardImpIn(Matrix& output, Argument& label, Matrix& cost); void forwardImpIn(Matrix& output, Argument& label, Matrix& cost);
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
void backwardImpIn(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImpIn(Matrix& outputValue, Argument& label, Matrix& outputGrad);
}; };
......
...@@ -35,14 +35,15 @@ public: ...@@ -35,14 +35,15 @@ public:
~CudnnBatchNormLayer(); ~CudnnBatchNormLayer();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* reshape tensor of ioDesc_. * reshape tensor of ioDesc_.
*/ */
void reshape(int batchSize); void reshape(int batchSize);
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
/** /**
......
...@@ -45,9 +45,10 @@ public: ...@@ -45,9 +45,10 @@ public:
~CudnnConvLayer(); ~CudnnConvLayer();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType); const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
void addBiases(); void addBiases();
void bpropBiases(); void bpropBiases();
}; };
......
...@@ -45,7 +45,8 @@ public: ...@@ -45,7 +45,8 @@ public:
hl_pooling_mode_t* mode = nullptr); hl_pooling_mode_t* mode = nullptr);
explicit CudnnPoolLayer(const LayerConfig& config); explicit CudnnPoolLayer(const LayerConfig& config);
~CudnnPoolLayer(); ~CudnnPoolLayer();
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* Reshape input and output tensor descriptor. * Reshape input and output tensor descriptor.
...@@ -53,8 +54,8 @@ public: ...@@ -53,8 +54,8 @@ public:
* So reshaping is needed. * So reshaping is needed.
*/ */
void reshape(int batchSize); void reshape(int batchSize);
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -33,13 +33,13 @@ public: ...@@ -33,13 +33,13 @@ public:
/** /**
* Prefetch sparse matrix/ids only. * Prefetch sparse matrix/ids only.
*/ */
void prefetch() { output_ = data_; } void prefetch() override { output_ = data_; }
/** /**
* Forward propagation. Copy data_ (value, in, grad, ids, cpuSequenceDims, * Forward propagation. Copy data_ (value, in, grad, ids, cpuSequenceDims,
* sequenceStartPositions, subSequenceStartPositions, strs) to output_. * sequenceStartPositions, subSequenceStartPositions, strs) to output_.
*/ */
virtual void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
copyDataToOutput(output_); copyDataToOutput(output_);
if (FLAGS_show_layer_stat) { if (FLAGS_show_layer_stat) {
...@@ -50,9 +50,9 @@ public: ...@@ -50,9 +50,9 @@ public:
/** /**
* Data layer's backward propagation do nothing. * Data layer's backward propagation do nothing.
*/ */
virtual void backward(const UpdateCallback& callback) { (void)callback; } void backward(const UpdateCallback& callback) override { (void)callback; }
virtual void copyOutputToOtherDevice() { void copyOutputToOtherDevice() override {
for (size_t i = 0; i != outputOtherDevice_.size(); i++) { for (size_t i = 0; i != outputOtherDevice_.size(); i++) {
copyDataToOutput(outputOtherDevice_[i]); copyDataToOutput(outputOtherDevice_[i]);
} }
......
...@@ -44,10 +44,11 @@ public: ...@@ -44,10 +44,11 @@ public:
~DataNormLayer() {} ~DataNormLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
int mode_; int mode_;
......
...@@ -27,14 +27,14 @@ class EosIdCheckLayer : public Layer { ...@@ -27,14 +27,14 @@ class EosIdCheckLayer : public Layer {
public: public:
explicit EosIdCheckLayer(const LayerConfig& config) : Layer(config) {} explicit EosIdCheckLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap); bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1UL, inputLayers_.size()); CHECK_EQ(1UL, inputLayers_.size());
return ret; return ret;
} }
virtual void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
const Argument& input = getInput(0); const Argument& input = getInput(0);
...@@ -42,7 +42,7 @@ public: ...@@ -42,7 +42,7 @@ public:
output_.ids->isEqualTo(*input.ids, config_.eos_id()); output_.ids->isEqualTo(*input.ids, config_.eos_id());
} }
virtual void backward(const UpdateCallback& callback) {} void backward(const UpdateCallback& callback) override {}
}; };
REGISTER_LAYER(eos_id, EosIdCheckLayer); REGISTER_LAYER(eos_id, EosIdCheckLayer);
......
...@@ -48,7 +48,8 @@ public: ...@@ -48,7 +48,8 @@ public:
~ExpandConvBaseLayer() {} ~ExpandConvBaseLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
size_t getOutputSize(); size_t getOutputSize();
/** /**
......
...@@ -35,10 +35,11 @@ public: ...@@ -35,10 +35,11 @@ public:
~ExpandConvLayer() {} ~ExpandConvLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -34,10 +34,11 @@ public: ...@@ -34,10 +34,11 @@ public:
~ExpandConvTransLayer() {} ~ExpandConvTransLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -53,10 +53,11 @@ public: ...@@ -53,10 +53,11 @@ public:
~ExpandLayer() {} ~ExpandLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -46,10 +46,11 @@ public: ...@@ -46,10 +46,11 @@ public:
~FeatureMapExpandLayer() {} ~FeatureMapExpandLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(featmap_expand, FeatureMapExpandLayer); REGISTER_LAYER(featmap_expand, FeatureMapExpandLayer);
......
...@@ -36,13 +36,14 @@ public: ...@@ -36,13 +36,14 @@ public:
explicit FullyConnectedLayer(const LayerConfig& config) : Layer(config) {} explicit FullyConnectedLayer(const LayerConfig& config) : Layer(config) {}
~FullyConnectedLayer() {} ~FullyConnectedLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
Weight& getWeight(int idx) { return *weights_[idx]; } Weight& getWeight(int idx) { return *weights_[idx]; }
void prefetch(); void prefetch() override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -50,17 +50,18 @@ class GatedRecurrentLayer : public Layer, public GruCompute { ...@@ -50,17 +50,18 @@ class GatedRecurrentLayer : public Layer, public GruCompute {
public: public:
explicit GatedRecurrentLayer(const LayerConfig& config) : Layer(config) {} explicit GatedRecurrentLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
void resetState(); void resetState() override;
void setState(LayerStatePtr state); void setState(LayerStatePtr state) override;
LayerStatePtr getState(); LayerStatePtr getState() override;
protected: protected:
void forwardSequence(int batchSize, void forwardSequence(int batchSize,
......
...@@ -22,17 +22,18 @@ public: ...@@ -22,17 +22,18 @@ public:
~GetOutputLayer() {} ~GetOutputLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
if (!Layer::init(layerMap, parameterMap)) return false; if (!Layer::init(layerMap, parameterMap)) return false;
CHECK_EQ(1U, inputLayers_.size()); CHECK_EQ(1U, inputLayers_.size());
CHECK_NE(inputArgument_[0], ""); CHECK_NE(inputArgument_[0], "");
return true; return true;
} }
void forward(PassType passType) { void forward(PassType passType) override {
output_ = getPrev(0)->getOutput(inputArgument_[0]); output_ = getPrev(0)->getOutput(inputArgument_[0]);
} }
void backward(const UpdateCallback& callback = nullptr) {} void backward(const UpdateCallback& callback = nullptr) override {}
}; };
REGISTER_LAYER(get_output, GetOutputLayer); REGISTER_LAYER(get_output, GetOutputLayer);
......
...@@ -55,10 +55,11 @@ public: ...@@ -55,10 +55,11 @@ public:
~GruStepLayer() {} ~GruStepLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(gru_step, GruStepLayer); REGISTER_LAYER(gru_step, GruStepLayer);
......
...@@ -61,9 +61,10 @@ class HierarchicalSigmoidLayer : public Layer { ...@@ -61,9 +61,10 @@ class HierarchicalSigmoidLayer : public Layer {
public: public:
explicit HierarchicalSigmoidLayer(const LayerConfig& config) explicit HierarchicalSigmoidLayer(const LayerConfig& config)
: Layer(config) {} : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
virtual void forward(PassType passType); const ParameterMap& parameterMap) override;
virtual void backward(const UpdateCallback& callback); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected: protected:
/** /**
......
...@@ -43,10 +43,11 @@ public: ...@@ -43,10 +43,11 @@ public:
~InterpolationLayer() {} ~InterpolationLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(interpolation, InterpolationLayer); REGISTER_LAYER(interpolation, InterpolationLayer);
......
...@@ -311,6 +311,7 @@ public: ...@@ -311,6 +311,7 @@ public:
return *output->second; return *output->second;
} else { } else {
LOG(FATAL) << "No specific output " << str; LOG(FATAL) << "No specific output " << str;
return *((Argument*)nullptr);
} }
} }
} }
......
...@@ -74,17 +74,18 @@ class LstmLayer : public Layer, public LstmCompute { ...@@ -74,17 +74,18 @@ class LstmLayer : public Layer, public LstmCompute {
public: public:
explicit LstmLayer(const LayerConfig &config) : Layer(config) {} explicit LstmLayer(const LayerConfig &config) : Layer(config) {}
bool init(const LayerMap &layerMap, const ParameterMap &parameterMap); bool init(const LayerMap &layerMap,
const ParameterMap &parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback &callback); void backward(const UpdateCallback &callback) override;
void resetState(); void resetState() override;
void setState(LayerStatePtr state); void setState(LayerStatePtr state) override;
LayerStatePtr getState(); LayerStatePtr getState() override;
protected: protected:
/** /**
......
...@@ -35,10 +35,11 @@ public: ...@@ -35,10 +35,11 @@ public:
~LstmStepLayer() {} ~LstmStepLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(lstm_step, LstmStepLayer); REGISTER_LAYER(lstm_step, LstmStepLayer);
......
...@@ -181,11 +181,12 @@ class MDLstmLayer : public LstmLayer { ...@@ -181,11 +181,12 @@ class MDLstmLayer : public LstmLayer {
public: public:
explicit MDLstmLayer(const LayerConfig& config) : LstmLayer(config) {} explicit MDLstmLayer(const LayerConfig& config) : LstmLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
protected: protected:
void forwardOneSequence(int start, CoordIterator& coordIter); void forwardOneSequence(int start, CoordIterator& coordIter);
......
...@@ -30,8 +30,8 @@ private: ...@@ -30,8 +30,8 @@ private:
public: public:
explicit MaxIdLayer(const LayerConfig& config) : Layer(config) {} explicit MaxIdLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap); bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1UL, inputLayers_.size()); CHECK_EQ(1UL, inputLayers_.size());
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
return ret; return ret;
} }
virtual void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
const Argument& input = getInput(0); const Argument& input = getInput(0);
size_t batchSize = input.getBatchSize(); size_t batchSize = input.getBatchSize();
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
input.value->rowMax(*output_.ids, *output_.in); input.value->rowMax(*output_.ids, *output_.in);
} }
virtual void backward(const UpdateCallback& callback) {} void backward(const UpdateCallback& callback) override {}
}; };
REGISTER_LAYER(maxid, MaxIdLayer); REGISTER_LAYER(maxid, MaxIdLayer);
......
...@@ -42,14 +42,13 @@ protected: ...@@ -42,14 +42,13 @@ protected:
public: public:
explicit MaxLayer(const LayerConfig& config) : SequencePoolLayer(config) {} explicit MaxLayer(const LayerConfig& config) : SequencePoolLayer(config) {}
~MaxLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) {
return SequencePoolLayer::init(layerMap, parameterMap); return SequencePoolLayer::init(layerMap, parameterMap);
} }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -45,10 +45,11 @@ public: ...@@ -45,10 +45,11 @@ public:
explicit MaxOutLayer(const LayerConfig& config) : Layer(config) {} explicit MaxOutLayer(const LayerConfig& config) : Layer(config) {}
virtual ~MaxOutLayer() {} virtual ~MaxOutLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -35,21 +35,22 @@ public: ...@@ -35,21 +35,22 @@ public:
~MixedLayer() {} ~MixedLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void prefetch(); void prefetch() override;
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
virtual void resetState(); void resetState() override;
/** /**
* setState() should be called after getState(). * setState() should be called after getState().
* Argument state consists of all projections states. * Argument state consists of all projections states.
*/ */
virtual void setState(LayerStatePtr state); void setState(LayerStatePtr state) override;
/** /**
* Return state which consists of all projections states. * Return state which consists of all projections states.
*/ */
virtual LayerStatePtr getState(); LayerStatePtr getState() override;
protected: protected:
std::vector<std::unique_ptr<Projection>> projections_; std::vector<std::unique_ptr<Projection>> projections_;
......
...@@ -69,10 +69,11 @@ public: ...@@ -69,10 +69,11 @@ public:
~MultiplexLayer() {} ~MultiplexLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
private: private:
/** /**
......
...@@ -61,7 +61,8 @@ public: ...@@ -61,7 +61,8 @@ public:
rand_(0, config.num_classes() - 1), rand_(0, config.num_classes() - 1),
prepared_(false) {} prepared_(false) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
/* Initialize the basic parent class */ /* Initialize the basic parent class */
Layer::init(layerMap, parameterMap); Layer::init(layerMap, parameterMap);
...@@ -146,7 +147,7 @@ public: ...@@ -146,7 +147,7 @@ public:
prepared_ = true; prepared_ = true;
} }
void prefetch() { void prefetch() override {
prepareSamples(); prepareSamples();
IVector::resizeOrCreate(labelIds_, samples_.size(), useGpu_); IVector::resizeOrCreate(labelIds_, samples_.size(), useGpu_);
int* ids = labelIds_->getData(); int* ids = labelIds_->getData();
...@@ -163,7 +164,7 @@ public: ...@@ -163,7 +164,7 @@ public:
} }
} }
void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
CHECK(!useGpu_) << "GPU is not supported"; CHECK(!useGpu_) << "GPU is not supported";
...@@ -199,7 +200,7 @@ public: ...@@ -199,7 +200,7 @@ public:
forwardCost(); forwardCost();
} }
void backward(const UpdateCallback& callback) { void backward(const UpdateCallback& callback) override {
Matrix::resizeOrCreate(sampleOut_.grad, Matrix::resizeOrCreate(sampleOut_.grad,
1, 1,
samples_.size(), samples_.size(),
......
...@@ -30,7 +30,8 @@ class NormLayer : public Layer { ...@@ -30,7 +30,8 @@ class NormLayer : public Layer {
public: public:
explicit NormLayer(const LayerConfig& config) : Layer(config) {} explicit NormLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
Layer::init(layerMap, parameterMap); Layer::init(layerMap, parameterMap);
return true; return true;
} }
...@@ -56,9 +57,10 @@ protected: ...@@ -56,9 +57,10 @@ protected:
public: public:
explicit ResponseNormLayer(const LayerConfig& config) : NormLayer(config) {} explicit ResponseNormLayer(const LayerConfig& config) : NormLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType) { LOG(FATAL) << "Not implemented"; } const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback = nullptr) { void forward(PassType passType) override { LOG(FATAL) << "Not implemented"; }
void backward(const UpdateCallback& callback = nullptr) override {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
}; };
......
...@@ -36,9 +36,10 @@ public: ...@@ -36,9 +36,10 @@ public:
size_t getSize(); size_t getSize();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType); const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback = nullptr); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
TensorShape shape_; TensorShape shape_;
......
...@@ -38,10 +38,11 @@ public: ...@@ -38,10 +38,11 @@ public:
~OuterProdLayer() {} ~OuterProdLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(out_prod, OuterProdLayer); REGISTER_LAYER(out_prod, OuterProdLayer);
......
...@@ -29,9 +29,10 @@ public: ...@@ -29,9 +29,10 @@ public:
~PadLayer() {} ~PadLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType); const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback = nullptr); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
void setOutDims(const size_t batchSize); void setOutDims(const size_t batchSize);
......
...@@ -56,9 +56,10 @@ public: ...@@ -56,9 +56,10 @@ public:
~ParameterReluLayer() {} ~ParameterReluLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -46,7 +46,8 @@ public: ...@@ -46,7 +46,8 @@ public:
*/ */
static Layer* create(const LayerConfig& config); static Layer* create(const LayerConfig& config);
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
size_t getSize(); size_t getSize();
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -40,10 +40,11 @@ public: ...@@ -40,10 +40,11 @@ public:
~PowerLayer() {} ~PowerLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(power, PowerLayer); REGISTER_LAYER(power, PowerLayer);
......
...@@ -19,38 +19,17 @@ namespace paddle { ...@@ -19,38 +19,17 @@ namespace paddle {
class PrintLayer : public Layer { class PrintLayer : public Layer {
public: public:
explicit PrintLayer(const LayerConfig& config) : Layer(config) {} explicit PrintLayer(const LayerConfig& config) : Layer(config) {}
void forward(PassType passType);
void backward(const UpdateCallback& callback) {}
};
void PrintLayer::forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
for (size_t i = 0; i != inputLayers_.size(); ++i) { for (size_t i = 0; i != inputLayers_.size(); ++i) {
const auto& argu = getInput(i); getInput(i).printValueString(LOG(INFO),
const std::string& name = inputLayers_[i]->getName(); "layer=" + inputLayers_[i]->getName() + " ");
if (argu.value) {
std::ostringstream os;
argu.value->print(os);
LOG(INFO) << "layer=" << name << " value matrix:\n" << os.str();
}
if (argu.ids) {
std::ostringstream os;
argu.ids->print(os, argu.ids->getSize());
LOG(INFO) << "layer=" << name << " ids vector:\n" << os.str();
}
if (auto startPos = argu.sequenceStartPositions) {
std::ostringstream os;
startPos->getVector(false)->print(os, startPos->getSize());
LOG(INFO) << "layer=" << name << " sequence pos vector:\n" << os.str();
}
if (auto subStartPos = argu.subSequenceStartPositions) {
std::ostringstream os;
subStartPos->getVector(false)->print(os, subStartPos->getSize());
LOG(INFO) << "layer=" << name << " sub-sequence pos vector:\n"
<< os.str();
} }
} }
}
void backward(const UpdateCallback& callback) override {}
};
REGISTER_LAYER(print, PrintLayer); REGISTER_LAYER(print, PrintLayer);
......
...@@ -30,10 +30,11 @@ namespace paddle { ...@@ -30,10 +30,11 @@ namespace paddle {
class PriorBoxLayer : public Layer { class PriorBoxLayer : public Layer {
public: public:
explicit PriorBoxLayer(const LayerConfig& config) : Layer(config) {} explicit PriorBoxLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) {} void backward(const UpdateCallback& callback) override {}
protected: protected:
int numPriors_; int numPriors_;
......
...@@ -45,17 +45,18 @@ class RecurrentLayer : public Layer { ...@@ -45,17 +45,18 @@ class RecurrentLayer : public Layer {
public: public:
explicit RecurrentLayer(const LayerConfig& config) : Layer(config) {} explicit RecurrentLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
void resetState(); void resetState() override;
void setState(LayerStatePtr state); void setState(LayerStatePtr state) override;
LayerStatePtr getState(); LayerStatePtr getState() override;
protected: protected:
/** /**
......
...@@ -33,15 +33,15 @@ public: ...@@ -33,15 +33,15 @@ public:
void initSubNetwork(NeuralNetwork* rootNetwork, void initSubNetwork(NeuralNetwork* rootNetwork,
const ModelConfig& config, const ModelConfig& config,
const std::vector<ParameterType>& parameterTypes, const std::vector<ParameterType>& parameterTypes,
bool useGpu); bool useGpu) override;
void forward(PassType passType) { void forward(PassType passType) override {
REGISTER_TIMER_INFO("RecurrentGroupFwTime", getName().c_str()); REGISTER_TIMER_INFO("RecurrentGroupFwTime", getName().c_str());
const std::vector<Argument> inArgs; const std::vector<Argument> inArgs;
std::vector<Argument> outArgs; std::vector<Argument> outArgs;
network_->forward(inArgs, &outArgs, passType); network_->forward(inArgs, &outArgs, passType);
} }
void backward(const UpdateCallback& callback) { void backward(const UpdateCallback& callback) override {
REGISTER_TIMER_INFO("RecurrentGroupBwTime", getName().c_str()); REGISTER_TIMER_INFO("RecurrentGroupBwTime", getName().c_str());
network_->backward(nullptr); network_->backward(nullptr);
...@@ -53,7 +53,8 @@ public: ...@@ -53,7 +53,8 @@ public:
/** /**
* @see Layer.accessSubNetwork * @see Layer.accessSubNetwork
*/ */
void accessSubNetwork(const std::function<void(NeuralNetwork&)>& callback) { void accessSubNetwork(
const std::function<void(NeuralNetwork&)>& callback) override {
callback(*network_); callback(*network_);
} }
......
...@@ -20,18 +20,19 @@ namespace paddle { ...@@ -20,18 +20,19 @@ namespace paddle {
/** /**
* @brief A layer for resizing a minibatch matrix h*w to h'*w' * @brief A layer for resizing a minibatch matrix h*w to h'*w'
* @note * @note
* origin matrix height * witdth) * origin matrix height * width)
* resize matrix: (height * width / size) * size * resize matrix: (height * width / size) * size
*/ */
class ResizeLayer : public Layer { class ResizeLayer : public Layer {
public: public:
explicit ResizeLayer(const LayerConfig& config) : Layer(config) {} explicit ResizeLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
}; };
REGISTER_LAYER(resize, ResizeLayer); REGISTER_LAYER(resize, ResizeLayer);
......
...@@ -35,8 +35,8 @@ public: ...@@ -35,8 +35,8 @@ public:
explicit SamplingIdLayer(const LayerConfig& config) explicit SamplingIdLayer(const LayerConfig& config)
: Layer(config), rand1_(0, 1) {} : Layer(config), rand1_(0, 1) {}
virtual bool init(const LayerMap& layerMap, bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap); bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1UL, inputLayers_.size()); CHECK_EQ(1UL, inputLayers_.size());
if (useGpu_) { if (useGpu_) {
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
return ret; return ret;
} }
void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
if (useGpu_) { if (useGpu_) {
for (size_t i = 0; i < inputLayers_.size(); i++) { for (size_t i = 0; i < inputLayers_.size(); i++) {
...@@ -83,7 +83,7 @@ public: ...@@ -83,7 +83,7 @@ public:
output_.ids->copyFrom(ids.data(), batchSize); output_.ids->copyFrom(ids.data(), batchSize);
} }
virtual void backward(const UpdateCallback& callback) {} void backward(const UpdateCallback& callback) override {}
}; };
REGISTER_LAYER(sampling_id, SamplingIdLayer); REGISTER_LAYER(sampling_id, SamplingIdLayer);
......
...@@ -37,10 +37,11 @@ public: ...@@ -37,10 +37,11 @@ public:
~ScalingLayer() {} ~ScalingLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(scaling, ScalingLayer); REGISTER_LAYER(scaling, ScalingLayer);
......
...@@ -65,9 +65,10 @@ public: ...@@ -65,9 +65,10 @@ public:
: Layer(config), selCols_(nullptr) {} : Layer(config), selCols_(nullptr) {}
~SelectiveFullyConnectedLayer() {} ~SelectiveFullyConnectedLayer() {}
void prefetch(); void prefetch() override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
Weight& getWeight(int idx) { return *weights_[idx]; } Weight& getWeight(int idx) { return *weights_[idx]; }
...@@ -90,8 +91,8 @@ public: ...@@ -90,8 +91,8 @@ public:
void fillSelectiveData( void fillSelectiveData(
const std::shared_ptr<std::vector<std::pair<int*, size_t>>>& candidates); const std::shared_ptr<std::vector<std::pair<int*, size_t>>>& candidates);
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
private: private:
/** /**
......
...@@ -21,9 +21,11 @@ namespace paddle { ...@@ -21,9 +21,11 @@ namespace paddle {
/** /**
* A layer for concatenating the first sequence with the second sequence * A layer for concatenating the first sequence with the second sequence
* following the first * Input: two sequences each containing the same number of instances
* Input: two sequences each containing some instances * seq1 = [a1, a2, ..., an]
* seq2 = [b1, b2, ..., bn]
* Output: a concatenated sequence of the two input sequences * Output: a concatenated sequence of the two input sequences
* out = [a1, b1, a2, b2, ..., an, bn]
*/ */
class SequenceConcatLayer : public Layer { class SequenceConcatLayer : public Layer {
...@@ -35,10 +37,11 @@ public: ...@@ -35,10 +37,11 @@ public:
~SequenceConcatLayer() {} ~SequenceConcatLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(seqconcat, SequenceConcatLayer); REGISTER_LAYER(seqconcat, SequenceConcatLayer);
...@@ -167,13 +170,17 @@ void SequenceConcatLayer::backward(const UpdateCallback& callback) { ...@@ -167,13 +170,17 @@ void SequenceConcatLayer::backward(const UpdateCallback& callback) {
size_t rightNumIns = 0; size_t rightNumIns = 0;
for (size_t seqId = 0; seqId < numSequences1; ++seqId) { for (size_t seqId = 0; seqId < numSequences1; ++seqId) {
leftNumIns = starts1[seqId + 1] - starts1[seqId]; leftNumIns = starts1[seqId + 1] - starts1[seqId];
inputGrad1->subMatrix(starts1[seqId], leftNumIns) if (inputGrad1) {
->add(*(outputGrad->subMatrix(offset, leftNumIns))); inputGrad1->subMatrix(starts1[seqId], leftNumIns)
->add(*(outputGrad->subMatrix(offset, leftNumIns)));
}
offset += leftNumIns; offset += leftNumIns;
rightNumIns = starts2[seqId + 1] - starts2[seqId]; rightNumIns = starts2[seqId + 1] - starts2[seqId];
inputGrad2->subMatrix(starts2[seqId], rightNumIns) if (inputGrad2) {
->add(*(outputGrad->subMatrix(offset, rightNumIns))); inputGrad2->subMatrix(starts2[seqId], rightNumIns)
->add(*(outputGrad->subMatrix(offset, rightNumIns)));
}
offset += rightNumIns; offset += rightNumIns;
} }
} }
......
...@@ -42,12 +42,11 @@ public: ...@@ -42,12 +42,11 @@ public:
explicit SequenceLastInstanceLayer(const LayerConfig& config) explicit SequenceLastInstanceLayer(const LayerConfig& config)
: SequencePoolLayer(config) {} : SequencePoolLayer(config) {}
~SequenceLastInstanceLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
REGISTER_LAYER(seqlastins, SequenceLastInstanceLayer); REGISTER_LAYER(seqlastins, SequenceLastInstanceLayer);
......
...@@ -46,12 +46,11 @@ protected: ...@@ -46,12 +46,11 @@ protected:
public: public:
explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {} explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {}
virtual ~SequencePoolLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
} // namespace paddle } // namespace paddle
...@@ -20,9 +20,12 @@ limitations under the License. */ ...@@ -20,9 +20,12 @@ limitations under the License. */
namespace paddle { namespace paddle {
/** /**
* A layer for reshaping the sequence * A layer for reshaping the sequence. Assume the input sequence has
* Input: a sequence * T instances, the dimension of each instance is M, and the input
* Output: a sequence * reshape_dim is N, then the output sequence has T*M/N instances,
* the dimension of each instance is N.
*
* Note that T*M/N must be an integer.
*/ */
class SequenceReshapeLayer : public Layer { class SequenceReshapeLayer : public Layer {
...@@ -34,12 +37,11 @@ protected: ...@@ -34,12 +37,11 @@ protected:
public: public:
explicit SequenceReshapeLayer(const LayerConfig& config) : Layer(config) {} explicit SequenceReshapeLayer(const LayerConfig& config) : Layer(config) {}
~SequenceReshapeLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
REGISTER_LAYER(seqreshape, SequenceReshapeLayer); REGISTER_LAYER(seqreshape, SequenceReshapeLayer);
......
...@@ -39,12 +39,11 @@ class SlopeInterceptLayer : public Layer { ...@@ -39,12 +39,11 @@ class SlopeInterceptLayer : public Layer {
public: public:
explicit SlopeInterceptLayer(const LayerConfig& config) : Layer(config) {} explicit SlopeInterceptLayer(const LayerConfig& config) : Layer(config) {}
~SlopeInterceptLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
REGISTER_LAYER(slope_intercept, SlopeInterceptLayer); REGISTER_LAYER(slope_intercept, SlopeInterceptLayer);
......
...@@ -43,9 +43,8 @@ protected: ...@@ -43,9 +43,8 @@ protected:
public: public:
explicit SpatialPyramidPoolLayer(const LayerConfig& config) : Layer(config) {} explicit SpatialPyramidPoolLayer(const LayerConfig& config) : Layer(config) {}
~SpatialPyramidPoolLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
ProjectionConfig getConfig(size_t sizeX_, ProjectionConfig getConfig(size_t sizeX_,
size_t sizeY_, size_t sizeY_,
...@@ -54,7 +53,7 @@ public: ...@@ -54,7 +53,7 @@ public:
std::string& poolType_); std::string& poolType_);
size_t getSize(); size_t getSize();
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -35,12 +35,11 @@ protected: ...@@ -35,12 +35,11 @@ protected:
public: public:
explicit SubSequenceLayer(const LayerConfig& config) : Layer(config) {} explicit SubSequenceLayer(const LayerConfig& config) : Layer(config) {}
~SubSequenceLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
REGISTER_LAYER(subseq, SubSequenceLayer); REGISTER_LAYER(subseq, SubSequenceLayer);
......
...@@ -41,12 +41,11 @@ protected: ...@@ -41,12 +41,11 @@ protected:
public: public:
explicit SumToOneNormLayer(const LayerConfig& config) : Layer(config) {} explicit SumToOneNormLayer(const LayerConfig& config) : Layer(config) {}
~SumToOneNormLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
REGISTER_LAYER(sum_to_one_norm, SumToOneNormLayer); REGISTER_LAYER(sum_to_one_norm, SumToOneNormLayer);
......
...@@ -44,13 +44,12 @@ protected: ...@@ -44,13 +44,12 @@ protected:
public: public:
explicit TensorLayer(const LayerConfig& config) : Layer(config) {} explicit TensorLayer(const LayerConfig& config) : Layer(config) {}
~TensorLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
Weight& getWeight(int idx) { return *weights_[idx]; } Weight& getWeight(int idx) { return *weights_[idx]; }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -32,9 +32,10 @@ class TransLayer : public Layer { ...@@ -32,9 +32,10 @@ class TransLayer : public Layer {
public: public:
explicit TransLayer(const LayerConfig& config) : Layer(config) {} explicit TransLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -26,7 +26,8 @@ class ValidationLayer : public Layer { ...@@ -26,7 +26,8 @@ class ValidationLayer : public Layer {
public: public:
explicit ValidationLayer(const LayerConfig& config) : Layer(config) {} explicit ValidationLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer() { return inputLayers_[0]; } LayerPtr getOutputLayer() { return inputLayers_[0]; }
...@@ -37,13 +38,13 @@ public: ...@@ -37,13 +38,13 @@ public:
return inputLayers_[2]; return inputLayers_[2];
} }
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
virtual void validationImp(MatrixPtr outputValue, IVectorPtr label) = 0; virtual void validationImp(MatrixPtr outputValue, IVectorPtr label) = 0;
virtual void onPassEnd() = 0; void onPassEnd() override = 0;
}; };
/* /*
...@@ -57,11 +58,12 @@ public: ...@@ -57,11 +58,12 @@ public:
cpuLabel_(nullptr), cpuLabel_(nullptr),
cpuWeight_(nullptr) {} cpuWeight_(nullptr) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void validationImp(MatrixPtr outputValue, IVectorPtr label); void validationImp(MatrixPtr outputValue, IVectorPtr label) override;
void onPassEnd(); void onPassEnd() override;
struct PredictionResult { struct PredictionResult {
PredictionResult(real __out, int __label) : out(__out), label(__label) {} PredictionResult(real __out, int __label) : out(__out), label(__label) {}
...@@ -86,11 +88,12 @@ public: ...@@ -86,11 +88,12 @@ public:
explicit PnpairValidation(const LayerConfig& config) explicit PnpairValidation(const LayerConfig& config)
: ValidationLayer(config) {} : ValidationLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void validationImp(MatrixPtr outputValue, IVectorPtr label); void validationImp(MatrixPtr outputValue, IVectorPtr label) override;
void onPassEnd(); void onPassEnd() override;
private: private:
bool passBegin_; bool passBegin_;
......
...@@ -30,9 +30,10 @@ public: ...@@ -30,9 +30,10 @@ public:
explicit WarpCTCLayer(const LayerConfig& config) : Layer(config) {} explicit WarpCTCLayer(const LayerConfig& config) : Layer(config) {}
~WarpCTCLayer() {} ~WarpCTCLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
virtual void forward(PassType passType); const ParameterMap& parameterMap) override;
virtual void backward(const UpdateCallback& callback); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected: protected:
/** /**
......
...@@ -110,6 +110,18 @@ void testEvaluator(TestConfig testConf, ...@@ -110,6 +110,18 @@ void testEvaluator(TestConfig testConf,
testEvaluator->finish(); testEvaluator->finish();
LOG(INFO) << *testEvaluator; LOG(INFO) << *testEvaluator;
std::vector<std::string> names;
testEvaluator->getNames(&names);
paddle::Error err;
for (auto& name : names) {
auto value = testEvaluator->getValue(name, &err);
ASSERT_TRUE(err.isOK());
LOG(INFO) << name << " " << value;
auto tp = testEvaluator->getType(name, &err);
ASSERT_TRUE(err.isOK());
ASSERT_EQ(testConf.evaluatorConfig.type(), tp);
}
double totalScore2 = 0.0; double totalScore2 = 0.0;
if (testConf.testAccumulate) { if (testConf.testAccumulate) {
testEvaluator->start(); testEvaluator->start();
...@@ -129,6 +141,7 @@ void testEvaluatorAll(TestConfig testConf, ...@@ -129,6 +141,7 @@ void testEvaluatorAll(TestConfig testConf,
TEST(Evaluator, classification_error) { TEST(Evaluator, classification_error) {
TestConfig config; TestConfig config;
config.evaluatorConfig.set_type("classification_error"); config.evaluatorConfig.set_type("classification_error");
config.evaluatorConfig.set_top_k(5);
config.inputDefs.push_back({INPUT_DATA, "output", 50}); config.inputDefs.push_back({INPUT_DATA, "output", 50});
config.inputDefs.push_back({INPUT_LABEL, "label", 50}); config.inputDefs.push_back({INPUT_LABEL, "label", 50});
......
...@@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { ...@@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
size_t beam = maxVal.getWidth(); size_t beam = maxVal.getWidth();
CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxIds.getSize(), numSamples * beam);
CHECK_EQ(maxVal.getHeight(), numSamples); CHECK_EQ(maxVal.getHeight(), numSamples);
CHECK_EQ(maxVal.getWidth(), beam);
hl_matrix_top_k(maxVal.getData(), hl_matrix_top_k(maxVal.getData(),
maxVal.getStride(), maxVal.getStride(),
...@@ -792,19 +793,32 @@ void GpuMatrix::maxoutBackward(Matrix& a, ...@@ -792,19 +793,32 @@ void GpuMatrix::maxoutBackward(Matrix& a,
} }
/*calulate the error of classification */ /*calulate the error of classification */
void GpuMatrix::classificationError(Matrix& output, IVector& label) { void GpuMatrix::classificationError(Matrix& output,
auto output_ptr = dynamic_cast<const GpuMatrix*>(&output); IVector& label,
auto label_ptr = dynamic_cast<const GpuIVector*>(&label); size_t topkSize) {
CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; auto gpuOutput = dynamic_cast<GpuMatrix*>(&output);
auto gpuLabel = dynamic_cast<GpuIVector*>(&label);
CHECK(height_ == output_ptr->height_ && width_ == 1) size_t numSamples = this->getHeight();
GpuMatrixPtr gpuTopVal = std::make_shared<GpuMatrix>(numSamples, topkSize);
GpuIVectorPtr gpuTopIds = std::make_shared<GpuIVector>(numSamples * topkSize);
CHECK(gpuOutput && gpuLabel) << "Invalid argument pointer";
CHECK(gpuTopVal && gpuTopIds) << "Allocate GPU memory failed";
CHECK(gpuLabel->getSize() == numSamples) << "Vector size is not equal";
CHECK(numSamples == gpuOutput->getHeight() && this->getWidth() == 1)
<< "Matrix dimensions are not equal"; << "Matrix dimensions are not equal";
hl_matrix_classification_error((real*)output_ptr->data_, size_t dim = gpuOutput->getWidth();
(int*)label_ptr->getData(), hl_matrix_classification_error(gpuTopVal->getData(),
data_, gpuTopVal->getStride(),
height_, gpuTopIds->getData(),
output_ptr->width_); gpuOutput->getData(),
gpuOutput->getStride(),
dim,
topkSize,
numSamples,
gpuLabel->getData(),
this->getData());
} }
/* copy -log(output[i * width + label]) to this->data[i] */ /* copy -log(output[i * width + label]) to this->data[i] */
...@@ -941,59 +955,6 @@ void GpuMatrix::softreluDerivative(Matrix& output) { ...@@ -941,59 +955,6 @@ void GpuMatrix::softreluDerivative(Matrix& output) {
void GpuMatrix::scaledTanh(Matrix& output, real p1, real p2) { void GpuMatrix::scaledTanh(Matrix& output, real p1, real p2) {
BaseMatrix::scaledTanh(output, p1, p2); BaseMatrix::scaledTanh(output, p1, p2);
} }
void GpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) {
CHECK(output1.useGpu_ == true && output2.useGpu_ == true)
<< "Matrix type are not equal";
size_t numSamples = getHeight();
size_t dim = output1.getWidth();
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output1.getHeight(), numSamples);
CHECK_EQ(output1.getWidth(), output2.getWidth());
real* out = getData();
real* x = output1.getData();
real* y = output2.getData();
hl_cossim(out, x, y, dim, output1.getHeight(), output2.getHeight(), scale);
}
void GpuMatrix::cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale) {
CHECK(output.useGpu_ == true && prevOut1.useGpu_ == true &&
prevOut2.useGpu_ == true && prevGrad1.useGpu_ == true &&
prevGrad2.useGpu_ == true)
<< "Matrix type are not equal";
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output.getWidth(), 1UL);
size_t numSamples = getHeight();
CHECK_EQ(output.getHeight(), numSamples);
CHECK_EQ(prevOut1.getHeight(), numSamples);
CHECK_EQ(prevGrad1.getHeight(), numSamples);
size_t dim = prevOut1.getWidth();
CHECK_EQ(prevOut2.getWidth(), dim);
CHECK_EQ(prevGrad1.getWidth(), dim);
CHECK_EQ(prevGrad2.getWidth(), dim);
real* grad = getData();
real* out = output.getData();
real* prevOutX = prevOut1.getData();
real* prevOutY = prevOut2.getData();
real* prevGradX = prevGrad1.getData();
real* prevGradY = prevGrad2.getData();
hl_cossim_derivative(grad,
out,
prevOutX,
prevOutY,
prevGradX,
prevGradY,
dim,
prevOut1.getHeight(),
prevOut2.getHeight(),
scale);
}
void GpuMatrix::randomizeUniform() { void GpuMatrix::randomizeUniform() {
CHECK(isContiguous()); CHECK(isContiguous());
...@@ -3092,7 +3053,7 @@ void CpuMatrix::rowMax(Matrix& max) { ...@@ -3092,7 +3053,7 @@ void CpuMatrix::rowMax(Matrix& max) {
max.maxRows(*this); max.maxRows(*this);
} }
/* get beam size of max ids and values */ /* Get the top k elements of each row of this matrix */
void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
CHECK(isContiguous()); CHECK(isContiguous());
CHECK(!maxIds.useGpu() && !maxVal.useGpu()) << "Matrix type are not equal"; CHECK(!maxIds.useGpu() && !maxVal.useGpu()) << "Matrix type are not equal";
...@@ -3100,6 +3061,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { ...@@ -3100,6 +3061,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
size_t beam = maxVal.getWidth(); size_t beam = maxVal.getWidth();
CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxIds.getSize(), numSamples * beam);
CHECK_EQ(maxVal.getHeight(), numSamples); CHECK_EQ(maxVal.getHeight(), numSamples);
CHECK_EQ(maxVal.getWidth(), beam);
real* a = getData(); real* a = getData();
int* s = maxIds.getData(); int* s = maxIds.getData();
...@@ -3251,32 +3213,39 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) { ...@@ -3251,32 +3213,39 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) {
} }
/* calulate classification error */ /* calulate classification error */
void CpuMatrix::classificationError(Matrix& output, IVector& label) { void CpuMatrix::classificationError(Matrix& output,
CHECK(dynamic_cast<const CpuMatrix*>(&output)); IVector& label,
CHECK(dynamic_cast<const CpuIVector*>(&label)); size_t topkSize) {
size_t numSamples = this->getHeight();
auto cpuOutput = dynamic_cast<CpuMatrix*>(&output);
auto cpuLabel = dynamic_cast<CpuIVector*>(&label);
IVectorPtr cpuTopIds = std::make_shared<CpuIVector>(numSamples * topkSize);
MatrixPtr cpuTopVal = std::make_shared<CpuMatrix>(numSamples, topkSize);
CHECK(cpuOutput && cpuLabel) << "Invalid argument pointer";
CHECK(cpuTopIds && cpuTopVal) << "Allocate cpu memory failed";
CHECK(cpuLabel->getSize() == numSamples) << "Vector size is not equal";
CHECK(cpuOutput->getHeight() == numSamples && this->getWidth() == 1)
<< "Matrix dimensions are not equal";
CHECK_EQ(getWidth(), (size_t)1); // top k matrix classification
size_t numSamples = getHeight(); cpuOutput->rowMax(*cpuTopIds, *cpuTopVal);
CHECK_EQ(label.getSize(), numSamples);
CHECK_EQ(output.getHeight(), numSamples);
size_t dim = output.getWidth(); size_t dim = cpuOutput->getWidth();
real* out = output.getData(); real* result = this->getData();
int* lbl = label.getData(); int* ids = cpuTopIds->getData();
real maxData = 0.0; int* lbl = cpuLabel->getData();
int maxIndex = -1;
for (size_t i = 0; i < numSamples; ++i) { for (size_t i = 0; i < numSamples; ++i) {
CHECK_GE(lbl[i], 0); CHECK_GE(lbl[i], 0);
CHECK_LT((size_t)lbl[i], dim); CHECK_LT((size_t)lbl[i], dim);
maxData = out[i * dim];
maxIndex = 0; for (size_t j = 0; j < topkSize; ++j) {
for (size_t j = 0; j < dim; ++j) { if (ids[j + i * topkSize] == lbl[i]) {
if (maxData < out[i * dim + j]) { result[i] = 0;
maxIndex = j; break;
maxData = out[i * dim + j];
} }
result[i] = 1.0f;
} }
getData()[i] = (maxIndex != lbl[i]);
} }
} }
...@@ -3470,105 +3439,6 @@ void CpuMatrix::softmaxDerivative(Matrix& output, Matrix& sftmaxSum) { ...@@ -3470,105 +3439,6 @@ void CpuMatrix::softmaxDerivative(Matrix& output, Matrix& sftmaxSum) {
} }
} }
void CpuMatrix::cosSim(Matrix& output1, Matrix& output2, real scale) {
size_t numSamples = getHeight();
size_t dim = output1.getWidth();
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output1.getHeight(), numSamples);
CHECK_EQ(output1.getWidth(), output2.getWidth());
real* out = getData();
const real* x = output1.getData();
const real* y = output2.getData();
size_t yInc = dim;
if (output2.getHeight() == 1LU) {
yInc = 0;
} else {
CHECK_EQ(output2.getHeight(), numSamples);
}
for (size_t i = 0; i < numSamples; ++i, x += dim, y += yInc) {
real squareSumX = 0;
real squareSumY = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
squareSumX += _square(x[j]);
squareSumY += _square(y[j]);
xy += x[j] * y[j];
}
CHECK(squareSumX > 0 && squareSumY > 0);
out[i] = scale * xy / (std::sqrt(squareSumX) * std::sqrt(squareSumY));
}
}
void CpuMatrix::cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale) {
CHECK(output.useGpu_ == false) << "Matrix type are not equal";
CHECK_EQ(getWidth(), 1UL);
CHECK_EQ(output.getWidth(), 1UL);
size_t numSamples = getHeight();
CHECK_EQ(output.getHeight(), numSamples);
CHECK_EQ(prevOut1.getHeight(), numSamples);
CHECK_EQ(prevGrad1.getHeight(), numSamples);
size_t dim = prevOut1.getWidth();
CHECK_EQ(prevOut2.getWidth(), dim);
CHECK_EQ(prevGrad1.getWidth(), dim);
CHECK_EQ(prevGrad2.getWidth(), dim);
const real* grad = getData();
const real* out = output.getData();
const real* prevOutX = prevOut1.getData();
const real* prevOutY = prevOut2.getData();
real* prevGradX = prevGrad1.getData();
real* prevGradY = prevGrad2.getData();
size_t yInc = dim;
if (prevOut2.getHeight() == 1LU) {
yInc = 0;
CHECK_EQ(prevGrad2.getHeight(), 1LU);
} else {
CHECK_EQ(prevOut2.getHeight(), numSamples);
CHECK_EQ(prevGrad2.getHeight(), numSamples);
}
for (size_t i = 0; i < numSamples; ++i,
prevOutX += dim,
prevOutY += yInc,
prevGradX += dim,
prevGradY += yInc) {
real squareSumX = 0;
real squareSumY = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
squareSumX += _square(prevOutX[j]);
squareSumY += _square(prevOutY[j]);
xy += prevOutX[j] * prevOutY[j];
}
CHECK(squareSumX > 0 && squareSumY > 0);
if (xy == 0) {
real reciprocal = 1.0f / (std::sqrt(squareSumX) * std::sqrt(squareSumY));
for (size_t j = 0; j < dim; ++j) {
prevGradX[j] += scale * grad[i] * prevOutY[j] * reciprocal;
prevGradY[j] += scale * grad[i] * prevOutX[j] * reciprocal;
}
} else {
real reciprocalXY = 1.0f / xy;
real reciprocalSquareSumX = 1.0f / squareSumX;
real reciprocalSquareSumY = 1.0f / squareSumY;
for (size_t j = 0; j < dim; ++j) {
prevGradX[j] += out[i] * grad[i] * (prevOutY[j] * reciprocalXY -
prevOutX[j] * reciprocalSquareSumX);
prevGradY[j] += out[i] * grad[i] * (prevOutX[j] * reciprocalXY -
prevOutY[j] * reciprocalSquareSumY);
}
}
}
}
void CpuMatrix::sumOfSquares(Matrix& output, Matrix& label) { void CpuMatrix::sumOfSquares(Matrix& output, Matrix& label) {
CHECK(output.useGpu_ == false && label.useGpu_ == false) CHECK(output.useGpu_ == false && label.useGpu_ == false)
<< "Matrix type are not equal"; << "Matrix type are not equal";
......
...@@ -799,26 +799,6 @@ public: ...@@ -799,26 +799,6 @@ public:
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
/**
* cosine similarity, for each row i,
* this[i] = cos(output1[i], output2[i])
*
* output2 can only have one row, then for each row i,
* this[i] = cos(output1[i], output2[0])
*/
virtual void cosSim(Matrix& output1, Matrix& output2, real scale = 1.0f) {
LOG(FATAL) << "Not implemented";
}
virtual void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale = 1.0f) {
LOG(FATAL) << "Not implemented";
}
/// print out the values of elements to os /// print out the values of elements to os
virtual void print(std::ostream& os) const { virtual void print(std::ostream& os) const {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
...@@ -856,8 +836,11 @@ public: ...@@ -856,8 +836,11 @@ public:
* output[i] = 1 if row i is an error. * output[i] = 1 if row i is an error.
* *
* output[i] = 0 if row i is correct. * output[i] = 0 if row i is correct.
*
*/ */
virtual void classificationError(Matrix& output, IVector& label) { virtual void classificationError(Matrix& output,
IVector& label,
size_t topkSize = 1) {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
...@@ -1324,14 +1307,6 @@ public: ...@@ -1324,14 +1307,6 @@ public:
void softreluDerivative(Matrix& output); void softreluDerivative(Matrix& output);
void scaledTanh(Matrix& output, real p1, real p2); void scaledTanh(Matrix& output, real p1, real p2);
void cosSim(Matrix& output1, Matrix& output2, real scale);
void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale);
virtual void print(std::ostream& os) const; virtual void print(std::ostream& os) const;
virtual void print(std::ostream& os, size_t height, size_t width) const; virtual void print(std::ostream& os, size_t height, size_t width) const;
...@@ -1342,7 +1317,7 @@ public: ...@@ -1342,7 +1317,7 @@ public:
void check(std::ostream& os, Matrix& refMat, bool printDiff = true); void check(std::ostream& os, Matrix& refMat, bool printDiff = true);
void randomizeUniform(); void randomizeUniform();
void classificationError(Matrix& output, IVector& label); void classificationError(Matrix& output, IVector& label, size_t topkSize = 1);
void convExpand(Matrix& feature, void convExpand(Matrix& feature,
int feaImgHeight, int feaImgHeight,
...@@ -1752,14 +1727,6 @@ public: ...@@ -1752,14 +1727,6 @@ public:
void softreluDerivative(Matrix& output); void softreluDerivative(Matrix& output);
void scaledTanh(Matrix& output, real p1, real p2); void scaledTanh(Matrix& output, real p1, real p2);
void cosSim(Matrix& output1, Matrix& output2, real scale);
void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale);
void print(std::ostream& os) const; void print(std::ostream& os) const;
void print(std::ostream& os, size_t height, size_t width) const; void print(std::ostream& os, size_t height, size_t width) const;
void printOneRow(std::ostream& os, size_t idx) const; void printOneRow(std::ostream& os, size_t idx) const;
...@@ -1775,7 +1742,7 @@ public: ...@@ -1775,7 +1742,7 @@ public:
void randomizeUniform(); void randomizeUniform();
void classificationError(Matrix& output, IVector& label); void classificationError(Matrix& output, IVector& label, size_t topkSize = 1);
void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec); void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec);
......
...@@ -181,28 +181,6 @@ TEST(Matrix, copyByRowIndex) { ...@@ -181,28 +181,6 @@ TEST(Matrix, copyByRowIndex) {
} }
} }
void testCosSim(int heightX, int heightY, int width, real scale) {
AutoCompare test(heightX, 1);
CpuMatrix arg1(heightX, width);
CpuMatrix arg2(heightY, width);
arg1.randomizeUniform();
arg2.randomizeUniform();
arg2.add(-0.5);
test.cmpWithArg(&Matrix::cosSim, arg1, arg2, scale);
}
TEST(Matrix, cosSim) {
for (auto heightX : {10, 100, 1000}) {
for (auto heightY : {1, heightX}) {
for (auto width : {10, 100, 1000}) {
for (auto scale : {1.0, 2.0}) {
testCosSim(heightX, heightY, width, scale);
}
}
}
}
}
void testParamReluForward(int height, int width, int w_height, int w_width) { void testParamReluForward(int height, int width, int w_height, int w_width) {
AutoCompare test(height, width); AutoCompare test(height, width);
CpuMatrix arg1(height, width); CpuMatrix arg1(height, width);
......
...@@ -720,61 +720,6 @@ TEST(Matrix, sequenceAvgForward) { ...@@ -720,61 +720,6 @@ TEST(Matrix, sequenceAvgForward) {
} }
} }
void testCosSimDerivate(int heightX, int heightY, int width, real scale) {
MatrixPtr prevOutX = CpuMatrix::create(heightX, width, false, false);
MatrixPtr prevOutY = CpuMatrix::create(heightY, width, false, false);
MatrixPtr grad = CpuMatrix::create(heightX, 1, false, false);
MatrixPtr output = CpuMatrix::create(heightX, 1, false, false);
MatrixPtr prevGradX = CpuMatrix::create(heightX, width, false, false);
MatrixPtr prevGradY = CpuMatrix::create(heightY, width, false, false);
prevOutX->randomizeUniform();
prevOutY->randomizeUniform();
grad->randomizeUniform();
output->randomizeUniform();
prevGradX->randomizeUniform();
prevGradY->randomizeUniform();
MatrixPtr prevOutXGpu = GpuMatrix::create(heightX, width, false, true);
MatrixPtr prevOutYGpu = GpuMatrix::create(heightY, width, false, true);
MatrixPtr gradGpu = GpuMatrix::create(heightX, 1, false, true);
MatrixPtr outputGpu = GpuMatrix::create(heightX, 1, false, true);
MatrixPtr prevGradXGpu = GpuMatrix::create(heightX, width, false, true);
MatrixPtr prevGradYGpu = GpuMatrix::create(heightY, width, false, true);
prevOutXGpu->copyFrom(*prevOutX);
prevOutYGpu->copyFrom(*prevOutY);
gradGpu->copyFrom(*grad);
outputGpu->copyFrom(*output);
prevGradXGpu->copyFrom(*prevGradX);
prevGradYGpu->copyFrom(*prevGradY);
grad->cosSimDerivative(
*output, *prevOutX, *prevOutY, *prevGradX, *prevGradY, scale);
gradGpu->cosSimDerivative(*outputGpu,
*prevOutXGpu,
*prevOutYGpu,
*prevGradXGpu,
*prevGradYGpu,
scale);
TensorCheckErr(*prevGradX, *prevGradXGpu);
TensorCheckErr(*prevGradY, *prevGradYGpu);
}
TEST(Matrix, cosSimDerivate) {
for (auto heightX : {1, 10, 100}) {
for (auto heightY : {1, heightX}) {
for (auto width : {1, 10, 100}) {
for (auto scale : {1.0, 2.0}) {
testCosSimDerivate(heightX, heightY, width, scale);
}
}
}
}
}
void testParamReluBackwardDiff(int height, void testParamReluBackwardDiff(int height,
int width, int width,
int w_height, int w_height,
...@@ -819,7 +764,7 @@ TEST(Matrix, paramReluBackwardDiff) { ...@@ -819,7 +764,7 @@ TEST(Matrix, paramReluBackwardDiff) {
} }
} }
void testClassificationError(int numSamples, int dim) { void testClassificationError(int numSamples, int dim, int topkSize) {
MatrixPtr cpuError = std::make_shared<CpuMatrix>(numSamples, 1); MatrixPtr cpuError = std::make_shared<CpuMatrix>(numSamples, 1);
MatrixPtr gpuError = std::make_shared<GpuMatrix>(numSamples, 1); MatrixPtr gpuError = std::make_shared<GpuMatrix>(numSamples, 1);
MatrixPtr cpuOutput = std::make_shared<CpuMatrix>(numSamples, dim); MatrixPtr cpuOutput = std::make_shared<CpuMatrix>(numSamples, dim);
...@@ -832,17 +777,22 @@ void testClassificationError(int numSamples, int dim) { ...@@ -832,17 +777,22 @@ void testClassificationError(int numSamples, int dim) {
gpuOutput->copyFrom(*cpuOutput); gpuOutput->copyFrom(*cpuOutput);
gpuLabel->copyFrom(*cpuLabel); gpuLabel->copyFrom(*cpuLabel);
cpuError->classificationError(*cpuOutput, *cpuLabel); cpuError->classificationError(*cpuOutput, *cpuLabel, topkSize);
gpuError->classificationError(*gpuOutput, *gpuLabel); gpuError->classificationError(*gpuOutput, *gpuLabel, topkSize);
TensorCheckEqual(*cpuError, *gpuError); TensorCheckEqual(*cpuError, *gpuError);
} }
TEST(Matrix, classificationError) { TEST(Matrix, classificationError) {
for (auto numSamples : {1, 10, 100, 1000, 70000}) { for (auto numSamples : {1, 5, 31, 90, 150, 300}) {
for (auto dim : {1, 10, 100, 1000}) { for (auto dim :
VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; {1, 5, 8, 10, 15, 64, 80, 120, 256, 300, 1280, 5120, 50000}) {
testClassificationError(numSamples, dim); for (auto topkSize : {1, 5, 10, 20, 40, (int)rand() % dim + 1}) {
if (topkSize > dim) continue;
VLOG(3) << " sample= " << numSamples << " topkSize= " << topkSize
<< " dim= " << dim;
testClassificationError(numSamples, dim, topkSize);
}
} }
} }
} }
......
...@@ -602,6 +602,44 @@ void Argument::degradeSequence(const Argument& input, bool useGpu) { ...@@ -602,6 +602,44 @@ void Argument::degradeSequence(const Argument& input, bool useGpu) {
tgtBuf[numSequences] = numSubSequences; tgtBuf[numSequences] = numSubSequences;
} }
void Argument::getValueString(
std::unordered_map<std::string, std::string>* out) const {
if (value) {
std::ostringstream os;
value->print(os);
out->insert({"value", os.str()});
}
if (ids) {
std::ostringstream os;
ids->print(os, ids->getSize());
out->insert({"ids", os.str()});
}
if (sequenceStartPositions) {
std::ostringstream os;
sequenceStartPositions->getVector(false)->print(
os, sequenceStartPositions->getSize());
out->insert({"sequence pos", os.str()});
}
if (subSequenceStartPositions) {
std::ostringstream os;
subSequenceStartPositions->getVector(false)->print(
os, subSequenceStartPositions->getSize());
out->insert({"sub-sequence pos", os.str()});
}
}
void Argument::printValueString(std::ostream& stream,
const std::string& prefix) const {
std::unordered_map<std::string, std::string> out;
getValueString(&out);
for (auto field : {"value", "id", "sequence pos", "sub-sequence pos"}) {
auto it = out.find(field);
if (it != out.end()) {
stream << prefix << field << ":\n" << it->second;
}
}
}
void Argument::subArgFrom(const Argument& input, void Argument::subArgFrom(const Argument& input,
size_t offset, size_t offset,
size_t height, size_t height,
......
...@@ -297,6 +297,23 @@ struct Argument { ...@@ -297,6 +297,23 @@ struct Argument {
sequence has sub-sequence degrades to a sequence. sequence has sub-sequence degrades to a sequence.
*/ */
void degradeSequence(const Argument& input, bool useGpu); void degradeSequence(const Argument& input, bool useGpu);
/**
* @brief getValueString will return the argument's output in string. There
* are several kinds of output. The keys of output dictionary are 'value',
* 'id', 'sequence pos', 'sub-sequence pos'.
* @param out [out]: the return values.
*/
void getValueString(std::unordered_map<std::string, std::string>* out) const;
/**
* @brief printValueString will print the argument's output in order of
* 'value', 'id', 'sequence pos', 'sub-sequence pos'.
* @param stream: Output stream
* @param prefix: line prefix for printing.
*/
void printValueString(std::ostream& stream,
const std::string& prefix = "") const;
}; };
} // namespace paddle } // namespace paddle
...@@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) { ...@@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) {
std::ifstream fs(filename, std::ios_base::binary); std::ifstream fs(filename, std::ios_base::binary);
if (!fs) { if (!fs) {
LOG(INFO) << "missing parameters [" << filename << "] while loading model."; LOG(INFO) << "missing parameters [" << filename << "] while loading model.";
if (isStatic()) {
LOG(FATAL) << getName() << " is static but missing, not allowed.";
return false;
}
if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) {
LOG(FATAL) << getName() << " missing, not allowed."; LOG(FATAL) << getName() << " missing, not allowed.";
return false; return false;
......
...@@ -10,28 +10,30 @@ RUN apt-get update && \ ...@@ -10,28 +10,30 @@ RUN apt-get update && \
apt-get install -y wget unzip tar xz-utils bzip2 gzip coreutils && \ apt-get install -y wget unzip tar xz-utils bzip2 gzip coreutils && \
apt-get install -y curl sed grep graphviz libjpeg-dev zlib1g-dev && \ apt-get install -y curl sed grep graphviz libjpeg-dev zlib1g-dev && \
apt-get install -y python-numpy python-matplotlib gcc g++ gfortran && \ apt-get install -y python-numpy python-matplotlib gcc g++ gfortran && \
apt-get install -y automake clang-3.8 llvm-3.8 libclang-3.8-dev && \ apt-get install -y automake && \
apt-get clean -y apt-get clean -y
RUN pip install --upgrade pip && \ RUN pip install --upgrade pip && \
pip install -U protobuf && \ pip install -U "protobuf==3.1.0" && \
pip install -U wheel pillow BeautifulSoup && \ pip install -U wheel pillow BeautifulSoup && \
pip install -U docopt PyYAML sphinx && \ pip install -U docopt PyYAML sphinx && \
pip install -U sphinx_rtd_theme recommonmark jupyter pip install -U sphinx_rtd_theme recommonmark jupyter
RUN curl -sSL https://cmake.org/files/v3.4/cmake-3.4.1.tar.gz | tar -xz && \ RUN curl -sSL https://cmake.org/files/v3.4/cmake-3.4.1.tar.gz | tar -xz && \
cd cmake-3.4.1 && ./bootstrap && make -j4 && make install && \ cd cmake-3.4.1 && ./bootstrap && make -j `nproc` && make install && \
cd .. && rm -rf cmake-3.4.1 cd .. && rm -rf cmake-3.4.1
ARG BUILD_WOBOQ
ARG BUILD_AND_INSTALL ARG BUILD_AND_INSTALL
ARG WITH_AVX ARG WITH_AVX
ARG WITH_DOC ARG WITH_DOC
ARG WITH_STYLE_CHECK ARG WITH_STYLE_CHECK
ENV BUILD_WOBOQ=${BUILD_WOBOQ:-OFF}
ENV BUILD_AND_INSTALL=${BUILD_AND_INSTALL:-OFF} ENV BUILD_AND_INSTALL=${BUILD_AND_INSTALL:-OFF}
ENV WITH_GPU=OFF ENV WITH_GPU=OFF
ENV WITH_AVX=${WITH_AVX:-ON} ENV WITH_AVX=${WITH_AVX:-ON}
ENV WITH_DOC=${WITH_DOC:-ON} ENV WITH_DOC=${WITH_DOC:-OFF}
ENV WITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF} ENV WITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
RUN mkdir /paddle RUN mkdir /paddle
......
...@@ -10,28 +10,30 @@ RUN apt-get update && \ ...@@ -10,28 +10,30 @@ RUN apt-get update && \
apt-get install -y wget unzip tar xz-utils bzip2 gzip coreutils && \ apt-get install -y wget unzip tar xz-utils bzip2 gzip coreutils && \
apt-get install -y curl sed grep graphviz libjpeg-dev zlib1g-dev && \ apt-get install -y curl sed grep graphviz libjpeg-dev zlib1g-dev && \
apt-get install -y python-numpy python-matplotlib gcc g++ gfortran && \ apt-get install -y python-numpy python-matplotlib gcc g++ gfortran && \
apt-get install -y automake clang-3.8 llvm-3.8 libclang-3.8-dev && \ apt-get install -y automake && \
apt-get clean -y apt-get clean -y
RUN pip install --upgrade pip && \ RUN pip install --upgrade pip && \
pip install -U protobuf && \ pip install -U "protobuf==3.1.0" && \
pip install -U wheel pillow BeautifulSoup && \ pip install -U wheel pillow BeautifulSoup && \
pip install -U docopt PyYAML sphinx && \ pip install -U docopt PyYAML sphinx && \
pip install -U sphinx_rtd_theme recommonmark jupyter pip install -U sphinx_rtd_theme recommonmark jupyter
RUN curl -sSL https://cmake.org/files/v3.4/cmake-3.4.1.tar.gz | tar -xz && \ RUN curl -sSL https://cmake.org/files/v3.4/cmake-3.4.1.tar.gz | tar -xz && \
cd cmake-3.4.1 && ./bootstrap && make -j4 && make install && \ cd cmake-3.4.1 && ./bootstrap && make -j `nproc` && make install && \
cd .. && rm -rf cmake-3.4.1 cd .. && rm -rf cmake-3.4.1
ARG BUILD_WOBOQ
ARG BUILD_AND_INSTALL ARG BUILD_AND_INSTALL
ARG WITH_AVX ARG WITH_AVX
ARG WITH_DOC ARG WITH_DOC
ARG WITH_STYLE_CHECK ARG WITH_STYLE_CHECK
ENV BUILD_WOBOQ=${BUILD_WOBOQ:-OFF}
ENV BUILD_AND_INSTALL=${BUILD_AND_INSTALL:-OFF} ENV BUILD_AND_INSTALL=${BUILD_AND_INSTALL:-OFF}
ENV WITH_GPU=ON ENV WITH_GPU=ON
ENV WITH_AVX=${WITH_AVX:-ON} ENV WITH_AVX=${WITH_AVX:-ON}
ENV WITH_DOC=${WITH_DOC:-ON} ENV WITH_DOC=${WITH_DOC:-OFF}
ENV WITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF} ENV WITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
RUN mkdir /paddle RUN mkdir /paddle
......
...@@ -11,7 +11,7 @@ set -e ...@@ -11,7 +11,7 @@ set -e
# If Dockerfile.* sets BUILD_AND_INSTALL to 'ON', it would have copied # If Dockerfile.* sets BUILD_AND_INSTALL to 'ON', it would have copied
# source tree to /paddle, and this scripts should build it into # source tree to /paddle, and this scripts should build it into
# /paddle/build. # /paddle/build.
if [[ ${BUILD_AND_INSTALL:-ON} == 'ON' ]]; then if [[ ${BUILD_AND_INSTALL:-OFF} == 'ON' ]]; then
if [[ ${WITH_GPU:-OFF} == 'ON' ]]; then if [[ ${WITH_GPU:-OFF} == 'ON' ]]; then
ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so /usr/lib/libcudnn.so ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so /usr/lib/libcudnn.so
fi fi
...@@ -19,7 +19,7 @@ if [[ ${BUILD_AND_INSTALL:-ON} == 'ON' ]]; then ...@@ -19,7 +19,7 @@ if [[ ${BUILD_AND_INSTALL:-ON} == 'ON' ]]; then
mkdir -p /paddle/build # -p means no error if exists mkdir -p /paddle/build # -p means no error if exists
cd /paddle/build cd /paddle/build
cmake .. \ cmake .. \
-DWITH_DOC=ON \ -DWITH_DOC=${WITH_DOC:-OFF} \
-DWITH_GPU=${WITH_GPU:-OFF} \ -DWITH_GPU=${WITH_GPU:-OFF} \
-DWITH_AVX=${WITH_AVX:-OFF} \ -DWITH_AVX=${WITH_AVX:-OFF} \
-DWITH_SWIG_PY=ON \ -DWITH_SWIG_PY=ON \
...@@ -29,28 +29,32 @@ if [[ ${BUILD_AND_INSTALL:-ON} == 'ON' ]]; then ...@@ -29,28 +29,32 @@ if [[ ${BUILD_AND_INSTALL:-ON} == 'ON' ]]; then
make -j `nproc` make -j `nproc`
make install make install
# Install woboq_codebrowser. if [[ ${BUILD_WOBOQ:-OFF} == 'ON' ]]; then
git clone https://github.com/woboq/woboq_codebrowser /woboq apt-get install -y clang-3.8 llvm-3.8 libclang-3.8-dev
cd /woboq # Install woboq_codebrowser.
cmake -DLLVM_CONFIG_EXECUTABLE=/usr/bin/llvm-config-3.8 \ git clone https://github.com/woboq/woboq_codebrowser /woboq
-DCMAKE_BUILD_TYPE=Release \ cd /woboq
. cmake -DLLVM_CONFIG_EXECUTABLE=/usr/bin/llvm-config-3.8 \
make -DCMAKE_BUILD_TYPE=Release \
.
export WOBOQ_OUT=/usr/share/nginx/html/paddle make
export BUILD_DIR=/paddle/build
mkdir -p $WOBOQ_OUT export WOBOQ_OUT=/usr/share/nginx/html/paddle
cp -rv /woboq/data $WOBOQ_OUT/../data export BUILD_DIR=/paddle/build
/woboq/generator/codebrowser_generator \ mkdir -p $WOBOQ_OUT
-b /paddle/build \ cp -rv /woboq/data $WOBOQ_OUT/../data
-a \ /woboq/generator/codebrowser_generator \
-o $WOBOQ_OUT \ -b /paddle/build \
-p paddle:/paddle -a \
/woboq/indexgenerator/codebrowser_indexgenerator $WOBOQ_OUT -o $WOBOQ_OUT \
cd /woboq -p paddle:/paddle
make clean /woboq/indexgenerator/codebrowser_indexgenerator $WOBOQ_OUT
cd /woboq
pip install /usr/local/opt/paddle/share/wheels/*.whl make clean
fi
pip install /usr/local/opt/paddle/share/wheels/py_paddle*linux*.whl
pip install /usr/local/opt/paddle/share/wheels/paddle*.whl
paddle version paddle version
fi fi
......
...@@ -55,6 +55,9 @@ elif is_osx == True: ...@@ -55,6 +55,9 @@ elif is_osx == True:
include_dirs = [np.get_include(), "../"] # include numpy and paddle. include_dirs = [np.get_include(), "../"] # include numpy and paddle.
os.environ["CC"] = "@CMAKE_C_COMPILER@"
os.environ["CXX"] = "@CMAKE_CXX_COMPILER@"
setup(name="py_paddle", setup(name="py_paddle",
version="@PADDLE_VERSION@", version="@PADDLE_VERSION@",
ext_modules=[ ext_modules=[
......
...@@ -37,10 +37,10 @@ namespace paddle { ...@@ -37,10 +37,10 @@ namespace paddle {
* *
* Error __must_check bar() { * Error __must_check bar() {
* // do something. * // do something.
* Status s = foo(); // invoke other method return status. * Error err = foo(); // invoke other method return status.
* if (!s) return s; * if (err) return err;
* // do something else. * // do something else.
* return Status(); * return Error();
* } * }
* @endcode{cpp} * @endcode{cpp}
* *
...@@ -53,8 +53,8 @@ namespace paddle { ...@@ -53,8 +53,8 @@ namespace paddle {
* *
* int foo(Error* error) { * int foo(Error* error) {
* // Do something. * // Do something.
* Error s = bar(); * Error err = bar();
* if (!s) { * if (err) {
* *error = s; * *error = s;
* return 0; * return 0;
* } * }
...@@ -68,10 +68,10 @@ namespace paddle { ...@@ -68,10 +68,10 @@ namespace paddle {
* } * }
* *
* Error foobar() { * Error foobar() {
* Error s; * Error err;
* // do something. * // do something.
* foo(&s); * foo(&err);
* if (!s) return s; * if (err) return err;
* } * }
* @endcode{cpp} * @endcode{cpp}
* *
...@@ -112,16 +112,22 @@ public: ...@@ -112,16 +112,22 @@ public:
} }
/** /**
* @brief operator bool, return True if there is no error. * @brief operator bool, return True if there is something error.
*/ */
operator bool() const { return msg_ == nullptr; } operator bool() const { return !this->isOK(); }
/**
* @brief isOK return True if there is no error.
* @return True if no error.
*/
bool isOK() const { return msg_ == nullptr; }
/** /**
* @brief check this status by glog. * @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be * @note It is a temp method used during cleaning Paddle code. It will be
* removed later. * removed later.
*/ */
void check() const { CHECK(*this) << msg(); } void check() const { CHECK(this->isOK()) << msg(); }
private: private:
std::shared_ptr<std::string> msg_; std::shared_ptr<std::string> msg_;
......
...@@ -289,6 +289,7 @@ void mkDir(const char* filename) { ...@@ -289,6 +289,7 @@ void mkDir(const char* filename) {
void mkDirRecursively(const char* dir) { void mkDirRecursively(const char* dir) {
struct stat sb; struct stat sb;
if (*dir == 0) return; // empty string
if (!stat(dir, &sb)) return; if (!stat(dir, &sb)) return;
mkDirRecursively(path::dirname(dir).c_str()); mkDirRecursively(path::dirname(dir).c_str());
......
...@@ -18,17 +18,17 @@ limitations under the License. */ ...@@ -18,17 +18,17 @@ limitations under the License. */
TEST(Error, testAll) { TEST(Error, testAll) {
paddle::Error error; paddle::Error error;
ASSERT_TRUE(error);
error = paddle::Error("I'm the error");
ASSERT_FALSE(error); ASSERT_FALSE(error);
error = paddle::Error("I'm the error");
ASSERT_TRUE(error);
ASSERT_STREQ("I'm the error", error.msg()); ASSERT_STREQ("I'm the error", error.msg());
error = paddle::Error("error2"); error = paddle::Error("error2");
ASSERT_FALSE(error); ASSERT_TRUE(error);
ASSERT_STREQ("error2", error.msg()); ASSERT_STREQ("error2", error.msg());
int i = 3; int i = 3;
auto error3 = paddle::Error("error%d", i); auto error3 = paddle::Error("error%d", i);
ASSERT_FALSE(error3); ASSERT_TRUE(error3);
ASSERT_STREQ("error3", error3.msg()); ASSERT_STREQ("error3", error3.msg());
} }
...@@ -475,6 +475,10 @@ message EvaluatorConfig { ...@@ -475,6 +475,10 @@ message EvaluatorConfig {
// Used by ChunkEvaluator // Used by ChunkEvaluator
// chunk of these types are not counted // chunk of these types are not counted
repeated int32 excluded_chunk_types = 12; repeated int32 excluded_chunk_types = 12;
// Used by ClassificationErrorEvaluator
// top # classification error
optional int32 top_k = 13 [default = 1];
} }
message LinkConfig { message LinkConfig {
......
...@@ -24,6 +24,8 @@ add_custom_target(paddle_python ALL DEPENDS ...@@ -24,6 +24,8 @@ add_custom_target(paddle_python ALL DEPENDS
${OUTPUT_DIR}/.timestamp) ${OUTPUT_DIR}/.timestamp)
add_subdirectory(paddle/trainer_config_helpers/tests) add_subdirectory(paddle/trainer_config_helpers/tests)
add_subdirectory(paddle/reader/tests)
add_subdirectory(paddle/v2/tests)
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/dist/ install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/dist/
DESTINATION opt/paddle/share/wheels DESTINATION opt/paddle/share/wheels
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# It would be too lengthy to require our users to prefix decorators with `decorator`.
# For example, we want the following line
#
# r = paddle.reader.decorator.bufferd(paddle.reader.creator.text("hello.txt"))
#
# to be a shorter version:
#
# r = paddle.reader.buffered(paddle.reader.creator.text("hello.txt"))
from decorator import *
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned'
]
from Queue import Queue
from threading import Thread
import itertools
import random
def map_readers(func, *readers):
"""
Creates a data reader that outputs return value of function using
output of each data readers as arguments.
:param func: function to use.
:param *readers: readers whose outputs will be used as arguments of func.
:returns: the created data reader.
"""
def reader():
rs = []
for r in readers:
rs.append(r())
for e in itertools.imap(func, *rs):
yield e
return reader
def shuffle(reader, buf_size):
"""
Creates a data reader whose data output is suffled.
Output from the iterator that created by original reader will be
buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
is determined by argument buf_size.
:param reader: the original reader whose output will be shuffled.
:param buf_size: shuffle buffer size.
:returns:the new reader whose output is shuffled.
"""
def data_reader():
buf = []
for e in reader():
buf.append(e)
if len(buf) >= buf_size:
random.shuffle(buf)
for b in buf:
yield b
buf = []
if len(buf) > 0:
random.shuffle(buf)
for b in buf:
yield b
return data_reader
def chain(*readers):
"""
Creates a data reader whose output is the outputs of input data
readers chained together.
If input readers output following data entries:
[0, 0, 0]
[1, 1, 1]
[2, 2, 2]
The chained reader will output:
[0, 0, 0, 1, 1, 1, 2, 2, 2]
:param readers: input readers.
:returns: the new data reader.
"""
def reader():
rs = []
for r in readers:
rs.append(r())
for e in itertools.chain(*rs):
yield e
return reader
class ComposeNotAligned(ValueError):
pass
def compose(*readers, **kwargs):
"""
Creates a data reader whose output is the combination of input readers.
If input readers output following data entries:
(1, 2) 3 (4, 5)
The composed reader will output:
(1, 2, 3, 4, 5)
:*readers: readers that will be composed together.
:check_alignment: if True, will check if input readers are aligned
correctly. If False, will not check alignment and trailing outputs
will be discarded. Defaults to True.
:returns: the new data reader.
:raises ComposeNotAligned: outputs of readers are not aligned.
Will not raise when check_alignment is set to False.
"""
check_alignment = kwargs.pop('check_alignment', True)
def make_tuple(x):
if isinstance(x, tuple):
return x
else:
return (x, )
def reader():
rs = []
for r in readers:
rs.append(r())
if not check_alignment:
for outputs in itertools.izip(*rs):
yield sum(map(make_tuple, outputs), ())
else:
for outputs in itertools.izip_longest(*rs):
for o in outputs:
if o is None:
# None will be not be present if compose is aligned
raise ComposeNotAligned(
"outputs of readers are not aligned.")
yield sum(map(make_tuple, outputs), ())
return reader
def buffered(reader, size):
"""
Creates a buffered data reader.
The buffered data reader will read and save data entries into a
buffer. Reading from the buffered data reader will proceed as long
as the buffer is not empty.
:param reader: the data reader to read from.
:param size: max buffer size.
:returns: the buffered data reader.
"""
class EndSignal():
pass
end = EndSignal()
def read_worker(r, q):
for d in r:
q.put(d)
q.put(end)
def data_reader():
r = reader()
q = Queue(maxsize=size)
t = Thread(
target=read_worker, args=(
r,
q, ))
t.daemon = True
t.start()
e = q.get()
while e != end:
yield e
e = q.get()
return data_reader
add_test(NAME reader_decorator_test
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/reader/tests/decorator_test.py
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
# Copyright PaddlePaddle contributors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.reader
import time
def reader_creator_10(dur):
def reader():
for i in range(10):
# this invocation helps testing paddle.reader.buffer
time.sleep(dur)
yield i
return reader
class TestMap(unittest.TestCase):
def test_map(self):
d = {"h": 0, "i": 1}
def tokenize(x):
return d[x]
def read():
yield "h"
yield "i"
r = paddle.reader.map_readers(tokenize, read)
for i, e in enumerate(r()):
self.assertEqual(e, i)
class TestBuffered(unittest.TestCase):
def test_read(self):
for size in range(20):
b = paddle.reader.buffered(reader_creator_10(0), size)
c = 0
for i in b():
self.assertEqual(i, c)
c += 1
self.assertEqual(c, 10)
def test_buffering(self):
# read have 30ms delay.
b = paddle.reader.buffered(reader_creator_10(0.03), 10)
last_time = time.time()
for idx, i in enumerate(b()):
elapsed_time = time.time() - last_time
if i == 0:
time.sleep(0.3)
else:
# read time should be short, meaning already buffered.
self.assertLess(elapsed_time, 0.05)
last_time = time.time()
class TestCompose(unittest.TestCase):
def test_compse(self):
reader = paddle.reader.compose(
reader_creator_10(0), reader_creator_10(0))
for idx, e in enumerate(reader()):
self.assertEqual(e, (idx, idx))
def test_compose_not_aligned(self):
total = 0
reader = paddle.reader.compose(
paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)),
reader_creator_10(0))
with self.assertRaises(paddle.reader.ComposeNotAligned):
for e in reader():
total += 1
# expecting 10, not 20
self.assertEqual(total, 10)
def test_compose_not_aligned_no_check(self):
total = 0
reader = paddle.reader.compose(
paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)),
reader_creator_10(0),
check_alignment=False)
for e in reader():
total += 1
# expecting 10, not 20
self.assertEqual(total, 10)
class TestChain(unittest.TestCase):
def test_chain(self):
c = paddle.reader.chain(reader_creator_10(0), reader_creator_10(0))
idx = 0
for e in c():
self.assertEqual(e, idx % 10)
idx += 1
self.assertEqual(idx, 20)
class TestShuffle(unittest.TestCase):
def test_shuffle(self):
case = [(0, True), (1, True), (10, False), (100, False)]
a = reader_creator_10(0)
for size, checkEq in case:
s = paddle.reader.shuffle(a, size)
total = 0
for idx, e in enumerate(s()):
if checkEq:
self.assertEqual(idx, e)
total += 1
self.assertEqual(total, 10)
if __name__ == '__main__':
unittest.main()
...@@ -893,11 +893,11 @@ class MaxOut(Cfg): ...@@ -893,11 +893,11 @@ class MaxOut(Cfg):
self.add_keys(locals()) self.add_keys(locals())
def DataBase(async_load_data=False, def create_data_config_proto(async_load_data=False,
constant_slots=None, constant_slots=None,
data_ratio=1, data_ratio=1,
is_main_data=True, is_main_data=True,
usage_ratio=None): usage_ratio=None):
# default: all sub dataproviders are treat as "main data". # default: all sub dataproviders are treat as "main data".
# see proto/DataConfig.proto for is_main_data # see proto/DataConfig.proto for is_main_data
data_config = DataConfig() data_config = DataConfig()
...@@ -923,7 +923,7 @@ def SimpleData(files=None, ...@@ -923,7 +923,7 @@ def SimpleData(files=None,
context_len=None, context_len=None,
buffer_capacity=None, buffer_capacity=None,
**xargs): **xargs):
data_config = DataBase(**xargs) data_config = create_data_config_proto(**xargs)
data_config.type = 'simple' data_config.type = 'simple'
data_config.files = files data_config.files = files
data_config.feat_dim = feat_dim data_config.feat_dim = feat_dim
...@@ -945,7 +945,7 @@ def PyData(files=None, ...@@ -945,7 +945,7 @@ def PyData(files=None,
constant_slots=None, constant_slots=None,
load_thread_num=None, load_thread_num=None,
**xargs): **xargs):
data_config = DataBase(**xargs) data_config = create_data_config_proto(**xargs)
data_config.type = 'py' data_config.type = 'py'
if load_data_module in g_py_module_name_list: if load_data_module in g_py_module_name_list:
...@@ -996,7 +996,7 @@ def ProtoData(files=None, ...@@ -996,7 +996,7 @@ def ProtoData(files=None,
constant_slots=None, constant_slots=None,
load_thread_num=None, load_thread_num=None,
**xargs): **xargs):
data_config = DataBase(**xargs) data_config = create_data_config_proto(**xargs)
if type is None: if type is None:
data_config.type = 'proto' data_config.type = 'proto'
else: else:
...@@ -1035,7 +1035,7 @@ def Data(type, ...@@ -1035,7 +1035,7 @@ def Data(type,
buffer_capacity=None, buffer_capacity=None,
**xargs): **xargs):
data_config = DataBase(**xargs) data_config = create_data_config_proto(**xargs)
data_config.type = type data_config.type = type
data_config.files = files data_config.files = files
data_config.feat_dim = feat_dim data_config.feat_dim = feat_dim
...@@ -1253,6 +1253,7 @@ def Evaluator( ...@@ -1253,6 +1253,7 @@ def Evaluator(
dict_file=None, dict_file=None,
result_file=None, result_file=None,
num_results=None, num_results=None,
top_k=None,
delimited=None, delimited=None,
excluded_chunk_types=None, ): excluded_chunk_types=None, ):
evaluator = g_config.model_config.evaluators.add() evaluator = g_config.model_config.evaluators.add()
...@@ -1280,6 +1281,8 @@ def Evaluator( ...@@ -1280,6 +1281,8 @@ def Evaluator(
evaluator.result_file = result_file evaluator.result_file = result_file
if num_results is not None: if num_results is not None:
evaluator.num_results = num_results evaluator.num_results = num_results
if top_k is not None:
evaluator.top_k = top_k
if delimited is not None: if delimited is not None:
evaluator.delimited = delimited evaluator.delimited = delimited
......
...@@ -58,8 +58,8 @@ def define_py_data_source(file_list, ...@@ -58,8 +58,8 @@ def define_py_data_source(file_list,
:param obj: python object name. May be a function name if using :param obj: python object name. May be a function name if using
PyDataProviderWrapper. PyDataProviderWrapper.
:type obj: basestring :type obj: basestring
:param args: The best practice is using dict to pass arguments into :param args: The best practice is using dict to pass arguments into
DataProvider, and use :code:`@init_hook_wrapper` to DataProvider, and use :code:`@init_hook_wrapper` to
receive arguments. receive arguments.
:type args: string or picklable object :type args: string or picklable object
:param async: Load Data asynchronously or not. :param async: Load Data asynchronously or not.
...@@ -98,7 +98,7 @@ def define_py_data_sources(train_list, ...@@ -98,7 +98,7 @@ def define_py_data_sources(train_list,
The annotation is almost the same as define_py_data_sources2, except that The annotation is almost the same as define_py_data_sources2, except that
it can specific train_async and data_cls. it can specific train_async and data_cls.
:param data_cls: :param data_cls:
:param train_list: Train list name. :param train_list: Train list name.
:type train_list: basestring :type train_list: basestring
:param test_list: Test list name. :param test_list: Test list name.
...@@ -111,8 +111,8 @@ def define_py_data_sources(train_list, ...@@ -111,8 +111,8 @@ def define_py_data_sources(train_list,
a tuple or list to this argument. a tuple or list to this argument.
:type obj: basestring or tuple or list :type obj: basestring or tuple or list
:param args: The best practice is using dict() to pass arguments into :param args: The best practice is using dict() to pass arguments into
DataProvider, and use :code:`@init_hook_wrapper` to receive DataProvider, and use :code:`@init_hook_wrapper` to receive
arguments. If train and test is different, then pass a tuple arguments. If train and test is different, then pass a tuple
or list to this argument. or list to this argument.
:type args: string or picklable object or list or tuple. :type args: string or picklable object or list or tuple.
:param train_async: Is training data load asynchronously or not. :param train_async: Is training data load asynchronously or not.
...@@ -163,12 +163,12 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None): ...@@ -163,12 +163,12 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None):
.. code-block:: python .. code-block:: python
define_py_data_sources2(train_list="train.list", define_py_data_sources2(train_list="train.list",
test_list="test.list", test_list="test.list",
module="data_provider" module="data_provider"
# if train/test use different configurations, # if train/test use different configurations,
# obj=["process_train", "process_test"] # obj=["process_train", "process_test"]
obj="process", obj="process",
args={"dictionary": dict_name}) args={"dictionary": dict_name})
The related data provider can refer to :ref:`api_pydataprovider2_sequential_model` . The related data provider can refer to :ref:`api_pydataprovider2_sequential_model` .
...@@ -185,8 +185,8 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None): ...@@ -185,8 +185,8 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None):
a tuple or list to this argument. a tuple or list to this argument.
:type obj: basestring or tuple or list :type obj: basestring or tuple or list
:param args: The best practice is using dict() to pass arguments into :param args: The best practice is using dict() to pass arguments into
DataProvider, and use :code:`@init_hook_wrapper` to receive DataProvider, and use :code:`@init_hook_wrapper` to receive
arguments. If train and test is different, then pass a tuple arguments. If train and test is different, then pass a tuple
or list to this argument. or list to this argument.
:type args: string or picklable object or list or tuple. :type args: string or picklable object or list or tuple.
:return: None :return: None
...@@ -195,13 +195,13 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None): ...@@ -195,13 +195,13 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None):
def py_data2(files, load_data_module, load_data_object, load_data_args, def py_data2(files, load_data_module, load_data_object, load_data_args,
**kwargs): **kwargs):
data = DataBase() data = create_data_config_proto()
data.type = 'py2' data.type = 'py2'
data.files = files data.files = files
data.load_data_module = load_data_module data.load_data_module = load_data_module
data.load_data_object = load_data_object data.load_data_object = load_data_object
data.load_data_args = load_data_args data.load_data_args = load_data_args
data.async_load_data = True data.async_load_data = False
return data return data
define_py_data_sources( define_py_data_sources(
......
...@@ -71,6 +71,7 @@ def evaluator_base( ...@@ -71,6 +71,7 @@ def evaluator_base(
result_file=None, result_file=None,
num_results=None, num_results=None,
delimited=None, delimited=None,
top_k=None,
excluded_chunk_types=None, ): excluded_chunk_types=None, ):
""" """
Evaluator will evaluate the network status while training/testing. Evaluator will evaluate the network status while training/testing.
...@@ -104,12 +105,15 @@ def evaluator_base( ...@@ -104,12 +105,15 @@ def evaluator_base(
:param weight: An input layer which is a weight for each sample. :param weight: An input layer which is a weight for each sample.
Each evaluator may calculate differently to use this weight. Each evaluator may calculate differently to use this weight.
:type weight: LayerOutput. :type weight: LayerOutput.
:param top_k: number k in top-k error rate
:type top_k: int
""" """
# inputs type assertions. # inputs type assertions.
assert classification_threshold is None or isinstance( assert classification_threshold is None or isinstance(
classification_threshold, float) classification_threshold, float)
assert positive_label is None or isinstance(positive_label, int) assert positive_label is None or isinstance(positive_label, int)
assert num_results is None or isinstance(num_results, int) assert num_results is None or isinstance(num_results, int)
assert top_k is None or isinstance(top_k, int)
if not isinstance(input, list): if not isinstance(input, list):
input = [input] input = [input]
...@@ -130,6 +134,8 @@ def evaluator_base( ...@@ -130,6 +134,8 @@ def evaluator_base(
dict_file=dict_file, dict_file=dict_file,
result_file=result_file, result_file=result_file,
delimited=delimited, delimited=delimited,
num_results=num_results,
top_k=top_k,
excluded_chunk_types=excluded_chunk_types, ) excluded_chunk_types=excluded_chunk_types, )
...@@ -139,6 +145,7 @@ def classification_error_evaluator(input, ...@@ -139,6 +145,7 @@ def classification_error_evaluator(input,
label, label,
name=None, name=None,
weight=None, weight=None,
top_k=None,
threshold=None): threshold=None):
""" """
Classification Error Evaluator. It will print error rate for classification. Classification Error Evaluator. It will print error rate for classification.
...@@ -167,6 +174,8 @@ def classification_error_evaluator(input, ...@@ -167,6 +174,8 @@ def classification_error_evaluator(input,
then means not set weight. The larger weight it is, the more then means not set weight. The larger weight it is, the more
important this sample is. important this sample is.
:type weight: LayerOutput :type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param threshold: The classification threshold. :param threshold: The classification threshold.
:type threshold: float :type threshold: float
:return: None. :return: None.
...@@ -178,6 +187,7 @@ def classification_error_evaluator(input, ...@@ -178,6 +187,7 @@ def classification_error_evaluator(input,
input=input, input=input,
label=label, label=label,
weight=weight, weight=weight,
top_k=top_k,
classification_threshold=threshold, ) classification_threshold=threshold, )
......
...@@ -37,6 +37,7 @@ __all__ = [ ...@@ -37,6 +37,7 @@ __all__ = [
"dotmul_projection", "dotmul_projection",
"dotmul_operator", "dotmul_operator",
"repeat_layer", "repeat_layer",
"seq_reshape_layer",
"table_projection", "table_projection",
"mixed_layer", "mixed_layer",
"data_layer", "data_layer",
...@@ -59,6 +60,7 @@ __all__ = [ ...@@ -59,6 +60,7 @@ __all__ = [
'img_cmrnorm_layer', 'img_cmrnorm_layer',
'addto_layer', 'addto_layer',
'concat_layer', 'concat_layer',
'seq_concat_layer',
'lstm_step_layer', 'lstm_step_layer',
'recurrent_group', 'recurrent_group',
'memory', 'memory',
...@@ -124,6 +126,7 @@ class LayerType(object): ...@@ -124,6 +126,7 @@ class LayerType(object):
GRUMEMORY = "gated_recurrent" GRUMEMORY = "gated_recurrent"
SEQUENCE_LAST_INSTANCE = "seqlastins" SEQUENCE_LAST_INSTANCE = "seqlastins"
SEQUENCE_FIRST_INSTANCE = "seqfirstins" SEQUENCE_FIRST_INSTANCE = "seqfirstins"
SEQUENCE_RESHAPE = "seqreshape"
POOLING_MAX = "max" POOLING_MAX = "max"
POOLING_AVG = 'average' POOLING_AVG = 'average'
FC_LAYER = "fc" FC_LAYER = "fc"
...@@ -144,6 +147,7 @@ class LayerType(object): ...@@ -144,6 +147,7 @@ class LayerType(object):
CONCAT_LAYER = 'concat' CONCAT_LAYER = 'concat'
CONCAT_PROJ_LAYER = 'concat2' CONCAT_PROJ_LAYER = 'concat2'
SEQUENCE_CONCAT_LAYER = 'seqconcat'
LSTM_STEP_LAYER = 'lstm_step' LSTM_STEP_LAYER = 'lstm_step'
GRU_STEP_LAYER = 'gru_step' GRU_STEP_LAYER = 'gru_step'
...@@ -1448,6 +1452,61 @@ def repeat_layer(input, num_repeats, name=None, layer_attr=None): ...@@ -1448,6 +1452,61 @@ def repeat_layer(input, num_repeats, name=None, layer_attr=None):
parents=[input]) parents=[input])
@wrap_name_default("seqreshape")
@wrap_act_default(act=IdentityActivation())
@wrap_bias_attr_default(has_bias=False)
@layer_support()
def seq_reshape_layer(input,
reshape_size,
act=None,
name=None,
layer_attr=None,
bias_attr=None):
"""
A layer for reshaping the sequence. Assume the input sequence has T instances,
the dimension of each instance is M, and the input reshape_size is N, then the
output sequence has T*M/N instances, the dimension of each instance is N.
Note that T*M/N must be an integer.
The example usage is:
.. code-block:: python
reshape = seq_reshape_layer(input=layer, reshape_size=4)
:param input: Input layer.
:type input: LayerOutput
:param reshape_size: the size of reshaped sequence.
:type reshape_size: int
:param name: Layer name.
:type name: basestring
:param act: Activation type.
:type act: BaseActivation
:param layer_attr: extra layer attributes.
:type layer_attr: ExtraLayerAttribute.
:param bias_attr: The Bias Attribute. If no bias, then pass False or
something not type of ParameterAttribute. None will get a
default Bias.
:type bias_attr: ParameterAttribute or None or bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
Layer(
inputs=[input.name],
name=name,
size=reshape_size,
type=LayerType.SEQUENCE_RESHAPE,
bias=ParamAttr.to_bias(bias_attr),
**ExtraAttr.to_kwargs(layer_attr))
return LayerOutput(
name=name,
size=reshape_size,
layer_type=LayerType.SEQUENCE_RESHAPE,
parents=[input])
@wrap_name_default() @wrap_name_default()
@layer_support() @layer_support()
def interpolation_layer(input, weight, name=None, layer_attr=None): def interpolation_layer(input, weight, name=None, layer_attr=None):
...@@ -2570,6 +2629,63 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None): ...@@ -2570,6 +2629,63 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None):
size=sz) size=sz)
@wrap_name_default("seqconcat")
@wrap_act_default(act=IdentityActivation())
@wrap_bias_attr_default(has_bias=False)
@layer_support()
def seq_concat_layer(a, b, act=None, name=None, layer_attr=None,
bias_attr=None):
"""
Concat sequence a with sequence b.
Inputs:
- a = [a1, a2, ..., an]
- b = [b1, b2, ..., bn]
- Note that the length of a and b should be the same.
Output: [a1, b1, a2, b2, ..., an, bn]
The example usage is:
.. code-block:: python
concat = seq_concat_layer(a=layer1, b=layer2)
:param name: Layer name.
:type name: basestring
:param a: input sequence layer
:type a: LayerOutput
:param b: input sequence layer
:type b: LayerOutput
:param act: Activation type.
:type act: BaseActivation
:param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute
:param bias_attr: The Bias Attribute. If no bias, then pass False or
something not type of ParameterAttribute. None will get a
default Bias.
:type bias_attr: ParameterAttribute or None or bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(a, LayerOutput) and isinstance(b, LayerOutput)
assert a.size == b.size
Layer(
name=name,
type=LayerType.SEQUENCE_CONCAT_LAYER,
inputs=[a.name, b.name],
active_type=act.name,
bias=ParamAttr.to_bias(bias_attr),
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name,
layer_type=LayerType.SEQUENCE_CONCAT_LAYER,
parents=[a, b],
activation=act,
size=a.size)
def memory(name, def memory(name,
size, size,
is_seq=False, is_seq=False,
...@@ -2754,8 +2870,8 @@ def gru_step_layer(input, ...@@ -2754,8 +2870,8 @@ def gru_step_layer(input,
:param name: :param name:
:param gate_act: :param gate_act:
:param bias_attr: :param bias_attr:
:param param_attr: the parameter_attribute for transforming the output_mem :param param_attr: the parameter_attribute for transforming the output_mem
from previous step. from previous step.
:param layer_attr: :param layer_attr:
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
...@@ -2766,10 +2882,10 @@ def gru_step_layer(input, ...@@ -2766,10 +2882,10 @@ def gru_step_layer(input,
Layer( Layer(
name=name, name=name,
type=LayerType.GRU_STEP_LAYER, type=LayerType.GRU_STEP_LAYER,
# The parameter here is for transforming the output_mem. The input has # The parameter here is for transforming the output_mem. The input has
# already been transformed outside this module so it does not need # already been transformed outside this module so it does not need
# parameter associated with it. # parameter associated with it.
# The parameter here is instead grouped with input is due to # The parameter here is instead grouped with input is due to
# backward model compatibility. # backward model compatibility.
inputs=[Input(input.name, **param_attr.attr), output_mem.name], inputs=[Input(input.name, **param_attr.attr), output_mem.name],
bias=ParamAttr.to_bias(bias_attr), bias=ParamAttr.to_bias(bias_attr),
...@@ -3420,6 +3536,7 @@ def classification_cost(input, ...@@ -3420,6 +3536,7 @@ def classification_cost(input,
label, label,
weight=None, weight=None,
name=None, name=None,
top_k=None,
evaluator=classification_error_evaluator, evaluator=classification_error_evaluator,
layer_attr=None): layer_attr=None):
""" """
...@@ -3434,6 +3551,8 @@ def classification_cost(input, ...@@ -3434,6 +3551,8 @@ def classification_cost(input,
:param weight: The weight affects the cost, namely the scale of cost. :param weight: The weight affects the cost, namely the scale of cost.
It is an optional argument. It is an optional argument.
:type weight: LayerOutput :type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param evaluator: Evaluator method. :param evaluator: Evaluator method.
:param layer_attr: layer's extra attribute. :param layer_attr: layer's extra attribute.
:type layer_attr: ExtraLayerAttribute :type layer_attr: ExtraLayerAttribute
...@@ -3461,7 +3580,7 @@ def classification_cost(input, ...@@ -3461,7 +3580,7 @@ def classification_cost(input,
assert isinstance(e.for_classification, bool) assert isinstance(e.for_classification, bool)
assert e.for_classification assert e.for_classification
e(name=e.__name__, input=input, label=label, weight=weight) e(name=e.__name__, input=input, label=label, weight=weight, top_k=top_k)
if not isinstance(evaluator, collections.Sequence): if not isinstance(evaluator, collections.Sequence):
evaluator = [evaluator] evaluator = [evaluator]
...@@ -3677,26 +3796,27 @@ def pad_layer(input, ...@@ -3677,26 +3796,27 @@ def pad_layer(input,
For example, For example,
.. code-block:: .. code-block:: python
input(2,2,2,3) = [ input(2,2,2,3) = [
[ [[1,2,3], [3,4,5]], [ [[1,2,3], [3,4,5]],
[[2,3,5], [1,6,7]] ], [[2,3,5], [1,6,7]] ],
[ [[4,3,1], [1,8,7]], [ [[4,3,1], [1,8,7]],
[[3,8,9], [2,3,5]] ] [[3,8,9], [2,3,5]] ]
] ]
pad_c=[1,1], pad_h=[0,0], pad_w=[0,0] pad_c=[1,1], pad_h=[0,0], pad_w=[0,0]
output(2,4,2,3) = [
[ [[0,0,0], [0,0,0]], output(2,4,2,3) = [
[[1,2,3], [3,4,5]], [ [[0,0,0], [0,0,0]],
[[2,3,5], [1,6,7]], [[1,2,3], [3,4,5]],
[[0,0,0], [0,0,0]] ], [[2,3,5], [1,6,7]],
[ [[0,0,0], [0,0,0]], [[0,0,0], [0,0,0]] ],
[[4,3,1], [1,8,7]], [ [[0,0,0], [0,0,0]],
[[3,8,9], [2,3,5]], [[4,3,1], [1,8,7]],
[[0,0,0], [0,0,0]] ] [[3,8,9], [2,3,5]],
] [[0,0,0], [0,0,0]] ]
]
The simply usage is: The simply usage is:
...@@ -4191,13 +4311,7 @@ def block_expand_layer(input, ...@@ -4191,13 +4311,7 @@ def block_expand_layer(input,
@wrap_name_default() @wrap_name_default()
@layer_support() @layer_support()
def maxout_layer(input, def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None):
groups,
num_channels=None,
size_x=None,
size_y=None,
name=None,
layer_attr=None):
""" """
A layer to do max out on conv layer output. A layer to do max out on conv layer output.
- Input: output of a conv layer. - Input: output of a conv layer.
...@@ -4227,12 +4341,6 @@ def maxout_layer(input, ...@@ -4227,12 +4341,6 @@ def maxout_layer(input,
:type num_channels: int|None :type num_channels: int|None
:param groups: The group number of input layer. :param groups: The group number of input layer.
:type groups: int :type groups: int
:param size_x: conv output width. If None will be set
automatically from previous output.
:type size_x: int|None
:param size_y: conv output height. If None will be set
automatically from previous output.
:type size_y: int|None
:param name: The name of this layer, which can not specify. :param name: The name of this layer, which can not specify.
:type name: None|basestring. :type name: None|basestring.
:param layer_attr: Extra Layer attribute. :param layer_attr: Extra Layer attribute.
......
...@@ -4,6 +4,7 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer ...@@ -4,6 +4,7 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer
last_first_seq test_expand_layer test_ntm_layers test_hsigmoid last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops) test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape)
export whole_configs=(test_split_datasource) export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "data1"
type: "data"
size: 30
active_type: ""
}
layers {
name: "data2"
type: "data"
size: 30
active_type: ""
}
layers {
name: "__seqconcat_0__"
type: "seqconcat"
size: 30
active_type: ""
inputs {
input_layer_name: "data1"
}
inputs {
input_layer_name: "data2"
}
}
layers {
name: "__seqreshape_0__"
type: "seqreshape"
size: 5
active_type: "linear"
inputs {
input_layer_name: "data1"
}
}
input_layer_names: "data1"
input_layer_names: "data2"
output_layer_names: "__seqconcat_0__"
output_layer_names: "__seqreshape_0__"
sub_models {
name: "root"
layer_names: "data1"
layer_names: "data2"
layer_names: "__seqconcat_0__"
layer_names: "__seqreshape_0__"
input_layer_names: "data1"
input_layer_names: "data2"
output_layer_names: "__seqconcat_0__"
output_layer_names: "__seqreshape_0__"
is_recurrent_layer_group: false
}
...@@ -19,7 +19,7 @@ model_config { ...@@ -19,7 +19,7 @@ model_config {
data_config { data_config {
type: "py2" type: "py2"
files: "train.list" files: "train.list"
async_load_data: true async_load_data: false
for_test: false for_test: false
load_data_module: "a" load_data_module: "a"
load_data_object: "c" load_data_object: "c"
...@@ -58,7 +58,7 @@ opt_config { ...@@ -58,7 +58,7 @@ opt_config {
test_data_config { test_data_config {
type: "py2" type: "py2"
files: "test.list" files: "test.list"
async_load_data: true async_load_data: false
for_test: true for_test: true
load_data_module: "b" load_data_module: "b"
load_data_object: "d" load_data_object: "d"
......
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
din1 = data_layer(name='data1', size=30)
din2 = data_layer(name='data2', size=30)
opts = []
opts.append(seq_concat_layer(a=din1, b=din2))
opts.append(seq_reshape_layer(input=din1, reshape_size=5))
outputs(opts)
...@@ -11,7 +11,25 @@ ...@@ -11,7 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import optimizer import optimizer
import layer
import activation
import parameters
import trainer
import event
import data_type
import attr
import py_paddle.swig_paddle as api
__all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
'event', 'data_type', 'attr'
]
def init(**kwargs):
args = []
for key in kwargs.keys():
args.append('--%s=%s' % (key, str(kwargs[key])))
__all__ = ['optimizer'] api.initPaddle(*args)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers.activations import *
__all__ = [
"Base", "Tanh", "Sigmoid", "Softmax", "Identity", "Linear",
'SequenceSoftmax', "Exp", "Relu", "BRelu", "SoftRelu", "STanh", "Abs",
"Square", "Log"
]
Base = BaseActivation
Tanh = TanhActivation
Sigmoid = SigmoidActivation
Softmax = SoftmaxActivation
SequenceSoftmax = SequenceSoftmaxActivation
Identity = IdentityActivation
Linear = Identity
Relu = ReluActivation
BRelu = BReluActivation
SoftRelu = SoftReluActivation
STanh = STanhActivation
Abs = AbsActivation
Square = SquareActivation
Exp = ExpActivation
Log = LogActivation
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers.attrs import *
__all__ = [
"Param",
"Extra",
]
Param = ParameterAttribute
Extra = ExtraLayerAttribute
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer.PyDataProvider2 import \
InputType, dense_vector, sparse_binary_vector,\
sparse_vector, integer_value
__all__ = [
'InputType', 'dense_vector', 'sparse_binary_vector', 'sparse_vector',
'integer_value'
]
"""
All training events.
There are:
* BeginTraining
* EndTraining
* BeginIteration
* EndIteration
* BeginPass
* EndPass
TODO(yuyang18): Complete it!
"""
__all__ = ['EndIteration']
class EndIteration(object):
"""
Event On One Batch Training Complete.
"""
def __init__(self, pass_id, batch_id, cost):
self.pass_id = pass_id
self.batch_id = batch_id
self.cost = cost
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Before this new package paddle.v2.layer, users would need to use functions
in paddle.trainer_config_helpers.layers to configure networks.
The Old Way:
=========
This old way requires that the creation of a network be defined in a Python
function, say network_config, and that this Python function being passed to
paddle.trainer_config_helpers.parse_network_config for the creation of
protobuf message description of this network.
```python
def network_config():
img = paddle.trainer_config_helpers.data_layer(name="pixel", size=784)
inference = paddle.trainer_config_helpers.fc_layer(
input=img,
size=10,
act=paddle.trainer_config_helpers.SoftmaxActivation())
cost = paddle.trainer_config_helpers.classification_cost(
input=inference,
label=paddle.trainer_config_helpers.data_layer(name="label", size=10))
proto_desc = parse_network_config(network_config)
```
When parse_network_config executes network_config, those layer definition
functions like data_layer and fc_layer would change some Python global variables,
so that after the execution, parse_network_config could collect information from
these global variables and generates the protobuf message.
The New Way:
=========
In this PR, we define a function in paddle.v2.layer which creates a Python
class for each layer creation function in paddle.trainer_config_helpers.layers.
Users can use create a network as follows:
```python
img = paddle.v2.layer.data(name="pixel", size=784)
inference = paddle.v2.layer.fc(input=img, size=10, act=paddle.v2.layer.Softmax())
cost = paddle.v2.layer.classification(
input=inference,
label=paddle.v2.layer.data(name="label", size=10))
parameters = paddle.v2.parameters.create(cost)
```
This new way doesn't require those invocations to layer definition functions
to be in a Python function but could be anywhere.
Also, the creation of a protobuf message is hidden in the invocation of
paddle.v2.parameters.create, no longer exposed to users.
"""
import collections
import paddle.trainer_config_helpers as conf_helps
from paddle.trainer_config_helpers.config_parser_utils import \
parse_network_config as __parse__
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
import data_type
import activation
import attr
__all__ = [
'parse_network', 'data', 'fc', 'max_id', 'classification_cost',
'cross_entropy_cost', 'cross_entropy_with_selfnorm_cost', 'regression_cost',
'multi_binary_label_cross_entropy_cost', 'rank_cost', 'lambda_cost',
'sum_cost', 'huber_cost'
]
def parse_network(*outputs):
"""
parse all output layers and then generate a model config proto.
:param outputs:
:return:
"""
def __real_func__():
context = dict()
real_output = [each.to_proto(context=context) for each in outputs]
conf_helps.outputs(real_output)
return __parse__(__real_func__)
class Layer(object):
def __init__(self, name, parent_layers):
assert isinstance(parent_layers, dict)
assert isinstance(name, basestring)
self.name = name
self.__parent_layers__ = parent_layers
def to_proto(self, context):
"""
function to set proto attribute
"""
kwargs = dict()
for layer_name in self.__parent_layers__:
if not isinstance(self.__parent_layers__[layer_name],
collections.Sequence):
v1_layer = self.__parent_layers__[layer_name].to_proto(
context=context)
else:
v1_layer = map(lambda x: x.to_proto(context=context),
self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer
if self.name not in context:
context[self.name] = self.to_proto_impl(**kwargs)
return context[self.name]
def to_proto_impl(self, **kwargs):
raise NotImplementedError()
def __convert_to_v2__(method_name, name_prefix, parent_names):
if name_prefix is not None:
wrapper = wrap_name_default(name_prefix=name_prefix)
else:
wrapper = None
class V2LayerImpl(Layer):
def __init__(self, name=None, **kwargs):
parent_layers = dict()
other_kwargs = dict()
for pname in parent_names:
if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys():
if key not in parent_names:
other_kwargs[key] = kwargs[key]
super(V2LayerImpl, self).__init__(name, parent_layers)
self.__other_kwargs__ = other_kwargs
if wrapper is not None:
__init__ = wrapper(__init__)
def to_proto_impl(self, **kwargs):
args = dict()
for each in kwargs:
args[each] = kwargs[each]
for each in self.__other_kwargs__:
args[each] = self.__other_kwargs__[each]
return getattr(conf_helps, method_name)(name=self.name, **args)
return V2LayerImpl
"""
Some layer may need some special config, and can not use __convert_to_v2__ to convert.
So we also need to implement some special LayerV2.
"""
class DataLayerV2(Layer):
def __init__(self, name, type, **kwargs):
assert isinstance(type, data_type.InputType)
self.type = type
self.__method_name__ = 'data_layer'
self.__kwargs__ = kwargs
super(DataLayerV2, self).__init__(name=name, parent_layers=dict())
def to_proto_impl(self, **kwargs):
args = dict()
args['size'] = self.type.dim
for each in kwargs:
args[each] = kwargs[each]
for each in self.__kwargs__:
args[each] = self.__kwargs__[each]
return getattr(conf_helps, self.__method_name__)(name=self.name, **args)
data = DataLayerV2
fc = __convert_to_v2__('fc_layer', name_prefix='fc', parent_names=['input'])
max_id = __convert_to_v2__(
'maxid_layer', name_prefix='maxid', parent_names=['input'])
classification_cost = __convert_to_v2__(
'classification_cost',
name_prefix='classification_cost',
parent_names=['input', 'label', 'weight'])
regression_cost = __convert_to_v2__(
'regression_cost',
name_prefix='regression_cost',
parent_names=['input', 'label', 'weight'])
cross_entropy_cost = __convert_to_v2__(
'cross_entropy',
name_prefix='cross_entropy',
parent_names=['input', 'label'])
cross_entropy_with_selfnorm_cost = __convert_to_v2__(
'cross_entropy_with_selfnorm',
name_prefix='cross_entropy_with_selfnorm',
parent_names=['input', 'label'])
multi_binary_label_cross_entropy_cost = __convert_to_v2__(
'multi_binary_label_cross_entropy',
name_prefix='multi_binary_label_cross_entropy',
parent_names=['input', 'label'])
rank_cost = __convert_to_v2__(
'rank_cost',
name_prefix='rank_cost',
parent_names=['left', 'right', 'label', 'weight'])
lambda_cost = __convert_to_v2__(
'lambda_cost', name_prefix='lambda_cost', parent_names=['input', 'score'])
sum_cost = __convert_to_v2__(
'sum_cost', name_prefix='sum_cost', parent_names=['input'])
huber_cost = __convert_to_v2__(
'huber_cost', name_prefix='huber_cost', parent_names=['input', 'label'])
if __name__ == '__main__':
pixel = data(name='pixel', type=data_type.dense_vector(784))
label = data(name='label', type=data_type.integer_value(10))
weight = data(name='weight', type=data_type.dense_vector(10))
score = data(name='score', type=data_type.dense_vector(1))
hidden = fc(input=pixel,
size=100,
act=activation.Sigmoid(),
param_attr=attr.Param(name='hidden'))
inference = fc(input=hidden, size=10, act=activation.Softmax())
maxid = max_id(input=inference)
cost1 = classification_cost(input=inference, label=label)
cost2 = classification_cost(input=inference, label=label, weight=weight)
cost3 = cross_entropy_cost(input=inference, label=label)
cost4 = cross_entropy_with_selfnorm_cost(input=inference, label=label)
cost5 = regression_cost(input=inference, label=label)
cost6 = regression_cost(input=inference, label=label, weight=weight)
cost7 = multi_binary_label_cross_entropy_cost(input=inference, label=label)
cost8 = rank_cost(left=score, right=score, label=score)
cost9 = lambda_cost(input=inference, score=score)
cost10 = sum_cost(input=inference)
cost11 = huber_cost(input=score, label=label)
print parse_network(cost1, cost2)
print parse_network(cost3, cost4)
print parse_network(cost5, cost6)
print parse_network(cost7, cost8, cost9, cost10, cost11)
print parse_network(inference, maxid)
import numpy as np
from . import layer as v2_layer
import py_paddle.swig_paddle as api
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
__all__ = ['Parameters', 'create']
def create(*layers):
"""
Create parameter pool by layers. In paddle, layer can be represent a
model config.
:param layers:
:return:
"""
for layer in layers:
if not isinstance(layer, v2_layer.Layer):
raise ValueError(
'create must pass a topologies which type is paddle.layer.Layer')
model_config = v2_layer.parse_network(*layers)
pool = Parameters()
for param in model_config.parameters:
pool.__append_config__(param)
return pool
class Parameters(object):
"""
Parameters is a dictionary contains Paddle's parameter. The key of
Parameters is the name of parameter. The value of Parameters is a plain
:code:`numpy.ndarry` .
Basically usage is
.. code-block:: python
data = paddle.layers.data(...)
...
out = paddle.layers.fc(...)
parameters = paddle.parameters.create(out)
parameter_names = parameters.names()
fc_mat = parameters.get('fc')
print fc_mat
"""
def __init__(self):
self.__param_conf__ = dict()
self.__gradient_machines__ = []
self.__tmp_params__ = []
def __append_config__(self, param_conf):
"""
Append a parameter configuration. It used to initialize Parameters and
should be invoked only in paddle.parameters.create
:param param_conf: The parameter configuration in protobuf
:type param_conf: ParameterConfig
:return: Nothing
"""
if not isinstance(param_conf, ParameterConfig):
raise ValueError("param_conf must be paddle.proto.ParameterConfig")
if param_conf.name in self.__param_conf__:
raise ValueError("duplicated parameter %s" % param_conf.name)
self.__param_conf__[param_conf.name] = param_conf
def keys(self):
"""
keys are the names of each parameter.
:return: list of parameter name
:rtype: list
"""
return self.__param_conf__.keys()
def names(self):
"""
names of each parameter.
:return: list of parameter name
:rtype: list
"""
return self.keys()
def has_key(self, key):
"""
has_key return true if there are such parameter name == key
:param key: Parameter name
:type key: basestring
:return: True if contains such key
"""
return key in self.__param_conf__.keys()
def __iter__(self):
"""
Return an iterator of parameter name. It is used by `for loop`
or `in` operator.
.. code-block:: python
parameters = paddle.parameters.create(...)
if "fc_param" in parameters:
print 'OK'
:return: an iterator of parameter name
:rtype: iterator
"""
return iter(self.__param_conf__)
def __getitem__(self, key):
"""
Get parameter by parameter name. It uses Python dict syntax.
:note: It will always copy the parameter from C++ side.
:param key: Parameter name
:type key: basestring
:return: parameter value
:rtype: np.ndarray
"""
shape = self.get_shape(key)
if len(self.__gradient_machines__) == 0:
# create new parameter in python numpy.
return np.ndarray(shape=shape, dtype=np.float32)
else:
for each_gradient_machine in self.__gradient_machines__:
param = __get_parameter_in_gradient_machine__(
each_gradient_machine, key)
# for simplify implementation now, we always copy from C++
assert isinstance(param, api.Parameter)
val = param.getBuf(api.PARAMETER_VALUE)
assert isinstance(val, api.Vector)
val = val.copyToNumpyArray()
return val
# else continue
raise RuntimeError("Unexpected branch")
def get_shape(self, key):
"""
get shape of the parameter.
:param key: parameter name
:type key: basestring
:return: parameter's shape
:rtype: tuple
"""
if not isinstance(key, basestring):
raise ValueError("parameter name should be string")
if not self.has_key(key):
raise ValueError("No such parameter %s" % key)
conf = self.__param_conf__[key]
return tuple(map(int, conf.dims))
def __setitem__(self, key, value):
"""
Set parameter by parameter name & value. It use Python dict syntax.
:note: It will always copy the parameter to C++ side.
:param key: Parameter name
:type key: basestring
:param value: Parameter matrix.
:type value: np.ndarray
:return: Nothing
"""
if not isinstance(value, np.ndarray):
raise ValueError("Must return ndarray")
value = value.astype(dtype=np.float32)
shape = self.get_shape(key)
if value.shape != shape:
raise ValueError("Value shape mismatch, expect %s, should %s" %
(shape, value.shape))
if len(self.__gradient_machines__) == 0:
self.__tmp_params__.append((key, value))
else:
for each_gradient_machine in self.__gradient_machines__:
__copy_parameter_to_gradient_machine__(each_gradient_machine,
key, value)
def get(self, parameter_name):
"""
Get parameter by parameter name.
:note: It will always copy the parameter from C++ side.
:param parameter_name: parameter name
:type parameter_name: basestring
:return: The parameter matrix.
:rtype: np.ndarray
"""
return self.__getitem__(key=parameter_name)
def set(self, parameter_name, value):
"""
Set parameter by parameter name & matrix.
:param parameter_name: parameter name
:type parameter_name: basestring
:param value: parameter matrix
:type value: np.ndarray
:return: Nothing.
"""
self.__setitem__(key=parameter_name, value=value)
def append_gradient_machine(self, gradient_machine):
"""
append gradient machine to parameters. This method is used internally in
Trainer.train.
:param gradient_machine: Paddle C++ GradientMachine object.
:type gradient_machine: api.GradientMachine
:return:
"""
if not isinstance(gradient_machine, api.GradientMachine):
raise ValueError("gradient_machine should be api.GradientMachine")
if len(self.__tmp_params__) != 0:
for name, val in self.__tmp_params__:
try:
__copy_parameter_to_gradient_machine__(gradient_machine,
name, val)
except ValueError:
# If no such parameter in gradient machine, then don't copy
pass
self.__gradient_machines__.append(gradient_machine)
def __get_parameter_in_gradient_machine__(gradient_machine, name):
"""
:param gradient_machine:
:type gradient_machine: api.GradientMachine
:param name:
:return:
:rtype: api.Parameter
"""
params = filter(lambda p: p.getName() == name,
gradient_machine.getParameters())
if len(params) == 0:
raise ValueError("No such parameter")
elif len(params) > 1:
raise ValueError("Unexpected branch")
else:
return params[0]
def __copy_parameter_to_gradient_machine__(gradient_machine, name, arr):
"""
Copy a python ndarray into the gradient machine.
:param gradient_machine:
:type gradient_machine: api.GradientMachine
:param name:
:param arr:
:type arr: np.ndarray
:return:
:rtype: api.Parameter
"""
param = __get_parameter_in_gradient_machine__(gradient_machine, name)
vec = param.getBuf(api.PARAMETER_VALUE)
assert isinstance(vec, api.Vector)
vec.copyFromNumpyArray(arr.flatten())
add_test(NAME test_v2_layer
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/v2/tests/test_layer.py
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
# Copyright PaddlePaddle contributors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import difflib
import unittest
import paddle.trainer_config_helpers as conf_helps
import paddle.v2.activation as activation
import paddle.v2.attr as attr
import paddle.v2.data_type as data_type
import paddle.v2.layer as layer
from paddle.trainer_config_helpers.config_parser_utils import \
parse_network_config as parse_network
pixel = layer.data(name='pixel', type=data_type.dense_vector(784))
label = layer.data(name='label', type=data_type.integer_value(10))
weight = layer.data(name='weight', type=data_type.dense_vector(10))
score = layer.data(name='score', type=data_type.dense_vector(1))
hidden = layer.fc(input=pixel,
size=100,
act=activation.Sigmoid(),
param_attr=attr.Param(name='hidden'))
inference = layer.fc(input=hidden, size=10, act=activation.Softmax())
class CostLayerTest(unittest.TestCase):
def test_cost_layer(self):
cost1 = layer.classification_cost(input=inference, label=label)
cost2 = layer.classification_cost(
input=inference, label=label, weight=weight)
cost3 = layer.cross_entropy_cost(input=inference, label=label)
cost4 = layer.cross_entropy_with_selfnorm_cost(
input=inference, label=label)
cost5 = layer.regression_cost(input=inference, label=label)
cost6 = layer.regression_cost(
input=inference, label=label, weight=weight)
cost7 = layer.multi_binary_label_cross_entropy_cost(
input=inference, label=label)
cost8 = layer.rank_cost(left=score, right=score, label=score)
cost9 = layer.lambda_cost(input=inference, score=score)
cost10 = layer.sum_cost(input=inference)
cost11 = layer.huber_cost(input=score, label=label)
print dir(layer)
layer.parse_network(cost1, cost2)
print dir(layer)
#print layer.parse_network(cost3, cost4)
#print layer.parse_network(cost5, cost6)
#print layer.parse_network(cost7, cost8, cost9, cost10, cost11)
if __name__ == '__main__':
unittest.main()
import collections
import py_paddle.swig_paddle as api
from paddle.proto.ModelConfig_pb2 import ModelConfig
from py_paddle import DataProviderConverter
from . import event as v2_event
from . import layer as v2_layer
from . import optimizer as v2_optimizer
from . import parameters as v2_parameters
__all__ = ['ITrainer', 'SGD']
def default_event_handler(event):
"""
Default event handler. It will print some log and save mode.
TODO(yuyang18): Complete it!
:param event:
:return:
"""
pass
class ITrainer(object):
"""
The interface of Trainer. The only exposed method is `train`.
"""
def train(self,
train_data_reader,
topology,
parameters,
test_data_reader=None,
event_handler=None):
"""
train method.
:param train_data_reader:
:param topology:
:param parameters:
:param test_data_reader:
:param event_handler:
:return:
"""
raise NotImplementedError()
class SGD(ITrainer):
def __init__(self, update_equation):
"""
Simple SGD Trainer.
:param update_equation: The optimizer object.
:type update_equation: v2_optimizer.Optimizer
"""
if not isinstance(update_equation, v2_optimizer.Optimizer):
raise ValueError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer")
self.__optimizer__ = update_equation
def train(self,
train_data_reader,
topology,
parameters,
num_passes=1,
test_data_reader=None,
event_handler=None,
batch_size=32,
data_types=None):
"""
Training method. Will train num_passes of input data.
:param train_data_reader:
:param topology: Network Topology, use one or more Layers to represent it.
:param parameters: The parameter pools.
:param num_passes: The total train passes.
:param test_data_reader:
:param event_handler: Event handler. A method will be invoked when event
occurred.
:type event_handler: (BaseEvent) => None
:param batch_size: Not important, will be removed after data refactor.
:param data_types: Not important, will be removed after data refactor.
:return:
"""
if event_handler is None:
event_handler = default_event_handler
topology = v2_layer.parse_network(topology)
__check_train_args__(**locals())
gm = api.GradientMachine.createFromConfigProto(
topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
updater = self.__optimizer__.create_local_updater()
updater.init(gm)
gm.start()
out_args = api.Arguments.createArguments(0)
data_types_lists = []
for each in topology.input_layer_names:
if each not in data_types:
raise ValueError()
data_types_lists.append(data_types[each])
converter = DataProviderConverter(input_types=data_types_lists)
for pass_id in xrange(num_passes):
updater.startPass()
for batch_id, data_batch in enumerate(
__data_reader_to_batch__(train_data_reader, batch_size,
topology)):
pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(converter(data_batch), out_args, pass_type)
for each_param in gm.getParameters():
updater.update(each_param)
# Get cost. We use numpy to calculate total cost for this batch.
cost_vec = out_args.getSlotValue(0)
cost_vec = cost_vec.copyToNumpyMat()
cost = cost_vec.sum() / len(data_batch)
updater.finishBatch(cost)
event_handler(
v2_event.EndIteration(
pass_id=pass_id, batch_id=batch_id, cost=cost))
updater.finishPass()
gm.finish()
def __data_reader_to_batch__(reader, batch_size, topology):
"""
This function is not important, and will be removed when data refactored.
"""
def input_reorder(func):
for item in func():
retv = []
for __layer_name__ in topology.input_layer_names:
retv.append(item[__layer_name__])
yield retv
return __generator_to_batch__(input_reorder(reader), batch_size=batch_size)
def __generator_to_batch__(generator, batch_size):
"""
This function is not important, and will be removed when data refactored.
"""
ret_val = list()
for each_item in generator:
ret_val.append(each_item)
if len(ret_val) == batch_size:
yield ret_val
ret_val = list()
if len(ret_val) != 0:
yield ret_val
def __check_train_args__(train_data_reader, topology, parameters,
test_data_reader, event_handler, **kwargs):
"""
Check train function's argument types
"""
if not callable(train_data_reader) or not isinstance(train_data_reader(),
collections.Iterator):
raise ValueError('train_data_reader should be a function, '
'which can return a iterator')
if test_data_reader is not None:
if not callable(test_data_reader) or not isinstance(
test_data_reader(), collections.Iterator):
raise ValueError('test_data_reader should be a function, which can '
'return a iterator')
if not isinstance(topology, ModelConfig):
raise ValueError('topology should be a model config')
if not isinstance(parameters, v2_parameters.Parameters):
raise ValueError('parameters should be a parameter pool')
if not callable(event_handler):
raise ValueError('event handler should be a function')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册