diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index fe39f0bd23f78e1a9d61f708dc880d9853b7a5f9..ea1caa7dd9653a2cc2860ace736fe3d25a3767e0 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -9,7 +9,6 @@ The user api could be simpler and carefully designed. import random import numpy as np -import paddle.trainer.PyDataProvider2 as dp import paddle.v2 as paddle_v2 import py_paddle.swig_paddle as api from paddle.trainer_config_helpers import * @@ -71,8 +70,10 @@ def main(): assert isinstance(updater, api.ParameterUpdater) # define network - images = paddle_v2.layer.data(name='pixel', size=784) - label = paddle_v2.layer.data(name='label', size=10) + images = paddle_v2.layer.data( + name='pixel', type=paddle_v2.data_type.dense_vector(784)) + label = paddle_v2.layer.data( + name='label', type=paddle_v2.data_type.integer_value(10)) hidden1 = paddle_v2.layer.fc(input=images, size=200) hidden2 = paddle_v2.layer.fc(input=hidden1, size=200) inference = paddle_v2.layer.fc(input=hidden2, @@ -98,8 +99,7 @@ def main(): # DataProvider Converter is a utility convert Python Object to Paddle C++ # Input. The input format is as same as Paddle's DataProvider. - converter = DataProviderConverter( - input_types=[dp.dense_vector(784), dp.integer_value(10)]) + converter = DataProviderConverter(input_types=[images.type, label.type]) train_file = './data/raw_data/train' test_file = './data/raw_data/t10k' diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index b5cc74ce67dfc8e1afa65bd52f5ec600260032ce..6fc01ce58be57c77144c6558d039430b22d3a746 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -1,6 +1,5 @@ import numpy import paddle.v2 as paddle -from paddle.trainer.PyDataProvider2 import dense_vector, integer_value import mnist_util @@ -16,8 +15,10 @@ def main(): paddle.init(use_gpu=False, trainer_count=1) # define network topology - images = paddle.layer.data(name='pixel', size=784) - label = paddle.layer.data(name='label', size=10) + images = paddle.layer.data( + name='pixel', type=paddle.data_type.dense_vector(784)) + label = paddle.layer.data( + name='label', type=paddle.data_type.integer_value(10)) hidden1 = paddle.layer.fc(input=images, size=200) hidden2 = paddle.layer.fc(input=hidden1, size=200) inference = paddle.layer.fc(input=hidden2, @@ -51,8 +52,8 @@ def main(): batch_size=32, # batch size should be refactor in Data reader data_types={ # data_types will be removed, It should be in # network topology - 'pixel': dense_vector(784), - 'label': integer_value(10) + 'pixel': images.type, + 'label': label.type }) diff --git a/demo/sentiment/dataprovider.py b/demo/sentiment/dataprovider.py index 00f72cecacb454a0dd1184fa2098be4543007de7..4b7f5d0e504aef3884a04cbed8c16503a4079772 100755 --- a/demo/sentiment/dataprovider.py +++ b/demo/sentiment/dataprovider.py @@ -32,4 +32,6 @@ def process(settings, file_name): word_slot = [ settings.word_dict[w] for w in words if w in settings.word_dict ] + if not word_slot: + continue yield word_slot, label diff --git a/demo/sentiment/predict.py b/demo/sentiment/predict.py index 8ec490f64691924013200a3d0038d39aa834b038..64c78e0d6b9297e7a321a4f070517593b0bfe332 100755 --- a/demo/sentiment/predict.py +++ b/demo/sentiment/predict.py @@ -138,7 +138,11 @@ def main(): batch = [] for line in sys.stdin: - batch.append([predict.get_index(line)]) + words = predict.get_index(line) + if words: + batch.append([words]) + else: + print('All the words in [%s] are not in the dictionary.' % line) if len(batch) == batch_size: predict.batch_predict(batch) batch = [] diff --git a/doc/api/trainer_config_helpers/layers.rst b/doc/api/trainer_config_helpers/layers.rst index 8b0e553eacc932bc59062103ac6e6ac4245d03cb..2793d6afd9565eb461c8657b838b146fe1992b20 100644 --- a/doc/api/trainer_config_helpers/layers.rst +++ b/doc/api/trainer_config_helpers/layers.rst @@ -279,6 +279,12 @@ concat_layer :members: concat_layer :noindex: +seq_concat_layer +---------------- +.. automodule:: paddle.trainer_config_helpers.layers + :members: seq_concat_layer + :noindex: + Reshaping Layers ================ @@ -302,6 +308,12 @@ repeat_layer :members: repeat_layer :noindex: +seq_reshape_layer +----------------- +.. automodule:: paddle.trainer_config_helpers.layers + :members: seq_reshape_layer + :noindex: + Math Layers =========== diff --git a/doc/design/reader/README.md b/doc/design/reader/README.md index 8f7abf12f733542734efe91111f365a34aa4b15b..17d52b9e20b8130688028092421f4b33f44763ac 100644 --- a/doc/design/reader/README.md +++ b/doc/design/reader/README.md @@ -1,30 +1,40 @@ # Python Data Reader Design Doc -Paddle reads data from data reader during training. It will be passed into `paddle.train` as a parameter. +At training and testing time, PaddlePaddle programs need to read data. To ease the users' work to write data reading code, we define that + +- A *reader* is a function that reads data (from file, network, random number generator, etc) and yields data items. +- A *reader creator* is a function that returns a reader function. +- A *reader* decorator is a function, which accepts one or more readers, and returns a reader. + +and provide frequently used reader creators and reader decorators. ## Data Reader Interface -Data reader is a function with no parameter that creates a iterable (anything can be used in `for x in iterable`): +Indeed, *data reader* doesn't have to be a function that reads and yields data items. It can be any function with no parameter that creates a iterable (anything can be used in `for x in iterable`): ``` iterable = data_reader() ``` -Element produced for the iterable should be a **single** entry of data, **not** a mini batch. That entry of data could be a single item, or a tuple of items. Item should be of [supported type](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int) +Element produced from the iterable should be a **single** entry of data, **not** a mini batch. That entry of data could be a single item, or a tuple of items. Item should be of [supported type](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int) -An example implementation for single item data reader: +An example implementation for single item data reader creator: ```python -def data_reader_fake_image(): - while True: - yield numpy.random.uniform(-1, 1, size=20*20) +def reader_creator_random_image(width, height): + def reader(): + while True: + yield numpy.random.uniform(-1, 1, size=width*height) + return reader ``` -An example implementation for multiple item data reader: +An example implementation for multiple item data reader creator: ```python -def data_reader_fake_image_and_label(): - while True: - yield numpy.random.uniform(-1, 1, size=20*20), False +def reader_creator_random_imageand_label(widht, height, label): + def reader(): + while True: + yield numpy.random.uniform(-1, 1, size=width*height), label + return reader ``` ## Usage @@ -41,9 +51,9 @@ label_layer = paddle.layer.data("label", ...) paddle.train(paddle.dataset.mnist, {"image":0, "label":1}, 128, 10, ...) ``` -## Data Reader Decorators +## Data Reader Decorator -Data reader decorators takes a single or multiple data reader, returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` syntax. +*Data reader decorator* takes a single or multiple data reader, returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` syntax. Since we have a strict interface for data readers (no parameter, return a single data item). Data reader can be used flexiable via data reader decorators. Following are a few examples: @@ -61,23 +71,27 @@ buffered_reader = paddle.reader.buffered(paddle.dataset.mnist, 100) ### Compose Multiple Data Readers -For example, we want to use a source of real images (reusing mnist dataset), and a source of fake images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661). +For example, we want to use a source of real images (reusing mnist dataset), and a source of random images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661). We can do: ```python -def data_reader_fake_image(): - while True: - yield numpy.random.uniform(-1, 1, size=20*20) - -def data_reader_bool(t): - while True: - yield t - -true_reader = lambda : data_reader_bool(True) -false_reader = lambda : data_reader_bool(False) - -reader = paddle.reader.combine(paddle.dataset.mnist, data_reader_fake_image, true_reader, false_reader) +def reader_creator_random_image(width, height): + def reader(): + while True: + yield numpy.random.uniform(-1, 1, size=width*height) + return reader + +def reader_creator_bool(t): + def reader: + while True: + yield t + return reader + +true_reader = reader_creator_bool(True) +false_reader = reader_creator_bool(False) + +reader = paddle.reader.compose(paddle.dataset.mnist, data_reader_creator_random_image(20, 20), true_reader, false_reader) # Skipped 1 because paddle.dataset.mnist produces two items per data entry. # And we don't care second item at this time. paddle.train(reader, {"true_image":0, "fake_image": 2, "true_label": 3, "false_label": 4}, ...) @@ -98,28 +112,31 @@ reader = paddle.reader.shuffle(paddle.dataset.mnist, 512) If a mini batch is returned, data reader need to take care of batch size. But batch size is a concept for training, it makes more sense for user to specify batch size as a parameter for `train`. -Practically, always return a single entry make reusing existing data reader much easier (e.g., if existing data reader return not a single entry but 3 entries, training code will be more complex because it need to handle cases like batch size 2). +Practically, always return a single entry make reusing existing data readers much easier (e.g., if existing reader return not a single entry but 3 entries, training code will be more complex because it need to handle cases like batch size 2). ### Why use a dictionary but not a list to provide mapping? We decided to use dictionary (`{"image":0, "label":1}`) instead of list (`["image", "label"]`) is because that user can easily resue item (e.g., using `{"image_a":0, "image_b":0, "label":1}`) or skip item (e.g., using `{"image_a":0, "label":2}`). -### How to create custom data reader +### How to create custom data reader creator ```python -def image_reader(image_path, label_path): - f = open(image_path) - l = open(label_path) - images = numpy.fromfile( - f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32') - images = images / 255.0 * 2.0 - 1.0 - labels = numpy.fromfile(l, 'ubyte', count=n).astype("int") - for i in xrange(n): - yield images[i, :], labels[i] # a single entry of data is created each time - f.close() - -# use python lambda to change image_reader into a function with no parameters. -reader = lambda : image_reader("/path/to/image_file", "/path/to/label_file") +def image_reader_creator(image_path, label_path, n): + def reader(): + f = open(image_path) + l = open(label_path) + images = numpy.fromfile( + f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32') + images = images / 255.0 * 2.0 - 1.0 + labels = numpy.fromfile(l, 'ubyte', count=n).astype("int") + for i in xrange(n): + yield images[i, :], labels[i] # a single entry of data is created each time + f.close() + l.close() + return reader + +# images_reader_creator creates a reader +reader = image_reader_creator("/path/to/image_file", "/path/to/label_file", 1024) paddle.train(reader, {"image":0, "label":1}, ...) ``` @@ -128,7 +145,7 @@ paddle.train(reader, {"image":0, "label":1}, ...) An example implementation of paddle.train could be: ```python -def minibatch_decorater(reader, minibatch_size): +def make_minibatch(reader, minibatch_size): def ret(): r = reader() buf = [r.next() for x in xrange(minibatch_size)] @@ -139,6 +156,6 @@ def minibatch_decorater(reader, minibatch_size): def train(reader, mapping, batch_size, total_pass): for pass_idx in range(total_pass): - for mini_batch in minibatch_decorater(reader): # this loop will never end in online learning. + for mini_batch in make_minibatch(reader): # this loop will never end in online learning. do_forward_backward(mini_batch, mapping) ``` diff --git a/paddle/api/test/testMatrix.py b/paddle/api/test/testMatrix.py index 37666bdccc9aedfe8f8079124129aad2ade53a43..f08fbf3ccdf5d7c0a5c739868b1bcb516146c23d 100644 --- a/paddle/api/test/testMatrix.py +++ b/paddle/api/test/testMatrix.py @@ -68,7 +68,7 @@ class TestMatrix(unittest.TestCase): def test_numpyCpu(self): numpy_mat = np.matrix([[1, 2], [3, 4], [5, 6]], dtype="float32") - m = swig_paddle.Matrix.createCpuDenseFromNumpy(numpy_mat, copy=False) + m = swig_paddle.Matrix.createCpuDenseFromNumpy(numpy_mat, False) self.assertEqual((int(m.getHeight()), int(m.getWidth())), numpy_mat.shape) diff --git a/paddle/api/test/testVector.py b/paddle/api/test/testVector.py index 1ab095c1d3d0d2c84d2d2f95a03f172b901de209..6339cf8542607bdda99eb9ccaa8b06480f144b78 100644 --- a/paddle/api/test/testVector.py +++ b/paddle/api/test/testVector.py @@ -43,7 +43,7 @@ class TestIVector(unittest.TestCase): def test_cpu_numpy(self): vec = np.array([1, 3, 4, 65, 78, 1, 4], dtype="int32") - iv = swig_paddle.IVector.createCpuVectorFromNumpy(vec, copy=False) + iv = swig_paddle.IVector.createCpuVectorFromNumpy(vec, False) self.assertEqual(vec.shape[0], int(iv.__len__())) vec[4] = 832 for i in xrange(len(iv)): @@ -106,7 +106,7 @@ class TestVector(unittest.TestCase): def testCpuNumpy(self): numpy_arr = np.array([1.2, 2.3, 3.4, 4.5], dtype="float32") - vec = swig_paddle.Vector.createCpuVectorFromNumpy(numpy_arr, copy=False) + vec = swig_paddle.Vector.createCpuVectorFromNumpy(numpy_arr, False) assert isinstance(vec, swig_paddle.Vector) numpy_arr[0] = 0.1 for n, v in zip(numpy_arr, vec): diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index 6f21b82afdc6cdde785fdd8f13eef17a0fdd6324..eb454c59c1e58cf2b4817b4cb3230b9d75e320ac 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -69,19 +69,6 @@ extern void hl_sequence_softmax_forward(real* A_d, const int* index, int numSequence); -/** - * @brief Matrix classification error. - * - * @param[in] A_d input matrix (M x N). - * @param[in] B_d input vector (M x 1). - * @param[out] C_d output vector (M x 1). - * @param[in] dimM matrix height. - * @param[in] dimN matrix width. - * - */ -extern void hl_matrix_classification_error( - real* A_d, int* B_d, real* C_d, int dimM, int dimN); - /** * @brief Matrix cross entropy. * diff --git a/paddle/cuda/include/hl_top_k.h b/paddle/cuda/include/hl_top_k.h index 77949ed295a6eaf7cc535853e53bef066ffac37c..79ae0d0e741de06e622454ccd220e2c749d795b3 100644 --- a/paddle/cuda/include/hl_top_k.h +++ b/paddle/cuda/include/hl_top_k.h @@ -58,4 +58,30 @@ extern void hl_sparse_matrix_top_k(real* topVal, int beamSize, int numSamples); -#endif /* HL_TOP_K_H_ */ +/** + * @brief Matrix classification error. + * + * @param[out] topVal top k element. + * @param[in] ldv leading dimension of topVal. + * @param[out] topIds top k index. + * @param[in] src input value. + * @param[in] lds leading dimension of src. + * @param[in] dim width of input value. + * @param[in] topkSize size of top k element. + * @param[in] numSamples height of input value. + * @param[in] label ground truth label. + * @param[out] recResult top-k classification error. + * + */ +extern void hl_matrix_classification_error(real* topVal, + int ldv, + int* topIds, + real* src, + int lds, + int dim, + int topkSize, + int numSamples, + int* label, + real* recResult); + +#endif // HL_TOP_K_H_ diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index f4e6461cdcf198637b2c96fee88d1de2766aaf18..127cb7e27983e8ff2c1ff6ef5108b5f8c5bd6ca5 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -35,8 +35,16 @@ inline void hl_sequence_softmax_forward(real* A_d, inline void hl_matrix_softmax_derivative( real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {} -inline void hl_matrix_classification_error( - real* A_d, int* B_d, real* C_d, int dimM, int dimN) {} +inline void hl_matrix_classification_error(real* topVal, + int ldv, + int* topIds, + real* src, + int lds, + int dim, + int topkSize, + int numSamples, + int* label, + real* recResult) {} inline void hl_matrix_cross_entropy( real* A_d, real* C_d, int* label_d, int dimM, int dimN) {} diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 96c07d9c3b7a37daa9198fd7ea66b7d811600348..9bcc7fb7de44b2211db450fb164655f7947dcad9 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -265,59 +265,6 @@ void hl_matrix_softmax_derivative(real *grad_d, CHECK_SYNC("hl_matrix_softmax_derivative failed"); } -template -__global__ void KeMatrixClassificationError(real* in_A, - int* in_B, - real* out_C, - int dimN) { - __shared__ real max_s[blockSize]; - __shared__ int max_l[blockSize]; - const int tid = threadIdx.x; - const int rowId = blockIdx.x; - - max_s[tid] = -1e30f; - in_A += rowId * dimN; - real tmp; - for (int colId = tid; colId < dimN; colId += blockSize) { - tmp = in_A[colId]; - if (max_s[tid] < tmp) { - max_s[tid] = tmp; - max_l[tid] = colId; - } - } - __syncthreads(); - - for (int stride = blockSize/2; stride > 0; stride = stride/2) { - if (tid < stride) { - if (max_s[tid] < max_s[tid + stride]) { - max_s[tid] = max_s[tid + stride]; - max_l[tid] = max_l[tid + stride]; - } - } - __syncthreads(); - } - __syncthreads(); - - if (tid == 0) { - out_C[rowId] = (max_l[0] == in_B[rowId] ? 0 : 1.0f); - } -} - -void hl_matrix_classification_error(real* A_d, - int* B_d, - real* C_d, - int dimM, - int dimN) { - CHECK_NOTNULL(A_d); - CHECK_NOTNULL(B_d); - CHECK_NOTNULL(C_d); - - // each sample is calculated by one block - KeMatrixClassificationError<1024><<< dimM, 1024, 0, STREAM_DEFAULT >>> - (A_d, B_d, C_d, dimN); - CHECK_SYNC("hl_matrix_classification_error"); -} - __global__ void KeMatrixMultiBinaryCrossEntropy(real* output, real* entropy, int* row, diff --git a/paddle/cuda/src/hl_top_k.cu b/paddle/cuda/src/hl_top_k.cu index f0ef0cc3c51f9e7935dc3c40f630e4d70960802a..4f0bbfcf4e3aa51dd06acf254af65c62098a1df7 100644 --- a/paddle/cuda/src/hl_top_k.cu +++ b/paddle/cuda/src/hl_top_k.cu @@ -384,3 +384,81 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv, CHECK_SYNC("hl_sparse_matrix_top_k failed"); } +/** + * Each block compute one sample. + * In a block: + * 1. every thread get top maxLength value; + * 2. merge to shTopK, block reduce and get max value; + * 3. go to the second setp, until one thread's topK value is null; + * 4. go to the first setp, until get the topK value. + */ +template +__global__ void KeMatrixTopKClassificationError(real* topVal, int ldv, + int * topIds, + real* src, int lds, + int dim, + int beamSize, + int* label, + real* recResult) { + __shared__ Pair shTopK[blockSize]; + __shared__ int maxId[blockSize / 2]; + const int tid = threadIdx.x; + const int warp = threadIdx.x / 32; + src += blockIdx.x * lds; + topVal += blockIdx.x * ldv; + topIds += blockIdx.x * beamSize; + + Pair topK[maxLength]; // NOLINT + int beam = maxLength; + Pair max; + bool isEmpty = false; + bool firstStep = true; + int topkSize = beamSize; + + for (int k = 0; k < maxLength; k++) { + topK[k].set(-HL_FLOAT_MAX, -1); + } + + while (beamSize) { + threadGetTopK + (topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid); + + shTopK[tid] = topK[0]; + blockReduce + (shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp); + } + + __syncthreads(); + if (tid == 0) { + for (int i = 0; i < topkSize; i++) { + if (*--topIds == label[blockIdx.x]) { + recResult[blockIdx.x] = 0; + break; + } + recResult[blockIdx.x] = 1.0f; + } + } +} + +void hl_matrix_classification_error(real* topVal, int ldv, + int* topIds, + real* src, int lds, + int dim, + int topkSize, + int numSamples, + int* label, + real* recResult) { + CHECK_NOTNULL(topVal); + CHECK_NOTNULL(topIds); + CHECK_NOTNULL(src); + + if (topkSize > dim) topkSize = dim; + + dim3 threads(256, 1); + dim3 grid(numSamples, 1); + KeMatrixTopKClassificationError<5, 256> + <<< grid, threads, 0, STREAM_DEFAULT >>> + (topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult); + + CHECK_SYNC("hl_matrix_top_k classification error failed"); +} diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 6c1c2f62be273cf4855621d28da74befec8f259f..5911a9ec59a5b89702c341a02bfbac1eeb105551 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -39,6 +39,14 @@ void Evaluator::eval(const NeuralNetwork& nn) { */ class ClassificationErrorEvaluator : public Evaluator { public: + /* + ClassificationErrorEvaluator() : totalScore2_(0) {} + + virtual void start() { + Evaluator::start(); + totalScore2_ = 0; + } */ + virtual void updateSamplesNum(const std::vector& arguments) { if (3 == arguments.size()) { numSamples_ += arguments[2].value->getSum(); @@ -76,9 +84,11 @@ public: 1, /* trans= */ false, useGpu(arguments[0].deviceId)); + errorMat->zeroMem(); + if (label != nullptr) { - errorMat->classificationError(*output, *label); + errorMat->classificationError(*output, *label, config_.top_k()); } else if (dynamic_cast(multiBinaryLabel.get()) || dynamic_cast(multiBinaryLabel.get())) { errorMat->classificationErrorMulti( @@ -94,6 +104,16 @@ public: return errorMat; } + void printStats(std::ostream& os) const { + if (config_.top_k() == 1) { + os << config_.name() << "=" + << (numSamples_ ? totalScore_ / numSamples_ : 0); + } else { + os << " top_" << config_.top_k() + << "_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0); + } + } + virtual real evalImp(std::vector& arguments) { MatrixPtr errorMat = calcError(arguments); return errorMat->getSum(); diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 6dfd48fb96618102b71e9f6de79a348dc7f62647..7c4bea072157aac17787afab184b51c09ff656f2 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -311,6 +311,7 @@ public: return *output->second; } else { LOG(FATAL) << "No specific output " << str; + return *((Argument*)nullptr); } } } diff --git a/paddle/gserver/layers/SequenceConcatLayer.cpp b/paddle/gserver/layers/SequenceConcatLayer.cpp index 599706eb419ede72dbd6f4c8c74e57f5f9965388..4b24d8f0c852e1bdc887d4ee1465b9ad05d210bb 100644 --- a/paddle/gserver/layers/SequenceConcatLayer.cpp +++ b/paddle/gserver/layers/SequenceConcatLayer.cpp @@ -21,9 +21,11 @@ namespace paddle { /** * A layer for concatenating the first sequence with the second sequence - * following the first - * Input: two sequences each containing some instances + * Input: two sequences each containing the same number of instances + * seq1 = [a1, a2, ..., an] + * seq2 = [b1, b2, ..., bn] * Output: a concatenated sequence of the two input sequences + * out = [a1, b1, a2, b2, ..., an, bn] */ class SequenceConcatLayer : public Layer { diff --git a/paddle/gserver/layers/SequenceReshapeLayer.cpp b/paddle/gserver/layers/SequenceReshapeLayer.cpp index 66f49159087ab9e2c83b1d74e9b4d9bfe4f49e79..433592953b220eda4db4634124a57a2074cef4c0 100644 --- a/paddle/gserver/layers/SequenceReshapeLayer.cpp +++ b/paddle/gserver/layers/SequenceReshapeLayer.cpp @@ -20,9 +20,12 @@ limitations under the License. */ namespace paddle { /** - * A layer for reshaping the sequence - * Input: a sequence - * Output: a sequence + * A layer for reshaping the sequence. Assume the input sequence has + * T instances, the dimension of each instance is M, and the input + * reshape_dim is N, then the output sequence has T*M/N instances, + * the dimension of each instance is N. + * + * Note that T*M/N must be an integer. */ class SequenceReshapeLayer : public Layer { diff --git a/paddle/gserver/tests/test_Evaluator.cpp b/paddle/gserver/tests/test_Evaluator.cpp index 07f486b1f4511ba210256b5a21021e8ca0265eb8..4f5fdbb37ce024e18b8d39c5dda74c69bf82166a 100644 --- a/paddle/gserver/tests/test_Evaluator.cpp +++ b/paddle/gserver/tests/test_Evaluator.cpp @@ -141,6 +141,7 @@ void testEvaluatorAll(TestConfig testConf, TEST(Evaluator, classification_error) { TestConfig config; config.evaluatorConfig.set_type("classification_error"); + config.evaluatorConfig.set_top_k(5); config.inputDefs.push_back({INPUT_DATA, "output", 50}); config.inputDefs.push_back({INPUT_LABEL, "label", 50}); diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 1964b2f8bfaebc49fe3073e03c949a8a9c3e385a..07450bfb0ef709840f7e8253e87c227276529a2a 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { size_t beam = maxVal.getWidth(); CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxVal.getHeight(), numSamples); + CHECK_EQ(maxVal.getWidth(), beam); hl_matrix_top_k(maxVal.getData(), maxVal.getStride(), @@ -792,19 +793,32 @@ void GpuMatrix::maxoutBackward(Matrix& a, } /*calulate the error of classification */ -void GpuMatrix::classificationError(Matrix& output, IVector& label) { - auto output_ptr = dynamic_cast(&output); - auto label_ptr = dynamic_cast(&label); - CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; - - CHECK(height_ == output_ptr->height_ && width_ == 1) +void GpuMatrix::classificationError(Matrix& output, + IVector& label, + size_t topkSize) { + auto gpuOutput = dynamic_cast(&output); + auto gpuLabel = dynamic_cast(&label); + size_t numSamples = this->getHeight(); + GpuMatrixPtr gpuTopVal = std::make_shared(numSamples, topkSize); + GpuIVectorPtr gpuTopIds = std::make_shared(numSamples * topkSize); + + CHECK(gpuOutput && gpuLabel) << "Invalid argument pointer"; + CHECK(gpuTopVal && gpuTopIds) << "Allocate GPU memory failed"; + CHECK(gpuLabel->getSize() == numSamples) << "Vector size is not equal"; + CHECK(numSamples == gpuOutput->getHeight() && this->getWidth() == 1) << "Matrix dimensions are not equal"; - hl_matrix_classification_error((real*)output_ptr->data_, - (int*)label_ptr->getData(), - data_, - height_, - output_ptr->width_); + size_t dim = gpuOutput->getWidth(); + hl_matrix_classification_error(gpuTopVal->getData(), + gpuTopVal->getStride(), + gpuTopIds->getData(), + gpuOutput->getData(), + gpuOutput->getStride(), + dim, + topkSize, + numSamples, + gpuLabel->getData(), + this->getData()); } /* copy -log(output[i * width + label]) to this->data[i] */ @@ -3039,7 +3053,7 @@ void CpuMatrix::rowMax(Matrix& max) { max.maxRows(*this); } -/* get beam size of max ids and values */ +/* Get the top k elements of each row of this matrix */ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { CHECK(isContiguous()); CHECK(!maxIds.useGpu() && !maxVal.useGpu()) << "Matrix type are not equal"; @@ -3047,6 +3061,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { size_t beam = maxVal.getWidth(); CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxVal.getHeight(), numSamples); + CHECK_EQ(maxVal.getWidth(), beam); real* a = getData(); int* s = maxIds.getData(); @@ -3198,32 +3213,39 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) { } /* calulate classification error */ -void CpuMatrix::classificationError(Matrix& output, IVector& label) { - CHECK(dynamic_cast(&output)); - CHECK(dynamic_cast(&label)); +void CpuMatrix::classificationError(Matrix& output, + IVector& label, + size_t topkSize) { + size_t numSamples = this->getHeight(); + auto cpuOutput = dynamic_cast(&output); + auto cpuLabel = dynamic_cast(&label); + IVectorPtr cpuTopIds = std::make_shared(numSamples * topkSize); + MatrixPtr cpuTopVal = std::make_shared(numSamples, topkSize); + + CHECK(cpuOutput && cpuLabel) << "Invalid argument pointer"; + CHECK(cpuTopIds && cpuTopVal) << "Allocate cpu memory failed"; + CHECK(cpuLabel->getSize() == numSamples) << "Vector size is not equal"; + CHECK(cpuOutput->getHeight() == numSamples && this->getWidth() == 1) + << "Matrix dimensions are not equal"; - CHECK_EQ(getWidth(), (size_t)1); - size_t numSamples = getHeight(); - CHECK_EQ(label.getSize(), numSamples); - CHECK_EQ(output.getHeight(), numSamples); + // top k matrix classification + cpuOutput->rowMax(*cpuTopIds, *cpuTopVal); - size_t dim = output.getWidth(); - real* out = output.getData(); - int* lbl = label.getData(); - real maxData = 0.0; - int maxIndex = -1; + size_t dim = cpuOutput->getWidth(); + real* result = this->getData(); + int* ids = cpuTopIds->getData(); + int* lbl = cpuLabel->getData(); for (size_t i = 0; i < numSamples; ++i) { CHECK_GE(lbl[i], 0); CHECK_LT((size_t)lbl[i], dim); - maxData = out[i * dim]; - maxIndex = 0; - for (size_t j = 0; j < dim; ++j) { - if (maxData < out[i * dim + j]) { - maxIndex = j; - maxData = out[i * dim + j]; + + for (size_t j = 0; j < topkSize; ++j) { + if (ids[j + i * topkSize] == lbl[i]) { + result[i] = 0; + break; } + result[i] = 1.0f; } - getData()[i] = (maxIndex != lbl[i]); } } diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index ea4bbb86b057b526c5ea294b2cd835aef65de58d..d0ba2e93feabfcc11ac1d261bc40c9c6973a8c29 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -836,8 +836,11 @@ public: * output[i] = 1 if row i is an error. * * output[i] = 0 if row i is correct. + * */ - virtual void classificationError(Matrix& output, IVector& label) { + virtual void classificationError(Matrix& output, + IVector& label, + size_t topkSize = 1) { LOG(FATAL) << "Not implemented"; } @@ -1314,7 +1317,7 @@ public: void check(std::ostream& os, Matrix& refMat, bool printDiff = true); void randomizeUniform(); - void classificationError(Matrix& output, IVector& label); + void classificationError(Matrix& output, IVector& label, size_t topkSize = 1); void convExpand(Matrix& feature, int feaImgHeight, @@ -1739,7 +1742,7 @@ public: void randomizeUniform(); - void classificationError(Matrix& output, IVector& label); + void classificationError(Matrix& output, IVector& label, size_t topkSize = 1); void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 6caaea443c1df756bfeb775154e8a90400cc3211..08b64c1bb6f5d359a2d2164e723a76c5360168ee 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -764,7 +764,7 @@ TEST(Matrix, paramReluBackwardDiff) { } } -void testClassificationError(int numSamples, int dim) { +void testClassificationError(int numSamples, int dim, int topkSize) { MatrixPtr cpuError = std::make_shared(numSamples, 1); MatrixPtr gpuError = std::make_shared(numSamples, 1); MatrixPtr cpuOutput = std::make_shared(numSamples, dim); @@ -777,17 +777,22 @@ void testClassificationError(int numSamples, int dim) { gpuOutput->copyFrom(*cpuOutput); gpuLabel->copyFrom(*cpuLabel); - cpuError->classificationError(*cpuOutput, *cpuLabel); - gpuError->classificationError(*gpuOutput, *gpuLabel); + cpuError->classificationError(*cpuOutput, *cpuLabel, topkSize); + gpuError->classificationError(*gpuOutput, *gpuLabel, topkSize); TensorCheckEqual(*cpuError, *gpuError); } TEST(Matrix, classificationError) { - for (auto numSamples : {1, 10, 100, 1000, 70000}) { - for (auto dim : {1, 10, 100, 1000}) { - VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; - testClassificationError(numSamples, dim); + for (auto numSamples : {1, 5, 31, 90, 150, 300}) { + for (auto dim : + {1, 5, 8, 10, 15, 64, 80, 120, 256, 300, 1280, 5120, 50000}) { + for (auto topkSize : {1, 5, 10, 20, 40, (int)rand() % dim + 1}) { + if (topkSize > dim) continue; + VLOG(3) << " sample= " << numSamples << " topkSize= " << topkSize + << " dim= " << dim; + testClassificationError(numSamples, dim, topkSize); + } } } } diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index 29d6e20dc16968cdda3e79b66b0c81aaaf303ef4..1ccded818796798105a889df978618688b56ed36 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) { std::ifstream fs(filename, std::ios_base::binary); if (!fs) { LOG(INFO) << "missing parameters [" << filename << "] while loading model."; - if (isStatic()) { - LOG(FATAL) << getName() << " is static but missing, not allowed."; - return false; - } if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { LOG(FATAL) << getName() << " missing, not allowed."; return false; diff --git a/paddle/setup.py.in b/paddle/setup.py.in index c79666bc81b6f343f166422697cd3901ce8ff441..38621af065913c9edd44958e9fb767c983c00dbb 100644 --- a/paddle/setup.py.in +++ b/paddle/setup.py.in @@ -55,6 +55,9 @@ elif is_osx == True: include_dirs = [np.get_include(), "../"] # include numpy and paddle. +os.environ["CC"] = "@CMAKE_C_COMPILER@" +os.environ["CXX"] = "@CMAKE_CXX_COMPILER@" + setup(name="py_paddle", version="@PADDLE_VERSION@", ext_modules=[ diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index be4634d5103c0f219389823d132b1977963017e1..65d5d50277b665e7c355202d6e8043f656ae92f1 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -475,6 +475,10 @@ message EvaluatorConfig { // Used by ChunkEvaluator // chunk of these types are not counted repeated int32 excluded_chunk_types = 12; + + // Used by ClassificationErrorEvaluator + // top # classification error + optional int32 top_k = 13 [default = 1]; } message LinkConfig { diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index f0ddb0ff812b15ede21e6965c7c8857f12716fa0..5fc799e61dab954b9993321c4e816f2f5abce448 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -12,18 +12,135 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['buffered'] +__all__ = ['buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned'] from Queue import Queue from threading import Thread +import itertools +import random + + +def shuffle(reader, buf_size): + """Creates a data reader whose data output is suffled. + + Output from the iterator that created by original reader will be + buffered into shuffle buffer, and then shuffled. The size of shuffle buffer + is determined by argument buf_size. + + Args: + reader: the original reader whose output will be + shuffled. + buf_size: shuffle buffer size. + + Returns: + the new reader whose output is shuffled. + """ + + def data_reader(): + buf = [] + for e in reader(): + buf.append(e) + if len(buf) >= buf_size: + random.shuffle(buf) + for b in buf: + yield b + buf = [] + + if len(buf) > 0: + random.shuffle(buf) + for b in buf: + yield b + + return data_reader + + +def chain(*readers): + """Creates a data reader whose output is the outputs of input data + readers chained together. + + If input readers output following data entries: + [0, 0, 0] + [1, 1, 1] + [2, 2, 2] + The chained reader will output: + [0, 0, 0, 1, 1, 1, 2, 2, 2] + + Args: + readers: input readers. + + Returns: + the new data reader. + """ + + def reader(): + rs = [] + for r in readers: + rs.append(r()) + + for e in itertools.chain(*rs): + yield e + + return reader + + +class ComposeNotAligned(ValueError): + pass + + +def compose(*readers, **kwargs): + """Creates a data reader whose output is the combination of input readers. + + If input readers output following data entries: + (1, 2) 3 (4, 5) + The composed reader will output: + (1, 2, 3, 4, 5) + + Args: + *readers: readers that will be composed together. + check_alignment: If True, will check if input readers are aligned + correctly. If False, will not check alignment and trailing outputs + will be discarded. Defaults to True. + + Returns: + the new data reader. + + Raises: + ComposeNotAligned: outputs of readers are not aligned. + Will not raise when check_alignment is set to False. + """ + check_alignment = kwargs.pop('check_alignment', True) + + def make_tuple(x): + if isinstance(x, tuple): + return x + else: + return (x, ) + + def reader(): + rs = [] + for r in readers: + rs.append(r()) + if not check_alignment: + for outputs in itertools.izip(*rs): + yield sum(map(make_tuple, outputs), ()) + else: + for outputs in itertools.izip_longest(*rs): + for o in outputs: + if o is None: + # None will be not be present if compose is aligned + raise ComposeNotAligned( + "outputs of readers are not aligned.") + yield sum(map(make_tuple, outputs), ()) + + return reader def buffered(reader, size): """Creates a buffered data reader. - The buffered data reader will read and save data entries into a buffer. - Reading from the buffered data reader will proceed as long as the buffer - is not empty. + The buffered data reader will read and save data entries into a + buffer. Reading from the buffered data reader will proceed as long + as the buffer is not empty. Args: reader: the data reader to read from. @@ -43,7 +160,7 @@ def buffered(reader, size): q.put(d) q.put(end) - def create_reader(): + def data_reader(): r = reader() q = Queue(maxsize=size) t = Thread( @@ -57,4 +174,4 @@ def buffered(reader, size): yield e e = q.get() - return create_reader + return data_reader diff --git a/python/paddle/reader/tests/decorator_test.py b/python/paddle/reader/tests/decorator_test.py index 879d1d9c1d0e0650d347b5c44e36771a0c15390e..46eec44158cee5f8c70a0e6197e856e485a7d40c 100644 --- a/python/paddle/reader/tests/decorator_test.py +++ b/python/paddle/reader/tests/decorator_test.py @@ -16,16 +16,20 @@ import paddle.reader import time -def reader_10(dur): - for i in range(10): - time.sleep(dur) - yield i +def reader_creator_10(dur): + def reader(): + for i in range(10): + # this invocation helps testing paddle.reader.buffer + time.sleep(dur) + yield i + + return reader class TestBuffered(unittest.TestCase): def test_read(self): for size in range(20): - b = paddle.reader.buffered(lambda: reader_10(0), size) + b = paddle.reader.buffered(reader_creator_10(0), size) c = 0 for i in b(): self.assertEqual(i, c) @@ -34,7 +38,7 @@ class TestBuffered(unittest.TestCase): def test_buffering(self): # read have 30ms delay. - b = paddle.reader.buffered(lambda: reader_10(0.03), 10) + b = paddle.reader.buffered(reader_creator_10(0.03), 10) last_time = time.time() for idx, i in enumerate(b()): elapsed_time = time.time() - last_time @@ -42,9 +46,63 @@ class TestBuffered(unittest.TestCase): time.sleep(0.3) else: # read time should be short, meaning already buffered. - self.assertLess(elapsed_time, 0.01) + self.assertLess(elapsed_time, 0.05) last_time = time.time() +class TestCompose(unittest.TestCase): + def test_compse(self): + reader = paddle.reader.compose( + reader_creator_10(0), reader_creator_10(0)) + for idx, e in enumerate(reader()): + self.assertEqual(e, (idx, idx)) + + def test_compose_not_aligned(self): + total = 0 + reader = paddle.reader.compose( + paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)), + reader_creator_10(0)) + with self.assertRaises(paddle.reader.ComposeNotAligned): + for e in reader(): + total += 1 + # expecting 10, not 20 + self.assertEqual(total, 10) + + def test_compose_not_aligned_no_check(self): + total = 0 + reader = paddle.reader.compose( + paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)), + reader_creator_10(0), + check_alignment=False) + for e in reader(): + total += 1 + # expecting 10, not 20 + self.assertEqual(total, 10) + + +class TestChain(unittest.TestCase): + def test_chain(self): + c = paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)) + idx = 0 + for e in c(): + self.assertEqual(e, idx % 10) + idx += 1 + self.assertEqual(idx, 20) + + +class TestShuffle(unittest.TestCase): + def test_shuffle(self): + case = [(0, True), (1, True), (10, False), (100, False)] + a = reader_creator_10(0) + for size, checkEq in case: + s = paddle.reader.shuffle(a, size) + total = 0 + for idx, e in enumerate(s()): + if checkEq: + self.assertEqual(idx, e) + total += 1 + self.assertEqual(total, 10) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index d403a6029a3e9d4c41b80a2206397dcdfe780026..da937152ee0ce788309690c7b718943bb21b5a76 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1253,6 +1253,7 @@ def Evaluator( dict_file=None, result_file=None, num_results=None, + top_k=None, delimited=None, excluded_chunk_types=None, ): evaluator = g_config.model_config.evaluators.add() @@ -1280,6 +1281,8 @@ def Evaluator( evaluator.result_file = result_file if num_results is not None: evaluator.num_results = num_results + if top_k is not None: + evaluator.top_k = top_k if delimited is not None: evaluator.delimited = delimited diff --git a/python/paddle/trainer_config_helpers/evaluators.py b/python/paddle/trainer_config_helpers/evaluators.py index bd247ea9af9d8dfb2d476cdc62638bd65c11add5..567521ee9dbadb7a2502cfb9972ef0940e1e410a 100644 --- a/python/paddle/trainer_config_helpers/evaluators.py +++ b/python/paddle/trainer_config_helpers/evaluators.py @@ -71,6 +71,7 @@ def evaluator_base( result_file=None, num_results=None, delimited=None, + top_k=None, excluded_chunk_types=None, ): """ Evaluator will evaluate the network status while training/testing. @@ -104,12 +105,15 @@ def evaluator_base( :param weight: An input layer which is a weight for each sample. Each evaluator may calculate differently to use this weight. :type weight: LayerOutput. + :param top_k: number k in top-k error rate + :type top_k: int """ # inputs type assertions. assert classification_threshold is None or isinstance( classification_threshold, float) assert positive_label is None or isinstance(positive_label, int) assert num_results is None or isinstance(num_results, int) + assert top_k is None or isinstance(top_k, int) if not isinstance(input, list): input = [input] @@ -130,6 +134,8 @@ def evaluator_base( dict_file=dict_file, result_file=result_file, delimited=delimited, + num_results=num_results, + top_k=top_k, excluded_chunk_types=excluded_chunk_types, ) @@ -139,6 +145,7 @@ def classification_error_evaluator(input, label, name=None, weight=None, + top_k=None, threshold=None): """ Classification Error Evaluator. It will print error rate for classification. @@ -167,6 +174,8 @@ def classification_error_evaluator(input, then means not set weight. The larger weight it is, the more important this sample is. :type weight: LayerOutput + :param top_k: number k in top-k error rate + :type top_k: int :param threshold: The classification threshold. :type threshold: float :return: None. @@ -178,6 +187,7 @@ def classification_error_evaluator(input, input=input, label=label, weight=weight, + top_k=top_k, classification_threshold=threshold, ) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 1fdc4c462363712e8b5b4dee10d0aaa26f4deffa..00aef80691fba05be543beadf22acde7d28c5e8e 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -37,6 +37,7 @@ __all__ = [ "dotmul_projection", "dotmul_operator", "repeat_layer", + "seq_reshape_layer", "table_projection", "mixed_layer", "data_layer", @@ -59,6 +60,7 @@ __all__ = [ 'img_cmrnorm_layer', 'addto_layer', 'concat_layer', + 'seq_concat_layer', 'lstm_step_layer', 'recurrent_group', 'memory', @@ -124,6 +126,7 @@ class LayerType(object): GRUMEMORY = "gated_recurrent" SEQUENCE_LAST_INSTANCE = "seqlastins" SEQUENCE_FIRST_INSTANCE = "seqfirstins" + SEQUENCE_RESHAPE = "seqreshape" POOLING_MAX = "max" POOLING_AVG = 'average' FC_LAYER = "fc" @@ -144,6 +147,7 @@ class LayerType(object): CONCAT_LAYER = 'concat' CONCAT_PROJ_LAYER = 'concat2' + SEQUENCE_CONCAT_LAYER = 'seqconcat' LSTM_STEP_LAYER = 'lstm_step' GRU_STEP_LAYER = 'gru_step' @@ -1448,6 +1452,61 @@ def repeat_layer(input, num_repeats, name=None, layer_attr=None): parents=[input]) +@wrap_name_default("seqreshape") +@wrap_act_default(act=IdentityActivation()) +@wrap_bias_attr_default(has_bias=False) +@layer_support() +def seq_reshape_layer(input, + reshape_size, + act=None, + name=None, + layer_attr=None, + bias_attr=None): + """ + A layer for reshaping the sequence. Assume the input sequence has T instances, + the dimension of each instance is M, and the input reshape_size is N, then the + output sequence has T*M/N instances, the dimension of each instance is N. + + Note that T*M/N must be an integer. + + The example usage is: + + .. code-block:: python + + reshape = seq_reshape_layer(input=layer, reshape_size=4) + + :param input: Input layer. + :type input: LayerOutput + :param reshape_size: the size of reshaped sequence. + :type reshape_size: int + :param name: Layer name. + :type name: basestring + :param act: Activation type. + :type act: BaseActivation + :param layer_attr: extra layer attributes. + :type layer_attr: ExtraLayerAttribute. + :param bias_attr: The Bias Attribute. If no bias, then pass False or + something not type of ParameterAttribute. None will get a + default Bias. + :type bias_attr: ParameterAttribute or None or bool + :return: LayerOutput object. + :rtype: LayerOutput + """ + + Layer( + inputs=[input.name], + name=name, + size=reshape_size, + type=LayerType.SEQUENCE_RESHAPE, + bias=ParamAttr.to_bias(bias_attr), + **ExtraAttr.to_kwargs(layer_attr)) + return LayerOutput( + name=name, + size=reshape_size, + layer_type=LayerType.SEQUENCE_RESHAPE, + parents=[input]) + + @wrap_name_default() @layer_support() def interpolation_layer(input, weight, name=None, layer_attr=None): @@ -2570,6 +2629,63 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None): size=sz) +@wrap_name_default("seqconcat") +@wrap_act_default(act=IdentityActivation()) +@wrap_bias_attr_default(has_bias=False) +@layer_support() +def seq_concat_layer(a, b, act=None, name=None, layer_attr=None, + bias_attr=None): + """ + Concat sequence a with sequence b. + + Inputs: + - a = [a1, a2, ..., an] + - b = [b1, b2, ..., bn] + - Note that the length of a and b should be the same. + + Output: [a1, b1, a2, b2, ..., an, bn] + + The example usage is: + + .. code-block:: python + + concat = seq_concat_layer(a=layer1, b=layer2) + + :param name: Layer name. + :type name: basestring + :param a: input sequence layer + :type a: LayerOutput + :param b: input sequence layer + :type b: LayerOutput + :param act: Activation type. + :type act: BaseActivation + :param layer_attr: Extra Layer Attribute. + :type layer_attr: ExtraLayerAttribute + :param bias_attr: The Bias Attribute. If no bias, then pass False or + something not type of ParameterAttribute. None will get a + default Bias. + :type bias_attr: ParameterAttribute or None or bool + :return: LayerOutput object. + :rtype: LayerOutput + """ + assert isinstance(a, LayerOutput) and isinstance(b, LayerOutput) + assert a.size == b.size + Layer( + name=name, + type=LayerType.SEQUENCE_CONCAT_LAYER, + inputs=[a.name, b.name], + active_type=act.name, + bias=ParamAttr.to_bias(bias_attr), + **ExtraLayerAttribute.to_kwargs(layer_attr)) + + return LayerOutput( + name, + layer_type=LayerType.SEQUENCE_CONCAT_LAYER, + parents=[a, b], + activation=act, + size=a.size) + + def memory(name, size, is_seq=False, @@ -2754,8 +2870,8 @@ def gru_step_layer(input, :param name: :param gate_act: :param bias_attr: - :param param_attr: the parameter_attribute for transforming the output_mem - from previous step. + :param param_attr: the parameter_attribute for transforming the output_mem + from previous step. :param layer_attr: :return: LayerOutput object. :rtype: LayerOutput @@ -2766,10 +2882,10 @@ def gru_step_layer(input, Layer( name=name, type=LayerType.GRU_STEP_LAYER, - # The parameter here is for transforming the output_mem. The input has - # already been transformed outside this module so it does not need - # parameter associated with it. - # The parameter here is instead grouped with input is due to + # The parameter here is for transforming the output_mem. The input has + # already been transformed outside this module so it does not need + # parameter associated with it. + # The parameter here is instead grouped with input is due to # backward model compatibility. inputs=[Input(input.name, **param_attr.attr), output_mem.name], bias=ParamAttr.to_bias(bias_attr), @@ -3420,6 +3536,7 @@ def classification_cost(input, label, weight=None, name=None, + top_k=None, evaluator=classification_error_evaluator, layer_attr=None): """ @@ -3434,6 +3551,8 @@ def classification_cost(input, :param weight: The weight affects the cost, namely the scale of cost. It is an optional argument. :type weight: LayerOutput + :param top_k: number k in top-k error rate + :type top_k: int :param evaluator: Evaluator method. :param layer_attr: layer's extra attribute. :type layer_attr: ExtraLayerAttribute @@ -3461,7 +3580,7 @@ def classification_cost(input, assert isinstance(e.for_classification, bool) assert e.for_classification - e(name=e.__name__, input=input, label=label, weight=weight) + e(name=e.__name__, input=input, label=label, weight=weight, top_k=top_k) if not isinstance(evaluator, collections.Sequence): evaluator = [evaluator] diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index ea46b557a26ce638742facda3eb6aa2feb4b2563..c9178e3c6a46a2d663ec368569e529e780b76a6f 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -4,6 +4,7 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight -test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops) +test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops +test_seq_concat_reshape) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_seq_concat_reshape.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_seq_concat_reshape.protostr new file mode 100644 index 0000000000000000000000000000000000000000..91284b4fb32fcfdbf6b9e7384ffe080574b78821 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_seq_concat_reshape.protostr @@ -0,0 +1,51 @@ +type: "nn" +layers { + name: "data1" + type: "data" + size: 30 + active_type: "" +} +layers { + name: "data2" + type: "data" + size: 30 + active_type: "" +} +layers { + name: "__seqconcat_0__" + type: "seqconcat" + size: 30 + active_type: "" + inputs { + input_layer_name: "data1" + } + inputs { + input_layer_name: "data2" + } +} +layers { + name: "__seqreshape_0__" + type: "seqreshape" + size: 5 + active_type: "linear" + inputs { + input_layer_name: "data1" + } +} +input_layer_names: "data1" +input_layer_names: "data2" +output_layer_names: "__seqconcat_0__" +output_layer_names: "__seqreshape_0__" +sub_models { + name: "root" + layer_names: "data1" + layer_names: "data2" + layer_names: "__seqconcat_0__" + layer_names: "__seqreshape_0__" + input_layer_names: "data1" + input_layer_names: "data2" + output_layer_names: "__seqconcat_0__" + output_layer_names: "__seqreshape_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_seq_concat_reshape.py b/python/paddle/trainer_config_helpers/tests/configs/test_seq_concat_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..5c161ba805fb301e8feb8702ad61a8341df40e3f --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_seq_concat_reshape.py @@ -0,0 +1,12 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +din1 = data_layer(name='data1', size=30) +din2 = data_layer(name='data2', size=30) + +opts = [] +opts.append(seq_concat_layer(a=din1, b=din2)) +opts.append(seq_reshape_layer(input=din1, reshape_size=5)) + +outputs(opts) diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index bc064a21ae150256752156f7ace56438321d5ba7..30d0b2a398bd0e39895daf9b1421ec736ab8da83 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -17,10 +17,12 @@ import activation import parameters import trainer import event +import data_type import py_paddle.swig_paddle as api __all__ = [ - 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'event' + 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', + 'event', 'data_type' ] diff --git a/python/paddle/v2/data_type.py b/python/paddle/v2/data_type.py new file mode 100644 index 0000000000000000000000000000000000000000..5b01ba4cd4866cf7b355fc0a6a667409cf9c4419 --- /dev/null +++ b/python/paddle/v2/data_type.py @@ -0,0 +1,22 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer.PyDataProvider2 import \ + InputType, dense_vector, sparse_binary_vector,\ + sparse_vector, integer_value + +__all__ = [ + 'InputType', 'dense_vector', 'sparse_binary_vector', 'sparse_vector', + 'integer_value' +] diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index 0ce4ecd569aa1dd9ad27c65775d235b969a52905..cd6dd5110a44b2bb9f5bf3f0a0d789ec9a2290fc 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -66,15 +66,20 @@ Also, the creation of a protobuf message is hidden in the invocation of paddle.v2.parameters.create, no longer exposed to users. """ +import collections + import paddle.trainer_config_helpers as conf_helps from paddle.trainer_config_helpers.config_parser_utils import \ parse_network_config as __parse__ from paddle.trainer_config_helpers.default_decorators import wrap_name_default -import collections + +import data_type __all__ = [ 'parse_network', 'data', 'fc', 'max_id', 'classification_cost', - 'cross_entropy_cost' + 'cross_entropy_cost', 'cross_entropy_with_selfnorm_cost', 'regression_cost', + 'multi_binary_label_cross_entropy_cost', 'rank_cost', 'lambda_cost', + 'sum_cost', 'huber_cost' ] @@ -134,7 +139,8 @@ def __convert_to_v2__(method_name, name_prefix, parent_names): parent_layers = dict() other_kwargs = dict() for pname in parent_names: - parent_layers[pname] = kwargs[pname] + if kwargs.has_key(pname): + parent_layers[pname] = kwargs[pname] for key in kwargs.keys(): if key not in parent_names: @@ -157,30 +163,90 @@ def __convert_to_v2__(method_name, name_prefix, parent_names): return V2LayerImpl -data = __convert_to_v2__('data_layer', None, []) +""" +Some layer may need some special config, and can not use __convert_to_v2__ to convert. +So we also need to implement some special LayerV2. +""" + + +class DataLayerV2(Layer): + def __init__(self, name, type, **kwargs): + assert isinstance(type, data_type.InputType) + + self.type = type + self.__method_name__ = 'data_layer' + self.__kwargs__ = kwargs + + super(DataLayerV2, self).__init__(name=name, parent_layers=dict()) + + def to_proto_impl(self, **kwargs): + args = dict() + args['size'] = self.type.dim + for each in kwargs: + args[each] = kwargs[each] + for each in self.__kwargs__: + args[each] = self.__kwargs__[each] + return getattr(conf_helps, self.__method_name__)(name=self.name, **args) + + +data = DataLayerV2 fc = __convert_to_v2__('fc_layer', name_prefix='fc', parent_names=['input']) max_id = __convert_to_v2__( - 'maxid_layer', name_prefix='maxid_layer', parent_names=['input']) + 'maxid_layer', name_prefix='maxid', parent_names=['input']) classification_cost = __convert_to_v2__( 'classification_cost', name_prefix='classification_cost', - parent_names=['input', 'label']) + parent_names=['input', 'label', 'weight']) +regression_cost = __convert_to_v2__( + 'regression_cost', + name_prefix='regression_cost', + parent_names=['input', 'label', 'weight']) cross_entropy_cost = __convert_to_v2__( 'cross_entropy', name_prefix='cross_entropy', parent_names=['input', 'label']) +cross_entropy_with_selfnorm_cost = __convert_to_v2__( + 'cross_entropy_with_selfnorm', + name_prefix='cross_entropy_with_selfnorm', + parent_names=['input', 'label']) +multi_binary_label_cross_entropy_cost = __convert_to_v2__( + 'multi_binary_label_cross_entropy', + name_prefix='multi_binary_label_cross_entropy', + parent_names=['input', 'label']) +rank_cost = __convert_to_v2__( + 'rank_cost', + name_prefix='rank_cost', + parent_names=['left', 'right', 'label', 'weight']) +lambda_cost = __convert_to_v2__( + 'lambda_cost', name_prefix='lambda_cost', parent_names=['input', 'score']) +sum_cost = __convert_to_v2__( + 'sum_cost', name_prefix='sum_cost', parent_names=['input']) +huber_cost = __convert_to_v2__( + 'huber_cost', name_prefix='huber_cost', parent_names=['input', 'label']) if __name__ == '__main__': - pixel = data(name='pixel', size=784) - label = data(name='label', size=10) + pixel = data(name='pixel', type=data_type.dense_vector(784)) + label = data(name='label', type=data_type.integer_value(10)) + weight = data(name='weight', type=data_type.dense_vector(10)) + score = data(name='score', type=data_type.dense_vector(1)) + hidden = fc(input=pixel, size=100, act=conf_helps.SigmoidActivation()) inference = fc(input=hidden, size=10, act=conf_helps.SoftmaxActivation()) maxid = max_id(input=inference) cost1 = classification_cost(input=inference, label=label) - cost2 = cross_entropy_cost(input=inference, label=label) + cost2 = classification_cost(input=inference, label=label, weight=weight) + cost3 = cross_entropy_cost(input=inference, label=label) + cost4 = cross_entropy_with_selfnorm_cost(input=inference, label=label) + cost5 = regression_cost(input=inference, label=label) + cost6 = regression_cost(input=inference, label=label, weight=weight) + cost7 = multi_binary_label_cross_entropy_cost(input=inference, label=label) + cost8 = rank_cost(left=score, right=score, label=score) + cost9 = lambda_cost(input=inference, score=score) + cost10 = sum_cost(input=inference) + cost11 = huber_cost(input=score, label=label) - print parse_network(cost1) - print parse_network(cost2) print parse_network(cost1, cost2) - print parse_network(cost2) + print parse_network(cost3, cost4) + print parse_network(cost5, cost6) + print parse_network(cost7, cost8, cost9, cost10, cost11) print parse_network(inference, maxid)