diff --git a/doc_cn/demo/quick_start/index.md b/doc_cn/demo/quick_start/index.md
index 34cd4a840e442c0ff2559bd8627a14d5388a0971..aa6b66ca8c02411016420bf9d99c5e1b4e3cefdd 100644
--- a/doc_cn/demo/quick_start/index.md
+++ b/doc_cn/demo/quick_start/index.md
@@ -4,7 +4,7 @@
## 安装(Install)
-首先请参考安装教程安装PaddlePaddle。
+首先请参考安装教程安装PaddlePaddle。
## 使用概述(Overview)
diff --git a/paddle/cuda/src/hl_cuda_cublas.cc b/paddle/cuda/src/hl_cuda_cublas.cc
index 445279fa01034cc0805c3dbd2e3cb1b269607661..dc109487ded20f91c3081ebde8bb50834c362bcf 100644
--- a/paddle/cuda/src/hl_cuda_cublas.cc
+++ b/paddle/cuda/src/hl_cuda_cublas.cc
@@ -217,7 +217,7 @@ void hl_matrix_mul(real *A_d, hl_trans_op_t transa,
} else {
LOG(FATAL) << "parameter transa error!";
}
- CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS);
+ CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS) << hl_cublas_get_error_string(stat);
CHECK_SYNC("hl_matrix_mul failed");
}
@@ -266,7 +266,7 @@ void hl_matrix_mul_vector(real *A_d, hl_trans_op_t trans,
LOG(FATAL) << "parameter transa error!";
}
- CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS);
+ CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS) << hl_cublas_get_error_string(stat);
CHECK_SYNC("hl_matrix_mul_vector");
}
diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
index bee82faa5fca8bb82848b862a239258a8165ce7b..fc38bca3c403b2855ad873e5cc06539d10a941cf 100644
--- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
+++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
@@ -497,20 +497,21 @@ void RecurrentGradientMachine::forward(const std::vector& inArgs,
int idSize = 0;
// connect in_links
for (size_t j = 0; j < inFrameLines_.size(); ++j) {
+ Info& info = info_[shareInlinkInfo ? 0 : j];
// idSize denotes the sum number of tokens in each length i
- idSize = info_[j].idIndex[i + 1] - info_[j].idIndex[i];
+ idSize = info.idIndex[i + 1] - info.idIndex[i];
InFrameLine inFrameLine = inFrameLines_[j];
auto scatterAgent =
dynamic_cast(inFrameLine.agents[i].get());
scatterAgent->setRealLayerAndOutput(inFrameLine.inLayer,
- inFrameLine.outArg, info_[j].allIds,
- info_[j].idIndex[i], idSize);
+ inFrameLine.outArg, info.allIds,
+ info.idIndex[i], idSize);
if (hasSubseq) {
// size: the length of subsequence
int size =
- info_[j].seqStartPosIndex[i + 1] - info_[j].seqStartPosIndex[i];
- scatterAgent->setSequenceStartPositions(info_[j].sequenceStartPositions,
- info_[j].seqStartPosIndex[i],
+ info.seqStartPosIndex[i + 1] - info.seqStartPosIndex[i];
+ scatterAgent->setSequenceStartPositions(info.sequenceStartPositions,
+ info.seqStartPosIndex[i],
size);
}
}
@@ -744,16 +745,24 @@ void RecurrentGradientMachine::selectRowsOneTime(LayerPtr layer,
const IVectorPtr& allIds,
Argument* arg,
PassType passType) {
- const MatrixPtr& realV = layer->getOutputValue();
- int height = realV->getHeight();
- int width = realV->getWidth();
- Matrix::resizeOrCreate(arg->value, height, width, /* trans */ false, useGpu_);
- arg->value->zeroMem();
- arg->value->selectRows(*realV, *allIds);
- if (passType != PASS_TEST) {
- Matrix::resizeOrCreate(arg->grad, height, width, /* trans */ false,
- useGpu_);
- arg->grad->zeroMem();
+ Argument& src = layer->getOutput();
+ if (src.value) {
+ const MatrixPtr& realV = src.value;
+ int height = realV->getHeight();
+ int width = realV->getWidth();
+ Matrix::resizeOrCreate(
+ arg->value, height, width, /* trans */ false, useGpu_);
+ arg->value->zeroMem();
+ arg->value->selectRows(*realV, *allIds);
+ if (passType != PASS_TEST) {
+ Matrix::resizeOrCreate(arg->grad, height, width, /* trans */ false,
+ useGpu_);
+ arg->grad->zeroMem();
+ }
+ }
+ if (src.ids) {
+ IVector::resizeOrCreate(arg->ids, src.ids->getSize(), useGpu_);
+ arg->ids->selectFrom(*src.ids, *allIds);
}
}
diff --git a/paddle/gserver/layers/AgentLayer.cpp b/paddle/gserver/layers/AgentLayer.cpp
index c1bef18ed38af8393b044f184364dfbd7e9e6bbb..056e9568852ac93552413334be1960e9c17525d4 100644
--- a/paddle/gserver/layers/AgentLayer.cpp
+++ b/paddle/gserver/layers/AgentLayer.cpp
@@ -139,15 +139,16 @@ void ScatterAgentLayer::forward(PassType passType) {
Layer::forward(passType);
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
- if (realLayer_->getOutput().ids) { // ids scatter
- IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_);
- output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_);
- } else { // value scatter
- int width = this->getSize();
- if (realOutArg_.value) {
- output_.subArgFrom(realOutArg_, /* offset */ idIndex_ * width, idSize_,
- width, useGpu_);
- } else { // used in generation
+ int width = this->getSize();
+ if (realOutArg_.value || realOutArg_.ids) {
+ output_.subArgFrom(realOutArg_, /* offset */ idIndex_, idSize_,
+ width, useGpu_);
+ } else { // used in generation
+ if (realLayer_->getOutput().ids) {
+ IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_);
+ output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_);
+ }
+ if (realLayer_->getOutput().value) {
int height = ids_->getSize();
resetOutput(height, width);
@@ -213,18 +214,17 @@ void SequenceGatherAgentLayer::forward(PassType passType) {
void SequenceScatterAgentLayer::forward(PassType passType) {
Layer::forward(passType);
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
- CHECK(!realLayer_->getOutput().ids) << "Not supported";
const Argument& input = realLayer_->getOutput();
- CHECK_EQ(input.value->getWidth(), this->getSize());
+ CHECK_EQ(realLayer_->getSize(), this->getSize());
int width = this->getSize();
AsyncGpuBlock asyncGpuBlock;
REGISTER_TIMER_INFO("SequenceAgentLayerForward", getName().c_str());
- if (realOutArg_.value) {
+ if (realOutArg_.value || realOutArg_.ids) {
CHECK(realOutArg_.sequenceStartPositions);
- output_.subArgFrom(realOutArg_, /* offset */ idIndex_ * width, idSize_,
+ output_.subArgFrom(realOutArg_, /* offset */ idIndex_, idSize_,
width, useGpu_, /* trans */ false, /* seqFlag */ true,
/* seqStart */ seqStartPosIndex_,
/* seqSize */ numSequences_);
diff --git a/paddle/gserver/layers/CRFLayer.h b/paddle/gserver/layers/CRFLayer.h
index c6ba8e7c965a3957b022319bdd7f6c4c012153c7..58902a0d3b7e4cad67dac94be10c35ebbf83b001 100644
--- a/paddle/gserver/layers/CRFLayer.h
+++ b/paddle/gserver/layers/CRFLayer.h
@@ -25,7 +25,7 @@ namespace paddle {
/**
* A layer for calculating the cost of sequential conditional random field
* model.
- * See LinearChainCRF.h for the detail of the CRF formulation.
+ * See class LinearChainCRF for the detail of the CRF formulation.
*/
class CRFLayer : public Layer {
public:
diff --git a/paddle/gserver/layers/LinearChainCRF.h b/paddle/gserver/layers/LinearChainCRF.h
index 3bde1aa415ce9b330d75bcf07cd79b88b527c5ed..c33c83b25987e1b944a84d960cf6539cff1b872f 100644
--- a/paddle/gserver/layers/LinearChainCRF.h
+++ b/paddle/gserver/layers/LinearChainCRF.h
@@ -21,39 +21,39 @@ namespace paddle {
class LinearChainCRF {
public:
- /*
- The size of para and grad must be (numClasses + 2) * numClasses.
- The first numClasses values of para are for starting weights (a).
- The next numClasses values of para are for ending weights (b),
- The remaning values are for transition weights (w).
-
- The probability of a state sequence s of length L is defined as:
- P(s) = (1/Z) exp(a_{s_1} + b_{s_L}
- + \sum_{l=1}^L x_{s_l}
- + \sum_{l=2}^L w_{s_{l-1},s_l})
- where Z is a normalization value so that the sum of P(s) over all possible
- sequences is 1, and x is the input feature to the CRF.
+ /**
+ * The size of para and grad must be \f$(numClasses + 2) * numClasses\f$.
+ * The first numClasses values of para are for starting weights (\f$a\f$).
+ * The next numClasses values of para are for ending weights (\f$b\f$),
+ * The remaning values are for transition weights (\f$w\f$).
+ *
+ * The probability of a state sequence s of length \f$L\f$ is defined as:
+ * \f$P(s) = (1/Z) exp(a_{s_1} + b_{s_L}
+ * + \sum_{l=1}^L x_{s_l}
+ * + \sum_{l=2}^L w_{s_{l-1},s_l})\f$
+ * where \f$Z\f$ is a normalization value so that the sum of \f$P(s)\f$ over all possible
+ * sequences is \f$1\f$, and \f$x\f$ is the input feature to the CRF.
*/
LinearChainCRF(int numClasses, real* para, real* grad);
- /*
- Calculate the negative log likelihood of s given x.
- The size of x must be length * numClasses. Each consecutive numClasses
- values are the features for one time step.
+ /**
+ * Calculate the negative log likelihood of s given x.
+ * The size of x must be length * numClasses. Each consecutive numClasses
+ * values are the features for one time step.
*/
real forward(real* x, int* s, int length);
- /*
- Calculate the gradient with respect to x, a, b, and w.
- The gradient of x will be stored in dx.
- backward() can only be called after a corresponding call to forward() with
- the same x, s and length.
- NOTE: The gradient is added to dx and grad (provided at constructor).
+ /**
+ * Calculate the gradient with respect to x, a, b, and w.
+ * The gradient of x will be stored in dx.
+ * backward() can only be called after a corresponding call to forward() with
+ * the same x, s and length.
+ * @note The gradient is added to dx and grad (provided at constructor).
*/
void backward(real* x, real* dx, int* s, int length);
- /*
- Find the most probable sequence given x. The result will be stored in s.
+ /**
+ * Find the most probable sequence given x. The result will be stored in s.
*/
void decode(real* x, int* s, int length);
diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt
index 129f10fac114d41f7c016e1fc22f311ee78cbfa5..ff2abf76973174ac2a437830b234f4c9937c08ed 100644
--- a/paddle/gserver/tests/CMakeLists.txt
+++ b/paddle/gserver/tests/CMakeLists.txt
@@ -56,7 +56,6 @@ add_test(NAME test_RecurrentGradientMachine
COMMAND .set_python_path.sh -d
${PROJ_ROOT}/python:${PROJ_ROOT}/paddle/gserver/tests
${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine
- --use_gpu=false
WORKING_DIRECTORY ${PROJ_ROOT}/paddle)
add_unittest_without_exec(test_NetworkCompare
diff --git a/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf b/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf
new file mode 100644
index 0000000000000000000000000000000000000000..e01b3f8e7aa5c4c14c64c2843b0f6f82817972a1
--- /dev/null
+++ b/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf
@@ -0,0 +1,77 @@
+#edit-mode: -*- python -*-
+# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle.trainer_config_helpers import *
+
+######################## data source ################################
+define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list',
+ test_list=None,
+ module='rnn_data_provider',
+ obj='process_subseq')
+
+
+settings(batch_size=2, learning_rate=0.01)
+######################## network configure ################################
+dict_dim = 10
+word_dim = 8
+hidden_dim = 8
+label_dim = 3
+
+data = data_layer(name="word", size=dict_dim)
+
+emb = embedding_layer(input=data, size=word_dim)
+
+# This hierachical RNN is designed to be equivalent to the simple RNN in
+# sequence_rnn.conf
+
+def outer_step(wid, x):
+ outer_mem = memory(name="outer_rnn_state", size=hidden_dim)
+ def inner_step(y, wid):
+ z = embedding_layer(input=wid, size=word_dim)
+ inner_mem = memory(name="inner_rnn_state",
+ size=hidden_dim,
+ boot_layer=outer_mem)
+ out = fc_layer(input=[y, z, inner_mem],
+ size=hidden_dim,
+ act=TanhActivation(),
+ bias_attr=True,
+ name="inner_rnn_state")
+ return out
+
+ inner_rnn_output = recurrent_group(
+ step=inner_step,
+ name="inner",
+ input=[x, wid])
+ last = last_seq(input=inner_rnn_output, name="outer_rnn_state")
+
+ # "return last" should also work. But currently RecurrentGradientMachine
+ # does not handle it correctly. Current implementation requires that
+ # all the out links are from sequences. However, it does not report error
+ # when the out links are not sequences.
+ return inner_rnn_output
+
+out = recurrent_group(
+ name="outer",
+ step=outer_step,
+ input=[SubsequenceInput(data), SubsequenceInput(emb)])
+
+rep = last_seq(input=out)
+prob = fc_layer(size=label_dim,
+ input=rep,
+ act=SoftmaxActivation(),
+ bias_attr=True)
+
+outputs(classification_cost(input=prob,
+ label=data_layer(name="label", size=label_dim)))
diff --git a/paddle/gserver/tests/sequence_rnn_multi_input.conf b/paddle/gserver/tests/sequence_rnn_multi_input.conf
new file mode 100644
index 0000000000000000000000000000000000000000..968621cab59be9296ae5ee962a3a359fff59e022
--- /dev/null
+++ b/paddle/gserver/tests/sequence_rnn_multi_input.conf
@@ -0,0 +1,58 @@
+#edit-mode: -*- python -*-
+# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle.trainer_config_helpers import *
+
+######################## data source ################################
+define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list',
+ test_list=None,
+ module='rnn_data_provider',
+ obj='process_seq')
+
+
+settings(batch_size=2, learning_rate=0.01)
+######################## network configure ################################
+dict_dim = 10
+word_dim = 8
+hidden_dim = 8
+label_dim = 3
+
+data = data_layer(name="word", size=dict_dim)
+
+emb = embedding_layer(input=data, size=word_dim)
+
+def step(y, wid):
+ z = embedding_layer(input=wid, size=word_dim)
+ mem = memory(name="rnn_state", size=hidden_dim)
+ out = fc_layer(input=[y, z, mem],
+ size=hidden_dim,
+ act=TanhActivation(),
+ bias_attr=True,
+ name="rnn_state")
+ return out
+
+out = recurrent_group(
+ name="rnn",
+ step=step,
+ input=[emb, data])
+
+rep = last_seq(input=out)
+prob = fc_layer(size=label_dim,
+ input=rep,
+ act=SoftmaxActivation(),
+ bias_attr=True)
+
+outputs(classification_cost(input=prob,
+ label=data_layer(name="label", size=label_dim)))
diff --git a/paddle/gserver/tests/test_RecurrentGradientMachine.cpp b/paddle/gserver/tests/test_RecurrentGradientMachine.cpp
index b73fdd18abf35858a366552120e69c8a039a4726..550df0a31844ece80aa3f2d976f46a84cef9b35f 100644
--- a/paddle/gserver/tests/test_RecurrentGradientMachine.cpp
+++ b/paddle/gserver/tests/test_RecurrentGradientMachine.cpp
@@ -92,7 +92,11 @@ void CalCost(const string& conf, const string& dir, real* cost,
rmDir(dir.c_str());
}
-void test(const string& conf1, const string& conf2, double eps) {
+void test(const string& conf1, const string& conf2, double eps, bool useGpu) {
+ if (!paddle::version::isWithGpu() && useGpu) {
+ return;
+ }
+ FLAGS_use_gpu = useGpu;
int num_passes = 5;
real* cost1 = new real[num_passes];
const string dir1 = "gserver/tests/t1";
@@ -113,17 +117,28 @@ void test(const string& conf1, const string& conf2, double eps) {
}
TEST(RecurrentGradientMachine, HasSubSequence) {
- test("gserver/tests/sequence_layer_group.conf",
- "gserver/tests/sequence_nest_layer_group.conf",
- 1e-5);
+ for (bool useGpu : {false, true}) {
+ test("gserver/tests/sequence_layer_group.conf",
+ "gserver/tests/sequence_nest_layer_group.conf",
+ 1e-5, useGpu);
+ }
}
TEST(RecurrentGradientMachine, rnn) {
- test("gserver/tests/sequence_rnn.conf",
- "gserver/tests/sequence_nest_rnn.conf",
- 0);
+ for (bool useGpu : {false, true}) {
+ test("gserver/tests/sequence_rnn.conf",
+ "gserver/tests/sequence_nest_rnn.conf",
+ 1e-6, useGpu);
+ }
}
+TEST(RecurrentGradientMachine, rnn_multi_input) {
+ for (bool useGpu : {false, true}) {
+ test("gserver/tests/sequence_rnn_multi_input.conf",
+ "gserver/tests/sequence_nest_rnn_multi_input.conf",
+ 1e-6, useGpu);
+ }
+}
int main(int argc, char** argv) {
if (paddle::version::isWithPyDataProvider()) {
diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp
index 0ca56b29b39b317d01d80631e332ba02356a613d..42c74661d2b2cebe0c2f5f14d0970ab2f1fec866 100644
--- a/paddle/parameter/Argument.cpp
+++ b/paddle/parameter/Argument.cpp
@@ -554,11 +554,16 @@ void Argument::degradeSequence(const Argument& input, bool useGpu) {
void Argument::subArgFrom(const Argument& input, size_t offset, size_t height,
size_t width, bool useGpu, bool trans, bool seqFlag,
size_t seqStart, size_t seqSize) {
- value = Matrix::create(input.value->getData() + offset, height, width, trans,
- useGpu);
+ if (input.value) {
+ value = Matrix::create(input.value->getData() + offset * width,
+ height, width, trans, useGpu);
+ }
+ if (input.ids) {
+ ids = IVector::create(input.ids->getData() + offset, height, useGpu);
+ }
if (input.grad) {
- grad = Matrix::create(input.grad->getData() + offset, height, width, trans,
- useGpu);
+ grad = Matrix::create(input.grad->getData() + offset * width,
+ height, width, trans, useGpu);
}
if (seqFlag) {
sequenceStartPositions = std::make_shared(
diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h
index 81cd117fc45cfa34da0810b01c5a710d9ce5950b..81ff9029bc4c8fca7adbabd7ae65caf7ac2f3c2a 100644
--- a/paddle/parameter/Argument.h
+++ b/paddle/parameter/Argument.h
@@ -177,11 +177,11 @@ struct Argument {
}
/**
- * @brief (value, grad, sequenceStartPositions) of output are subset of
+ * @brief (value, ids, grad, sequenceStartPositions) of output are subset of
* input. Note that, output share the same memory of input.
*
* @param input[in] input
- * @param offset[in] offset of input.value
+ * @param offset[in] offset in terms of rows
* @param height[in] height of output.value
* @param width[in] width of output.value
* @param useGpu[in]
diff --git a/paddle/trainer/ThreadParameterUpdater.cpp b/paddle/trainer/ThreadParameterUpdater.cpp
index 65d827787ee78fe7a572869d7115c7abe27304a6..91f7f4d29df938e88a0e8c54b7046194c7adfb35 100644
--- a/paddle/trainer/ThreadParameterUpdater.cpp
+++ b/paddle/trainer/ThreadParameterUpdater.cpp
@@ -141,7 +141,7 @@ void SgdThreadUpdater::traverse(GetTraverseCallback getTraverseCallback) {
} else if (hasCpuPara) {
getGlobalSyncThreadPool()->exec(cpuTraverse);
} else if (hasGpuPara) {
- cpuTraverse(0, 0);
+ gpuTraverse(0, 0);
}
}
diff --git a/paddle/trainer/TrainerInternal.cpp b/paddle/trainer/TrainerInternal.cpp
index 76b6b9bc3ee38341a67d7ec111196e28a28a0e9b..6029a4b2c1d0a0c04058bbd979523f26b72b5a5e 100644
--- a/paddle/trainer/TrainerInternal.cpp
+++ b/paddle/trainer/TrainerInternal.cpp
@@ -101,6 +101,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId,
// it
//! to ParameterHook.
auto& grad = para->getBuf(PARAMETER_GRADIENT);
+ SetDevice device(para->getDeviceId());
paraStats[para->getID()].avgAbsGrad = grad->getAbsSum() / para->getSize();
paraStats[para->getID()].maxAbsGrad = grad->getAbsMax();
}
diff --git a/paddle/trainer/tests/sample_trainer_config_parallel.conf b/paddle/trainer/tests/sample_trainer_config_parallel.conf
index 3563fede1c18262369074c38ed8b2dc76fc863f5..e35a1f26dad2f81f70fe31f9f8c921606ea8461b 100644
--- a/paddle/trainer/tests/sample_trainer_config_parallel.conf
+++ b/paddle/trainer/tests/sample_trainer_config_parallel.conf
@@ -13,137 +13,74 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-#Todo(luotao02) This config is only used for unitest. It is out of date now, and will be updated later.
+from paddle.trainer_config_helpers import *
-TrainData(
- SimpleData(
- files = "trainer/tests/sample_filelist.txt",
- feat_dim = 3,
- context_len = 0,
- buffer_capacity = 1000000,
- )
-)
+TrainData(SimpleData(
+ files = "trainer/tests/sample_filelist.txt",
+ feat_dim = 3,
+ context_len = 0,
+ buffer_capacity = 1000000))
-TestData(
- SimpleData(
- files = "trainer/tests/sample_filelist.txt",
- feat_dim = 3,
- context_len = 0,
- buffer_capacity = 1000000,
- )
-)
+TestData(SimpleData(
+ files = "trainer/tests/sample_filelist.txt",
+ feat_dim = 3,
+ context_len = 0,
+ buffer_capacity = 1000000))
-Settings(
- algorithm = "sgd",
- num_batches_per_send_parameter = 1,
- num_batches_per_get_parameter = 1,
- batch_size = 100,
- learning_rate = 0.001,
- learning_rate_decay_a = 1e-5,
- learning_rate_decay_b = 0.5,
-)
+settings(batch_size = 100)
-default_initial_std(0.2)
# Output layer, label layer, cost layer, preferably set to the same environment.
output_device = 0
-model_type("nn")
-
# Input Layer does not need to specify the device number.
-Layer(
- name = "input",
- type = "data",
- size = 3,
-)
+data = data_layer(name='input', size=3)
# Calculate in the CPU.
-Layer(
- name = "layer1_1",
- type = "fc",
- size = 5,
- active_type = "sigmoid",
- device = -1,
- inputs = "input",
-)
+fc1 = fc_layer(input=data, size=5,
+ bias_attr=True,
+ layer_attr=ExtraAttr(device=-1),
+ act=SigmoidActivation())
# Calculate in the GPU 0.
-Layer(
- name = "layer2_1",
- type = "fc",
- size = 10,
- active_type = "sigmoid",
- device = 0,
- inputs = "layer1_1",
-)
+fc2 = fc_layer(input=fc1, size=10,
+ bias_attr=True,
+ layer_attr=ExtraAttr(device=0),
+ act=SigmoidActivation())
# Calculate in the GPU 1.
-Layer(
- name = "layer2_2",
- type = "fc",
- size = 10,
- active_type = "sigmoid",
- device = 1,
- inputs = "layer1_1",
-)
+fc3 = fc_layer(input=fc1, size=10,
+ bias_attr=True,
+ layer_attr=ExtraAttr(device=1),
+ act=SigmoidActivation())
# Calculate in the GPU 0.
-Layer(
- name = "layer3_1",
- type = "fc",
- size = 10,
- device = 0,
- active_type = "sigmoid",
- inputs = ["layer2_1", "layer2_2"],
-)
+fc4 = fc_layer(input=[fc2,fc3], size=10,
+ bias_attr=True,
+ layer_attr=ExtraAttr(device=0),
+ act=SigmoidActivation())
# Calculate in the GPU 1.
-Layer(
- name = "layer3_2",
- type = "fc",
- size = 10,
- device = 1,
- active_type = "sigmoid",
- inputs = ["layer2_1", "layer2_2"],
-)
-
+fc5 = fc_layer(input=[fc2,fc3], size=10,
+ bias_attr=True,
+ layer_attr=ExtraAttr(device=1),
+ act=SigmoidActivation())
-Layer(
- name = "output",
- type = "fc",
- size = 10,
- device = output_device,
- active_type = "sigmoid",
- inputs = ["layer3_1", "layer3_2"],
-)
+output = fc_layer(input=[fc4,fc5], size=10,
+ bias_attr=True,
+ layer_attr=ExtraAttr(device=output_device),
+ act=SoftmaxActivation())
if get_config_arg('with_cost', bool, True):
# This is for training the neural network.
# We need to have another data layer for label
# and a layer for calculating cost
- Layer(
- name = "label",
- type = "data",
- device = output_device,
- size = 1,
- )
-
- Layer(
- name = "cost",
- type = "multi-class-cross-entropy",
- device = output_device,
- inputs = ["output", "label"],
- )
-
- Evaluator(
- name = "error",
- type = "classification_error",
- inputs = ["output", "label"])
-
- Inputs("input", "label")
- Outputs("cost")
-
+ lbl = data_layer(name='label', size=1,
+ layer_attr=ExtraAttr(device=output_device))
+
+ outputs(classification_cost(input=output,
+ label=lbl,
+ layer_attr=ExtraAttr(device=output_device)))
else:
# This is for prediction where we don't have label
# and don't need to calculate cost
- Inputs("input")
- Outputs("output")
+ outputs(output)
diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py
index a57e9065c6f980b0338bd4ed0a91160fa7bed94f..1f55298f24f0742203adf6b332f86193d3ffc732 100644
--- a/python/paddle/trainer/config_parser.py
+++ b/python/paddle/trainer/config_parser.py
@@ -1279,7 +1279,7 @@ class LayerBase(object):
size,
dims=None,
sparse = None,
- format = "csr"):
+ format = None):
if dims is None:
# TODO(yuyang18): print warning and callstack here!
dims = list()
@@ -2074,7 +2074,7 @@ class MaxLayer(LayerBase):
active_type='linear',
device=None,
bias=False,
- output_max_index=False):
+ output_max_index=None):
super(MaxLayer, self).__init__(name, 'max', 0, inputs=inputs, device=device)
config_assert(len(self.inputs) == 1, 'MaxLayer must have 1 input')
self.config.trans_type = trans_type
@@ -2083,7 +2083,8 @@ class MaxLayer(LayerBase):
input_layer = self.get_input_layer(input_index)
self.set_layer_size(input_layer.size)
self.create_bias_parameter(bias, self.config.size)
- self.config.output_max_index=output_max_index
+ if output_max_index is not None:
+ self.config.output_max_index = output_max_index
@config_layer('maxid')
@@ -2440,7 +2441,7 @@ class MixedLayer(LayerBase):
inputs,
size=0,
bias=True,
- error_clipping_threshold=0.0,
+ error_clipping_threshold=None,
**xargs):
config_assert(inputs, 'inputs cannot be empty')
super(MixedLayer, self).__init__(
@@ -2510,7 +2511,8 @@ class MixedLayer(LayerBase):
self.create_bias_parameter(bias, self.config.size)
- self.config.error_clipping_threshold = error_clipping_threshold
+ if error_clipping_threshold is not None:
+ self.config.error_clipping_threshold = error_clipping_threshold
# like MixedLayer, but no bias parameter
@config_func
diff --git a/python/paddle/trainer_config_helpers/activations.py b/python/paddle/trainer_config_helpers/activations.py
index 85534675199e7627f9753e5d233f5208b14decfd..292014519374eabbe55c61daa73692814a52aac2 100644
--- a/python/paddle/trainer_config_helpers/activations.py
+++ b/python/paddle/trainer_config_helpers/activations.py
@@ -15,8 +15,10 @@
__all__ = ["TanhActivation", "SigmoidActivation",
"SoftmaxActivation", "IdentityActivation", "LinearActivation",
'SequenceSoftmaxActivation', 'ExpActivation',
- "ReluActivation", "BReluActivation", "SoftReluActivation", "STanhActivation",
- "AbsActivation", "SquareActivation", "BaseActivation"]
+ "ReluActivation", "BReluActivation", "SoftReluActivation",
+ "STanhActivation",
+ "AbsActivation", "SquareActivation",
+ "BaseActivation"]
class BaseActivation(object):
@@ -36,6 +38,9 @@ class BaseActivation(object):
self.name = name
self.support_hppl = support_hppl
+ def __repr__(self):
+ return self.name
+
class TanhActivation(BaseActivation):
"""
diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py
index 7b0a398d191722421c3604961e80057bccbf0eb1..d26344124733246c67790025fb186c6b350c3947 100644
--- a/python/paddle/trainer_config_helpers/attrs.py
+++ b/python/paddle/trainer_config_helpers/attrs.py
@@ -17,6 +17,42 @@ __all__ = ['ParamAttr', 'ExtraAttr', 'ParameterAttribute',
'ExtraLayerAttribute']
+def convert_and_compare(x, Type):
+ """
+ Convert x to be the same type as Type and then convert back to
+ check whether there is a loss of information
+ :param x: object to be checked
+ :param Type: target type to check x over
+
+ """
+ return type(x)(Type(x))==x
+
+def is_compatible_with(x, Type):
+ """
+ Check if x has a type compatible with Type
+ :param x: object to be checked
+ :param Type: target type to check x over
+
+ """
+ if type(x) == Type:
+ return True
+ try:
+ if float == Type or int == Type:
+ # avoid those types that can be converted to float/int but not very
+ # meaningful and could potentially lead to error
+ # i.e., str and bool typed value should not be used for initializing float/int variable
+ if not isinstance(x, str) and not isinstance(x, bool):
+ return convert_and_compare(x, Type)
+ elif bool == Type:
+ # should not use string type to initialize bool variable
+ if not isinstance(x, str):
+ return convert_and_compare(x, Type)
+ else:
+ return False
+ except:
+ return False
+
+
class ParameterAttribute(object):
"""
Parameter Attributes object. To fine-tuning network training process, user
@@ -65,14 +101,18 @@ class ParameterAttribute(object):
elif initial_std is None and initial_mean is None and initial_max \
is None and initial_min is None:
self.attr = {'initial_smart': True}
- elif isinstance(initial_std, float) or isinstance(initial_mean, float):
+ elif is_compatible_with(initial_std, float) or \
+ is_compatible_with(initial_mean, float):
self.attr = dict()
if initial_std is not None:
self.attr['initial_std'] = initial_std
if initial_mean is not None:
self.attr['initial_mean'] = initial_mean
self.attr['initial_strategy'] = 0 # Gauss Random
- elif isinstance(initial_max, float) and isinstance(initial_min, float):
+ elif is_compatible_with(initial_max, float) and \
+ is_compatible_with(initial_min, float):
+ initial_max = initial_max
+ initial_min = initial_min
assert initial_min < initial_max
initial_mean = (initial_max + initial_min) / 2
initial_std = initial_mean - initial_min
@@ -83,16 +123,16 @@ class ParameterAttribute(object):
else:
raise RuntimeError("Unexpected branch.")
- if not is_static and isinstance(l1_rate, float):
+ if not is_static and is_compatible_with(l1_rate, float):
self.attr['decay_rate_l1'] = l1_rate
- if not is_static and isinstance(l2_rate, float):
+ if not is_static and is_compatible_with(l2_rate, float):
self.attr['decay_rate'] = l2_rate
- if not is_static and isinstance(learning_rate, float):
+ if not is_static and is_compatible_with(learning_rate, float):
self.attr['learning_rate'] = learning_rate
- if not is_static and isinstance(momentum, float):
+ if not is_static and is_compatible_with(momentum, float):
self.attr['momentum'] = momentum
if name is not None:
@@ -134,12 +174,16 @@ class ExtraLayerAttribute(object):
The dropout rate is the zero rate of this mask. The
details of what dropout is please refer to `here
`_
+ JMLRdropout.pdf>`_.
:type drop_rate: float
-
+ :param device: device ID of layer. device=-1, use CPU. device>0, use GPU.
+ The details allocation in parallel_nn please refer to `here
+ `_.
+ :type device: int
"""
- def __init__(self, error_clipping_threshold=None, drop_rate=None):
+ def __init__(self, error_clipping_threshold=None, drop_rate=None, device=None):
self.attr = dict()
if isinstance(error_clipping_threshold, float):
assert error_clipping_threshold > 0
@@ -149,6 +193,9 @@ class ExtraLayerAttribute(object):
assert drop_rate > 0
self.attr["drop_rate"] = drop_rate
+ if isinstance(device, int):
+ self.attr["device"] = device
+
def check(self, layer_name):
for key in self.attr:
if not hasattr(self, 'can_%s' % key) or \
diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py
index 8b7cabf2fad507b15c820ffa44f29f44e44f407e..b28dd02b70946f6e4c2aaf90cb2a7058da0f79dc 100644
--- a/python/paddle/trainer_config_helpers/layers.py
+++ b/python/paddle/trainer_config_helpers/layers.py
@@ -13,6 +13,7 @@
# limitations under the License.
import functools
+import collections
from paddle.trainer.config_parser import *
from .activations import LinearActivation, SigmoidActivation, TanhActivation, \
@@ -21,6 +22,7 @@ from .evaluators import *
from .poolings import MaxPooling, AvgPooling, BasePoolingType
from .attrs import *
from .default_decorators import *
+
try:
import cPickle as pickle
except ImportError:
@@ -51,7 +53,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
'cross_entropy_with_selfnorm', 'cross_entropy',
'multi_binary_label_cross_entropy',
'rank_cost', 'lambda_cost', 'huber_cost',
- 'block_expand_layer', 'out_prod_layer', 'print_layer'
+ # 'block_expand_layer', # TODO(yuyang18): this layer is not correct
+ 'out_prod_layer', 'print_layer'
]
@@ -165,11 +168,12 @@ class LayerOutput(object):
:param activation: Layer Activation.
:type activation: BaseActivation.
:param parents: Layer's parents.
- :type parents: list|tuple
+ :type parents: list|tuple|collection.Sequence
"""
def __init__(self, name, layer_type, parents=None, activation=None,
- num_filters=None, img_norm_type=None, size=None, outputs=None):
+ num_filters=None, img_norm_type=None, size=None, outputs=None,
+ reverse=None):
assert isinstance(name, basestring)
assert isinstance(layer_type, basestring)
assert LayerType.is_layer_type(layer_type)
@@ -185,6 +189,7 @@ class LayerOutput(object):
if outputs is None:
outputs = ['default']
self.outputs = outputs
+ self.reverse = reverse
def __repr__(self):
"""
@@ -201,32 +206,16 @@ class LayerOutput(object):
ERROR_CLIPPING = 'error_clipping_threshold'
DROPOUT = 'drop_rate'
-
-
-def check_input(input):
- """
- Check input is a LayerOutput or list of LayerOutput or tuple of LayerOutput
- if is a LayerOutput,
-
- :param input: The input layer. Could be a list/tuple of input layer.
- :type input: LayerOutput|list|tuple
- :return: list of LayerOutput
- :rtype: list of LayerOutput
- """
-
- if isinstance(input, LayerOutput):
- return [LayerOutput]
- assert isinstance(input, list)
- for inp in input:
- assert isinstance(inp, LayerOutput)
- return list(input)
+DEVICE = 'device'
def layer_support(*attrs):
+ attrs_list = list(attrs)
+ attrs_list.append(DEVICE)
def decorator(method):
@functools.wraps(method)
def wrapper(*args, **kwargs):
- for attr in attrs:
+ for attr in attrs_list:
for each in args:
if isinstance(each, ExtraLayerAttribute):
setattr(each, '_'.join(['can', attr]), True)
@@ -289,6 +278,43 @@ def full_matrix_projection(input, size=0, param_attr=None):
return proj
+@wrap_param_attr_default()
+def trans_full_matrix_projection(input, size=0, param_attr=None):
+ """
+ Different from full_matrix_projection, this projection performs matrix
+ multiplication, using transpose of weight.
+
+ .. math::
+ out.row[i] += in.row[i] * w^\mathrm{T}
+
+ :math:`w^\mathrm{T}` means transpose of weight.
+ The simply usage is:
+
+ .. code-block:: python
+
+ proj = trans_full_matrix_projection(input=layer,
+ size=100,
+ param_attr=ParamAttr(
+ name='_proj',
+ initial_mean=0.0,
+ initial_std=0.01))
+
+ :param input: input layer
+ :type input: LayerOutput
+ :param size: The parameter size. Means the width of parameter.
+ :type size: int
+ :param param_attr: Parameter config, None if use default.
+ :type param_attr: ParameterAttribute
+ :return: A TransposedFullMatrixProjection Object.
+ :rtype: TransposedFullMatrixProjection
+ """
+ proj = TransposedFullMatrixProjection(input_layer_name=input.name,
+ size=size,
+ **param_attr.attr)
+ proj.origin = input
+ return proj
+
+
@wrap_param_attr_default()
def table_projection(input, size=0, param_attr=None):
"""
@@ -366,7 +392,7 @@ def identity_projection(input, offset=None):
Note that both of two projections should not have any parameter.
:param input: Input Layer.
- :type input: LayerOutput.
+ :type input: LayerOutput
:param offset: Offset, None if use default.
:type offset: int
:return: A IdentityProjection or IdentityOffsetProjection Object
@@ -409,10 +435,11 @@ def dotmul_projection(input, param_attr=None):
proj = DotMulProjection(input_layer_name=input.name,
size=input.size,
**param_attr.attr)
- proj.origin = input
+ proj.origin = input
return proj
-def dotmul_operator(x, y, scale=1):
+
+def dotmul_operator(a=None, b=None, scale=1, **kwargs):
"""
DotMulOperator takes two inputs and performs element-wise multiplication:
@@ -428,22 +455,31 @@ def dotmul_operator(x, y, scale=1):
op = dotmul_operator(x=layer1, y=layer2, scale=0.5)
- :param x: Input layer1
- :type x: LayerOutput
- :param y: Input layer2
- :type y: LayerOutput
+ :param a: Input layer1
+ :type a: LayerOutput
+ :param b: Input layer2
+ :type b: LayerOutput
:param scale: config scalar, default value is one.
:type scale: float
:return: A DotMulOperator Object.
:rtype: DotMulOperator
"""
- assert isinstance(x, LayerOutput)
- assert isinstance(y, LayerOutput)
- op = DotMulOperator(input_layer_names=[x.name, y.name],
+ if 'x' in kwargs or 'y' in kwargs:
+ logger.warning('x and y arguments for dotmul_operator is deprecated. '
+ 'Please use a and b as parameter.')
+ a = kwargs.get('x', a) # For Backward capacity.
+ b = kwargs.get('y', b)
+ assert isinstance(a, LayerOutput)
+ assert isinstance(b, LayerOutput)
+ if a.size is not None and b.size is not None:
+ assert a.size == b.size
+
+ op = DotMulOperator(input_layer_names=[a.name, b.name],
scale=scale)
- op.origin = [x, y]
+ op.origin = [a, b]
return op
+
@wrap_bias_attr_default(['padding_attr'])
def context_projection(input, context_len, context_start=None,
padding_attr=False):
@@ -612,7 +648,7 @@ def mixed_layer(size=0, input=None, name=None, act=None, bias_attr=False,
else:
with mixed_layer(name=name, size=size, act=act, bias_attr=bias_attr,
layer_attr=layer_attr) as m:
- if isinstance(input, list) or isinstance(input, tuple):
+ if isinstance(input, collections.Sequence):
for each in input:
m += each
else:
@@ -722,23 +758,19 @@ def fc_layer(input, size, act=None, name=None,
"""
if isinstance(input, LayerOutput):
input = [input]
- assert not isinstance(param_attr, list)
+ assert not isinstance(param_attr, collections.Sequence)
param_attr = [param_attr]
else:
- if isinstance(param_attr, list) or isinstance(param_attr, tuple):
+ if isinstance(param_attr, collections.Sequence):
assert len(input) == len(param_attr)
else:
param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))]
- assert isinstance(input, list)
-
- def __idx_to_input__(i):
- attr = param_attr[i]
- assert isinstance(attr, ParameterAttribute)
- return Input(input[i].name, **attr.attr)
+ assert isinstance(input, collections.Sequence)
Layer(
- inputs=map(__idx_to_input__, range(len(input))),
+ inputs=[Input(ipt.name, **attr.attr) for ipt, attr in zip(
+ input, param_attr)],
name=name,
type=LayerType.FC_LAYER,
size=size,
@@ -759,16 +791,20 @@ def print_layer(input, name=None):
:type name: basestring
:param input: The input layer. Could be a list/tuple of input layer.
:type input: LayerOutput|list|tuple
- :return: No return
+ :return: LayerOutput
"""
- check_input(input)
+ if isinstance(input, LayerOutput):
+ input = [input]
+ assert isinstance(input, collections.Sequence) # list or tuple
+ for each in input:
+ assert isinstance(each, LayerOutput)
Layer(
name=name,
type=LayerType.PRINT_LAYER,
inputs=[l.name for l in input],
)
- LayerOutput(name, LayerType.PRINT_LAYER, input)
+ # this layer don't return anything, can not be input of other layer.
@wrap_name_default("seq_pooling")
@@ -807,8 +843,13 @@ def pooling_layer(input, pooling_type=None, name=None, bias_attr=None,
:rtype: LayerType
"""
extra_dict = dict()
+ # noinspection PyUnresolvedReferences
if isinstance(pooling_type, AvgPooling):
extra_dict['average_strategy'] = pooling_type.strategy
+ elif isinstance(pooling_type, MaxPooling) and \
+ pooling_type.output_max_index is not None:
+ assert isinstance(pooling_type.output_max_index, bool)
+ extra_dict['output_max_index'] = pooling_type.output_max_index
extra_dict.update(ExtraLayerAttribute.to_kwargs(layer_attr))
Layer(
@@ -832,7 +873,7 @@ def pooling_layer(input, pooling_type=None, name=None, bias_attr=None,
@wrap_name_default("lstmemory")
@layer_support(DROPOUT)
def lstmemory(input, name=None, reverse=False, act=None,
- gate_act=None,
+ gate_act=None, size=None,
state_act=None, bias_attr=None, param_attr=None,
layer_attr=None):
"""
@@ -897,6 +938,16 @@ def lstmemory(input, name=None, reverse=False, act=None,
assert gate_act.support_hppl
assert state_act.support_hppl
assert act.support_hppl
+ assert input.size is not None and input.size % 4 == 0
+ if size is not None:
+ if input.size / 4 == size:
+ plog = logger.warning
+ else:
+ plog = logger.fatal
+
+ plog("NOTE: The lstmemory layer[%s]'s size is set by previous input "
+ "layer. The lstm size should be equal with input layer size/4. The"
+ " size which is set explicitly will be ignored." % name)
Layer(name=name,
type=LayerType.LSTMEMORY,
@@ -908,8 +959,9 @@ def lstmemory(input, name=None, reverse=False, act=None,
inputs=[Input(input.name, **param_attr.attr)],
**ExtraLayerAttribute.to_kwargs(layer_attr))
- return LayerOutput(name, LayerType.LSTMEMORY, [input],
- size=input.size / 4 if input.size is not None else None)
+ return LayerOutput(name, LayerType.LSTMEMORY, [input], size=input.size / 4,
+ reverse=reverse)
+
@wrap_bias_attr_default()
@wrap_param_attr_default()
@@ -919,7 +971,7 @@ def lstmemory(input, name=None, reverse=False, act=None,
@wrap_name_default("gru")
@layer_support(DROPOUT)
def grumemory(input, name=None, reverse=False, act=None,
- gate_act=None,
+ gate_act=None, size=None,
bias_attr=None, param_attr=None,
layer_attr=None):
"""
@@ -977,7 +1029,7 @@ def grumemory(input, name=None, reverse=False, act=None,
:type name: None|basestring
:param input: input layer.
:type input: LayerOutput.
- :param reverse: Wether sequence process is reversed or not.
+ :param reverse: Whether sequence process is reversed or not.
:type reverse: bool
:param act: activation type, TanhActivation by default. This activation
affects the :math:`{\\tilde{h_t}}`.
@@ -993,12 +1045,23 @@ def grumemory(input, name=None, reverse=False, act=None,
:type param_attr: ParameterAttribute|None|False
:param layer_attr: Extra Layer attribute
:type layer_attr: ExtraLayerAttribute|None
+ :param size: Stub parameter of size, but actually not used. If set this size
+ will get a warning.
+ :type size: None
:return: LayerOutput object.
:rtype: LayerOutput
"""
-
assert act.support_hppl
assert gate_act.support_hppl
+ assert input.size is not None and input.size % 3 == 0
+ if size is not None:
+ if input.size / 3 == size:
+ plog = logger.warning
+ else:
+ plog = logger.fatal
+ plog("NOTE: the gru memory layer's size is set by previous input layer,"
+ " and should be input size / 3. Set size explicitly will be "
+ "ignored.")
Layer(name=name,
type=LayerType.GRUMEMORY,
@@ -1010,8 +1073,9 @@ def grumemory(input, name=None, reverse=False, act=None,
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
- return LayerOutput(name, LayerType.GRUMEMORY, [input],
- size=input.size / 3 if input.size is not None else None)
+ return LayerOutput(name, LayerType.GRUMEMORY, [input], size=input.size / 3,
+ reverse=reverse)
+
@wrap_name_default()
@layer_support()
@@ -1030,6 +1094,12 @@ def last_seq(input, name=None, agg_level=AggregateLevel.EACH_TIMESTEP,
:return: LayerOutput object.
:rtype: LayerOutput
"""
+ if input.reverse is not None and input.reverse:
+ logger.warning("You are getting the last instance of a sequence that"
+ " is a output of a REVERSED layer. There is no time"
+ " series information at all. Maybe you want to use"
+ " first_seq instead.")
+
Layer(
name=name,
type=LayerType.SEQUENCE_LAST_INSTANCE,
@@ -1058,6 +1128,13 @@ def first_seq(input, name=None, agg_level=AggregateLevel.EACH_TIMESTEP,
:return: LayerOutput object.
:rtype: LayerOutput
"""
+
+ if input.reverse is not None and not input.reverse:
+ logger.warning('You are getting the first instance for a time series,'
+ ' and it is a normal recurrent layer output. There is no'
+ ' time series information at all. Maybe you want to use'
+ ' last_seq instead.')
+
Layer(
name=name,
type=LayerType.SEQUENCE_FIRST_INSTANCE,
@@ -1073,6 +1150,7 @@ class ExpandLevel(object):
FROM_TIMESTEP = AggregateLevel.EACH_TIMESTEP
FROM_SEQUENCE = AggregateLevel.EACH_SEQUENCE
+
@wrap_name_default()
@layer_support()
def expand_layer(input, expand_as,
@@ -1123,7 +1201,6 @@ def expand_layer(input, expand_as,
parents=[input, expand_as])
-
@wrap_name_default()
@layer_support()
def interpolation_layer(input, weight, name=None, layer_attr=None):
@@ -1155,10 +1232,15 @@ def interpolation_layer(input, weight, name=None, layer_attr=None):
:return: LayerOutput object.
:rtype: LayerOutput
"""
- assert isinstance(input, list) or isinstance(input, tuple)
+ assert isinstance(input, collections.Sequence)
assert len(input) == 2
- assert input[0].size == input[1].size
- assert weight.size == 1
+ assert isinstance(input[0], LayerOutput) and isinstance(input[1],
+ LayerOutput)
+ if input[0].size is not None and input[1].size is not None:
+ assert input[0].size == input[1].size
+ assert isinstance(weight, LayerOutput)
+ if weight.size is not None:
+ assert weight.size == 1
Layer(
name=name,
type=LayerType.INTERPOLATION_LAYER,
@@ -1200,11 +1282,13 @@ def power_layer(input, weight, name=None, layer_attr=None):
:return: LayerOutput object.
:rtype: LayerOutput
"""
- assert weight.size == 1
+ assert isinstance(input, LayerOutput) and isinstance(weight, LayerOutput)
+ if weight.size is not None:
+ assert weight.size == 1
Layer(
name=name,
type=LayerType.POWER_LAYER,
- inputs=[input.name, weight.name],
+ inputs=[weight.name, input.name],
**ExtraAttr.to_kwargs(layer_attr)
)
return LayerOutput(name, LayerType.POWER_LAYER,
@@ -1243,7 +1327,9 @@ def scaling_layer(input, weight, name=None, layer_attr=None):
:return: LayerOutput object.
:rtype: LayerOutput
"""
- assert weight.size == 1
+ assert isinstance(weight, LayerOutput) and isinstance(input, LayerOutput)
+ if weight.size is not None:
+ assert weight.size == 1
Layer(
name=name,
type=LayerType.SCALING_LAYER,
@@ -1322,6 +1408,7 @@ def cos_sim(a, b, scale=5, size=1, name=None, layer_attr=None):
:return: LayerOutput object.
:rtype: LayerOutput
"""
+ assert isinstance(a, LayerOutput) and isinstance(b, LayerOutput)
if size == 1:
Layer(
name=name,
@@ -1331,6 +1418,8 @@ def cos_sim(a, b, scale=5, size=1, name=None, layer_attr=None):
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
else:
+ if a.size is not None and b.size is not None:
+ assert size == b.size / a.size
Layer(
name=name,
type=LayerType.COSINE_SIM_VEC,
@@ -1341,11 +1430,13 @@ def cos_sim(a, b, scale=5, size=1, name=None, layer_attr=None):
)
return LayerOutput(name, LayerType.COSINE_SIM, parents=[a, b])
+
@wrap_name_default()
@wrap_bias_attr_default(has_bias=True)
+@wrap_param_attr_default()
@layer_support()
def hsigmoid(input, label, num_classes, name=None, bias_attr=None,
- layer_attr=None):
+ param_attr=None, layer_attr=None):
"""
Organize the classes into a binary tree. At each node, a sigmoid function
is used to calculate the probability of belonging to the right branch.
@@ -1379,15 +1470,23 @@ def hsigmoid(input, label, num_classes, name=None, bias_attr=None,
"""
if isinstance(input, LayerOutput):
input = [input]
- assert isinstance(input, list) or isinstance(input, tuple)
+ if not isinstance(param_attr, collections.Sequence):
+ param_attr = [param_attr]
+ else:
+ if not isinstance(param_attr, collections.Sequence):
+ param_attr = [param_attr] * len(input)
+ else:
+ assert len(param_attr) == len(input)
+
+ assert isinstance(input, collections.Sequence)
assert isinstance(label, LayerOutput)
assert label.layer_type == LayerType.DATA
ipts_for_layer = []
parents = []
- for each_input in input:
+ for each_input, each_param_attr in zip(input, param_attr):
assert isinstance(each_input, LayerOutput)
- ipts_for_layer.append(each_input.name)
+ ipts_for_layer.append(Input(each_input.name, **each_param_attr.attr))
parents.append(each_input)
ipts_for_layer.append(label.name)
parents.append(label)
@@ -1402,6 +1501,7 @@ def hsigmoid(input, label, num_classes, name=None, bias_attr=None,
)
return LayerOutput(name, LayerType.HSIGMOID, parents=parents)
+
@wrap_name_default("conv")
@wrap_param_attr_default()
@wrap_bias_attr_default()
@@ -1435,23 +1535,26 @@ def img_conv_layer(input, filter_size, num_filters,
:type name: basestring
:param input: Layer Input.
:type input: LayerOutput
- :param filter_size: The x dimension of a filter kernel.
- :type filter_size: int
+ :param filter_size: The x dimension of a filter kernel. Or input a tuple for
+ two image dimension.
+ :type filter_size: int|tuple|list
:param filter_size_y: The y dimension of a filter kernel. Since PaddlePaddle
currently supports rectangular filters, the filter's
shape will be (filter_size, filter_size_y).
- :type filter_size_y: int
+ :type filter_size_y: int|None
:param num_filters: Each filter group's number of filter
:param act: Activation type. Default is tanh
:type act: BaseActivation
:param groups: Group size of filters.
:type groups: int
- :param stride: The x dimension of the stride.
- :type stride: int
+ :param stride: The x dimension of the stride. Or input a tuple for two image
+ dimension.
+ :type stride: int|tuple|list
:param stride_y: The y dimension of the stride.
:type stride_y: int
- :param padding: The x dimension of the padding.
- :type padding: int
+ :param padding: The x dimension of the padding. Or input a tuple for two
+ image dimension
+ :type padding: int|tuple|list
:param padding_y: The y dimension of the padding.
:type padding_y: int
:param bias_attr: Convolution bias attribute. None means default bias.
@@ -1472,13 +1575,30 @@ def img_conv_layer(input, filter_size, num_filters,
if num_channels is None:
assert input.num_filters is not None
num_channels = input.num_filters
+
if filter_size_y is None:
- filter_size_y = filter_size
+ if isinstance(filter_size, collections.Sequence):
+ assert len(filter_size) == 2
+ filter_size, filter_size_y = filter_size
+ else:
+ filter_size_y = filter_size
+
if stride_y is None:
- stride_y = stride
+ if isinstance(stride, collections.Sequence):
+ assert len(stride) == 2
+ stride, stride_y = stride
+ else:
+ stride_y = stride
+
if padding_y is None:
- padding_y = padding
- if param_attr.attr.get('initial_smart') == True: # special initial for conv layers.
+ if isinstance(padding, collections.Sequence):
+ assert len(padding) == 2
+ padding, padding_y = padding
+ else:
+ padding_y = padding
+
+ if param_attr.attr.get('initial_smart'):
+ # special initial for conv layers.
init_w = (2.0 / (filter_size ** 2 * num_channels)) ** 0.5
param_attr.attr["initial_mean"] = 0.0
param_attr.attr["initial_std"] = init_w
@@ -1489,8 +1609,9 @@ def img_conv_layer(input, filter_size, num_filters,
inputs=Input(input.name, conv=Conv(
filter_size=filter_size, padding=padding, stride=stride,
channels=num_channels, groups=groups,
- filter_size_y=filter_size_y, padding_y=padding_y, stride_y=stride_y),
- **param_attr.attr),
+ filter_size_y=filter_size_y, padding_y=padding_y,
+ stride_y=stride_y),
+ **param_attr.attr),
active_type=act.name,
num_filters=num_filters,
bias=ParamAttr.to_bias(bias_attr),
@@ -1550,7 +1671,7 @@ def img_pool_layer(input, pool_size, name=None,
type=LayerType.POOL_LAYER,
inputs=[Input(input.name,
pool=Pool(
- pool_type=pool_type.name + '-projection',
+ pool_type=''.join([pool_type.name, '-projection']),
channels=num_channels,
size_x=pool_size,
start=start,
@@ -1604,7 +1725,6 @@ def img_cmrnorm_layer(input, size, scale=0.0128, power=0.75,
:type power: float
:param num_channels: input layer's filers number or channels. If
num_channels is None, it will be set automatically.
- :param blocked: namely normalize in number of blocked feature maps.
:param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
@@ -1657,7 +1777,7 @@ def batch_norm_layer(input, act=None, name=None, num_channels=None,
batch_norm for CPU. Otherwise, select batch norm
type based on the specified type. If you use cudnn_batch_norm,
we suggested you use latest version, such as v5.1.
- :type type: None|string, None or "batch_norm" or "cudnn_batch_norm"
+ :type batch_norm_type: None|string, None or "batch_norm" or "cudnn_batch_norm"
:param act: Activation Type. Better be relu. Because batch
normalization will normalize input near zero.
:type act: BaseActivation
@@ -1818,7 +1938,7 @@ def addto_layer(input, act=None, name=None, bias_attr=None,
if isinstance(input, LayerOutput):
input = [input]
- assert isinstance(input, list) or isinstance(input, tuple)
+ assert isinstance(input, collections.Sequence)
ipts_for_layer = []
for each_input in input:
assert isinstance(each_input, LayerOutput)
@@ -1832,7 +1952,7 @@ def addto_layer(input, act=None, name=None, bias_attr=None,
active_type=act.name,
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
- assert isinstance(input, list) or isinstance(input, tuple)
+
return LayerOutput(name, LayerType.ADDTO_LAYER, parents=input,
activation=act, num_filters=num_filters)
@@ -1848,7 +1968,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None):
:param name: Layer name.
:type name: basestring
:param input: input layers or projections
- :type input: list|tuple
+ :type input: list|tuple|collection.Sequence
:param act: Activation type.
:type act: BaseActivation
:param layer_attr: Extra Layer Attribute.
@@ -1862,10 +1982,10 @@ def concat_layer(input, act=None, name=None, layer_attr=None):
elif isinstance(input, Projection):
input = [input]
else:
- assert isinstance(input, list) or isinstance(input, tuple)
+ assert isinstance(input, collections.Sequence)
def __is_type__(o, tp):
- if not isinstance(o, list) and not isinstance(o, tuple):
+ if not isinstance(o, collections.Sequence):
if o == tp:
return True
elif len(o.__bases__) == 0:
@@ -2147,28 +2267,51 @@ def get_output_layer(input, arg_name, name=None, layer_attr=None):
@wrap_param_attr_default()
@layer_support()
def recurrent_layer(input, act=None, bias_attr=None,
- param_attr=None, name=None, layer_attr=None):
+ param_attr=None, name=None, reverse=False, layer_attr=None):
"""
- TODO(yuyang18): Add docs
+ Simple recurrent unit layer. It is just a fully connect layer through both
+ time and neural network.
- :param input:
- :param size:
- :param act:
- :param bias_attr:
- :param param_attr:
- :param name:
- :param layer_attr:
+ For each sequence [start, end] it performs the following computation\:
+
+ .. math::
+
+ out_{i} = act(in_{i}) \\ \\ \\text{for} \\ i = start \\\\
+ out_{i} = act(in_{i} + out_{i-1} * W) \\ \\ \\text{for} \\ start < i <= end
+
+ If reversed is true, the order is reversed\:
+
+ .. math::
+
+ out_{i} = act(in_{i}) \\ \\ \\text{for} \\ i = end \\\\
+ out_{i} = act(in_{i} + out_{i+1} * W) \\ \\ \\text{for} \\ start <= i < end
+
+
+ :param input: Input Layer
+ :type input: LayerOutput
+ :param act: activation.
+ :type act: BaseActivation
+ :param bias_attr: bias attribute.
+ :type bias_attr: ParameterAttribute
+ :param param_attr: parameter attribute.
+ :type param_attr: ParameterAttribute
+ :param name: name of the layer
+ :type name: basestring
+ :param layer_attr: Layer Attribute.
+ :type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
+ :rtype: LayerOutput
"""
Layer(name=name,
type=LayerType.RECURRENT_LAYER,
inputs=Input(input.name, **param_attr.attr),
active_type=act.name,
- size=input.size,
bias=ParamAttr.to_bias(bias_attr),
+ reversed=reverse,
**ExtraAttr.to_kwargs(layer_attr))
return LayerOutput(name=name, layer_type=LayerType.RECURRENT_LAYER,
- parents=[input], size=input.size, activation=act)
+ parents=[input], size=input.size, activation=act,
+ reverse=reverse)
class StaticInput(object):
@@ -2176,6 +2319,7 @@ class StaticInput(object):
StaticInput is only used in recurrent_group which defines a read-only memory
that can be a sequence or non-sequence.
"""
+
def __init__(self, input, is_seq=False, size=None):
assert isinstance(input, LayerOutput)
self.input = input
@@ -2195,6 +2339,7 @@ class SubsequenceInput(object):
input = SubsequenceInput(layer)
"""
+
def __init__(self, input):
assert isinstance(input, LayerOutput)
assert input.size is not None
@@ -2267,7 +2412,7 @@ def recurrent_group(step, input, reverse=False, name=None):
if is_single_input(input):
input = [input]
- assert isinstance(input, list) or isinstance(input, tuple)
+ assert isinstance(input, collections.Sequence)
def is_in_links(x):
return isinstance(x, LayerOutput) or isinstance(x, SubsequenceInput)
@@ -2311,6 +2456,7 @@ def recurrent_group(step, input, reverse=False, name=None):
for ot in layer_outs:
assert isinstance(ot, LayerOutput)
+ ot.reverse = reverse
if contains_sub_seq[0]:
RecurrentLayerGroupSetOutLink(Link(ot.name, has_subseq=True))
else:
@@ -2323,6 +2469,7 @@ def recurrent_group(step, input, reverse=False, name=None):
else:
return layer_outs
+
class BaseGeneratedInput(object):
def __init__(self):
self.bos_id = None
@@ -2351,6 +2498,7 @@ class GeneratedInput(BaseGeneratedInput):
return trg_emb
def __init__(self, size, embedding_name, embedding_size):
+ super(GeneratedInput, self).__init__()
self.size = size
self.embedding_name = embedding_name
self.embedding_size = embedding_size
@@ -2387,6 +2535,7 @@ def maxid_layer(input, name=None, layer_attr=None):
layer_type=LayerType.MAXID_LAYER,
parents=[input])
+
@wrap_name_default()
def out_prod_layer(input1, input2, name=None, layer_attr=None):
"""
@@ -2419,7 +2568,8 @@ def out_prod_layer(input1, input2, name=None, layer_attr=None):
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name=name,
layer_type=LayerType.OUT_PROD_LAYER,
- parents=[input1,input2])
+ parents=[input1, input2])
+
@wrap_name_default()
def eos_layer(input, eos_id, name=None, layer_attr=None):
@@ -2472,14 +2622,14 @@ def beam_search(step, input, bos_id, eos_id, beam_size,
def rnn_step(input):
last_time_step_output = memory(name='rnn', size=512)
- with mixed_layer(size=512) as simple_rnn:
+ with mixed_layer(size=512, name='rnn') as simple_rnn:
simple_rnn += full_matrix_projection(input)
simple_rnn += last_time_step_output
return simple_rnn
beam_gen = beam_search(name="decoder",
step=rnn_step,
- input=[StaticInput("encoder_last")],
+ input=[StaticInput(encoder_last)],
bos_id=0,
eos_id=1,
beam_size=5,
@@ -2493,18 +2643,18 @@ def beam_search(step, input, bos_id, eos_id, beam_size,
:param name: Name of the recurrent unit that generates sequences.
:type name: base string
:param step: A callable function that defines the calculation in a time
- step, and it is appled to sequences with arbitrary length by
+ step, and it is applied to sequences with arbitrary length by
sharing a same set of weights.
You can refer to the first parameter of recurrent_group, or
demo/seqToseq/seqToseq_net.py for more details.
:type step: callable
:param input: Input data for the recurrent unit
- :type input: StaticInput|GeneratedInput
+ :type input: list
:param bos_id: Index of the start symbol in the dictionary. The start symbol
is a special token for NLP task, which indicates the
beginning of a sequence. In the generation task, the start
- symbol is ensential, since it is used to initialize the RNN
+ symbol is essential, since it is used to initialize the RNN
internal state.
:type bos_id: int
:param eos_id: Index of the end symbol in the dictionary. The end symbol is
@@ -2513,6 +2663,8 @@ def beam_search(step, input, bos_id, eos_id, beam_size,
symbol is generated, or a pre-defined max iteration number
is exceeded.
:type eos_id: int
+ :param max_length: Max generated sequence length.
+ :type max_length: int
:param beam_size: Beam search for sequence generation is an iterative search
algorithm. To maintain tractability, every iteration only
only stores a predetermined number, called the beam_size,
@@ -2553,8 +2705,8 @@ def beam_search(step, input, bos_id, eos_id, beam_size,
real_input = []
for i, each_input in enumerate(input):
# print type(each_input)
- assert isinstance(each_input, StaticInput) or isinstance(each_input,
- BaseGeneratedInput)
+ assert isinstance(each_input, StaticInput) or isinstance(
+ each_input, BaseGeneratedInput)
if isinstance(each_input, BaseGeneratedInput):
assert generated_input_index == -1
generated_input_index = i
@@ -2625,9 +2777,11 @@ def regression_cost(input, label, cost='square_error', name=None):
@wrap_name_default("cost")
+@layer_support()
def classification_cost(input, label, name=None,
cost="multi-class-cross-entropy",
- evaluator=classification_error_evaluator):
+ evaluator=classification_error_evaluator,
+ layer_attr=None):
"""
classification cost Layer.
@@ -2640,13 +2794,16 @@ def classification_cost(input, label, name=None,
:param cost: cost method.
:type cost: basestring
:param evaluator: Evaluator method.
+ :param layer_attr: layer's extra attribute.
+ :type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert input.layer_type != LayerType.DATA
assert isinstance(input.activation, SoftmaxActivation)
assert label.layer_type == LayerType.DATA
- Layer(name=name, type=cost, inputs=[Input(input.name), Input(label.name)])
+ Layer(name=name, type=cost, inputs=[Input(input.name), Input(label.name)],
+ **ExtraLayerAttribute.to_kwargs(layer_attr))
def __add_evaluator__(e):
assert callable(e)
@@ -2659,7 +2816,7 @@ def classification_cost(input, label, name=None,
e(name=e.__name__, input=input, label=label)
- if not isinstance(evaluator, list) and not isinstance(evaluator, tuple):
+ if not isinstance(evaluator, collections.Sequence):
evaluator = [evaluator]
for each_evaluator in evaluator:
@@ -2667,8 +2824,9 @@ def classification_cost(input, label, name=None,
return LayerOutput(name, LayerType.COST, parents=[input, label])
+
def conv_operator(img, filter, filter_size, num_filters,
- num_channel=None, stride=1, padding=0, groups=1,
+ num_channel=None, stride=1, padding=0,
filter_size_y=None, stride_y=None, padding_y=None):
"""
Different from img_conv_layer, conv_op is an Operator, which can be used
@@ -2682,7 +2840,7 @@ def conv_operator(img, filter, filter_size, num_filters,
op = conv_operator(img=input1,
filter=input2,
- filter_size=3.0,
+ filter_size=3,
num_filters=64,
num_channels=64)
@@ -2696,8 +2854,8 @@ def conv_operator(img, filter, filter_size, num_filters,
PaddlePaddle now supports rectangular filters,
the filter's shape can be (filter_size, filter_size_y).
:type filter_size_y: int
- :param num_filter: channel of output data.
- :type num_filter: int
+ :param num_filters: channel of output data.
+ :type num_filters: int
:param num_channel: channel of input data.
:type num_channel: int
:param stride: The x dimension of the stride.
@@ -2717,8 +2875,16 @@ def conv_operator(img, filter, filter_size, num_filters,
stride_y = stride
if padding_y is None:
padding_y = padding
+
+ if num_channel is None:
+ num_channel = img.num_filters
+
+ assert isinstance(filter, LayerOutput)
+ if filter.size is not None:
+ filter.size = filter_size * filter_size_y * num_filters * num_channel
+
op = ConvOperator(input_layer_names=[img.name, filter.name],
- num_filters = num_filter,
+ num_filters=num_filters,
conv_conf=Conv(filter_size=filter_size,
padding=padding,
stride=stride,
@@ -2726,13 +2892,13 @@ def conv_operator(img, filter, filter_size, num_filters,
filter_size_y=filter_size_y,
padding_y=padding_y,
stride_y=stride_y,
- groups=groups))
+ groups=1))
op.origin = [img, filter]
return op
@wrap_name_default()
-def conv_shift_layer(input, name=None):
+def conv_shift_layer(a, b, name=None):
"""
This layer performs cyclic convolution for two input. For example:
- a[in]: contains M elements.
@@ -2744,68 +2910,77 @@ def conv_shift_layer(input, name=None):
c[i] = \sum_{j=-(N-1)/2}^{(N-1)/2}a_{i+j} * b_{j}
In this formular:
- - a's index is computed modulo M.
- - b's index is computed modulo N.
+ - a's index is computed modulo M. When it is negative, then get item from
+ the right side (which is the end of array) to the left.
+ - b's index is computed modulo N. When it is negative, then get item from
+ the right size (which is the end of array) to the left.
The example usage is:
.. code-block:: python
- conv_shift = conv_shif_layer(input=[layer1, layer2])
+ conv_shift = conv_shift_layer(input=[layer1, layer2])
:param name: layer name
:type name: basestring
- :param input: Input layer.
- :type input: LayerOutput|list|tuple.
+ :param a: Input layer a.
+ :type a: LayerOutput
+ :param b: input layer b
+ :type b: LayerOutput
:return: LayerOutput object.
:rtype: LayerOutput
"""
- assert isinstance(input, list) or isinstance(input, tuple)
+ assert isinstance(a, LayerOutput) and isinstance(b, LayerOutput)
+ assert b.size is None or b.size % 2 == 1 # size of b must be odd.
Layer(
name=name,
type=LayerType.CONV_SHIFT_LAYER,
- inputs=[x.name for x in input],
+ inputs=[a.name, b.name],
)
- return LayerOutput(name, LayerType.CONV_SHIFT_LAYER, parents=input)
+ return LayerOutput(name, LayerType.CONV_SHIFT_LAYER, parents=[a, b],
+ size=a.size)
@wrap_name_default()
@wrap_param_attr_default()
@wrap_bias_attr_default()
+@wrap_act_default(act=LinearActivation())
@layer_support(ERROR_CLIPPING, DROPOUT)
-def tensor_layer(input, size, act=None, name=None,
+def tensor_layer(a, b, size, act=None, name=None,
param_attr=None, bias_attr=None, layer_attr=None):
"""
This layer performs tensor operation for two input.
For example, each sample:
.. math::
- y_{i} = x_{1} * W_{i} * {x_{2}^\mathrm{T}}, i=0,1,...,K-1
+ y_{i} = a * W_{i} * {b^\mathrm{T}}, i=0,1,...,K-1
In this formular:
- - :math:`x_{1}`: the first input contains M elements.
- - :math:`x_{2}`: the second input contains N elements.
+ - :math:`a`: the first input contains M elements.
+ - :math:`b`: the second input contains N elements.
- :math:`y_{i}`: the i-th element of y.
- :math:`W_{i}`: the i-th learned weight, shape if [M, N]
- - :math:`{x_{2}}^\mathrm{T}`: the transpose of :math:`x_{2}`.
+ - :math:`b^\mathrm{T}`: the transpose of :math:`b_{2}`.
The simple usage is:
.. code-block:: python
- tensor = tensor_layer(input=[layer1, layer2])
+ tensor = tensor_layer(a=layer1, b=layer2, size=1000)
:param name: layer name
:type name: basestring
- :param input: Input layer.
- :type input: LayerOutput|list|tuple.
+ :param a: Input layer a.
+ :type a: LayerOutput
+ :param b: input layer b.
+ :type b: LayerOutput
:param size: the layer dimension.
:type size: int.
:param act: Activation Type. Default is tanh.
:type act: BaseActivation
:param param_attr: The Parameter Attribute.
- :type param_attr: ParameterAttribute|list
+ :type param_attr: ParameterAttribute
:param bias_attr: The Bias Attribute. If no bias, then pass False or
something not type of ParameterAttribute. None will get a
default Bias.
@@ -2815,65 +2990,26 @@ def tensor_layer(input, size, act=None, name=None,
:return: LayerOutput object.
:rtype: LayerOutput
"""
- assert isinstance(input, list) or isinstance(input, tuple)
- assert len(input) == 2
+ assert isinstance(a, LayerOutput) and isinstance(b, LayerOutput)
Layer(
name=name,
size=size,
type=LayerType.TENSOR_LAYER,
active_type=act.name,
bias=ParamAttr.to_bias(bias_attr),
- inputs=[Input(input[0].name, **param_attr.attr),
- Input(input[1].name)],
+ inputs=[Input(a.name, **param_attr.attr),
+ Input(b.name)],
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
- return LayerOutput(name, LayerType.TENSOR_LAYER, parents=input,
+ return LayerOutput(name, LayerType.TENSOR_LAYER, parents=[a, b],
activation=act, size=size)
-@wrap_param_attr_default()
-def trans_full_matrix_projection(input, size=0, param_attr=None):
- """
- Different from full_matrix_projection, this projection performs matrix
- multiplication, using transpose of weight.
-
- .. math::
- out.row[i] += in.row[i] * w^\mathrm{T}
-
- :math:`w^\mathrm{T}` means transpose of weight.
- The simply usage is:
-
- .. code-block:: python
-
- proj = trans_full_matrix_projection(input=layer,
- size=100,
- param_attr=ParamAttr(
- name='_proj',
- initial_mean=0.0,
- initial_std=0.01))
-
- :param input: input layer
- :type input: LayerOutput
- :param size: The parameter size. Means the width of parameter.
- :type size: int
- :param param_attr: Parameter config, None if use default.
- :type param_attr: ParameterAttribute
- :return: A TransposedFullMatrixProjection Object.
- :rtype: TransposedFullMatrixProjection
- """
- proj = TransposedFullMatrixProjection(input_layer_name=input.name,
- size=size,
- **param_attr.attr)
- proj.origin = input
- proj.origin.projection = "trans_matrix"
- return proj
-
-
@wrap_name_default()
@wrap_param_attr_default()
@wrap_bias_attr_default()
@wrap_act_default()
-def selective_fc_layer(input, size, act=None, name=None,
+def selective_fc_layer(input, select, size, act=None, name=None,
pass_generation=False,
has_selected_colums=True,
mul_ratio=0.02,
@@ -2888,12 +3024,15 @@ def selective_fc_layer(input, size, act=None, name=None,
.. code-block:: python
- sel_fc = selective_fc_layer(input=input, 128, act=TanhActivation())
+ sel_fc = selective_fc_layer(input=input, size=128, act=TanhActivation())
:param name: The Layer Name.
:type name: basestring
:param input: The input layer.
:type input: LayerOutput|list|tuple
+ :param select: The select layer. The output of select layer should be a
+ sparse binary matrix, and treat as the mask of selective fc.
+ :type select: LayerOutput
:param size: The layer dimension.
:type size: int
:param act: Activation Type. Default is tanh.
@@ -2911,33 +3050,33 @@ def selective_fc_layer(input, size, act=None, name=None,
"""
if isinstance(input, LayerOutput):
input = [input]
- assert not isinstance(param_attr, list)
+ assert not isinstance(param_attr, collections.Sequence)
param_attr = [param_attr]
else:
- if isinstance(param_attr, list) or isinstance(param_attr, tuple):
+ if isinstance(param_attr, collections.Sequence):
assert len(input) == len(param_attr)
else:
param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))]
- assert isinstance(input, list)
-
- def __idx_to_input__(i):
- attr = param_attr[i]
- assert isinstance(attr, ParameterAttribute)
- return Input(input[i].name, **attr.attr)
-
+ assert isinstance(input, collections.Sequence)
+ assert isinstance(select, LayerOutput)
+ if select.size is not None:
+ assert select.size == size
Layer(
- inputs=map(__idx_to_input__, range(len(input))),
+ inputs=[Input(ipt.name, **attr.attr) for ipt, attr in zip(
+ input, param_attr)] + [select.name],
name=name,
type=LayerType.SEL_FC_LAYER,
size=size,
+ bias=ParameterAttribute.to_bias(bias_attr),
active_type=act.name,
selective_fc_pass_generation=pass_generation,
has_selected_colums=has_selected_colums,
selective_fc_full_mul_ratio=mul_ratio,
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
- return LayerOutput(name, LayerType.SEL_FC_LAYER, input, activation=act,
+ return LayerOutput(name, LayerType.SEL_FC_LAYER, list(input) + [select],
+ activation=act,
size=size)
@@ -3005,7 +3144,7 @@ def slope_intercept_layer(input, name=None, slope=1.0, intercept=0.0):
@wrap_name_default()
-def linear_comb_layer(weights, vectors, size, name=None):
+def linear_comb_layer(weights, vectors, size=None, name=None):
"""
A layer for weighted sum of vectors takes two inputs.
- Input: size of weights is M
@@ -3035,11 +3174,13 @@ def linear_comb_layer(weights, vectors, size, name=None):
.. code-block:: python
- linear_comb = linear_comb_layer(weighs=weight, vectors=vectors,
+ linear_comb = linear_comb_layer(weights=weight, vectors=vectors,
size=elem_dim)
- :param input: The input layers.
- :type input: LayerOutput
+ :param weights: The weight layer.
+ :type weights: LayerOutput
+ :param vectors: The vector layer.
+ :type vectors: LayerOutput
:param size: the dimension of this layer.
:type size: int
:param name: The Layer Name.
@@ -3047,7 +3188,13 @@ def linear_comb_layer(weights, vectors, size, name=None):
:return: LayerOutput object.
:rtype: LayerOutput
"""
-
+ assert isinstance(weights, LayerOutput) and isinstance(vectors, LayerOutput)
+ if vectors.size is not None and weights.size is not None:
+ assert vectors.size % weights.size == 0
+ if size is None:
+ size = vectors.size / weights.size
+ else:
+ assert size == vectors.size / weights.size
Layer(
name=name,
type=LayerType.LINEAR_COMBINATION_LAYER,
@@ -3057,8 +3204,10 @@ def linear_comb_layer(weights, vectors, size, name=None):
return LayerOutput(name, LayerType.LINEAR_COMBINATION_LAYER,
[weights, vectors], size=size)
+
convex_comb_layer = linear_comb_layer
+
@wrap_name_default()
def block_expand_layer(input,
channel=0,
@@ -3120,22 +3269,22 @@ def block_expand_layer(input,
"""
Layer(name=name,
input=Input(input.name,
- block_expand=BlockExpand(channel=channel,
+ block_expand=BlockExpand(channels=channel,
block_x=block_x,
block_y=block_y,
stride_x=stride_x,
stride_y=stride_y,
padding_x=padding_x,
padding_y=padding_y)
- ),
+ ),
type=LayerType.BLOCK_EXPAND,
- )
+ )
+
+ return LayerOutput(name, LayerType.BLOCK_EXPAND, parents=[input])
- return LayerOutput(name, LayerType.BLOCK_EXPAND,
- parents=[input], size=size)
@wrap_name_default()
-def ctc_layer(input, label, size, name=None, norm_by_times=False):
+def ctc_layer(input, label, size=None, name=None, norm_by_times=False):
"""
Connectionist Temporal Classification (CTC) is designed for temporal
classication task. That is, for sequence labeling problems where the
@@ -3143,7 +3292,8 @@ def ctc_layer(input, label, size, name=None, norm_by_times=False):
More details can be found by referring to `Connectionist Temporal
Classification: Labelling Unsegmented Sequence Data with Recurrent
- Neural Networks `_
+ Neural Networks `_
Note:
Considering the 'blank' label needed by CTC, you need to use
@@ -3161,14 +3311,14 @@ def ctc_layer(input, label, size, name=None, norm_by_times=False):
size=9055,
norm_by_times=True)
- :param input: The input layers.
+ :param input: The input layer.
:type input: LayerOutput
:param label: The data layer of label with variable length.
:type label: LayerOutput
:param size: category numbers + 1.
:type size: int
- :param name: The name of this layer, which can not specify.
- :type name: string|None
+ :param name: The name of this layer
+ :type name: basestring|None
:param norm_by_times: Whether to normalization by times. False by default.
:type norm_by_times: bool
:return: LayerOutput object.
@@ -3176,18 +3326,24 @@ def ctc_layer(input, label, size, name=None, norm_by_times=False):
"""
assert isinstance(input, LayerOutput)
assert isinstance(label, LayerOutput)
+ if label.size is not None:
+ if size is not None:
+ assert size == label.size + 1
+ else:
+ size = label.size + 1
Layer(
- name = name,
- type = LayerType.CTC_LAYER,
- size = size,
- norm_by_times = norm_by_times,
- inputs = [input.name, label.name]
+ name=name,
+ type=LayerType.CTC_LAYER,
+ size=size,
+ norm_by_times=norm_by_times,
+ inputs=[input.name, label.name]
)
return LayerOutput(name, LayerType.CTC_LAYER, [input, label], size=size)
+
@wrap_name_default()
@wrap_param_attr_default()
-def crf_layer(input, label, size, weight=None, param_attr=None, name=None):
+def crf_layer(input, label, size=None, weight=None, param_attr=None, name=None):
"""
A layer for calculating the cost of sequential conditional random
field model.
@@ -3203,7 +3359,7 @@ def crf_layer(input, label, size, weight=None, param_attr=None, name=None):
:param input: The first input layer is the feature.
:type input: LayerOutput
:param label: The second input layer is label.
- :type input: LayerOutput
+ :type label: LayerOutput
:param size: The category number.
:type size: int
:param weight: The third layer is "weight" of each sample, which is an
@@ -3219,6 +3375,12 @@ def crf_layer(input, label, size, weight=None, param_attr=None, name=None):
assert isinstance(input, LayerOutput)
assert isinstance(label, LayerOutput)
assert weight is None or isinstance(weight, LayerOutput)
+ if input.size is not None and label.size is not None:
+ assert input.size == label.size
+ if size is None:
+ size = input.size
+ else:
+ assert size == input.size
ipts = [Input(input.name, **param_attr.attr),
Input(label.name)]
@@ -3226,16 +3388,17 @@ def crf_layer(input, label, size, weight=None, param_attr=None, name=None):
ipts.append(Input(weight.name))
Layer(
- name = name,
- type = LayerType.CRF_LAYER,
- size = size,
- inputs = ipts,
+ name=name,
+ type=LayerType.CRF_LAYER,
+ size=size,
+ inputs=ipts,
)
parents = [input, label]
if weight is not None:
parents.append(weight)
return LayerOutput(name, LayerType.CRF_LAYER, parents, size=size)
+
@wrap_name_default()
@wrap_param_attr_default()
def crf_decoding_layer(input, size, label=None, param_attr=None, name=None):
@@ -3268,24 +3431,28 @@ def crf_decoding_layer(input, size, label=None, param_attr=None, name=None):
ipts.append(Input(label.name))
Layer(
- name = name,
- type = LayerType.CRF_DECODING_LAYER,
- size = size,
- inputs = ipts,
+ name=name,
+ type=LayerType.CRF_DECODING_LAYER,
+ size=size,
+ inputs=ipts,
)
parents = [input]
if label is not None:
parents.append(label)
return LayerOutput(name, LayerType.CRF_DECODING_LAYER, parents, size=size)
+
"""
following are cost Layers.
"""
+
+
@wrap_name_default()
-def rank_cost(left, right, lable, weight=None, name=None, coeff=1.0):
+def rank_cost(left, right, label, weight=None, name=None, coeff=1.0):
"""
A cost Layer for learning to rank using gradient descent. Details can refer
- to `papers `_.
+ to `papers `_.
This layer contains at least three inputs. The weight is an optional
argument, which affects the cost.
@@ -3342,12 +3509,13 @@ def rank_cost(left, right, lable, weight=None, name=None, coeff=1.0):
type=LayerType.RANK_COST,
inputs=ipts,
coeff=coeff,
- )
+ )
return LayerOutput(name, LayerType.RANK_COST, parents=parents)
+
@wrap_name_default()
-def lambda_cost(input, score, NDCG_num=5, max_sort_size=-1, coeff=1.0):
+def lambda_cost(input, score, name, NDCG_num=5, max_sort_size=-1):
"""
lambdaCost for lambdaRank LTR approach.
@@ -3360,9 +3528,7 @@ def lambda_cost(input, score, NDCG_num=5, max_sort_size=-1, coeff=1.0):
NDCG_num=8,
max_sort_size=-1)
- :param input: The 1st input. Samples of the same query should be loaded
- as sequence. User should provided socres for each sample.
- The score should be the 2nd input of this layer.
+ :param input: Samples of the same query should be loaded as sequence.
:type input: LayerOutput
:param score: The 2nd input. Score of each sample.
:type input: LayerOutput
@@ -3380,21 +3546,22 @@ def lambda_cost(input, score, NDCG_num=5, max_sort_size=-1, coeff=1.0):
:type max_sort_size: int
:param name: The name of this layers. It is not necessary.
:type name: None|basestring
- :param coeff: The coefficient affects the gradient in the backward.
- :type coeff: float
:return: LayerOutput object.
:rtype: LayerOutput
"""
+ assert isinstance(input, LayerOutput) and isinstance(score, LayerOutput)
+ if score.size is not None:
+ assert score.size == 1
Layer(name=name,
type=LayerType.LAMBDA_COST,
inputs=[input.name, score.name],
NDCG_num=NDCG_num,
- max_sort_size=max_sort_size,
- coeff=coeff,
- )
+ max_sort_size=max_sort_size
+ )
return LayerOutput(name, LayerType.LAMBDA_COST, parents=[input, score])
+
@wrap_name_default()
def cross_entropy(input, label, name=None, coeff=1.0):
"""
@@ -3422,9 +3589,10 @@ def cross_entropy(input, label, name=None, coeff=1.0):
type=LayerType.CROSS_ENTROPY,
inputs=[input.name, label.name],
coeff=coeff,
- )
+ )
return LayerOutput(name, LayerType.CROSS_ENTROPY, parents=[input, label])
+
@wrap_name_default()
def cross_entropy_with_selfnorm(input, label, name=None, coeff=1.0,
softmax_selfnorm_alpha=0.1):
@@ -3455,12 +3623,13 @@ def cross_entropy_with_selfnorm(input, label, name=None, coeff=1.0,
inputs=[input.name, label.name],
coeff=coeff,
softmax_selfnorm_alpha=softmax_selfnorm_alpha,
- )
+ )
return LayerOutput(name,
LayerType.CROSS_ENTROPY_WITH_SELFNORM,
parents=[input, label])
+
@wrap_name_default()
def huber_cost(input, label, name=None, coeff=1.0):
"""
@@ -3474,8 +3643,6 @@ def huber_cost(input, label, name=None, coeff=1.0):
:type input: LayerOutput.
:param label: The input label.
:type input: LayerOutput.
- :param type: The type of cost.
- :type type: basestring.
:param name: The name of this layers. It is not necessary.
:type name: None|basestring.
:param coeff: The coefficient affects the gradient in the backward.
@@ -3483,14 +3650,17 @@ def huber_cost(input, label, name=None, coeff=1.0):
:return: LayerOutput object.
:rtype: LayerOutput.
"""
-
+ assert isinstance(input, LayerOutput)
+ if input.size is not None:
+ assert input.size == 1
Layer(name=name,
type=LayerType.HUBER,
inputs=[input.name, label.name],
coeff=coeff,
- )
+ )
return LayerOutput(name, LayerType.HUBER, parents=[input, label])
+
@wrap_name_default()
def multi_binary_label_cross_entropy(input, label, name=None, coeff=1.0):
"""
@@ -3514,15 +3684,16 @@ def multi_binary_label_cross_entropy(input, label, name=None, coeff=1.0):
:rtype: LayerOutput
"""
- if not isinstance(input.act, SigmoidActivation):
+ if input.activation is None or \
+ not isinstance(input.activation, SigmoidActivation):
logger.log(logging.WARN,
"%s is not recommend for batch normalization's activation, "
- "maybe the relu is better" % act.name)
+ "maybe the relu is better" % repr(input.activation))
Layer(name=name,
type=LayerType.MULTI_BIN_LABEL_CROSS_ENTROPY,
inputs=[input.name, label.name],
coeff=coeff,
- )
+ )
return LayerOutput(name, LayerType.MULTI_BIN_LABEL_CROSS_ENTROPY,
parents=[input, label])
diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py
index 93c4261cc48c5ef16ac8bd8612513bd8ebfd04bc..e59e93acbe33ab354e820fbc0a34069399bf6f86 100644
--- a/python/paddle/trainer_config_helpers/networks.py
+++ b/python/paddle/trainer_config_helpers/networks.py
@@ -616,7 +616,7 @@ def lstmemory_group(input, size=None, name=None,
cell states, or hidden states in every time step are accessible to for the
user. This is especially useful in attention model. If you do not need to
access to the internal states of the lstm, but merely use its outputs,
- it is recommanded to use the lstmemory, which is relatively faster than
+ it is recommended to use the lstmemory, which is relatively faster than
lstmemory_group.
NOTE: In PaddlePaddle's implementation, the following input-to-hidden
@@ -1052,7 +1052,7 @@ def dropout_layer(input, dropout_rate, name=None):
layer_attr=ExtraAttr(drop_rate=dropout_rate))
-def outputs(layers):
+def outputs(layers, *args):
"""
Declare the end of network. Currently it will only calculate the
input/output order of network. It will calculate the predict network or
@@ -1089,9 +1089,12 @@ def outputs(layers):
if isinstance(layers, LayerOutput):
layers = [layers]
+ if len(args) != 0:
+ layers.extend(args)
+
assert len(layers) > 0
if len(layers) != 1:
- logger.warning("EndOfNetwork routine try to calculate network's"
+ logger.warning("`outputs` routine try to calculate network's"
" inputs and outputs order. It might not work well."
"Please see follow log carefully.")
inputs = []
diff --git a/python/paddle/trainer_config_helpers/poolings.py b/python/paddle/trainer_config_helpers/poolings.py
index 5e06d82005841c9dad899399695a48d8c42280b0..d627daab0c496d4fa465c0d3afda3cec2b98c3f9 100644
--- a/python/paddle/trainer_config_helpers/poolings.py
+++ b/python/paddle/trainer_config_helpers/poolings.py
@@ -47,9 +47,14 @@ class MaxPooling(BasePoolingType):
.. math::
max(samples\\_of\\_a\\_sequence)
+
+ :param output_max_index: True if output sequence max index instead of max
+ value. None means use default value in proto.
+ :type output_max_index: bool|None
"""
- def __init__(self):
+ def __init__(self, output_max_index=None):
BasePoolingType.__init__(self, "max")
+ self.output_max_index = output_max_index
class AvgPooling(BasePoolingType):
diff --git a/python/paddle/trainer_config_helpers/tests/CMakeLists.txt b/python/paddle/trainer_config_helpers/tests/CMakeLists.txt
index 611fb855a8c9ad6679167105dd737c995b23c209..cf52b06bfea06d6a88b5f11558627936a74bd0b9 100644
--- a/python/paddle/trainer_config_helpers/tests/CMakeLists.txt
+++ b/python/paddle/trainer_config_helpers/tests/CMakeLists.txt
@@ -3,3 +3,8 @@ add_test(NAME layers_test
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/layers_test.py
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
+
+add_test(NAME test_layerHelpers
+ COMMAND
+ ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/configs/run_tests.sh
+)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/.gitignore b/python/paddle/trainer_config_helpers/tests/configs/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..52378fe7a486589352182ef4da6186365daf4bde
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/.gitignore
@@ -0,0 +1 @@
+*protostr
diff --git a/python/paddle/trainer_config_helpers/tests/configs/check.md5 b/python/paddle/trainer_config_helpers/tests/configs/check.md5
new file mode 100644
index 0000000000000000000000000000000000000000..29928b6f7b4239a0240b9fc035b6e1568427a9aa
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/check.md5
@@ -0,0 +1,17 @@
+7e6919d17562516e9a1d9a88de1fb3b9 img_layers.protostr
+a5d9259ff1fd7ca23d0ef090052cb1f2 last_first_seq.protostr
+9c038249ec8ff719753a746cdb04c026 layer_activations.protostr
+5913f87b39cee3b2701fa158270aca26 projections.protostr
+6b39e34beea8dfb782bee9bd3dea9eb5 simple_rnn_layers.protostr
+0fc1409600f1a3301da994ab9d28b0bf test_cost_layers.protostr
+144bc6d3a509de74115fa623741797ed test_expand_layer.protostr
+2378518bdb71e8c6e888b1842923df58 test_fc.protostr
+8bb44e1e5072d0c261572307e7672bda test_grumemory_layer.protostr
+1f3510672dce7a9ed25317fc58579ac7 test_hsigmoid.protostr
+d350bd91a0dc13e854b1364c3d9339c6 test_lstmemory_layer.protostr
+251a948ba41c1071afcd3d9cf9c233f7 test_ntm_layers.protostr
+e6ff04e70aea27c7b06d808cc49c9497 test_print_layer.protostr
+2a75dd33b640c49a8821c2da6e574577 test_rnn_group.protostr
+67d6fde3afb54f389d0ce4ff14726fe1 test_sequence_pooling.protostr
+f586a548ef4350ba1ed47a81859a64cb unused_layers.protostr
+8122477f4f65244580cec09edc590041 util_layers.protostr
diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh
new file mode 100755
index 0000000000000000000000000000000000000000..fc2acbd41ed906f04c7a3d88c5e0b01a0779bd7b
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+
+set -e
+cd `dirname $0`
+export PYTHONPATH=$PWD/../../../../
+
+configs=(test_fc layer_activations projections test_print_layer
+test_sequence_pooling test_lstmemory_layer test_grumemory_layer
+last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
+img_layers util_layers simple_rnn_layers unused_layers test_cost_layers
+test_rnn_group)
+
+
+for conf in ${configs[*]}
+do
+ echo "Generating " $conf
+ python -m paddle.utils.dump_config $conf.py > $conf.protostr
+done
diff --git a/python/paddle/trainer_config_helpers/tests/configs/img_layers.py b/python/paddle/trainer_config_helpers/tests/configs/img_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c8ba8be846e5d943a5b1f034e2dabaaf001cede
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/img_layers.py
@@ -0,0 +1,20 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ learning_rate=1e-3,
+ batch_size=1000
+)
+
+img = data_layer(name='image', size=256*256)
+
+img_conv = img_conv_layer(input=img, num_channels=1, num_filters=64,
+ filter_size=(32, 64), padding=(1, 0), stride=(1, 1),
+ act=LinearActivation())
+img_bn = batch_norm_layer(input=img_conv, act=ReluActivation())
+
+img_norm = img_cmrnorm_layer(input=img_bn, size=32)
+
+img_pool = img_pool_layer(input=img_conv, pool_size=32, pool_type=MaxPooling())
+
+
+outputs(img_pool, img_norm)
\ No newline at end of file
diff --git a/python/paddle/trainer_config_helpers/tests/configs/last_first_seq.py b/python/paddle/trainer_config_helpers/tests/configs/last_first_seq.py
new file mode 100644
index 0000000000000000000000000000000000000000..d54a1c49fd3fdf9eb8a675dd94561e6da5b310bc
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/last_first_seq.py
@@ -0,0 +1,26 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ batch_size=1000,
+ learning_rate=1e-5
+)
+
+din = data_layer(name='data', size=30)
+
+seq_op = [
+ first_seq,
+ last_seq
+]
+
+agg_level = [
+ AggregateLevel.EACH_SEQUENCE,
+ AggregateLevel.EACH_TIMESTEP
+]
+
+opts = []
+
+for op in seq_op:
+ for al in agg_level:
+ opts.append(op(input=din, agg_level=al))
+
+outputs(opts)
\ No newline at end of file
diff --git a/python/paddle/trainer_config_helpers/tests/configs/layer_activations.py b/python/paddle/trainer_config_helpers/tests/configs/layer_activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba10dc78e1e3bf382c14c62d542256c957c1fdf5
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/layer_activations.py
@@ -0,0 +1,21 @@
+'''
+Test all activations.
+'''
+
+from paddle.trainer_config_helpers import *
+
+settings(
+ learning_rate=1e-4,
+ batch_size=1000
+)
+
+din = data_layer(name='input', size=100)
+
+acts = [
+ TanhActivation, SigmoidActivation, SoftmaxActivation, IdentityActivation,
+ LinearActivation, ExpActivation, ReluActivation, BReluActivation,
+ SoftReluActivation, STanhActivation, AbsActivation, SquareActivation]
+
+outputs(
+ [fc_layer(input=din, size=100, act=act(), name="layer_%d" % i) for i, act in
+ enumerate(acts)])
diff --git a/python/paddle/trainer_config_helpers/tests/configs/projections.py b/python/paddle/trainer_config_helpers/tests/configs/projections.py
new file mode 100644
index 0000000000000000000000000000000000000000..4066c5bc6e0f06e43b1c4d13020c092babdaea91
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/projections.py
@@ -0,0 +1,47 @@
+'''
+Test mixed layer, projections and operators.
+'''
+from paddle.trainer_config_helpers import *
+
+settings(
+ batch_size=1000,
+ learning_rate=1e-4
+)
+
+din = data_layer(name='test', size=100)
+
+din = embedding_layer(input=din, size=256)
+
+with mixed_layer(size=100) as m1:
+ m1 += full_matrix_projection(input=din)
+
+with mixed_layer(size=100) as m2:
+ m2 += table_projection(input=m1)
+
+with mixed_layer(size=100) as m3:
+ m3 += identity_projection(input=m2)
+
+with mixed_layer(size=100) as m4:
+ m4 += dotmul_projection(input=m3)
+
+with mixed_layer() as m5:
+ m5 += context_projection(input=m4, context_len=3)
+
+with mixed_layer() as m6:
+ m6 += dotmul_operator(a=m3, b=m4)
+
+img = data_layer(name='img', size=32*32)
+flt = data_layer(name='filter', size=3*3*1*64)
+
+with mixed_layer() as m7:
+ m7 += conv_operator(img=img, filter=flt, num_filters=64,
+ num_channel=1, filter_size=3)
+
+end = mixed_layer(input=[full_matrix_projection(input=m5),
+ trans_full_matrix_projection(input=m6),
+ full_matrix_projection(input=m7)],
+ size=100,
+ layer_attr=ExtraAttr(drop_rate=0.5,
+ error_clipping_threshold=40))
+
+outputs(end)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/run_tests.sh b/python/paddle/trainer_config_helpers/tests/configs/run_tests.sh
new file mode 100755
index 0000000000000000000000000000000000000000..78114ce32b019cde7a028acde4d281cf6b3dac8e
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/run_tests.sh
@@ -0,0 +1,5 @@
+#!/bin/bash
+cd `dirname $0`
+set -e
+./generate_protostr.sh
+md5sum -c check.md5
diff --git a/python/paddle/trainer_config_helpers/tests/configs/simple_rnn_layers.py b/python/paddle/trainer_config_helpers/tests/configs/simple_rnn_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..87c2a85cf92dde19cc78d14d5d212940f11546f9
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/simple_rnn_layers.py
@@ -0,0 +1,36 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ batch_size=1000,
+ learning_rate=1e-4
+)
+
+din = data_layer(name='data', size=200)
+
+hidden = fc_layer(input=din, size=200, act=SigmoidActivation())
+
+rnn = recurrent_layer(input=hidden, act=SigmoidActivation())
+
+rnn2 = recurrent_layer(input=hidden, act=SigmoidActivation(), reverse=True)
+
+lstm1_param = fc_layer(input=hidden, size=200*4, act=LinearActivation(),
+ bias_attr=False)
+
+lstm1 = lstmemory(input=lstm1_param, act=SigmoidActivation())
+
+lstm2_param = fc_layer(input=hidden, size=200*4, act=LinearActivation(),
+ bias_attr=False)
+
+lstm2 = lstmemory(input=lstm2_param, act=SigmoidActivation(), reverse=True)
+
+gru1_param = fc_layer(input=hidden, size=200*3, act=LinearActivation(),
+ bias_attr=False)
+gru1 = grumemory(input=gru1_param, act=SigmoidActivation())
+
+gru2_param = fc_layer(input=hidden, size=200*3, act=LinearActivation(),
+ bias_attr=False)
+gru2 = grumemory(input=gru2_param, act=SigmoidActivation(), reverse=True)
+
+outputs(last_seq(input=rnn), first_seq(input=rnn2),
+ last_seq(input=lstm1), first_seq(input=lstm2),
+ last_seq(input=gru1), first_seq(gru2))
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers.py b/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..64b45f4ded10b09ec4a7e77499e2d7b21215f430
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers.py
@@ -0,0 +1,26 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ learning_rate=1e-4,
+ batch_size=1000
+)
+
+seq_in = data_layer(name='input', size=200)
+labels = data_layer(name='labels', size=5000)
+
+probs = data_layer(name='probs', size=10)
+xe_label = data_layer(name='xe-label', size=10)
+
+outputs(ctc_layer(input=seq_in, label=labels),
+ crf_layer(input=fc_layer(input=seq_in, size=4),
+ label=data_layer(name='crf_label', size=4)),
+ rank_cost(left=data_layer(name='left', size=1),
+ right=data_layer(name='right', size=1),
+ label=data_layer(name='label', size=1)),
+ lambda_cost(input=data_layer(name='list_feature', size=100),
+ score=data_layer(name='list_scores', size=1)),
+ cross_entropy(input=probs, label=xe_label),
+ cross_entropy_with_selfnorm(input=probs, label=xe_label),
+ huber_cost(input=data_layer(name='huber_probs', size=1),
+ label=data_layer(name='huber_label', size=1)),
+ multi_binary_label_cross_entropy(input=probs, label=xe_label))
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_expand_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_expand_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9c841ab277e10fe3bbb6751d002af62862e9237
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_expand_layer.py
@@ -0,0 +1,14 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ batch_size=1000,
+ learning_rate=1e-5
+)
+
+din = data_layer(name='data', size=30)
+data_seq = data_layer(name='data_seq', size=30)
+
+outputs(expand_layer(input=din, expand_as=data_seq,
+ expand_level=ExpandLevel.FROM_SEQUENCE),
+ expand_layer(input=din, expand_as=data_seq,
+ expand_level=ExpandLevel.FROM_TIMESTEP))
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_fc.py b/python/paddle/trainer_config_helpers/tests/configs/test_fc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6d033f291d2c60086a0c6e7de2005c4acfbbc03
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_fc.py
@@ -0,0 +1,20 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ batch_size=1000,
+ learning_rate=1e-5
+)
+
+din = data_layer(name='data', size=100)
+
+trans = trans_layer(input=din)
+
+hidden = fc_layer(input=trans, size=100,
+ bias_attr=False)
+
+mask = data_layer(name='mask', size=100)
+
+hidden_sel = selective_fc_layer(input=din, select=mask, size=100,
+ act=SigmoidActivation())
+
+outputs(hidden, hidden_sel)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_grumemory_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_grumemory_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d9fd9df5179c7006afd369879301e823e11bdb5
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_grumemory_layer.py
@@ -0,0 +1,11 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ batch_size=1000,
+ learning_rate=1e-4
+)
+
+din = data_layer(name='data', size=120)
+
+outputs(grumemory(input=din, size=40, reverse=True, gate_act=TanhActivation(),
+ act=SigmoidActivation()))
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_hsigmoid.py b/python/paddle/trainer_config_helpers/tests/configs/test_hsigmoid.py
new file mode 100644
index 0000000000000000000000000000000000000000..46069074ded56098e3fb995dda0ad360fc897900
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_hsigmoid.py
@@ -0,0 +1,11 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ learning_rate=1e-4,
+ batch_size=1000
+)
+
+din = data_layer(name='data', size=100)
+label = data_layer(name='label', size=10)
+
+outputs(hsigmoid(input=din, label=label, num_classes=10))
\ No newline at end of file
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_lstmemory_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_lstmemory_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..56304addb17b233229a0bf717378833195e3188f
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_lstmemory_layer.py
@@ -0,0 +1,11 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ batch_size=1000,
+ learning_rate=1e-5
+)
+
+din = data_layer(name='data', size=128)
+
+outputs(lstmemory(input=din, reverse=True, gate_act=TanhActivation(),
+ act=TanhActivation(), size=32))
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_ntm_layers.py b/python/paddle/trainer_config_helpers/tests/configs/test_ntm_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d8e1fdc6b598d3d0e29d3834c804e5a6976710b
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_ntm_layers.py
@@ -0,0 +1,23 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ batch_size=1000,
+ learning_rate=1e-5
+)
+
+weight = data_layer(name='w', size=1)
+a = data_layer(name='a', size=100)
+b = data_layer(name='b', size=100)
+c = data_layer(name='c', size=200)
+d = data_layer(name='d', size=31)
+
+outputs(interpolation_layer(input=[a, b], weight=weight),
+ power_layer(input=a, weight=weight),
+ scaling_layer(input=a, weight=weight),
+ cos_sim(a=a, b=b),
+ cos_sim(a=a, b=c, size=2),
+ sum_to_one_norm_layer(input=a),
+ conv_shift_layer(a=a, b=d),
+ tensor_layer(a=a, b=b, size=1000),
+ slope_intercept_layer(input=a, slope=0.7, intercept=0.9),
+ linear_comb_layer(weights=b, vectors=c))
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_print_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_print_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6b2661c7b9e8563d3816def9490e22962e3f7cb
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_print_layer.py
@@ -0,0 +1,12 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ learning_rate=1e-4,
+ batch_size=1000
+)
+
+din = data_layer(name='input', size=100)
+
+print_layer(input=din)
+
+outputs(din)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_rnn_group.py b/python/paddle/trainer_config_helpers/tests/configs/test_rnn_group.py
new file mode 100644
index 0000000000000000000000000000000000000000..53f5c5d2499f9e6693d40b05354af83614fde8f5
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_rnn_group.py
@@ -0,0 +1,35 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ learning_rate=1e-4,
+ batch_size=1000
+)
+
+seq = data_layer(name='seq_input', size=100)
+sub_seq = data_layer(name='sub_seq_input', size=100)
+lbl = data_layer(name='label', size=1)
+
+
+def generate_rnn_simple(name):
+ def rnn_simple(s):
+ m = memory(name=name, size=200)
+ fc = fc_layer(input=[s, m], size=200, name=name)
+ return fc
+
+ return rnn_simple
+
+
+with mixed_layer() as lstm_param: # test lstm unit, rnn group
+ lstm_param += full_matrix_projection(input=seq, size=100 * 4)
+
+with mixed_layer() as gru_param:
+ gru_param += full_matrix_projection(input=seq, size=100 * 3)
+
+outputs(last_seq(input=recurrent_group(step=generate_rnn_simple('rnn_forward'),
+ input=seq)),
+ first_seq(input=recurrent_group(step=generate_rnn_simple('rnn_back'),
+ input=seq, reverse=True)),
+ last_seq(input=recurrent_group(step=generate_rnn_simple(
+ 'rnn_subseq_forward'), input=SubsequenceInput(input=sub_seq))),
+ last_seq(input=lstmemory_group(input=lstm_param, size=100)),
+ last_seq(input=gru_group(input=gru_param, size=100)))
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_sequence_pooling.py b/python/paddle/trainer_config_helpers/tests/configs/test_sequence_pooling.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e24164b5578cfe2abdb2a0ad889d5fd0c3f6e57
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_sequence_pooling.py
@@ -0,0 +1,30 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ learning_rate=1e-4,
+ batch_size=1000
+)
+
+din = data_layer(name='dat_in', size=100)
+
+POOL_TYPE = [
+ MaxPooling,
+ AvgPooling,
+ SumPooling
+]
+
+AGG_LEVEL = [
+ AggregateLevel.EACH_SEQUENCE,
+ AggregateLevel.EACH_TIMESTEP
+]
+
+opts = []
+
+for pt in POOL_TYPE:
+ for al in AGG_LEVEL:
+ opts.append(pooling_layer(input=din, agg_level=al, pooling_type=pt()))
+
+opts.append(pooling_layer(input=din,
+ pooling_type=MaxPooling(output_max_index=True)))
+
+outputs(opts)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/unused_layers.py b/python/paddle/trainer_config_helpers/tests/configs/unused_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6a3d09a4315ac3be63b8fecdb0fa359de834df0
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/unused_layers.py
@@ -0,0 +1,14 @@
+from paddle.trainer_config_helpers import *
+settings(
+ batch_size=1000,
+ learning_rate=1e-4
+)
+
+probs = data_layer(name='probs', size=100)
+
+outputs(
+ sampling_id_layer(input=probs), # It seems not support training
+
+ # It seems this layer is not correct, and should be rewrite.
+ # block_expand_layer(input=probs, channel=1, block_x=1, block_y=3),
+)
\ No newline at end of file
diff --git a/python/paddle/trainer_config_helpers/tests/configs/util_layers.py b/python/paddle/trainer_config_helpers/tests/configs/util_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..aadb3f3f5e7997a56a73724b2858346e5ae17179
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/util_layers.py
@@ -0,0 +1,15 @@
+from paddle.trainer_config_helpers import *
+
+settings(learning_rate=1e-4, batch_size=1000)
+
+a = data_layer(name='a', size=10)
+b = data_layer(name='b', size=10)
+
+result = addto_layer(input=[a, b])
+concat1 = concat_layer(input=[a, b])
+concat2 = concat_layer(input=[
+ identity_projection(input=a),
+ identity_projection(input=b)
+])
+
+outputs(result, concat1, concat2)
\ No newline at end of file
diff --git a/python/paddle/trainer_config_helpers/tests/layers_test_config.py b/python/paddle/trainer_config_helpers/tests/layers_test_config.py
index 27b22ecb701c52ab2a0a1f5f95d7b07186fbbb58..faaab9107d8fbf11afafb722075acbe986efe9fd 100644
--- a/python/paddle/trainer_config_helpers/tests/layers_test_config.py
+++ b/python/paddle/trainer_config_helpers/tests/layers_test_config.py
@@ -23,6 +23,15 @@ z = out_prod_layer(input1=x, input2=y)
x1 = fc_layer(input=x, size=5)
y1 = fc_layer(input=y, size=5)
+
+z1 = mixed_layer(act=LinearActivation(),
+ input=[conv_operator(img=x1,
+ filter=y1,
+ filter_size=1,
+ num_filters=5,
+ num_channel=5,
+ stride=1)])
+
y2 = fc_layer(input=y, size=15)
cos1 = cos_sim(a=x1, b=y1)
@@ -30,7 +39,7 @@ cos3 = cos_sim(a=x1, b=y2, size=3)
linear_comb = linear_comb_layer(weights=x1, vectors=y2, size=3)
-out = fc_layer(input=[cos1, cos3, linear_comb, z],
+out = fc_layer(input=[cos1, cos3, linear_comb, z, z1],
size=num_classes,
act=SoftmaxActivation())
@@ -38,11 +47,21 @@ print_layer(input=[out])
outputs(classification_cost(out, data_layer(name="label", size=num_classes)))
-dotmul = mixed_layer(input=[dotmul_operator(x=x1, y=y1),
- dotmul_projection(input=y1)])
+dotmul = mixed_layer(input=[dotmul_operator(a=x1, b=x1),
+ dotmul_projection(input=y1)])
+
+proj_with_attr_init = mixed_layer(input=full_matrix_projection(input=y1,
+ param_attr=ParamAttr(learning_rate = 0,
+ initial_mean = 0,
+ initial_std = 0)),
+ bias_attr = ParamAttr(initial_mean=0, initial_std=0, learning_rate=0),
+ act = LinearActivation(),
+ size = 5,
+ name='proj_with_attr_init')
+
# for ctc
-tmp = fc_layer(input=[x1, dotmul],
+tmp = fc_layer(input=[x1, dotmul, proj_with_attr_init],
size=num_classes + 1,
act=SoftmaxActivation())
ctc = ctc_layer(input=tmp,