diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ba29d6bbcc4acf9538973562df55b823e6428ef..6aeef23330170a80348ad54e57dcfde9b9ba08eb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,8 @@ cmake_minimum_required(VERSION 3.0) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) +SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") +SET(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") include(system) diff --git a/doc/design/evaluator.md b/doc/design/evaluator.md index a62d75ffef14962aec8c7587e172d78dfe0cb4be..11cc129d56905a9ee666da92fbe6f8559c6d325a 100644 --- a/doc/design/evaluator.md +++ b/doc/design/evaluator.md @@ -1,22 +1,22 @@ ## Evaluator Design -### The Problem +### Problem Statement -During training or serving, we provide the evaluation function to measure the model performance, e.g., accuracy, precision. In the operator based framework design, the data go through the network pipeline batch by batch. As a result, inside the operator, we only can calculate one minibatch metrics. We need to provide a mechanism to calculate the metrics for each N pass/batch the user wanted. +During training or inference, we provide an evaluation function to measure the model performance, for example, accuracy, precision, etc. In the operator based framework design, the data passes through the network pipeline batch by batch. As a result, inside the operator, we only calculate the metrics for one minibatch. Thus, we need to provide a mechanism to calculate the metrics for each N pass/batch the user wants. ### Evaluator Design -Currently, every operation is expressed in the graph. we divide the evaluator process into three steps. +Currently, every operation is expressed in the graph. We divide the evaluator process into three steps. 1. Initialize the metric state and add it into the block. -2. Calculate the statistic of the metric state in every mini-batch. The single operator is only responsible for calculating necessary statistics for one mini-batch. For example, accuracy operator only calculate a minibatch data if run once. +2. Calculate the concerned metrics for every mini-batch. The single evaluator operator is only responsible for calculating the necessary statistics for one mini-batch. For example, the accuracy operator only calculates the accuracy for a minibatch data if run once. 3. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. When it comes to distributed training/Multi-GPU training, aggregate the value from different devices. ### Implementation -This design is shown in python API. -Each metric operator need to caculate the metric statistic and return the batch aware states, Python side responsible for accumulate the states for each pass. +This design is shown in the Python API. +Each metric operator needs to caculate the metric statistic and return the batch-aware states. Python side is responsible for accumulating the states for each pass. ```python diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index 24e6cae8e69557c42ed5d437edce101709ca3983..b578a906c2027a1169a0098b93f8d0742920f99d 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -1,5 +1,4 @@ # gserver pacakge unittests - add_simple_unittest(test_LinearChainCRF) add_simple_unittest(test_RecurrentLayer) @@ -29,6 +28,26 @@ gserver_test(test_KmaxSeqScore) gserver_test(test_Expand) gserver_test(test_MaxPoolingWithMaskOutput) +set(PYTHON_PATH + ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d + ${PADDLE_SOURCE_DIR}/python/:${PADDLE_SOURCE_DIR}/paddle/gserver/tests) +function(gserver_test_with_python TARGET) + add_unittest_without_exec(${TARGET} ${TARGET}.cpp) + add_test(NAME ${TARGET} + COMMAND ${PYTHON_PATH} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET} + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) +endfunction() + +gserver_test_with_python(test_PyDataProvider2) +if(WITH_PYTHON) + gserver_test_with_python(test_PyDataProvider) +endif() +if(NOT MOBILE_INFERENCE) + gserver_test_with_python(test_CompareTwoNets) + # TODO(yuyang18): There is some bug in test_RecurrentGradientMachine, I will fix it. + gserver_test_with_python(test_RecurrentGradientMachine) +endif() + ########## test_MKLDNN layers and activations ########## if(WITH_MKLDNN) add_unittest_without_exec(test_MKLDNN @@ -36,86 +55,43 @@ if(WITH_MKLDNN) MKLDNNTester.cpp LayerGradUtil.cpp) add_test(NAME test_MKLDNN - COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python - ${CMAKE_CURRENT_BINARY_DIR}/test_MKLDNN + COMMAND ${PYTHON_PATH} ${CMAKE_CURRENT_BINARY_DIR}/test_MKLDNN WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) endif() -############## test_PyDataProvider ######################## -if(WITH_PYTHON) - add_unittest_without_exec(test_PyDataProvider - test_PyDataProvider.cpp) - - add_test(NAME test_PyDataProvider - COMMAND .set_python_path.sh -d ./gserver/tests:${PADDLE_SOURCE_DIR}/python/ ${CMAKE_CURRENT_BINARY_DIR}/test_PyDataProvider - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) -endif() - ############### test_WarpCTCLayer ####################### if(NOT WITH_DOUBLE AND NOT MOBILE_INFERENCE) add_unittest_without_exec(test_WarpCTCLayer test_WarpCTCLayer.cpp) - add_test(NAME test_WarpCTCLayer COMMAND ${CMAKE_CURRENT_BINARY_DIR}/test_WarpCTCLayer --warpctc_dir=${WARPCTC_LIB_DIR} WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) endif() if(NOT MOBILE_INFERENCE) - ################## test_Evaluator ####################### + ################## test_Evaluator ############# add_unittest(test_Evaluator test_Evaluator.cpp) - ############### test_RecurrentGradientMachine ############### - # TODO(yuyang18): There is some bug in test_RecurrentGradientMachine - # I will fix it. - add_unittest_without_exec(test_RecurrentGradientMachine - test_RecurrentGradientMachine.cpp) - add_test(NAME test_RecurrentGradientMachine - COMMAND .set_python_path.sh -d - ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests - ${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) - - ############### test_NetworkCompare ############### + ########### test_NetworkCompare ############### add_unittest_without_exec(test_NetworkCompare test_NetworkCompare.cpp) if(WITH_GPU) - add_test(NAME test_NetworkCompare - COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_NetworkCompare --use_gpu=true - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) + set(use_gpu true) else() - add_test(NAME test_NetworkCompare - COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_NetworkCompare --use_gpu=false - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) + set(use_gpu false) endif() + add_test(NAME test_NetworkCompare + COMMAND ${PYTHON_PATH} ${CMAKE_CURRENT_BINARY_DIR}/test_NetworkCompare --use_gpu=${use_gpu} + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) - ################# test_CompareSparse ################## + ############ test_CompareSparse ################ add_unittest_without_exec(test_CompareSparse test_CompareSparse.cpp) if(NOT ON_TRAVIS) add_test(NAME test_CompareSparse - COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d - ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests - ./.set_port.sh -p port -n 6 - ${CMAKE_CURRENT_BINARY_DIR}/test_CompareSparse + COMMAND ${PYTHON_PATH} ./.set_port.sh -p port -n 6 + ${CMAKE_CURRENT_BINARY_DIR}/test_CompareSparse WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) endif() - - ################ test_CompareTwoNets ###################### - add_unittest_without_exec(test_CompareTwoNets - test_CompareTwoNets.cpp) - add_test(NAME test_CompareTwoNets - COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d - ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests - ${CMAKE_CURRENT_BINARY_DIR}/test_CompareTwoNets - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) endif() - -################ test_PyDataProvider2 ###################### -add_unittest_without_exec(test_PyDataProvider2 - test_PyDataProvider2.cpp) -add_test(NAME test_PyDataProvider2 - COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/paddle/gserver/tests:${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_PyDataProvider2 - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle -) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 4cbb60f3fdab968e8c36d4fbad55fd3efc7b1d0d..fa8e5f2da8d3ba7fc218f3c836270139fc6c4882 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -181,7 +181,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Long-Short Term Memory (LSTM) Operator. -The defalut implementation is diagonal/peephole connection +The defalut implementation is diagonal/peephole connection (https://arxiv.org/pdf/1402.1128.pdf), the formula is as follows: $$ @@ -198,27 +198,27 @@ c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t} \\ h_t = o_t \odot act_h(c_t) $$ -where the W terms denote weight matrices (e.g. \f$W_{xi}\f$ is the matrix -of weights from the input gate to the input), \f$W_{ic}, W_{fc}, W_{oc}\f$ +where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix +of weights from the input gate to the input), $W_{ic}, W_{fc}, W_{oc}$ are diagonal weight matrices for peephole connections. In our implementation, we use vectors to reprenset these diagonal weight matrices. The b terms -denote bias vectors (\f$b_i\f$ is the input gate bias vector), \f$\sigma\f$ +denote bias vectors ($b_i$ is the input gate bias vector), $\sigma$ is the non-line activations, such as logistic sigmoid function, and -\f$i, f, o\f$ and \f$c\f$ are the input gate, forget gate, output gate, +$i, f, o$ and $c$ are the input gate, forget gate, output gate, and cell activation vectors, respectively, all of which have the same size as -the cell output activation vector \f$h\f$. +the cell output activation vector $h$. -The \f$\odot\f$ is the element-wise product of the vectors. \f$act_g\f$ and \f$act_h\f$ +The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$ are the cell input and cell output activation functions and `tanh` is usually -used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state, +used for them. $\tilde{c_t}$ is also called candidate hidden state, which is computed based on the current input and the previous hidden state. -Set `use_peepholes` False to disable peephole connection -(http://www.bioinf.jku.at/publications/older/2604.pdf). The formula -is omitted here. +Set `use_peepholes` False to disable peephole connection. The formula +is omitted here, please refer to the paper +http://www.bioinf.jku.at/publications/older/2604.pdf for details. -Note that these \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$ -operations on the input \f$x_{t}\f$ are NOT included in this operator. +Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$ +operations on the input $x_{t}$ are NOT included in this operator. Users can choose to use fully-connect operator before LSTM operator. )DOC"); diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 721aa42c92f2926aabbc13d0a9027b2b4e573225..a78f548aafbcdb3834a1ed56f9b31143fae37386 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel { T* bias_data = const_cast(bias->data()); // the code style in LstmMetaValue will be updated later. - lstm_value.checkIg = bias_data + 4 * frame_size; - lstm_value.checkFg = lstm_value.checkIg + frame_size; - lstm_value.checkOg = lstm_value.checkFg + frame_size; + lstm_value.check_ig = bias_data + 4 * frame_size; + lstm_value.check_fg = lstm_value.check_ig + frame_size; + lstm_value.check_og = lstm_value.check_fg + frame_size; } else { - lstm_value.checkIg = nullptr; - lstm_value.checkFg = nullptr; - lstm_value.checkOg = nullptr; + lstm_value.check_ig = nullptr; + lstm_value.check_fg = nullptr; + lstm_value.check_og = nullptr; } - lstm_value.prevStateValue = nullptr; + lstm_value.prev_state_value = nullptr; Tensor ordered_c0; const size_t* order = batch_gate->lod()[2].data(); if (cell_t0) { @@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel { // to reorder. ReorderInitState(device_ctx, *cell_t0, order, &ordered_c0, true); - lstm_value.prevStateValue = ordered_c0.data(); + lstm_value.prev_state_value = ordered_c0.data(); } // Use the local variable as here. @@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel { static_cast(1.0)); } - lstm_value.gateValue = gate_t.data(); - lstm_value.outputValue = out_t.data(); - lstm_value.stateValue = cell_t.data(); - lstm_value.stateActiveValue = cell_pre_act_t.data(); + lstm_value.gate_value = gate_t.data(); + lstm_value.output_value = out_t.data(); + lstm_value.state_value = cell_t.data(); + lstm_value.state_active_value = cell_pre_act_t.data(); math::LstmUnitFunctor::compute(device_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act, cand_act); - lstm_value.prevStateValue = lstm_value.stateValue; + lstm_value.prev_state_value = lstm_value.state_value; } math::Batch2LoDTensorFunctor to_seq; @@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel { math::LstmMetaValue lstm_value; if (bias && ctx.Attr("use_peepholes")) { T* bias_data = const_cast(bias->data()); - lstm_value.checkIg = bias_data + 4 * frame_size; - lstm_value.checkFg = lstm_value.checkIg + frame_size; - lstm_value.checkOg = lstm_value.checkFg + frame_size; + lstm_value.check_ig = bias_data + 4 * frame_size; + lstm_value.check_fg = lstm_value.check_ig + frame_size; + lstm_value.check_og = lstm_value.check_fg + frame_size; } else { - lstm_value.checkIg = nullptr; - lstm_value.checkFg = nullptr; - lstm_value.checkOg = nullptr; + lstm_value.check_ig = nullptr; + lstm_value.check_fg = nullptr; + lstm_value.check_og = nullptr; } math::LstmMetaGrad lstm_grad; @@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel { } if (bias && bias_g && ctx.Attr("use_peepholes")) { T* bias_g_data = bias_g->data(); - lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size; - lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size; - lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size; + lstm_grad.check_ig_grad = bias_g_data + 4 * frame_size; + lstm_grad.check_fg_grad = lstm_grad.check_ig_grad + frame_size; + lstm_grad.check_og_grad = lstm_grad.check_fg_grad + frame_size; } else { - lstm_grad.checkIgGrad = nullptr; - lstm_grad.checkFgGrad = nullptr; - lstm_grad.checkOgGrad = nullptr; + lstm_grad.check_ig_grad = nullptr; + lstm_grad.check_fg_grad = nullptr; + lstm_grad.check_og_grad = nullptr; } math::LoDTensor2BatchFunctor to_batch; @@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel { Tensor gate = batch_gate->Slice(bstart, bend); Tensor cell = batch_cell.Slice(bstart, bend); Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend); - lstm_value.gateValue = gate.data(); - lstm_value.stateValue = cell.data(); - lstm_value.stateActiveValue = cell_pre_act.data(); + lstm_value.gate_value = gate.data(); + lstm_value.state_value = cell.data(); + lstm_value.state_active_value = cell_pre_act.data(); Tensor out_g = batch_hidden_g.Slice(bstart, bend); Tensor gate_g = batch_gate_g.Slice(bstart, bend); Tensor cell_g = batch_cell_g.Slice(bstart, bend); - lstm_grad.stateGrad = cell_g.data(); - lstm_grad.gateGrad = gate_g.data(); - lstm_grad.outputGrad = out_g.data(); + lstm_grad.state_grad = cell_g.data(); + lstm_grad.gate_grad = gate_g.data(); + lstm_grad.output_grad = out_g.data(); if (n > 0) { int bstart_pre = static_cast(batch_starts[n - 1]); Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart); Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart); - lstm_value.prevStateValue = cell_pre.data(); - lstm_grad.prevStateGrad = cell_pre_g.data(); + lstm_value.prev_state_value = cell_pre.data(); + lstm_grad.prev_state_grad = cell_pre_g.data(); } else { - lstm_value.prevStateValue = c0 ? ordered_c0.data() : nullptr; - lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data() : nullptr; + lstm_value.prev_state_value = c0 ? ordered_c0.data() : nullptr; + lstm_grad.prev_state_grad = c0_g ? ordered_c0_g.data() : nullptr; } int cur_batch_size = bend - bstart; diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h index fc3ad0ce58aa1552ef7e717fb529c2d454b4895a..a734ad31eea4816e952641bad73776d93d8c8d34 100644 --- a/paddle/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -26,278 +26,284 @@ namespace detail { template void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frameSize, + int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - T rValueIn; - T rValueIg; - T rValueFg; - T rValueOg; - T rCheckI; - T rCheckF; - T rCheckO; - T rState; - T rPrevState = 0; - T rStateAtv; - T rOut; - - T *valueIn = value.gateValue; - T *valueIg = value.gateValue + frameSize; - T *valueFg = value.gateValue + frameSize * 2; - T *valueOg = value.gateValue + frameSize * 3; - - for (int i = 0; i < frameSize; i++) { - rValueIn = valueIn[i]; - rValueIg = valueIg[i]; - rValueFg = valueFg[i]; - rValueOg = valueOg[i]; - rCheckI = value.checkIg ? value.checkIg[i] : 0; - rCheckF = value.checkFg ? value.checkFg[i] : 0; - rCheckO = value.checkOg ? value.checkOg[i] : 0; - - if (value.prevStateValue) { - rPrevState = value.prevStateValue[i]; + T r_value_in; + T r_value_ig; + T r_value_fg; + T r_value_og; + T r_checkI; + T r_checkF; + T r_checkO; + T r_state; + T r_prev_state = 0; + T r_state_atv; + T r_out; + + T *value_in = value.gate_value; + T *value_ig = value.gate_value + frame_size; + T *value_fg = value.gate_value + frame_size * 2; + T *value_og = value.gate_value + frame_size * 3; + + for (int i = 0; i < frame_size; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + r_checkI = value.check_ig ? value.check_ig[i] : 0; + r_checkF = value.check_fg ? value.check_fg[i] : 0; + r_checkO = value.check_og ? value.check_og[i] : 0; + + if (value.prev_state_value) { + r_prev_state = value.prev_state_value[i]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); - - valueIn[i] = rValueIn; - valueIg[i] = rValueIg; - valueFg[i] = rValueFg; - valueOg[i] = rValueOg; - value.stateValue[i] = rState; - value.stateActiveValue[i] = rStateAtv; - value.outputValue[i] = rOut; + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, + r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, + active_gate, active_state); + + value_in[i] = r_value_in; + value_ig[i] = r_value_ig; + value_fg[i] = r_value_fg; + value_og[i] = r_value_og; + value.state_value[i] = r_state; + value.state_active_value[i] = r_state_atv; + value.output_value[i] = r_out; } } template void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frameSize, + LstmMetaGrad grad, int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - T rValueIn; - T rValueIg; - T rValueFg; - T rValueOg; - T rGradIn; - T rGradIg; - T rGradFg; - T rGradOg; - T rPrevState = 0; - T rPrevStateGrad; - T rState; - T rStateGrad; - T rStateAtv; - T rOutputGrad; - T rCheckI; - T rCheckF; - T rCheckO; - T rCheckIGrad; - T rCheckFGrad; - T rCheckOGrad; - - T *valueIn = value.gateValue; - T *valueIg = value.gateValue + frameSize; - T *valueFg = value.gateValue + frameSize * 2; - T *valueOg = value.gateValue + frameSize * 3; - T *gradIn = grad.gateGrad; - T *gradIg = grad.gateGrad + frameSize; - T *gradFg = grad.gateGrad + frameSize * 2; - T *gradOg = grad.gateGrad + frameSize * 3; - - for (int i = 0; i < frameSize; i++) { - rValueIn = valueIn[i]; - rValueIg = valueIg[i]; - rValueFg = valueFg[i]; - rValueOg = valueOg[i]; - rCheckI = value.checkIg ? value.checkIg[i] : 0; - rCheckF = value.checkFg ? value.checkFg[i] : 0; - rCheckO = value.checkOg ? value.checkOg[i] : 0; - rState = value.stateValue[i]; - rStateAtv = value.stateActiveValue[i]; - rOutputGrad = grad.outputGrad[i]; - rStateGrad = grad.stateGrad[i]; - if (value.prevStateValue) { - rPrevState = value.prevStateValue[i]; + T r_value_in; + T r_value_ig; + T r_value_fg; + T r_value_og; + T r_grad_in; + T r_grad_ig; + T r_grad_fg; + T r_grad_og; + T r_prev_state = 0; + T r_prev_state_grad; + T r_state; + T r_state_grad; + T r_state_atv; + T r_output_grad; + T r_checkI; + T r_checkF; + T r_checkO; + T r_checkIGrad; + T r_checkFGrad; + T r_checkOGrad; + + T *value_in = value.gate_value; + T *value_ig = value.gate_value + frame_size; + T *value_fg = value.gate_value + frame_size * 2; + T *value_og = value.gate_value + frame_size * 3; + T *grad_in = grad.gate_grad; + T *grad_ig = grad.gate_grad + frame_size; + T *grad_fg = grad.gate_grad + frame_size * 2; + T *grad_og = grad.gate_grad + frame_size * 3; + + for (int i = 0; i < frame_size; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + r_checkI = value.check_ig ? value.check_ig[i] : 0; + r_checkF = value.check_fg ? value.check_fg[i] : 0; + r_checkO = value.check_og ? value.check_og[i] : 0; + r_state = value.state_value[i]; + r_state_atv = value.state_active_value[i]; + r_output_grad = grad.output_grad[i]; + r_state_grad = grad.state_grad[i]; + if (value.prev_state_value) { + r_prev_state = value.prev_state_value[i]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, - rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, - rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, - rCheckOGrad, active_node, active_gate, active_state); - - gradIn[i] = rGradIn; - gradIg[i] = rGradIg; - gradFg[i] = rGradFg; - gradOg[i] = rGradOg; - grad.stateGrad[i] = rStateGrad; - - if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad; - if (value.prevStateValue) { - if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad; - if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad; + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, + r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, + r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, + r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, + active_state); + + grad_in[i] = r_grad_in; + grad_ig[i] = r_grad_ig; + grad_fg[i] = r_grad_fg; + grad_og[i] = r_grad_og; + grad.state_grad[i] = r_state_grad; + + if (grad.prev_state_grad) grad.prev_state_grad[i] = r_prev_state_grad; + if (value.prev_state_value) { + if (grad.check_ig_grad) grad.check_ig_grad[i] += r_checkIGrad; + if (grad.check_fg_grad) grad.check_fg_grad[i] += r_checkFGrad; } - if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad; + if (grad.check_og_grad) grad.check_og_grad[i] += r_checkOGrad; } } template -void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, int frameSize, +void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, + int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { #ifdef __AVX__ - __m256 rValueIn; - __m256 rValueIg; - __m256 rValueFg; - __m256 rValueOg; - __m256 rCheckI = _mm256_set1_ps(0.0f); - __m256 rCheckF = _mm256_set1_ps(0.0f); - __m256 rCheckO = _mm256_set1_ps(0.0f); - __m256 rState; - __m256 rPrevState = _mm256_set1_ps(0.0f); - __m256 rStateAtv; - __m256 rOut; - - __m256 *valueIn = (__m256 *)value.gateValue; - __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); - __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); - __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); - - for (int i = 0; i < frameSize / 8; i++) { - rValueIn = valueIn[i]; - rValueIg = valueIg[i]; - rValueFg = valueFg[i]; - rValueOg = valueOg[i]; - if (value.checkIg) { - rCheckI = ((__m256 *)value.checkIg)[i]; - rCheckF = ((__m256 *)value.checkFg)[i]; - rCheckO = ((__m256 *)value.checkOg)[i]; + __m256 r_value_in; + __m256 r_value_ig; + __m256 r_value_fg; + __m256 r_value_og; + __m256 r_checkI = _mm256_set1_ps(0.0f); + __m256 r_checkF = _mm256_set1_ps(0.0f); + __m256 r_checkO = _mm256_set1_ps(0.0f); + __m256 r_state; + __m256 r_prev_state = _mm256_set1_ps(0.0f); + __m256 r_state_atv; + __m256 r_out; + + __m256 *value_in = (__m256 *)value.gate_value; + __m256 *value_ig = (__m256 *)(value.gate_value + frame_size); + __m256 *value_fg = (__m256 *)(value.gate_value + frame_size * 2); + __m256 *value_og = (__m256 *)(value.gate_value + frame_size * 3); + + for (int i = 0; i < frame_size / 8; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + if (value.check_ig) { + r_checkI = ((__m256 *)value.check_ig)[i]; + r_checkF = ((__m256 *)value.check_fg)[i]; + r_checkO = ((__m256 *)value.check_og)[i]; } - if (value.prevStateValue) { - rPrevState = ((__m256 *)value.prevStateValue)[i]; + if (value.prev_state_value) { + r_prev_state = ((__m256 *)value.prev_state_value)[i]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); - - valueIn[i] = rValueIn; - valueIg[i] = rValueIg; - valueFg[i] = rValueFg; - valueOg[i] = rValueOg; - ((__m256 *)value.stateValue)[i] = rState; - ((__m256 *)value.stateActiveValue)[i] = rStateAtv; - ((__m256 *)value.outputValue)[i] = rOut; + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, + r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, + active_gate, active_state); + + value_in[i] = r_value_in; + value_ig[i] = r_value_ig; + value_fg[i] = r_value_fg; + value_og[i] = r_value_og; + ((__m256 *)value.state_value)[i] = r_state; + ((__m256 *)value.state_active_value)[i] = r_state_atv; + ((__m256 *)value.output_value)[i] = r_out; } #endif } template void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frameSize, + LstmMetaGrad grad, int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { #ifdef __AVX__ - __m256 rValueIn; - __m256 rValueIg; - __m256 rValueFg; - __m256 rValueOg; - __m256 rGradIn; - __m256 rGradIg; - __m256 rGradFg; - __m256 rGradOg; - __m256 rPrevState = _mm256_set1_ps(0.0f); - __m256 rPrevStateGrad; - __m256 rStateGrad; - __m256 rState; - __m256 rStateAtv; - __m256 rOutputGrad; - __m256 rCheckI = _mm256_set1_ps(0.0f); - __m256 rCheckF = _mm256_set1_ps(0.0f); - __m256 rCheckO = _mm256_set1_ps(0.0f); - __m256 rCheckIGrad; - __m256 rCheckFGrad; - __m256 rCheckOGrad; - - __m256 *valueIn = (__m256 *)value.gateValue; - __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); - __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); - __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); - __m256 *gradIn = (__m256 *)grad.gateGrad; - __m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize); - __m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2); - __m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3); - - for (int i = 0; i < frameSize / 8; i++) { - rValueIn = valueIn[i]; - rValueIg = valueIg[i]; - rValueFg = valueFg[i]; - rValueOg = valueOg[i]; - if (value.checkIg) { - rCheckI = ((__m256 *)value.checkIg)[i]; - rCheckF = ((__m256 *)value.checkFg)[i]; - rCheckO = ((__m256 *)value.checkOg)[i]; + __m256 r_value_in; + __m256 r_value_ig; + __m256 r_value_fg; + __m256 r_value_og; + __m256 r_grad_in; + __m256 r_grad_ig; + __m256 r_grad_fg; + __m256 r_grad_og; + __m256 r_prev_state = _mm256_set1_ps(0.0f); + __m256 r_prev_state_grad; + __m256 r_state_grad; + __m256 r_state; + __m256 r_state_atv; + __m256 r_output_grad; + __m256 r_checkI = _mm256_set1_ps(0.0f); + __m256 r_checkF = _mm256_set1_ps(0.0f); + __m256 r_checkO = _mm256_set1_ps(0.0f); + __m256 r_checkIGrad; + __m256 r_checkFGrad; + __m256 r_checkOGrad; + + __m256 *value_in = (__m256 *)value.gate_value; + __m256 *value_ig = (__m256 *)(value.gate_value + frame_size); + __m256 *value_fg = (__m256 *)(value.gate_value + frame_size * 2); + __m256 *value_og = (__m256 *)(value.gate_value + frame_size * 3); + __m256 *grad_in = (__m256 *)grad.gate_grad; + __m256 *grad_ig = (__m256 *)(grad.gate_grad + frame_size); + __m256 *grad_fg = (__m256 *)(grad.gate_grad + frame_size * 2); + __m256 *grad_og = (__m256 *)(grad.gate_grad + frame_size * 3); + + for (int i = 0; i < frame_size / 8; i++) { + r_value_in = value_in[i]; + r_value_ig = value_ig[i]; + r_value_fg = value_fg[i]; + r_value_og = value_og[i]; + if (value.check_ig) { + r_checkI = ((__m256 *)value.check_ig)[i]; + r_checkF = ((__m256 *)value.check_fg)[i]; + r_checkO = ((__m256 *)value.check_og)[i]; } - rState = ((__m256 *)value.stateValue)[i]; - rStateAtv = ((__m256 *)value.stateActiveValue)[i]; - rOutputGrad = ((__m256 *)grad.outputGrad)[i]; - rStateGrad = ((__m256 *)grad.stateGrad)[i]; - if (value.prevStateValue) { - rPrevState = ((__m256 *)value.prevStateValue)[i]; + r_state = ((__m256 *)value.state_value)[i]; + r_state_atv = ((__m256 *)value.state_active_value)[i]; + r_output_grad = ((__m256 *)grad.output_grad)[i]; + r_state_grad = ((__m256 *)grad.state_grad)[i]; + if (value.prev_state_value) { + r_prev_state = ((__m256 *)value.prev_state_value)[i]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, - rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, - rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, - rCheckOGrad, active_node, active_gate, active_state); - - gradIn[i] = rGradIn; - gradIg[i] = rGradIg; - gradFg[i] = rGradFg; - gradOg[i] = rGradOg; - ((__m256 *)grad.stateGrad)[i] = rStateGrad; - - if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad; - if (value.prevStateValue) { - if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad; - if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad; + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, + r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, + r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, + r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, + active_state); + + grad_in[i] = r_grad_in; + grad_ig[i] = r_grad_ig; + grad_fg[i] = r_grad_fg; + grad_og[i] = r_grad_og; + ((__m256 *)grad.state_grad)[i] = r_state_grad; + + if (grad.prev_state_grad) + ((__m256 *)grad.prev_state_grad)[i] = r_prev_state_grad; + if (value.prev_state_value) { + if (grad.check_ig_grad) ((__m256 *)grad.check_ig_grad)[i] += r_checkIGrad; + if (grad.check_fg_grad) ((__m256 *)grad.check_fg_grad)[i] += r_checkFGrad; } - if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad; + if (grad.check_og_grad) ((__m256 *)grad.check_og_grad)[i] += r_checkOGrad; } #endif } template -void cpu_lstm_forward(Op op, LstmMetaValue value, int frameSize, +void cpu_lstm_forward(Op op, LstmMetaValue value, int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same::value)) { - avx_lstm_forward_one_sequence(op, value, frameSize, active_node, + if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { + avx_lstm_forward_one_sequence(op, value, frame_size, active_node, active_gate, active_state); } else { - naive_lstm_forward_one_sequence(op, value, frameSize, active_node, + naive_lstm_forward_one_sequence(op, value, frame_size, active_node, active_gate, active_state); } } template void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, - int frameSize, activation_mode_t active_node, + int frame_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same::value)) { - avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { + avx_lstm_backward_one_sequence(op, value, grad, frame_size, active_node, active_gate, active_state); } else { - naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, - active_gate, active_state); + naive_lstm_backward_one_sequence(op, value, grad, frame_size, + active_node, active_gate, active_state); } } diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index d138bbe411f69929a14ad19af3e84824ac7a5d58..91bfedea53a2600156c9025f6ff3615d695a712b 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -26,189 +26,192 @@ namespace math { namespace detail { /* - * threads(framePerBlock, batchPerBlock) - * grid(frameBlocks, batchBlocks) + * threads(frame_per_block, batch_per_block) + * grid(frame_blocks, batch_blocks) */ -template -__global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, - int batchSize, activation_mode_t active_node, +template +__global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, + int batch_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; - if (frameIdx >= frameSize) return; - - int batchIdx = 0; - if (isBatch) { - batchIdx = blockIdx.y * blockDim.y + threadIdx.y; - if (batchIdx >= batchSize) return; - value.gateValue += batchIdx * frameSize * 4; - value.outputValue += batchIdx * frameSize; - value.stateValue += batchIdx * frameSize; - value.stateActiveValue += batchIdx * frameSize; + const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (frame_idx >= frame_size) return; + + int batch_idx = 0; + if (is_batch) { + batch_idx = blockIdx.y * blockDim.y + threadIdx.y; + if (batch_idx >= batch_size) return; + value.gate_value += batch_idx * frame_size * 4; + value.output_value += batch_idx * frame_size; + value.state_value += batch_idx * frame_size; + value.state_active_value += batch_idx * frame_size; } - T rState; - T rPrevState = 0; - T rStateAtv; - T rOut; - T rValueIn; - T rValueIg; - T rValueFg; - T rValueOg; - - T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0; - T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0; - T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0; - - rValueIn = value.gateValue[frameIdx]; - rValueIg = value.gateValue[frameIdx + frameSize]; - rValueFg = value.gateValue[frameIdx + frameSize * 2]; - rValueOg = value.gateValue[frameIdx + frameSize * 3]; - - if (value.prevStateValue) { - if (isBatch) value.prevStateValue += batchIdx * frameSize; - rPrevState = value.prevStateValue[frameIdx]; + T r_state; + T r_prev_state = 0; + T r_state_atv; + T r_out; + T r_value_in; + T r_value_ig; + T r_value_fg; + T r_value_og; + + T r_checkI = value.check_ig ? value.check_ig[frame_idx] : 0; + T r_checkF = value.check_fg ? value.check_fg[frame_idx] : 0; + T r_checkO = value.check_og ? value.check_og[frame_idx] : 0; + + r_value_in = value.gate_value[frame_idx]; + r_value_ig = value.gate_value[frame_idx + frame_size]; + r_value_fg = value.gate_value[frame_idx + frame_size * 2]; + r_value_og = value.gate_value[frame_idx + frame_size * 3]; + + if (value.prev_state_value) { + if (is_batch) value.prev_state_value += batch_idx * frame_size; + r_prev_state = value.prev_state_value[frame_idx]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, + r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate, + active_state); - value.gateValue[frameIdx] = rValueIn; - value.gateValue[frameIdx + frameSize] = rValueIg; - value.gateValue[frameIdx + frameSize * 2] = rValueFg; - value.gateValue[frameIdx + frameSize * 3] = rValueOg; + value.gate_value[frame_idx] = r_value_in; + value.gate_value[frame_idx + frame_size] = r_value_ig; + value.gate_value[frame_idx + frame_size * 2] = r_value_fg; + value.gate_value[frame_idx + frame_size * 3] = r_value_og; - value.stateValue[frameIdx] = rState; - value.stateActiveValue[frameIdx] = rStateAtv; - value.outputValue[frameIdx] = rOut; + value.state_value[frame_idx] = r_state; + value.state_active_value[frame_idx] = r_state_atv; + value.output_value[frame_idx] = r_out; } /* - * threads(framePerBlock, batchPerBlock) - * grid(frameBlocks, batchBlocks) + * threads(frame_per_block, batch_per_block) + * grid(frame_blocks, batch_blocks) */ -template +template __global__ void KeLstmBackward(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frameSize, - int batchSize, activation_mode_t active_node, + LstmMetaGrad grad, int frame_size, + int batch_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; - if (frameIdx >= frameSize) return; - - int batchIdx = 0; - if (isBatch) { - batchIdx = blockIdx.y * blockDim.y + threadIdx.y; - if (batchIdx >= batchSize) return; - value.gateValue += batchIdx * frameSize * 4; - value.stateValue += batchIdx * frameSize; - value.stateActiveValue += batchIdx * frameSize; - grad.gateGrad += batchIdx * frameSize * 4; - grad.stateGrad += batchIdx * frameSize; - grad.outputGrad += batchIdx * frameSize; + const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (frame_idx >= frame_size) return; + + int batch_idx = 0; + if (is_batch) { + batch_idx = blockIdx.y * blockDim.y + threadIdx.y; + if (batch_idx >= batch_size) return; + value.gate_value += batch_idx * frame_size * 4; + value.state_value += batch_idx * frame_size; + value.state_active_value += batch_idx * frame_size; + grad.gate_grad += batch_idx * frame_size * 4; + grad.state_grad += batch_idx * frame_size; + grad.output_grad += batch_idx * frame_size; } - T rValueIn; - T rValueIg; - T rValueFg; - T rValueOg; - T rGradIn; - T rGradIg; - T rGradFg; - T rGradOg; - T rPrevState = 0; - T rPrevStateGrad; - T rState; - T rStateGrad; - T rStateAtv; - T rOutputGrad; - T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0; - T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0; - T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0; - - T rCheckIGrad; - T rCheckFGrad; - T rCheckOGrad; - - rValueIn = value.gateValue[frameIdx]; - rValueIg = value.gateValue[frameIdx + frameSize]; - rValueFg = value.gateValue[frameIdx + frameSize * 2]; - rValueOg = value.gateValue[frameIdx + frameSize * 3]; - rState = value.stateValue[frameIdx]; - rStateAtv = value.stateActiveValue[frameIdx]; - rOutputGrad = grad.outputGrad[frameIdx]; - rStateGrad = grad.stateGrad[frameIdx]; - - if (value.prevStateValue) { - if (isBatch) value.prevStateValue += batchIdx * frameSize; - rPrevState = value.prevStateValue[frameIdx]; + T r_value_in; + T r_value_ig; + T r_value_fg; + T r_value_og; + T r_grad_in; + T r_grad_ig; + T r_grad_fg; + T r_grad_og; + T r_prev_state = 0; + T r_prev_state_grad; + T r_state; + T r_state_grad; + T r_state_atv; + T r_output_grad; + T r_checkI = value.check_ig ? value.check_ig[frame_idx] : 0; + T r_checkF = value.check_fg ? value.check_fg[frame_idx] : 0; + T r_checkO = value.check_og ? value.check_og[frame_idx] : 0; + + T r_checkIGrad; + T r_checkFGrad; + T r_checkOGrad; + + r_value_in = value.gate_value[frame_idx]; + r_value_ig = value.gate_value[frame_idx + frame_size]; + r_value_fg = value.gate_value[frame_idx + frame_size * 2]; + r_value_og = value.gate_value[frame_idx + frame_size * 3]; + r_state = value.state_value[frame_idx]; + r_state_atv = value.state_active_value[frame_idx]; + r_output_grad = grad.output_grad[frame_idx]; + r_state_grad = grad.state_grad[frame_idx]; + + if (value.prev_state_value) { + if (is_batch) value.prev_state_value += batch_idx * frame_size; + r_prev_state = value.prev_state_value[frame_idx]; } - op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, - rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, - rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, - active_node, active_gate, active_state); - - grad.gateGrad[frameIdx] = rGradIn; - grad.gateGrad[frameIdx + frameSize] = rGradIg; - grad.gateGrad[frameIdx + frameSize * 2] = rGradFg; - grad.gateGrad[frameIdx + frameSize * 3] = rGradOg; - grad.stateGrad[frameIdx] = rStateGrad; - if (grad.prevStateGrad) { - if (isBatch) grad.prevStateGrad += batchIdx * frameSize; - grad.prevStateGrad[frameIdx] = rPrevStateGrad; + op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, + r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, + r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, + r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, + active_state); + + grad.gate_grad[frame_idx] = r_grad_in; + grad.gate_grad[frame_idx + frame_size] = r_grad_ig; + grad.gate_grad[frame_idx + frame_size * 2] = r_grad_fg; + grad.gate_grad[frame_idx + frame_size * 3] = r_grad_og; + grad.state_grad[frame_idx] = r_state_grad; + if (grad.prev_state_grad) { + if (is_batch) grad.prev_state_grad += batch_idx * frame_size; + grad.prev_state_grad[frame_idx] = r_prev_state_grad; } - if (isBatch) { - if (value.prevStateValue) { - if (grad.checkIgGrad) - paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx, - rCheckIGrad); - if (grad.checkFgGrad) - paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx, - rCheckFGrad); + if (is_batch) { + if (value.prev_state_value) { + if (grad.check_ig_grad) + paddle::platform::CudaAtomicAdd(grad.check_ig_grad + frame_idx, + r_checkIGrad); + if (grad.check_fg_grad) + paddle::platform::CudaAtomicAdd(grad.check_fg_grad + frame_idx, + r_checkFGrad); } - if (grad.checkOgGrad) - paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad); + if (grad.check_og_grad) + paddle::platform::CudaAtomicAdd(grad.check_og_grad + frame_idx, + r_checkOGrad); } else { - if (value.prevStateValue) { - if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad; - if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad; + if (value.prev_state_value) { + if (grad.check_ig_grad) grad.check_ig_grad[frame_idx] += r_checkIGrad; + if (grad.check_fg_grad) grad.check_fg_grad[frame_idx] += r_checkFGrad; } - if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad; + if (grad.check_og_grad) grad.check_og_grad[frame_idx] += r_checkOGrad; } } template void gpu_lstm_forward(const platform::DeviceContext& context, Op op, - LstmMetaValue value, int frameSize, int batchSize, + LstmMetaValue value, int frame_size, int batch_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { dim3 threads; dim3 grid; - if (batchSize == 1) { - int framePerBlock = frameSize <= 1024 ? frameSize : 1024; - int frameBlocks = (frameSize + 1024 - 1) / 1024; - threads = dim3(framePerBlock, 1); - grid = dim3(frameBlocks, 1); + if (batch_size == 1) { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grid = dim3(frame_blocks, 1); } else { - /* framePerBlock = 32 batchPerBlock = 32 */ + /* frame_per_block = 32 batch_per_block = 32 */ threads = dim3(32, 32); - grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); } auto stream = reinterpret_cast(context).stream(); - if (batchSize == 1) { + if (batch_size == 1) { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, + /* is_batch= */ false><<>>( + op, value, frame_size, batch_size, active_node, active_gate, active_state); } else { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, + /* is_batch= */ true><<>>( + op, value, frame_size, batch_size, active_node, active_gate, active_state); } } @@ -216,34 +219,34 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, template void gpu_lstm_backward(const platform::DeviceContext& context, Op op, LstmMetaValue value, LstmMetaGrad grad, - int frameSize, int batchSize, + int frame_size, int batch_size, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { dim3 threads; dim3 grid; - if (batchSize == 1) { - int framePerBlock = frameSize <= 1024 ? frameSize : 1024; - int frameBlocks = (frameSize + 1024 - 1) / 1024; - threads = dim3(framePerBlock, 1); - grid = dim3(frameBlocks, 1); + if (batch_size == 1) { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grid = dim3(frame_blocks, 1); } else { - /* framePerBlock = 32 batchPerBlock = 16 */ + /* frame_per_block = 32 batch_per_block = 16 */ threads = dim3(32, 16); - grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 16 - 1) / 16); + grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 16 - 1) / 16); } auto stream = reinterpret_cast(context).stream(); - if (batchSize == 1) { + if (batch_size == 1) { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, + /* is_batch= */ false><<>>( + op, value, grad, frame_size, batch_size, active_node, active_gate, active_state); } else { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, + /* is_batch= */ true><<>>( + op, value, grad, frame_size, batch_size, active_node, active_gate, active_state); } } diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h index 9daaf91981a8e0252374f528f0e063111bd32675..78f9a249a3d5d413452952edf990975c02f1a369 100644 --- a/paddle/operators/math/detail/lstm_kernel.h +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -27,19 +27,19 @@ namespace forward { template class lstm { public: - HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, - T &prevState, T &state, T &stateAtv, T &output, + HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, + T &prev_state, T &state, T &state_atv, T &output, T &checkI, T &checkF, T &checkO, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - valueIn = activation(valueIn, active_node); - valueIg = activation(valueIg + prevState * checkI, active_gate); - valueFg = activation(valueFg + prevState * checkF, active_gate); - state = valueIn * valueIg + prevState * valueFg; - valueOg = activation(valueOg + state * checkO, active_gate); - stateAtv = activation(state, active_state); - output = valueOg * stateAtv; + value_in = activation(value_in, active_node); + value_ig = activation(value_ig + prev_state * checkI, active_gate); + value_fg = activation(value_fg + prev_state * checkF, active_gate); + state = value_in * value_ig + prev_state * value_fg; + value_og = activation(value_og + state * checkO, active_gate); + state_atv = activation(state, active_state); + output = value_og * state_atv; } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -48,24 +48,27 @@ class lstm { // Only float support AVX optimization static const bool avx = std::is_same::value; - HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, - __m256 &valueOg, __m256 &prevState, __m256 &state, - __m256 &stateAtv, __m256 &output, __m256 &checkI, + HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig, + __m256 &value_fg, __m256 &value_og, + __m256 &prev_state, __m256 &state, + __m256 &state_atv, __m256 &output, __m256 &checkI, __m256 &checkF, __m256 &checkO, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - valueIn = activation(valueIn, active_node); - valueIg = activation( - _mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)), active_gate); - valueFg = activation( - _mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)), active_gate); - state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg), - _mm256_mul_ps(prevState, valueFg)); - valueOg = activation(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)), - active_gate); - stateAtv = activation(state, active_state); - output = _mm256_mul_ps(valueOg, stateAtv); + value_in = activation(value_in, active_node); + value_ig = + activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)), + active_gate); + value_fg = + activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)), + active_gate); + state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig), + _mm256_mul_ps(prev_state, value_fg)); + value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)), + active_gate); + state_atv = activation(state, active_state); + output = _mm256_mul_ps(value_og, state_atv); } #endif #endif @@ -78,25 +81,26 @@ namespace backward { template class lstm { public: - HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, - T &gradIn, T &gradIg, T &gradFg, T &gradOg, - T &prevState, T &prevStateGrad, T &state, - T &stateGrad, T &stateAtv, T &outputGrad, + HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, + T &grad_in, T &grad_ig, T &grad_fg, T &grad_og, + T &prev_state, T &prev_state_grad, T &state, + T &state_grad, T &state_atv, T &output_grad, T &checkI, T &checkF, T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - gradOg = activation(outputGrad * stateAtv, valueOg, active_gate); - stateGrad += activation(outputGrad * valueOg, stateAtv, active_state) + - gradOg * checkO; - gradIn = activation(stateGrad * valueIg, valueIn, active_node); - gradIg = activation(stateGrad * valueIn, valueIg, active_gate); - gradFg = activation(stateGrad * prevState, valueFg, active_gate); - prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; - checkIGrad = gradIg * prevState; - checkFGrad = gradFg * prevState; - checkOGrad = gradOg * state; + grad_og = activation(output_grad * state_atv, value_og, active_gate); + state_grad += activation(output_grad * value_og, state_atv, active_state) + + grad_og * checkO; + grad_in = activation(state_grad * value_ig, value_in, active_node); + grad_ig = activation(state_grad * value_in, value_ig, active_gate); + grad_fg = activation(state_grad * prev_state, value_fg, active_gate); + prev_state_grad = + grad_ig * checkI + grad_fg * checkF + state_grad * value_fg; + checkIGrad = grad_ig * prev_state; + checkFGrad = grad_fg * prev_state; + checkOGrad = grad_og * state; } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -105,32 +109,32 @@ class lstm { // Only float support AVX optimization static const bool avx = std::is_same::value; HOSTDEVICE void operator()( - __m256 &valueIn, __m256 &valueIg, __m256 &valueFg, __m256 &valueOg, - __m256 &gradIn, __m256 &gradIg, __m256 &gradFg, __m256 &gradOg, - __m256 &prevState, __m256 &prevStateGrad, __m256 &state, - __m256 &stateGrad, __m256 &stateAtv, __m256 &outputGrad, __m256 &checkI, - __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad, - __m256 &checkOGrad, activation_mode_t active_node, + __m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og, + __m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og, + __m256 &prev_state, __m256 &prev_state_grad, __m256 &state, + __m256 &state_grad, __m256 &state_atv, __m256 &output_grad, + __m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, + __m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_node, activation_mode_t active_gate, activation_mode_t active_state) { - gradOg = - activation(_mm256_mul_ps(outputGrad, stateAtv), valueOg, active_gate); - stateGrad = _mm256_add_ps( - activation(_mm256_mul_ps(outputGrad, valueOg), stateAtv, active_state), - stateGrad); - stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad); - gradIn = - activation(_mm256_mul_ps(stateGrad, valueIg), valueIn, active_node); - gradIg = - activation(_mm256_mul_ps(stateGrad, valueIn), valueIg, active_gate); - gradFg = - activation(_mm256_mul_ps(stateGrad, prevState), valueFg, active_gate); - prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI), - _mm256_mul_ps(gradFg, checkF)); - prevStateGrad = - _mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad); - checkIGrad = _mm256_mul_ps(gradIg, prevState); - checkFGrad = _mm256_mul_ps(gradFg, prevState); - checkOGrad = _mm256_mul_ps(gradOg, state); + grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og, + active_gate); + state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og), + state_atv, active_state), + state_grad); + state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad); + grad_in = + activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node); + grad_ig = + activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate); + grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg, + active_gate); + prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI), + _mm256_mul_ps(grad_fg, checkF)); + prev_state_grad = + _mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad); + checkIGrad = _mm256_mul_ps(grad_ig, prev_state); + checkFGrad = _mm256_mul_ps(grad_fg, prev_state); + checkOGrad = _mm256_mul_ps(grad_og, state); } #endif #endif diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc index 0febf8e3b70111d12f858cf6259a2801a42d9a90..ad3a59bcdbe65a4107db4e56a0870794914a0fd8 100644 --- a/paddle/operators/math/lstm_compute.cc +++ b/paddle/operators/math/lstm_compute.cc @@ -30,12 +30,12 @@ struct LstmUnitFunctor { detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, ActiveType(cand_act), ActiveType(gate_act), ActiveType(cell_act)); - value.gateValue += frame_size * 4; - value.stateValue += frame_size; - value.stateActiveValue += frame_size; - value.outputValue += frame_size; - if (value.prevStateValue) { - value.prevStateValue += frame_size; + value.gate_value += frame_size * 4; + value.state_value += frame_size; + value.state_active_value += frame_size; + value.output_value += frame_size; + if (value.prev_state_value) { + value.prev_state_value += frame_size; } } } @@ -53,20 +53,20 @@ struct LstmUnitGradFunctor { frame_size, ActiveType(cand_act), ActiveType(gate_act), ActiveType(cell_act)); - value.gateValue += frame_size * 4; - value.stateValue += frame_size; - value.stateActiveValue += frame_size; - value.outputValue += frame_size; - if (value.prevStateValue) { - value.prevStateValue += frame_size; + value.gate_value += frame_size * 4; + value.state_value += frame_size; + value.state_active_value += frame_size; + value.output_value += frame_size; + if (value.prev_state_value) { + value.prev_state_value += frame_size; } - grad.gateGrad += frame_size * 4; - grad.stateGrad += frame_size; - grad.stateActiveGrad += frame_size; - grad.outputGrad += frame_size; - if (grad.prevStateGrad) { - grad.prevStateGrad += frame_size; + grad.gate_grad += frame_size * 4; + grad.state_grad += frame_size; + grad.state_active_grad += frame_size; + grad.output_grad += frame_size; + if (grad.prev_state_grad) { + grad.prev_state_grad += frame_size; } } } diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index 28d2c6fd3b0d8143da90c37f241072e37397f98b..9652399d4c149c5c186a0780f2b8bf2294bad978 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -31,26 +31,26 @@ typedef enum { template struct LstmMetaValue { - T *gateValue; - T *prevStateValue; - T *stateValue; - T *stateActiveValue; - T *outputValue; - T *checkIg; - T *checkFg; - T *checkOg; + T *gate_value; + T *prev_state_value; + T *state_value; + T *state_active_value; + T *output_value; + T *check_ig; + T *check_fg; + T *check_og; }; template struct LstmMetaGrad { - T *gateGrad; - T *prevStateGrad; - T *stateGrad; - T *stateActiveGrad; - T *outputGrad; - T *checkIgGrad; - T *checkFgGrad; - T *checkOgGrad; + T *gate_grad; + T *prev_state_grad; + T *state_grad; + T *state_active_grad; + T *output_grad; + T *check_ig_grad; + T *check_fg_grad; + T *check_og_grad; }; inline activation_mode_t ActiveType(const std::string &type) { diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index a54dc0d9fdb3c30391b01966ad493540c8ad1375..fd55f410d3f0fee418e7efffa927e46c38d23a07 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -5,4 +5,6 @@ if(WITH_PYTHON) ${GLOB_OP_LIB}) endif(WITH_PYTHON) -cc_binary(print_operators_doc SRCS print_operators_doc.cc DEPS ${GLOB_OP_LIB}) +if(WITH_DOC) + cc_binary(print_operators_doc SRCS print_operators_doc.cc DEPS ${GLOB_OP_LIB}) +endif(WITH_DOC) diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index fbd0b6b07876451ad973eb98bbff822a2a58db43..0f889e6853a48b77e2b75c06554632ad14aefa48 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -185,7 +185,14 @@ EOF ${DOCKERFILE_GPU_ENV} ADD go/cmd/pserver/pserver /usr/bin/ ADD go/cmd/master/master /usr/bin/ - ADD paddle/pybind/print_operators_doc /usr/bin/ +EOF + + if [[ ${WITH_DOC:-OFF} == 'ON' ]]; then + cat >> /paddle/build/Dockerfile <> /paddle/build/Dockerfile < 0: + t.set_lod(self.lod) + return t + + +class DataFeeder(object): + def __init__(self, feed_list, place): + self.feed_dtypes = [] + self.feed_names = [] + self.feed_shapes = [] + self.feed_lod_level = [] + for each_var in feed_list: + if not isinstance(each_var, Variable): + raise TypeError("Feed list should contain a list of variable") + self.feed_dtypes.append(each_var.dtype) + self.feed_names.append(each_var.name) + shape = each_var.shape + batch_size_dim = -1 + for i, s in enumerate(shape): + if s < 0: + batch_size_dim = i + break + if batch_size_dim == -1: + raise ValueError("Variable {0} must has a batch size dimension", + each_var.name) + self.feed_lod_level.append(each_var.lod_level) + self.feed_shapes.append(shape) + + self.place = place + + def feed(self, iterable): + converter = [] + for lod_level, shape, dtype in six.zip( + self.feed_lod_level, self.feed_shapes, self.feed_dtypes): + converter.append( + DataToLoDTensorConverter( + place=self.place, + lod_level=lod_level, + shape=shape, + dtype=dtype)) + + for each_sample in iterable: + for each_converter, each_slot in six.zip(converter, each_sample): + each_converter.feed(each_slot) + ret_dict = {} + for each_name, each_converter in six.zip(self.feed_names, converter): + ret_dict[each_name] = each_converter.done() + return ret_dict diff --git a/python/paddle/v2/fluid/tests/book/test_fit_a_line.py b/python/paddle/v2/fluid/tests/book/test_fit_a_line.py index 9f98493adb21a03b8efde0f88c490e77c9d303e7..fbf46ac6cba8fa4981cc8a6e8f5434a510c52d7d 100644 --- a/python/paddle/v2/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/v2/fluid/tests/book/test_fit_a_line.py @@ -22,6 +22,7 @@ train_reader = paddle.batch( batch_size=BATCH_SIZE) place = fluid.CPUPlace() +feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) @@ -31,12 +32,8 @@ for pass_id in range(PASS_NUM): fluid.io.save_persistables(exe, "./fit_a_line.model/") fluid.io.load_persistables(exe, "./fit_a_line.model/") for data in train_reader(): - x_data = np.array(map(lambda _: _[0], data)).astype("float32") - y_data = np.array(map(lambda _: _[1], data)).astype("float32") - avg_loss_value, = exe.run(fluid.default_main_program(), - feed={'x': x_data, - 'y': y_data}, + feed=feeder.feed(data), fetch_list=[avg_cost]) if avg_loss_value[0] < 10.0: diff --git a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py index 0f0cc5b5406ef51ac3504a95ea716056ae8730af..4e71b6f345ea7a1e6d29bc4ad810bc5b5f99d456 100644 --- a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py +++ b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py @@ -113,23 +113,14 @@ train_reader = paddle.batch( place = fluid.CPUPlace() exe = fluid.Executor(place) - +feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) exe.run(fluid.default_startup_program()) for pass_id in range(PASS_NUM): accuracy.reset(exe) for data in train_reader(): - img_data = np.array(map(lambda x: x[0].reshape(data_shape), - data)).astype("float32") - y_data = np.array(map(lambda x: x[1], data)).astype("int64") - batch_size = 1 - for i in y_data.shape: - batch_size = batch_size * i - y_data = y_data.reshape([batch_size, 1]) - loss, acc = exe.run(fluid.default_main_program(), - feed={"pixel": img_data, - "label": y_data}, + feed=feeder.feed(data), fetch_list=[avg_cost] + accuracy.metrics) pass_acc = accuracy.eval(exe) print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( diff --git a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py index bcd6f4d6bc66fd01406332bd1d6d7a5c4b0ddb5a..0494c7cdcaa8b55e79b4365b3ed389ca26974f36 100644 --- a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py @@ -28,17 +28,9 @@ def load_parameter(file_name, h, w): return np.fromfile(f, dtype=np.float32).reshape(h, w) -def db_lstm(): +def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, + **ignored): # 8 features - word = fluid.layers.data(name='word_data', shape=[1], dtype='int64') - predicate = fluid.layers.data(name='verb_data', shape=[1], dtype='int64') - ctx_n2 = fluid.layers.data(name='ctx_n2_data', shape=[1], dtype='int64') - ctx_n1 = fluid.layers.data(name='ctx_n1_data', shape=[1], dtype='int64') - ctx_0 = fluid.layers.data(name='ctx_0_data', shape=[1], dtype='int64') - ctx_p1 = fluid.layers.data(name='ctx_p1_data', shape=[1], dtype='int64') - ctx_p2 = fluid.layers.data(name='ctx_p2_data', shape=[1], dtype='int64') - mark = fluid.layers.data(name='mark_data', shape=[1], dtype='int64') - predicate_embedding = fluid.layers.embedding( input=predicate, size=[pred_len, word_dim], @@ -120,8 +112,25 @@ def to_lodtensor(data, place): def main(): # define network topology - feature_out = db_lstm() - target = fluid.layers.data(name='target', shape=[1], dtype='int64') + word = fluid.layers.data( + name='word_data', shape=[1], dtype='int64', lod_level=1) + predicate = fluid.layers.data( + name='verb_data', shape=[1], dtype='int64', lod_level=1) + ctx_n2 = fluid.layers.data( + name='ctx_n2_data', shape=[1], dtype='int64', lod_level=1) + ctx_n1 = fluid.layers.data( + name='ctx_n1_data', shape=[1], dtype='int64', lod_level=1) + ctx_0 = fluid.layers.data( + name='ctx_0_data', shape=[1], dtype='int64', lod_level=1) + ctx_p1 = fluid.layers.data( + name='ctx_p1_data', shape=[1], dtype='int64', lod_level=1) + ctx_p2 = fluid.layers.data( + name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1) + mark = fluid.layers.data( + name='mark_data', shape=[1], dtype='int64', lod_level=1) + feature_out = db_lstm(**locals()) + target = fluid.layers.data( + name='target', shape=[1], dtype='int64', lod_level=1) crf_cost = fluid.layers.linear_chain_crf( input=feature_out, label=target, @@ -139,6 +148,11 @@ def main(): paddle.dataset.conll05.test(), buf_size=8192), batch_size=BATCH_SIZE) place = fluid.CPUPlace() + feeder = fluid.DataFeeder( + feed_list=[ + word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate, mark, target + ], + place=place) exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) @@ -150,28 +164,8 @@ def main(): batch_id = 0 for pass_id in xrange(PASS_NUM): for data in train_data(): - word_data = to_lodtensor(map(lambda x: x[0], data), place) - ctx_n2_data = to_lodtensor(map(lambda x: x[1], data), place) - ctx_n1_data = to_lodtensor(map(lambda x: x[2], data), place) - ctx_0_data = to_lodtensor(map(lambda x: x[3], data), place) - ctx_p1_data = to_lodtensor(map(lambda x: x[4], data), place) - ctx_p2_data = to_lodtensor(map(lambda x: x[5], data), place) - verb_data = to_lodtensor(map(lambda x: x[6], data), place) - mark_data = to_lodtensor(map(lambda x: x[7], data), place) - target = to_lodtensor(map(lambda x: x[8], data), place) - outs = exe.run(fluid.default_main_program(), - feed={ - 'word_data': word_data, - 'ctx_n2_data': ctx_n2_data, - 'ctx_n1_data': ctx_n1_data, - 'ctx_0_data': ctx_0_data, - 'ctx_p1_data': ctx_p1_data, - 'ctx_p2_data': ctx_p2_data, - 'verb_data': verb_data, - 'mark_data': mark_data, - 'target': target - }, + feed=feeder.feed(data), fetch_list=[avg_cost]) avg_cost_val = np.array(outs[0]) diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py index ba686b56f8603834c12f5ed24e0ef7308c78585d..35bf8da924dc76475df9bd5e6a4c04f4d204426a 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py @@ -37,20 +37,14 @@ train_reader = paddle.batch( place = fluid.CPUPlace() exe = fluid.Executor(place) - +feeder = fluid.DataFeeder(feed_list=[images, label], place=place) exe.run(fluid.default_startup_program()) for pass_id in range(PASS_NUM): accuracy.reset(exe) for data in train_reader(): - img_data = np.array(map(lambda x: x[0].reshape([1, 28, 28]), - data)).astype("float32") - y_data = np.array(map(lambda x: x[1], data)).astype("int64") - y_data = y_data.reshape([BATCH_SIZE, 1]) - loss, acc = exe.run(fluid.default_main_program(), - feed={"pixel": img_data, - "label": y_data}, + feed=feeder.feed(data), fetch_list=[avg_cost] + accuracy.metrics) pass_acc = accuracy.eval(exe) print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" + diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index fa18965aac667c0829b9e6ee56ece585564f9060..4dc2c50e1c963a189b727f0a7edcb6886abd9038 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -48,40 +48,22 @@ test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) place = fluid.CPUPlace() exe = fluid.Executor(place) - +feeder = fluid.DataFeeder(feed_list=[image, label], place=place) exe.run(fluid.default_startup_program()) PASS_NUM = 100 for pass_id in range(PASS_NUM): accuracy.reset(exe) for data in train_reader(): - x_data = np.array(map(lambda x: x[0], data)).astype("float32") - y_data = np.array(map(lambda x: x[1], data)).astype("int64") - y_data = np.expand_dims(y_data, axis=1) - - tensor_x = fluid.LoDTensor() - tensor_x.set(x_data, place) - - tensor_y = fluid.LoDTensor() - tensor_y.set(y_data, place) - - outs = exe.run(fluid.default_main_program(), - feed={'x': tensor_x, - 'y': tensor_y}, - fetch_list=[avg_cost] + accuracy.metrics) - out = np.array(outs[0]) - acc = np.array(outs[1]) + out, acc = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost] + accuracy.metrics) pass_acc = accuracy.eval(exe) test_accuracy.reset(exe) for data in test_reader(): - x_data = np.array(map(lambda x: x[0], data)).astype("float32") - y_data = np.array(map(lambda x: x[1], data)).astype("int64") - y_data = np.expand_dims(y_data, axis=1) - out, acc = exe.run(inference_program, - feed={'x': x_data, - 'y': y_data}, + feed=feeder.feed(data), fetch_list=[avg_cost] + test_accuracy.metrics) test_pass_acc = test_accuracy.eval(exe) diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py index be875a952b7086ee64984525d70ffd3f1ecb5fae..f103358edca9bbd2e28c99afd249f97b1d8069ae 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py @@ -4,10 +4,8 @@ import paddle.v2 as paddle import paddle.v2.fluid as fluid -def convolution_net(input_dim, class_dim=2, emb_dim=32, hid_dim=32): - data = fluid.layers.data(name="words", shape=[1], dtype="int64") - label = fluid.layers.data(name="label", shape=[1], dtype="int64") - +def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32, + hid_dim=32): emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim]) conv_3 = fluid.nets.sequence_conv_pool( input=emb, @@ -55,8 +53,11 @@ def main(): dict_dim = len(word_dict) class_dim = 2 + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") cost, accuracy, acc_out = convolution_net( - input_dim=dict_dim, class_dim=class_dim) + data, label, input_dim=dict_dim, class_dim=class_dim) train_data = paddle.batch( paddle.reader.shuffle( @@ -64,25 +65,16 @@ def main(): batch_size=BATCH_SIZE) place = fluid.CPUPlace() exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[data, label], place=place) exe.run(fluid.default_startup_program()) for pass_id in xrange(PASS_NUM): accuracy.reset(exe) for data in train_data(): - tensor_words = to_lodtensor(map(lambda x: x[0], data), place) - - label = np.array(map(lambda x: x[1], data)).astype("int64") - label = label.reshape([BATCH_SIZE, 1]) - - tensor_label = fluid.LoDTensor() - tensor_label.set(label, place) - - cost_val, acc_val = exe.run( - fluid.default_main_program(), - feed={"words": tensor_words, - "label": tensor_label}, - fetch_list=[cost, acc_out]) + cost_val, acc_val = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[cost, acc_out]) pass_acc = accuracy.eval(exe) print("cost=" + str(cost_val) + " acc=" + str(acc_val) + " pass_acc=" + str(pass_acc)) diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py index 094a3cdcda12eaee351476e99a388c44b3c81cd6..cd28f04b8574778316d70e7d8a03026f807c3e52 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py @@ -3,14 +3,14 @@ import paddle.v2 as paddle import paddle.v2.fluid as fluid -def stacked_lstm_net(input_dim, +def stacked_lstm_net(data, + label, + input_dim, class_dim=2, emb_dim=128, hid_dim=512, stacked_num=3): assert stacked_num % 2 == 1 - data = fluid.layers.data(name="words", shape=[1], dtype="int64") - label = fluid.layers.data(name="label", shape=[1], dtype="int64") emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim]) # add bias attr @@ -65,8 +65,11 @@ def main(): dict_dim = len(word_dict) class_dim = 2 + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") cost, accuracy, acc_out = stacked_lstm_net( - input_dim=dict_dim, class_dim=class_dim) + data, label, input_dim=dict_dim, class_dim=class_dim) train_data = paddle.batch( paddle.reader.shuffle( @@ -74,25 +77,16 @@ def main(): batch_size=BATCH_SIZE) place = fluid.CPUPlace() exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[data, label], place=place) exe.run(fluid.default_startup_program()) for pass_id in xrange(PASS_NUM): accuracy.reset(exe) for data in train_data(): - tensor_words = to_lodtensor(map(lambda x: x[0], data), place) - - label = np.array(map(lambda x: x[1], data)).astype("int64") - label = label.reshape([BATCH_SIZE, 1]) - - tensor_label = fluid.LoDTensor() - tensor_label.set(label, place) - - cost_val, acc_val = exe.run( - fluid.default_main_program(), - feed={"words": tensor_words, - "label": tensor_label}, - fetch_list=[cost, acc_out]) + cost_val, acc_val = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[cost, acc_out]) pass_acc = accuracy.eval(exe) print("cost=" + str(cost_val) + " acc=" + str(acc_val) + " pass_acc=" + str(pass_acc)) diff --git a/python/paddle/v2/fluid/tests/book/test_word2vec.py b/python/paddle/v2/fluid/tests/book/test_word2vec.py index 1b441e15c72c85c3d44c39b0f685a88db2304eef..8b928ff9eed41f8945c749058b4177fd023452ba 100644 --- a/python/paddle/v2/fluid/tests/book/test_word2vec.py +++ b/python/paddle/v2/fluid/tests/book/test_word2vec.py @@ -57,23 +57,16 @@ train_reader = paddle.batch( place = fluid.CPUPlace() exe = fluid.Executor(place) +feeder = fluid.DataFeeder( + feed_list=[first_word, second_word, third_word, forth_word, next_word], + place=place) exe.run(fluid.default_startup_program()) for pass_id in range(PASS_NUM): for data in train_reader(): - input_data = [[data_idx[idx] for data_idx in data] for idx in xrange(5)] - input_data = map(lambda x: np.array(x).astype("int64"), input_data) - input_data = map(lambda x: np.expand_dims(x, axis=1), input_data) - avg_cost_np = exe.run(fluid.default_main_program(), - feed={ - 'firstw': input_data[0], - 'secondw': input_data[1], - 'thirdw': input_data[2], - 'forthw': input_data[3], - 'nextw': input_data[4] - }, + feed=feeder.feed(data), fetch_list=[avg_cost]) if avg_cost_np[0] < 5.0: exit(0) # if avg cost less than 10.0, we think our code is good. diff --git a/python/paddle/v2/fluid/tests/test_data_feeder.py b/python/paddle/v2/fluid/tests/test_data_feeder.py new file mode 100644 index 0000000000000000000000000000000000000000..454969320321b72342803f507f0054f79f276669 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_data_feeder.py @@ -0,0 +1,13 @@ +import paddle.v2.fluid as fluid + + +def test_converter(): + img = fluid.layers.data(name='image', shape=[1, 28, 28]) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + feeder = fluid.DataFeeder([img, label], fluid.CPUPlace()) + result = feeder.feed([[[0] * 784, [9]], [[1] * 784, [1]]]) + print(result) + + +if __name__ == '__main__': + test_converter()