diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 354ca563eb3be4c036ebb601c0278c6f41fe1c9f..4bfe69e548c59faff9d8123dc73d37a111257d32 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -61,12 +61,12 @@ paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)) paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None)) +paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None)) paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)) paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, False)) -paddle.fluid.layers.softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None)) +paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)) +paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None)) paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None)) paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None)) paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False)) @@ -95,8 +95,8 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) -paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples'], varargs=None, keywords=None, defaults=(None, None, None, None)) -paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None)) +paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None)) +paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 6698efd1fa773127a84b4bcb28f57f4226dd7ae2..b5ccaac2e683417e04b35e3d870db73443e088b1 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -19,8 +19,18 @@ cc_library(paddle_fluid_origin DEPS ${fluid_modules} paddle_fluid_api) add_subdirectory(api) +set(STATIC_INFERENCE_APIS paddle_fluid_api paddle_inference_api analysis_predictor) +set(SHARED_INFERENCE_SRCS + io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc + ${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc ) +if (WITH_GPU AND TENSORRT_FOUND) + set(STATIC_INFERENCE_APIS ${STATIC_INFERENCE_APIS} paddle_inference_tensorrt_subgraph_engine) + set(SHARED_INFERENCE_SRCS ${SHARED_INFERENCE_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/api/api_tensorrt_subgraph_engine.cc) +endif() + # Create static library -cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api paddle_inference_api analysis_predictor) +cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} ) + if(NOT APPLE) # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac. set(LINK_FLAGS "-Wl,--retain-symbols-file ${CMAKE_CURRENT_SOURCE_DIR}/paddle_fluid.sym") @@ -28,9 +38,7 @@ if(NOT APPLE) endif() # Create shared library -cc_library(paddle_fluid_shared SHARED - SRCS io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc - ${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc +cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} DEPS ${fluid_modules} paddle_fluid_api) set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index d4e6bb3e4a4ceb361ccd35121d0ecf84a764243e..ec8471ef960a2fc44af23c52be09cd678fab3f70 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -3,6 +3,7 @@ project(cpp_inference_demo CXX C) option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) +option(USE_TENSORRT "Compile demo with TensorRT." OFF) macro(safe_set_static_flag) foreach(flag_var @@ -60,6 +61,13 @@ endif(NOT WIN32) include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") +if (NOT WIN32) + if (USE_TENSORRT AND WITH_GPU) + include_directories("${TENSORRT_INCLUDE_DIR}") + link_directories("${TENSORRT_LIB_DIR}") + endif() +endif(NOT WIN32) + if (NOT WIN32) link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") @@ -112,6 +120,10 @@ endif(NOT WIN32) if(WITH_GPU) if(NOT WIN32) + if (USE_TENSORRT) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif() set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) else() set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh index 0f7d541c5edfc62e80cf50f83b491f06dcb42644..76238070cda725463820e0e0834f2594382c8e57 100755 --- a/paddle/fluid/inference/api/demo_ci/run.sh +++ b/paddle/fluid/inference/api/demo_ci/run.sh @@ -2,6 +2,12 @@ set -x PADDLE_ROOT=$1 TURN_ON_MKL=$2 # use MKL or Openblas TEST_GPU_CPU=$3 # test both GPU/CPU mode or only CPU mode +DATA_DIR=$4 # dataset +TENSORRT_INCLUDE_DIR=$5 # TensorRT header file dir, defalut to /usr/local/TensorRT/include +TENSORRT_LIB_DIR=$6 # TensorRT lib file dir, default to /usr/local/TensorRT/lib + +cd `dirname $0` +current_dir=`pwd` if [ $2 == ON ]; then # You can export yourself if move the install path MKL_LIB=${PADDLE_ROOT}/build/fluid_install_dir/third_party/install/mklml/lib @@ -13,6 +19,11 @@ else use_gpu_list='false' fi +USE_TENSORRT=OFF +if [ [-d"$TENSORRT_INCLUDE_DIR"] -a [-d"$TENSORRT_LIB_DIR"] ]; then + USE_TENSORRT=ON +fi + PREFIX=inference-vis-demos%2F URL_ROOT=http://paddlemodels.cdn.bcebos.com/${PREFIX} @@ -29,15 +40,15 @@ function download() { fi cd .. } -mkdir -p data -cd data +mkdir -p $DATA_DIR +cd $DATA_DIR vis_demo_list='se_resnext50 ocr mobilenet' for vis_demo_name in $vis_demo_list; do download $vis_demo_name done -cd .. # compile and test the demo +cd $current_dir mkdir -p build cd build @@ -73,9 +84,9 @@ for WITH_STATIC_LIB in ON OFF; do for use_gpu in $use_gpu_list; do for vis_demo_name in $vis_demo_list; do ./vis_demo \ - --modeldir=../data/$vis_demo_name/model \ - --data=../data/$vis_demo_name/data.txt \ - --refer=../data/$vis_demo_name/result.txt \ + --modeldir=$DATA_DIR/$vis_demo_name/model \ + --data=$DATA_DIR/$vis_demo_name/data.txt \ + --refer=$DATA_DIR/$vis_demo_name/result.txt \ --use_gpu=$use_gpu if [ $? -ne 0 ]; then echo "vis demo $vis_demo_name runs fail." @@ -83,5 +94,25 @@ for WITH_STATIC_LIB in ON OFF; do fi done done + + # --------tensorrt mobilenet------ + if [ $USE_TENSORRT == ON -a $TEST_GPU_CPU == ON ]; then + rm -rf * + cmake .. -DPADDLE_LIB=${PADDLE_ROOT}/build/fluid_install_dir/ \ + -DWITH_MKL=$TURN_ON_MKL \ + -DDEMO_NAME=vis_demo \ + -DWITH_GPU=$TEST_GPU_CPU \ + -DWITH_STATIC_LIB=$WITH_STATIC_LIB \ + -DUSE_TENSORRT=$USE_TENSORRT \ + -DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \ + -DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR + make -j + ./vis_demo \ + --modeldir=$DATA_DIR/mobilenet/model \ + --data=$DATA_DIR/mobilenet/data.txt \ + --refer=$DATA_DIR/mobilenet/result.txt \ + --use_gpu=true \ + --use_trt=true + fi done set +x diff --git a/paddle/fluid/inference/api/demo_ci/vis_demo.cc b/paddle/fluid/inference/api/demo_ci/vis_demo.cc index 3800d49b34738d5a272033d75cb415ae9ad1fb8f..e6caf6c4b682b376ff3775559eca66068184e6db 100644 --- a/paddle/fluid/inference/api/demo_ci/vis_demo.cc +++ b/paddle/fluid/inference/api/demo_ci/vis_demo.cc @@ -33,6 +33,7 @@ DEFINE_string( "path of data; each line is a record, format is " "'\t predictor; + if (!use_trt) { + NativeConfig config; + config.param_file = FLAGS_modeldir + "/__params__"; + config.prog_file = FLAGS_modeldir + "/__model__"; + config.use_gpu = use_gpu; + config.device = 0; + if (FLAGS_use_gpu) { + config.fraction_of_gpu_memory = 0.1; // set by yourself + } + + VLOG(3) << "init predictor"; + predictor = + CreatePaddlePredictor(config); + } else { + paddle::contrib::MixedRTConfig config; + config.param_file = FLAGS_modeldir + "/__params__"; + config.prog_file = FLAGS_modeldir + "/__model__"; + config.use_gpu = true; + config.device = 0; + config.max_batch_size = 1; config.fraction_of_gpu_memory = 0.1; // set by yourself + predictor = CreatePaddlePredictor(config); } - VLOG(3) << "init predictor"; - auto predictor = - CreatePaddlePredictor(config); - VLOG(3) << "begin to process data"; // Just a single batch of data. std::string line; @@ -131,7 +144,7 @@ void Main(bool use_gpu) { VLOG(3) << "run executor"; std::vector output; - predictor->Run({input}, &output); + predictor->Run({input}, &output, 1); VLOG(3) << "output.size " << output.size(); auto& tensor = output.front(); @@ -146,9 +159,12 @@ void Main(bool use_gpu) { int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); - paddle::demo::Main(false /* use_gpu*/); - if (FLAGS_use_gpu) { - paddle::demo::Main(true /*use_gpu*/); + if (FLAGS_use_gpu && FLAGS_use_trt) { + paddle::demo::Main(true /*use_gpu*/, true); + } else if (FLAGS_use_gpu) { + paddle::demo::Main(true /*use_gpu*/, false); + } else { + paddle::demo::Main(false /*use_gpu*/, false /*use_tensorrt*/); } return 0; } diff --git a/paddle/fluid/operators/adadelta_op.cc b/paddle/fluid/operators/adadelta_op.cc index d1970515f58969948b1d2db5847e4344112f77f9..89a7a49e0fa8427826f5d91274912a68f2316b61 100644 --- a/paddle/fluid/operators/adadelta_op.cc +++ b/paddle/fluid/operators/adadelta_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; + class AdadeltaOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -31,6 +32,16 @@ class AdadeltaOp : public framework::OperatorWithKernel { "Input(AvgSquaredGrad) of AdadeltaOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"), "Input(AvgSquaredUpdate) of AdadeltaOp should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Grad").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(ParamOut) of AdadeltaOp should not be null."); @@ -56,6 +67,7 @@ class AdadeltaOp : public framework::OperatorWithKernel { ctx->SetOutputDim("AvgSquaredGradOut", param_dim); ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = diff --git a/paddle/fluid/operators/adadelta_op.h b/paddle/fluid/operators/adadelta_op.h index 822458daf663d99bbb38d99205f51163a0df4c4d..6c616aa03d9809e9b7725a700c7edd5ff5d6dc42 100644 --- a/paddle/fluid/operators/adadelta_op.h +++ b/paddle/fluid/operators/adadelta_op.h @@ -23,6 +23,17 @@ template class AdadeltaOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + const auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE(grad_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Grad").front(), grad_var->Type().name()); + auto param_out_tensor = ctx.Output("ParamOut"); auto avg_squared_grad_out_tensor = ctx.Output("AvgSquaredGradOut"); diff --git a/paddle/fluid/operators/adagrad_op.h b/paddle/fluid/operators/adagrad_op.h index df520fcc898ff5514927dbdd845ecaecdcf3c147..0a16ce00f71586ef55007c3753e024be29d0ed56 100644 --- a/paddle/fluid/operators/adagrad_op.h +++ b/paddle/fluid/operators/adagrad_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -21,25 +22,31 @@ namespace operators { template struct SparseAdagradFunctor { - void operator()(const DeviceContext& context, - const framework::SelectedRows& grad, - const framework::Tensor& learning_rate, T epsilon, - framework::Tensor* moment, framework::Tensor* param); + void operator()(const DeviceContext &context, + const framework::SelectedRows &grad, + const framework::Tensor &learning_rate, T epsilon, + framework::Tensor *moment, framework::Tensor *param); }; template class AdagradOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* param_out_tensor = ctx.Output("ParamOut"); - auto* moment_out_tensor = ctx.Output("MomentOut"); + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + + auto *param_out_tensor = ctx.Output("ParamOut"); + auto *moment_out_tensor = ctx.Output("MomentOut"); param_out_tensor->mutable_data(ctx.GetPlace()); moment_out_tensor->mutable_data(ctx.GetPlace()); T epsilon = static_cast(ctx.Attr("epsilon")); - auto* grad_var = ctx.InputVar("Grad"); + auto *grad_var = ctx.InputVar("Grad"); if (grad_var->IsType()) { auto param = framework::EigenVector::Flatten( *ctx.Input("Param")); @@ -47,16 +54,16 @@ class AdagradOpKernel : public framework::OpKernel { *ctx.Input("Grad")); auto moment = framework::EigenVector::Flatten( *ctx.Input("Moment")); - auto* learning_rate = ctx.Input("LearningRate"); + auto *learning_rate = ctx.Input("LearningRate"); auto param_out = framework::EigenVector::Flatten(*param_out_tensor); auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); - auto* place = ctx.template device_context().eigen_device(); + auto *place = ctx.template device_context().eigen_device(); moment_out.device(*place) = moment + grad * grad; Eigen::DSizes m_dsize(moment_out_tensor->numel()); if (platform::is_cpu_place(ctx.GetPlace())) { - auto* lr = learning_rate->data(); + auto *lr = learning_rate->data(); param_out.device(*place) = param - lr[0] * grad / (moment_out.sqrt() + epsilon); } else { @@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel { lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); } } else if (grad_var->IsType()) { - auto* param_tensor = ctx.Input("Param"); + auto *param_tensor = ctx.Input("Param"); PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor); - auto* moment_tensor = ctx.Input("Moment"); + auto *moment_tensor = ctx.Input("Moment"); PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor); SparseAdagradFunctor functor; diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h index 4cb1f3a80e95bdda79e6451dc3cc87e899b11779..3455d1ee54e8e6e498d0b0e6932ec099af9c0b30 100644 --- a/paddle/fluid/operators/adam_op.h +++ b/paddle/fluid/operators/adam_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/for_range.h" @@ -199,23 +200,9 @@ struct SparseAdamFunctor { row_numel_(row_numel), row_count_(row_count) {} - inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const { - int64_t beg = 0, end = row_count_ - 1; - while (beg <= end) { - auto mid = ((beg + end) >> 1); - if (rows_[mid] == row) - return mid; - else if (rows_[mid] < row) - beg = mid + 1; - else - end = mid - 1; - } - return -1; - } - inline HOSTDEVICE void operator()(size_t i) const { - int64_t row = i / row_numel_; - auto row_idx = BinarySearchInRows(row); + auto row_idx = + math::BinarySearch(rows_, row_count_, i / row_numel_); T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; // The following code is the same as dense @@ -244,6 +231,12 @@ template class AdamOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + using paddle::framework::LoDTensor; using paddle::operators::detail::Ref; diff --git a/paddle/fluid/operators/adamax_op.cc b/paddle/fluid/operators/adamax_op.cc index 32062574bcf71ff96e451eaa6865b6bbfc3b1c80..d4aa4d338a2379adf985ba7f89b528bc402eda06 100644 --- a/paddle/fluid/operators/adamax_op.cc +++ b/paddle/fluid/operators/adamax_op.cc @@ -35,6 +35,16 @@ class AdamaxOp : public framework::OperatorWithKernel { "Input(LearningRate) of AdamaxOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), "Input(Beta1Pow) of AdamaxOp should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Grad").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(ParamOut) of AdamaxOp should not be null."); diff --git a/paddle/fluid/operators/adamax_op.h b/paddle/fluid/operators/adamax_op.h index de644676fd9c3fabdbf01d2fd9c69858c2627ed3..7137fbd9651b4523f6d1609a0595b30758aa40df 100644 --- a/paddle/fluid/operators/adamax_op.h +++ b/paddle/fluid/operators/adamax_op.h @@ -23,6 +23,17 @@ template class AdamaxOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + const auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE(grad_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Grad").front(), grad_var->Type().name()); + auto param_out_tensor = ctx.Output("ParamOut"); auto moment_out_tensor = ctx.Output("MomentOut"); auto inf_norm_out_tensor = ctx.Output("InfNormOut"); diff --git a/paddle/fluid/operators/decayed_adagrad_op.cc b/paddle/fluid/operators/decayed_adagrad_op.cc index c0f2b49a04d9e88502c4b63bca493cd2b7ad1c5c..d73ae9e2721b388212cb6efa354eb4b480df9cad 100644 --- a/paddle/fluid/operators/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/decayed_adagrad_op.cc @@ -32,6 +32,16 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->HasInput("LearningRate"), "Input(LearningRate) of DecayedAdagradOp should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Grad").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(ParamOut) of DecayedAdagradOp should not be null."); diff --git a/paddle/fluid/operators/decayed_adagrad_op.h b/paddle/fluid/operators/decayed_adagrad_op.h index a46af078e0c6b4bf1faca0570b6a97b026864f13..5df43d33ef9f720fd20d57c53ff37cc85440b24e 100644 --- a/paddle/fluid/operators/decayed_adagrad_op.h +++ b/paddle/fluid/operators/decayed_adagrad_op.h @@ -23,6 +23,17 @@ template class DecayedAdagradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + const auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE(grad_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Grad").front(), grad_var->Type().name()); + auto param_out_tensor = ctx.Output("ParamOut"); auto moment_out_tensor = ctx.Output("MomentOut"); diff --git a/paddle/fluid/operators/ftrl_op.cc b/paddle/fluid/operators/ftrl_op.cc index 70ba25c213046cc934f46be067080d5fdbb42f9e..b77e12d6508eb07ae137b313ca91eac951afbcbe 100644 --- a/paddle/fluid/operators/ftrl_op.cc +++ b/paddle/fluid/operators/ftrl_op.cc @@ -34,6 +34,16 @@ class FTRLOp : public framework::OperatorWithKernel { "Input(Grad) of FTRL should not be null."); PADDLE_ENFORCE(ctx->HasInput("LearningRate"), "Input(LearningRate) of FTRL should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Grad").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(ParamOut) of FTRL should not be null."); diff --git a/paddle/fluid/operators/ftrl_op.h b/paddle/fluid/operators/ftrl_op.h index 6f821e7e9944214fc5ebdf6bc7db8789b8ada6b9..8f812c9a037bfac8c1e29e32a5ad5b077c8153d1 100644 --- a/paddle/fluid/operators/ftrl_op.h +++ b/paddle/fluid/operators/ftrl_op.h @@ -28,6 +28,17 @@ template class FTRLOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + const auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE(grad_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Grad").front(), grad_var->Type().name()); + auto* param_out = ctx.Output("ParamOut"); auto* sq_accum_out = ctx.Output("SquaredAccumOut"); auto* lin_accum_out = ctx.Output("LinearAccumOut"); diff --git a/paddle/fluid/operators/math/algorithm.h b/paddle/fluid/operators/math/algorithm.h new file mode 100644 index 0000000000000000000000000000000000000000..262469beea7449eb5820b86de1ac4f790a833e79 --- /dev/null +++ b/paddle/fluid/operators/math/algorithm.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // for int64_t +#include + +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { + int64_t beg = 0, end = num - 1; + while (beg <= end) { + auto mid = ((beg + end) >> 1); + if (x[mid] == val) + return mid; + else if (x[mid] < val) + beg = mid + 1; + else + end = mid - 1; + } + return -1; +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index 5f43c5810812260c4384349bdb709716c9a182f5..12b916fcebd425bd4a03d920f947829098a924a1 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(param) of Momentum should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), @@ -33,6 +33,11 @@ class MomentumOp : public framework::OperatorWithKernel { "Input(velocity) of Momentum should not be null."); PADDLE_ENFORCE(ctx->HasInput("LearningRate"), "Input(LearningRate) of Momentum should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(ParamOut) of Momentum should not be null."); @@ -40,12 +45,15 @@ class MomentumOp : public framework::OperatorWithKernel { "Output(VelocityOut) of Momentum should not be null."); auto param_dim = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - "Param and Grad input of MomentumOp should have the same dimension."); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Velocity"), - "Param and Velocity of MomentumOp should have the same dimension."); + if (ctx->GetInputsVarType("Grad")[0] == + framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + "Param and Grad input of MomentumOp should have the same dimension."); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Velocity"), + "Param and Velocity of MomentumOp should have the same dimension."); + } PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1, "Learning_rate should be a scalar"); @@ -53,13 +61,34 @@ class MomentumOp : public framework::OperatorWithKernel { ctx->SetOutputDim("VelocityOut", param_dim); } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("Param")->type()); + const framework::ExecutionContext& ctx) const override { + auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; +class MomentumOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto input_var = op_desc.Input("Param")[0]; + for (auto& out_var : op_desc.Output("ParamOut")) { + if (block->FindRecursiveOrCreateVar(input_var).GetType() == + framework::proto::VarType::SELECTED_ROWS) { + block->FindRecursiveOrCreateVar(out_var).SetType( + framework::proto::VarType::SELECTED_ROWS); + } else if (block->FindRecursiveOrCreateVar(input_var).GetType() == + framework::proto::VarType::LOD_TENSOR) { + block->FindRecursiveOrCreateVar(out_var).SetType( + framework::proto::VarType::LOD_TENSOR); + } else { + PADDLE_THROW( + "Only support LodTensor and SelectedRows, Unexpected Input Type."); + } + } + } +}; + class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -110,6 +139,9 @@ $$ } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker); -REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel, - ops::MomentumOpKernel); +REGISTER_OPERATOR(momentum, ops::MomentumOp, ops::MomentumOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::MomentumOpInferVarType); +REGISTER_OP_CPU_KERNEL( + momentum, ops::MomentumOpKernel, + ops::MomentumOpKernel); diff --git a/paddle/fluid/operators/momentum_op.cu b/paddle/fluid/operators/momentum_op.cu index a3932db1f3a50305d585cd3d5e86fa1b527df78b..b68fec34d43f0dee834f1045f192d5c6089d9356 100644 --- a/paddle/fluid/operators/momentum_op.cu +++ b/paddle/fluid/operators/momentum_op.cu @@ -15,65 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/momentum_op.h" -namespace paddle { -namespace operators { - -template -__global__ void MomentumKernel(const T* p, const T* g, const T* v, - const T* learning_rate, const T mu, - const int64_t num, bool use_nesterov, T* p_out, - T* v_out) { - T lr = learning_rate[0]; - if (use_nesterov) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { - T g_val = g[i]; - T v_new = v[i] * mu + g_val; - v_out[i] = v_new; - p_out[i] = p[i] - (g_val + v_new * mu) * lr; - } - } else { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { - T v_new = v[i] * mu + g[i]; - v_out[i] = v_new; - p_out[i] = p[i] - lr * v_new; - } - } -} - -template -class MomentumOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); - auto param = ctx.Input("Param"); - auto velocity = ctx.Input("Velocity"); - auto grad = ctx.Input("Grad"); - auto learning_rate = ctx.Input("LearningRate"); - - T* p_out = param_out->mutable_data(ctx.GetPlace()); - T* v_out = velocity_out->mutable_data(ctx.GetPlace()); - - T mu = static_cast(ctx.Attr("mu")); - bool use_nesterov = ctx.Attr("use_nesterov"); - - auto* p = param->data(); - auto* v = velocity->data(); - auto* g = grad->data(); - auto* lr = learning_rate->data(); - - int block = 512; - int grid = (param->numel() + block - 1) / block; - MomentumKernel<<>>( - p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out); - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(momentum, ops::MomentumOpCUDAKernel, - ops::MomentumOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + momentum, ops::MomentumOpKernel, + ops::MomentumOpKernel); diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/momentum_op.h index 264726040fb566a52b8c0cdee0a1524197d2a675..6b4d00f56ca06c402c07ecf770a390e88ae3edf1 100644 --- a/paddle/fluid/operators/momentum_op.h +++ b/paddle/fluid/operators/momentum_op.h @@ -13,29 +13,48 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { -template -class MomentumOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); - auto param = ctx.Input("Param"); - auto velocity = ctx.Input("Velocity"); - auto grad = ctx.Input("Grad"); - auto learning_rate = ctx.Input("LearningRate"); +using framework::Tensor; +using framework::SelectedRows; +struct NoNesterov; +struct UseNesterov; - param_out->mutable_data(ctx.GetPlace()); - velocity_out->mutable_data(ctx.GetPlace()); +template +class CPUDenseMomentumFunctor { + private: + const Tensor* param; + const Tensor* grad; + const Tensor* velocity; + const Tensor* learning_rate; + const T mu; + const T use_nesterov; + Tensor* param_out; + Tensor* velocity_out; - T mu = static_cast(ctx.Attr("mu")); - bool use_nesterov = ctx.Attr("use_nesterov"); + public: + CPUDenseMomentumFunctor(const Tensor* param, const Tensor* grad, + const Tensor* velocity, const Tensor* learning_rate, + const T mu, const bool use_nesterov, + Tensor* param_out, Tensor* velocity_out) + : param(param), + grad(grad), + velocity(velocity), + learning_rate(learning_rate), + mu(mu), + use_nesterov(use_nesterov), + param_out(param_out), + velocity_out(velocity_out) {} + inline void operator()() { auto p_out = framework::EigenVector::Flatten(*param_out); auto v_out = framework::EigenVector::Flatten(*velocity_out); @@ -53,5 +72,283 @@ class MomentumOpKernel : public framework::OpKernel { } }; +template +class DenseMomentumFunctor; + +// NOTE(dzh) for performance. +// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two +// functor. +template +class DenseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t num_; + T* p_out_; + T* v_out_; + + public: + DenseMomentumFunctor(const T* p, const T* g, const T* v, + const T* learning_rate, const T mu, const int64_t num, + T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(learning_rate), + mu_(mu), + num_(num), + p_out_(p_out), + v_out_(v_out) {} + inline HOSTDEVICE void operator()(size_t i) const { + // put memory access in register + const T p = p_[i]; + const T g = g_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - (g + v_out * mu_) * lr; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template +class DenseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t num_; + T* p_out_; + T* v_out_; + + public: + DenseMomentumFunctor(const T* p, const T* g, const T* v, + const T* learning_rate, const T mu, const int64_t num, + T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(learning_rate), + mu_(mu), + num_(num), + p_out_(p_out), + v_out_(v_out) {} + inline HOSTDEVICE void operator()(size_t i) const { + // put memory access in register + const T p = p_[i]; + const T g = g_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - lr * v_out; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template +class SparseMomentumFunctor; + +template +class SparseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t* rows_; + const int64_t row_numel_; + const int64_t row_height_; + T* p_out_; + T* v_out_; + + public: + SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr, + const T mu, const int64_t* rows, int64_t row_numel, + int64_t row_height, T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(lr), + mu_(mu), + rows_(rows), + row_numel_(row_numel), + row_height_(row_height), + p_out_(p_out), + v_out_(v_out) {} + + inline HOSTDEVICE void operator()(size_t i) { + auto row_idx = + math::BinarySearch(rows_, row_height_, i / row_numel_); + T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0; + // put memory access in register + const T p = p_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - (g + v_out * mu_) * lr; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template +class SparseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t* rows_; + const int64_t row_numel_; + const int64_t row_height_; + T* p_out_; + T* v_out_; + + public: + SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr, + const T mu, const int64_t* rows, int64_t row_numel, + int64_t row_height, T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(lr), + mu_(mu), + rows_(rows), + row_numel_(row_numel), + row_height_(row_height), + p_out_(p_out), + v_out_(v_out) {} + + inline HOSTDEVICE void operator()(size_t i) { + auto row_idx = + math::BinarySearch(rows_, row_height_, i / row_numel_); + T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0; + // put memory access in register + const T p = p_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - v_out * lr; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template +class MomentumOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + T mu = static_cast(ctx.Attr("mu")); + bool use_nesterov = ctx.Attr("use_nesterov"); + + auto learning_rate = ctx.Input("LearningRate"); + auto param = ctx.Input("Param"); + auto param_out = ctx.Output("ParamOut"); + auto* velocity = ctx.Input("Velocity"); + auto velocity_out = ctx.Output("VelocityOut"); + param_out->mutable_data(ctx.GetPlace()); + velocity_out->mutable_data(ctx.GetPlace()); + + auto* grad_var = ctx.InputVar("Grad"); + if (grad_var->IsType()) { + auto grad = ctx.Input("Grad"); + if (platform::is_cpu_place(ctx.GetPlace())) { + CPUDenseMomentumFunctor functor(param, grad, velocity, learning_rate, + mu, use_nesterov, param_out, + velocity_out); + functor(); + } else if (platform::is_gpu_place(ctx.GetPlace())) { + platform::ForRange for_range( + static_cast(ctx.device_context()), + param->numel()); + if (use_nesterov) { + DenseMomentumFunctor functor( + param->data(), grad->data(), velocity->data(), + learning_rate->data(), mu, param->numel(), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + + } else { + DenseMomentumFunctor functor( + param->data(), grad->data(), velocity->data(), + learning_rate->data(), mu, param->numel(), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + } + } + + } else if (grad_var->IsType()) { + // sparse update embedding with selectedrows + auto grad = ctx.Input("Grad"); + + // sparse update maybe empty. + if (grad->rows().size() == 0) { + VLOG(3) << "Grad SelectedRows contains no data!"; + return; + } + auto* merged_grad = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + math::scatter::MergeAdd merge_func; + merge_func(ctx.template device_context(), *grad, + merged_grad); + + const int64_t* rows = nullptr; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + rows = merged_grad->rows().CUDAData(ctx.GetPlace()); + } else { +#endif + rows = merged_grad->rows().data(); +#ifdef PADDLE_WITH_CUDA + } +#endif + int64_t row_numel = + merged_grad->value().numel() / merged_grad->rows().size(); + platform::ForRange for_range( + static_cast(ctx.device_context()), + param->numel()); + if (use_nesterov) { + SparseMomentumFunctor functor( + param->data(), merged_grad->value().data(), + velocity->data(), learning_rate->data(), mu, rows, row_numel, + static_cast(merged_grad->rows().size()), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + + } else { + SparseMomentumFunctor functor( + param->data(), merged_grad->value().data(), + velocity->data(), learning_rate->data(), mu, rows, row_numel, + static_cast(merged_grad->rows().size()), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + } + } else { + PADDLE_THROW( + string::Sprintf("MomentumOp only supports LoDTensor or SelectedRows " + "gradient, but the received Variable Type is %s", + grad_var->Type().name())); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/rmsprop_op.cc b/paddle/fluid/operators/rmsprop_op.cc index 2f773f222e50a440801b06a4fd997bf237b34772..f06f87e61d3a4d1fc8b864b9dd84e697fb12a006 100644 --- a/paddle/fluid/operators/rmsprop_op.cc +++ b/paddle/fluid/operators/rmsprop_op.cc @@ -32,6 +32,11 @@ class RmspropOp : public framework::OperatorWithKernel { "Input(Grad) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Moment"), "Input(Moment) of RmspropOp should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(param_out) of RmspropOp should not be null."); diff --git a/paddle/fluid/operators/rmsprop_op.h b/paddle/fluid/operators/rmsprop_op.h index 25ed32c5ebb2ff5be962ac1e3e38c970623d705c..2a1527a3d97cf33dc77160a1890eae29e77129d7 100644 --- a/paddle/fluid/operators/rmsprop_op.h +++ b/paddle/fluid/operators/rmsprop_op.h @@ -13,66 +13,259 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; template using EigenVector = framework::EigenVector; +template +struct DenseRmspropGradFunctor { + inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {} + + HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; } + + const T *grad_; +}; + +template +struct SparseRmspropGradFunctor { + inline SparseRmspropGradFunctor(const T *grad, const int64_t *rows, + int64_t row_numel, int64_t row_count) + : grad_(grad), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count) {} + + HOSTDEVICE inline T operator()(int64_t idx) const { + auto row_idx = math::BinarySearch(rows_, row_count_, idx / row_numel_); + return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0; + } + + const T *grad_; + const int64_t *rows_; + int64_t row_numel_; + int64_t row_count_; +}; + +template +struct UncenteredRmspropFunctor { + UncenteredRmspropFunctor(T *param, T *ms, T *mom, const T *lr, T rho, + T epsilon, T momentum, + const GradFunctor &grad_functor) + : param_(param), + ms_(ms), + mom_(mom), + lr_(lr), + rho_(rho), + epsilon_(epsilon), + momentum_(momentum), + grad_functor_(grad_functor) {} + + HOSTDEVICE inline void operator()(int64_t idx) const { + T g = grad_functor_(idx); + T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; + T mom_out = momentum_ * mom_[idx] + lr_[0] * g / sqrt(ms_out + epsilon_); + param_[idx] -= mom_out; + ms_[idx] = ms_out; + mom_[idx] = mom_out; + } + + T *param_; + T *ms_; + T *mom_; + const T *lr_; + T rho_; + T epsilon_; + T momentum_; + GradFunctor grad_functor_; +}; + +template +struct CenteredRmspropFunctor { + CenteredRmspropFunctor(T *param, T *ms, T *mom, T *mean_grad, const T *lr, + T rho, T epsilon, T momentum, + const GradFunctor &grad_functor) + : param_(param), + ms_(ms), + mom_(mom), + mean_grad_(mean_grad), + lr_(lr), + rho_(rho), + epsilon_(epsilon), + momentum_(momentum), + grad_functor_(grad_functor) {} + + HOSTDEVICE inline void operator()(int64_t idx) const { + T g = grad_functor_(idx); + T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; + T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g; + T mom_out = momentum_ * mom_[idx] + + lr_[0] * g / sqrt(ms_out - mg_out * mg_out + epsilon_); + param_[idx] -= mom_out; + ms_[idx] = ms_out; + mom_[idx] = mom_out; + mean_grad_[idx] = mg_out; + } + + T *param_; + T *ms_; + T *mom_; + T *mean_grad_; + const T *lr_; + T rho_; + T epsilon_; + T momentum_; + GradFunctor grad_functor_; +}; + template class RmspropOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* param_out = ctx.Output("ParamOut"); - auto* moment_out = ctx.Output("MomentOut"); - auto* mean_square_out = ctx.Output("MeanSquareOut"); + void Compute(const framework::ExecutionContext &ctx) const override { + using LoDTensor = framework::LoDTensor; + const auto *param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + auto *grad_var = ctx.InputVar("Grad"); + auto *param_out = ctx.Output("ParamOut"); + auto *moment_out = ctx.Output("MomentOut"); + auto *mean_square_out = ctx.Output("MeanSquareOut"); - auto grad = ctx.Input("Grad"); + auto epsilon = static_cast(ctx.Attr("epsilon")); + auto rho = static_cast(ctx.Attr("decay")); + auto momentum = static_cast(ctx.Attr("momentum")); + bool centered = ctx.Attr("centered"); - param_out->mutable_data(ctx.GetPlace()); - moment_out->mutable_data(ctx.GetPlace()); - mean_square_out->mutable_data(ctx.GetPlace()); + auto &p_tensor = *ctx.Input("Param"); + auto &ms_tensor = *ctx.Input("MeanSquare"); + auto &lr_tensor = *ctx.Input("LearningRate"); + auto &mom_tensor = *ctx.Input("Moment"); - float epsilon = ctx.Attr("epsilon"); - float rho = ctx.Attr("decay"); - float momentum = ctx.Attr("momentum"); - bool centered = ctx.Attr("centered"); + PADDLE_ENFORCE_EQ(&p_tensor, param_out, + "Param and ParamOut must be the same Tensor"); + PADDLE_ENFORCE_EQ(&mom_tensor, moment_out, + "Moment and MomentOut must be the same Tensor"); + PADDLE_ENFORCE_EQ(&ms_tensor, mean_square_out, + "MeanSquare and MeanSquareOut must be the same Tensor"); + + auto &dev_ctx = ctx.template device_context(); + size_t limit = static_cast(ms_tensor.numel()); + + if (grad_var->IsType()) { + auto &grad_tensor = grad_var->Get(); + + if (std::is_same::value) { + auto &place = + *ctx.template device_context().eigen_device(); + auto lr_value = lr_tensor.data()[0]; + + auto p = EigenVector::Flatten(p_tensor); + auto ms = EigenVector::Flatten(ms_tensor); + auto g = EigenVector::Flatten(grad_tensor); + auto mom = EigenVector::Flatten(mom_tensor); + + auto p_out = EigenVector::Flatten(*param_out); + auto mom_out = EigenVector::Flatten(*moment_out); + auto ms_out = EigenVector::Flatten(*mean_square_out); + + ms_out.device(place) = rho * ms + (1 - rho) * g * g; + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto mg = EigenVector::Flatten(mg_tensor); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + auto mg_out = EigenVector::Flatten(*mean_grad_out); + + mg_out.device(place) = rho * mg + (1 - rho) * g; + mom_out.device(place) = + momentum * mom + + lr_value * g / (ms_out - mg_out.square() + epsilon).sqrt(); + } else { + mom_out.device(place) = + momentum * mom + lr_value * g / (ms_out + epsilon).sqrt(); + } + p_out.device(place) = p - mom_out; + } else { + DenseRmspropGradFunctor grad_func(grad_tensor.data()); + platform::ForRange for_range(dev_ctx, limit); + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + for_range(CenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), + mean_grad_out->mutable_data(ctx.GetPlace()), + lr_tensor.data(), rho, epsilon, momentum, grad_func)); + } else { + for_range(UncenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } + } + } else if (grad_var->IsType()) { + auto &grad = grad_var->Get(); + auto *merged_grad = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + + math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, grad, merged_grad); + + platform::ForRange for_range(dev_ctx, limit); + const int64_t *rows; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + rows = merged_grad->rows().CUDAData(ctx.GetPlace()); + } else { +#endif + rows = merged_grad->rows().data(); +#ifdef PADDLE_WITH_CUDA + } +#endif + auto &merged_tensor = merged_grad->value(); + int64_t row_count = merged_grad->rows().size(); + int64_t row_numel = merged_tensor.numel() / row_count; + SparseRmspropGradFunctor grad_func(merged_tensor.data(), rows, + row_numel, row_count); - auto p = EigenVector::Flatten(*ctx.Input("Param")); - auto ms = EigenVector::Flatten(*ctx.Input("MeanSquare")); - auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); - auto g = EigenVector::Flatten(*grad); - auto mom = EigenVector::Flatten(*ctx.Input("Moment")); - - auto p_out = EigenVector::Flatten(*param_out); - auto mom_out = EigenVector::Flatten(*moment_out); - auto ms_out = EigenVector::Flatten(*mean_square_out); - auto& place = *ctx.template device_context().eigen_device(); - - Eigen::DSizes grad_dsize(static_cast(grad->numel())); - - ms_out.device(place) = rho * ms + (1 - rho) * g * g; - if (centered) { - auto mg = EigenVector::Flatten(*ctx.Input("MeanGrad")); - auto* mean_grad_out = ctx.Output("MeanGradOut"); - mean_grad_out->mutable_data(ctx.GetPlace()); - auto mg_out = EigenVector::Flatten(*mean_grad_out); - - mg_out.device(place) = rho * mg + (1 - rho) * g; - mom_out.device(place) = momentum * mom + - lr.broadcast(grad_dsize) * g / - (ms_out - mg_out.square() + epsilon).sqrt(); + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + for_range(CenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), + mean_grad_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } else { + for_range(UncenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } } else { - mom_out.device(place) = - momentum * mom + - lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt(); + PADDLE_THROW("RMSProp only supports LoDTensor or SelectedRows gradient"); } - p_out.device(place) = p - mom_out; } }; diff --git a/paddle/fluid/operators/sgd_op.cc b/paddle/fluid/operators/sgd_op.cc index fef230e42d07a5ed73b7a7a6ab682694675bb9d2..411a126bc8e2b3a8d25f436489c13970568ccae4 100644 --- a/paddle/fluid/operators/sgd_op.cc +++ b/paddle/fluid/operators/sgd_op.cc @@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of SGDOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), @@ -42,7 +42,7 @@ class SGDOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); return framework::OpKernelType(data_type, ctx.device_context()); } @@ -50,17 +50,20 @@ class SGDOp : public framework::OperatorWithKernel { class SGDOpInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto input_var = op_desc.Input("Param")[0]; - for (auto& out_var : op_desc.Output("ParamOut")) { - if (block->FindRecursiveOrCreateVar(input_var).GetType() == - framework::proto::VarType::SELECTED_ROWS) { - block->FindRecursiveOrCreateVar(out_var).SetType( - framework::proto::VarType::SELECTED_ROWS); - } else { - block->FindRecursiveOrCreateVar(out_var).SetType( - framework::proto::VarType::LOD_TENSOR); + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto input_var_n = op_desc.Input("Param")[0]; + auto in_var_type = block->FindRecursiveOrCreateVar(input_var_n).GetType(); + PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || + in_var_type == framework::proto::VarType::LOD_TENSOR, + "The input Var's type should be LoDtensor or SelectedRows," + " but the received var(%s)'s type is %s", + input_var_n, in_var_type); + + for (auto &out_var_n : op_desc.Output("ParamOut")) { + auto &out_var = block->FindRecursiveOrCreateVar(out_var_n); + if (out_var.GetType() != in_var_type) { + out_var.SetType(in_var_type); } } } diff --git a/paddle/fluid/operators/sgd_op.cu b/paddle/fluid/operators/sgd_op.cu index 4722be7a666d3e8f3c25c9499f88ddda835f60e3..5ddbac4c8176b6da519aa84490b6303b65214479 100644 --- a/paddle/fluid/operators/sgd_op.cu +++ b/paddle/fluid/operators/sgd_op.cu @@ -57,6 +57,12 @@ template class SGDOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + auto* param = ctx.Input("Param"); auto* param_out = ctx.Output("ParamOut"); auto* learning_rate = ctx.Input("LearningRate"); diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 763bb403588d13c15271d26b09813dddf3a5dd8c..aa907595cb7cf165974caa69fe8eb0370471732d 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -23,14 +23,14 @@ namespace operators { template class CPUUniformRandomKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - framework::Tensor* tensor = nullptr; + void Compute(const framework::ExecutionContext &ctx) const override { + framework::Tensor *tensor = nullptr; auto out_var = ctx.OutputVar("Out"); if (out_var->IsType()) { tensor = out_var->GetMutable(); } else if (out_var->IsType()) { auto shape = ctx.Attr>("shape"); - auto* selected_rows = out_var->GetMutable(); + auto *selected_rows = out_var->GetMutable(); tensor = selected_rows->mutable_value(); tensor->Resize(framework::make_ddim(shape)); selected_rows->mutable_rows()->reserve(shape[0]); @@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel { "uniform_random_op's output only" "supports SelectedRows and LoDTensor"); } - T* data = tensor->mutable_data(ctx.GetPlace()); + T *data = tensor->mutable_data(ctx.GetPlace()); unsigned int seed = static_cast(ctx.Attr("seed")); std::minstd_rand engine; if (seed == 0) { @@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of UniformRandomOp should not be null."); PADDLE_ENFORCE( ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), "uniform_random's min must less then max"); - auto& shape = ctx->Attrs().Get>("shape"); + auto &shape = ctx->Attrs().Get>("shape"); std::vector temp; temp.reserve(shape.size()); for (auto dim : shape) { @@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); @@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max]. class UniformRandomOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { auto out_var_name = op_desc.Output("Out").front(); - if (block->FindRecursiveOrCreateVar(out_var_name).GetType() == - framework::proto::VarType::SELECTED_ROWS) { - block->FindRecursiveOrCreateVar(out_var_name) - .SetType(framework::proto::VarType::SELECTED_ROWS); - } else { - block->FindRecursiveOrCreateVar(out_var_name) - .SetType(framework::proto::VarType::LOD_TENSOR); + auto var_data_type = static_cast( + boost::get(op_desc.GetAttr("dtype"))); + + auto out_var = block->FindRecursiveOrCreateVar(out_var_name); + if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) { + out_var.SetType(framework::proto::VarType::LOD_TENSOR); } + out_var.SetDataType(var_data_type); } }; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 15268aebe4df5ac4038727338b133cbd0fca2acd..942b6de090267a4321aa0a14b02ec70fb218a834 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -156,7 +156,50 @@ PYBIND11_PLUGIN(core) { .def("_get_double_element", TensorGetElement) .def("_dtype", [](Tensor &self) { return ToDataType(self.type()); }); - py::class_(m, "LoDTensor") + py::class_(m, "LoDTensor", R"DOC( + LoDTensor is a Tensor with optional LoD information. + + np.array(lod_tensor) can convert LoDTensor to numpy array. + lod_tensor.lod() can retrieve the LoD information. + + LoD is short for Level of Details and is usually used for varied sequence + length. You can skip the following comment if you don't need optional LoD. + + For example: + A LoDTensor X can look like the example below. It contains 2 sequences. + The first has length 2 and the second has length 3, as described by x.lod. + + The first tensor dimension 5=2+3 is calculated from LoD if it's available. + It means the total number of sequence element. In X, each element has 2 + columns, hence [5, 2]. + + x.lod = [[2, 3]] + x.data = [[1, 2], [3, 4], // seq 1 + [5, 6], [7, 8], [9, 10]] // seq 2 + x.shape = [5, 2] + + LoD can have multiple levels (for example, a paragraph can have multiple + sentences and a sentence can have multiple words). In the following + LodTensor Y, the lod_level is 2. It means there are 2 sequence, the + first sequence length is 2 (has 2 sub-sequences), the second one's + length is 1. The first sequence's 2 sub-sequences have length 2 and 2, + respectively. And the second sequence's 1 sub-sequence has length 3. + + y.lod = [[2 1], [2 2 3]] + y.shape = [2+2+3, ...] + + Note: + In above description, LoD is length-based. In Paddle internal + implementation, lod is offset-based. Hence, internally, + y.lod is represented as [[0, 2, 3], [0, 2, 4, 7]] (length-based + equivlent would be [[2-0, 3-2], [2-0, 4-2, 7-4]]). + + Sometimes LoD is called recursive_sequence_length to be more + self-explanatory. In this case, it must be length-based. Due to history + reasons. when LoD is called lod in public API, it might be offset-based. + Users should be careful about it. + + )DOC") .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) .def("__init__", @@ -596,26 +639,58 @@ All parameter, weight, gradient are variables in Paddle. // -- python binds for parallel executor. py::class_ pe(m, "ParallelExecutor"); - py::class_ exec_strategy(pe, "ExecutionStrategy"); + py::class_ exec_strategy(pe, "ExecutionStrategy", R"DOC( + ExecutionStrategy allows the user to more preciously control how to run + the program in ParallelExecutor by setting the property. + + Examples: + .. code-block:: python + + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = 4 + + train_exe = fluid.ParallelExecutor(use_cuda=True, + loss_name=loss.name, + exec_strategy=exec_strategy) + + train_loss, = train_exe.run([loss.name], feed=feed_dict) + + )DOC"); + exec_strategy.def(py::init()) .def_property( "num_threads", [](const ExecutionStrategy &self) { return self.num_threads_; }, [](ExecutionStrategy &self, size_t num_threads) { self.num_threads_ = num_threads; - }) + }, + R"DOC(The type is INT, num_threads represents the size of thread pool that + used to run the operators of the current program in ParallelExecutor. + If :math:`num\_threads=1`, all the operators will execute one by one, + but the order maybe difference between iterations. + If it is not set, it will be set in ParallelExecutor according to the + device type and device count, for GPU, :math:`num\_threads=device\_count*4`, for CPU, + :math:`num\_threads=CPU\_NUM*4`, the explanation of:math:`CPU\_NUM` is in ParallelExecutor. + if it is not set, ParallelExecutor will get the cpu count by calling + `multiprocessing.cpu_count()`. Default 0.)DOC") .def_property( "use_cuda", [](const ExecutionStrategy &self) { return self.use_cuda_; }, [](ExecutionStrategy &self, bool use_cuda) { self.use_cuda_ = use_cuda; - }) + }) // FIXME(chengduo): Doesn't add doc for 'use_cuda', use_cuda may + // make user confuse, because ParallelExecutor has a parameter named + // 'use_cuda' too, in current implementation, ParallelExecutor's + // 'use_cuda' will rewrite ExecutionStrategy's 'use_cuda'. .def_property( "allow_op_delay", [](const ExecutionStrategy &self) { return self.allow_op_delay_; }, [](ExecutionStrategy &self, bool allow_op_delay) { self.allow_op_delay_ = allow_op_delay; - }) + }, + R"DOC(The type is BOOL, allow_op_delay represents whether to delay the + communication operators to run, it may make the execution faster. + Note that in some models, allow_op_delay may cause program hang. Default False.)DOC") .def_property( "num_iteration_per_drop_scope", [](const ExecutionStrategy &self) { @@ -623,7 +698,19 @@ All parameter, weight, gradient are variables in Paddle. }, [](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) { self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope; - }); + }, + R"DOC(The type is INT, num_iteration_per_drop_scope indicates how + many iterations to clean up the temp variables which + is generated during execution. It may make the execution faster, + because the temp variable's shape maybe the same between two iterations. Default 100. + + NOTES: + 1. If you fetch data when calling the 'run', the ParallelExecutor + will clean up the temp variables at the end of the current iteration. + 2. In some NLP model, it may cause the GPU memory is insufficient, + in this case, you should reduce `num_iteration_per_drop_scope`. + )DOC"); + exec_strategy.def_property( "use_experimental_executor", [](const ExecutionStrategy &self) { @@ -634,7 +721,22 @@ All parameter, weight, gradient are variables in Paddle. : ExecutionStrategy::kDefault; }); - py::class_ build_strategy(pe, "BuildStrategy"); + py::class_ build_strategy(pe, "BuildStrategy", R"DOC( + BuildStrategy allows the user to more preciously control how to + build the SSA Graph in ParallelExecutor by setting the property. + + Examples: + .. code-block:: python + + build_strategy = fluid.BuildStrategy() + build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce + + train_exe = fluid.ParallelExecutor(use_cuda=True, + loss_name=loss.name, + build_strategy=build_strategy) + + train_loss, = train_exe.run([loss.name], feed=feed_dict) +)DOC"); py::enum_(build_strategy, "ReduceStrategy") .value("Reduce", BuildStrategy::ReduceStrategy::kReduce) @@ -652,31 +754,51 @@ All parameter, weight, gradient are variables in Paddle. [](const BuildStrategy &self) { return self.reduce_; }, [](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) { self.reduce_ = strategy; - }) + }, + R"DOC(The type is STR, there are two reduce strategies in ParallelExecutor, + 'AllReduce' and 'Reduce'. If you want that all the parameters' + optimization are done on all devices independently, you should choose 'AllReduce'; + if you choose 'Reduce', all the parameters' optimization will be evenly distributed + to different devices, and then broadcast the optimized parameter to other devices. + In some models, `Reduce` is faster. Default 'AllReduce'. )DOC") .def_property( "gradient_scale_strategy", [](const BuildStrategy &self) { return self.gradient_scale_; }, [](BuildStrategy &self, BuildStrategy::GradientScaleStrategy strategy) { self.gradient_scale_ = strategy; - }) + }, + R"DOC(The type is STR, there are three ways of defining :math:`loss@grad` in + ParallelExecutor, 'CoeffNumDevice', 'One' and 'Customized'. By default, + ParallelExecutor sets the :math:`loss@grad` according to the number of devices. + If you want to customize :math:`loss@grad`, you can choose 'Customized'. + Default 'CoeffNumDevice'.)DOC") .def_property( "debug_graphviz_path", [](const BuildStrategy &self) { return self.debug_graphviz_path_; }, [](BuildStrategy &self, const std::string &path) { self.debug_graphviz_path_ = path; - }) + }, + R"DOC(The type is STR, debug_graphviz_path indicate the path that + writing the SSA Graph to file in the form of graphviz, you. + It is useful for debugging. Default "")DOC") .def_property( "enable_data_balance", [](const BuildStrategy &self) { return self.enable_data_balance_; }, - [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }) - .def_property("fuse_elewise_add_act_ops", - [](const BuildStrategy &self) { - return self.fuse_elewise_add_act_ops_; - }, - [](BuildStrategy &self, bool b) { - self.fuse_elewise_add_act_ops_ = b; - }); + [](BuildStrategy &self, bool b) { + self.enable_data_balance_ = b; + }) // FIXME(chengudo): enable_data_balance seems not important + .def_property( + "fuse_elewise_add_act_ops", + [](const BuildStrategy &self) { + return self.fuse_elewise_add_act_ops_; + }, + [](BuildStrategy &self, bool b) { + self.fuse_elewise_add_act_ops_ = b; + }, + R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether + to fuse elementwise_add_op and activation_op, + it may make the execution faster. Default False)DOC"); pe.def(py::init &, const std::unordered_set &, diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index bb288384e8501caf1cb6f9cd186cc09985a5de55..8a5807d5a7801c1e84bdd2673d282a374641dfee 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -654,11 +654,21 @@ function gen_fluid_inference_lib() { if [[ ${WITH_C_API:-OFF} == "OFF" && ${WITH_INFERENCE:-ON} == "ON" ]] ; then cat <