提交 cbb53219 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into rnn

# Python Data Reader Design Doc # Python Data Reader Design Doc
Paddle reads data from data reader during training. It will be passed into `paddle.train` as a parameter. At training and testing time, PaddlePaddle programs need to read data. To ease the users' work to write data reading code, we define that
- A *reader* is a function that reads data (from file, network, random number generator, etc) and yields data items.
- A *reader creator* is a function that returns a reader function.
- A *reader* decorator is a function, which accepts one or more readers, and returns a reader.
and provide frequently used reader creators and reader decorators.
## Data Reader Interface ## Data Reader Interface
Data reader is a function with no parameter that creates a iterable (anything can be used in `for x in iterable`): Indeed, *data reader* doesn't have to be a function that reads and yields data items. It can be any function with no parameter that creates a iterable (anything can be used in `for x in iterable`):
``` ```
iterable = data_reader() iterable = data_reader()
``` ```
Element produced for the iterable should be a **single** entry of data, **not** a mini batch. That entry of data could be a single item, or a tuple of items. Item should be of [supported type](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int) Element produced from the iterable should be a **single** entry of data, **not** a mini batch. That entry of data could be a single item, or a tuple of items. Item should be of [supported type](http://www.paddlepaddle.org/doc/ui/data_provider/pydataprovider2.html?highlight=dense_vector#input-types) (e.g., numpy 1d array of float32, int, list of int)
An example implementation for single item data reader: An example implementation for single item data reader creator:
```python ```python
def data_reader_fake_image(): def reader_creator_random_image(width, height):
while True: def reader():
yield numpy.random.uniform(-1, 1, size=20*20) while True:
yield numpy.random.uniform(-1, 1, size=width*height)
return reader
``` ```
An example implementation for multiple item data reader: An example implementation for multiple item data reader creator:
```python ```python
def data_reader_fake_image_and_label(): def reader_creator_random_imageand_label(widht, height, label):
while True: def reader():
yield numpy.random.uniform(-1, 1, size=20*20), False while True:
yield numpy.random.uniform(-1, 1, size=width*height), label
return reader
``` ```
## Usage ## Usage
...@@ -41,9 +51,9 @@ label_layer = paddle.layer.data("label", ...) ...@@ -41,9 +51,9 @@ label_layer = paddle.layer.data("label", ...)
paddle.train(paddle.dataset.mnist, {"image":0, "label":1}, 128, 10, ...) paddle.train(paddle.dataset.mnist, {"image":0, "label":1}, 128, 10, ...)
``` ```
## Data Reader Decorators ## Data Reader Decorator
Data reader decorators takes a single or multiple data reader, returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` syntax. *Data reader decorator* takes a single or multiple data reader, returns a new data reader. It is similar to a [python decorator](https://wiki.python.org/moin/PythonDecorators), but it does not use `@` syntax.
Since we have a strict interface for data readers (no parameter, return a single data item). Data reader can be used flexiable via data reader decorators. Following are a few examples: Since we have a strict interface for data readers (no parameter, return a single data item). Data reader can be used flexiable via data reader decorators. Following are a few examples:
...@@ -61,23 +71,27 @@ buffered_reader = paddle.reader.buffered(paddle.dataset.mnist, 100) ...@@ -61,23 +71,27 @@ buffered_reader = paddle.reader.buffered(paddle.dataset.mnist, 100)
### Compose Multiple Data Readers ### Compose Multiple Data Readers
For example, we want to use a source of real images (reusing mnist dataset), and a source of fake images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661). For example, we want to use a source of real images (reusing mnist dataset), and a source of random images as input for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661).
We can do: We can do:
```python ```python
def data_reader_fake_image(): def reader_creator_random_image(width, height):
while True: def reader():
yield numpy.random.uniform(-1, 1, size=20*20) while True:
yield numpy.random.uniform(-1, 1, size=width*height)
def data_reader_bool(t): return reader
while True:
yield t def reader_creator_bool(t):
def reader:
true_reader = lambda : data_reader_bool(True) while True:
false_reader = lambda : data_reader_bool(False) yield t
return reader
reader = paddle.reader.combine(paddle.dataset.mnist, data_reader_fake_image, true_reader, false_reader)
true_reader = reader_creator_bool(True)
false_reader = reader_creator_bool(False)
reader = paddle.reader.compose(paddle.dataset.mnist, data_reader_creator_random_image(20, 20), true_reader, false_reader)
# Skipped 1 because paddle.dataset.mnist produces two items per data entry. # Skipped 1 because paddle.dataset.mnist produces two items per data entry.
# And we don't care second item at this time. # And we don't care second item at this time.
paddle.train(reader, {"true_image":0, "fake_image": 2, "true_label": 3, "false_label": 4}, ...) paddle.train(reader, {"true_image":0, "fake_image": 2, "true_label": 3, "false_label": 4}, ...)
...@@ -98,29 +112,31 @@ reader = paddle.reader.shuffle(paddle.dataset.mnist, 512) ...@@ -98,29 +112,31 @@ reader = paddle.reader.shuffle(paddle.dataset.mnist, 512)
If a mini batch is returned, data reader need to take care of batch size. But batch size is a concept for training, it makes more sense for user to specify batch size as a parameter for `train`. If a mini batch is returned, data reader need to take care of batch size. But batch size is a concept for training, it makes more sense for user to specify batch size as a parameter for `train`.
Practically, always return a single entry make reusing existing data reader much easier (e.g., if existing data reader return not a single entry but 3 entries, training code will be more complex because it need to handle cases like batch size 2). Practically, always return a single entry make reusing existing data readers much easier (e.g., if existing reader return not a single entry but 3 entries, training code will be more complex because it need to handle cases like batch size 2).
### Why use a dictionary but not a list to provide mapping? ### Why use a dictionary but not a list to provide mapping?
We decided to use dictionary (`{"image":0, "label":1}`) instead of list (`["image", "label"]`) is because that user can easily resue item (e.g., using `{"image_a":0, "image_b":0, "label":1}`) or skip item (e.g., using `{"image_a":0, "label":2}`). We decided to use dictionary (`{"image":0, "label":1}`) instead of list (`["image", "label"]`) is because that user can easily resue item (e.g., using `{"image_a":0, "image_b":0, "label":1}`) or skip item (e.g., using `{"image_a":0, "label":2}`).
### How to create custom data reader ### How to create custom data reader creator
```python ```python
def image_reader(image_path, label_path, n): def image_reader_creator(image_path, label_path, n):
f = open(image_path) def reader():
l = open(label_path) f = open(image_path)
images = numpy.fromfile( l = open(label_path)
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32') images = numpy.fromfile(
images = images / 255.0 * 2.0 - 1.0 f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int") images = images / 255.0 * 2.0 - 1.0
for i in xrange(n): labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
yield images[i, :], labels[i] # a single entry of data is created each time for i in xrange(n):
f.close() yield images[i, :], labels[i] # a single entry of data is created each time
l.close() f.close()
l.close()
# use python lambda to change image_reader into a function with no parameters. return reader
reader = lambda : image_reader("/path/to/image_file", "/path/to/label_file", 1024)
# images_reader_creator creates a reader
reader = image_reader_creator("/path/to/image_file", "/path/to/label_file", 1024)
paddle.train(reader, {"image":0, "label":1}, ...) paddle.train(reader, {"image":0, "label":1}, ...)
``` ```
...@@ -129,7 +145,7 @@ paddle.train(reader, {"image":0, "label":1}, ...) ...@@ -129,7 +145,7 @@ paddle.train(reader, {"image":0, "label":1}, ...)
An example implementation of paddle.train could be: An example implementation of paddle.train could be:
```python ```python
def minibatch_decorater(reader, minibatch_size): def make_minibatch(reader, minibatch_size):
def ret(): def ret():
r = reader() r = reader()
buf = [r.next() for x in xrange(minibatch_size)] buf = [r.next() for x in xrange(minibatch_size)]
...@@ -140,6 +156,6 @@ def minibatch_decorater(reader, minibatch_size): ...@@ -140,6 +156,6 @@ def minibatch_decorater(reader, minibatch_size):
def train(reader, mapping, batch_size, total_pass): def train(reader, mapping, batch_size, total_pass):
for pass_idx in range(total_pass): for pass_idx in range(total_pass):
for mini_batch in minibatch_decorater(reader): # this loop will never end in online learning. for mini_batch in make_minibatch(reader): # this loop will never end in online learning.
do_forward_backward(mini_batch, mapping) do_forward_backward(mini_batch, mapping)
``` ```
...@@ -68,7 +68,7 @@ class TestMatrix(unittest.TestCase): ...@@ -68,7 +68,7 @@ class TestMatrix(unittest.TestCase):
def test_numpyCpu(self): def test_numpyCpu(self):
numpy_mat = np.matrix([[1, 2], [3, 4], [5, 6]], dtype="float32") numpy_mat = np.matrix([[1, 2], [3, 4], [5, 6]], dtype="float32")
m = swig_paddle.Matrix.createCpuDenseFromNumpy(numpy_mat, copy=False) m = swig_paddle.Matrix.createCpuDenseFromNumpy(numpy_mat, False)
self.assertEqual((int(m.getHeight()), int(m.getWidth())), self.assertEqual((int(m.getHeight()), int(m.getWidth())),
numpy_mat.shape) numpy_mat.shape)
......
...@@ -43,7 +43,7 @@ class TestIVector(unittest.TestCase): ...@@ -43,7 +43,7 @@ class TestIVector(unittest.TestCase):
def test_cpu_numpy(self): def test_cpu_numpy(self):
vec = np.array([1, 3, 4, 65, 78, 1, 4], dtype="int32") vec = np.array([1, 3, 4, 65, 78, 1, 4], dtype="int32")
iv = swig_paddle.IVector.createCpuVectorFromNumpy(vec, copy=False) iv = swig_paddle.IVector.createCpuVectorFromNumpy(vec, False)
self.assertEqual(vec.shape[0], int(iv.__len__())) self.assertEqual(vec.shape[0], int(iv.__len__()))
vec[4] = 832 vec[4] = 832
for i in xrange(len(iv)): for i in xrange(len(iv)):
...@@ -106,7 +106,7 @@ class TestVector(unittest.TestCase): ...@@ -106,7 +106,7 @@ class TestVector(unittest.TestCase):
def testCpuNumpy(self): def testCpuNumpy(self):
numpy_arr = np.array([1.2, 2.3, 3.4, 4.5], dtype="float32") numpy_arr = np.array([1.2, 2.3, 3.4, 4.5], dtype="float32")
vec = swig_paddle.Vector.createCpuVectorFromNumpy(numpy_arr, copy=False) vec = swig_paddle.Vector.createCpuVectorFromNumpy(numpy_arr, False)
assert isinstance(vec, swig_paddle.Vector) assert isinstance(vec, swig_paddle.Vector)
numpy_arr[0] = 0.1 numpy_arr[0] = 0.1
for n, v in zip(numpy_arr, vec): for n, v in zip(numpy_arr, vec):
......
...@@ -69,19 +69,6 @@ extern void hl_sequence_softmax_forward(real* A_d, ...@@ -69,19 +69,6 @@ extern void hl_sequence_softmax_forward(real* A_d,
const int* index, const int* index,
int numSequence); int numSequence);
/**
* @brief Matrix classification error.
*
* @param[in] A_d input matrix (M x N).
* @param[in] B_d input vector (M x 1).
* @param[out] C_d output vector (M x 1).
* @param[in] dimM matrix height.
* @param[in] dimN matrix width.
*
*/
extern void hl_matrix_classification_error(
real* A_d, int* B_d, real* C_d, int dimM, int dimN);
/** /**
* @brief Matrix cross entropy. * @brief Matrix cross entropy.
* *
......
...@@ -58,4 +58,30 @@ extern void hl_sparse_matrix_top_k(real* topVal, ...@@ -58,4 +58,30 @@ extern void hl_sparse_matrix_top_k(real* topVal,
int beamSize, int beamSize,
int numSamples); int numSamples);
#endif /* HL_TOP_K_H_ */ /**
* @brief Matrix classification error.
*
* @param[out] topVal top k element.
* @param[in] ldv leading dimension of topVal.
* @param[out] topIds top k index.
* @param[in] src input value.
* @param[in] lds leading dimension of src.
* @param[in] dim width of input value.
* @param[in] topkSize size of top k element.
* @param[in] numSamples height of input value.
* @param[in] label ground truth label.
* @param[out] recResult top-k classification error.
*
*/
extern void hl_matrix_classification_error(real* topVal,
int ldv,
int* topIds,
real* src,
int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult);
#endif // HL_TOP_K_H_
...@@ -35,8 +35,16 @@ inline void hl_sequence_softmax_forward(real* A_d, ...@@ -35,8 +35,16 @@ inline void hl_sequence_softmax_forward(real* A_d,
inline void hl_matrix_softmax_derivative( inline void hl_matrix_softmax_derivative(
real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {} real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {}
inline void hl_matrix_classification_error( inline void hl_matrix_classification_error(real* topVal,
real* A_d, int* B_d, real* C_d, int dimM, int dimN) {} int ldv,
int* topIds,
real* src,
int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult) {}
inline void hl_matrix_cross_entropy( inline void hl_matrix_cross_entropy(
real* A_d, real* C_d, int* label_d, int dimM, int dimN) {} real* A_d, real* C_d, int* label_d, int dimM, int dimN) {}
......
...@@ -265,59 +265,6 @@ void hl_matrix_softmax_derivative(real *grad_d, ...@@ -265,59 +265,6 @@ void hl_matrix_softmax_derivative(real *grad_d,
CHECK_SYNC("hl_matrix_softmax_derivative failed"); CHECK_SYNC("hl_matrix_softmax_derivative failed");
} }
template<int blockSize>
__global__ void KeMatrixClassificationError(real* in_A,
int* in_B,
real* out_C,
int dimN) {
__shared__ real max_s[blockSize];
__shared__ int max_l[blockSize];
const int tid = threadIdx.x;
const int rowId = blockIdx.x;
max_s[tid] = -1e30f;
in_A += rowId * dimN;
real tmp;
for (int colId = tid; colId < dimN; colId += blockSize) {
tmp = in_A[colId];
if (max_s[tid] < tmp) {
max_s[tid] = tmp;
max_l[tid] = colId;
}
}
__syncthreads();
for (int stride = blockSize/2; stride > 0; stride = stride/2) {
if (tid < stride) {
if (max_s[tid] < max_s[tid + stride]) {
max_s[tid] = max_s[tid + stride];
max_l[tid] = max_l[tid + stride];
}
}
__syncthreads();
}
__syncthreads();
if (tid == 0) {
out_C[rowId] = (max_l[0] == in_B[rowId] ? 0 : 1.0f);
}
}
void hl_matrix_classification_error(real* A_d,
int* B_d,
real* C_d,
int dimM,
int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
CHECK_NOTNULL(C_d);
// each sample is calculated by one block
KeMatrixClassificationError<1024><<< dimM, 1024, 0, STREAM_DEFAULT >>>
(A_d, B_d, C_d, dimN);
CHECK_SYNC("hl_matrix_classification_error");
}
__global__ void KeMatrixMultiBinaryCrossEntropy(real* output, __global__ void KeMatrixMultiBinaryCrossEntropy(real* output,
real* entropy, real* entropy,
int* row, int* row,
......
...@@ -384,3 +384,81 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv, ...@@ -384,3 +384,81 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv,
CHECK_SYNC("hl_sparse_matrix_top_k failed"); CHECK_SYNC("hl_sparse_matrix_top_k failed");
} }
/**
* Each block compute one sample.
* In a block:
* 1. every thread get top maxLength value;
* 2. merge to shTopK, block reduce and get max value;
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
*/
template<int maxLength, int blockSize>
__global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
int * topIds,
real* src, int lds,
int dim,
int beamSize,
int* label,
real* recResult) {
__shared__ Pair shTopK[blockSize];
__shared__ int maxId[blockSize / 2];
const int tid = threadIdx.x;
const int warp = threadIdx.x / 32;
src += blockIdx.x * lds;
topVal += blockIdx.x * ldv;
topIds += blockIdx.x * beamSize;
Pair topK[maxLength]; // NOLINT
int beam = maxLength;
Pair max;
bool isEmpty = false;
bool firstStep = true;
int topkSize = beamSize;
for (int k = 0; k < maxLength; k++) {
topK[k].set(-HL_FLOAT_MAX, -1);
}
while (beamSize) {
threadGetTopK<maxLength, blockSize>
(topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
shTopK[tid] = topK[0];
blockReduce<maxLength, blockSize>
(shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
}
__syncthreads();
if (tid == 0) {
for (int i = 0; i < topkSize; i++) {
if (*--topIds == label[blockIdx.x]) {
recResult[blockIdx.x] = 0;
break;
}
recResult[blockIdx.x] = 1.0f;
}
}
}
void hl_matrix_classification_error(real* topVal, int ldv,
int* topIds,
real* src, int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult) {
CHECK_NOTNULL(topVal);
CHECK_NOTNULL(topIds);
CHECK_NOTNULL(src);
if (topkSize > dim) topkSize = dim;
dim3 threads(256, 1);
dim3 grid(numSamples, 1);
KeMatrixTopKClassificationError<5, 256>
<<< grid, threads, 0, STREAM_DEFAULT >>>
(topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult);
CHECK_SYNC("hl_matrix_top_k classification error failed");
}
...@@ -54,22 +54,26 @@ DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size) ...@@ -54,22 +54,26 @@ DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size)
#define WARPCTC_GET_VERSION dynload::get_warpctc_version #define WARPCTC_GET_VERSION dynload::get_warpctc_version
#define WARPCTC_GET_STATUS_STRING dynload::ctcGetStatusString #define WARPCTC_GET_STATUS_STRING dynload::ctcGetStatusString
static int g_warpctcVersion = -1;
#ifndef PADDLE_TYPE_DOUBLE #ifndef PADDLE_TYPE_DOUBLE
#define WARPCTC_COMPUTE_LOSS dynload::compute_ctc_loss #define WARPCTC_COMPUTE_LOSS dynload::compute_ctc_loss
#define WARPCTC_GET_WORKSPACE_SIZE dynload::get_workspace_size #define WARPCTC_GET_WORKSPACE_SIZE dynload::get_workspace_size
#else #else
#define WARPCTC_LOG_FATAL \ hl_warpctc_status_t fatal(...) {
LOG(FATAL) << "warp-ctc [version " << g_warpctcVersion \ LOG(FATAL) << "warp-ctc [version " << g_warpctcVersion
<< "] Error: not support double precision." << "] Error: not support double precision.";
#define WARPCTC_COMPUTE_LOSS(...) WARPCTC_LOG_FATAL(__VA_ARGS__) // both of get_warpctc_version() and get_workspace_size() return an ctcStatus
#define WARPCTC_GET_WORKSPACE_SIZE(...) WARPCTC_LOG_FATAL(__VA_ARGS__) // type value
return CTC_STATUS_EXECUTION_FAILED;
}
#define WARPCTC_COMPUTE_LOSS fatal
#define WARPCTC_GET_WORKSPACE_SIZE fatal
#endif #endif
/** /**
* Check build-in warp-ctc function using glog and it also * Check build-in warp-ctc function using glog and it also
* support << operator for more details error info. * support << operator for more details error info.
*/ */
static int g_warpctcVersion = -1;
#define CHECK_WARPCTC(warpctcStat) \ #define CHECK_WARPCTC(warpctcStat) \
CHECK_EQ(CTC_STATUS_SUCCESS, warpctcStat) \ CHECK_EQ(CTC_STATUS_SUCCESS, warpctcStat) \
<< "warp-ctc [version " << g_warpctcVersion \ << "warp-ctc [version " << g_warpctcVersion \
......
...@@ -39,6 +39,14 @@ void Evaluator::eval(const NeuralNetwork& nn) { ...@@ -39,6 +39,14 @@ void Evaluator::eval(const NeuralNetwork& nn) {
*/ */
class ClassificationErrorEvaluator : public Evaluator { class ClassificationErrorEvaluator : public Evaluator {
public: public:
/*
ClassificationErrorEvaluator() : totalScore2_(0) {}
virtual void start() {
Evaluator::start();
totalScore2_ = 0;
} */
virtual void updateSamplesNum(const std::vector<Argument>& arguments) { virtual void updateSamplesNum(const std::vector<Argument>& arguments) {
if (3 == arguments.size()) { if (3 == arguments.size()) {
numSamples_ += arguments[2].value->getSum(); numSamples_ += arguments[2].value->getSum();
...@@ -76,9 +84,11 @@ public: ...@@ -76,9 +84,11 @@ public:
1, 1,
/* trans= */ false, /* trans= */ false,
useGpu(arguments[0].deviceId)); useGpu(arguments[0].deviceId));
errorMat->zeroMem(); errorMat->zeroMem();
if (label != nullptr) { if (label != nullptr) {
errorMat->classificationError(*output, *label); errorMat->classificationError(*output, *label, config_.top_k());
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) || } else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) { dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
errorMat->classificationErrorMulti( errorMat->classificationErrorMulti(
...@@ -94,6 +104,16 @@ public: ...@@ -94,6 +104,16 @@ public:
return errorMat; return errorMat;
} }
void printStats(std::ostream& os) const {
if (config_.top_k() == 1) {
os << config_.name() << "="
<< (numSamples_ ? totalScore_ / numSamples_ : 0);
} else {
os << " top_" << config_.top_k()
<< "_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0);
}
}
virtual real evalImp(std::vector<Argument>& arguments) { virtual real evalImp(std::vector<Argument>& arguments) {
MatrixPtr errorMat = calcError(arguments); MatrixPtr errorMat = calcError(arguments);
return errorMat->getSum(); return errorMat->getSum();
......
...@@ -311,6 +311,7 @@ public: ...@@ -311,6 +311,7 @@ public:
return *output->second; return *output->second;
} else { } else {
LOG(FATAL) << "No specific output " << str; LOG(FATAL) << "No specific output " << str;
return *((Argument*)nullptr);
} }
} }
} }
......
...@@ -129,6 +129,7 @@ void testEvaluatorAll(TestConfig testConf, ...@@ -129,6 +129,7 @@ void testEvaluatorAll(TestConfig testConf,
TEST(Evaluator, classification_error) { TEST(Evaluator, classification_error) {
TestConfig config; TestConfig config;
config.evaluatorConfig.set_type("classification_error"); config.evaluatorConfig.set_type("classification_error");
config.evaluatorConfig.set_top_k(5);
config.inputDefs.push_back({INPUT_DATA, "output", 50}); config.inputDefs.push_back({INPUT_DATA, "output", 50});
config.inputDefs.push_back({INPUT_LABEL, "label", 50}); config.inputDefs.push_back({INPUT_LABEL, "label", 50});
......
...@@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { ...@@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
size_t beam = maxVal.getWidth(); size_t beam = maxVal.getWidth();
CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxIds.getSize(), numSamples * beam);
CHECK_EQ(maxVal.getHeight(), numSamples); CHECK_EQ(maxVal.getHeight(), numSamples);
CHECK_EQ(maxVal.getWidth(), beam);
hl_matrix_top_k(maxVal.getData(), hl_matrix_top_k(maxVal.getData(),
maxVal.getStride(), maxVal.getStride(),
...@@ -792,19 +793,32 @@ void GpuMatrix::maxoutBackward(Matrix& a, ...@@ -792,19 +793,32 @@ void GpuMatrix::maxoutBackward(Matrix& a,
} }
/*calulate the error of classification */ /*calulate the error of classification */
void GpuMatrix::classificationError(Matrix& output, IVector& label) { void GpuMatrix::classificationError(Matrix& output,
auto output_ptr = dynamic_cast<const GpuMatrix*>(&output); IVector& label,
auto label_ptr = dynamic_cast<const GpuIVector*>(&label); size_t topkSize) {
CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; auto gpuOutput = dynamic_cast<GpuMatrix*>(&output);
auto gpuLabel = dynamic_cast<GpuIVector*>(&label);
CHECK(height_ == output_ptr->height_ && width_ == 1) size_t numSamples = this->getHeight();
GpuMatrixPtr gpuTopVal = std::make_shared<GpuMatrix>(numSamples, topkSize);
GpuIVectorPtr gpuTopIds = std::make_shared<GpuIVector>(numSamples * topkSize);
CHECK(gpuOutput && gpuLabel) << "Invalid argument pointer";
CHECK(gpuTopVal && gpuTopIds) << "Allocate GPU memory failed";
CHECK(gpuLabel->getSize() == numSamples) << "Vector size is not equal";
CHECK(numSamples == gpuOutput->getHeight() && this->getWidth() == 1)
<< "Matrix dimensions are not equal"; << "Matrix dimensions are not equal";
hl_matrix_classification_error((real*)output_ptr->data_, size_t dim = gpuOutput->getWidth();
(int*)label_ptr->getData(), hl_matrix_classification_error(gpuTopVal->getData(),
data_, gpuTopVal->getStride(),
height_, gpuTopIds->getData(),
output_ptr->width_); gpuOutput->getData(),
gpuOutput->getStride(),
dim,
topkSize,
numSamples,
gpuLabel->getData(),
this->getData());
} }
/* copy -log(output[i * width + label]) to this->data[i] */ /* copy -log(output[i * width + label]) to this->data[i] */
...@@ -3039,7 +3053,7 @@ void CpuMatrix::rowMax(Matrix& max) { ...@@ -3039,7 +3053,7 @@ void CpuMatrix::rowMax(Matrix& max) {
max.maxRows(*this); max.maxRows(*this);
} }
/* get beam size of max ids and values */ /* Get the top k elements of each row of this matrix */
void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
CHECK(isContiguous()); CHECK(isContiguous());
CHECK(!maxIds.useGpu() && !maxVal.useGpu()) << "Matrix type are not equal"; CHECK(!maxIds.useGpu() && !maxVal.useGpu()) << "Matrix type are not equal";
...@@ -3047,6 +3061,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { ...@@ -3047,6 +3061,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
size_t beam = maxVal.getWidth(); size_t beam = maxVal.getWidth();
CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxIds.getSize(), numSamples * beam);
CHECK_EQ(maxVal.getHeight(), numSamples); CHECK_EQ(maxVal.getHeight(), numSamples);
CHECK_EQ(maxVal.getWidth(), beam);
real* a = getData(); real* a = getData();
int* s = maxIds.getData(); int* s = maxIds.getData();
...@@ -3198,32 +3213,39 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) { ...@@ -3198,32 +3213,39 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) {
} }
/* calulate classification error */ /* calulate classification error */
void CpuMatrix::classificationError(Matrix& output, IVector& label) { void CpuMatrix::classificationError(Matrix& output,
CHECK(dynamic_cast<const CpuMatrix*>(&output)); IVector& label,
CHECK(dynamic_cast<const CpuIVector*>(&label)); size_t topkSize) {
size_t numSamples = this->getHeight();
auto cpuOutput = dynamic_cast<CpuMatrix*>(&output);
auto cpuLabel = dynamic_cast<CpuIVector*>(&label);
IVectorPtr cpuTopIds = std::make_shared<CpuIVector>(numSamples * topkSize);
MatrixPtr cpuTopVal = std::make_shared<CpuMatrix>(numSamples, topkSize);
CHECK(cpuOutput && cpuLabel) << "Invalid argument pointer";
CHECK(cpuTopIds && cpuTopVal) << "Allocate cpu memory failed";
CHECK(cpuLabel->getSize() == numSamples) << "Vector size is not equal";
CHECK(cpuOutput->getHeight() == numSamples && this->getWidth() == 1)
<< "Matrix dimensions are not equal";
CHECK_EQ(getWidth(), (size_t)1); // top k matrix classification
size_t numSamples = getHeight(); cpuOutput->rowMax(*cpuTopIds, *cpuTopVal);
CHECK_EQ(label.getSize(), numSamples);
CHECK_EQ(output.getHeight(), numSamples);
size_t dim = output.getWidth(); size_t dim = cpuOutput->getWidth();
real* out = output.getData(); real* result = this->getData();
int* lbl = label.getData(); int* ids = cpuTopIds->getData();
real maxData = 0.0; int* lbl = cpuLabel->getData();
int maxIndex = -1;
for (size_t i = 0; i < numSamples; ++i) { for (size_t i = 0; i < numSamples; ++i) {
CHECK_GE(lbl[i], 0); CHECK_GE(lbl[i], 0);
CHECK_LT((size_t)lbl[i], dim); CHECK_LT((size_t)lbl[i], dim);
maxData = out[i * dim];
maxIndex = 0; for (size_t j = 0; j < topkSize; ++j) {
for (size_t j = 0; j < dim; ++j) { if (ids[j + i * topkSize] == lbl[i]) {
if (maxData < out[i * dim + j]) { result[i] = 0;
maxIndex = j; break;
maxData = out[i * dim + j];
} }
result[i] = 1.0f;
} }
getData()[i] = (maxIndex != lbl[i]);
} }
} }
......
...@@ -836,8 +836,11 @@ public: ...@@ -836,8 +836,11 @@ public:
* output[i] = 1 if row i is an error. * output[i] = 1 if row i is an error.
* *
* output[i] = 0 if row i is correct. * output[i] = 0 if row i is correct.
*
*/ */
virtual void classificationError(Matrix& output, IVector& label) { virtual void classificationError(Matrix& output,
IVector& label,
size_t topkSize = 1) {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
...@@ -1314,7 +1317,7 @@ public: ...@@ -1314,7 +1317,7 @@ public:
void check(std::ostream& os, Matrix& refMat, bool printDiff = true); void check(std::ostream& os, Matrix& refMat, bool printDiff = true);
void randomizeUniform(); void randomizeUniform();
void classificationError(Matrix& output, IVector& label); void classificationError(Matrix& output, IVector& label, size_t topkSize = 1);
void convExpand(Matrix& feature, void convExpand(Matrix& feature,
int feaImgHeight, int feaImgHeight,
...@@ -1739,7 +1742,7 @@ public: ...@@ -1739,7 +1742,7 @@ public:
void randomizeUniform(); void randomizeUniform();
void classificationError(Matrix& output, IVector& label); void classificationError(Matrix& output, IVector& label, size_t topkSize = 1);
void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec); void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec);
......
...@@ -764,7 +764,7 @@ TEST(Matrix, paramReluBackwardDiff) { ...@@ -764,7 +764,7 @@ TEST(Matrix, paramReluBackwardDiff) {
} }
} }
void testClassificationError(int numSamples, int dim) { void testClassificationError(int numSamples, int dim, int topkSize) {
MatrixPtr cpuError = std::make_shared<CpuMatrix>(numSamples, 1); MatrixPtr cpuError = std::make_shared<CpuMatrix>(numSamples, 1);
MatrixPtr gpuError = std::make_shared<GpuMatrix>(numSamples, 1); MatrixPtr gpuError = std::make_shared<GpuMatrix>(numSamples, 1);
MatrixPtr cpuOutput = std::make_shared<CpuMatrix>(numSamples, dim); MatrixPtr cpuOutput = std::make_shared<CpuMatrix>(numSamples, dim);
...@@ -777,17 +777,22 @@ void testClassificationError(int numSamples, int dim) { ...@@ -777,17 +777,22 @@ void testClassificationError(int numSamples, int dim) {
gpuOutput->copyFrom(*cpuOutput); gpuOutput->copyFrom(*cpuOutput);
gpuLabel->copyFrom(*cpuLabel); gpuLabel->copyFrom(*cpuLabel);
cpuError->classificationError(*cpuOutput, *cpuLabel); cpuError->classificationError(*cpuOutput, *cpuLabel, topkSize);
gpuError->classificationError(*gpuOutput, *gpuLabel); gpuError->classificationError(*gpuOutput, *gpuLabel, topkSize);
TensorCheckEqual(*cpuError, *gpuError); TensorCheckEqual(*cpuError, *gpuError);
} }
TEST(Matrix, classificationError) { TEST(Matrix, classificationError) {
for (auto numSamples : {1, 10, 100, 1000, 70000}) { for (auto numSamples : {1, 5, 31, 90, 150, 300}) {
for (auto dim : {1, 10, 100, 1000}) { for (auto dim :
VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; {1, 5, 8, 10, 15, 64, 80, 120, 256, 300, 1280, 5120, 50000}) {
testClassificationError(numSamples, dim); for (auto topkSize : {1, 5, 10, 20, 40, (int)rand() % dim + 1}) {
if (topkSize > dim) continue;
VLOG(3) << " sample= " << numSamples << " topkSize= " << topkSize
<< " dim= " << dim;
testClassificationError(numSamples, dim, topkSize);
}
} }
} }
} }
......
...@@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) { ...@@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) {
std::ifstream fs(filename, std::ios_base::binary); std::ifstream fs(filename, std::ios_base::binary);
if (!fs) { if (!fs) {
LOG(INFO) << "missing parameters [" << filename << "] while loading model."; LOG(INFO) << "missing parameters [" << filename << "] while loading model.";
if (isStatic()) {
LOG(FATAL) << getName() << " is static but missing, not allowed.";
return false;
}
if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) {
LOG(FATAL) << getName() << " missing, not allowed."; LOG(FATAL) << getName() << " missing, not allowed.";
return false; return false;
......
...@@ -55,6 +55,9 @@ elif is_osx == True: ...@@ -55,6 +55,9 @@ elif is_osx == True:
include_dirs = [np.get_include(), "../"] # include numpy and paddle. include_dirs = [np.get_include(), "../"] # include numpy and paddle.
os.environ["CC"] = "@CMAKE_C_COMPILER@"
os.environ["CXX"] = "@CMAKE_CXX_COMPILER@"
setup(name="py_paddle", setup(name="py_paddle",
version="@PADDLE_VERSION@", version="@PADDLE_VERSION@",
ext_modules=[ ext_modules=[
......
...@@ -475,6 +475,10 @@ message EvaluatorConfig { ...@@ -475,6 +475,10 @@ message EvaluatorConfig {
// Used by ChunkEvaluator // Used by ChunkEvaluator
// chunk of these types are not counted // chunk of these types are not counted
repeated int32 excluded_chunk_types = 12; repeated int32 excluded_chunk_types = 12;
// Used by ClassificationErrorEvaluator
// top # classification error
optional int32 top_k = 13 [default = 1];
} }
message LinkConfig { message LinkConfig {
......
...@@ -12,25 +12,158 @@ ...@@ -12,25 +12,158 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__all__ = ['buffered'] __all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned'
]
from Queue import Queue from Queue import Queue
from threading import Thread from threading import Thread
import itertools
import random
def map_readers(func, *readers):
"""
Creates a data reader that outputs return value of function using
output of each data readers as arguments.
:param func: function to use.
:param *readers: readers whose outputs will be used as arguments of func.
:returns: the created data reader.
"""
def reader():
rs = []
for r in readers:
rs.append(r())
for e in itertools.imap(func, *rs):
yield e
return reader
def shuffle(reader, buf_size):
"""
Creates a data reader whose data output is suffled.
Output from the iterator that created by original reader will be
buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
is determined by argument buf_size.
:param reader: the original reader whose output will be shuffled.
:param buf_size: shuffle buffer size.
:returns:the new reader whose output is shuffled.
"""
def data_reader():
buf = []
for e in reader():
buf.append(e)
if len(buf) >= buf_size:
random.shuffle(buf)
for b in buf:
yield b
buf = []
if len(buf) > 0:
random.shuffle(buf)
for b in buf:
yield b
return data_reader
def chain(*readers):
"""
Creates a data reader whose output is the outputs of input data
readers chained together.
If input readers output following data entries:
[0, 0, 0]
[1, 1, 1]
[2, 2, 2]
The chained reader will output:
[0, 0, 0, 1, 1, 1, 2, 2, 2]
:param readers: input readers.
:returns: the new data reader.
"""
def reader():
rs = []
for r in readers:
rs.append(r())
for e in itertools.chain(*rs):
yield e
return reader
class ComposeNotAligned(ValueError):
pass
def compose(*readers, **kwargs):
"""
Creates a data reader whose output is the combination of input readers.
If input readers output following data entries:
(1, 2) 3 (4, 5)
The composed reader will output:
(1, 2, 3, 4, 5)
:*readers: readers that will be composed together.
:check_alignment: if True, will check if input readers are aligned
correctly. If False, will not check alignment and trailing outputs
will be discarded. Defaults to True.
:returns: the new data reader.
:raises ComposeNotAligned: outputs of readers are not aligned.
Will not raise when check_alignment is set to False.
"""
check_alignment = kwargs.pop('check_alignment', True)
def make_tuple(x):
if isinstance(x, tuple):
return x
else:
return (x, )
def reader():
rs = []
for r in readers:
rs.append(r())
if not check_alignment:
for outputs in itertools.izip(*rs):
yield sum(map(make_tuple, outputs), ())
else:
for outputs in itertools.izip_longest(*rs):
for o in outputs:
if o is None:
# None will be not be present if compose is aligned
raise ComposeNotAligned(
"outputs of readers are not aligned.")
yield sum(map(make_tuple, outputs), ())
return reader
def buffered(reader, size): def buffered(reader, size):
"""Creates a buffered data reader. """
Creates a buffered data reader.
The buffered data reader will read and save data entries into a buffer. The buffered data reader will read and save data entries into a
Reading from the buffered data reader will proceed as long as the buffer buffer. Reading from the buffered data reader will proceed as long
is not empty. as the buffer is not empty.
Args: :param reader: the data reader to read from.
reader: the data reader to read from. :param size: max buffer size.
size: max buffer size.
Returns: :returns: the buffered data reader.
The buffered data reader.
""" """
class EndSignal(): class EndSignal():
...@@ -43,7 +176,7 @@ def buffered(reader, size): ...@@ -43,7 +176,7 @@ def buffered(reader, size):
q.put(d) q.put(d)
q.put(end) q.put(end)
def create_reader(): def data_reader():
r = reader() r = reader()
q = Queue(maxsize=size) q = Queue(maxsize=size)
t = Thread( t = Thread(
...@@ -57,4 +190,4 @@ def buffered(reader, size): ...@@ -57,4 +190,4 @@ def buffered(reader, size):
yield e yield e
e = q.get() e = q.get()
return create_reader return data_reader
...@@ -16,16 +16,36 @@ import paddle.reader ...@@ -16,16 +16,36 @@ import paddle.reader
import time import time
def reader_10(dur): def reader_creator_10(dur):
for i in range(10): def reader():
time.sleep(dur) for i in range(10):
yield i # this invocation helps testing paddle.reader.buffer
time.sleep(dur)
yield i
return reader
class TestMap(unittest.TestCase):
def test_map(self):
d = {"h": 0, "i": 1}
def tokenize(x):
return d[x]
def read():
yield "h"
yield "i"
r = paddle.reader.map_readers(tokenize, read)
for i, e in enumerate(r()):
self.assertEqual(e, i)
class TestBuffered(unittest.TestCase): class TestBuffered(unittest.TestCase):
def test_read(self): def test_read(self):
for size in range(20): for size in range(20):
b = paddle.reader.buffered(lambda: reader_10(0), size) b = paddle.reader.buffered(reader_creator_10(0), size)
c = 0 c = 0
for i in b(): for i in b():
self.assertEqual(i, c) self.assertEqual(i, c)
...@@ -34,7 +54,7 @@ class TestBuffered(unittest.TestCase): ...@@ -34,7 +54,7 @@ class TestBuffered(unittest.TestCase):
def test_buffering(self): def test_buffering(self):
# read have 30ms delay. # read have 30ms delay.
b = paddle.reader.buffered(lambda: reader_10(0.03), 10) b = paddle.reader.buffered(reader_creator_10(0.03), 10)
last_time = time.time() last_time = time.time()
for idx, i in enumerate(b()): for idx, i in enumerate(b()):
elapsed_time = time.time() - last_time elapsed_time = time.time() - last_time
...@@ -42,9 +62,63 @@ class TestBuffered(unittest.TestCase): ...@@ -42,9 +62,63 @@ class TestBuffered(unittest.TestCase):
time.sleep(0.3) time.sleep(0.3)
else: else:
# read time should be short, meaning already buffered. # read time should be short, meaning already buffered.
self.assertLess(elapsed_time, 0.01) self.assertLess(elapsed_time, 0.05)
last_time = time.time() last_time = time.time()
class TestCompose(unittest.TestCase):
def test_compse(self):
reader = paddle.reader.compose(
reader_creator_10(0), reader_creator_10(0))
for idx, e in enumerate(reader()):
self.assertEqual(e, (idx, idx))
def test_compose_not_aligned(self):
total = 0
reader = paddle.reader.compose(
paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)),
reader_creator_10(0))
with self.assertRaises(paddle.reader.ComposeNotAligned):
for e in reader():
total += 1
# expecting 10, not 20
self.assertEqual(total, 10)
def test_compose_not_aligned_no_check(self):
total = 0
reader = paddle.reader.compose(
paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)),
reader_creator_10(0),
check_alignment=False)
for e in reader():
total += 1
# expecting 10, not 20
self.assertEqual(total, 10)
class TestChain(unittest.TestCase):
def test_chain(self):
c = paddle.reader.chain(reader_creator_10(0), reader_creator_10(0))
idx = 0
for e in c():
self.assertEqual(e, idx % 10)
idx += 1
self.assertEqual(idx, 20)
class TestShuffle(unittest.TestCase):
def test_shuffle(self):
case = [(0, True), (1, True), (10, False), (100, False)]
a = reader_creator_10(0)
for size, checkEq in case:
s = paddle.reader.shuffle(a, size)
total = 0
for idx, e in enumerate(s()):
if checkEq:
self.assertEqual(idx, e)
total += 1
self.assertEqual(total, 10)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -1253,6 +1253,7 @@ def Evaluator( ...@@ -1253,6 +1253,7 @@ def Evaluator(
dict_file=None, dict_file=None,
result_file=None, result_file=None,
num_results=None, num_results=None,
top_k=None,
delimited=None, delimited=None,
excluded_chunk_types=None, ): excluded_chunk_types=None, ):
evaluator = g_config.model_config.evaluators.add() evaluator = g_config.model_config.evaluators.add()
...@@ -1280,6 +1281,8 @@ def Evaluator( ...@@ -1280,6 +1281,8 @@ def Evaluator(
evaluator.result_file = result_file evaluator.result_file = result_file
if num_results is not None: if num_results is not None:
evaluator.num_results = num_results evaluator.num_results = num_results
if top_k is not None:
evaluator.top_k = top_k
if delimited is not None: if delimited is not None:
evaluator.delimited = delimited evaluator.delimited = delimited
......
...@@ -71,6 +71,7 @@ def evaluator_base( ...@@ -71,6 +71,7 @@ def evaluator_base(
result_file=None, result_file=None,
num_results=None, num_results=None,
delimited=None, delimited=None,
top_k=None,
excluded_chunk_types=None, ): excluded_chunk_types=None, ):
""" """
Evaluator will evaluate the network status while training/testing. Evaluator will evaluate the network status while training/testing.
...@@ -104,12 +105,15 @@ def evaluator_base( ...@@ -104,12 +105,15 @@ def evaluator_base(
:param weight: An input layer which is a weight for each sample. :param weight: An input layer which is a weight for each sample.
Each evaluator may calculate differently to use this weight. Each evaluator may calculate differently to use this weight.
:type weight: LayerOutput. :type weight: LayerOutput.
:param top_k: number k in top-k error rate
:type top_k: int
""" """
# inputs type assertions. # inputs type assertions.
assert classification_threshold is None or isinstance( assert classification_threshold is None or isinstance(
classification_threshold, float) classification_threshold, float)
assert positive_label is None or isinstance(positive_label, int) assert positive_label is None or isinstance(positive_label, int)
assert num_results is None or isinstance(num_results, int) assert num_results is None or isinstance(num_results, int)
assert top_k is None or isinstance(top_k, int)
if not isinstance(input, list): if not isinstance(input, list):
input = [input] input = [input]
...@@ -130,6 +134,8 @@ def evaluator_base( ...@@ -130,6 +134,8 @@ def evaluator_base(
dict_file=dict_file, dict_file=dict_file,
result_file=result_file, result_file=result_file,
delimited=delimited, delimited=delimited,
num_results=num_results,
top_k=top_k,
excluded_chunk_types=excluded_chunk_types, ) excluded_chunk_types=excluded_chunk_types, )
...@@ -139,6 +145,7 @@ def classification_error_evaluator(input, ...@@ -139,6 +145,7 @@ def classification_error_evaluator(input,
label, label,
name=None, name=None,
weight=None, weight=None,
top_k=None,
threshold=None): threshold=None):
""" """
Classification Error Evaluator. It will print error rate for classification. Classification Error Evaluator. It will print error rate for classification.
...@@ -167,6 +174,8 @@ def classification_error_evaluator(input, ...@@ -167,6 +174,8 @@ def classification_error_evaluator(input,
then means not set weight. The larger weight it is, the more then means not set weight. The larger weight it is, the more
important this sample is. important this sample is.
:type weight: LayerOutput :type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param threshold: The classification threshold. :param threshold: The classification threshold.
:type threshold: float :type threshold: float
:return: None. :return: None.
...@@ -178,6 +187,7 @@ def classification_error_evaluator(input, ...@@ -178,6 +187,7 @@ def classification_error_evaluator(input,
input=input, input=input,
label=label, label=label,
weight=weight, weight=weight,
top_k=top_k,
classification_threshold=threshold, ) classification_threshold=threshold, )
......
...@@ -2870,8 +2870,8 @@ def gru_step_layer(input, ...@@ -2870,8 +2870,8 @@ def gru_step_layer(input,
:param name: :param name:
:param gate_act: :param gate_act:
:param bias_attr: :param bias_attr:
:param param_attr: the parameter_attribute for transforming the output_mem :param param_attr: the parameter_attribute for transforming the output_mem
from previous step. from previous step.
:param layer_attr: :param layer_attr:
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
...@@ -2882,10 +2882,10 @@ def gru_step_layer(input, ...@@ -2882,10 +2882,10 @@ def gru_step_layer(input,
Layer( Layer(
name=name, name=name,
type=LayerType.GRU_STEP_LAYER, type=LayerType.GRU_STEP_LAYER,
# The parameter here is for transforming the output_mem. The input has # The parameter here is for transforming the output_mem. The input has
# already been transformed outside this module so it does not need # already been transformed outside this module so it does not need
# parameter associated with it. # parameter associated with it.
# The parameter here is instead grouped with input is due to # The parameter here is instead grouped with input is due to
# backward model compatibility. # backward model compatibility.
inputs=[Input(input.name, **param_attr.attr), output_mem.name], inputs=[Input(input.name, **param_attr.attr), output_mem.name],
bias=ParamAttr.to_bias(bias_attr), bias=ParamAttr.to_bias(bias_attr),
...@@ -3536,6 +3536,7 @@ def classification_cost(input, ...@@ -3536,6 +3536,7 @@ def classification_cost(input,
label, label,
weight=None, weight=None,
name=None, name=None,
top_k=None,
evaluator=classification_error_evaluator, evaluator=classification_error_evaluator,
layer_attr=None): layer_attr=None):
""" """
...@@ -3550,6 +3551,8 @@ def classification_cost(input, ...@@ -3550,6 +3551,8 @@ def classification_cost(input,
:param weight: The weight affects the cost, namely the scale of cost. :param weight: The weight affects the cost, namely the scale of cost.
It is an optional argument. It is an optional argument.
:type weight: LayerOutput :type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param evaluator: Evaluator method. :param evaluator: Evaluator method.
:param layer_attr: layer's extra attribute. :param layer_attr: layer's extra attribute.
:type layer_attr: ExtraLayerAttribute :type layer_attr: ExtraLayerAttribute
...@@ -3577,7 +3580,7 @@ def classification_cost(input, ...@@ -3577,7 +3580,7 @@ def classification_cost(input,
assert isinstance(e.for_classification, bool) assert isinstance(e.for_classification, bool)
assert e.for_classification assert e.for_classification
e(name=e.__name__, input=input, label=label, weight=weight) e(name=e.__name__, input=input, label=label, weight=weight, top_k=top_k)
if not isinstance(evaluator, collections.Sequence): if not isinstance(evaluator, collections.Sequence):
evaluator = [evaluator] evaluator = [evaluator]
......
...@@ -18,11 +18,12 @@ import parameters ...@@ -18,11 +18,12 @@ import parameters
import trainer import trainer
import event import event
import data_type import data_type
import attr
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
__all__ = [ __all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
'event', 'data_type' 'event', 'data_type', 'attr'
] ]
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers.attrs import *
__all__ = [
"Param",
"Extra",
]
Param = ParameterAttribute
Extra = ExtraLayerAttribute
...@@ -75,10 +75,14 @@ from paddle.trainer_config_helpers.default_decorators import wrap_name_default ...@@ -75,10 +75,14 @@ from paddle.trainer_config_helpers.default_decorators import wrap_name_default
import activation import activation
import data_type import data_type
import activation
import attr
__all__ = [ __all__ = [
'parse_network', 'data', 'fc', 'max_id', 'classification_cost', 'parse_network', 'data', 'fc', 'max_id', 'classification_cost',
'cross_entropy_cost' 'cross_entropy_cost', 'cross_entropy_with_selfnorm_cost', 'regression_cost',
'multi_binary_label_cross_entropy_cost', 'rank_cost', 'lambda_cost',
'sum_cost', 'huber_cost'
] ]
...@@ -145,7 +149,8 @@ def __convert_to_v2__(method_name, name_prefix, parent_names): ...@@ -145,7 +149,8 @@ def __convert_to_v2__(method_name, name_prefix, parent_names):
parent_layers = dict() parent_layers = dict()
other_kwargs = dict() other_kwargs = dict()
for pname in parent_names: for pname in parent_names:
parent_layers[pname] = kwargs[pname] if kwargs.has_key(pname):
parent_layers[pname] = kwargs[pname]
for key in kwargs.keys(): for key in kwargs.keys():
if key not in parent_names: if key not in parent_names:
...@@ -213,11 +218,15 @@ class MemoryV2(Layer): ...@@ -213,11 +218,15 @@ class MemoryV2(Layer):
data = DataLayerV2 data = DataLayerV2
fc = __convert_to_v2__('fc_layer', name_prefix='fc', parent_names=['input']) fc = __convert_to_v2__('fc_layer', name_prefix='fc', parent_names=['input'])
max_id = __convert_to_v2__( max_id = __convert_to_v2__(
'maxid_layer', name_prefix='maxid_layer', parent_names=['input']) 'maxid_layer', name_prefix='maxid', parent_names=['input'])
classification_cost = __convert_to_v2__( classification_cost = __convert_to_v2__(
'classification_cost', 'classification_cost',
name_prefix='classification_cost', name_prefix='classification_cost',
parent_names=['input', 'label']) parent_names=['input', 'label', 'weight'])
regression_cost = __convert_to_v2__(
'regression_cost',
name_prefix='regression_cost',
parent_names=['input', 'label', 'weight'])
cross_entropy_cost = __convert_to_v2__( cross_entropy_cost = __convert_to_v2__(
'cross_entropy', 'cross_entropy',
name_prefix='cross_entropy', name_prefix='cross_entropy',
...@@ -230,14 +239,48 @@ recurrent_group = __convert_to_v2__( ...@@ -230,14 +239,48 @@ recurrent_group = __convert_to_v2__(
'recurrent_group', name_prefix='recurrent_layer', parent_names=['input']) 'recurrent_group', name_prefix='recurrent_layer', parent_names=['input'])
memory = MemoryV2 memory = MemoryV2
cross_entropy_with_selfnorm_cost = __convert_to_v2__(
'cross_entropy_with_selfnorm',
name_prefix='cross_entropy_with_selfnorm',
parent_names=['input', 'label'])
multi_binary_label_cross_entropy_cost = __convert_to_v2__(
'multi_binary_label_cross_entropy',
name_prefix='multi_binary_label_cross_entropy',
parent_names=['input', 'label'])
rank_cost = __convert_to_v2__(
'rank_cost',
name_prefix='rank_cost',
parent_names=['left', 'right', 'label', 'weight'])
lambda_cost = __convert_to_v2__(
'lambda_cost', name_prefix='lambda_cost', parent_names=['input', 'score'])
sum_cost = __convert_to_v2__(
'sum_cost', name_prefix='sum_cost', parent_names=['input'])
huber_cost = __convert_to_v2__(
'huber_cost', name_prefix='huber_cost', parent_names=['input', 'label'])
if __name__ == '__main__': if __name__ == '__main__':
pixel = data(name='pixel', type=data_type.dense_vector(784)) pixel = data(name='pixel', type=data_type.dense_vector(784))
label = data(name='label', type=data_type.integer_value(10)) label = data(name='label', type=data_type.integer_value(10))
hidden = fc(input=pixel, size=100, act=conf_helps.SigmoidActivation()) weight = data(name='weight', type=data_type.dense_vector(10))
inference = fc(input=hidden, size=10, act=conf_helps.SoftmaxActivation()) score = data(name='score', type=data_type.dense_vector(1))
hidden = fc(input=pixel,
size=100,
act=activation.Sigmoid(),
param_attr=attr.Param(name='hidden'))
inference = fc(input=hidden, size=10, act=activation.Softmax())
maxid = max_id(input=inference) maxid = max_id(input=inference)
cost1 = classification_cost(input=inference, label=label) cost1 = classification_cost(input=inference, label=label)
cost2 = cross_entropy_cost(input=inference, label=label) cost2 = classification_cost(input=inference, label=label, weight=weight)
cost3 = cross_entropy_cost(input=inference, label=label)
cost4 = cross_entropy_with_selfnorm_cost(input=inference, label=label)
cost5 = regression_cost(input=inference, label=label)
cost6 = regression_cost(input=inference, label=label, weight=weight)
cost7 = multi_binary_label_cross_entropy_cost(input=inference, label=label)
cost8 = rank_cost(left=score, right=score, label=score)
cost9 = lambda_cost(input=inference, score=score)
cost10 = sum_cost(input=inference)
cost11 = huber_cost(input=score, label=label)
mem = memory(name="rnn_state", size=10) mem = memory(name="rnn_state", size=10)
...@@ -246,6 +289,11 @@ if __name__ == '__main__': ...@@ -246,6 +289,11 @@ if __name__ == '__main__':
# print parse_network(cost1, cost2) # print parse_network(cost1, cost2)
# print parse_network(cost2) # print parse_network(cost2)
# print parse_network(inference, maxid) # print parse_network(inference, maxid)
print parse_network(cost1, cost2)
print parse_network(cost3, cost4)
print parse_network(cost5, cost6)
print parse_network(cost7, cost8, cost9, cost10, cost11)
print parse_network(inference, maxid)
dict_dim = 10 dict_dim = 10
word_dim = 8 word_dim = 8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册